Create main.py
This commit is contained in:
parent
6a19eb14f2
commit
7e9493b5d3
219
main.py
Normal file
219
main.py
Normal file
|
@ -0,0 +1,219 @@
|
||||||
|
"""collabore tunnel SSH server"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import string
|
||||||
|
import sys
|
||||||
|
from asyncio import AbstractEventLoop
|
||||||
|
from os import path
|
||||||
|
from types import FrameType
|
||||||
|
from typing import AnyStr, Optional, Tuple
|
||||||
|
|
||||||
|
import asyncssh
|
||||||
|
from asyncssh import SSHKey, SSHServerConnection
|
||||||
|
from asyncssh.listener import create_unix_forward_listener
|
||||||
|
from asyncssh.misc import MaybeAwait
|
||||||
|
from asyncssh.channel import SSHUNIXChannel, SSHUNIXSession, SSHUNIXSessionFactory
|
||||||
|
from loguru import logger
|
||||||
|
from loguru._handler import Handler
|
||||||
|
|
||||||
|
unix_sockets_dir: str = os.getenv("UNIX_SOCKETS_DIRECTORY", "/tmp/collabore-tunnel")
|
||||||
|
server_hostname: str = os.getenv("SERVER_HOSTNAME", "tnl.clb.re")
|
||||||
|
config_dir: str = os.getenv("CONFIG_DIRECTORY", ".")
|
||||||
|
ssh_server_host: str = os.getenv("SSH_SERVER_HOST", "0.0.0.0")
|
||||||
|
ssh_server_port: int = int(os.getenv("SSH_SERVER_PORT", "22"))
|
||||||
|
log_level: str = os.getenv("LOG_LEVEL", "INFO")
|
||||||
|
log_depth: int = int(os.getenv("LOG_DEPTH", "2"))
|
||||||
|
|
||||||
|
welcome_banner = f"===============================================================================\n\
|
||||||
|
Welcome to collabore tunnel!\n\
|
||||||
|
collabore tunnel is a free and open source service offered as part of the\n\
|
||||||
|
club elec collabore platform (https://collabore.fr) operated by club elec that\n\
|
||||||
|
allows you to expose your local services on the public Internet.\n\
|
||||||
|
To learn more about collabore tunnel,\n\
|
||||||
|
visit the documentation website: https://tunnel.collabore.fr/\n\
|
||||||
|
club elec (https://clubelec.insset.fr) is a french not-for-profit\n\
|
||||||
|
student organisation.\n\
|
||||||
|
===============================================================================\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
class SSHServer(asyncssh.SSHServer):
|
||||||
|
"""SSH server protocol handler class"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Init class"""
|
||||||
|
self.conn: SSHServerConnection
|
||||||
|
self.socket_path: str
|
||||||
|
|
||||||
|
def connection_made(self, conn: SSHServerConnection) -> None:
|
||||||
|
"""Called when a connection is made"""
|
||||||
|
self.conn = conn
|
||||||
|
|
||||||
|
def connection_lost(self, exc: Optional[Exception]) -> None:
|
||||||
|
"""Called when a connection is lost or closed"""
|
||||||
|
if exc:
|
||||||
|
logging.info("The connection has been terminated: %s", str(exc))
|
||||||
|
try:
|
||||||
|
os.remove(self.socket_path)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def begin_auth(self, username: str) -> MaybeAwait[bool]:
|
||||||
|
"""Authentication has been requested by the client"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
def password_auth_supported(self) -> bool:
|
||||||
|
"""Return whether or not password authentication is supported"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def generate_socket_path(self) -> str:
|
||||||
|
"""Return the path of a socket whose name has been randomly generated"""
|
||||||
|
socket_name = get_random_slug(16)
|
||||||
|
self.socket_path = os.path.join(unix_sockets_dir, f"{socket_name}.sock")
|
||||||
|
self.conn.set_extra_info(socket_name=socket_name)
|
||||||
|
return self.socket_path
|
||||||
|
|
||||||
|
def unix_server_requested(self, listen_path: str):
|
||||||
|
"""Handle a request to listen on a UNIX domain socket"""
|
||||||
|
rewrite_path: str = self.generate_socket_path()
|
||||||
|
|
||||||
|
async def tunnel_connection(
|
||||||
|
session_factory: SSHUNIXSessionFactory[AnyStr],
|
||||||
|
) -> Tuple[SSHUNIXChannel[AnyStr], SSHUNIXSession[AnyStr]]:
|
||||||
|
return await self.conn.create_unix_connection(session_factory, listen_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return create_unix_forward_listener(
|
||||||
|
self.conn, asyncio.get_event_loop(), tunnel_connection, rewrite_path
|
||||||
|
)
|
||||||
|
except OSError as create_unix_forward_listener_exception:
|
||||||
|
logging.error(
|
||||||
|
"An error occurred while creating the forward listener: %s",
|
||||||
|
str(create_unix_forward_listener_exception),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_ssh_client(process) -> None:
|
||||||
|
"""Function called every time a client connects to the SSH server"""
|
||||||
|
socket_name: str = process.get_extra_info("socket_name")
|
||||||
|
response: str = ""
|
||||||
|
if not socket_name:
|
||||||
|
response = f"Usage: ssh -R /:host:port ssh.tunnel.collabore.fr\n"
|
||||||
|
process.stdout.write(response + "\n")
|
||||||
|
process.exit(1)
|
||||||
|
logging.info(
|
||||||
|
"The user was ejected because they did not connect in port forwarding mode."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
no_tls: str = f"{socket_name}.{server_hostname}"
|
||||||
|
tls: str = f"https://{socket_name}.{server_hostname}"
|
||||||
|
response = f"{welcome_banner}Your local service has been exposed\
|
||||||
|
to the public Internet address: {no_tls}\nTLS termination: {tls}\n"
|
||||||
|
process.stdout.write(response + "\n")
|
||||||
|
logging.info(f"Exposed on {no_tls}, {tls}.")
|
||||||
|
while not process.stdin.at_eof():
|
||||||
|
try:
|
||||||
|
await process.stdin.read()
|
||||||
|
except asyncssh.TerminalSizeChanged:
|
||||||
|
pass
|
||||||
|
process.exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
async def start_ssh_server() -> None:
|
||||||
|
"""Start the SSH server"""
|
||||||
|
ssh_key_file: str = path.join(config_dir, "id_rsa_host")
|
||||||
|
await asyncssh.create_server(
|
||||||
|
SSHServer,
|
||||||
|
host=ssh_server_host,
|
||||||
|
port=ssh_server_port,
|
||||||
|
server_host_keys=[ssh_key_file],
|
||||||
|
process_factory=handle_ssh_client,
|
||||||
|
agent_forwarding=False,
|
||||||
|
allow_scp=False,
|
||||||
|
keepalive_interval=30,
|
||||||
|
)
|
||||||
|
logging.info("SSH server started successfully.")
|
||||||
|
|
||||||
|
|
||||||
|
def check_unix_sockets_dir() -> None:
|
||||||
|
"""If the directory for UNIX sockets does not exist, it is created"""
|
||||||
|
if not path.exists(unix_sockets_dir):
|
||||||
|
os.mkdir(unix_sockets_dir)
|
||||||
|
logging.warning(
|
||||||
|
"The %s folder does not exist, it has been created.", unix_sockets_dir
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.info("The %s folder exist.", unix_sockets_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_ssh_key() -> None:
|
||||||
|
"""If the SSH key of the server does not exist, it is generated"""
|
||||||
|
ssh_host_key: str = path.join(config_dir, "id_rsa_host")
|
||||||
|
logging.info("Loading the SSH key")
|
||||||
|
if not path.exists(ssh_host_key):
|
||||||
|
logging.warning(
|
||||||
|
"The SSH key for the host was not found, generation in progress..."
|
||||||
|
)
|
||||||
|
key: SSHKey = asyncssh.generate_private_key("ssh-rsa")
|
||||||
|
private_key: bytes = key.export_private_key()
|
||||||
|
with open(ssh_host_key, "wb") as ssh_host_key_data:
|
||||||
|
ssh_host_key_data.write(private_key)
|
||||||
|
logging.info("The key was successfully created!")
|
||||||
|
else:
|
||||||
|
logging.info("SSH key has been found")
|
||||||
|
|
||||||
|
|
||||||
|
class InterceptHandler(logging.Handler):
|
||||||
|
"""Intercept logging call"""
|
||||||
|
|
||||||
|
def emit(self, record):
|
||||||
|
"""Find caller from where originated the logged message"""
|
||||||
|
frame: FrameType = logging.currentframe()
|
||||||
|
depth: int = log_depth
|
||||||
|
while frame.f_code.co_filename == logging.__file__:
|
||||||
|
frame = frame.f_back
|
||||||
|
depth += 1
|
||||||
|
logger.opt(exception=record.exc_info).log(log_level, record.getMessage())
|
||||||
|
|
||||||
|
|
||||||
|
def init_logging():
|
||||||
|
"""Init logging with a custom handler"""
|
||||||
|
logging.root.handlers: Handler = [InterceptHandler()]
|
||||||
|
logging.root.setLevel(log_level)
|
||||||
|
for name in logging.root.manager.loggerDict.keys():
|
||||||
|
logging.getLogger(name).handlers: list = []
|
||||||
|
logging.getLogger(name).propagate: bool = True
|
||||||
|
logger.configure(handlers=[{"sink": sys.stdout, "serialize": False}])
|
||||||
|
|
||||||
|
|
||||||
|
def get_random_slug(length) -> str:
|
||||||
|
"""Function that generates a random string of a defined size"""
|
||||||
|
chars: str = string.ascii_lowercase + string.digits
|
||||||
|
return "".join(random.choices(chars, k=length))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
init_logging()
|
||||||
|
logging.info("Starting collabore tunnel SSH server...")
|
||||||
|
os.umask(0o000)
|
||||||
|
generate_ssh_key()
|
||||||
|
logging.info("Checking for the existence of a folder for UNIX sockets...")
|
||||||
|
check_unix_sockets_dir()
|
||||||
|
loop: AbstractEventLoop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
loop.run_until_complete(start_ssh_server())
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
except (OSError, asyncssh.Error) as ssh_server_startup_exception:
|
||||||
|
logging.critical(
|
||||||
|
"An error occurred while starting the SSH server: %s",
|
||||||
|
str(ssh_server_startup_exception),
|
||||||
|
)
|
||||||
|
sys.exit()
|
||||||
|
try:
|
||||||
|
loop.run_forever()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
sys.exit()
|
Loading…
Reference in New Issue
Block a user