""" This module defines "server instances", which manage the TCP/UDP servers spawned by mitmproxy as specified by the proxy mode. Example: mode = ProxyMode.parse("reverse:https://example.com") inst = ServerInstance.make(mode, manager_that_handles_callbacks) await inst.start() # TCP server is running now. """ from __future__ import annotations import asyncio import errno import json import logging import os import socket import sys import textwrap import typing from abc import ABCMeta from abc import abstractmethod from contextlib import contextmanager from pathlib import Path from typing import cast from typing import ClassVar from typing import Generic from typing import get_args from typing import TYPE_CHECKING from typing import TypeVar import mitmproxy_rs from mitmproxy import ctx from mitmproxy import flow from mitmproxy import platform from mitmproxy.connection import Address from mitmproxy.net import local_ip from mitmproxy.net.free_port import get_free_port from mitmproxy.proxy import commands from mitmproxy.proxy import layers from mitmproxy.proxy import mode_specs from mitmproxy.proxy import server from mitmproxy.proxy.context import Context from mitmproxy.proxy.layer import Layer from mitmproxy.utils import human if sys.version_info < (3, 11): from typing_extensions import Self # pragma: no cover else: from typing import Self if TYPE_CHECKING: from mitmproxy.master import Master logger = logging.getLogger(__name__) class ProxyConnectionHandler(server.LiveConnectionHandler): master: Master def __init__(self, master, r, w, options, mode): self.master = master super().__init__(r, w, options, mode) self.log_prefix = f"{human.format_address(self.client.peername)}: " async def handle_hook(self, hook: commands.StartHook) -> None: with self.timeout_watchdog.disarm(): # We currently only support single-argument hooks. (data,) = hook.args() await self.master.addons.handle_lifecycle(hook) if isinstance(data, flow.Flow): await data.wait_for_resume() # pragma: no cover M = TypeVar("M", bound=mode_specs.ProxyMode) class ServerManager(typing.Protocol): # temporary workaround: for UDP, we use the 4-tuple because we don't have a uuid. connections: dict[tuple | str, ProxyConnectionHandler] @contextmanager def register_connection( self, connection_id: tuple | str, handler: ProxyConnectionHandler ): ... # pragma: no cover class ServerInstance(Generic[M], metaclass=ABCMeta): __modes: ClassVar[dict[str, type[ServerInstance]]] = {} last_exception: Exception | None = None def __init__(self, mode: M, manager: ServerManager): self.mode: M = mode self.manager: ServerManager = manager def __init_subclass__(cls, **kwargs): """Register all subclasses so that make() finds them.""" # extract mode from Generic[Mode]. mode = get_args(cls.__orig_bases__[0])[0] # type: ignore if not isinstance(mode, TypeVar): assert issubclass(mode, mode_specs.ProxyMode) assert mode.type_name not in ServerInstance.__modes ServerInstance.__modes[mode.type_name] = cls @classmethod def make( cls, mode: mode_specs.ProxyMode | str, manager: ServerManager, ) -> Self: if isinstance(mode, str): mode = mode_specs.ProxyMode.parse(mode) inst = ServerInstance.__modes[mode.type_name](mode, manager) if not isinstance(inst, cls): raise ValueError(f"{mode!r} is not a spec for a {cls.__name__} server.") return inst @property @abstractmethod def is_running(self) -> bool: pass async def start(self) -> None: try: await self._start() except Exception as e: self.last_exception = e raise else: self.last_exception = None if self.listen_addrs: addrs = " and ".join({human.format_address(a) for a in self.listen_addrs}) logger.info(f"{self.mode.description} listening at {addrs}.") else: logger.info(f"{self.mode.description} started.") async def stop(self) -> None: listen_addrs = self.listen_addrs try: await self._stop() except Exception as e: self.last_exception = e raise else: self.last_exception = None if listen_addrs: addrs = " and ".join({human.format_address(a) for a in listen_addrs}) logger.info(f"{self.mode.description} at {addrs} stopped.") else: logger.info(f"{self.mode.description} stopped.") @abstractmethod async def _start(self) -> None: pass @abstractmethod async def _stop(self) -> None: pass @property @abstractmethod def listen_addrs(self) -> tuple[Address, ...]: pass @abstractmethod def make_top_layer(self, context: Context) -> Layer: pass def to_json(self) -> dict: return { "type": self.mode.type_name, "description": self.mode.description, "full_spec": self.mode.full_spec, "is_running": self.is_running, "last_exception": str(self.last_exception) if self.last_exception else None, "listen_addrs": self.listen_addrs, } async def handle_stream( self, reader: asyncio.StreamReader | mitmproxy_rs.Stream, writer: asyncio.StreamWriter | mitmproxy_rs.Stream | None = None, ) -> None: if writer is None: assert isinstance(reader, mitmproxy_rs.Stream) writer = reader handler = ProxyConnectionHandler( ctx.master, reader, writer, ctx.options, self.mode ) handler.layer = self.make_top_layer(handler.layer.context) if isinstance(self.mode, mode_specs.TransparentMode): assert isinstance(writer, asyncio.StreamWriter) s = cast(socket.socket, writer.get_extra_info("socket")) try: assert platform.original_addr original_dst = platform.original_addr(s) except Exception as e: logger.error(f"Transparent mode failure: {e!r}") writer.close() return else: handler.layer.context.client.sockname = original_dst handler.layer.context.server.address = original_dst elif isinstance( self.mode, (mode_specs.WireGuardMode, mode_specs.LocalMode, mode_specs.TunMode), ): # pragma: no cover on platforms without wg-test-client handler.layer.context.server.address = writer.get_extra_info( "remote_endpoint", handler.layer.context.client.sockname ) with self.manager.register_connection(handler.layer.context.client.id, handler): await handler.handle_client() class AsyncioServerInstance(ServerInstance[M], metaclass=ABCMeta): _servers: list[ asyncio.Server | mitmproxy_rs.udp.UdpServer | mitmproxy_rs.wireguard.WireGuardServer ] def __init__(self, *args, **kwargs) -> None: self._servers = [] super().__init__(*args, **kwargs) @property def is_running(self) -> bool: return bool(self._servers) @property def listen_addrs(self) -> tuple[Address, ...]: addrs = [] for s in self._servers: if isinstance( s, (mitmproxy_rs.udp.UdpServer, mitmproxy_rs.wireguard.WireGuardServer) ): addrs.append(s.getsockname()) else: try: addrs.extend(sock.getsockname() for sock in s.sockets) except OSError: # pragma: no cover pass # this can fail during shutdown, see https://github.com/mitmproxy/mitmproxy/issues/6529 return tuple(addrs) async def _start(self) -> None: assert not self._servers host = self.mode.listen_host(ctx.options.listen_host) port = self.mode.listen_port(ctx.options.listen_port) assert port is not None try: self._servers = await self.listen(host, port) except OSError as e: message = f"{self.mode.description} failed to listen on {host or '*'}:{port} with {e}" if e.errno == errno.EADDRINUSE and self.mode.custom_listen_port is None: assert ( self.mode.custom_listen_host is None ) # since [@ [listen_addr:]listen_port] message += f"\nTry specifying a different port by using `--mode {self.mode.full_spec}@{port + 2}`." raise OSError(e.errno, message, e.filename) from e async def _stop(self) -> None: assert self._servers try: for s in self._servers: s.close() # https://github.com/python/cpython/issues/104344 # await asyncio.gather(*[s.wait_closed() for s in self._servers]) finally: # we always reset _server and ignore failures self._servers = [] async def listen( self, host: str, port: int ) -> list[ asyncio.Server | mitmproxy_rs.udp.UdpServer | mitmproxy_rs.wireguard.WireGuardServer ]: if self.mode.transport_protocol not in ("tcp", "udp", "both"): raise AssertionError(self.mode.transport_protocol) # workaround for https://github.com/python/cpython/issues/89856: # We want both IPv4 and IPv6 sockets to bind to the same port. # This may fail (https://github.com/mitmproxy/mitmproxy/pull/5542#issuecomment-1222803291), # so we try to cover the 99% case and then give up and fall back to what asyncio does. if port == 0: try: return await self.listen(host, get_free_port()) except Exception as e: logger.debug( f"Failed to listen on a single port ({e!r}), falling back to default behavior." ) servers: list[ asyncio.Server | mitmproxy_rs.udp.UdpServer | mitmproxy_rs.wireguard.WireGuardServer ] = [] if self.mode.transport_protocol in ("tcp", "both"): servers.append(await asyncio.start_server(self.handle_stream, host, port)) if self.mode.transport_protocol in ("udp", "both"): # we start two servers for dual-stack support. # On Linux, this would also be achievable by toggling IPV6_V6ONLY off, but this here works cross-platform. if host == "": ipv4 = await self.start_udp_based_server("0.0.0.0", port) servers.append(ipv4) try: ipv6 = await self.start_udp_based_server( "::", ipv4.getsockname()[1] ) servers.append(ipv6) # pragma: no cover except Exception: # pragma: no cover logger.debug("Failed to listen on '::', listening on IPv4 only.") else: servers.append(await self.start_udp_based_server(host, port)) return servers async def start_udp_based_server( self, host, port ) -> mitmproxy_rs.udp.UdpServer | mitmproxy_rs.wireguard.WireGuardServer: return await mitmproxy_rs.udp.start_udp_server( host, port, self.handle_stream, ) class WireGuardServerInstance(AsyncioServerInstance[mode_specs.WireGuardMode]): server_key: str client_key: str pubkey: str def make_top_layer( self, context: Context ) -> Layer: # pragma: no cover on platforms without wg-test-client return layers.modes.TransparentProxy(context) async def _start(self) -> None: if self.mode.data: conf_path = Path(self.mode.data).expanduser() else: conf_path = Path(ctx.options.confdir).expanduser() / "wireguard.conf" if not conf_path.exists(): conf_path.parent.mkdir(parents=True, exist_ok=True) conf_path.write_text( json.dumps( { "server_key": mitmproxy_rs.wireguard.genkey(), "client_key": mitmproxy_rs.wireguard.genkey(), }, indent=4, ) ) try: c = json.loads(conf_path.read_text()) self.server_key = c["server_key"] self.client_key = c["client_key"] except Exception as e: raise ValueError(f"Invalid configuration file ({conf_path}): {e}") from e # error early on invalid keys self.pubkey = mitmproxy_rs.wireguard.pubkey(self.client_key) _ = mitmproxy_rs.wireguard.pubkey(self.server_key) await super()._start() conf = self.client_conf() assert conf logger.info("-" * 60 + "\n" + conf + "\n" + "-" * 60) async def start_udp_based_server( self, host, port ) -> mitmproxy_rs.wireguard.WireGuardServer: return await mitmproxy_rs.wireguard.start_wireguard_server( host, port, self.server_key, [self.pubkey], self.handle_stream, self.handle_stream, ) def client_conf(self) -> str | None: if not self._servers: return None host = ( self.mode.listen_host(ctx.options.listen_host) or local_ip.get_local_ip() or local_ip.get_local_ip6() ) port = self.mode.listen_port(ctx.options.listen_port) return textwrap.dedent( f""" [Interface] PrivateKey = {self.client_key} Address = 10.0.0.1/32 DNS = 10.0.0.53 [Peer] PublicKey = {mitmproxy_rs.wireguard.pubkey(self.server_key)} AllowedIPs = 0.0.0.0/0 Endpoint = {host}:{port} """ ).strip() def to_json(self) -> dict: return {"wireguard_conf": self.client_conf(), **super().to_json()} class LocalRedirectorInstance(ServerInstance[mode_specs.LocalMode]): _server: ClassVar[mitmproxy_rs.local.LocalRedirector | None] = None """The local redirector daemon. Will be started once and then reused for all future instances.""" _instance: ClassVar[LocalRedirectorInstance | None] = None """The current LocalRedirectorInstance. Will be unset again if an instance is stopped.""" listen_addrs = () @property def is_running(self) -> bool: return self._instance is not None def make_top_layer(self, context: Context) -> Layer: return layers.modes.TransparentProxy(context) @classmethod async def redirector_handle_stream( cls, stream: mitmproxy_rs.Stream, ) -> None: if cls._instance is not None: await cls._instance.handle_stream(stream) async def _start(self) -> None: if self._instance: raise RuntimeError("Cannot spawn more than one local redirector.") if self.mode.data: spec = f"{self.mode.data},!{os.getpid()}" else: spec = f"!{os.getpid()}" cls = self.__class__ cls._instance = self # assign before awaiting to avoid races if cls._server is None: try: cls._server = await mitmproxy_rs.local.start_local_redirector( cls.redirector_handle_stream, cls.redirector_handle_stream, ) except Exception: cls._instance = None raise cls._server.set_intercept(spec) async def _stop(self) -> None: assert self._instance assert self._server self.__class__._instance = None # We're not shutting down the server because we want to avoid additional UAC prompts. self._server.set_intercept("") class RegularInstance(AsyncioServerInstance[mode_specs.RegularMode]): def make_top_layer(self, context: Context) -> Layer: return layers.modes.HttpProxy(context) class UpstreamInstance(AsyncioServerInstance[mode_specs.UpstreamMode]): def make_top_layer(self, context: Context) -> Layer: return layers.modes.HttpUpstreamProxy(context) class TransparentInstance(AsyncioServerInstance[mode_specs.TransparentMode]): def make_top_layer(self, context: Context) -> Layer: return layers.modes.TransparentProxy(context) class ReverseInstance(AsyncioServerInstance[mode_specs.ReverseMode]): def make_top_layer(self, context: Context) -> Layer: return layers.modes.ReverseProxy(context) class Socks5Instance(AsyncioServerInstance[mode_specs.Socks5Mode]): def make_top_layer(self, context: Context) -> Layer: return layers.modes.Socks5Proxy(context) class DnsInstance(AsyncioServerInstance[mode_specs.DnsMode]): def make_top_layer(self, context: Context) -> Layer: return layers.DNSLayer(context) class TunInstance(ServerInstance[mode_specs.TunMode]): _server: mitmproxy_rs.tun.TunInterface | None = None listen_addrs = () def make_top_layer( self, context: Context ) -> Layer: # pragma: no cover mocked in tests return layers.modes.TransparentProxy(context) @property def is_running(self) -> bool: return self._server is not None @property def tun_name(self) -> str | None: if self._server: return self._server.tun_name() else: return None def to_json(self) -> dict: return {"tun_name": self.tun_name, **super().to_json()} async def _start(self) -> None: assert self._server is None self._server = await mitmproxy_rs.tun.create_tun_interface( self.handle_stream, self.handle_stream, tun_name=self.mode.data or None, ) logger.info(f"TUN interface created: {self._server.tun_name()}") async def _stop(self) -> None: assert self._server is not None try: self._server.close() await self._server.wait_closed() finally: self._server = None # class Http3Instance(AsyncioServerInstance[mode_specs.Http3Mode]): # def make_top_layer(self, context: Context) -> Layer: # return layers.modes.HttpProxy(context)