diff --git a/main.py b/main.py index 8ce7039..f8eedf3 100644 --- a/main.py +++ b/main.py @@ -6,50 +6,134 @@ import os import random import string import sys +import time from asyncio import AbstractEventLoop +from collections import deque from os import path from types import FrameType from typing import AnyStr, Optional, Tuple +from _asyncio import Task import asyncssh from asyncssh import SSHKey, SSHServerConnection +from asyncssh.channel import ( + SSHUNIXChannel, + SSHUNIXSession, + SSHUNIXSessionFactory, +) from asyncssh.listener import create_unix_forward_listener from asyncssh.misc import MaybeAwait -from asyncssh.channel import SSHUNIXChannel, SSHUNIXSession, SSHUNIXSessionFactory from loguru import logger from loguru._handler import Handler + unix_sockets_dir: str = os.getenv("UNIX_SOCKETS_DIRECTORY", "/tmp/collabore-tunnel") server_hostname: str = os.getenv("SERVER_HOSTNAME", "tnl.clb.re") config_dir: str = os.getenv("CONFIG_DIRECTORY", ".") +welcome_banner_file: str = os.getenv("WELCOME_BANNER_FILE", "./welcome_banner.txt") +rate_limit_count: int = int(os.getenv("RATE_LIMIT_COUNT", "5")) +rate_limit_interval: int = int(os.getenv("RATE_LIMIT_INTERVAL", "60")) +max_connections_per_ip: int = int(os.getenv("MAX_CONNECTIONS_PER_IP", "5")) +timeout: int = int(os.getenv("TIMEOUT", "120")) ssh_server_host: str = os.getenv("SSH_SERVER_HOST", "0.0.0.0") ssh_server_port: int = int(os.getenv("SSH_SERVER_PORT", "22")) log_level: str = os.getenv("LOG_LEVEL", "INFO") log_depth: int = int(os.getenv("LOG_DEPTH", "2")) -welcome_banner = f"===============================================================================\n\ -Welcome to collabore tunnel!\n\ -collabore tunnel is a free and open source service offered as part of the\n\ -club elec collabore platform (https://collabore.fr) operated by club elec that\n\ -allows you to expose your local services on the public Internet.\n\ -To learn more about collabore tunnel,\n\ -visit the documentation website: https://tunnel.collabore.fr/\n\ -club elec (https://clubelec.insset.fr) is a french not-for-profit\n\ -student organisation.\n\ -===============================================================================\n\n" + +def read_welcome_banner() -> str: + """Read the welcome banner from a file""" + if not os.path.exists(welcome_banner_file): + return welcome_banner + with open(welcome_banner_file, "r", encoding="UTF-8") as file: + return file.read() + + +welcome_banner: str = read_welcome_banner() + + +class RateLimiter: + """Rate limiter handling class""" + + def __init__(self, max_requests: int, interval: int): + """Init class""" + self.max_requests: int = max_requests + self.interval: int = interval + self.timestamps: deque = deque() + + def is_rate_limited(self) -> bool: + """Check if rate limited""" + now: float = time.time() + while self.timestamps and self.timestamps[0] < now - self.interval: + self.timestamps.popleft() + if len(self.timestamps) >= self.max_requests: + return True + self.timestamps.append(now) + return False + + +class ConcurrentConnections: + """Concurrent connection handling class""" + + def __init__(self): + """Init class""" + self.ip_connections: dict = {} + + def increment(self, ip_addr: str) -> None: + """Increment the number of concurrent connections for an IP""" + if ip_addr not in self.ip_connections: + self.ip_connections[ip_addr] = 1 + else: + self.ip_connections[ip_addr] += 1 + + def decrement(self, ip_addr: str) -> None: + """Decrement the number of concurrent connections for an IP""" + self.ip_connections[ip_addr] -= 1 + + def get(self, ip_addr: str) -> int: + """Get the number of concurent connection for an IP""" + return self.ip_connections.get(ip_addr, 0) + + +ip_address_connections = ConcurrentConnections() + + +def check_concurrent_connections(ip_addr: str) -> bool: + """Checking for concurrent connections""" + return ip_address_connections.get(ip_addr) >= max_connections_per_ip class SSHServer(asyncssh.SSHServer): """SSH server protocol handler class""" + rate_limiters: dict = {} + def __init__(self): """Init class""" self.conn: SSHServerConnection self.socket_path: str + self.ip_addr: str + + def check_rate_limit(self, ip_addr: str) -> bool: + """Check if rate limited""" + if ip_addr not in self.rate_limiters: + self.rate_limiters[ip_addr] = RateLimiter( + rate_limit_count, rate_limit_interval + ) + return self.rate_limiters[ip_addr].is_rate_limited() def connection_made(self, conn: SSHServerConnection) -> None: """Called when a connection is made""" self.conn = conn + self.ip_addr, _ = conn.get_extra_info("peername") + + if self.check_rate_limit(self.ip_addr): + conn.set_extra_info(rate_limited=True) + + if check_concurrent_connections(self.ip_addr): + conn.set_extra_info(connection_limited=True) + + ip_address_connections.increment(self.ip_addr) def connection_lost(self, exc: Optional[Exception]) -> None: """Called when a connection is lost or closed""" @@ -59,6 +143,7 @@ class SSHServer(asyncssh.SSHServer): os.remove(self.socket_path) except AttributeError: pass + ip_address_connections.decrement(self.ip_addr) def begin_auth(self, username: str) -> MaybeAwait[bool]: """Authentication has been requested by the client""" @@ -98,27 +183,67 @@ class SSHServer(asyncssh.SSHServer): async def handle_ssh_client(process) -> None: """Function called every time a client connects to the SSH server""" socket_name: str = process.get_extra_info("socket_name") + rate_limited: bool = process.get_extra_info("rate_limited") + connection_limited: bool = process.get_extra_info("connection_limited") response: str = "" - if not socket_name: - response = f"Usage: ssh -R /:host:port ssh.tunnel.collabore.fr\n" - process.stdout.write(response + "\n") - process.exit(1) - logging.info( - "The user was ejected because they did not connect in port forwarding mode." + + async def process_timeout(process): + """Function to terminate the connection automatically + after a specific period of time (in minutes)""" + await asyncio.sleep(timeout * 60) + response = ( + f"Timeout: you were automatically ejected after {timeout} minutes of use.\n" ) + process.stdout.write(response + "\n") + process.logger.info( + f"The user was automatically ejected after {timeout} minutes of use" + ) + process.close() + + if not rate_limited: + if not connection_limited: + if not socket_name: + response = "Usage: ssh -R /:host:port ssh.tunnel.collabore.fr\n" + process.stdout.write(response + "\n") + process.logger.info( + "The user was ejected because they did not connect in port forwarding mode" + ) + process.exit(1) + return + no_tls: str = f"{socket_name}.{server_hostname}" + tls: str = f"https://{socket_name}.{server_hostname}" + response = f"{welcome_banner}\nYour local service has been exposed to the public\n\ +Internet address: {no_tls}\nTLS termination: {tls}\n" + process.stdout.write(response + "\n") + process.logger.info(f"Exposed on {no_tls}") + read_task: Task = asyncio.create_task(process.stdin.read()) + timeout_task: Task = asyncio.create_task(process_timeout(process)) + done, pending = await asyncio.wait( + [read_task, timeout_task], return_when=asyncio.FIRST_COMPLETED + ) + for task in done: + try: + await task + except asyncssh.BreakReceived: + pass + for task in pending: + task.cancel() + + process.exit(0) + else: + response = ( + "Per-IP connection limit: too many connections running over this IP.\n" + ) + process.stdout.write(response + "\n") + process.logger.warning("Rejected connection due to per-IP connection limit") + process.exit(1) + return + else: + response = "Rate limited: please try later.\n" + process.stdout.write(response + "\n") + process.logger.warning("Rejected connection due to rate limit") + process.exit(1) return - no_tls: str = f"{socket_name}.{server_hostname}" - tls: str = f"https://{socket_name}.{server_hostname}" - response = f"{welcome_banner}Your local service has been exposed\ - to the public Internet address: {no_tls}\nTLS termination: {tls}\n" - process.stdout.write(response + "\n") - logging.info(f"Exposed on {no_tls}, {tls}.") - while not process.stdin.at_eof(): - try: - await process.stdin.read() - except asyncssh.TerminalSizeChanged: - pass - process.exit(0) async def start_ssh_server() -> None: @@ -182,10 +307,8 @@ def init_logging(): """Init logging with a custom handler""" logging.root.handlers: Handler = [InterceptHandler()] logging.root.setLevel(log_level) - for name in logging.root.manager.loggerDict.keys(): - logging.getLogger(name).handlers: list = [] - logging.getLogger(name).propagate: bool = True - logger.configure(handlers=[{"sink": sys.stdout, "serialize": False}]) + fmt = "[{time}] [{level}] - {message}" + logger.configure(handlers=[{"sink": sys.stdout, "serialize": False, "format": fmt}]) def get_random_slug(length) -> str: