Update main.py
This commit is contained in:
		
							parent
							
								
									beae19da38
								
							
						
					
					
						commit
						25111acee3
					
				
							
								
								
									
										189
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										189
									
								
								main.py
									
									
									
									
									
								
							|  | @ -6,50 +6,134 @@ import os | |||
| import random | ||||
| import string | ||||
| import sys | ||||
| import time | ||||
| from asyncio import AbstractEventLoop | ||||
| from collections import deque | ||||
| from os import path | ||||
| from types import FrameType | ||||
| from typing import AnyStr, Optional, Tuple | ||||
| from _asyncio import Task | ||||
| 
 | ||||
| import asyncssh | ||||
| from asyncssh import SSHKey, SSHServerConnection | ||||
| from asyncssh.channel import ( | ||||
|     SSHUNIXChannel, | ||||
|     SSHUNIXSession, | ||||
|     SSHUNIXSessionFactory, | ||||
| ) | ||||
| from asyncssh.listener import create_unix_forward_listener | ||||
| from asyncssh.misc import MaybeAwait | ||||
| from asyncssh.channel import SSHUNIXChannel, SSHUNIXSession, SSHUNIXSessionFactory | ||||
| from loguru import logger | ||||
| from loguru._handler import Handler | ||||
| 
 | ||||
| 
 | ||||
| unix_sockets_dir: str = os.getenv("UNIX_SOCKETS_DIRECTORY", "/tmp/collabore-tunnel") | ||||
| server_hostname: str = os.getenv("SERVER_HOSTNAME", "tnl.clb.re") | ||||
| config_dir: str = os.getenv("CONFIG_DIRECTORY", ".") | ||||
| welcome_banner_file: str = os.getenv("WELCOME_BANNER_FILE", "./welcome_banner.txt") | ||||
| rate_limit_count: int = int(os.getenv("RATE_LIMIT_COUNT", "5")) | ||||
| rate_limit_interval: int = int(os.getenv("RATE_LIMIT_INTERVAL", "60")) | ||||
| max_connections_per_ip: int = int(os.getenv("MAX_CONNECTIONS_PER_IP", "5")) | ||||
| timeout: int = int(os.getenv("TIMEOUT", "120")) | ||||
| ssh_server_host: str = os.getenv("SSH_SERVER_HOST", "0.0.0.0") | ||||
| ssh_server_port: int = int(os.getenv("SSH_SERVER_PORT", "22")) | ||||
| log_level: str = os.getenv("LOG_LEVEL", "INFO") | ||||
| log_depth: int = int(os.getenv("LOG_DEPTH", "2")) | ||||
| 
 | ||||
| welcome_banner = f"===============================================================================\n\ | ||||
| Welcome to collabore tunnel!\n\ | ||||
| collabore tunnel is a free and open source service offered as part of the\n\ | ||||
| club elec collabore platform (https://collabore.fr) operated by club elec that\n\ | ||||
| allows you to expose your local services on the public Internet.\n\ | ||||
| To learn more about collabore tunnel,\n\ | ||||
| visit the documentation website: https://tunnel.collabore.fr/\n\ | ||||
| club elec (https://clubelec.insset.fr) is a french not-for-profit\n\ | ||||
| student organisation.\n\ | ||||
| ===============================================================================\n\n" | ||||
| 
 | ||||
| def read_welcome_banner() -> str: | ||||
|     """Read the welcome banner from a file""" | ||||
|     if not os.path.exists(welcome_banner_file): | ||||
|         return welcome_banner | ||||
|     with open(welcome_banner_file, "r", encoding="UTF-8") as file: | ||||
|         return file.read() | ||||
| 
 | ||||
| 
 | ||||
| welcome_banner: str = read_welcome_banner() | ||||
| 
 | ||||
| 
 | ||||
| class RateLimiter: | ||||
|     """Rate limiter handling class""" | ||||
| 
 | ||||
|     def __init__(self, max_requests: int, interval: int): | ||||
|         """Init class""" | ||||
|         self.max_requests: int = max_requests | ||||
|         self.interval: int = interval | ||||
|         self.timestamps: deque = deque() | ||||
| 
 | ||||
|     def is_rate_limited(self) -> bool: | ||||
|         """Check if rate limited""" | ||||
|         now: float = time.time() | ||||
|         while self.timestamps and self.timestamps[0] < now - self.interval: | ||||
|             self.timestamps.popleft() | ||||
|         if len(self.timestamps) >= self.max_requests: | ||||
|             return True | ||||
|         self.timestamps.append(now) | ||||
|         return False | ||||
| 
 | ||||
| 
 | ||||
| class ConcurrentConnections: | ||||
|     """Concurrent connection handling class""" | ||||
| 
 | ||||
|     def __init__(self): | ||||
|         """Init class""" | ||||
|         self.ip_connections: dict = {} | ||||
| 
 | ||||
|     def increment(self, ip_addr: str) -> None: | ||||
|         """Increment the number of concurrent connections for an IP""" | ||||
|         if ip_addr not in self.ip_connections: | ||||
|             self.ip_connections[ip_addr] = 1 | ||||
|         else: | ||||
|             self.ip_connections[ip_addr] += 1 | ||||
| 
 | ||||
|     def decrement(self, ip_addr: str) -> None: | ||||
|         """Decrement the number of concurrent connections for an IP""" | ||||
|         self.ip_connections[ip_addr] -= 1 | ||||
| 
 | ||||
|     def get(self, ip_addr: str) -> int: | ||||
|         """Get the number of concurent connection for an IP""" | ||||
|         return self.ip_connections.get(ip_addr, 0) | ||||
| 
 | ||||
| 
 | ||||
| ip_address_connections = ConcurrentConnections() | ||||
| 
 | ||||
| 
 | ||||
| def check_concurrent_connections(ip_addr: str) -> bool: | ||||
|     """Checking for concurrent connections""" | ||||
|     return ip_address_connections.get(ip_addr) >= max_connections_per_ip | ||||
| 
 | ||||
| 
 | ||||
| class SSHServer(asyncssh.SSHServer): | ||||
|     """SSH server protocol handler class""" | ||||
| 
 | ||||
|     rate_limiters: dict = {} | ||||
| 
 | ||||
|     def __init__(self): | ||||
|         """Init class""" | ||||
|         self.conn: SSHServerConnection | ||||
|         self.socket_path: str | ||||
|         self.ip_addr: str | ||||
| 
 | ||||
|     def check_rate_limit(self, ip_addr: str) -> bool: | ||||
|         """Check if rate limited""" | ||||
|         if ip_addr not in self.rate_limiters: | ||||
|             self.rate_limiters[ip_addr] = RateLimiter( | ||||
|                 rate_limit_count, rate_limit_interval | ||||
|             ) | ||||
|         return self.rate_limiters[ip_addr].is_rate_limited() | ||||
| 
 | ||||
|     def connection_made(self, conn: SSHServerConnection) -> None: | ||||
|         """Called when a connection is made""" | ||||
|         self.conn = conn | ||||
|         self.ip_addr, _ = conn.get_extra_info("peername") | ||||
| 
 | ||||
|         if self.check_rate_limit(self.ip_addr): | ||||
|             conn.set_extra_info(rate_limited=True) | ||||
| 
 | ||||
|         if check_concurrent_connections(self.ip_addr): | ||||
|             conn.set_extra_info(connection_limited=True) | ||||
| 
 | ||||
|         ip_address_connections.increment(self.ip_addr) | ||||
| 
 | ||||
|     def connection_lost(self, exc: Optional[Exception]) -> None: | ||||
|         """Called when a connection is lost or closed""" | ||||
|  | @ -59,6 +143,7 @@ class SSHServer(asyncssh.SSHServer): | |||
|             os.remove(self.socket_path) | ||||
|         except AttributeError: | ||||
|             pass | ||||
|         ip_address_connections.decrement(self.ip_addr) | ||||
| 
 | ||||
|     def begin_auth(self, username: str) -> MaybeAwait[bool]: | ||||
|         """Authentication has been requested by the client""" | ||||
|  | @ -98,27 +183,67 @@ class SSHServer(asyncssh.SSHServer): | |||
| async def handle_ssh_client(process) -> None: | ||||
|     """Function called every time a client connects to the SSH server""" | ||||
|     socket_name: str = process.get_extra_info("socket_name") | ||||
|     rate_limited: bool = process.get_extra_info("rate_limited") | ||||
|     connection_limited: bool = process.get_extra_info("connection_limited") | ||||
|     response: str = "" | ||||
|     if not socket_name: | ||||
|         response = f"Usage: ssh -R /:host:port ssh.tunnel.collabore.fr\n" | ||||
|         process.stdout.write(response + "\n") | ||||
|         process.exit(1) | ||||
|         logging.info( | ||||
|             "The user was ejected because they did not connect in port forwarding mode." | ||||
| 
 | ||||
|     async def process_timeout(process): | ||||
|         """Function to terminate the connection automatically | ||||
|         after a specific period of time (in minutes)""" | ||||
|         await asyncio.sleep(timeout * 60) | ||||
|         response = ( | ||||
|             f"Timeout: you were automatically ejected after {timeout} minutes of use.\n" | ||||
|         ) | ||||
|         process.stdout.write(response + "\n") | ||||
|         process.logger.info( | ||||
|             f"The user was automatically ejected after {timeout} minutes of use" | ||||
|         ) | ||||
|         process.close() | ||||
| 
 | ||||
|     if not rate_limited: | ||||
|         if not connection_limited: | ||||
|             if not socket_name: | ||||
|                 response = "Usage: ssh -R /:host:port ssh.tunnel.collabore.fr\n" | ||||
|                 process.stdout.write(response + "\n") | ||||
|                 process.logger.info( | ||||
|                     "The user was ejected because they did not connect in port forwarding mode" | ||||
|                 ) | ||||
|                 process.exit(1) | ||||
|                 return | ||||
|             no_tls: str = f"{socket_name}.{server_hostname}" | ||||
|             tls: str = f"https://{socket_name}.{server_hostname}" | ||||
|             response = f"{welcome_banner}\nYour local service has been exposed to the public\n\ | ||||
| Internet address: {no_tls}\nTLS termination: {tls}\n" | ||||
|             process.stdout.write(response + "\n") | ||||
|             process.logger.info(f"Exposed on {no_tls}") | ||||
|             read_task: Task = asyncio.create_task(process.stdin.read()) | ||||
|             timeout_task: Task = asyncio.create_task(process_timeout(process)) | ||||
|             done, pending = await asyncio.wait( | ||||
|                 [read_task, timeout_task], return_when=asyncio.FIRST_COMPLETED | ||||
|             ) | ||||
|             for task in done: | ||||
|                 try: | ||||
|                     await task | ||||
|                 except asyncssh.BreakReceived: | ||||
|                     pass | ||||
|             for task in pending: | ||||
|                 task.cancel() | ||||
| 
 | ||||
|             process.exit(0) | ||||
|         else: | ||||
|             response = ( | ||||
|                 "Per-IP connection limit: too many connections running over this IP.\n" | ||||
|             ) | ||||
|             process.stdout.write(response + "\n") | ||||
|             process.logger.warning("Rejected connection due to per-IP connection limit") | ||||
|             process.exit(1) | ||||
|             return | ||||
|     else: | ||||
|         response = "Rate limited: please try later.\n" | ||||
|         process.stdout.write(response + "\n") | ||||
|         process.logger.warning("Rejected connection due to rate limit") | ||||
|         process.exit(1) | ||||
|         return | ||||
|     no_tls: str = f"{socket_name}.{server_hostname}" | ||||
|     tls: str = f"https://{socket_name}.{server_hostname}" | ||||
|     response = f"{welcome_banner}Your local service has been exposed\ | ||||
|  to the public Internet address: {no_tls}\nTLS termination: {tls}\n" | ||||
|     process.stdout.write(response + "\n") | ||||
|     logging.info(f"Exposed on {no_tls}, {tls}.") | ||||
|     while not process.stdin.at_eof(): | ||||
|         try: | ||||
|             await process.stdin.read() | ||||
|         except asyncssh.TerminalSizeChanged: | ||||
|             pass | ||||
|     process.exit(0) | ||||
| 
 | ||||
| 
 | ||||
| async def start_ssh_server() -> None: | ||||
|  | @ -182,10 +307,8 @@ def init_logging(): | |||
|     """Init logging with a custom handler""" | ||||
|     logging.root.handlers: Handler = [InterceptHandler()] | ||||
|     logging.root.setLevel(log_level) | ||||
|     for name in logging.root.manager.loggerDict.keys(): | ||||
|         logging.getLogger(name).handlers: list = [] | ||||
|         logging.getLogger(name).propagate: bool = True | ||||
|     logger.configure(handlers=[{"sink": sys.stdout, "serialize": False}]) | ||||
|     fmt = "<green>[{time}]</green> <level>[{level}]</level> - <level>{message}</level>" | ||||
|     logger.configure(handlers=[{"sink": sys.stdout, "serialize": False, "format": fmt}]) | ||||
| 
 | ||||
| 
 | ||||
| def get_random_slug(length) -> str: | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user