From 93e0b4edd958f54c7313a6ace91ef79569b1eb6b Mon Sep 17 00:00:00 2001 From: Euiseo Cha Date: Sat, 31 Aug 2024 14:37:21 +0900 Subject: [PATCH] feat: refactoring typed programing --- base.py | 21 ++++-- download_certs.sh | 0 plugins/alwaysonline.py | 37 +++++++---- plugins/container.py | 40 +++++++----- plugins/serial.py | 5 +- server.py | 138 ++++++++++++++++++++++------------------ smtp.py | 18 +++--- 7 files changed, 155 insertions(+), 104 deletions(-) mode change 100644 => 100755 download_certs.sh diff --git a/base.py b/base.py index 8816905..d0588af 100644 --- a/base.py +++ b/base.py @@ -19,6 +19,7 @@ import importlib import subprocess import platform +from abc import ABC, abstractmethod from datetime import datetime, timezone from typing import Union, List @@ -47,14 +48,14 @@ def jsonrpc2_create_id(data): def jsonrpc2_encode(method, params=None): data = {"jsonrpc": "2.0", "method": method, "params": params} id = jsonrpc2_create_id(data) - id = data.get('id') + id = data.get("id") return (id, json.dumps(data)) def jsonrpc2_decode(text): data = json.loads(text) - type = 'error' if 'error' in data else 'result' if 'result' in data else None - id = data.get('id') + type = "error" if "error" in data else "result" if "result" in data else None + id = data.get("id") rpcdata = data.get(type) if type else None return type, id, rpcdata @@ -68,6 +69,7 @@ def jsonrpc2_error_encode(error, id=""): data = {"jsonrpc": "2.0", "error": error, "id": id} return json.dumps(data) + def find_openssl_binpath(): system = platform.system() @@ -121,8 +123,19 @@ def find_openssl_binpath(): return "openssl" +class ExtensionType: + def __init__(self): + self.type: str = None + self.method: str = None + self.exported_methods: list[str] = [] + self.connection_type: str = None + + +type extension_type = ExtensionType + + class Extension: - extensions = [] + extensions: list[extension_type] = [] protocols = [] buffer_size = 8192 diff --git a/download_certs.sh b/download_certs.sh old mode 100644 new mode 100755 diff --git a/plugins/alwaysonline.py b/plugins/alwaysonline.py index 6402e01..4e75996 100644 --- a/plugins/alwaysonline.py +++ b/plugins/alwaysonline.py @@ -29,9 +29,11 @@ except Exception as e: es = Elasticsearch([es_host]) + def generate_id(url): """Generate a unique ID for a URL by hashing it.""" - return hashlib.sha256(url.encode('utf-8')).hexdigest() + return hashlib.sha256(url.encode("utf-8")).hexdigest() + def get_cached_page_from_google(url): status_code, content = (0, b"") @@ -50,6 +52,7 @@ def get_cached_page_from_google(url): return status_code, content + # API documentation: https://archive.org/help/wayback_api.php def get_cached_page_from_wayback(url): status_code, content = (0, b"") @@ -89,30 +92,37 @@ def get_cached_page_from_wayback(url): return status_code, content + def get_cached_page_from_elasticsearch(url): url_id = generate_id(url) try: result = es.get(index=es_index, id=url_id) - logger.info(result['_source']) - return 200, result['_source']['content'].encode(client_encoding) + logger.info(result["_source"]) + return 200, result["_source"]["content"].encode(client_encoding) except NotFoundError: return 404, b"" except Exception as e: logger.error(f"Error fetching from Elasticsearch: {e}") return 502, b"" + def cache_to_elasticsearch(url, data): url_id = generate_id(url) timestamp = datetime.utcnow().isoformat() try: - es.index(index=es_index, id=url_id, body={ - "url": url, - "content": data.decode(client_encoding), - "timestamp": timestamp - }) + es.index( + index=es_index, + id=url_id, + body={ + "url": url, + "content": data.decode(client_encoding), + "timestamp": timestamp, + }, + ) except Exception as e: logger.error(f"Error caching to Elasticsearch: {e}") + def get_page_from_origin_server(url): try: response = requests.get(url) @@ -120,6 +130,7 @@ def get_page_from_origin_server(url): except Exception as e: return 502, str(e).encode(client_encoding) + class AlwaysOnline(Extension): def __init__(self): self.type = "connector" # this is a connector @@ -128,13 +139,13 @@ class AlwaysOnline(Extension): def connect(self, conn, data, webserver, port, scheme, method, url): logger.info("[*] Connecting... Connecting...") - + connected = False - + is_ssl = scheme in [b"https", b"tls", b"ssl"] cache_hit = 0 buffered = b"" - + def sendall(sock, conn, data): # send first chuck sock.send(data) @@ -151,11 +162,11 @@ class AlwaysOnline(Extension): sock.send(chunk) except: break - + target_url = url.decode(client_encoding) target_scheme = scheme.decode(client_encoding) target_webserver = webserver.decode(client_encoding) - + if "://" not in target_url: target_url = f"{target_scheme}://{target_webserver}:{port}{target_url}" diff --git a/plugins/container.py b/plugins/container.py index da14569..a74b4bd 100644 --- a/plugins/container.py +++ b/plugins/container.py @@ -21,7 +21,17 @@ class Container(Extension): def __init__(self): self.type = "rpcmethod" self.method = "container_init" - self.exported_methods = ["container_cteate", "container_start", "container_run", "container_stop", "container_pause", "container_unpause", "container_restart", "container_kill", "container_remove"] + self.exported_methods = [ + "container_cteate", + "container_start", + "container_run", + "container_stop", + "container_pause", + "container_unpause", + "container_restart", + "container_kill", + "container_remove", + ] # docker self.client = docker.from_env() @@ -33,13 +43,13 @@ class Container(Extension): def container_cteate(self, type, id, params, conn): # todo: - return b"[*] Created" - + def container_start(self, type, id, params, conn): - name = params['name'] + name = params["name"] container = self.client.containers.get(name) container.start() - + def container_run(self, type, id, params, conn): devices = params["devices"] image = params["image"] @@ -68,35 +78,35 @@ class Container(Extension): logger.info("[*] Stopped") return b"[*] Stopped" - + def container_pause(self, type, id, params, conn): - name = params['name'] + name = params["name"] container = self.client.containers.get(name) container.pause() return b"[*] Paused" - + def container_unpause(self, type, id, params, conn): - name = params['name'] + name = params["name"] container = self.client.containers.get(name) container.unpause() return b"[*] Unpaused" - + def container_restart(self, type, id, params, conn): - name = params['name'] + name = params["name"] container = self.client.containers.get(name) container.restart() return b"[*] Restarted" - + def container_kill(self, type, id, params, conn): # TODO: - return b"[*] Killed" - - def container_remove(self, type, id, params, conn): - name = params['name'] + + def container_remove(self, type, id, params, conn): + name = params["name"] container = self.client.containers.get(name) container.remove() - return b"[*] Removed" \ No newline at end of file + return b"[*] Removed" diff --git a/plugins/serial.py b/plugins/serial.py index dec37a9..e83bda3 100644 --- a/plugins/serial.py +++ b/plugins/serial.py @@ -25,6 +25,7 @@ import logging logger = logging.getLogger(__name__) + class Serial(Extension): def __init__(self): self.type = "connector" @@ -38,7 +39,7 @@ class Serial(Extension): connected = False ser = None try: - port_path = url.decode(client_encoding).replace('/', '') + port_path = url.decode(client_encoding).replace("/", "") if not ser: ser = serial.Serial(port_path, baudrate=9600, timeout=2) connected = True @@ -49,7 +50,7 @@ class Serial(Extension): ser_data = ser.read_all() logger.debug(f"Data received: {ser_data}") - + if ser_data: conn.send(ser_data.decode(client_encoding)) except serial.SerialException as e: diff --git a/server.py b/server.py index 6734d0c..e065cdd 100644 --- a/server.py +++ b/server.py @@ -38,6 +38,7 @@ from base import ( Logger, ) + logger = Logger(name="server") # initialization @@ -47,11 +48,11 @@ try: config("SERVER_URL", default="") ) server_connection_type = config("SERVER_CONNECTION_TYPE", default="proxy") - cakey = config("CA_KEY", default="ca.key") - cacert = config("CA_CERT", default="ca.crt") - certkey = config("CERT_KEY", default="cert.key") - certdir = config("CERT_DIR", default="certs/") - openssl_binpath = config("OPENSSL_BINPATH", default=find_openssl_binpath()) + ca_key = config("CA_KEY", default="ca.key") + ca_cert = config("CA_CERT", default="ca.crt") + cert_key = config("CERT_KEY", default="cert.key") + cert_dir = config("CERT_DIR", default="certs/") + openssl_bin_path = config("OPENSSL_BINPATH", default=find_openssl_binpath()) client_encoding = config("CLIENT_ENCODING", default="utf-8") local_domain = config("LOCAL_DOMAIN", default="") proxy_pass = config("PROXY_PASS", default="") @@ -87,7 +88,7 @@ if _username: auth = HTTPBasicAuth(_username, _password) -def parse_first_data(data): +def parse_first_data(data: bytes): parsed_data = (b"", b"", b"", b"", b"") try: @@ -126,13 +127,13 @@ def parse_first_data(data): return parsed_data -def conn_string(conn, data, addr): +def conn_string(conn: socket.socket, data: bytes, addr: bytes): # JSON-RPC 2.0 request - def process_jsonrpc2(data): - jsondata = json.loads(data.decode(client_encoding, errors="ignore")) - if jsondata["jsonrpc"] == "2.0": + def process_jsonrpc2(_data): + json_data = json.loads(_data.decode(client_encoding, errors="ignore")) + if json_data["jsonrpc"] == "2.0": jsonrpc2_server( - conn, jsondata["id"], jsondata["method"], jsondata["params"] + conn, json_data["id"], json_data["method"], json_data["params"] ) return True return False @@ -166,42 +167,44 @@ def conn_string(conn, data, addr): proxy_server(webserver, port, scheme, method, url, conn, addr, data) -def jsonrpc2_server(conn, id, method, params): +def jsonrpc2_server( + conn: socket.socket, _id: str, method: str, params: dict[str, str | int] +): if method == "relay_accept": - accepted_relay[id] = conn + accepted_relay[_id] = conn connection_speed = params["connection_speed"] - logger.info("[*] connection speed: %s milliseconds" % (str(connection_speed))) + logger.info("[*] connection speed: %s milliseconds" % str(connection_speed)) while conn.fileno() > -1: time.sleep(1) - del accepted_relay[id] - logger.info("[*] relay destroyed: %s" % id) + del accepted_relay[_id] + logger.info("[*] relay destroyed: %s" % _id) else: - Extension.dispatch_rpcmethod(method, "call", id, params, conn) + Extension.dispatch_rpcmethod(method, "call", _id, params, conn) # return in conn_string() -def proxy_connect(webserver, conn): +def proxy_connect(webserver: bytes, conn: socket.socket): hostname = webserver.decode(client_encoding) - certpath = "%s/%s.crt" % (certdir.rstrip("/"), hostname) + cert_path = "%s/%s.crt" % (cert_dir.rstrip("/"), hostname) - if not os.path.exists(certdir): - os.makedirs(certdir) + if not os.path.exists(cert_dir): + os.makedirs(cert_dir) # https://stackoverflow.com/questions/24055036/handle-https-request-in-proxy-server-by-c-sharp-connect-tunnel conn.send(b"HTTP/1.1 200 Connection Established\r\n\r\n") # https://github.com/inaz2/proxy2/blob/master/proxy2.py try: - if not os.path.isfile(certpath): + if not os.path.isfile(cert_path): epoch = "%d" % (time.time() * 1000) p1 = Popen( [ - openssl_binpath, + openssl_bin_path, "req", "-new", "-key", - certkey, + cert_key, "-subj", "/CN=%s" % hostname, ], @@ -209,19 +212,19 @@ def proxy_connect(webserver, conn): ) p2 = Popen( [ - openssl_binpath, + openssl_bin_path, "x509", "-req", "-days", "3650", "-CA", - cacert, + ca_cert, "-CAkey", - cakey, + ca_key, "-set_serial", epoch, "-out", - certpath, + cert_path, ], stdin=p1.stdout, stderr=PIPE, @@ -232,20 +235,20 @@ def proxy_connect(webserver, conn): "[*] OpenSSL distribution not found on this system. Skipping certificate issuance.", exc_info=e, ) - certpath = "default.crt" + cert_path = "default.crt" except Exception as e: logger.error("[*] Skipping certificate issuance.", exc_info=e) - certpath = "default.crt" - - logger.info("[*] Certificate file: %s" % (certpath)) - logger.info("[*] Private key file: %s" % (certkey)) + cert_path = "default.crt" + + logger.info("[*] Certificate file: %s" % cert_path) + logger.info("[*] Private key file: %s" % cert_key) # https://stackoverflow.com/questions/11255530/python-simple-ssl-socket-server # https://docs.python.org/3/library/ssl.html context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) context.check_hostname = False context.verify_mode = ssl.CERT_NONE - context.load_cert_chain(certfile=certpath, keyfile=certkey) + context.load_cert_chain(certfile=cert_path, keyfile=cert_key) try: # https://stackoverflow.com/questions/11255530/python-simple-ssl-socket-server @@ -256,12 +259,14 @@ def proxy_connect(webserver, conn): "[*] SSL negotiation failed.", exc_info=e, ) - return (conn, b"") + return conn, b"" - return (conn, data) + return conn, data -def proxy_check_filtered(data, webserver, port, scheme, method, url): +def proxy_check_filtered( + data: bytes, webserver: bytes, port: bytes, scheme: bytes, method: bytes, url: bytes +): filtered = False filters = Extension.get_filters() @@ -272,7 +277,16 @@ def proxy_check_filtered(data, webserver, port, scheme, method, url): return filtered -def proxy_server(webserver, port, scheme, method, url, conn, addr, data): +def proxy_server( + webserver: bytes, + port: bytes, + scheme: bytes, + method: bytes, + url: bytes, + conn: socket.socket, + addr: bytes, + data: bytes, +): try: logger.info("[*] Started the request. %s" % (str(addr[0]))) @@ -296,14 +310,11 @@ def proxy_server(webserver, port, scheme, method, url, conn, addr, data): _, _, _, method, url = parse_first_data(data) # https://stackoverflow.com/questions/44343739/python-sockets-ssl-eof-occurred-in-violation-of-protocol - def sock_close(sock, is_ssl=False): - # if is_ssl: - # sock = sock.unwrap() - # sock.shutdown(socket.SHUT_RDWR) - sock.close() + def sock_close(_sock: socket.socket): + _sock.close() # Wait to see if there is more data to transmit - def sendall(sock, conn, data): + def sendall(_sock: socket.socket, _conn: socket.socket, _data: bytes): # send first chuck if proxy_check_filtered(data, webserver, port, scheme, method, url): sock.close() @@ -324,7 +335,7 @@ def proxy_server(webserver, port, scheme, method, url, conn, addr, data): if proxy_check_filtered( buffered, webserver, port, scheme, method, url ): - sock_close(sock, is_ssl) + sock_close(sock) raise Exception("Filtered request") sock.send(chunk) if len(buffered) > buffer_size * 2: @@ -354,7 +365,7 @@ def proxy_server(webserver, port, scheme, method, url, conn, addr, data): i = 0 is_http_403 = False - buffered = b"" + _buffered = b"" while True: chunk = sock.recv(buffer_size) if not chunk: @@ -362,24 +373,26 @@ def proxy_server(webserver, port, scheme, method, url, conn, addr, data): if i == 0 and chunk.find(b"HTTP/1.1 403") == 0: is_http_403 = True break - buffered += chunk - if proxy_check_filtered(buffered, webserver, port, scheme, method, url): - sock_close(sock, is_ssl) + _buffered += chunk + if proxy_check_filtered( + _buffered, webserver, port, scheme, method, url + ): + sock_close(sock) add_filtered_host(webserver.decode(client_encoding), "127.0.0.1") raise Exception("Filtered response") conn.send(chunk) - if len(buffered) > buffer_size * 2: - buffered = buffered[-buffer_size * 2 :] + if len(_buffered) > buffer_size * 2: + _buffered = _buffered[-buffer_size * 2 :] i += 1 # when blocked if is_http_403: logger.warning( "[*] Blocked the request by remote server: %s" - % (webserver.decode(client_encoding)) + % webserver.decode(client_encoding) ) - def bypass_callback(response, *args, **kwargs): + def bypass_callback(response: requests.Response): if response.status_code != 200: conn.sendall(b'HTTP/1.1 403 Forbidden\r\n\r\n{"status":403}') return @@ -420,7 +433,7 @@ def proxy_server(webserver, port, scheme, method, url, conn, addr, data): else: conn.sendall(b'HTTP/1.1 403 Forbidden\r\n\r\n{"status":403}') - sock_close(sock, is_ssl) + sock_close(sock) logger.info( "[*] Received %s chunks. (%s bytes per chunk)" @@ -509,27 +522,27 @@ def proxy_server(webserver, port, scheme, method, url, conn, addr, data): else: resolved_address_list.remove(resolved_address_list[0]) logger.info("[*] the relay is gone. %s" % id) - sock_close(sock, is_ssl) + sock_close(sock) return # get response i = 0 buffered = b"" while True: - chunk = sock.recv(buffer_size) - if not chunk: + _chunk = sock.recv(buffer_size) + if not _chunk: break - buffered += chunk + buffered += _chunk if proxy_check_filtered(buffered, webserver, port, scheme, method, url): - sock_close(sock, is_ssl) + sock_close(sock) add_filtered_host(webserver.decode(client_encoding), "127.0.0.1") raise Exception("Filtered response") - conn.send(chunk) + conn.send(_chunk) if len(buffered) > buffer_size * 2: buffered = buffered[-buffer_size * 2 :] i += 1 - sock_close(sock, is_ssl) + sock_close(sock) logger.info( "[*] Received %s chunks. (%s bytes per chunk)" @@ -604,7 +617,7 @@ def proxy_server(webserver, port, scheme, method, url, conn, addr, data): # journaling a filtered hosts -def add_filtered_host(domain, ip_address): +def add_filtered_host(domain: str, ip_address: str): hosts_path = "./filtered.hosts" with open(hosts_path, "r") as file: lines = file.readlines() @@ -619,6 +632,7 @@ def add_filtered_host(domain, ip_address): def start(): # Main Program try: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.bind(("", listening_port)) sock.listen(max_connection) logger.info("[*] Server started successfully [ %d ]" % listening_port) diff --git a/smtp.py b/smtp.py index 23d6536..8e0382d 100644 --- a/smtp.py +++ b/smtp.py @@ -40,6 +40,7 @@ auth = None if _username: auth = HTTPBasicAuth(_username, _password) + class CaterpillarSMTPHandler: def __init__(self): self.smtpd_hostname = "CaterpillarSMTPServer" @@ -53,8 +54,8 @@ class CaterpillarSMTPHandler: message = EmailMessage() message.set_content(data) - subject = message.get('Subject', '') - to = message.get('To', '') + subject = message.get("Subject", "") + to = message.get("To", "") proxy_data = { "headers": { @@ -75,13 +76,13 @@ class CaterpillarSMTPHandler: response = await asyncio.to_thread( requests.post, server_url, - headers=proxy_data['headers'], + headers=proxy_data["headers"], data=raw_data, - auth=auth + auth=auth, ) if response.status_code == 200: type, id, rpcdata = jsonrpc2_decode(response.text) - if rpcdata['success']: + if rpcdata["success"]: logger.info("[*] Email sent successfully.") else: raise Exception(f"({rpcdata['code']}) {rpcdata['message']}") @@ -89,9 +90,9 @@ class CaterpillarSMTPHandler: raise Exception(f"Status {response.status_code}") except Exception as e: logger.error("[*] Failed to send email", exc_info=e) - return '500 Could not process your message. ' + str(e) + return "500 Could not process your message. " + str(e) - return '250 OK' + return "250 OK" # https://aiosmtpd-pepoluan.readthedocs.io/en/latest/migrating.html @@ -101,11 +102,12 @@ def main(): # Run the event loop in a separate thread. controller.start() # Wait for the user to press Return. - input('SMTP server running. Press Return to stop server and exit.') + input("SMTP server running. Press Return to stop server and exit.") controller.stop() logger.warning("[*] User has requested an interrupt") logger.warning("[*] Application Exiting.....") sys.exit() + if __name__ == "__main__": main()