Create main.py
This commit is contained in:
		
							parent
							
								
									6a19eb14f2
								
							
						
					
					
						commit
						7e9493b5d3
					
				
							
								
								
									
										219
									
								
								main.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										219
									
								
								main.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,219 @@ | |||
| """collabore tunnel SSH server""" | ||||
| 
 | ||||
| import asyncio | ||||
| import logging | ||||
| import os | ||||
| import random | ||||
| import string | ||||
| import sys | ||||
| from asyncio import AbstractEventLoop | ||||
| from os import path | ||||
| from types import FrameType | ||||
| from typing import AnyStr, Optional, Tuple | ||||
| 
 | ||||
| import asyncssh | ||||
| from asyncssh import SSHKey, SSHServerConnection | ||||
| from asyncssh.listener import create_unix_forward_listener | ||||
| from asyncssh.misc import MaybeAwait | ||||
| from asyncssh.channel import SSHUNIXChannel, SSHUNIXSession, SSHUNIXSessionFactory | ||||
| from loguru import logger | ||||
| from loguru._handler import Handler | ||||
| 
 | ||||
| unix_sockets_dir: str = os.getenv("UNIX_SOCKETS_DIRECTORY", "/tmp/collabore-tunnel") | ||||
| server_hostname: str = os.getenv("SERVER_HOSTNAME", "tnl.clb.re") | ||||
| config_dir: str = os.getenv("CONFIG_DIRECTORY", ".") | ||||
| ssh_server_host: str = os.getenv("SSH_SERVER_HOST", "0.0.0.0") | ||||
| ssh_server_port: int = int(os.getenv("SSH_SERVER_PORT", "22")) | ||||
| log_level: str = os.getenv("LOG_LEVEL", "INFO") | ||||
| log_depth: int = int(os.getenv("LOG_DEPTH", "2")) | ||||
| 
 | ||||
| welcome_banner = f"===============================================================================\n\ | ||||
| Welcome to collabore tunnel!\n\ | ||||
| collabore tunnel is a free and open source service offered as part of the\n\ | ||||
| club elec collabore platform (https://collabore.fr) operated by club elec that\n\ | ||||
| allows you to expose your local services on the public Internet.\n\ | ||||
| To learn more about collabore tunnel,\n\ | ||||
| visit the documentation website: https://tunnel.collabore.fr/\n\ | ||||
| club elec (https://clubelec.insset.fr) is a french not-for-profit\n\ | ||||
| student organisation.\n\ | ||||
| ===============================================================================\n\n" | ||||
| 
 | ||||
| 
 | ||||
| class SSHServer(asyncssh.SSHServer): | ||||
|     """SSH server protocol handler class""" | ||||
| 
 | ||||
|     def __init__(self): | ||||
|         """Init class""" | ||||
|         self.conn: SSHServerConnection | ||||
|         self.socket_path: str | ||||
| 
 | ||||
|     def connection_made(self, conn: SSHServerConnection) -> None: | ||||
|         """Called when a connection is made""" | ||||
|         self.conn = conn | ||||
| 
 | ||||
|     def connection_lost(self, exc: Optional[Exception]) -> None: | ||||
|         """Called when a connection is lost or closed""" | ||||
|         if exc: | ||||
|             logging.info("The connection has been terminated: %s", str(exc)) | ||||
|         try: | ||||
|             os.remove(self.socket_path) | ||||
|         except AttributeError: | ||||
|             pass | ||||
| 
 | ||||
|     def begin_auth(self, username: str) -> MaybeAwait[bool]: | ||||
|         """Authentication has been requested by the client""" | ||||
|         return False | ||||
| 
 | ||||
|     def password_auth_supported(self) -> bool: | ||||
|         """Return whether or not password authentication is supported""" | ||||
|         return True | ||||
| 
 | ||||
|     def generate_socket_path(self) -> str: | ||||
|         """Return the path of a socket whose name has been randomly generated""" | ||||
|         socket_name = get_random_slug(16) | ||||
|         self.socket_path = os.path.join(unix_sockets_dir, f"{socket_name}.sock") | ||||
|         self.conn.set_extra_info(socket_name=socket_name) | ||||
|         return self.socket_path | ||||
| 
 | ||||
|     def unix_server_requested(self, listen_path: str): | ||||
|         """Handle a request to listen on a UNIX domain socket""" | ||||
|         rewrite_path: str = self.generate_socket_path() | ||||
| 
 | ||||
|         async def tunnel_connection( | ||||
|             session_factory: SSHUNIXSessionFactory[AnyStr], | ||||
|         ) -> Tuple[SSHUNIXChannel[AnyStr], SSHUNIXSession[AnyStr]]: | ||||
|             return await self.conn.create_unix_connection(session_factory, listen_path) | ||||
| 
 | ||||
|         try: | ||||
|             return create_unix_forward_listener( | ||||
|                 self.conn, asyncio.get_event_loop(), tunnel_connection, rewrite_path | ||||
|             ) | ||||
|         except OSError as create_unix_forward_listener_exception: | ||||
|             logging.error( | ||||
|                 "An error occurred while creating the forward listener: %s", | ||||
|                 str(create_unix_forward_listener_exception), | ||||
|             ) | ||||
| 
 | ||||
| 
 | ||||
| async def handle_ssh_client(process) -> None: | ||||
|     """Function called every time a client connects to the SSH server""" | ||||
|     socket_name: str = process.get_extra_info("socket_name") | ||||
|     response: str = "" | ||||
|     if not socket_name: | ||||
|         response = f"Usage: ssh -R /:host:port ssh.tunnel.collabore.fr\n" | ||||
|         process.stdout.write(response + "\n") | ||||
|         process.exit(1) | ||||
|         logging.info( | ||||
|             "The user was ejected because they did not connect in port forwarding mode." | ||||
|         ) | ||||
|         return | ||||
|     no_tls: str = f"{socket_name}.{server_hostname}" | ||||
|     tls: str = f"https://{socket_name}.{server_hostname}" | ||||
|     response = f"{welcome_banner}Your local service has been exposed\ | ||||
|  to the public Internet address: {no_tls}\nTLS termination: {tls}\n" | ||||
|     process.stdout.write(response + "\n") | ||||
|     logging.info(f"Exposed on {no_tls}, {tls}.") | ||||
|     while not process.stdin.at_eof(): | ||||
|         try: | ||||
|             await process.stdin.read() | ||||
|         except asyncssh.TerminalSizeChanged: | ||||
|             pass | ||||
|     process.exit(0) | ||||
| 
 | ||||
| 
 | ||||
| async def start_ssh_server() -> None: | ||||
|     """Start the SSH server""" | ||||
|     ssh_key_file: str = path.join(config_dir, "id_rsa_host") | ||||
|     await asyncssh.create_server( | ||||
|         SSHServer, | ||||
|         host=ssh_server_host, | ||||
|         port=ssh_server_port, | ||||
|         server_host_keys=[ssh_key_file], | ||||
|         process_factory=handle_ssh_client, | ||||
|         agent_forwarding=False, | ||||
|         allow_scp=False, | ||||
|         keepalive_interval=30, | ||||
|     ) | ||||
|     logging.info("SSH server started successfully.") | ||||
| 
 | ||||
| 
 | ||||
| def check_unix_sockets_dir() -> None: | ||||
|     """If the directory for UNIX sockets does not exist, it is created""" | ||||
|     if not path.exists(unix_sockets_dir): | ||||
|         os.mkdir(unix_sockets_dir) | ||||
|         logging.warning( | ||||
|             "The %s folder does not exist, it has been created.", unix_sockets_dir | ||||
|         ) | ||||
|     else: | ||||
|         logging.info("The %s folder exist.", unix_sockets_dir) | ||||
| 
 | ||||
| 
 | ||||
| def generate_ssh_key() -> None: | ||||
|     """If the SSH key of the server does not exist, it is generated""" | ||||
|     ssh_host_key: str = path.join(config_dir, "id_rsa_host") | ||||
|     logging.info("Loading the SSH key") | ||||
|     if not path.exists(ssh_host_key): | ||||
|         logging.warning( | ||||
|             "The SSH key for the host was not found, generation in progress..." | ||||
|         ) | ||||
|         key: SSHKey = asyncssh.generate_private_key("ssh-rsa") | ||||
|         private_key: bytes = key.export_private_key() | ||||
|         with open(ssh_host_key, "wb") as ssh_host_key_data: | ||||
|             ssh_host_key_data.write(private_key) | ||||
|         logging.info("The key was successfully created!") | ||||
|     else: | ||||
|         logging.info("SSH key has been found") | ||||
| 
 | ||||
| 
 | ||||
| class InterceptHandler(logging.Handler): | ||||
|     """Intercept logging call""" | ||||
| 
 | ||||
|     def emit(self, record): | ||||
|         """Find caller from where originated the logged message""" | ||||
|         frame: FrameType = logging.currentframe() | ||||
|         depth: int = log_depth | ||||
|         while frame.f_code.co_filename == logging.__file__: | ||||
|             frame = frame.f_back | ||||
|             depth += 1 | ||||
|         logger.opt(exception=record.exc_info).log(log_level, record.getMessage()) | ||||
| 
 | ||||
| 
 | ||||
| def init_logging(): | ||||
|     """Init logging with a custom handler""" | ||||
|     logging.root.handlers: Handler = [InterceptHandler()] | ||||
|     logging.root.setLevel(log_level) | ||||
|     for name in logging.root.manager.loggerDict.keys(): | ||||
|         logging.getLogger(name).handlers: list = [] | ||||
|         logging.getLogger(name).propagate: bool = True | ||||
|     logger.configure(handlers=[{"sink": sys.stdout, "serialize": False}]) | ||||
| 
 | ||||
| 
 | ||||
| def get_random_slug(length) -> str: | ||||
|     """Function that generates a random string of a defined size""" | ||||
|     chars: str = string.ascii_lowercase + string.digits | ||||
|     return "".join(random.choices(chars, k=length)) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     init_logging() | ||||
|     logging.info("Starting collabore tunnel SSH server...") | ||||
|     os.umask(0o000) | ||||
|     generate_ssh_key() | ||||
|     logging.info("Checking for the existence of a folder for UNIX sockets...") | ||||
|     check_unix_sockets_dir() | ||||
|     loop: AbstractEventLoop = asyncio.new_event_loop() | ||||
|     asyncio.set_event_loop(loop) | ||||
|     try: | ||||
|         loop.run_until_complete(start_ssh_server()) | ||||
|     except KeyboardInterrupt: | ||||
|         pass | ||||
|     except (OSError, asyncssh.Error) as ssh_server_startup_exception: | ||||
|         logging.critical( | ||||
|             "An error occurred while starting the SSH server: %s", | ||||
|             str(ssh_server_startup_exception), | ||||
|         ) | ||||
|         sys.exit() | ||||
|     try: | ||||
|         loop.run_forever() | ||||
|     except KeyboardInterrupt: | ||||
|         sys.exit() | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user