diff --git a/main.py b/main.py
index 8ce7039..f8eedf3 100644
--- a/main.py
+++ b/main.py
@@ -6,50 +6,134 @@ import os
import random
import string
import sys
+import time
from asyncio import AbstractEventLoop
+from collections import deque
from os import path
from types import FrameType
from typing import AnyStr, Optional, Tuple
+from _asyncio import Task
import asyncssh
from asyncssh import SSHKey, SSHServerConnection
+from asyncssh.channel import (
+ SSHUNIXChannel,
+ SSHUNIXSession,
+ SSHUNIXSessionFactory,
+)
from asyncssh.listener import create_unix_forward_listener
from asyncssh.misc import MaybeAwait
-from asyncssh.channel import SSHUNIXChannel, SSHUNIXSession, SSHUNIXSessionFactory
from loguru import logger
from loguru._handler import Handler
+
unix_sockets_dir: str = os.getenv("UNIX_SOCKETS_DIRECTORY", "/tmp/collabore-tunnel")
server_hostname: str = os.getenv("SERVER_HOSTNAME", "tnl.clb.re")
config_dir: str = os.getenv("CONFIG_DIRECTORY", ".")
+welcome_banner_file: str = os.getenv("WELCOME_BANNER_FILE", "./welcome_banner.txt")
+rate_limit_count: int = int(os.getenv("RATE_LIMIT_COUNT", "5"))
+rate_limit_interval: int = int(os.getenv("RATE_LIMIT_INTERVAL", "60"))
+max_connections_per_ip: int = int(os.getenv("MAX_CONNECTIONS_PER_IP", "5"))
+timeout: int = int(os.getenv("TIMEOUT", "120"))
ssh_server_host: str = os.getenv("SSH_SERVER_HOST", "0.0.0.0")
ssh_server_port: int = int(os.getenv("SSH_SERVER_PORT", "22"))
log_level: str = os.getenv("LOG_LEVEL", "INFO")
log_depth: int = int(os.getenv("LOG_DEPTH", "2"))
-welcome_banner = f"===============================================================================\n\
-Welcome to collabore tunnel!\n\
-collabore tunnel is a free and open source service offered as part of the\n\
-club elec collabore platform (https://collabore.fr) operated by club elec that\n\
-allows you to expose your local services on the public Internet.\n\
-To learn more about collabore tunnel,\n\
-visit the documentation website: https://tunnel.collabore.fr/\n\
-club elec (https://clubelec.insset.fr) is a french not-for-profit\n\
-student organisation.\n\
-===============================================================================\n\n"
+
+def read_welcome_banner() -> str:
+ """Read the welcome banner from a file"""
+ if not os.path.exists(welcome_banner_file):
+ return welcome_banner
+ with open(welcome_banner_file, "r", encoding="UTF-8") as file:
+ return file.read()
+
+
+welcome_banner: str = read_welcome_banner()
+
+
+class RateLimiter:
+ """Rate limiter handling class"""
+
+ def __init__(self, max_requests: int, interval: int):
+ """Init class"""
+ self.max_requests: int = max_requests
+ self.interval: int = interval
+ self.timestamps: deque = deque()
+
+ def is_rate_limited(self) -> bool:
+ """Check if rate limited"""
+ now: float = time.time()
+ while self.timestamps and self.timestamps[0] < now - self.interval:
+ self.timestamps.popleft()
+ if len(self.timestamps) >= self.max_requests:
+ return True
+ self.timestamps.append(now)
+ return False
+
+
+class ConcurrentConnections:
+ """Concurrent connection handling class"""
+
+ def __init__(self):
+ """Init class"""
+ self.ip_connections: dict = {}
+
+ def increment(self, ip_addr: str) -> None:
+ """Increment the number of concurrent connections for an IP"""
+ if ip_addr not in self.ip_connections:
+ self.ip_connections[ip_addr] = 1
+ else:
+ self.ip_connections[ip_addr] += 1
+
+ def decrement(self, ip_addr: str) -> None:
+ """Decrement the number of concurrent connections for an IP"""
+ self.ip_connections[ip_addr] -= 1
+
+ def get(self, ip_addr: str) -> int:
+ """Get the number of concurent connection for an IP"""
+ return self.ip_connections.get(ip_addr, 0)
+
+
+ip_address_connections = ConcurrentConnections()
+
+
+def check_concurrent_connections(ip_addr: str) -> bool:
+ """Checking for concurrent connections"""
+ return ip_address_connections.get(ip_addr) >= max_connections_per_ip
class SSHServer(asyncssh.SSHServer):
"""SSH server protocol handler class"""
+ rate_limiters: dict = {}
+
def __init__(self):
"""Init class"""
self.conn: SSHServerConnection
self.socket_path: str
+ self.ip_addr: str
+
+ def check_rate_limit(self, ip_addr: str) -> bool:
+ """Check if rate limited"""
+ if ip_addr not in self.rate_limiters:
+ self.rate_limiters[ip_addr] = RateLimiter(
+ rate_limit_count, rate_limit_interval
+ )
+ return self.rate_limiters[ip_addr].is_rate_limited()
def connection_made(self, conn: SSHServerConnection) -> None:
"""Called when a connection is made"""
self.conn = conn
+ self.ip_addr, _ = conn.get_extra_info("peername")
+
+ if self.check_rate_limit(self.ip_addr):
+ conn.set_extra_info(rate_limited=True)
+
+ if check_concurrent_connections(self.ip_addr):
+ conn.set_extra_info(connection_limited=True)
+
+ ip_address_connections.increment(self.ip_addr)
def connection_lost(self, exc: Optional[Exception]) -> None:
"""Called when a connection is lost or closed"""
@@ -59,6 +143,7 @@ class SSHServer(asyncssh.SSHServer):
os.remove(self.socket_path)
except AttributeError:
pass
+ ip_address_connections.decrement(self.ip_addr)
def begin_auth(self, username: str) -> MaybeAwait[bool]:
"""Authentication has been requested by the client"""
@@ -98,27 +183,67 @@ class SSHServer(asyncssh.SSHServer):
async def handle_ssh_client(process) -> None:
"""Function called every time a client connects to the SSH server"""
socket_name: str = process.get_extra_info("socket_name")
+ rate_limited: bool = process.get_extra_info("rate_limited")
+ connection_limited: bool = process.get_extra_info("connection_limited")
response: str = ""
- if not socket_name:
- response = f"Usage: ssh -R /:host:port ssh.tunnel.collabore.fr\n"
- process.stdout.write(response + "\n")
- process.exit(1)
- logging.info(
- "The user was ejected because they did not connect in port forwarding mode."
+
+ async def process_timeout(process):
+ """Function to terminate the connection automatically
+ after a specific period of time (in minutes)"""
+ await asyncio.sleep(timeout * 60)
+ response = (
+ f"Timeout: you were automatically ejected after {timeout} minutes of use.\n"
)
+ process.stdout.write(response + "\n")
+ process.logger.info(
+ f"The user was automatically ejected after {timeout} minutes of use"
+ )
+ process.close()
+
+ if not rate_limited:
+ if not connection_limited:
+ if not socket_name:
+ response = "Usage: ssh -R /:host:port ssh.tunnel.collabore.fr\n"
+ process.stdout.write(response + "\n")
+ process.logger.info(
+ "The user was ejected because they did not connect in port forwarding mode"
+ )
+ process.exit(1)
+ return
+ no_tls: str = f"{socket_name}.{server_hostname}"
+ tls: str = f"https://{socket_name}.{server_hostname}"
+ response = f"{welcome_banner}\nYour local service has been exposed to the public\n\
+Internet address: {no_tls}\nTLS termination: {tls}\n"
+ process.stdout.write(response + "\n")
+ process.logger.info(f"Exposed on {no_tls}")
+ read_task: Task = asyncio.create_task(process.stdin.read())
+ timeout_task: Task = asyncio.create_task(process_timeout(process))
+ done, pending = await asyncio.wait(
+ [read_task, timeout_task], return_when=asyncio.FIRST_COMPLETED
+ )
+ for task in done:
+ try:
+ await task
+ except asyncssh.BreakReceived:
+ pass
+ for task in pending:
+ task.cancel()
+
+ process.exit(0)
+ else:
+ response = (
+ "Per-IP connection limit: too many connections running over this IP.\n"
+ )
+ process.stdout.write(response + "\n")
+ process.logger.warning("Rejected connection due to per-IP connection limit")
+ process.exit(1)
+ return
+ else:
+ response = "Rate limited: please try later.\n"
+ process.stdout.write(response + "\n")
+ process.logger.warning("Rejected connection due to rate limit")
+ process.exit(1)
return
- no_tls: str = f"{socket_name}.{server_hostname}"
- tls: str = f"https://{socket_name}.{server_hostname}"
- response = f"{welcome_banner}Your local service has been exposed\
- to the public Internet address: {no_tls}\nTLS termination: {tls}\n"
- process.stdout.write(response + "\n")
- logging.info(f"Exposed on {no_tls}, {tls}.")
- while not process.stdin.at_eof():
- try:
- await process.stdin.read()
- except asyncssh.TerminalSizeChanged:
- pass
- process.exit(0)
async def start_ssh_server() -> None:
@@ -182,10 +307,8 @@ def init_logging():
"""Init logging with a custom handler"""
logging.root.handlers: Handler = [InterceptHandler()]
logging.root.setLevel(log_level)
- for name in logging.root.manager.loggerDict.keys():
- logging.getLogger(name).handlers: list = []
- logging.getLogger(name).propagate: bool = True
- logger.configure(handlers=[{"sink": sys.stdout, "serialize": False}])
+ fmt = "[{time}] [{level}] - {message}"
+ logger.configure(handlers=[{"sink": sys.stdout, "serialize": False, "format": fmt}])
def get_random_slug(length) -> str: