feat: refactoring typed programing

This commit is contained in:
Euiseo Cha 2024-08-31 14:37:21 +09:00
parent 9c2b66fb07
commit 93e0b4edd9
No known key found for this signature in database
GPG Key ID: 39F74FF9CEA87CC9
7 changed files with 155 additions and 104 deletions

21
base.py
View File

@ -19,6 +19,7 @@ import importlib
import subprocess
import platform
from abc import ABC, abstractmethod
from datetime import datetime, timezone
from typing import Union, List
@ -47,14 +48,14 @@ def jsonrpc2_create_id(data):
def jsonrpc2_encode(method, params=None):
data = {"jsonrpc": "2.0", "method": method, "params": params}
id = jsonrpc2_create_id(data)
id = data.get('id')
id = data.get("id")
return (id, json.dumps(data))
def jsonrpc2_decode(text):
data = json.loads(text)
type = 'error' if 'error' in data else 'result' if 'result' in data else None
id = data.get('id')
type = "error" if "error" in data else "result" if "result" in data else None
id = data.get("id")
rpcdata = data.get(type) if type else None
return type, id, rpcdata
@ -68,6 +69,7 @@ def jsonrpc2_error_encode(error, id=""):
data = {"jsonrpc": "2.0", "error": error, "id": id}
return json.dumps(data)
def find_openssl_binpath():
system = platform.system()
@ -121,8 +123,19 @@ def find_openssl_binpath():
return "openssl"
class ExtensionType:
def __init__(self):
self.type: str = None
self.method: str = None
self.exported_methods: list[str] = []
self.connection_type: str = None
type extension_type = ExtensionType
class Extension:
extensions = []
extensions: list[extension_type] = []
protocols = []
buffer_size = 8192

0
download_certs.sh Normal file → Executable file
View File

View File

@ -29,9 +29,11 @@ except Exception as e:
es = Elasticsearch([es_host])
def generate_id(url):
"""Generate a unique ID for a URL by hashing it."""
return hashlib.sha256(url.encode('utf-8')).hexdigest()
return hashlib.sha256(url.encode("utf-8")).hexdigest()
def get_cached_page_from_google(url):
status_code, content = (0, b"")
@ -50,6 +52,7 @@ def get_cached_page_from_google(url):
return status_code, content
# API documentation: https://archive.org/help/wayback_api.php
def get_cached_page_from_wayback(url):
status_code, content = (0, b"")
@ -89,30 +92,37 @@ def get_cached_page_from_wayback(url):
return status_code, content
def get_cached_page_from_elasticsearch(url):
url_id = generate_id(url)
try:
result = es.get(index=es_index, id=url_id)
logger.info(result['_source'])
return 200, result['_source']['content'].encode(client_encoding)
logger.info(result["_source"])
return 200, result["_source"]["content"].encode(client_encoding)
except NotFoundError:
return 404, b""
except Exception as e:
logger.error(f"Error fetching from Elasticsearch: {e}")
return 502, b""
def cache_to_elasticsearch(url, data):
url_id = generate_id(url)
timestamp = datetime.utcnow().isoformat()
try:
es.index(index=es_index, id=url_id, body={
"url": url,
"content": data.decode(client_encoding),
"timestamp": timestamp
})
es.index(
index=es_index,
id=url_id,
body={
"url": url,
"content": data.decode(client_encoding),
"timestamp": timestamp,
},
)
except Exception as e:
logger.error(f"Error caching to Elasticsearch: {e}")
def get_page_from_origin_server(url):
try:
response = requests.get(url)
@ -120,6 +130,7 @@ def get_page_from_origin_server(url):
except Exception as e:
return 502, str(e).encode(client_encoding)
class AlwaysOnline(Extension):
def __init__(self):
self.type = "connector" # this is a connector
@ -128,13 +139,13 @@ class AlwaysOnline(Extension):
def connect(self, conn, data, webserver, port, scheme, method, url):
logger.info("[*] Connecting... Connecting...")
connected = False
is_ssl = scheme in [b"https", b"tls", b"ssl"]
cache_hit = 0
buffered = b""
def sendall(sock, conn, data):
# send first chuck
sock.send(data)
@ -151,11 +162,11 @@ class AlwaysOnline(Extension):
sock.send(chunk)
except:
break
target_url = url.decode(client_encoding)
target_scheme = scheme.decode(client_encoding)
target_webserver = webserver.decode(client_encoding)
if "://" not in target_url:
target_url = f"{target_scheme}://{target_webserver}:{port}{target_url}"

View File

@ -21,7 +21,17 @@ class Container(Extension):
def __init__(self):
self.type = "rpcmethod"
self.method = "container_init"
self.exported_methods = ["container_cteate", "container_start", "container_run", "container_stop", "container_pause", "container_unpause", "container_restart", "container_kill", "container_remove"]
self.exported_methods = [
"container_cteate",
"container_start",
"container_run",
"container_stop",
"container_pause",
"container_unpause",
"container_restart",
"container_kill",
"container_remove",
]
# docker
self.client = docker.from_env()
@ -33,13 +43,13 @@ class Container(Extension):
def container_cteate(self, type, id, params, conn):
# todo: -
return b"[*] Created"
def container_start(self, type, id, params, conn):
name = params['name']
name = params["name"]
container = self.client.containers.get(name)
container.start()
def container_run(self, type, id, params, conn):
devices = params["devices"]
image = params["image"]
@ -68,35 +78,35 @@ class Container(Extension):
logger.info("[*] Stopped")
return b"[*] Stopped"
def container_pause(self, type, id, params, conn):
name = params['name']
name = params["name"]
container = self.client.containers.get(name)
container.pause()
return b"[*] Paused"
def container_unpause(self, type, id, params, conn):
name = params['name']
name = params["name"]
container = self.client.containers.get(name)
container.unpause()
return b"[*] Unpaused"
def container_restart(self, type, id, params, conn):
name = params['name']
name = params["name"]
container = self.client.containers.get(name)
container.restart()
return b"[*] Restarted"
def container_kill(self, type, id, params, conn):
# TODO: -
return b"[*] Killed"
def container_remove(self, type, id, params, conn):
name = params['name']
def container_remove(self, type, id, params, conn):
name = params["name"]
container = self.client.containers.get(name)
container.remove()
return b"[*] Removed"
return b"[*] Removed"

View File

@ -25,6 +25,7 @@ import logging
logger = logging.getLogger(__name__)
class Serial(Extension):
def __init__(self):
self.type = "connector"
@ -38,7 +39,7 @@ class Serial(Extension):
connected = False
ser = None
try:
port_path = url.decode(client_encoding).replace('/', '')
port_path = url.decode(client_encoding).replace("/", "")
if not ser:
ser = serial.Serial(port_path, baudrate=9600, timeout=2)
connected = True
@ -49,7 +50,7 @@ class Serial(Extension):
ser_data = ser.read_all()
logger.debug(f"Data received: {ser_data}")
if ser_data:
conn.send(ser_data.decode(client_encoding))
except serial.SerialException as e:

138
server.py
View File

@ -38,6 +38,7 @@ from base import (
Logger,
)
logger = Logger(name="server")
# initialization
@ -47,11 +48,11 @@ try:
config("SERVER_URL", default="")
)
server_connection_type = config("SERVER_CONNECTION_TYPE", default="proxy")
cakey = config("CA_KEY", default="ca.key")
cacert = config("CA_CERT", default="ca.crt")
certkey = config("CERT_KEY", default="cert.key")
certdir = config("CERT_DIR", default="certs/")
openssl_binpath = config("OPENSSL_BINPATH", default=find_openssl_binpath())
ca_key = config("CA_KEY", default="ca.key")
ca_cert = config("CA_CERT", default="ca.crt")
cert_key = config("CERT_KEY", default="cert.key")
cert_dir = config("CERT_DIR", default="certs/")
openssl_bin_path = config("OPENSSL_BINPATH", default=find_openssl_binpath())
client_encoding = config("CLIENT_ENCODING", default="utf-8")
local_domain = config("LOCAL_DOMAIN", default="")
proxy_pass = config("PROXY_PASS", default="")
@ -87,7 +88,7 @@ if _username:
auth = HTTPBasicAuth(_username, _password)
def parse_first_data(data):
def parse_first_data(data: bytes):
parsed_data = (b"", b"", b"", b"", b"")
try:
@ -126,13 +127,13 @@ def parse_first_data(data):
return parsed_data
def conn_string(conn, data, addr):
def conn_string(conn: socket.socket, data: bytes, addr: bytes):
# JSON-RPC 2.0 request
def process_jsonrpc2(data):
jsondata = json.loads(data.decode(client_encoding, errors="ignore"))
if jsondata["jsonrpc"] == "2.0":
def process_jsonrpc2(_data):
json_data = json.loads(_data.decode(client_encoding, errors="ignore"))
if json_data["jsonrpc"] == "2.0":
jsonrpc2_server(
conn, jsondata["id"], jsondata["method"], jsondata["params"]
conn, json_data["id"], json_data["method"], json_data["params"]
)
return True
return False
@ -166,42 +167,44 @@ def conn_string(conn, data, addr):
proxy_server(webserver, port, scheme, method, url, conn, addr, data)
def jsonrpc2_server(conn, id, method, params):
def jsonrpc2_server(
conn: socket.socket, _id: str, method: str, params: dict[str, str | int]
):
if method == "relay_accept":
accepted_relay[id] = conn
accepted_relay[_id] = conn
connection_speed = params["connection_speed"]
logger.info("[*] connection speed: %s milliseconds" % (str(connection_speed)))
logger.info("[*] connection speed: %s milliseconds" % str(connection_speed))
while conn.fileno() > -1:
time.sleep(1)
del accepted_relay[id]
logger.info("[*] relay destroyed: %s" % id)
del accepted_relay[_id]
logger.info("[*] relay destroyed: %s" % _id)
else:
Extension.dispatch_rpcmethod(method, "call", id, params, conn)
Extension.dispatch_rpcmethod(method, "call", _id, params, conn)
# return in conn_string()
def proxy_connect(webserver, conn):
def proxy_connect(webserver: bytes, conn: socket.socket):
hostname = webserver.decode(client_encoding)
certpath = "%s/%s.crt" % (certdir.rstrip("/"), hostname)
cert_path = "%s/%s.crt" % (cert_dir.rstrip("/"), hostname)
if not os.path.exists(certdir):
os.makedirs(certdir)
if not os.path.exists(cert_dir):
os.makedirs(cert_dir)
# https://stackoverflow.com/questions/24055036/handle-https-request-in-proxy-server-by-c-sharp-connect-tunnel
conn.send(b"HTTP/1.1 200 Connection Established\r\n\r\n")
# https://github.com/inaz2/proxy2/blob/master/proxy2.py
try:
if not os.path.isfile(certpath):
if not os.path.isfile(cert_path):
epoch = "%d" % (time.time() * 1000)
p1 = Popen(
[
openssl_binpath,
openssl_bin_path,
"req",
"-new",
"-key",
certkey,
cert_key,
"-subj",
"/CN=%s" % hostname,
],
@ -209,19 +212,19 @@ def proxy_connect(webserver, conn):
)
p2 = Popen(
[
openssl_binpath,
openssl_bin_path,
"x509",
"-req",
"-days",
"3650",
"-CA",
cacert,
ca_cert,
"-CAkey",
cakey,
ca_key,
"-set_serial",
epoch,
"-out",
certpath,
cert_path,
],
stdin=p1.stdout,
stderr=PIPE,
@ -232,20 +235,20 @@ def proxy_connect(webserver, conn):
"[*] OpenSSL distribution not found on this system. Skipping certificate issuance.",
exc_info=e,
)
certpath = "default.crt"
cert_path = "default.crt"
except Exception as e:
logger.error("[*] Skipping certificate issuance.", exc_info=e)
certpath = "default.crt"
logger.info("[*] Certificate file: %s" % (certpath))
logger.info("[*] Private key file: %s" % (certkey))
cert_path = "default.crt"
logger.info("[*] Certificate file: %s" % cert_path)
logger.info("[*] Private key file: %s" % cert_key)
# https://stackoverflow.com/questions/11255530/python-simple-ssl-socket-server
# https://docs.python.org/3/library/ssl.html
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
context.load_cert_chain(certfile=certpath, keyfile=certkey)
context.load_cert_chain(certfile=cert_path, keyfile=cert_key)
try:
# https://stackoverflow.com/questions/11255530/python-simple-ssl-socket-server
@ -256,12 +259,14 @@ def proxy_connect(webserver, conn):
"[*] SSL negotiation failed.",
exc_info=e,
)
return (conn, b"")
return conn, b""
return (conn, data)
return conn, data
def proxy_check_filtered(data, webserver, port, scheme, method, url):
def proxy_check_filtered(
data: bytes, webserver: bytes, port: bytes, scheme: bytes, method: bytes, url: bytes
):
filtered = False
filters = Extension.get_filters()
@ -272,7 +277,16 @@ def proxy_check_filtered(data, webserver, port, scheme, method, url):
return filtered
def proxy_server(webserver, port, scheme, method, url, conn, addr, data):
def proxy_server(
webserver: bytes,
port: bytes,
scheme: bytes,
method: bytes,
url: bytes,
conn: socket.socket,
addr: bytes,
data: bytes,
):
try:
logger.info("[*] Started the request. %s" % (str(addr[0])))
@ -296,14 +310,11 @@ def proxy_server(webserver, port, scheme, method, url, conn, addr, data):
_, _, _, method, url = parse_first_data(data)
# https://stackoverflow.com/questions/44343739/python-sockets-ssl-eof-occurred-in-violation-of-protocol
def sock_close(sock, is_ssl=False):
# if is_ssl:
# sock = sock.unwrap()
# sock.shutdown(socket.SHUT_RDWR)
sock.close()
def sock_close(_sock: socket.socket):
_sock.close()
# Wait to see if there is more data to transmit
def sendall(sock, conn, data):
def sendall(_sock: socket.socket, _conn: socket.socket, _data: bytes):
# send first chuck
if proxy_check_filtered(data, webserver, port, scheme, method, url):
sock.close()
@ -324,7 +335,7 @@ def proxy_server(webserver, port, scheme, method, url, conn, addr, data):
if proxy_check_filtered(
buffered, webserver, port, scheme, method, url
):
sock_close(sock, is_ssl)
sock_close(sock)
raise Exception("Filtered request")
sock.send(chunk)
if len(buffered) > buffer_size * 2:
@ -354,7 +365,7 @@ def proxy_server(webserver, port, scheme, method, url, conn, addr, data):
i = 0
is_http_403 = False
buffered = b""
_buffered = b""
while True:
chunk = sock.recv(buffer_size)
if not chunk:
@ -362,24 +373,26 @@ def proxy_server(webserver, port, scheme, method, url, conn, addr, data):
if i == 0 and chunk.find(b"HTTP/1.1 403") == 0:
is_http_403 = True
break
buffered += chunk
if proxy_check_filtered(buffered, webserver, port, scheme, method, url):
sock_close(sock, is_ssl)
_buffered += chunk
if proxy_check_filtered(
_buffered, webserver, port, scheme, method, url
):
sock_close(sock)
add_filtered_host(webserver.decode(client_encoding), "127.0.0.1")
raise Exception("Filtered response")
conn.send(chunk)
if len(buffered) > buffer_size * 2:
buffered = buffered[-buffer_size * 2 :]
if len(_buffered) > buffer_size * 2:
_buffered = _buffered[-buffer_size * 2 :]
i += 1
# when blocked
if is_http_403:
logger.warning(
"[*] Blocked the request by remote server: %s"
% (webserver.decode(client_encoding))
% webserver.decode(client_encoding)
)
def bypass_callback(response, *args, **kwargs):
def bypass_callback(response: requests.Response):
if response.status_code != 200:
conn.sendall(b'HTTP/1.1 403 Forbidden\r\n\r\n{"status":403}')
return
@ -420,7 +433,7 @@ def proxy_server(webserver, port, scheme, method, url, conn, addr, data):
else:
conn.sendall(b'HTTP/1.1 403 Forbidden\r\n\r\n{"status":403}')
sock_close(sock, is_ssl)
sock_close(sock)
logger.info(
"[*] Received %s chunks. (%s bytes per chunk)"
@ -509,27 +522,27 @@ def proxy_server(webserver, port, scheme, method, url, conn, addr, data):
else:
resolved_address_list.remove(resolved_address_list[0])
logger.info("[*] the relay is gone. %s" % id)
sock_close(sock, is_ssl)
sock_close(sock)
return
# get response
i = 0
buffered = b""
while True:
chunk = sock.recv(buffer_size)
if not chunk:
_chunk = sock.recv(buffer_size)
if not _chunk:
break
buffered += chunk
buffered += _chunk
if proxy_check_filtered(buffered, webserver, port, scheme, method, url):
sock_close(sock, is_ssl)
sock_close(sock)
add_filtered_host(webserver.decode(client_encoding), "127.0.0.1")
raise Exception("Filtered response")
conn.send(chunk)
conn.send(_chunk)
if len(buffered) > buffer_size * 2:
buffered = buffered[-buffer_size * 2 :]
i += 1
sock_close(sock, is_ssl)
sock_close(sock)
logger.info(
"[*] Received %s chunks. (%s bytes per chunk)"
@ -604,7 +617,7 @@ def proxy_server(webserver, port, scheme, method, url, conn, addr, data):
# journaling a filtered hosts
def add_filtered_host(domain, ip_address):
def add_filtered_host(domain: str, ip_address: str):
hosts_path = "./filtered.hosts"
with open(hosts_path, "r") as file:
lines = file.readlines()
@ -619,6 +632,7 @@ def add_filtered_host(domain, ip_address):
def start(): # Main Program
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(("", listening_port))
sock.listen(max_connection)
logger.info("[*] Server started successfully [ %d ]" % listening_port)

18
smtp.py
View File

@ -40,6 +40,7 @@ auth = None
if _username:
auth = HTTPBasicAuth(_username, _password)
class CaterpillarSMTPHandler:
def __init__(self):
self.smtpd_hostname = "CaterpillarSMTPServer"
@ -53,8 +54,8 @@ class CaterpillarSMTPHandler:
message = EmailMessage()
message.set_content(data)
subject = message.get('Subject', '')
to = message.get('To', '')
subject = message.get("Subject", "")
to = message.get("To", "")
proxy_data = {
"headers": {
@ -75,13 +76,13 @@ class CaterpillarSMTPHandler:
response = await asyncio.to_thread(
requests.post,
server_url,
headers=proxy_data['headers'],
headers=proxy_data["headers"],
data=raw_data,
auth=auth
auth=auth,
)
if response.status_code == 200:
type, id, rpcdata = jsonrpc2_decode(response.text)
if rpcdata['success']:
if rpcdata["success"]:
logger.info("[*] Email sent successfully.")
else:
raise Exception(f"({rpcdata['code']}) {rpcdata['message']}")
@ -89,9 +90,9 @@ class CaterpillarSMTPHandler:
raise Exception(f"Status {response.status_code}")
except Exception as e:
logger.error("[*] Failed to send email", exc_info=e)
return '500 Could not process your message. ' + str(e)
return "500 Could not process your message. " + str(e)
return '250 OK'
return "250 OK"
# https://aiosmtpd-pepoluan.readthedocs.io/en/latest/migrating.html
@ -101,11 +102,12 @@ def main():
# Run the event loop in a separate thread.
controller.start()
# Wait for the user to press Return.
input('SMTP server running. Press Return to stop server and exit.')
input("SMTP server running. Press Return to stop server and exit.")
controller.stop()
logger.warning("[*] User has requested an interrupt")
logger.warning("[*] Application Exiting.....")
sys.exit()
if __name__ == "__main__":
main()