collabore-tunnel/main.py

343 lines
12 KiB
Python
Raw Normal View History

2023-01-16 14:59:30 +00:00
"""collabore tunnel SSH server"""
import asyncio
import logging
import os
import random
import string
import sys
2023-05-13 15:03:41 +00:00
import time
2023-01-16 14:59:30 +00:00
from asyncio import AbstractEventLoop
2023-05-13 15:03:41 +00:00
from collections import deque
2023-01-16 14:59:30 +00:00
from os import path
from types import FrameType
from typing import AnyStr, Optional, Tuple
2023-05-13 15:03:41 +00:00
from _asyncio import Task
2023-01-16 14:59:30 +00:00
import asyncssh
from asyncssh import SSHKey, SSHServerConnection
2023-05-13 15:03:41 +00:00
from asyncssh.channel import (
SSHUNIXChannel,
SSHUNIXSession,
SSHUNIXSessionFactory,
)
2023-01-16 14:59:30 +00:00
from asyncssh.listener import create_unix_forward_listener
from asyncssh.misc import MaybeAwait
from loguru import logger
from loguru._handler import Handler
2023-05-13 15:03:41 +00:00
2023-01-16 14:59:30 +00:00
unix_sockets_dir: str = os.getenv("UNIX_SOCKETS_DIRECTORY", "/tmp/collabore-tunnel")
server_hostname: str = os.getenv("SERVER_HOSTNAME", "tnl.clb.re")
config_dir: str = os.getenv("CONFIG_DIRECTORY", ".")
2023-05-13 15:03:41 +00:00
welcome_banner_file: str = os.getenv("WELCOME_BANNER_FILE", "./welcome_banner.txt")
rate_limit_count: int = int(os.getenv("RATE_LIMIT_COUNT", "5"))
rate_limit_interval: int = int(os.getenv("RATE_LIMIT_INTERVAL", "60"))
max_connections_per_ip: int = int(os.getenv("MAX_CONNECTIONS_PER_IP", "5"))
timeout: int = int(os.getenv("TIMEOUT", "120"))
2023-01-16 14:59:30 +00:00
ssh_server_host: str = os.getenv("SSH_SERVER_HOST", "0.0.0.0")
ssh_server_port: int = int(os.getenv("SSH_SERVER_PORT", "22"))
log_level: str = os.getenv("LOG_LEVEL", "INFO")
log_depth: int = int(os.getenv("LOG_DEPTH", "2"))
2023-05-13 15:03:41 +00:00
def read_welcome_banner() -> str:
"""Read the welcome banner from a file"""
if not os.path.exists(welcome_banner_file):
return welcome_banner
with open(welcome_banner_file, "r", encoding="UTF-8") as file:
return file.read()
welcome_banner: str = read_welcome_banner()
class RateLimiter:
"""Rate limiter handling class"""
def __init__(self, max_requests: int, interval: int):
"""Init class"""
self.max_requests: int = max_requests
self.interval: int = interval
self.timestamps: deque = deque()
def is_rate_limited(self) -> bool:
"""Check if rate limited"""
now: float = time.time()
while self.timestamps and self.timestamps[0] < now - self.interval:
self.timestamps.popleft()
if len(self.timestamps) >= self.max_requests:
return True
self.timestamps.append(now)
return False
class ConcurrentConnections:
"""Concurrent connection handling class"""
def __init__(self):
"""Init class"""
self.ip_connections: dict = {}
def increment(self, ip_addr: str) -> None:
"""Increment the number of concurrent connections for an IP"""
if ip_addr not in self.ip_connections:
self.ip_connections[ip_addr] = 1
else:
self.ip_connections[ip_addr] += 1
def decrement(self, ip_addr: str) -> None:
"""Decrement the number of concurrent connections for an IP"""
self.ip_connections[ip_addr] -= 1
def get(self, ip_addr: str) -> int:
"""Get the number of concurent connection for an IP"""
return self.ip_connections.get(ip_addr, 0)
ip_address_connections = ConcurrentConnections()
def check_concurrent_connections(ip_addr: str) -> bool:
"""Checking for concurrent connections"""
return ip_address_connections.get(ip_addr) >= max_connections_per_ip
2023-01-16 14:59:30 +00:00
class SSHServer(asyncssh.SSHServer):
"""SSH server protocol handler class"""
2023-05-13 15:03:41 +00:00
rate_limiters: dict = {}
2023-01-16 14:59:30 +00:00
def __init__(self):
"""Init class"""
self.conn: SSHServerConnection
self.socket_path: str
2023-05-13 15:03:41 +00:00
self.ip_addr: str
def check_rate_limit(self, ip_addr: str) -> bool:
"""Check if rate limited"""
if ip_addr not in self.rate_limiters:
self.rate_limiters[ip_addr] = RateLimiter(
rate_limit_count, rate_limit_interval
)
return self.rate_limiters[ip_addr].is_rate_limited()
2023-01-16 14:59:30 +00:00
def connection_made(self, conn: SSHServerConnection) -> None:
"""Called when a connection is made"""
self.conn = conn
2023-05-13 15:03:41 +00:00
self.ip_addr, _ = conn.get_extra_info("peername")
if self.check_rate_limit(self.ip_addr):
conn.set_extra_info(rate_limited=True)
if check_concurrent_connections(self.ip_addr):
conn.set_extra_info(connection_limited=True)
ip_address_connections.increment(self.ip_addr)
2023-01-16 14:59:30 +00:00
def connection_lost(self, exc: Optional[Exception]) -> None:
"""Called when a connection is lost or closed"""
if exc:
logging.info("The connection has been terminated: %s", str(exc))
try:
os.remove(self.socket_path)
except AttributeError:
pass
2023-05-13 15:03:41 +00:00
ip_address_connections.decrement(self.ip_addr)
2023-01-16 14:59:30 +00:00
def begin_auth(self, username: str) -> MaybeAwait[bool]:
"""Authentication has been requested by the client"""
return False
def password_auth_supported(self) -> bool:
"""Return whether or not password authentication is supported"""
return True
def generate_socket_path(self) -> str:
"""Return the path of a socket whose name has been randomly generated"""
socket_name = get_random_slug(16)
self.socket_path = os.path.join(unix_sockets_dir, f"{socket_name}.sock")
self.conn.set_extra_info(socket_name=socket_name)
return self.socket_path
def unix_server_requested(self, listen_path: str):
"""Handle a request to listen on a UNIX domain socket"""
rewrite_path: str = self.generate_socket_path()
async def tunnel_connection(
session_factory: SSHUNIXSessionFactory[AnyStr],
) -> Tuple[SSHUNIXChannel[AnyStr], SSHUNIXSession[AnyStr]]:
return await self.conn.create_unix_connection(session_factory, listen_path)
try:
return create_unix_forward_listener(
self.conn, asyncio.get_event_loop(), tunnel_connection, rewrite_path
)
except OSError as create_unix_forward_listener_exception:
logging.error(
"An error occurred while creating the forward listener: %s",
str(create_unix_forward_listener_exception),
)
async def handle_ssh_client(process) -> None:
"""Function called every time a client connects to the SSH server"""
socket_name: str = process.get_extra_info("socket_name")
2023-05-13 15:03:41 +00:00
rate_limited: bool = process.get_extra_info("rate_limited")
connection_limited: bool = process.get_extra_info("connection_limited")
2023-01-16 14:59:30 +00:00
response: str = ""
2023-05-13 15:03:41 +00:00
async def process_timeout(process):
"""Function to terminate the connection automatically
after a specific period of time (in minutes)"""
await asyncio.sleep(timeout * 60)
response = (
f"Timeout: you were automatically ejected after {timeout} minutes of use.\n"
)
2023-01-16 14:59:30 +00:00
process.stdout.write(response + "\n")
2023-05-13 15:03:41 +00:00
process.logger.info(
f"The user was automatically ejected after {timeout} minutes of use"
2023-01-16 14:59:30 +00:00
)
2023-05-13 15:03:41 +00:00
process.close()
if not rate_limited:
if not connection_limited:
if not socket_name:
response = "Usage: ssh -R /:host:port ssh.tunnel.collabore.fr\n"
process.stdout.write(response + "\n")
process.logger.info(
"The user was ejected because they did not connect in port forwarding mode"
)
process.exit(1)
return
no_tls: str = f"{socket_name}.{server_hostname}"
tls: str = f"https://{socket_name}.{server_hostname}"
response = f"{welcome_banner}\nYour local service has been exposed to the public\n\
Internet address: {no_tls}\nTLS termination: {tls}\n"
process.stdout.write(response + "\n")
process.logger.info(f"Exposed on {no_tls}")
read_task: Task = asyncio.create_task(process.stdin.read())
timeout_task: Task = asyncio.create_task(process_timeout(process))
done, pending = await asyncio.wait(
[read_task, timeout_task], return_when=asyncio.FIRST_COMPLETED
)
for task in done:
try:
await task
except asyncssh.BreakReceived:
pass
for task in pending:
task.cancel()
process.exit(0)
else:
response = (
"Per-IP connection limit: too many connections running over this IP.\n"
)
process.stdout.write(response + "\n")
process.logger.warning("Rejected connection due to per-IP connection limit")
process.exit(1)
return
else:
response = "Rate limited: please try later.\n"
process.stdout.write(response + "\n")
process.logger.warning("Rejected connection due to rate limit")
process.exit(1)
2023-01-16 14:59:30 +00:00
return
async def start_ssh_server() -> None:
"""Start the SSH server"""
ssh_key_file: str = path.join(config_dir, "id_rsa_host")
await asyncssh.create_server(
SSHServer,
host=ssh_server_host,
port=ssh_server_port,
server_host_keys=[ssh_key_file],
process_factory=handle_ssh_client,
agent_forwarding=False,
allow_scp=False,
keepalive_interval=30,
)
logging.info("SSH server started successfully.")
def check_unix_sockets_dir() -> None:
"""If the directory for UNIX sockets does not exist, it is created"""
if not path.exists(unix_sockets_dir):
os.mkdir(unix_sockets_dir)
logging.warning(
"The %s folder does not exist, it has been created.", unix_sockets_dir
)
else:
logging.info("The %s folder exist.", unix_sockets_dir)
def generate_ssh_key() -> None:
"""If the SSH key of the server does not exist, it is generated"""
ssh_host_key: str = path.join(config_dir, "id_rsa_host")
logging.info("Loading the SSH key")
if not path.exists(ssh_host_key):
logging.warning(
"The SSH key for the host was not found, generation in progress..."
)
key: SSHKey = asyncssh.generate_private_key("ssh-rsa")
private_key: bytes = key.export_private_key()
with open(ssh_host_key, "wb") as ssh_host_key_data:
ssh_host_key_data.write(private_key)
logging.info("The key was successfully created!")
else:
logging.info("SSH key has been found")
class InterceptHandler(logging.Handler):
"""Intercept logging call"""
def emit(self, record):
"""Find caller from where originated the logged message"""
frame: FrameType = logging.currentframe()
depth: int = log_depth
while frame.f_code.co_filename == logging.__file__:
frame = frame.f_back
depth += 1
logger.opt(exception=record.exc_info).log(log_level, record.getMessage())
def init_logging():
"""Init logging with a custom handler"""
logging.root.handlers: Handler = [InterceptHandler()]
logging.root.setLevel(log_level)
2023-05-13 15:03:41 +00:00
fmt = "<green>[{time}]</green> <level>[{level}]</level> - <level>{message}</level>"
logger.configure(handlers=[{"sink": sys.stdout, "serialize": False, "format": fmt}])
2023-01-16 14:59:30 +00:00
def get_random_slug(length) -> str:
"""Function that generates a random string of a defined size"""
chars: str = string.ascii_lowercase + string.digits
return "".join(random.choices(chars, k=length))
if __name__ == "__main__":
init_logging()
logging.info("Starting collabore tunnel SSH server...")
os.umask(0o000)
generate_ssh_key()
logging.info("Checking for the existence of a folder for UNIX sockets...")
check_unix_sockets_dir()
loop: AbstractEventLoop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(start_ssh_server())
except KeyboardInterrupt:
pass
except (OSError, asyncssh.Error) as ssh_server_startup_exception:
logging.critical(
"An error occurred while starting the SSH server: %s",
str(ssh_server_startup_exception),
)
sys.exit()
try:
loop.run_forever()
except KeyboardInterrupt:
sys.exit()