Compare commits

..

No commits in common. "b11920e71d3ca3f730faa17e8694f8000b7f4187" and "691fd930cc82f55872b0a154e68d3f810c3e25c0" have entirely different histories.

5 changed files with 34 additions and 175 deletions

3
.gitignore vendored
View File

@ -160,6 +160,3 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder. # option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/ #.idea/
# Others
id_rsa_host
*.sock

View File

@ -6,11 +6,6 @@ After=network.target nginx.service
Environment=UNIX_SOCKETS_DIRECTORY=/tmp/collabore-tunnel Environment=UNIX_SOCKETS_DIRECTORY=/tmp/collabore-tunnel
Environment=SERVER_HOSTNAME=tnl.clb.re Environment=SERVER_HOSTNAME=tnl.clb.re
Environment=CONFIG_DIRECTORY=. Environment=CONFIG_DIRECTORY=.
Environment=WELCOME_BANNER_FILE=./welcome_banner.txt
Environment=RATE_LIMIT_COUNT=5
Environment=RATE_LIMIT_INTERVAL=60
Environment=MAX_CONNECTIONS_PER_IP=5
Environment=TIMEOUT=120
Environment=SSH_SERVER_HOST=0.0.0.0 Environment=SSH_SERVER_HOST=0.0.0.0
Environment=SSH_SERVER_PORT=22 Environment=SSH_SERVER_PORT=22
Environment=LOG_DEPTH=2 Environment=LOG_DEPTH=2

173
main.py
View File

@ -6,134 +6,50 @@ import os
import random import random
import string import string
import sys import sys
import time
from asyncio import AbstractEventLoop from asyncio import AbstractEventLoop
from collections import deque
from os import path from os import path
from types import FrameType from types import FrameType
from typing import AnyStr, Optional, Tuple from typing import AnyStr, Optional, Tuple
from _asyncio import Task
import asyncssh import asyncssh
from asyncssh import SSHKey, SSHServerConnection from asyncssh import SSHKey, SSHServerConnection
from asyncssh.channel import (
SSHUNIXChannel,
SSHUNIXSession,
SSHUNIXSessionFactory,
)
from asyncssh.listener import create_unix_forward_listener from asyncssh.listener import create_unix_forward_listener
from asyncssh.misc import MaybeAwait from asyncssh.misc import MaybeAwait
from asyncssh.channel import SSHUNIXChannel, SSHUNIXSession, SSHUNIXSessionFactory
from loguru import logger from loguru import logger
from loguru._handler import Handler from loguru._handler import Handler
unix_sockets_dir: str = os.getenv("UNIX_SOCKETS_DIRECTORY", "/tmp/collabore-tunnel") unix_sockets_dir: str = os.getenv("UNIX_SOCKETS_DIRECTORY", "/tmp/collabore-tunnel")
server_hostname: str = os.getenv("SERVER_HOSTNAME", "tnl.clb.re") server_hostname: str = os.getenv("SERVER_HOSTNAME", "tnl.clb.re")
config_dir: str = os.getenv("CONFIG_DIRECTORY", ".") config_dir: str = os.getenv("CONFIG_DIRECTORY", ".")
welcome_banner_file: str = os.getenv("WELCOME_BANNER_FILE", "./welcome_banner.txt")
rate_limit_count: int = int(os.getenv("RATE_LIMIT_COUNT", "5"))
rate_limit_interval: int = int(os.getenv("RATE_LIMIT_INTERVAL", "60"))
max_connections_per_ip: int = int(os.getenv("MAX_CONNECTIONS_PER_IP", "5"))
timeout: int = int(os.getenv("TIMEOUT", "120"))
ssh_server_host: str = os.getenv("SSH_SERVER_HOST", "0.0.0.0") ssh_server_host: str = os.getenv("SSH_SERVER_HOST", "0.0.0.0")
ssh_server_port: int = int(os.getenv("SSH_SERVER_PORT", "22")) ssh_server_port: int = int(os.getenv("SSH_SERVER_PORT", "22"))
log_level: str = os.getenv("LOG_LEVEL", "INFO") log_level: str = os.getenv("LOG_LEVEL", "INFO")
log_depth: int = int(os.getenv("LOG_DEPTH", "2")) log_depth: int = int(os.getenv("LOG_DEPTH", "2"))
welcome_banner = f"===============================================================================\n\
def read_welcome_banner() -> str: Welcome to collabore tunnel!\n\
"""Read the welcome banner from a file""" collabore tunnel is a free and open source service offered as part of the\n\
if not os.path.exists(welcome_banner_file): club elec collabore platform (https://collabore.fr) operated by club elec that\n\
return welcome_banner allows you to expose your local services on the public Internet.\n\
with open(welcome_banner_file, "r", encoding="UTF-8") as file: To learn more about collabore tunnel,\n\
return file.read() visit the documentation website: https://tunnel.collabore.fr/\n\
club elec (https://clubelec.insset.fr) is a french not-for-profit\n\
student organisation.\n\
welcome_banner: str = read_welcome_banner() ===============================================================================\n\n"
class RateLimiter:
"""Rate limiter handling class"""
def __init__(self, max_requests: int, interval: int):
"""Init class"""
self.max_requests: int = max_requests
self.interval: int = interval
self.timestamps: deque = deque()
def is_rate_limited(self) -> bool:
"""Check if rate limited"""
now: float = time.time()
while self.timestamps and self.timestamps[0] < now - self.interval:
self.timestamps.popleft()
if len(self.timestamps) >= self.max_requests:
return True
self.timestamps.append(now)
return False
class ConcurrentConnections:
"""Concurrent connection handling class"""
def __init__(self):
"""Init class"""
self.ip_connections: dict = {}
def increment(self, ip_addr: str) -> None:
"""Increment the number of concurrent connections for an IP"""
if ip_addr not in self.ip_connections:
self.ip_connections[ip_addr] = 1
else:
self.ip_connections[ip_addr] += 1
def decrement(self, ip_addr: str) -> None:
"""Decrement the number of concurrent connections for an IP"""
self.ip_connections[ip_addr] -= 1
def get(self, ip_addr: str) -> int:
"""Get the number of concurent connection for an IP"""
return self.ip_connections.get(ip_addr, 0)
ip_address_connections = ConcurrentConnections()
def check_concurrent_connections(ip_addr: str) -> bool:
"""Checking for concurrent connections"""
return ip_address_connections.get(ip_addr) >= max_connections_per_ip
class SSHServer(asyncssh.SSHServer): class SSHServer(asyncssh.SSHServer):
"""SSH server protocol handler class""" """SSH server protocol handler class"""
rate_limiters: dict = {}
def __init__(self): def __init__(self):
"""Init class""" """Init class"""
self.conn: SSHServerConnection self.conn: SSHServerConnection
self.socket_path: str self.socket_path: str
self.ip_addr: str
def check_rate_limit(self, ip_addr: str) -> bool:
"""Check if rate limited"""
if ip_addr not in self.rate_limiters:
self.rate_limiters[ip_addr] = RateLimiter(
rate_limit_count, rate_limit_interval
)
return self.rate_limiters[ip_addr].is_rate_limited()
def connection_made(self, conn: SSHServerConnection) -> None: def connection_made(self, conn: SSHServerConnection) -> None:
"""Called when a connection is made""" """Called when a connection is made"""
self.conn = conn self.conn = conn
self.ip_addr, _ = conn.get_extra_info("peername")
if self.check_rate_limit(self.ip_addr):
conn.set_extra_info(rate_limited=True)
if check_concurrent_connections(self.ip_addr):
conn.set_extra_info(connection_limited=True)
ip_address_connections.increment(self.ip_addr)
def connection_lost(self, exc: Optional[Exception]) -> None: def connection_lost(self, exc: Optional[Exception]) -> None:
"""Called when a connection is lost or closed""" """Called when a connection is lost or closed"""
@ -143,7 +59,6 @@ class SSHServer(asyncssh.SSHServer):
os.remove(self.socket_path) os.remove(self.socket_path)
except AttributeError: except AttributeError:
pass pass
ip_address_connections.decrement(self.ip_addr)
def begin_auth(self, username: str) -> MaybeAwait[bool]: def begin_auth(self, username: str) -> MaybeAwait[bool]:
"""Authentication has been requested by the client""" """Authentication has been requested by the client"""
@ -183,67 +98,27 @@ class SSHServer(asyncssh.SSHServer):
async def handle_ssh_client(process) -> None: async def handle_ssh_client(process) -> None:
"""Function called every time a client connects to the SSH server""" """Function called every time a client connects to the SSH server"""
socket_name: str = process.get_extra_info("socket_name") socket_name: str = process.get_extra_info("socket_name")
rate_limited: bool = process.get_extra_info("rate_limited")
connection_limited: bool = process.get_extra_info("connection_limited")
response: str = "" response: str = ""
async def process_timeout(process):
"""Function to terminate the connection automatically
after a specific period of time (in minutes)"""
await asyncio.sleep(timeout * 60)
response = (
f"Timeout: you were automatically ejected after {timeout} minutes of use.\n"
)
process.stdout.write(response + "\n")
process.logger.info(
f"The user was automatically ejected after {timeout} minutes of use"
)
process.close()
if not rate_limited:
if not connection_limited:
if not socket_name: if not socket_name:
response = "Usage: ssh -R /:host:port ssh.tunnel.collabore.fr\n" response = f"Usage: ssh -R /:host:port ssh.tunnel.collabore.fr\n"
process.stdout.write(response + "\n") process.stdout.write(response + "\n")
process.logger.info(
"The user was ejected because they did not connect in port forwarding mode"
)
process.exit(1) process.exit(1)
logging.info(
"The user was ejected because they did not connect in port forwarding mode."
)
return return
no_tls: str = f"{socket_name}.{server_hostname}" no_tls: str = f"{socket_name}.{server_hostname}"
tls: str = f"https://{socket_name}.{server_hostname}" tls: str = f"https://{socket_name}.{server_hostname}"
response = f"{welcome_banner}\nYour local service has been exposed to the public\n\ response = f"{welcome_banner}Your local service has been exposed\
Internet address: {no_tls}\nTLS termination: {tls}\n" to the public Internet address: {no_tls}\nTLS termination: {tls}\n"
process.stdout.write(response + "\n") process.stdout.write(response + "\n")
process.logger.info(f"Exposed on {no_tls}") logging.info(f"Exposed on {no_tls}, {tls}.")
read_task: Task = asyncio.create_task(process.stdin.read()) while not process.stdin.at_eof():
timeout_task: Task = asyncio.create_task(process_timeout(process))
done, pending = await asyncio.wait(
[read_task, timeout_task], return_when=asyncio.FIRST_COMPLETED
)
for task in done:
try: try:
await task await process.stdin.read()
except asyncssh.BreakReceived: except asyncssh.TerminalSizeChanged:
pass pass
for task in pending:
task.cancel()
process.exit(0) process.exit(0)
else:
response = (
"Per-IP connection limit: too many connections running over this IP.\n"
)
process.stdout.write(response + "\n")
process.logger.warning("Rejected connection due to per-IP connection limit")
process.exit(1)
return
else:
response = "Rate limited: please try later.\n"
process.stdout.write(response + "\n")
process.logger.warning("Rejected connection due to rate limit")
process.exit(1)
return
async def start_ssh_server() -> None: async def start_ssh_server() -> None:
@ -307,8 +182,10 @@ def init_logging():
"""Init logging with a custom handler""" """Init logging with a custom handler"""
logging.root.handlers: Handler = [InterceptHandler()] logging.root.handlers: Handler = [InterceptHandler()]
logging.root.setLevel(log_level) logging.root.setLevel(log_level)
fmt = "<green>[{time}]</green> <level>[{level}]</level> - <level>{message}</level>" for name in logging.root.manager.loggerDict.keys():
logger.configure(handlers=[{"sink": sys.stdout, "serialize": False, "format": fmt}]) logging.getLogger(name).handlers: list = []
logging.getLogger(name).propagate: bool = True
logger.configure(handlers=[{"sink": sys.stdout, "serialize": False}])
def get_random_slug(length) -> str: def get_random_slug(length) -> str:

View File

@ -1,2 +1,2 @@
asyncssh==2.13.1 asyncssh==2.12.0
loguru==0.7.0 loguru==0.6.0

View File

@ -1,10 +0,0 @@
===============================================================================
Welcome to collabore tunnel!
collabore tunnel is a free and open source service offered as part of the
club elec collabore platform (https://collabore.fr) operated by club elec that
allows you to expose your local services on the public Internet.
To learn more about collabore tunnel,
visit the documentation website: https://tunnel.collabore.fr/
club elec (https://clubelec.insset.fr) is a french not-for-profit
student organisation.
===============================================================================