Merge pull request #45 from zeroday0619/refactoring
Some checks failed
Ruff / ruff (push) Has been cancelled

feat: refactoring typed programing
This commit is contained in:
Namhyeon Go 2024-09-13 11:32:08 +09:00 committed by GitHub
commit bc08241aa2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 198 additions and 146 deletions

21
base.py
View File

@ -19,6 +19,7 @@ import importlib
import subprocess import subprocess
import platform import platform
from abc import ABC, abstractmethod
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Union, List from typing import Union, List
@ -47,14 +48,14 @@ def jsonrpc2_create_id(data):
def jsonrpc2_encode(method, params=None): def jsonrpc2_encode(method, params=None):
data = {"jsonrpc": "2.0", "method": method, "params": params} data = {"jsonrpc": "2.0", "method": method, "params": params}
id = jsonrpc2_create_id(data) id = jsonrpc2_create_id(data)
id = data.get('id') id = data.get("id")
return (id, json.dumps(data)) return (id, json.dumps(data))
def jsonrpc2_decode(text): def jsonrpc2_decode(text):
data = json.loads(text) data = json.loads(text)
type = 'error' if 'error' in data else 'result' if 'result' in data else None type = "error" if "error" in data else "result" if "result" in data else None
id = data.get('id') id = data.get("id")
rpcdata = data.get(type) if type else None rpcdata = data.get(type) if type else None
return type, id, rpcdata return type, id, rpcdata
@ -68,6 +69,7 @@ def jsonrpc2_error_encode(error, id=""):
data = {"jsonrpc": "2.0", "error": error, "id": id} data = {"jsonrpc": "2.0", "error": error, "id": id}
return json.dumps(data) return json.dumps(data)
def find_openssl_binpath(): def find_openssl_binpath():
system = platform.system() system = platform.system()
@ -121,8 +123,19 @@ def find_openssl_binpath():
return "openssl" return "openssl"
class ExtensionType:
def __init__(self):
self.type: str = None
self.method: str = None
self.exported_methods: list[str] = []
self.connection_type: str = None
type extension_type = ExtensionType
class Extension: class Extension:
extensions = [] extensions: list[extension_type] = []
protocols = [] protocols = []
buffer_size = 8192 buffer_size = 8192

0
download_certs.sh Normal file → Executable file
View File

View File

@ -15,7 +15,7 @@ import requests
from decouple import config from decouple import config
from elasticsearch import Elasticsearch, NotFoundError from elasticsearch import Elasticsearch, NotFoundError
import hashlib import hashlib
from datetime import datetime from datetime import datetime, UTC
from base import Extension, Logger from base import Extension, Logger
logger = Logger(name="wayback") logger = Logger(name="wayback")
@ -29,11 +29,13 @@ except Exception as e:
es = Elasticsearch([es_host]) es = Elasticsearch([es_host])
def generate_id(url):
"""Generate a unique ID for a URL by hashing it."""
return hashlib.sha256(url.encode('utf-8')).hexdigest()
def get_cached_page_from_google(url): def generate_id(url: str):
"""Generate a unique ID for a URL by hashing it."""
return hashlib.sha256(url.encode("utf-8")).hexdigest()
def get_cached_page_from_google(url: str):
status_code, content = (0, b"") status_code, content = (0, b"")
# Google Cache URL # Google Cache URL
@ -50,8 +52,9 @@ def get_cached_page_from_google(url):
return status_code, content return status_code, content
# API documentation: https://archive.org/help/wayback_api.php # API documentation: https://archive.org/help/wayback_api.php
def get_cached_page_from_wayback(url): def get_cached_page_from_wayback(url: str):
status_code, content = (0, b"") status_code, content = (0, b"")
# Wayback Machine API URL # Wayback Machine API URL
@ -89,44 +92,52 @@ def get_cached_page_from_wayback(url):
return status_code, content return status_code, content
def get_cached_page_from_elasticsearch(url):
def get_cached_page_from_elasticsearch(url: str):
url_id = generate_id(url) url_id = generate_id(url)
try: try:
result = es.get(index=es_index, id=url_id) result = es.get(index=es_index, id=url_id)
logger.info(result['_source']) logger.info(result["_source"])
return 200, result['_source']['content'].encode(client_encoding) return 200, result["_source"]["content"].encode(client_encoding)
except NotFoundError: except NotFoundError:
return 404, b"" return 404, b""
except Exception as e: except Exception as e:
logger.error(f"Error fetching from Elasticsearch: {e}") logger.error(f"Error fetching from Elasticsearch: {e}")
return 502, b"" return 502, b""
def cache_to_elasticsearch(url, data):
def cache_to_elasticsearch(url: str, data: bytes):
url_id = generate_id(url) url_id = generate_id(url)
timestamp = datetime.utcnow().isoformat() timestamp = datetime.now(UTC).timestamp()
try: try:
es.index(index=es_index, id=url_id, body={ es.index(
index=es_index,
id=url_id,
body={
"url": url, "url": url,
"content": data.decode(client_encoding), "content": data.decode(client_encoding),
"timestamp": timestamp "timestamp": timestamp,
}) },
)
except Exception as e: except Exception as e:
logger.error(f"Error caching to Elasticsearch: {e}") logger.error(f"Error caching to Elasticsearch: {e}")
def get_page_from_origin_server(url):
def get_page_from_origin_server(url: str):
try: try:
response = requests.get(url) response = requests.get(url)
return response.status_code, response.content return response.status_code, response.content
except Exception as e: except Exception as e:
return 502, str(e).encode(client_encoding) return 502, str(e).encode(client_encoding)
class AlwaysOnline(Extension): class AlwaysOnline(Extension):
def __init__(self): def __init__(self):
self.type = "connector" # this is a connector self.type = "connector" # this is a connector
self.connection_type = "alwaysonline" self.connection_type = "alwaysonline"
self.buffer_size = 8192 self.buffer_size = 8192
def connect(self, conn, data, webserver, port, scheme, method, url): def connect(self, conn: socket.socket, data: bytes, webserver: bytes, port: bytes, scheme: bytes, method: bytes, url: bytes):
logger.info("[*] Connecting... Connecting...") logger.info("[*] Connecting... Connecting...")
connected = False connected = False
@ -135,20 +146,20 @@ class AlwaysOnline(Extension):
cache_hit = 0 cache_hit = 0
buffered = b"" buffered = b""
def sendall(sock, conn, data): def sendall(_sock: socket.socket, _conn: socket.socket, _data: bytes):
# send first chuck # send first chuck
sock.send(data) sock.send(_data)
if len(data) < self.buffer_size: if len(_data) < self.buffer_size:
return return
# send following chunks # send following chunks
conn.settimeout(1) _conn.settimeout(1)
while True: while True:
try: try:
chunk = conn.recv(self.buffer_size) chunk = _conn.recv(self.buffer_size)
if not chunk: if not chunk:
break break
sock.send(chunk) _sock.send(chunk)
except: except:
break break

View File

@ -9,13 +9,14 @@
# Updated at: 2024-07-02 # Updated at: 2024-07-02
# #
from socket import socket
from Bio.Seq import Seq from Bio.Seq import Seq
from Bio.SeqUtils import gc_fraction from Bio.SeqUtils import gc_fraction
from base import Extension from base import Extension
def _analyze_sequence(sequence) -> dict[str, str]: def _analyze_sequence(sequence: str) -> dict[str, str]:
""" """
Analyze a given DNA sequence to provide various nucleotide transformations and translations. Analyze a given DNA sequence to provide various nucleotide transformations and translations.
@ -41,7 +42,7 @@ def _analyze_sequence(sequence) -> dict[str, str]:
) )
def _gc_content_calculation(sequence) -> dict[str, str]: def _gc_content_calculation(sequence: str) -> dict[str, str]:
""" """
Calculate the GC content of a given DNA sequence and return it as a float. Calculate the GC content of a given DNA sequence and return it as a float.
@ -63,7 +64,7 @@ class PyBio(Extension):
def dispatch(self, type, id, params, conn): def dispatch(self, type, id, params, conn):
conn.send(b"Greeting! dispatch") conn.send(b"Greeting! dispatch")
def analyze_sequence(self, type, id, params, conn): def analyze_sequence(self, type, id, params, conn: socket):
""" """
Analyze a DNA sequence provided in the params dictionary. Analyze a DNA sequence provided in the params dictionary.
@ -91,7 +92,7 @@ class PyBio(Extension):
result = _analyze_sequence(params["sequence"]) result = _analyze_sequence(params["sequence"])
return result return result
def gc_content_calculation(self, type, id, params, conn): def gc_content_calculation(self, type, id, params, conn: socket):
""" """
Calculate the GC content for a given DNA sequence provided in the params dictionary. Calculate the GC content for a given DNA sequence provided in the params dictionary.

View File

@ -11,7 +11,7 @@
# #
import docker import docker
from socket import socket
from base import Extension, Logger from base import Extension, Logger
logger = Logger("Container") logger = Logger("Container")
@ -21,26 +21,36 @@ class Container(Extension):
def __init__(self): def __init__(self):
self.type = "rpcmethod" self.type = "rpcmethod"
self.method = "container_init" self.method = "container_init"
self.exported_methods = ["container_cteate", "container_start", "container_run", "container_stop", "container_pause", "container_unpause", "container_restart", "container_kill", "container_remove"] self.exported_methods = [
"container_cteate",
"container_start",
"container_run",
"container_stop",
"container_pause",
"container_unpause",
"container_restart",
"container_kill",
"container_remove",
]
# docker # docker
self.client = docker.from_env() self.client = docker.from_env()
def dispatch(self, type, id, params, conn): def dispatch(self, type, id, params, conn: socket):
logger.info("[*] Greeting! dispatch") logger.info("[*] Greeting! dispatch")
conn.send(b"Greeting! dispatch") conn.send(b"Greeting! dispatch")
def container_cteate(self, type, id, params, conn): def container_cteate(self, type, id, params, conn: socket):
# todo: - # todo: -
return b"[*] Created" return b"[*] Created"
def container_start(self, type, id, params, conn): def container_start(self, type, id, params, conn: socket):
name = params['name'] name = params["name"]
container = self.client.containers.get(name) container = self.client.containers.get(name)
container.start() container.start()
def container_run(self, type, id, params, conn): def container_run(self, type, id, params, conn: socket):
devices = params["devices"] devices = params["devices"]
image = params["image"] image = params["image"]
devices = params["devices"] devices = params["devices"]
@ -60,7 +70,7 @@ class Container(Extension):
logger.info("[*] Running...") logger.info("[*] Running...")
return b"[*] Running..." return b"[*] Running..."
def container_stop(self, type, id, params, conn): def container_stop(self, type, id, params, conn: socket):
name = params["name"] name = params["name"]
container = self.client.containers.get(name) container = self.client.containers.get(name)
@ -69,33 +79,33 @@ class Container(Extension):
logger.info("[*] Stopped") logger.info("[*] Stopped")
return b"[*] Stopped" return b"[*] Stopped"
def container_pause(self, type, id, params, conn): def container_pause(self, type, id, params, conn: socket):
name = params['name'] name = params["name"]
container = self.client.containers.get(name) container = self.client.containers.get(name)
container.pause() container.pause()
return b"[*] Paused" return b"[*] Paused"
def container_unpause(self, type, id, params, conn): def container_unpause(self, type, id, params, conn: socket):
name = params['name'] name = params["name"]
container = self.client.containers.get(name) container = self.client.containers.get(name)
container.unpause() container.unpause()
return b"[*] Unpaused" return b"[*] Unpaused"
def container_restart(self, type, id, params, conn): def container_restart(self, type, id, params, conn: socket):
name = params['name'] name = params["name"]
container = self.client.containers.get(name) container = self.client.containers.get(name)
container.restart() container.restart()
return b"[*] Restarted" return b"[*] Restarted"
def container_kill(self, type, id, params, conn): def container_kill(self, type, id, params, conn: socket):
# TODO: - # TODO: -
return b"[*] Killed" return b"[*] Killed"
def container_remove(self, type, id, params, conn): def container_remove(self, type, id, params, conn: socket):
name = params['name'] name = params["name"]
container = self.client.containers.get(name) container = self.client.containers.get(name)
container.remove() container.remove()

View File

@ -25,6 +25,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Serial(Extension): class Serial(Extension):
def __init__(self): def __init__(self):
self.type = "connector" self.type = "connector"
@ -38,7 +39,7 @@ class Serial(Extension):
connected = False connected = False
ser = None ser = None
try: try:
port_path = url.decode(client_encoding).replace('/', '') port_path = url.decode(client_encoding).replace("/", "")
if not ser: if not ser:
ser = serial.Serial(port_path, baudrate=9600, timeout=2) ser = serial.Serial(port_path, baudrate=9600, timeout=2)
connected = True connected = True

136
server.py
View File

@ -38,6 +38,7 @@ from base import (
Logger, Logger,
) )
logger = Logger(name="server") logger = Logger(name="server")
# initialization # initialization
@ -47,11 +48,11 @@ try:
config("SERVER_URL", default="") config("SERVER_URL", default="")
) )
server_connection_type = config("SERVER_CONNECTION_TYPE", default="proxy") server_connection_type = config("SERVER_CONNECTION_TYPE", default="proxy")
cakey = config("CA_KEY", default="ca.key") ca_key = config("CA_KEY", default="ca.key")
cacert = config("CA_CERT", default="ca.crt") ca_cert = config("CA_CERT", default="ca.crt")
certkey = config("CERT_KEY", default="cert.key") cert_key = config("CERT_KEY", default="cert.key")
certdir = config("CERT_DIR", default="certs/") cert_dir = config("CERT_DIR", default="certs/")
openssl_binpath = config("OPENSSL_BINPATH", default=find_openssl_binpath()) openssl_bin_path = config("OPENSSL_BINPATH", default=find_openssl_binpath())
client_encoding = config("CLIENT_ENCODING", default="utf-8") client_encoding = config("CLIENT_ENCODING", default="utf-8")
local_domain = config("LOCAL_DOMAIN", default="") local_domain = config("LOCAL_DOMAIN", default="")
proxy_pass = config("PROXY_PASS", default="") proxy_pass = config("PROXY_PASS", default="")
@ -87,7 +88,7 @@ if _username:
auth = HTTPBasicAuth(_username, _password) auth = HTTPBasicAuth(_username, _password)
def parse_first_data(data): def parse_first_data(data: bytes):
parsed_data = (b"", b"", b"", b"", b"") parsed_data = (b"", b"", b"", b"", b"")
try: try:
@ -126,13 +127,13 @@ def parse_first_data(data):
return parsed_data return parsed_data
def conn_string(conn, data, addr): def conn_string(conn: socket.socket, data: bytes, addr: bytes):
# JSON-RPC 2.0 request # JSON-RPC 2.0 request
def process_jsonrpc2(data): def process_jsonrpc2(_data: bytes):
jsondata = json.loads(data.decode(client_encoding, errors="ignore")) json_data = json.loads(_data.decode(client_encoding, errors="ignore"))
if jsondata["jsonrpc"] == "2.0": if json_data["jsonrpc"] == "2.0":
jsonrpc2_server( jsonrpc2_server(
conn, jsondata["id"], jsondata["method"], jsondata["params"] conn, json_data["id"], json_data["method"], json_data["params"]
) )
return True return True
return False return False
@ -166,42 +167,44 @@ def conn_string(conn, data, addr):
proxy_server(webserver, port, scheme, method, url, conn, addr, data) proxy_server(webserver, port, scheme, method, url, conn, addr, data)
def jsonrpc2_server(conn, id, method, params): def jsonrpc2_server(
conn: socket.socket, _id: str, method: str, params: dict[str, str | int]
):
if method == "relay_accept": if method == "relay_accept":
accepted_relay[id] = conn accepted_relay[_id] = conn
connection_speed = params["connection_speed"] connection_speed = params["connection_speed"]
logger.info("[*] connection speed: %s milliseconds" % (str(connection_speed))) logger.info("[*] connection speed: %s milliseconds" % str(connection_speed))
while conn.fileno() > -1: while conn.fileno() > -1:
time.sleep(1) time.sleep(1)
del accepted_relay[id] del accepted_relay[_id]
logger.info("[*] relay destroyed: %s" % id) logger.info("[*] relay destroyed: %s" % _id)
else: else:
Extension.dispatch_rpcmethod(method, "call", id, params, conn) Extension.dispatch_rpcmethod(method, "call", _id, params, conn)
# return in conn_string() # return in conn_string()
def proxy_connect(webserver, conn): def proxy_connect(webserver: bytes, conn: socket.socket):
hostname = webserver.decode(client_encoding) hostname = webserver.decode(client_encoding)
certpath = "%s/%s.crt" % (certdir.rstrip("/"), hostname) cert_path = "%s/%s.crt" % (cert_dir.rstrip("/"), hostname)
if not os.path.exists(certdir): if not os.path.exists(cert_dir):
os.makedirs(certdir) os.makedirs(cert_dir)
# https://stackoverflow.com/questions/24055036/handle-https-request-in-proxy-server-by-c-sharp-connect-tunnel # https://stackoverflow.com/questions/24055036/handle-https-request-in-proxy-server-by-c-sharp-connect-tunnel
conn.send(b"HTTP/1.1 200 Connection Established\r\n\r\n") conn.send(b"HTTP/1.1 200 Connection Established\r\n\r\n")
# https://github.com/inaz2/proxy2/blob/master/proxy2.py # https://github.com/inaz2/proxy2/blob/master/proxy2.py
try: try:
if not os.path.isfile(certpath): if not os.path.isfile(cert_path):
epoch = "%d" % (time.time() * 1000) epoch = "%d" % (time.time() * 1000)
p1 = Popen( p1 = Popen(
[ [
openssl_binpath, openssl_bin_path,
"req", "req",
"-new", "-new",
"-key", "-key",
certkey, cert_key,
"-subj", "-subj",
"/CN=%s" % hostname, "/CN=%s" % hostname,
], ],
@ -209,19 +212,19 @@ def proxy_connect(webserver, conn):
) )
p2 = Popen( p2 = Popen(
[ [
openssl_binpath, openssl_bin_path,
"x509", "x509",
"-req", "-req",
"-days", "-days",
"3650", "3650",
"-CA", "-CA",
cacert, ca_cert,
"-CAkey", "-CAkey",
cakey, ca_key,
"-set_serial", "-set_serial",
epoch, epoch,
"-out", "-out",
certpath, cert_path,
], ],
stdin=p1.stdout, stdin=p1.stdout,
stderr=PIPE, stderr=PIPE,
@ -232,20 +235,20 @@ def proxy_connect(webserver, conn):
"[*] OpenSSL distribution not found on this system. Skipping certificate issuance.", "[*] OpenSSL distribution not found on this system. Skipping certificate issuance.",
exc_info=e, exc_info=e,
) )
certpath = "default.crt" cert_path = "default.crt"
except Exception as e: except Exception as e:
logger.error("[*] Skipping certificate issuance.", exc_info=e) logger.error("[*] Skipping certificate issuance.", exc_info=e)
certpath = "default.crt" cert_path = "default.crt"
logger.info("[*] Certificate file: %s" % (certpath)) logger.info("[*] Certificate file: %s" % cert_path)
logger.info("[*] Private key file: %s" % (certkey)) logger.info("[*] Private key file: %s" % cert_key)
# https://stackoverflow.com/questions/11255530/python-simple-ssl-socket-server # https://stackoverflow.com/questions/11255530/python-simple-ssl-socket-server
# https://docs.python.org/3/library/ssl.html # https://docs.python.org/3/library/ssl.html
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
context.check_hostname = False context.check_hostname = False
context.verify_mode = ssl.CERT_NONE context.verify_mode = ssl.CERT_NONE
context.load_cert_chain(certfile=certpath, keyfile=certkey) context.load_cert_chain(certfile=cert_path, keyfile=cert_key)
try: try:
# https://stackoverflow.com/questions/11255530/python-simple-ssl-socket-server # https://stackoverflow.com/questions/11255530/python-simple-ssl-socket-server
@ -256,12 +259,14 @@ def proxy_connect(webserver, conn):
"[*] SSL negotiation failed.", "[*] SSL negotiation failed.",
exc_info=e, exc_info=e,
) )
return (conn, b"") return conn, b""
return (conn, data) return conn, data
def proxy_check_filtered(data, webserver, port, scheme, method, url): def proxy_check_filtered(
data: bytes, webserver: bytes, port: bytes, scheme: bytes, method: bytes, url: bytes
):
filtered = False filtered = False
filters = Extension.get_filters() filters = Extension.get_filters()
@ -272,7 +277,16 @@ def proxy_check_filtered(data, webserver, port, scheme, method, url):
return filtered return filtered
def proxy_server(webserver, port, scheme, method, url, conn, addr, data): def proxy_server(
webserver: bytes,
port: bytes,
scheme: bytes,
method: bytes,
url: bytes,
conn: socket.socket,
addr: bytes,
data: bytes,
):
try: try:
logger.info("[*] Started the request. %s" % (str(addr[0]))) logger.info("[*] Started the request. %s" % (str(addr[0])))
@ -296,14 +310,11 @@ def proxy_server(webserver, port, scheme, method, url, conn, addr, data):
_, _, _, method, url = parse_first_data(data) _, _, _, method, url = parse_first_data(data)
# https://stackoverflow.com/questions/44343739/python-sockets-ssl-eof-occurred-in-violation-of-protocol # https://stackoverflow.com/questions/44343739/python-sockets-ssl-eof-occurred-in-violation-of-protocol
def sock_close(sock, is_ssl=False): def sock_close(_sock: socket.socket):
# if is_ssl: _sock.close()
# sock = sock.unwrap()
# sock.shutdown(socket.SHUT_RDWR)
sock.close()
# Wait to see if there is more data to transmit # Wait to see if there is more data to transmit
def sendall(sock, conn, data): def sendall(_sock: socket.socket, _conn: socket.socket, _data: bytes):
# send first chuck # send first chuck
if proxy_check_filtered(data, webserver, port, scheme, method, url): if proxy_check_filtered(data, webserver, port, scheme, method, url):
sock.close() sock.close()
@ -324,7 +335,7 @@ def proxy_server(webserver, port, scheme, method, url, conn, addr, data):
if proxy_check_filtered( if proxy_check_filtered(
buffered, webserver, port, scheme, method, url buffered, webserver, port, scheme, method, url
): ):
sock_close(sock, is_ssl) sock_close(sock)
raise Exception("Filtered request") raise Exception("Filtered request")
sock.send(chunk) sock.send(chunk)
if len(buffered) > buffer_size * 2: if len(buffered) > buffer_size * 2:
@ -354,7 +365,7 @@ def proxy_server(webserver, port, scheme, method, url, conn, addr, data):
i = 0 i = 0
is_http_403 = False is_http_403 = False
buffered = b"" _buffered = b""
while True: while True:
chunk = sock.recv(buffer_size) chunk = sock.recv(buffer_size)
if not chunk: if not chunk:
@ -362,24 +373,26 @@ def proxy_server(webserver, port, scheme, method, url, conn, addr, data):
if i == 0 and chunk.find(b"HTTP/1.1 403") == 0: if i == 0 and chunk.find(b"HTTP/1.1 403") == 0:
is_http_403 = True is_http_403 = True
break break
buffered += chunk _buffered += chunk
if proxy_check_filtered(buffered, webserver, port, scheme, method, url): if proxy_check_filtered(
sock_close(sock, is_ssl) _buffered, webserver, port, scheme, method, url
):
sock_close(sock)
add_filtered_host(webserver.decode(client_encoding), "127.0.0.1") add_filtered_host(webserver.decode(client_encoding), "127.0.0.1")
raise Exception("Filtered response") raise Exception("Filtered response")
conn.send(chunk) conn.send(chunk)
if len(buffered) > buffer_size * 2: if len(_buffered) > buffer_size * 2:
buffered = buffered[-buffer_size * 2 :] _buffered = _buffered[-buffer_size * 2 :]
i += 1 i += 1
# when blocked # when blocked
if is_http_403: if is_http_403:
logger.warning( logger.warning(
"[*] Blocked the request by remote server: %s" "[*] Blocked the request by remote server: %s"
% (webserver.decode(client_encoding)) % webserver.decode(client_encoding)
) )
def bypass_callback(response, *args, **kwargs): def bypass_callback(response: requests.Response):
if response.status_code != 200: if response.status_code != 200:
conn.sendall(b'HTTP/1.1 403 Forbidden\r\n\r\n{"status":403}') conn.sendall(b'HTTP/1.1 403 Forbidden\r\n\r\n{"status":403}')
return return
@ -420,7 +433,7 @@ def proxy_server(webserver, port, scheme, method, url, conn, addr, data):
else: else:
conn.sendall(b'HTTP/1.1 403 Forbidden\r\n\r\n{"status":403}') conn.sendall(b'HTTP/1.1 403 Forbidden\r\n\r\n{"status":403}')
sock_close(sock, is_ssl) sock_close(sock)
logger.info( logger.info(
"[*] Received %s chunks. (%s bytes per chunk)" "[*] Received %s chunks. (%s bytes per chunk)"
@ -509,27 +522,27 @@ def proxy_server(webserver, port, scheme, method, url, conn, addr, data):
else: else:
resolved_address_list.remove(resolved_address_list[0]) resolved_address_list.remove(resolved_address_list[0])
logger.info("[*] the relay is gone. %s" % id) logger.info("[*] the relay is gone. %s" % id)
sock_close(sock, is_ssl) sock_close(sock)
return return
# get response # get response
i = 0 i = 0
buffered = b"" buffered = b""
while True: while True:
chunk = sock.recv(buffer_size) _chunk = sock.recv(buffer_size)
if not chunk: if not _chunk:
break break
buffered += chunk buffered += _chunk
if proxy_check_filtered(buffered, webserver, port, scheme, method, url): if proxy_check_filtered(buffered, webserver, port, scheme, method, url):
sock_close(sock, is_ssl) sock_close(sock)
add_filtered_host(webserver.decode(client_encoding), "127.0.0.1") add_filtered_host(webserver.decode(client_encoding), "127.0.0.1")
raise Exception("Filtered response") raise Exception("Filtered response")
conn.send(chunk) conn.send(_chunk)
if len(buffered) > buffer_size * 2: if len(buffered) > buffer_size * 2:
buffered = buffered[-buffer_size * 2 :] buffered = buffered[-buffer_size * 2 :]
i += 1 i += 1
sock_close(sock, is_ssl) sock_close(sock)
logger.info( logger.info(
"[*] Received %s chunks. (%s bytes per chunk)" "[*] Received %s chunks. (%s bytes per chunk)"
@ -604,7 +617,7 @@ def proxy_server(webserver, port, scheme, method, url, conn, addr, data):
# journaling a filtered hosts # journaling a filtered hosts
def add_filtered_host(domain, ip_address): def add_filtered_host(domain: str, ip_address: str):
hosts_path = "./filtered.hosts" hosts_path = "./filtered.hosts"
with open(hosts_path, "r") as file: with open(hosts_path, "r") as file:
lines = file.readlines() lines = file.readlines()
@ -619,6 +632,7 @@ def add_filtered_host(domain, ip_address):
def start(): # Main Program def start(): # Main Program
try: try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(("", listening_port)) sock.bind(("", listening_port))
sock.listen(max_connection) sock.listen(max_connection)
logger.info("[*] Server started successfully [ %d ]" % listening_port) logger.info("[*] Server started successfully [ %d ]" % listening_port)

30
smtp.py
View File

@ -20,7 +20,7 @@ from requests.auth import HTTPBasicAuth
from base import ( from base import (
extract_credentials, extract_credentials,
jsonrpc2_encode, jsonrpc2_encode,
Logger, Logger, jsonrpc2_decode,
) )
logger = Logger(name="smtp") logger = Logger(name="smtp")
@ -40,21 +40,22 @@ auth = None
if _username: if _username:
auth = HTTPBasicAuth(_username, _password) auth = HTTPBasicAuth(_username, _password)
class CaterpillarSMTPHandler: class CaterpillarSMTPHandler:
def __init__(self): def __init__(self):
self.smtpd_hostname = "CaterpillarSMTPServer" self.smtpd_hostname = "CaterpillarSMTPServer"
self.smtp_version = "0.1.6" self.smtp_version = "0.1.6"
async def handle_DATA(self, server, session, envelope): async def handle_DATA(self, server, session, envelope):
mailfrom = envelope.mail_from mail_from = envelope.mail_from
rcpttos = envelope.rcpt_tos rcpt_tos = envelope.rcpt_tos
data = envelope.content data = envelope.content
message = EmailMessage() message = EmailMessage()
message.set_content(data) message.set_content(data)
subject = message.get('Subject', '') subject = message.get("Subject", "")
to = message.get('To', '') to = message.get("To", "")
proxy_data = { proxy_data = {
"headers": { "headers": {
@ -64,7 +65,7 @@ class CaterpillarSMTPHandler:
}, },
"data": { "data": {
"to": to, "to": to,
"from": mailfrom, "from": mail_from,
"subject": subject, "subject": subject,
"message": data.decode("utf-8"), "message": data.decode("utf-8"),
}, },
@ -75,23 +76,23 @@ class CaterpillarSMTPHandler:
response = await asyncio.to_thread( response = await asyncio.to_thread(
requests.post, requests.post,
server_url, server_url,
headers=proxy_data['headers'], headers=proxy_data["headers"],
data=raw_data, data=raw_data,
auth=auth auth=auth,
) )
if response.status_code == 200: if response.status_code == 200:
type, id, rpcdata = jsonrpc2_decode(response.text) _type, _id, rpc_data = jsonrpc2_decode(response.text)
if rpcdata['success']: if rpc_data["success"]:
logger.info("[*] Email sent successfully.") logger.info("[*] Email sent successfully.")
else: else:
raise Exception(f"({rpcdata['code']}) {rpcdata['message']}") raise Exception(f"({rpc_data['code']}) {rpc_data['message']}")
else: else:
raise Exception(f"Status {response.status_code}") raise Exception(f"Status {response.status_code}")
except Exception as e: except Exception as e:
logger.error("[*] Failed to send email", exc_info=e) logger.error("[*] Failed to send email", exc_info=e)
return '500 Could not process your message. ' + str(e) return "500 Could not process your message. " + str(e)
return '250 OK' return "250 OK"
# https://aiosmtpd-pepoluan.readthedocs.io/en/latest/migrating.html # https://aiosmtpd-pepoluan.readthedocs.io/en/latest/migrating.html
@ -101,11 +102,12 @@ def main():
# Run the event loop in a separate thread. # Run the event loop in a separate thread.
controller.start() controller.start()
# Wait for the user to press Return. # Wait for the user to press Return.
input('SMTP server running. Press Return to stop server and exit.') input("SMTP server running. Press Return to stop server and exit.")
controller.stop() controller.stop()
logger.warning("[*] User has requested an interrupt") logger.warning("[*] User has requested an interrupt")
logger.warning("[*] Application Exiting.....") logger.warning("[*] Application Exiting.....")
sys.exit() sys.exit()
if __name__ == "__main__": if __name__ == "__main__":
main() main()

12
web.py
View File

@ -49,18 +49,18 @@ def process_jsonrpc2():
conn = Connection(request) conn = Connection(request)
# JSON-RPC 2.0 request # JSON-RPC 2.0 request
jsondata = request.get_json(silent=True) json_data = request.get_json(silent=True)
if jsondata["jsonrpc"] == "2.0": if json_data["jsonrpc"] == "2.0":
return Extension.dispatch_rpcmethod( return Extension.dispatch_rpcmethod(
jsondata["method"], "call", jsondata["id"], jsondata["params"], conn json_data["method"], "call", json_data["id"], json_data["params"], conn
) )
# when error # when error
return jsonrpc2_error_encode({"message": "Not vaild JSON-RPC 2.0 request"}) return jsonrpc2_error_encode({"message": "Not valid JSON-RPC 2.0 request"})
def jsonrpc2_server(conn, id, method, params): def jsonrpc2_server(conn, _id, method, params):
return Extension.dispatch_rpcmethod(method, "call", id, params, conn) return Extension.dispatch_rpcmethod(method, "call", _id, params, conn)
class Connection: class Connection: