Update main.py
This commit is contained in:
parent
beae19da38
commit
25111acee3
189
main.py
189
main.py
|
@ -6,50 +6,134 @@ import os
|
|||
import random
|
||||
import string
|
||||
import sys
|
||||
import time
|
||||
from asyncio import AbstractEventLoop
|
||||
from collections import deque
|
||||
from os import path
|
||||
from types import FrameType
|
||||
from typing import AnyStr, Optional, Tuple
|
||||
from _asyncio import Task
|
||||
|
||||
import asyncssh
|
||||
from asyncssh import SSHKey, SSHServerConnection
|
||||
from asyncssh.channel import (
|
||||
SSHUNIXChannel,
|
||||
SSHUNIXSession,
|
||||
SSHUNIXSessionFactory,
|
||||
)
|
||||
from asyncssh.listener import create_unix_forward_listener
|
||||
from asyncssh.misc import MaybeAwait
|
||||
from asyncssh.channel import SSHUNIXChannel, SSHUNIXSession, SSHUNIXSessionFactory
|
||||
from loguru import logger
|
||||
from loguru._handler import Handler
|
||||
|
||||
|
||||
unix_sockets_dir: str = os.getenv("UNIX_SOCKETS_DIRECTORY", "/tmp/collabore-tunnel")
|
||||
server_hostname: str = os.getenv("SERVER_HOSTNAME", "tnl.clb.re")
|
||||
config_dir: str = os.getenv("CONFIG_DIRECTORY", ".")
|
||||
welcome_banner_file: str = os.getenv("WELCOME_BANNER_FILE", "./welcome_banner.txt")
|
||||
rate_limit_count: int = int(os.getenv("RATE_LIMIT_COUNT", "5"))
|
||||
rate_limit_interval: int = int(os.getenv("RATE_LIMIT_INTERVAL", "60"))
|
||||
max_connections_per_ip: int = int(os.getenv("MAX_CONNECTIONS_PER_IP", "5"))
|
||||
timeout: int = int(os.getenv("TIMEOUT", "120"))
|
||||
ssh_server_host: str = os.getenv("SSH_SERVER_HOST", "0.0.0.0")
|
||||
ssh_server_port: int = int(os.getenv("SSH_SERVER_PORT", "22"))
|
||||
log_level: str = os.getenv("LOG_LEVEL", "INFO")
|
||||
log_depth: int = int(os.getenv("LOG_DEPTH", "2"))
|
||||
|
||||
welcome_banner = f"===============================================================================\n\
|
||||
Welcome to collabore tunnel!\n\
|
||||
collabore tunnel is a free and open source service offered as part of the\n\
|
||||
club elec collabore platform (https://collabore.fr) operated by club elec that\n\
|
||||
allows you to expose your local services on the public Internet.\n\
|
||||
To learn more about collabore tunnel,\n\
|
||||
visit the documentation website: https://tunnel.collabore.fr/\n\
|
||||
club elec (https://clubelec.insset.fr) is a french not-for-profit\n\
|
||||
student organisation.\n\
|
||||
===============================================================================\n\n"
|
||||
|
||||
def read_welcome_banner() -> str:
|
||||
"""Read the welcome banner from a file"""
|
||||
if not os.path.exists(welcome_banner_file):
|
||||
return welcome_banner
|
||||
with open(welcome_banner_file, "r", encoding="UTF-8") as file:
|
||||
return file.read()
|
||||
|
||||
|
||||
welcome_banner: str = read_welcome_banner()
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Rate limiter handling class"""
|
||||
|
||||
def __init__(self, max_requests: int, interval: int):
|
||||
"""Init class"""
|
||||
self.max_requests: int = max_requests
|
||||
self.interval: int = interval
|
||||
self.timestamps: deque = deque()
|
||||
|
||||
def is_rate_limited(self) -> bool:
|
||||
"""Check if rate limited"""
|
||||
now: float = time.time()
|
||||
while self.timestamps and self.timestamps[0] < now - self.interval:
|
||||
self.timestamps.popleft()
|
||||
if len(self.timestamps) >= self.max_requests:
|
||||
return True
|
||||
self.timestamps.append(now)
|
||||
return False
|
||||
|
||||
|
||||
class ConcurrentConnections:
|
||||
"""Concurrent connection handling class"""
|
||||
|
||||
def __init__(self):
|
||||
"""Init class"""
|
||||
self.ip_connections: dict = {}
|
||||
|
||||
def increment(self, ip_addr: str) -> None:
|
||||
"""Increment the number of concurrent connections for an IP"""
|
||||
if ip_addr not in self.ip_connections:
|
||||
self.ip_connections[ip_addr] = 1
|
||||
else:
|
||||
self.ip_connections[ip_addr] += 1
|
||||
|
||||
def decrement(self, ip_addr: str) -> None:
|
||||
"""Decrement the number of concurrent connections for an IP"""
|
||||
self.ip_connections[ip_addr] -= 1
|
||||
|
||||
def get(self, ip_addr: str) -> int:
|
||||
"""Get the number of concurent connection for an IP"""
|
||||
return self.ip_connections.get(ip_addr, 0)
|
||||
|
||||
|
||||
ip_address_connections = ConcurrentConnections()
|
||||
|
||||
|
||||
def check_concurrent_connections(ip_addr: str) -> bool:
|
||||
"""Checking for concurrent connections"""
|
||||
return ip_address_connections.get(ip_addr) >= max_connections_per_ip
|
||||
|
||||
|
||||
class SSHServer(asyncssh.SSHServer):
|
||||
"""SSH server protocol handler class"""
|
||||
|
||||
rate_limiters: dict = {}
|
||||
|
||||
def __init__(self):
|
||||
"""Init class"""
|
||||
self.conn: SSHServerConnection
|
||||
self.socket_path: str
|
||||
self.ip_addr: str
|
||||
|
||||
def check_rate_limit(self, ip_addr: str) -> bool:
|
||||
"""Check if rate limited"""
|
||||
if ip_addr not in self.rate_limiters:
|
||||
self.rate_limiters[ip_addr] = RateLimiter(
|
||||
rate_limit_count, rate_limit_interval
|
||||
)
|
||||
return self.rate_limiters[ip_addr].is_rate_limited()
|
||||
|
||||
def connection_made(self, conn: SSHServerConnection) -> None:
|
||||
"""Called when a connection is made"""
|
||||
self.conn = conn
|
||||
self.ip_addr, _ = conn.get_extra_info("peername")
|
||||
|
||||
if self.check_rate_limit(self.ip_addr):
|
||||
conn.set_extra_info(rate_limited=True)
|
||||
|
||||
if check_concurrent_connections(self.ip_addr):
|
||||
conn.set_extra_info(connection_limited=True)
|
||||
|
||||
ip_address_connections.increment(self.ip_addr)
|
||||
|
||||
def connection_lost(self, exc: Optional[Exception]) -> None:
|
||||
"""Called when a connection is lost or closed"""
|
||||
|
@ -59,6 +143,7 @@ class SSHServer(asyncssh.SSHServer):
|
|||
os.remove(self.socket_path)
|
||||
except AttributeError:
|
||||
pass
|
||||
ip_address_connections.decrement(self.ip_addr)
|
||||
|
||||
def begin_auth(self, username: str) -> MaybeAwait[bool]:
|
||||
"""Authentication has been requested by the client"""
|
||||
|
@ -98,27 +183,67 @@ class SSHServer(asyncssh.SSHServer):
|
|||
async def handle_ssh_client(process) -> None:
|
||||
"""Function called every time a client connects to the SSH server"""
|
||||
socket_name: str = process.get_extra_info("socket_name")
|
||||
rate_limited: bool = process.get_extra_info("rate_limited")
|
||||
connection_limited: bool = process.get_extra_info("connection_limited")
|
||||
response: str = ""
|
||||
if not socket_name:
|
||||
response = f"Usage: ssh -R /:host:port ssh.tunnel.collabore.fr\n"
|
||||
process.stdout.write(response + "\n")
|
||||
process.exit(1)
|
||||
logging.info(
|
||||
"The user was ejected because they did not connect in port forwarding mode."
|
||||
|
||||
async def process_timeout(process):
|
||||
"""Function to terminate the connection automatically
|
||||
after a specific period of time (in minutes)"""
|
||||
await asyncio.sleep(timeout * 60)
|
||||
response = (
|
||||
f"Timeout: you were automatically ejected after {timeout} minutes of use.\n"
|
||||
)
|
||||
process.stdout.write(response + "\n")
|
||||
process.logger.info(
|
||||
f"The user was automatically ejected after {timeout} minutes of use"
|
||||
)
|
||||
process.close()
|
||||
|
||||
if not rate_limited:
|
||||
if not connection_limited:
|
||||
if not socket_name:
|
||||
response = "Usage: ssh -R /:host:port ssh.tunnel.collabore.fr\n"
|
||||
process.stdout.write(response + "\n")
|
||||
process.logger.info(
|
||||
"The user was ejected because they did not connect in port forwarding mode"
|
||||
)
|
||||
process.exit(1)
|
||||
return
|
||||
no_tls: str = f"{socket_name}.{server_hostname}"
|
||||
tls: str = f"https://{socket_name}.{server_hostname}"
|
||||
response = f"{welcome_banner}\nYour local service has been exposed to the public\n\
|
||||
Internet address: {no_tls}\nTLS termination: {tls}\n"
|
||||
process.stdout.write(response + "\n")
|
||||
process.logger.info(f"Exposed on {no_tls}")
|
||||
read_task: Task = asyncio.create_task(process.stdin.read())
|
||||
timeout_task: Task = asyncio.create_task(process_timeout(process))
|
||||
done, pending = await asyncio.wait(
|
||||
[read_task, timeout_task], return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
for task in done:
|
||||
try:
|
||||
await task
|
||||
except asyncssh.BreakReceived:
|
||||
pass
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
|
||||
process.exit(0)
|
||||
else:
|
||||
response = (
|
||||
"Per-IP connection limit: too many connections running over this IP.\n"
|
||||
)
|
||||
process.stdout.write(response + "\n")
|
||||
process.logger.warning("Rejected connection due to per-IP connection limit")
|
||||
process.exit(1)
|
||||
return
|
||||
else:
|
||||
response = "Rate limited: please try later.\n"
|
||||
process.stdout.write(response + "\n")
|
||||
process.logger.warning("Rejected connection due to rate limit")
|
||||
process.exit(1)
|
||||
return
|
||||
no_tls: str = f"{socket_name}.{server_hostname}"
|
||||
tls: str = f"https://{socket_name}.{server_hostname}"
|
||||
response = f"{welcome_banner}Your local service has been exposed\
|
||||
to the public Internet address: {no_tls}\nTLS termination: {tls}\n"
|
||||
process.stdout.write(response + "\n")
|
||||
logging.info(f"Exposed on {no_tls}, {tls}.")
|
||||
while not process.stdin.at_eof():
|
||||
try:
|
||||
await process.stdin.read()
|
||||
except asyncssh.TerminalSizeChanged:
|
||||
pass
|
||||
process.exit(0)
|
||||
|
||||
|
||||
async def start_ssh_server() -> None:
|
||||
|
@ -182,10 +307,8 @@ def init_logging():
|
|||
"""Init logging with a custom handler"""
|
||||
logging.root.handlers: Handler = [InterceptHandler()]
|
||||
logging.root.setLevel(log_level)
|
||||
for name in logging.root.manager.loggerDict.keys():
|
||||
logging.getLogger(name).handlers: list = []
|
||||
logging.getLogger(name).propagate: bool = True
|
||||
logger.configure(handlers=[{"sink": sys.stdout, "serialize": False}])
|
||||
fmt = "<green>[{time}]</green> <level>[{level}]</level> - <level>{message}</level>"
|
||||
logger.configure(handlers=[{"sink": sys.stdout, "serialize": False, "format": fmt}])
|
||||
|
||||
|
||||
def get_random_slug(length) -> str:
|
||||
|
|
Loading…
Reference in New Issue
Block a user