2025-12-25 upload

This commit is contained in:
“shengyudong”
2025-12-25 11:16:59 +08:00
commit 322ac74336
2241 changed files with 639966 additions and 0 deletions

View File

@@ -0,0 +1,610 @@
"""
Proxy Server Implementation using asyncio.
The very high level overview is as follows:
- Spawn one coroutine per client connection and create a reverse proxy layer to example.com
- Process any commands from layer (such as opening a server connection)
- Wait for any IO and send it as events to top layer.
"""
import abc
import asyncio
import collections
import logging
import time
from collections.abc import Awaitable
from collections.abc import Callable
from collections.abc import MutableMapping
from contextlib import contextmanager
from dataclasses import dataclass
from types import TracebackType
from typing import Literal
from OpenSSL import SSL
import mitmproxy_rs
from mitmproxy import http
from mitmproxy import options as moptions
from mitmproxy import tls
from mitmproxy.connection import Address
from mitmproxy.connection import Client
from mitmproxy.connection import Connection
from mitmproxy.connection import ConnectionState
from mitmproxy.proxy import commands
from mitmproxy.proxy import events
from mitmproxy.proxy import layer
from mitmproxy.proxy import layers
from mitmproxy.proxy import mode_specs
from mitmproxy.proxy import server_hooks
from mitmproxy.proxy.context import Context
from mitmproxy.proxy.layers.http import HTTPMode
from mitmproxy.utils import asyncio_utils
from mitmproxy.utils import human
from mitmproxy.utils.data import pkg_data
logger = logging.getLogger(__name__)
UDP_TIMEOUT = 20
class TimeoutWatchdog:
last_activity: float
timeout: int
can_timeout: asyncio.Event
blocker: int
def __init__(self, timeout: int, callback: Callable[[], Awaitable]):
self.timeout = timeout
self.callback = callback
self.last_activity = time.time()
self.can_timeout = asyncio.Event()
self.can_timeout.set()
self.blocker = 0
def register_activity(self):
self.last_activity = time.time()
async def watch(self):
try:
while True:
await self.can_timeout.wait()
await asyncio.sleep(self.timeout - (time.time() - self.last_activity))
if self.last_activity + self.timeout < time.time():
await self.callback()
return
except asyncio.CancelledError:
return
@contextmanager
def disarm(self):
self.can_timeout.clear()
self.blocker += 1
try:
yield
finally:
self.blocker -= 1
if self.blocker == 0:
self.register_activity()
self.can_timeout.set()
@dataclass
class ConnectionIO:
handler: asyncio.Task | None = None
reader: asyncio.StreamReader | mitmproxy_rs.Stream | None = None
writer: asyncio.StreamWriter | mitmproxy_rs.Stream | None = None
class ConnectionHandler(metaclass=abc.ABCMeta):
transports: MutableMapping[Connection, ConnectionIO]
timeout_watchdog: TimeoutWatchdog
client: Client
max_conns: collections.defaultdict[Address, asyncio.Semaphore]
layer: "layer.Layer"
wakeup_timer: set[asyncio.Task]
def __init__(self, context: Context) -> None:
self.client = context.client
self.transports = {}
self.max_conns = collections.defaultdict(lambda: asyncio.Semaphore(5))
self.wakeup_timer = set()
# Ask for the first layer right away.
# In a reverse proxy scenario, this is necessary as we would otherwise hang
# on protocols that start with a server greeting.
self.layer = layer.NextLayer(context, ask_on_start=True)
if self.client.transport_protocol == "tcp":
timeout = context.options.tcp_timeout
else:
timeout = UDP_TIMEOUT
self.timeout_watchdog = TimeoutWatchdog(timeout, self.on_timeout)
self._server_event_lock = asyncio.Lock()
# workaround for https://bugs.python.org/issue40124 / https://bugs.python.org/issue29930
self._drain_lock = asyncio.Lock()
async def handle_client(self) -> None:
asyncio_utils.set_current_task_debug_info(
name=f"client handler",
client=self.client.peername,
)
watch = asyncio_utils.create_task(
self.timeout_watchdog.watch(),
name="timeout watchdog",
keep_ref=False,
client=self.client.peername,
)
self.log("client connect")
await self.handle_hook(server_hooks.ClientConnectedHook(self.client))
if self.client.error:
self.log("client kill connection")
writer = self.transports.pop(self.client).writer
assert writer
writer.close()
else:
await self.server_event(events.Start())
handler = asyncio_utils.create_task(
self.handle_connection(self.client),
name=f"client connection handler",
keep_ref=False,
client=self.client.peername,
)
self.transports[self.client].handler = handler
await asyncio.wait([handler])
if not handler.cancelled() and (e := handler.exception()):
self.log(
f"connection handler has crashed: {e}",
logging.ERROR,
exc_info=(type(e), e, e.__traceback__),
)
watch.cancel()
while self.wakeup_timer:
timer = self.wakeup_timer.pop()
timer.cancel()
self.log("client disconnect")
self.client.timestamp_end = time.time()
await self.handle_hook(server_hooks.ClientDisconnectedHook(self.client))
if self.transports:
self.log("closing transports...", logging.DEBUG)
for io in self.transports.values():
if io.handler:
io.handler.cancel("client disconnected")
await asyncio.wait(
[x.handler for x in self.transports.values() if x.handler]
)
self.log("transports closed!", logging.DEBUG)
async def open_connection(self, command: commands.OpenConnection) -> None:
if not command.connection.address:
self.log(f"Cannot open connection, no hostname given.")
await self.server_event(
events.OpenConnectionCompleted(
command, f"Cannot open connection, no hostname given."
)
)
return
hook_data = server_hooks.ServerConnectionHookData(
client=self.client, server=command.connection
)
await self.handle_hook(server_hooks.ServerConnectHook(hook_data))
if err := command.connection.error:
self.log(
f"server connection to {human.format_address(command.connection.address)} killed before connect: {err}"
)
await self.handle_hook(server_hooks.ServerConnectErrorHook(hook_data))
await self.server_event(
events.OpenConnectionCompleted(command, f"Connection killed: {err}")
)
return
async with self.max_conns[command.connection.address]:
reader: asyncio.StreamReader | mitmproxy_rs.Stream
writer: asyncio.StreamWriter | mitmproxy_rs.Stream
try:
command.connection.timestamp_start = time.time()
if command.connection.transport_protocol == "tcp":
reader, writer = await asyncio.open_connection(
*command.connection.address,
local_addr=command.connection.sockname,
)
elif command.connection.transport_protocol == "udp":
reader = writer = await mitmproxy_rs.udp.open_udp_connection(
*command.connection.address,
local_addr=command.connection.sockname,
)
else:
raise AssertionError(command.connection.transport_protocol)
except (OSError, asyncio.CancelledError) as e:
err = str(e)
if not err: # str(CancelledError()) returns empty string.
err = "connection cancelled"
self.log(f"error establishing server connection: {err}")
command.connection.error = err
await self.handle_hook(server_hooks.ServerConnectErrorHook(hook_data))
await self.server_event(events.OpenConnectionCompleted(command, err))
if isinstance(e, asyncio.CancelledError):
# From https://docs.python.org/3/library/asyncio-exceptions.html#asyncio.CancelledError:
# > In almost all situations the exception must be re-raised.
# It is not really defined what almost means here, but we play safe.
raise
else:
if command.connection.transport_protocol == "tcp":
# TODO: Rename to `timestamp_setup` and make it agnostic for both TCP (SYN/ACK) and UDP (DNS resl.)
command.connection.timestamp_tcp_setup = time.time()
command.connection.state = ConnectionState.OPEN
command.connection.peername = writer.get_extra_info("peername")
command.connection.sockname = writer.get_extra_info("sockname")
self.transports[command.connection] = ConnectionIO(
handler=asyncio.current_task(),
reader=reader,
writer=writer,
)
assert command.connection.peername
if command.connection.address[0] != command.connection.peername[0]:
addr = f"{human.format_address(command.connection.address)} ({human.format_address(command.connection.peername)})"
else:
addr = human.format_address(command.connection.address)
self.log(f"server connect {addr}")
await self.handle_hook(server_hooks.ServerConnectedHook(hook_data))
await self.server_event(events.OpenConnectionCompleted(command, None))
try:
await self.handle_connection(command.connection)
finally:
self.log(f"server disconnect {addr}")
command.connection.timestamp_end = time.time()
await self.handle_hook(
server_hooks.ServerDisconnectedHook(hook_data)
)
async def wakeup(self, request: commands.RequestWakeup) -> None:
await asyncio.sleep(request.delay)
task = asyncio.current_task()
assert task is not None
self.wakeup_timer.discard(task)
await self.server_event(events.Wakeup(request))
async def handle_connection(self, connection: Connection) -> None:
"""
Handle a connection for its entire lifetime.
This means we read until EOF,
but then possibly also keep on waiting for our side of the connection to be closed.
"""
cancelled = None
reader = self.transports[connection].reader
assert reader
while True:
try:
data = await reader.read(65535)
if not data:
raise OSError("Connection closed by peer.")
except OSError:
break
except asyncio.CancelledError as e:
cancelled = e
break
await self.server_event(events.DataReceived(connection, data))
try:
await self.drain_writers()
except asyncio.CancelledError as e:
cancelled = e
break
if cancelled is None and connection.transport_protocol == "tcp":
# TCP connections can be half-closed.
connection.state &= ~ConnectionState.CAN_READ
else:
connection.state = ConnectionState.CLOSED
await self.server_event(events.ConnectionClosed(connection))
if connection.state is ConnectionState.CAN_WRITE:
# we may still use this connection to *send* stuff,
# even though the remote has closed their side of the connection.
# to make this work we keep this task running and wait for cancellation.
try:
await asyncio.Event().wait()
except asyncio.CancelledError as e:
cancelled = e
try:
writer = self.transports[connection].writer
assert writer
writer.close()
except OSError:
pass
self.transports.pop(connection)
if cancelled:
raise cancelled
async def drain_writers(self):
"""
Drain all writers to create some backpressure. We won't continue reading until there's space available in our
write buffers, so if we cannot write fast enough our own read buffers run full and the TCP recv stream is throttled.
"""
async with self._drain_lock:
for transport in list(self.transports.values()):
if transport.writer is not None:
try:
await transport.writer.drain()
except OSError as e:
if transport.handler is not None:
transport.handler.cancel(f"Error sending data: {e}")
async def on_timeout(self) -> None:
try:
handler = self.transports[self.client].handler
except KeyError: # pragma: no cover
# there is a super short window between connection close and watchdog cancellation
pass
else:
if self.client.transport_protocol == "tcp":
self.log(f"Closing connection due to inactivity: {self.client}")
assert handler
handler.cancel("timeout")
async def hook_task(self, hook: commands.StartHook) -> None:
await self.handle_hook(hook)
if hook.blocking:
await self.server_event(events.HookCompleted(hook))
@abc.abstractmethod
async def handle_hook(self, hook: commands.StartHook) -> None:
pass
def log(
self,
message: str,
level: int = logging.INFO,
exc_info: Literal[True]
| tuple[type[BaseException], BaseException, TracebackType | None]
| None = None,
) -> None:
logger.log(
level, message, extra={"client": self.client.peername}, exc_info=exc_info
)
async def server_event(self, event: events.Event) -> None:
# server_event is supposed to be completely sync without any `await` that could pause execution.
# However, create_task with an [eager task factory] will schedule tasks immediately,
# which causes [reentrancy issues]. So we put the entire thing behind a lock.
#
# [eager task factory]: https://docs.python.org/3/library/asyncio-task.html#eager-task-factory
# [reentrancy issues]: https://github.com/mitmproxy/mitmproxy/issues/7027.
async with self._server_event_lock:
# No `await` beyond this point.
self.timeout_watchdog.register_activity()
try:
layer_commands = self.layer.handle_event(event)
for command in layer_commands:
if isinstance(command, commands.OpenConnection):
assert command.connection not in self.transports
handler = asyncio_utils.create_task(
self.open_connection(command),
name=f"server connection handler {command.connection.address}",
keep_ref=False,
client=self.client.peername,
)
self.transports[command.connection] = ConnectionIO(
handler=handler
)
elif isinstance(command, commands.RequestWakeup):
task = asyncio_utils.create_task(
self.wakeup(command),
name=f"wakeup timer ({command.delay:.1f}s)",
keep_ref=False,
client=self.client.peername,
)
assert task is not None
self.wakeup_timer.add(task)
elif (
isinstance(command, commands.ConnectionCommand)
and command.connection not in self.transports
):
pass # The connection has already been closed.
elif isinstance(command, commands.SendData):
writer = self.transports[command.connection].writer
assert writer
if not writer.is_closing():
writer.write(command.data)
elif isinstance(command, commands.CloseTcpConnection):
self.close_connection(command.connection, command.half_close)
elif isinstance(command, commands.CloseConnection):
self.close_connection(command.connection, False)
elif isinstance(command, commands.StartHook):
asyncio_utils.create_task(
self.hook_task(command),
name=f"handle_hook({command.name})",
keep_ref=True,
client=self.client.peername,
)
elif isinstance(command, commands.Log):
self.log(command.message, command.level)
else:
raise RuntimeError(f"Unexpected command: {command}")
except Exception:
self.log(f"mitmproxy has crashed!", logging.ERROR, exc_info=True)
def close_connection(
self, connection: Connection, half_close: bool = False
) -> None:
if half_close:
if not connection.state & ConnectionState.CAN_WRITE:
return
self.log(f"half-closing {connection}", logging.DEBUG)
try:
writer = self.transports[connection].writer
assert writer
if not writer.is_closing():
writer.write_eof()
except OSError:
# if we can't write to the socket anymore we presume it completely dead.
connection.state = ConnectionState.CLOSED
else:
connection.state &= ~ConnectionState.CAN_WRITE
else:
connection.state = ConnectionState.CLOSED
if connection.state is ConnectionState.CLOSED:
handler = self.transports[connection].handler
assert handler
handler.cancel("closed by command")
class LiveConnectionHandler(ConnectionHandler, metaclass=abc.ABCMeta):
def __init__(
self,
reader: asyncio.StreamReader | mitmproxy_rs.Stream,
writer: asyncio.StreamWriter | mitmproxy_rs.Stream,
options: moptions.Options,
mode: mode_specs.ProxyMode,
) -> None:
client = Client(
transport_protocol=writer.get_extra_info("transport_protocol", "tcp"),
peername=writer.get_extra_info("peername"),
sockname=writer.get_extra_info("sockname"),
timestamp_start=time.time(),
proxy_mode=mode,
state=ConnectionState.OPEN,
)
context = Context(client, options)
super().__init__(context)
self.transports[client] = ConnectionIO(
handler=None, reader=reader, writer=writer
)
class SimpleConnectionHandler(LiveConnectionHandler): # pragma: no cover
"""Simple handler that does not really process any hooks."""
hook_handlers: dict[str, Callable]
def __init__(self, reader, writer, options, mode, hook_handlers):
super().__init__(reader, writer, options, mode)
self.hook_handlers = hook_handlers
async def handle_hook(self, hook: commands.StartHook) -> None:
if hook.name in self.hook_handlers:
self.hook_handlers[hook.name](*hook.args())
if __name__ == "__main__": # pragma: no cover
# simple standalone implementation for testing.
loop = asyncio.get_event_loop()
opts = moptions.Options()
# options duplicated here to simplify testing setup
opts.add_option(
"store_streamed_bodies",
bool,
False,
"",
)
opts.add_option(
"connection_strategy",
str,
"lazy",
"Determine when server connections should be established.",
choices=("eager", "lazy"),
)
opts.add_option(
"keep_host_header",
bool,
False,
"""
Reverse Proxy: Keep the original host header instead of rewriting it
to the reverse proxy target.
""",
)
async def handle(reader, writer):
layer_stack = [
# lambda ctx: layers.ServerTLSLayer(ctx),
# lambda ctx: layers.HttpLayer(ctx, HTTPMode.regular),
# lambda ctx: setattr(ctx.server, "tls", True) or layers.ServerTLSLayer(ctx),
# lambda ctx: layers.ClientTLSLayer(ctx),
lambda ctx: layers.modes.ReverseProxy(ctx),
lambda ctx: layers.HttpLayer(ctx, HTTPMode.transparent),
]
def next_layer(nl: layer.NextLayer):
layr = layer_stack.pop(0)(nl.context)
layr.debug = " " * len(nl.context.layers)
nl.layer = layr
def request(flow: http.HTTPFlow):
if "cached" in flow.request.path:
flow.response = http.Response.make(418, f"(cached) {flow.request.text}")
if "toggle-tls" in flow.request.path:
if flow.request.url.startswith("https://"):
flow.request.url = flow.request.url.replace("https://", "http://")
else:
flow.request.url = flow.request.url.replace("http://", "https://")
if "redirect" in flow.request.path:
flow.request.host = "httpbin.org"
def tls_start_client(tls_start: tls.TlsData):
# INSECURE
ssl_context = SSL.Context(SSL.SSLv23_METHOD)
ssl_context.use_privatekey_file(
pkg_data.path(
"../test/mitmproxy/data/verificationcerts/trusted-leaf.key"
)
)
ssl_context.use_certificate_chain_file(
pkg_data.path(
"../test/mitmproxy/data/verificationcerts/trusted-leaf.crt"
)
)
tls_start.ssl_conn = SSL.Connection(ssl_context)
tls_start.ssl_conn.set_accept_state()
def tls_start_server(tls_start: tls.TlsData):
# INSECURE
ssl_context = SSL.Context(SSL.SSLv23_METHOD)
tls_start.ssl_conn = SSL.Connection(ssl_context)
tls_start.ssl_conn.set_connect_state()
if tls_start.context.client.sni is not None:
tls_start.ssl_conn.set_tlsext_host_name(
tls_start.context.client.sni.encode()
)
await SimpleConnectionHandler(
reader,
writer,
opts,
mode_specs.ProxyMode.parse("reverse:http://127.0.0.1:3000/"),
{
"next_layer": next_layer,
"request": request,
"tls_start_client": tls_start_client,
"tls_start_server": tls_start_server,
},
).handle_client()
coro = asyncio.start_server(handle, "127.0.0.1", 8080, loop=loop)
server = loop.run_until_complete(coro)
# Serve requests until Ctrl+C is pressed
assert server.sockets
print(f"Serving on {human.format_address(server.sockets[0].getsockname())}")
try:
loop.run_forever()
except KeyboardInterrupt:
pass
# Close the server
server.close()
loop.run_until_complete(server.wait_closed())
loop.close()