From 7e9493b5d37bd4947a290f8a5a99ca1a8cfc1cb3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20L=2E=20H=2E-F?= Date: Mon, 16 Jan 2023 15:59:30 +0100 Subject: [PATCH] Create main.py --- main.py | 219 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 219 insertions(+) create mode 100644 main.py diff --git a/main.py b/main.py new file mode 100644 index 0000000..8ce7039 --- /dev/null +++ b/main.py @@ -0,0 +1,219 @@ +"""collabore tunnel SSH server""" + +import asyncio +import logging +import os +import random +import string +import sys +from asyncio import AbstractEventLoop +from os import path +from types import FrameType +from typing import AnyStr, Optional, Tuple + +import asyncssh +from asyncssh import SSHKey, SSHServerConnection +from asyncssh.listener import create_unix_forward_listener +from asyncssh.misc import MaybeAwait +from asyncssh.channel import SSHUNIXChannel, SSHUNIXSession, SSHUNIXSessionFactory +from loguru import logger +from loguru._handler import Handler + +unix_sockets_dir: str = os.getenv("UNIX_SOCKETS_DIRECTORY", "/tmp/collabore-tunnel") +server_hostname: str = os.getenv("SERVER_HOSTNAME", "tnl.clb.re") +config_dir: str = os.getenv("CONFIG_DIRECTORY", ".") +ssh_server_host: str = os.getenv("SSH_SERVER_HOST", "0.0.0.0") +ssh_server_port: int = int(os.getenv("SSH_SERVER_PORT", "22")) +log_level: str = os.getenv("LOG_LEVEL", "INFO") +log_depth: int = int(os.getenv("LOG_DEPTH", "2")) + +welcome_banner = f"===============================================================================\n\ +Welcome to collabore tunnel!\n\ +collabore tunnel is a free and open source service offered as part of the\n\ +club elec collabore platform (https://collabore.fr) operated by club elec that\n\ +allows you to expose your local services on the public Internet.\n\ +To learn more about collabore tunnel,\n\ +visit the documentation website: https://tunnel.collabore.fr/\n\ +club elec (https://clubelec.insset.fr) is a french not-for-profit\n\ +student organisation.\n\ +===============================================================================\n\n" + + +class SSHServer(asyncssh.SSHServer): + """SSH server protocol handler class""" + + def __init__(self): + """Init class""" + self.conn: SSHServerConnection + self.socket_path: str + + def connection_made(self, conn: SSHServerConnection) -> None: + """Called when a connection is made""" + self.conn = conn + + def connection_lost(self, exc: Optional[Exception]) -> None: + """Called when a connection is lost or closed""" + if exc: + logging.info("The connection has been terminated: %s", str(exc)) + try: + os.remove(self.socket_path) + except AttributeError: + pass + + def begin_auth(self, username: str) -> MaybeAwait[bool]: + """Authentication has been requested by the client""" + return False + + def password_auth_supported(self) -> bool: + """Return whether or not password authentication is supported""" + return True + + def generate_socket_path(self) -> str: + """Return the path of a socket whose name has been randomly generated""" + socket_name = get_random_slug(16) + self.socket_path = os.path.join(unix_sockets_dir, f"{socket_name}.sock") + self.conn.set_extra_info(socket_name=socket_name) + return self.socket_path + + def unix_server_requested(self, listen_path: str): + """Handle a request to listen on a UNIX domain socket""" + rewrite_path: str = self.generate_socket_path() + + async def tunnel_connection( + session_factory: SSHUNIXSessionFactory[AnyStr], + ) -> Tuple[SSHUNIXChannel[AnyStr], SSHUNIXSession[AnyStr]]: + return await self.conn.create_unix_connection(session_factory, listen_path) + + try: + return create_unix_forward_listener( + self.conn, asyncio.get_event_loop(), tunnel_connection, rewrite_path + ) + except OSError as create_unix_forward_listener_exception: + logging.error( + "An error occurred while creating the forward listener: %s", + str(create_unix_forward_listener_exception), + ) + + +async def handle_ssh_client(process) -> None: + """Function called every time a client connects to the SSH server""" + socket_name: str = process.get_extra_info("socket_name") + response: str = "" + if not socket_name: + response = f"Usage: ssh -R /:host:port ssh.tunnel.collabore.fr\n" + process.stdout.write(response + "\n") + process.exit(1) + logging.info( + "The user was ejected because they did not connect in port forwarding mode." + ) + return + no_tls: str = f"{socket_name}.{server_hostname}" + tls: str = f"https://{socket_name}.{server_hostname}" + response = f"{welcome_banner}Your local service has been exposed\ + to the public Internet address: {no_tls}\nTLS termination: {tls}\n" + process.stdout.write(response + "\n") + logging.info(f"Exposed on {no_tls}, {tls}.") + while not process.stdin.at_eof(): + try: + await process.stdin.read() + except asyncssh.TerminalSizeChanged: + pass + process.exit(0) + + +async def start_ssh_server() -> None: + """Start the SSH server""" + ssh_key_file: str = path.join(config_dir, "id_rsa_host") + await asyncssh.create_server( + SSHServer, + host=ssh_server_host, + port=ssh_server_port, + server_host_keys=[ssh_key_file], + process_factory=handle_ssh_client, + agent_forwarding=False, + allow_scp=False, + keepalive_interval=30, + ) + logging.info("SSH server started successfully.") + + +def check_unix_sockets_dir() -> None: + """If the directory for UNIX sockets does not exist, it is created""" + if not path.exists(unix_sockets_dir): + os.mkdir(unix_sockets_dir) + logging.warning( + "The %s folder does not exist, it has been created.", unix_sockets_dir + ) + else: + logging.info("The %s folder exist.", unix_sockets_dir) + + +def generate_ssh_key() -> None: + """If the SSH key of the server does not exist, it is generated""" + ssh_host_key: str = path.join(config_dir, "id_rsa_host") + logging.info("Loading the SSH key") + if not path.exists(ssh_host_key): + logging.warning( + "The SSH key for the host was not found, generation in progress..." + ) + key: SSHKey = asyncssh.generate_private_key("ssh-rsa") + private_key: bytes = key.export_private_key() + with open(ssh_host_key, "wb") as ssh_host_key_data: + ssh_host_key_data.write(private_key) + logging.info("The key was successfully created!") + else: + logging.info("SSH key has been found") + + +class InterceptHandler(logging.Handler): + """Intercept logging call""" + + def emit(self, record): + """Find caller from where originated the logged message""" + frame: FrameType = logging.currentframe() + depth: int = log_depth + while frame.f_code.co_filename == logging.__file__: + frame = frame.f_back + depth += 1 + logger.opt(exception=record.exc_info).log(log_level, record.getMessage()) + + +def init_logging(): + """Init logging with a custom handler""" + logging.root.handlers: Handler = [InterceptHandler()] + logging.root.setLevel(log_level) + for name in logging.root.manager.loggerDict.keys(): + logging.getLogger(name).handlers: list = [] + logging.getLogger(name).propagate: bool = True + logger.configure(handlers=[{"sink": sys.stdout, "serialize": False}]) + + +def get_random_slug(length) -> str: + """Function that generates a random string of a defined size""" + chars: str = string.ascii_lowercase + string.digits + return "".join(random.choices(chars, k=length)) + + +if __name__ == "__main__": + init_logging() + logging.info("Starting collabore tunnel SSH server...") + os.umask(0o000) + generate_ssh_key() + logging.info("Checking for the existence of a folder for UNIX sockets...") + check_unix_sockets_dir() + loop: AbstractEventLoop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(start_ssh_server()) + except KeyboardInterrupt: + pass + except (OSError, asyncssh.Error) as ssh_server_startup_exception: + logging.critical( + "An error occurred while starting the SSH server: %s", + str(ssh_server_startup_exception), + ) + sys.exit() + try: + loop.run_forever() + except KeyboardInterrupt: + sys.exit()