2025-12-25 upload

This commit is contained in:
“shengyudong”
2025-12-25 11:16:59 +08:00
commit 322ac74336
2241 changed files with 639966 additions and 0 deletions

View File

@@ -0,0 +1,27 @@
from . import modes
from .dns import DNSLayer
from .http import HttpLayer
from .quic import ClientQuicLayer
from .quic import QuicStreamLayer
from .quic import RawQuicLayer
from .quic import ServerQuicLayer
from .tcp import TCPLayer
from .tls import ClientTLSLayer
from .tls import ServerTLSLayer
from .udp import UDPLayer
from .websocket import WebsocketLayer
__all__ = [
"modes",
"DNSLayer",
"HttpLayer",
"QuicStreamLayer",
"RawQuicLayer",
"TCPLayer",
"UDPLayer",
"ClientQuicLayer",
"ClientTLSLayer",
"ServerQuicLayer",
"ServerTLSLayer",
"WebsocketLayer",
]

View File

@@ -0,0 +1,190 @@
import struct
import time
from dataclasses import dataclass
from typing import List
from typing import Literal
from mitmproxy import dns
from mitmproxy import flow as mflow
from mitmproxy.net.dns import response_codes
from mitmproxy.proxy import commands
from mitmproxy.proxy import events
from mitmproxy.proxy import layer
from mitmproxy.proxy.context import Context
from mitmproxy.proxy.utils import expect
_LENGTH_LABEL = struct.Struct("!H")
@dataclass
class DnsRequestHook(commands.StartHook):
"""
A DNS query has been received.
"""
flow: dns.DNSFlow
@dataclass
class DnsResponseHook(commands.StartHook):
"""
A DNS response has been received or set.
"""
flow: dns.DNSFlow
@dataclass
class DnsErrorHook(commands.StartHook):
"""
A DNS error has occurred.
"""
flow: dns.DNSFlow
def pack_message(
message: dns.DNSMessage, transport_protocol: Literal["tcp", "udp"]
) -> bytes:
packed = message.packed
if transport_protocol == "tcp":
return struct.pack("!H", len(packed)) + packed
else:
return packed
class DNSLayer(layer.Layer):
"""
Layer that handles resolving DNS queries.
"""
flows: dict[int, dns.DNSFlow]
req_buf: bytearray
resp_buf: bytearray
def __init__(self, context: Context):
super().__init__(context)
self.flows = {}
self.req_buf = bytearray()
self.resp_buf = bytearray()
def handle_request(
self, flow: dns.DNSFlow, msg: dns.DNSMessage
) -> layer.CommandGenerator[None]:
flow.request = msg # if already set, continue and query upstream again
yield DnsRequestHook(flow)
if flow.response:
yield from self.handle_response(flow, flow.response)
elif flow.error:
yield from self.handle_error(flow, flow.error.msg)
elif not self.context.server.address:
yield from self.handle_error(
flow, "No hook has set a response and there is no upstream server."
)
else:
if not self.context.server.connected:
err = yield commands.OpenConnection(self.context.server)
if err:
yield from self.handle_error(flow, str(err))
# cannot recover from this
return
packed = pack_message(flow.request, flow.server_conn.transport_protocol)
yield commands.SendData(self.context.server, packed)
def handle_response(
self, flow: dns.DNSFlow, msg: dns.DNSMessage
) -> layer.CommandGenerator[None]:
flow.response = msg
yield DnsResponseHook(flow)
if flow.response:
packed = pack_message(flow.response, flow.client_conn.transport_protocol)
yield commands.SendData(self.context.client, packed)
def handle_error(self, flow: dns.DNSFlow, err: str) -> layer.CommandGenerator[None]:
flow.error = mflow.Error(err)
yield DnsErrorHook(flow)
servfail = flow.request.fail(response_codes.SERVFAIL)
yield commands.SendData(
self.context.client,
pack_message(servfail, flow.client_conn.transport_protocol),
)
def unpack_message(self, data: bytes, from_client: bool) -> List[dns.DNSMessage]:
msgs: List[dns.DNSMessage] = []
buf = self.req_buf if from_client else self.resp_buf
if self.context.client.transport_protocol == "udp":
msgs.append(dns.DNSMessage.unpack(data, timestamp=time.time()))
elif self.context.client.transport_protocol == "tcp":
buf.extend(data)
size = len(buf)
offset = 0
while True:
if size - offset < _LENGTH_LABEL.size:
break
(expected_size,) = _LENGTH_LABEL.unpack_from(buf, offset)
offset += _LENGTH_LABEL.size
if expected_size == 0:
raise struct.error("Message length field cannot be zero")
if size - offset < expected_size:
offset -= _LENGTH_LABEL.size
break
data = bytes(buf[offset : expected_size + offset])
offset += expected_size
msgs.append(dns.DNSMessage.unpack(data, timestamp=time.time()))
del buf[:offset]
return msgs
@expect(events.Start)
def state_start(self, _) -> layer.CommandGenerator[None]:
self._handle_event = self.state_query
yield from ()
@expect(events.DataReceived, events.ConnectionClosed)
def state_query(self, event: events.Event) -> layer.CommandGenerator[None]:
assert isinstance(event, events.ConnectionEvent)
from_client = event.connection is self.context.client
if isinstance(event, events.DataReceived):
msgs: List[dns.DNSMessage] = []
try:
msgs = self.unpack_message(event.data, from_client)
except struct.error as e:
yield commands.Log(f"{event.connection} sent an invalid message: {e}")
yield commands.CloseConnection(event.connection)
self._handle_event = self.state_done
else:
for msg in msgs:
try:
flow = self.flows[msg.id]
except KeyError:
flow = dns.DNSFlow(
self.context.client, self.context.server, live=True
)
self.flows[msg.id] = flow
if from_client:
yield from self.handle_request(flow, msg)
else:
yield from self.handle_response(flow, msg)
elif isinstance(event, events.ConnectionClosed):
other_conn = self.context.server if from_client else self.context.client
if other_conn.connected:
yield commands.CloseConnection(other_conn)
self._handle_event = self.state_done
for flow in self.flows.values():
flow.live = False
else:
raise AssertionError(f"Unexpected event: {event}")
@expect(events.DataReceived, events.ConnectionClosed)
def state_done(self, _) -> layer.CommandGenerator[None]:
yield from ()
_handle_event = state_start

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,61 @@
import html
import textwrap
from dataclasses import dataclass
from mitmproxy import http
from mitmproxy.connection import Connection
from mitmproxy.proxy import commands
from mitmproxy.proxy import events
from mitmproxy.proxy import layer
from mitmproxy.proxy.context import Context
StreamId = int
@dataclass
class HttpEvent(events.Event):
# we need stream ids on every event to avoid race conditions
stream_id: StreamId
class HttpConnection(layer.Layer):
conn: Connection
def __init__(self, context: Context, conn: Connection):
super().__init__(context)
self.conn = conn
class HttpCommand(commands.Command):
pass
class ReceiveHttp(HttpCommand):
event: HttpEvent
def __init__(self, event: HttpEvent):
self.event = event
def __repr__(self) -> str:
return f"Receive({self.event})"
def format_error(status_code: int, message: str) -> bytes:
reason = http.status_codes.RESPONSES.get(status_code, "Unknown")
return (
textwrap.dedent(
f"""
<html>
<head>
<title>{status_code} {reason}</title>
</head>
<body>
<h1>{status_code} {reason}</h1>
<p>{html.escape(message)}</p>
</body>
</html>
"""
)
.strip()
.encode("utf8", "replace")
)

View File

@@ -0,0 +1,167 @@
import enum
import typing
from dataclasses import dataclass
from ._base import HttpEvent
from mitmproxy import http
from mitmproxy.http import HTTPFlow
from mitmproxy.net.http import status_codes
@dataclass
class RequestHeaders(HttpEvent):
request: http.Request
end_stream: bool
"""
If True, we already know at this point that there is no message body. This is useful for HTTP/2, where it allows
us to set END_STREAM on headers already (and some servers - Akamai - implicitly expect that).
In either case, this event will nonetheless be followed by RequestEndOfMessage.
"""
replay_flow: HTTPFlow | None = None
"""If set, the current request headers belong to a replayed flow, which should be reused."""
@dataclass
class ResponseHeaders(HttpEvent):
response: http.Response
end_stream: bool = False
# explicit constructors below to facilitate type checking in _http1/_http2
@dataclass
class RequestData(HttpEvent):
data: bytes
def __init__(self, stream_id: int, data: bytes):
self.stream_id = stream_id
self.data = data
@dataclass
class ResponseData(HttpEvent):
data: bytes
def __init__(self, stream_id: int, data: bytes):
self.stream_id = stream_id
self.data = data
@dataclass
class RequestTrailers(HttpEvent):
trailers: http.Headers
def __init__(self, stream_id: int, trailers: http.Headers):
self.stream_id = stream_id
self.trailers = trailers
@dataclass
class ResponseTrailers(HttpEvent):
trailers: http.Headers
def __init__(self, stream_id: int, trailers: http.Headers):
self.stream_id = stream_id
self.trailers = trailers
@dataclass
class RequestEndOfMessage(HttpEvent):
def __init__(self, stream_id: int):
self.stream_id = stream_id
@dataclass
class ResponseEndOfMessage(HttpEvent):
def __init__(self, stream_id: int):
self.stream_id = stream_id
class ErrorCode(enum.Enum):
GENERIC_CLIENT_ERROR = 1
GENERIC_SERVER_ERROR = 2
REQUEST_TOO_LARGE = 3
RESPONSE_TOO_LARGE = 4
CONNECT_FAILED = 5
PASSTHROUGH_CLOSE = 6
KILL = 7
HTTP_1_1_REQUIRED = 8
"""Client should fall back to HTTP/1.1 to perform request."""
DESTINATION_UNKNOWN = 9
"""Proxy does not know where to send request to."""
CLIENT_DISCONNECTED = 10
"""Client disconnected before receiving entire response."""
CANCEL = 11
"""Client or server cancelled h2/h3 stream."""
REQUEST_VALIDATION_FAILED = 12
RESPONSE_VALIDATION_FAILED = 13
def http_status_code(self) -> int | None:
match self:
# Client Errors
case (
ErrorCode.GENERIC_CLIENT_ERROR
| ErrorCode.REQUEST_VALIDATION_FAILED
| ErrorCode.DESTINATION_UNKNOWN
):
return status_codes.BAD_REQUEST
case ErrorCode.REQUEST_TOO_LARGE:
return status_codes.PAYLOAD_TOO_LARGE
case (
ErrorCode.CONNECT_FAILED
| ErrorCode.GENERIC_SERVER_ERROR
| ErrorCode.RESPONSE_VALIDATION_FAILED
| ErrorCode.RESPONSE_TOO_LARGE
):
return status_codes.BAD_GATEWAY
case (
ErrorCode.PASSTHROUGH_CLOSE
| ErrorCode.KILL
| ErrorCode.HTTP_1_1_REQUIRED
| ErrorCode.CLIENT_DISCONNECTED
| ErrorCode.CANCEL
):
return None
case other: # pragma: no cover
typing.assert_never(other)
@dataclass
class RequestProtocolError(HttpEvent):
message: str
code: ErrorCode = ErrorCode.GENERIC_CLIENT_ERROR
def __init__(self, stream_id: int, message: str, code: ErrorCode):
assert isinstance(code, ErrorCode)
self.stream_id = stream_id
self.message = message
self.code = code
@dataclass
class ResponseProtocolError(HttpEvent):
message: str
code: ErrorCode = ErrorCode.GENERIC_SERVER_ERROR
def __init__(self, stream_id: int, message: str, code: ErrorCode):
assert isinstance(code, ErrorCode)
self.stream_id = stream_id
self.message = message
self.code = code
__all__ = [
"ErrorCode",
"HttpEvent",
"RequestHeaders",
"RequestData",
"RequestEndOfMessage",
"ResponseHeaders",
"ResponseData",
"RequestTrailers",
"ResponseTrailers",
"ResponseEndOfMessage",
"RequestProtocolError",
"ResponseProtocolError",
]

View File

@@ -0,0 +1,122 @@
from dataclasses import dataclass
from mitmproxy import http
from mitmproxy.proxy import commands
@dataclass
class HttpRequestHeadersHook(commands.StartHook):
"""
HTTP request headers were successfully read. At this point, the body is empty.
"""
name = "requestheaders"
flow: http.HTTPFlow
@dataclass
class HttpRequestHook(commands.StartHook):
"""
The full HTTP request has been read.
Note: If request streaming is active, this event fires after the entire body has been streamed.
HTTP trailers, if present, have not been transmitted to the server yet and can still be modified.
Enabling streaming may cause unexpected event sequences: For example, `response` may now occur
before `request` because the server replied with "413 Payload Too Large" during upload.
"""
name = "request"
flow: http.HTTPFlow
@dataclass
class HttpResponseHeadersHook(commands.StartHook):
"""
HTTP response headers were successfully read. At this point, the body is empty.
"""
name = "responseheaders"
flow: http.HTTPFlow
@dataclass
class HttpResponseHook(commands.StartHook):
"""
The full HTTP response has been read.
Note: If response streaming is active, this event fires after the entire body has been streamed.
HTTP trailers, if present, have not been transmitted to the client yet and can still be modified.
"""
name = "response"
flow: http.HTTPFlow
@dataclass
class HttpErrorHook(commands.StartHook):
"""
An HTTP error has occurred, e.g. invalid server responses, or
interrupted connections. This is distinct from a valid server HTTP
error response, which is simply a response with an HTTP error code.
Every flow will receive either an error or an response event, but not both.
"""
name = "error"
flow: http.HTTPFlow
@dataclass
class HttpConnectHook(commands.StartHook):
"""
An HTTP CONNECT request was received. This event can be ignored for most practical purposes.
This event only occurs in regular and upstream proxy modes
when the client instructs mitmproxy to open a connection to an upstream host.
Setting a non 2xx response on the flow will return the response to the client and abort the connection.
CONNECT requests are HTTP proxy instructions for mitmproxy itself
and not forwarded. They do not generate the usual HTTP handler events,
but all requests going over the newly opened connection will.
"""
flow: http.HTTPFlow
@dataclass
class HttpConnectUpstreamHook(commands.StartHook):
"""
An HTTP CONNECT request is about to be sent to an upstream proxy.
This event can be ignored for most practical purposes.
This event can be used to set custom authentication headers for upstream proxies.
CONNECT requests do not generate the usual HTTP handler events,
but all requests going over the newly opened connection will.
"""
flow: http.HTTPFlow
@dataclass
class HttpConnectedHook(commands.StartHook):
"""
HTTP CONNECT was successful
> [!WARNING]
> This may fire before an upstream connection has been established
> if `connection_strategy` is set to `lazy` (default)
"""
flow: http.HTTPFlow
@dataclass
class HttpConnectErrorHook(commands.StartHook):
"""
HTTP CONNECT has failed.
This can happen when the upstream server is unreachable or proxy authentication is required.
In contrast to the `error` hook, `flow.error` is not guaranteed to be set.
"""
flow: http.HTTPFlow

View File

@@ -0,0 +1,502 @@
import abc
from collections.abc import Callable
from typing import Union
import h11
from h11._readers import ChunkedReader
from h11._readers import ContentLengthReader
from h11._readers import Http10Reader
from h11._receivebuffer import ReceiveBuffer
from ...context import Context
from ._base import format_error
from ._base import HttpConnection
from ._events import ErrorCode
from ._events import HttpEvent
from ._events import RequestData
from ._events import RequestEndOfMessage
from ._events import RequestHeaders
from ._events import RequestProtocolError
from ._events import ResponseData
from ._events import ResponseEndOfMessage
from ._events import ResponseHeaders
from ._events import ResponseProtocolError
from mitmproxy import http
from mitmproxy import version
from mitmproxy.connection import Connection
from mitmproxy.connection import ConnectionState
from mitmproxy.net.http import http1
from mitmproxy.net.http import status_codes
from mitmproxy.proxy import commands
from mitmproxy.proxy import events
from mitmproxy.proxy import layer
from mitmproxy.proxy.layers.http._base import ReceiveHttp
from mitmproxy.proxy.layers.http._base import StreamId
from mitmproxy.proxy.utils import expect
from mitmproxy.utils import human
TBodyReader = Union[ChunkedReader, Http10Reader, ContentLengthReader]
class Http1Connection(HttpConnection, metaclass=abc.ABCMeta):
stream_id: StreamId | None = None
request: http.Request | None = None
response: http.Response | None = None
request_done: bool = False
response_done: bool = False
# this is a bit of a hack to make both mypy and PyCharm happy.
state: Callable[[events.Event], layer.CommandGenerator[None]] | Callable
body_reader: TBodyReader
buf: ReceiveBuffer
ReceiveProtocolError: type[RequestProtocolError | ResponseProtocolError]
ReceiveData: type[RequestData | ResponseData]
ReceiveEndOfMessage: type[RequestEndOfMessage | ResponseEndOfMessage]
def __init__(self, context: Context, conn: Connection):
super().__init__(context, conn)
self.buf = ReceiveBuffer()
@abc.abstractmethod
def send(self, event: HttpEvent) -> layer.CommandGenerator[None]:
yield from () # pragma: no cover
@abc.abstractmethod
def read_headers(
self, event: events.ConnectionEvent
) -> layer.CommandGenerator[None]:
yield from () # pragma: no cover
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, HttpEvent):
yield from self.send(event)
else:
if (
isinstance(event, events.DataReceived)
and self.state != self.passthrough
):
self.buf += event.data
yield from self.state(event)
@expect(events.Start)
def start(self, _) -> layer.CommandGenerator[None]:
self.state = self.read_headers
yield from ()
state = start
def read_body(self, event: events.Event) -> layer.CommandGenerator[None]:
assert self.stream_id is not None
while True:
try:
if isinstance(event, events.DataReceived):
h11_event = self.body_reader(self.buf)
elif isinstance(event, events.ConnectionClosed):
h11_event = self.body_reader.read_eof()
else:
raise AssertionError(f"Unexpected event: {event}")
except h11.ProtocolError as e:
yield commands.CloseConnection(self.conn)
yield ReceiveHttp(
self.ReceiveProtocolError(
self.stream_id,
f"HTTP/1 protocol error: {e}",
code=self.ReceiveProtocolError.code,
)
)
return
if h11_event is None:
return
elif isinstance(h11_event, h11.Data):
data: bytes = bytes(h11_event.data)
if data:
yield ReceiveHttp(self.ReceiveData(self.stream_id, data))
elif isinstance(h11_event, h11.EndOfMessage):
assert self.request
if h11_event.headers:
raise NotImplementedError(f"HTTP trailers are not implemented yet.")
if self.request.data.method.upper() != b"CONNECT":
yield ReceiveHttp(self.ReceiveEndOfMessage(self.stream_id))
is_request = isinstance(self, Http1Server)
yield from self.mark_done(request=is_request, response=not is_request)
return
def wait(self, event: events.Event) -> layer.CommandGenerator[None]:
"""
We wait for the current flow to be finished before parsing the next message,
as we may want to upgrade to WebSocket or plain TCP before that.
"""
assert self.stream_id
if isinstance(event, events.DataReceived):
return
elif isinstance(event, events.ConnectionClosed):
# for practical purposes, we assume that a peer which sent at least a FIN
# is not interested in any more data from us, see
# see https://github.com/httpwg/http-core/issues/22
if event.connection.state is not ConnectionState.CLOSED:
yield commands.CloseConnection(event.connection)
yield ReceiveHttp(
self.ReceiveProtocolError(
self.stream_id,
f"Client disconnected.",
code=ErrorCode.CLIENT_DISCONNECTED,
)
)
else: # pragma: no cover
raise AssertionError(f"Unexpected event: {event}")
def done(self, event: events.ConnectionEvent) -> layer.CommandGenerator[None]:
yield from () # pragma: no cover
def make_pipe(self) -> layer.CommandGenerator[None]:
self.state = self.passthrough
if self.buf:
already_received = self.buf.maybe_extract_at_most(len(self.buf)) or b""
# Some clients send superfluous newlines after CONNECT, we want to eat those.
already_received = already_received.lstrip(b"\r\n")
if already_received:
yield from self.state(events.DataReceived(self.conn, already_received))
def passthrough(self, event: events.Event) -> layer.CommandGenerator[None]:
assert self.stream_id
if isinstance(event, events.DataReceived):
yield ReceiveHttp(self.ReceiveData(self.stream_id, event.data))
elif isinstance(event, events.ConnectionClosed):
if isinstance(self, Http1Server):
yield ReceiveHttp(RequestEndOfMessage(self.stream_id))
else:
yield ReceiveHttp(ResponseEndOfMessage(self.stream_id))
def mark_done(
self, *, request: bool = False, response: bool = False
) -> layer.CommandGenerator[None]:
if request:
self.request_done = True
if response:
self.response_done = True
if self.request_done and self.response_done:
assert self.request
assert self.response
if should_make_pipe(self.request, self.response):
yield from self.make_pipe()
return
try:
read_until_eof_semantics = (
http1.expected_http_body_size(self.request, self.response) == -1
)
except ValueError:
# this may raise only now (and not earlier) because an addon set invalid headers,
# in which case it's not really clear what we are supposed to do.
read_until_eof_semantics = False
connection_done = (
read_until_eof_semantics
or http1.connection_close(
self.request.http_version, self.request.headers
)
or http1.connection_close(
self.response.http_version, self.response.headers
)
# If we proxy HTTP/2 to HTTP/1, we only use upstream connections for one request.
# This simplifies our connection management quite a bit as we can rely on
# the proxyserver's max-connection-per-server throttling.
or (
(self.request.is_http2 or self.request.is_http3)
and isinstance(self, Http1Client)
)
)
if connection_done:
yield commands.CloseConnection(self.conn)
self.state = self.done
return
self.request_done = self.response_done = False
self.request = self.response = None
if isinstance(self, Http1Server):
self.stream_id += 2
else:
self.stream_id = None
self.state = self.read_headers
if self.buf:
yield from self.state(events.DataReceived(self.conn, b""))
class Http1Server(Http1Connection):
"""A simple HTTP/1 server with no pipelining support."""
ReceiveProtocolError = RequestProtocolError
ReceiveData = RequestData
ReceiveEndOfMessage = RequestEndOfMessage
stream_id: int
def __init__(self, context: Context):
super().__init__(context, context.client)
self.stream_id = 1
def send(self, event: HttpEvent) -> layer.CommandGenerator[None]:
assert event.stream_id == self.stream_id
if isinstance(event, ResponseHeaders):
self.response = response = event.response
if response.is_http2 or response.is_http3:
response = response.copy()
# Convert to an HTTP/1 response.
response.http_version = "HTTP/1.1"
# not everyone supports empty reason phrases, so we better make up one.
response.reason = status_codes.RESPONSES.get(response.status_code, "")
# Shall we set a Content-Length header here if there is none?
# For now, let's try to modify as little as possible.
raw = http1.assemble_response_head(response)
yield commands.SendData(self.conn, raw)
elif isinstance(event, ResponseData):
assert self.response
if "chunked" in self.response.headers.get("transfer-encoding", "").lower():
raw = b"%x\r\n%s\r\n" % (len(event.data), event.data)
else:
raw = event.data
if raw:
yield commands.SendData(self.conn, raw)
elif isinstance(event, ResponseEndOfMessage):
assert self.request
assert self.response
if (
self.request.method.upper() != "HEAD"
and "chunked"
in self.response.headers.get("transfer-encoding", "").lower()
):
yield commands.SendData(self.conn, b"0\r\n\r\n")
yield from self.mark_done(response=True)
elif isinstance(event, ResponseProtocolError):
if not (self.conn.state & ConnectionState.CAN_WRITE):
return
status = event.code.http_status_code()
if not self.response and status is not None:
yield commands.SendData(
self.conn, make_error_response(status, event.message)
)
yield commands.CloseConnection(self.conn)
else:
raise AssertionError(f"Unexpected event: {event}")
def read_headers(
self, event: events.ConnectionEvent
) -> layer.CommandGenerator[None]:
if isinstance(event, events.DataReceived):
request_head = self.buf.maybe_extract_lines()
if request_head:
try:
self.request = http1.read_request_head(
[bytes(x) for x in request_head]
)
expected_body_size = http1.expected_http_body_size(self.request)
except ValueError as e:
yield commands.SendData(self.conn, make_error_response(400, str(e)))
yield commands.CloseConnection(self.conn)
if self.request:
# we have headers that we can show in the ui
yield ReceiveHttp(
RequestHeaders(self.stream_id, self.request, False)
)
yield ReceiveHttp(
RequestProtocolError(
self.stream_id, str(e), ErrorCode.GENERIC_CLIENT_ERROR
)
)
else:
yield commands.Log(
f"{human.format_address(self.conn.peername)}: {e}"
)
self.state = self.done
return
yield ReceiveHttp(
RequestHeaders(
self.stream_id, self.request, expected_body_size == 0
)
)
self.body_reader = make_body_reader(expected_body_size)
self.state = self.read_body
yield from self.state(event)
else:
pass # FIXME: protect against header size DoS
elif isinstance(event, events.ConnectionClosed):
buf = bytes(self.buf)
if buf.strip():
yield commands.Log(
f"Client closed connection before completing request headers: {buf!r}"
)
yield commands.CloseConnection(self.conn)
else:
raise AssertionError(f"Unexpected event: {event}")
def mark_done(
self, *, request: bool = False, response: bool = False
) -> layer.CommandGenerator[None]:
yield from super().mark_done(request=request, response=response)
if self.request_done and not self.response_done:
self.state = self.wait
class Http1Client(Http1Connection):
"""A simple HTTP/1 client with no pipelining support."""
ReceiveProtocolError = ResponseProtocolError
ReceiveData = ResponseData
ReceiveEndOfMessage = ResponseEndOfMessage
def __init__(self, context: Context):
super().__init__(context, context.server)
def send(self, event: HttpEvent) -> layer.CommandGenerator[None]:
if isinstance(event, RequestProtocolError):
yield commands.CloseConnection(self.conn)
return
if self.stream_id is None:
assert isinstance(event, RequestHeaders)
self.stream_id = event.stream_id
self.request = event.request
assert self.stream_id == event.stream_id
if isinstance(event, RequestHeaders):
request = event.request
if request.is_http2 or request.is_http3:
# Convert to an HTTP/1 request.
request = (
request.copy()
) # (we could probably be a bit more efficient here.)
request.http_version = "HTTP/1.1"
if "Host" not in request.headers and request.authority:
request.headers.insert(0, "Host", request.authority)
request.authority = ""
cookie_headers = request.headers.get_all("Cookie")
if len(cookie_headers) > 1:
# Only HTTP/2 supports multiple cookie headers, HTTP/1.x does not.
# see: https://www.rfc-editor.org/rfc/rfc6265#section-5.4
# https://www.rfc-editor.org/rfc/rfc7540#section-8.1.2.5
request.headers["Cookie"] = "; ".join(cookie_headers)
raw = http1.assemble_request_head(request)
yield commands.SendData(self.conn, raw)
elif isinstance(event, RequestData):
assert self.request
if "chunked" in self.request.headers.get("transfer-encoding", "").lower():
raw = b"%x\r\n%s\r\n" % (len(event.data), event.data)
else:
raw = event.data
if raw:
yield commands.SendData(self.conn, raw)
elif isinstance(event, RequestEndOfMessage):
assert self.request
if "chunked" in self.request.headers.get("transfer-encoding", "").lower():
yield commands.SendData(self.conn, b"0\r\n\r\n")
elif http1.expected_http_body_size(self.request, self.response) == -1:
yield commands.CloseTcpConnection(self.conn, half_close=True)
yield from self.mark_done(request=True)
else:
raise AssertionError(f"Unexpected event: {event}")
def read_headers(
self, event: events.ConnectionEvent
) -> layer.CommandGenerator[None]:
if isinstance(event, events.DataReceived):
if not self.request:
# we just received some data for an unknown request.
yield commands.Log(f"Unexpected data from server: {bytes(self.buf)!r}")
yield commands.CloseConnection(self.conn)
return
assert self.stream_id is not None
response_head = self.buf.maybe_extract_lines()
if response_head:
try:
self.response = http1.read_response_head(
[bytes(x) for x in response_head]
)
expected_size = http1.expected_http_body_size(
self.request, self.response
)
except ValueError as e:
yield commands.CloseConnection(self.conn)
yield ReceiveHttp(
ResponseProtocolError(
self.stream_id,
f"Cannot parse HTTP response: {e}",
ErrorCode.GENERIC_SERVER_ERROR,
)
)
return
yield ReceiveHttp(
ResponseHeaders(self.stream_id, self.response, expected_size == 0)
)
self.body_reader = make_body_reader(expected_size)
self.state = self.read_body
yield from self.state(event)
else:
pass # FIXME: protect against header size DoS
elif isinstance(event, events.ConnectionClosed):
if self.conn.state & ConnectionState.CAN_WRITE:
yield commands.CloseConnection(self.conn)
if self.stream_id:
if self.buf:
yield ReceiveHttp(
ResponseProtocolError(
self.stream_id,
f"unexpected server response: {bytes(self.buf)!r}",
ErrorCode.GENERIC_SERVER_ERROR,
)
)
else:
# The server has closed the connection to prevent us from continuing.
# We need to signal that to the stream.
# https://tools.ietf.org/html/rfc7231#section-6.5.11
yield ReceiveHttp(
ResponseProtocolError(
self.stream_id,
"server closed connection",
ErrorCode.GENERIC_SERVER_ERROR,
)
)
else:
return
else:
raise AssertionError(f"Unexpected event: {event}")
def should_make_pipe(request: http.Request, response: http.Response) -> bool:
if response.status_code == 101:
return True
elif response.status_code == 200 and request.method.upper() == "CONNECT":
return True
else:
return False
def make_body_reader(expected_size: int | None) -> TBodyReader:
if expected_size is None:
return ChunkedReader()
elif expected_size == -1:
return Http10Reader()
else:
return ContentLengthReader(expected_size)
def make_error_response(
status_code: int,
message: str = "",
) -> bytes:
resp = http.Response.make(
status_code,
format_error(status_code, message),
http.Headers(
Server=version.MITMPROXY,
Connection="close",
Content_Type="text/html",
),
)
return http1.assemble_response(resp)
__all__ = [
"Http1Client",
"Http1Server",
]

View File

@@ -0,0 +1,714 @@
import collections
import time
from collections.abc import Sequence
from enum import Enum
from logging import DEBUG
from logging import ERROR
from typing import Any
from typing import assert_never
from typing import ClassVar
import h2.config
import h2.connection
import h2.errors
import h2.events
import h2.exceptions
import h2.settings
import h2.stream
import h2.utilities
from ...commands import CloseConnection
from ...commands import Log
from ...commands import RequestWakeup
from ...commands import SendData
from ...context import Context
from ...events import ConnectionClosed
from ...events import DataReceived
from ...events import Event
from ...events import Start
from ...events import Wakeup
from ...layer import CommandGenerator
from ...utils import expect
from . import ErrorCode
from . import RequestData
from . import RequestEndOfMessage
from . import RequestHeaders
from . import RequestProtocolError
from . import RequestTrailers
from . import ResponseData
from . import ResponseEndOfMessage
from . import ResponseHeaders
from . import ResponseProtocolError
from . import ResponseTrailers
from ._base import format_error
from ._base import HttpConnection
from ._base import HttpEvent
from ._base import ReceiveHttp
from ._http_h2 import BufferedH2Connection
from ._http_h2 import H2ConnectionLogger
from mitmproxy import http
from mitmproxy import version
from mitmproxy.connection import Connection
from mitmproxy.net.http import status_codes
from mitmproxy.net.http import url
from mitmproxy.utils import human
class StreamState(Enum):
EXPECTING_HEADERS = 1
HEADERS_RECEIVED = 2
CATCH_HYPER_H2_ERRORS = (ValueError, IndexError)
class Http2Connection(HttpConnection):
h2_conf: ClassVar[h2.config.H2Configuration]
h2_conf_defaults: dict[str, Any] = dict(
header_encoding=False,
validate_outbound_headers=False,
# validate_inbound_headers is controlled by the validate_inbound_headers option.
normalize_inbound_headers=False, # changing this to True is required to pass h2spec
normalize_outbound_headers=False,
)
h2_conn: BufferedH2Connection
streams: dict[int, StreamState]
"""keep track of all active stream ids to send protocol errors on teardown"""
ReceiveProtocolError: type[RequestProtocolError | ResponseProtocolError]
ReceiveData: type[RequestData | ResponseData]
ReceiveTrailers: type[RequestTrailers | ResponseTrailers]
ReceiveEndOfMessage: type[RequestEndOfMessage | ResponseEndOfMessage]
def __init__(self, context: Context, conn: Connection):
super().__init__(context, conn)
if self.debug:
self.h2_conf.logger = H2ConnectionLogger(
self.context.client.peername, self.__class__.__name__
)
self.h2_conf.validate_inbound_headers = (
self.context.options.validate_inbound_headers
)
self.h2_conn = BufferedH2Connection(self.h2_conf)
self.streams = {}
def is_closed(self, stream_id: int) -> bool:
"""Check if a non-idle stream is closed"""
stream = self.h2_conn.streams.get(stream_id, None)
if (
stream is not None
and stream.state_machine.state is not h2.stream.StreamState.CLOSED
and self.h2_conn.state_machine.state
is not h2.connection.ConnectionState.CLOSED
):
return False
else:
return True
def is_open_for_us(self, stream_id: int) -> bool:
"""Check if we can write to a non-idle stream."""
stream = self.h2_conn.streams.get(stream_id, None)
if (
stream is not None
and stream.state_machine.state
is not h2.stream.StreamState.HALF_CLOSED_LOCAL
and stream.state_machine.state is not h2.stream.StreamState.CLOSED
and self.h2_conn.state_machine.state
is not h2.connection.ConnectionState.CLOSED
):
return True
else:
return False
def _handle_event(self, event: Event) -> CommandGenerator[None]:
if isinstance(event, Start):
self.h2_conn.initiate_connection()
yield SendData(self.conn, self.h2_conn.data_to_send())
elif isinstance(event, HttpEvent):
if isinstance(event, (RequestData, ResponseData)):
if self.is_open_for_us(event.stream_id):
self.h2_conn.send_data(event.stream_id, event.data)
elif isinstance(event, (RequestTrailers, ResponseTrailers)):
if self.is_open_for_us(event.stream_id):
trailers = [*event.trailers.fields]
self.h2_conn.send_trailers(event.stream_id, trailers)
elif isinstance(event, (RequestEndOfMessage, ResponseEndOfMessage)):
if self.is_open_for_us(event.stream_id):
self.h2_conn.end_stream(event.stream_id)
elif isinstance(event, (RequestProtocolError, ResponseProtocolError)):
if not self.is_closed(event.stream_id):
stream: h2.stream.H2Stream = self.h2_conn.streams[event.stream_id]
status = event.code.http_status_code()
if (
isinstance(event, ResponseProtocolError)
and self.is_open_for_us(event.stream_id)
and not stream.state_machine.headers_sent
and status is not None
):
self.h2_conn.send_headers(
event.stream_id,
[
(b":status", b"%d" % status),
(b"server", version.MITMPROXY.encode()),
(b"content-type", b"text/html"),
],
)
self.h2_conn.send_data(
event.stream_id,
format_error(status, event.message),
end_stream=True,
)
else:
match event.code:
case ErrorCode.CANCEL | ErrorCode.CLIENT_DISCONNECTED:
error_code = h2.errors.ErrorCodes.CANCEL
case ErrorCode.KILL:
# XXX: Debateable whether this is the best error code.
error_code = h2.errors.ErrorCodes.INTERNAL_ERROR
case ErrorCode.HTTP_1_1_REQUIRED:
error_code = h2.errors.ErrorCodes.HTTP_1_1_REQUIRED
case ErrorCode.PASSTHROUGH_CLOSE:
# FIXME: This probably shouldn't be a protocol error, but an EOM event.
error_code = h2.errors.ErrorCodes.CANCEL
case (
ErrorCode.GENERIC_CLIENT_ERROR
| ErrorCode.GENERIC_SERVER_ERROR
| ErrorCode.REQUEST_TOO_LARGE
| ErrorCode.RESPONSE_TOO_LARGE
| ErrorCode.CONNECT_FAILED
| ErrorCode.DESTINATION_UNKNOWN
| ErrorCode.REQUEST_VALIDATION_FAILED
| ErrorCode.RESPONSE_VALIDATION_FAILED
):
error_code = h2.errors.ErrorCodes.INTERNAL_ERROR
case other: # pragma: no cover
assert_never(other)
self.h2_conn.reset_stream(event.stream_id, error_code.value)
else:
raise AssertionError(f"Unexpected event: {event}")
data_to_send = self.h2_conn.data_to_send()
if data_to_send:
yield SendData(self.conn, data_to_send)
elif isinstance(event, DataReceived):
try:
try:
events = self.h2_conn.receive_data(event.data)
except CATCH_HYPER_H2_ERRORS as e: # pragma: no cover
# this should never raise a ValueError, but we triggered one while fuzzing:
# https://github.com/python-hyper/hyper-h2/issues/1231
# this stays here as defense-in-depth.
raise h2.exceptions.ProtocolError(
f"uncaught hyper-h2 error: {e}"
) from e
except h2.exceptions.ProtocolError as e:
events = [e]
for h2_event in events:
if self.debug:
yield Log(f"{self.debug}[h2] {h2_event}", DEBUG)
if (yield from self.handle_h2_event(h2_event)):
if self.debug:
yield Log(f"{self.debug}[h2] done", DEBUG)
return
data_to_send = self.h2_conn.data_to_send()
if data_to_send:
yield SendData(self.conn, data_to_send)
elif isinstance(event, ConnectionClosed):
yield from self.close_connection("peer closed connection")
else:
raise AssertionError(f"Unexpected event: {event!r}")
def handle_h2_event(self, event: h2.events.Event) -> CommandGenerator[bool]:
"""returns true if further processing should be stopped."""
if isinstance(event, h2.events.DataReceived):
state = self.streams.get(event.stream_id, None)
if state is StreamState.HEADERS_RECEIVED:
is_empty_eos_data_frame = event.stream_ended and not event.data
if not is_empty_eos_data_frame:
yield ReceiveHttp(self.ReceiveData(event.stream_id, event.data))
elif state is StreamState.EXPECTING_HEADERS:
yield from self.protocol_error(
f"Received HTTP/2 data frame, expected headers."
)
return True
self.h2_conn.acknowledge_received_data(
event.flow_controlled_length, event.stream_id
)
elif isinstance(event, h2.events.TrailersReceived):
trailers = http.Headers(event.headers)
yield ReceiveHttp(self.ReceiveTrailers(event.stream_id, trailers))
elif isinstance(event, h2.events.StreamEnded):
state = self.streams.get(event.stream_id, None)
if state is StreamState.HEADERS_RECEIVED:
yield ReceiveHttp(self.ReceiveEndOfMessage(event.stream_id))
elif state is StreamState.EXPECTING_HEADERS:
raise AssertionError("unreachable")
if self.is_closed(event.stream_id):
self.streams.pop(event.stream_id, None)
elif isinstance(event, h2.events.StreamReset):
if event.stream_id in self.streams:
try:
err_str = h2.errors.ErrorCodes(event.error_code).name
except ValueError:
err_str = str(event.error_code)
match event.error_code:
case h2.errors.ErrorCodes.CANCEL:
err_code = ErrorCode.CANCEL
case h2.errors.ErrorCodes.HTTP_1_1_REQUIRED:
err_code = ErrorCode.HTTP_1_1_REQUIRED
case _:
err_code = self.ReceiveProtocolError.code
yield ReceiveHttp(
self.ReceiveProtocolError(
event.stream_id,
f"stream reset by client ({err_str})",
code=err_code,
)
)
self.streams.pop(event.stream_id)
else:
pass # We don't track priority frames which could be followed by a stream reset here.
elif isinstance(event, h2.exceptions.ProtocolError):
yield from self.protocol_error(f"HTTP/2 protocol error: {event}")
return True
elif isinstance(event, h2.events.ConnectionTerminated):
yield from self.close_connection(f"HTTP/2 connection closed: {event!r}")
return True
# The implementation above isn't really ideal, we should probably only terminate streams > last_stream_id?
# We currently lack a mechanism to signal that connections are still active but cannot be reused.
# for stream_id in self.streams:
# if stream_id > event.last_stream_id:
# yield ReceiveHttp(self.ReceiveProtocolError(stream_id, f"HTTP/2 connection closed: {event!r}"))
# self.streams.pop(stream_id)
elif isinstance(event, h2.events.RemoteSettingsChanged):
pass
elif isinstance(event, h2.events.SettingsAcknowledged):
pass
elif isinstance(event, h2.events.PriorityUpdated):
pass
elif isinstance(event, h2.events.PingReceived):
pass
elif isinstance(event, h2.events.PingAckReceived):
pass
elif isinstance(event, h2.events.PushedStreamReceived):
yield Log(
"Received HTTP/2 push promise, even though we signalled no support.",
ERROR,
)
elif isinstance(event, h2.events.UnknownFrameReceived):
# https://http2.github.io/http2-spec/#rfc.section.4.1
# Implementations MUST ignore and discard any frame that has a type that is unknown.
yield Log(f"Ignoring unknown HTTP/2 frame type: {event.frame.type}")
elif isinstance(event, h2.events.AlternativeServiceAvailable):
yield Log(
"Received HTTP/2 Alt-Svc frame, which will not be forwarded.", DEBUG
)
else:
raise AssertionError(f"Unexpected event: {event!r}")
return False
def protocol_error(
self,
message: str,
error_code: int = h2.errors.ErrorCodes.PROTOCOL_ERROR,
) -> CommandGenerator[None]:
yield Log(f"{human.format_address(self.conn.peername)}: {message}")
self.h2_conn.close_connection(error_code, message.encode())
yield SendData(self.conn, self.h2_conn.data_to_send())
yield from self.close_connection(message)
def close_connection(self, msg: str) -> CommandGenerator[None]:
yield CloseConnection(self.conn)
for stream_id in self.streams:
yield ReceiveHttp(
self.ReceiveProtocolError(
stream_id, msg, self.ReceiveProtocolError.code
)
)
self.streams.clear()
self._handle_event = self.done # type: ignore
@expect(DataReceived, HttpEvent, ConnectionClosed, Wakeup)
def done(self, _) -> CommandGenerator[None]:
yield from ()
def normalize_h1_headers(
headers: list[tuple[bytes, bytes]], is_client: bool
) -> list[tuple[bytes, bytes]]:
# HTTP/1 servers commonly send capitalized headers (Content-Length vs content-length),
# which isn't valid HTTP/2. As such we normalize.
# Make sure that this is not just an iterator but an iterable,
# otherwise hyper-h2 will silently drop headers.
return list(
h2.utilities.normalize_outbound_headers(
headers,
h2.utilities.HeaderValidationFlags(is_client, False, not is_client, False),
)
)
def normalize_h2_headers(headers: list[tuple[bytes, bytes]]) -> CommandGenerator[None]:
for i in range(len(headers)):
if not headers[i][0].islower():
yield Log(
f"Lowercased {repr(headers[i][0]).lstrip('b')} header as uppercase is not allowed with HTTP/2."
)
headers[i] = (headers[i][0].lower(), headers[i][1])
def format_h2_request_headers(
context: Context,
event: RequestHeaders,
) -> CommandGenerator[list[tuple[bytes, bytes]]]:
pseudo_headers = [
(b":method", event.request.data.method),
(b":scheme", event.request.data.scheme),
(b":path", event.request.data.path),
]
if event.request.authority:
pseudo_headers.append((b":authority", event.request.data.authority))
if event.request.is_http2 or event.request.is_http3:
hdrs = list(event.request.headers.fields)
if context.options.normalize_outbound_headers:
yield from normalize_h2_headers(hdrs)
else:
headers = event.request.headers
if not event.request.authority and "host" in headers:
headers = headers.copy()
pseudo_headers.append((b":authority", headers.pop(b"host")))
hdrs = normalize_h1_headers(list(headers.fields), True)
return pseudo_headers + hdrs
def format_h2_response_headers(
context: Context,
event: ResponseHeaders,
) -> CommandGenerator[list[tuple[bytes, bytes]]]:
headers = [
(b":status", b"%d" % event.response.status_code),
*event.response.headers.fields,
]
if event.response.is_http2 or event.response.is_http3:
if context.options.normalize_outbound_headers:
yield from normalize_h2_headers(headers)
else:
headers = normalize_h1_headers(headers, False)
return headers
class Http2Server(Http2Connection):
h2_conf = h2.config.H2Configuration(
**Http2Connection.h2_conf_defaults,
client_side=False,
)
ReceiveProtocolError = RequestProtocolError
ReceiveData = RequestData
ReceiveTrailers = RequestTrailers
ReceiveEndOfMessage = RequestEndOfMessage
def __init__(self, context: Context):
super().__init__(context, context.client)
def _handle_event(self, event: Event) -> CommandGenerator[None]:
if isinstance(event, ResponseHeaders):
if self.is_open_for_us(event.stream_id):
self.h2_conn.send_headers(
event.stream_id,
headers=(
yield from format_h2_response_headers(self.context, event)
),
end_stream=event.end_stream,
)
yield SendData(self.conn, self.h2_conn.data_to_send())
else:
yield from super()._handle_event(event)
def handle_h2_event(self, event: h2.events.Event) -> CommandGenerator[bool]:
if isinstance(event, h2.events.RequestReceived):
try:
(
host,
port,
method,
scheme,
authority,
path,
headers,
) = parse_h2_request_headers(event.headers)
except ValueError as e:
yield from self.protocol_error(f"Invalid HTTP/2 request headers: {e}")
return True
request = http.Request(
host=host,
port=port,
method=method,
scheme=scheme,
authority=authority,
path=path,
http_version=b"HTTP/2.0",
headers=headers,
content=None,
trailers=None,
timestamp_start=time.time(),
timestamp_end=None,
)
self.streams[event.stream_id] = StreamState.HEADERS_RECEIVED
yield ReceiveHttp(
RequestHeaders(
event.stream_id, request, end_stream=bool(event.stream_ended)
)
)
return False
else:
return (yield from super().handle_h2_event(event))
class Http2Client(Http2Connection):
h2_conf = h2.config.H2Configuration(
**Http2Connection.h2_conf_defaults,
client_side=True,
)
ReceiveProtocolError = ResponseProtocolError
ReceiveData = ResponseData
ReceiveTrailers = ResponseTrailers
ReceiveEndOfMessage = ResponseEndOfMessage
our_stream_id: dict[int, int]
their_stream_id: dict[int, int]
stream_queue: collections.defaultdict[int, list[Event]]
"""Queue of streams that we haven't sent yet because we have reached MAX_CONCURRENT_STREAMS"""
provisional_max_concurrency: int | None = 10
"""A provisional currency limit before we get the server's first settings frame."""
last_activity: float
"""Timestamp of when we've last seen network activity on this connection."""
def __init__(self, context: Context):
super().__init__(context, context.server)
# Disable HTTP/2 push for now to keep things simple.
# don't send here, that is done as part of initiate_connection().
self.h2_conn.local_settings.enable_push = 0
# hyper-h2 pitfall: we need to acknowledge here, otherwise its sends out the old settings.
self.h2_conn.local_settings.acknowledge()
self.our_stream_id = {}
self.their_stream_id = {}
self.stream_queue = collections.defaultdict(list)
def _handle_event(self, event: Event) -> CommandGenerator[None]:
# We can't reuse stream ids from the client because they may arrived reordered here
# and HTTP/2 forbids opening a stream on a lower id than what was previously sent (see test_stream_concurrency).
# To mitigate this, we transparently map the outside's stream id to our stream id.
if isinstance(event, HttpEvent):
ours = self.our_stream_id.get(event.stream_id, None)
if ours is None:
no_free_streams = self.h2_conn.open_outbound_streams >= (
self.provisional_max_concurrency
or self.h2_conn.remote_settings.max_concurrent_streams
)
if no_free_streams:
self.stream_queue[event.stream_id].append(event)
return
ours = self.h2_conn.get_next_available_stream_id()
self.our_stream_id[event.stream_id] = ours
self.their_stream_id[ours] = event.stream_id
event.stream_id = ours
for cmd in self._handle_event2(event):
if isinstance(cmd, ReceiveHttp):
cmd.event.stream_id = self.their_stream_id[cmd.event.stream_id]
yield cmd
can_resume_queue = self.stream_queue and self.h2_conn.open_outbound_streams < (
self.provisional_max_concurrency
or self.h2_conn.remote_settings.max_concurrent_streams
)
if can_resume_queue:
# popitem would be LIFO, but we want FIFO.
events = self.stream_queue.pop(next(iter(self.stream_queue)))
for event in events:
yield from self._handle_event(event)
def _handle_event2(self, event: Event) -> CommandGenerator[None]:
if isinstance(event, Wakeup):
send_ping_now = (
# add one second to avoid unnecessary roundtrip, we don't need to be super correct here.
time.time() - self.last_activity + 1
> self.context.options.http2_ping_keepalive
)
if send_ping_now:
# PING frames MUST contain 8 octets of opaque data in the payload.
# A sender can include any value it chooses and use those octets in any fashion.
self.last_activity = time.time()
self.h2_conn.ping(b"0" * 8)
data = self.h2_conn.data_to_send()
if data is not None:
yield Log(
f"Send HTTP/2 keep-alive PING to {human.format_address(self.conn.peername)}",
DEBUG,
)
yield SendData(self.conn, data)
time_until_next_ping = self.context.options.http2_ping_keepalive - (
time.time() - self.last_activity
)
yield RequestWakeup(time_until_next_ping)
return
self.last_activity = time.time()
if isinstance(event, Start):
if self.context.options.http2_ping_keepalive > 0:
yield RequestWakeup(self.context.options.http2_ping_keepalive)
yield from super()._handle_event(event)
elif isinstance(event, RequestHeaders):
self.h2_conn.send_headers(
event.stream_id,
headers=(yield from format_h2_request_headers(self.context, event)),
end_stream=event.end_stream,
)
self.streams[event.stream_id] = StreamState.EXPECTING_HEADERS
yield SendData(self.conn, self.h2_conn.data_to_send())
else:
yield from super()._handle_event(event)
def handle_h2_event(self, event: h2.events.Event) -> CommandGenerator[bool]:
if isinstance(event, h2.events.ResponseReceived):
if (
self.streams.get(event.stream_id, None)
is not StreamState.EXPECTING_HEADERS
):
yield from self.protocol_error(f"Received unexpected HTTP/2 response.")
return True
try:
status_code, headers = parse_h2_response_headers(event.headers)
except ValueError as e:
yield from self.protocol_error(f"Invalid HTTP/2 response headers: {e}")
return True
response = http.Response(
http_version=b"HTTP/2.0",
status_code=status_code,
reason=b"",
headers=headers,
content=None,
trailers=None,
timestamp_start=time.time(),
timestamp_end=None,
)
self.streams[event.stream_id] = StreamState.HEADERS_RECEIVED
yield ReceiveHttp(
ResponseHeaders(event.stream_id, response, bool(event.stream_ended))
)
return False
elif isinstance(event, h2.events.InformationalResponseReceived):
# We violate the spec here ("A proxy MUST forward 1xx responses", RFC 7231),
# but that's probably fine:
# - 100 Continue is sent by mitmproxy to clients (irrespective of what the server does).
# - 101 Switching Protocols is not allowed for HTTP/2.
# - 102 Processing is WebDAV only and also ignorable.
# - 103 Early Hints is not mission-critical.
headers = http.Headers(event.headers)
status: str | int = "<unknown status>"
try:
status = int(headers[":status"])
reason = status_codes.RESPONSES.get(status, "")
except (KeyError, ValueError):
reason = ""
yield Log(f"Swallowing HTTP/2 informational response: {status} {reason}")
return False
elif isinstance(event, h2.events.RequestReceived):
yield from self.protocol_error(
f"HTTP/2 protocol error: received request from server"
)
return True
elif isinstance(event, h2.events.RemoteSettingsChanged):
# We have received at least one settings from now,
# which means we can rely on the max concurrency in remote_settings
self.provisional_max_concurrency = None
return (yield from super().handle_h2_event(event))
else:
return (yield from super().handle_h2_event(event))
def split_pseudo_headers(
h2_headers: Sequence[tuple[bytes, bytes]],
) -> tuple[dict[bytes, bytes], http.Headers]:
pseudo_headers: dict[bytes, bytes] = {}
i = 0
for header, value in h2_headers:
if header.startswith(b":"):
if header in pseudo_headers:
raise ValueError(f"Duplicate HTTP/2 pseudo header: {header!r}")
pseudo_headers[header] = value
i += 1
else:
# Pseudo-headers must be at the start, we are done here.
break
headers = http.Headers(h2_headers[i:])
return pseudo_headers, headers
def parse_h2_request_headers(
h2_headers: Sequence[tuple[bytes, bytes]],
) -> tuple[str, int, bytes, bytes, bytes, bytes, http.Headers]:
"""Split HTTP/2 pseudo-headers from the actual headers and parse them."""
pseudo_headers, headers = split_pseudo_headers(h2_headers)
try:
method: bytes = pseudo_headers.pop(b":method")
scheme: bytes = pseudo_headers.pop(
b":scheme"
) # this raises for HTTP/2 CONNECT requests
path: bytes = pseudo_headers.pop(b":path")
authority: bytes = pseudo_headers.pop(b":authority", b"")
except KeyError as e:
raise ValueError(f"Required pseudo header is missing: {e}")
if pseudo_headers:
raise ValueError(f"Unknown pseudo headers: {pseudo_headers}")
if authority:
host, port = url.parse_authority(authority, check=True)
if port is None:
port = 80 if scheme == b"http" else 443
else:
host = ""
port = 0
return host, port, method, scheme, authority, path, headers
def parse_h2_response_headers(
h2_headers: Sequence[tuple[bytes, bytes]],
) -> tuple[int, http.Headers]:
"""Split HTTP/2 pseudo-headers from the actual headers and parse them."""
pseudo_headers, headers = split_pseudo_headers(h2_headers)
try:
status_code: int = int(pseudo_headers.pop(b":status"))
except KeyError as e:
raise ValueError(f"Required pseudo header is missing: {e}")
if pseudo_headers:
raise ValueError(f"Unknown pseudo headers: {pseudo_headers}")
return status_code, headers
__all__ = [
"format_h2_request_headers",
"format_h2_response_headers",
"parse_h2_request_headers",
"parse_h2_response_headers",
"Http2Client",
"Http2Server",
]

View File

@@ -0,0 +1,309 @@
import time
from abc import abstractmethod
from typing import assert_never
from aioquic.h3.connection import ErrorCode as H3ErrorCode
from aioquic.h3.connection import FrameUnexpected as H3FrameUnexpected
from aioquic.h3.events import DataReceived
from aioquic.h3.events import HeadersReceived
from aioquic.h3.events import PushPromiseReceived
from . import ErrorCode
from . import RequestData
from . import RequestEndOfMessage
from . import RequestHeaders
from . import RequestProtocolError
from . import RequestTrailers
from . import ResponseData
from . import ResponseEndOfMessage
from . import ResponseHeaders
from . import ResponseProtocolError
from . import ResponseTrailers
from ._base import format_error
from ._base import HttpConnection
from ._base import HttpEvent
from ._base import ReceiveHttp
from ._http2 import format_h2_request_headers
from ._http2 import format_h2_response_headers
from ._http2 import parse_h2_request_headers
from ._http2 import parse_h2_response_headers
from ._http_h3 import LayeredH3Connection
from ._http_h3 import StreamClosed
from ._http_h3 import TrailersReceived
from mitmproxy import connection
from mitmproxy import http
from mitmproxy import version
from mitmproxy.proxy import commands
from mitmproxy.proxy import context
from mitmproxy.proxy import events
from mitmproxy.proxy import layer
from mitmproxy.proxy.layers.quic import error_code_to_str
from mitmproxy.proxy.layers.quic import QuicConnectionClosed
from mitmproxy.proxy.layers.quic import QuicStreamEvent
from mitmproxy.proxy.utils import expect
class Http3Connection(HttpConnection):
h3_conn: LayeredH3Connection
ReceiveData: type[RequestData | ResponseData]
ReceiveEndOfMessage: type[RequestEndOfMessage | ResponseEndOfMessage]
ReceiveProtocolError: type[RequestProtocolError | ResponseProtocolError]
ReceiveTrailers: type[RequestTrailers | ResponseTrailers]
def __init__(self, context: context.Context, conn: connection.Connection):
super().__init__(context, conn)
self.h3_conn = LayeredH3Connection(
self.conn, is_client=self.conn is self.context.server
)
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, events.Start):
yield from self.h3_conn.transmit()
# send mitmproxy HTTP events over the H3 connection
elif isinstance(event, HttpEvent):
try:
if isinstance(event, (RequestData, ResponseData)):
self.h3_conn.send_data(event.stream_id, event.data)
elif isinstance(event, (RequestHeaders, ResponseHeaders)):
headers = yield from (
format_h2_request_headers(self.context, event)
if isinstance(event, RequestHeaders)
else format_h2_response_headers(self.context, event)
)
self.h3_conn.send_headers(
event.stream_id, headers, end_stream=event.end_stream
)
elif isinstance(event, (RequestTrailers, ResponseTrailers)):
self.h3_conn.send_trailers(
event.stream_id, [*event.trailers.fields]
)
elif isinstance(event, (RequestEndOfMessage, ResponseEndOfMessage)):
self.h3_conn.end_stream(event.stream_id)
elif isinstance(event, (RequestProtocolError, ResponseProtocolError)):
status = event.code.http_status_code()
if (
isinstance(event, ResponseProtocolError)
and not self.h3_conn.has_sent_headers(event.stream_id)
and status is not None
):
self.h3_conn.send_headers(
event.stream_id,
[
(b":status", b"%d" % status),
(b"server", version.MITMPROXY.encode()),
(b"content-type", b"text/html"),
],
)
self.h3_conn.send_data(
event.stream_id,
format_error(status, event.message),
end_stream=True,
)
else:
match event.code:
case ErrorCode.CANCEL | ErrorCode.CLIENT_DISCONNECTED:
error_code = H3ErrorCode.H3_REQUEST_CANCELLED
case ErrorCode.KILL:
error_code = H3ErrorCode.H3_INTERNAL_ERROR
case ErrorCode.HTTP_1_1_REQUIRED:
error_code = H3ErrorCode.H3_VERSION_FALLBACK
case ErrorCode.PASSTHROUGH_CLOSE:
# FIXME: This probably shouldn't be a protocol error, but an EOM event.
error_code = H3ErrorCode.H3_REQUEST_CANCELLED
case (
ErrorCode.GENERIC_CLIENT_ERROR
| ErrorCode.GENERIC_SERVER_ERROR
| ErrorCode.REQUEST_TOO_LARGE
| ErrorCode.RESPONSE_TOO_LARGE
| ErrorCode.CONNECT_FAILED
| ErrorCode.DESTINATION_UNKNOWN
| ErrorCode.REQUEST_VALIDATION_FAILED
| ErrorCode.RESPONSE_VALIDATION_FAILED
):
error_code = H3ErrorCode.H3_INTERNAL_ERROR
case other: # pragma: no cover
assert_never(other)
self.h3_conn.close_stream(event.stream_id, error_code.value)
else: # pragma: no cover
raise AssertionError(f"Unexpected event: {event!r}")
except H3FrameUnexpected as e:
# Http2Connection also ignores HttpEvents that violate the current stream state
yield commands.Log(f"Received {event!r} unexpectedly: {e}")
else:
# transmit buffered data
yield from self.h3_conn.transmit()
# forward stream messages from the QUIC layer to the H3 connection
elif isinstance(event, QuicStreamEvent):
h3_events = self.h3_conn.handle_stream_event(event)
for h3_event in h3_events:
if isinstance(h3_event, StreamClosed):
err_str = error_code_to_str(h3_event.error_code)
match h3_event.error_code:
case H3ErrorCode.H3_REQUEST_CANCELLED:
err_code = ErrorCode.CANCEL
case H3ErrorCode.H3_VERSION_FALLBACK:
err_code = ErrorCode.HTTP_1_1_REQUIRED
case _:
err_code = self.ReceiveProtocolError.code
yield ReceiveHttp(
self.ReceiveProtocolError(
h3_event.stream_id,
f"stream closed by client ({err_str})",
code=err_code,
)
)
elif isinstance(h3_event, DataReceived):
if h3_event.data:
yield ReceiveHttp(
self.ReceiveData(h3_event.stream_id, h3_event.data)
)
if h3_event.stream_ended:
yield ReceiveHttp(self.ReceiveEndOfMessage(h3_event.stream_id))
elif isinstance(h3_event, HeadersReceived):
try:
receive_event = self.parse_headers(h3_event)
except ValueError as e:
self.h3_conn.close_connection(
error_code=H3ErrorCode.H3_GENERAL_PROTOCOL_ERROR,
reason_phrase=f"Invalid HTTP/3 request headers: {e}",
)
else:
yield ReceiveHttp(receive_event)
if h3_event.stream_ended:
yield ReceiveHttp(
self.ReceiveEndOfMessage(h3_event.stream_id)
)
elif isinstance(h3_event, TrailersReceived):
yield ReceiveHttp(
self.ReceiveTrailers(
h3_event.stream_id, http.Headers(h3_event.trailers)
)
)
if h3_event.stream_ended:
yield ReceiveHttp(self.ReceiveEndOfMessage(h3_event.stream_id))
elif isinstance(h3_event, PushPromiseReceived): # pragma: no cover
self.h3_conn.close_connection(
error_code=H3ErrorCode.H3_GENERAL_PROTOCOL_ERROR,
reason_phrase=f"Received HTTP/3 push promise, even though we signalled no support.",
)
else: # pragma: no cover
raise AssertionError(f"Unexpected event: {event!r}")
yield from self.h3_conn.transmit()
# report a protocol error for all remaining open streams when a connection is closed
elif isinstance(event, QuicConnectionClosed):
self._handle_event = self.done # type: ignore
self.h3_conn.handle_connection_closed(event)
msg = event.reason_phrase or error_code_to_str(event.error_code)
for stream_id in self.h3_conn.get_open_stream_ids():
yield ReceiveHttp(
self.ReceiveProtocolError(
stream_id, msg, self.ReceiveProtocolError.code
)
)
else: # pragma: no cover
raise AssertionError(f"Unexpected event: {event!r}")
@expect(HttpEvent, QuicStreamEvent, QuicConnectionClosed)
def done(self, _) -> layer.CommandGenerator[None]:
yield from ()
@abstractmethod
def parse_headers(self, event: HeadersReceived) -> RequestHeaders | ResponseHeaders:
pass # pragma: no cover
class Http3Server(Http3Connection):
ReceiveData = RequestData
ReceiveEndOfMessage = RequestEndOfMessage
ReceiveProtocolError = RequestProtocolError
ReceiveTrailers = RequestTrailers
def __init__(self, context: context.Context):
super().__init__(context, context.client)
def parse_headers(self, event: HeadersReceived) -> RequestHeaders | ResponseHeaders:
# same as HTTP/2
(
host,
port,
method,
scheme,
authority,
path,
headers,
) = parse_h2_request_headers(event.headers)
request = http.Request(
host=host,
port=port,
method=method,
scheme=scheme,
authority=authority,
path=path,
http_version=b"HTTP/3",
headers=headers,
content=None,
trailers=None,
timestamp_start=time.time(),
timestamp_end=None,
)
return RequestHeaders(event.stream_id, request, end_stream=event.stream_ended)
class Http3Client(Http3Connection):
ReceiveData = ResponseData
ReceiveEndOfMessage = ResponseEndOfMessage
ReceiveProtocolError = ResponseProtocolError
ReceiveTrailers = ResponseTrailers
our_stream_id: dict[int, int]
their_stream_id: dict[int, int]
def __init__(self, context: context.Context):
super().__init__(context, context.server)
self.our_stream_id = {}
self.their_stream_id = {}
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
# QUIC and HTTP/3 would actually allow for direct stream ID mapping, but since we want
# to support H2<->H3, we need to translate IDs.
# NOTE: We always create bidirectional streams, as we can't safely infer unidirectionality.
if isinstance(event, HttpEvent):
ours = self.our_stream_id.get(event.stream_id, None)
if ours is None:
ours = self.h3_conn.get_next_available_stream_id()
self.our_stream_id[event.stream_id] = ours
self.their_stream_id[ours] = event.stream_id
event.stream_id = ours
for cmd in super()._handle_event(event):
if isinstance(cmd, ReceiveHttp):
cmd.event.stream_id = self.their_stream_id[cmd.event.stream_id]
yield cmd
def parse_headers(self, event: HeadersReceived) -> RequestHeaders | ResponseHeaders:
# same as HTTP/2
status_code, headers = parse_h2_response_headers(event.headers)
response = http.Response(
http_version=b"HTTP/3",
status_code=status_code,
reason=b"",
headers=headers,
content=None,
trailers=None,
timestamp_start=time.time(),
timestamp_end=None,
)
return ResponseHeaders(event.stream_id, response, event.stream_ended)
__all__ = [
"Http3Client",
"Http3Server",
]

View File

@@ -0,0 +1,207 @@
import collections
import logging
from typing import NamedTuple
import h2.config
import h2.connection
import h2.events
import h2.exceptions
import h2.settings
import h2.stream
logger = logging.getLogger(__name__)
class H2ConnectionLogger(h2.config.DummyLogger):
def __init__(self, peername: tuple, conn_type: str):
super().__init__()
self.peername = peername
self.conn_type = conn_type
def debug(self, fmtstr, *args):
logger.debug(
f"{self.conn_type} {fmtstr}", *args, extra={"client": self.peername}
)
def trace(self, fmtstr, *args):
logger.log(
logging.DEBUG - 1,
f"{self.conn_type} {fmtstr}",
*args,
extra={"client": self.peername},
)
class SendH2Data(NamedTuple):
data: bytes
end_stream: bool
class BufferedH2Connection(h2.connection.H2Connection):
"""
This class wrap's hyper-h2's H2Connection and adds internal send buffers.
To simplify implementation, padding is unsupported.
"""
stream_buffers: collections.defaultdict[int, collections.deque[SendH2Data]]
stream_trailers: dict[int, list[tuple[bytes, bytes]]]
def __init__(self, config: h2.config.H2Configuration):
super().__init__(config)
self.local_settings.initial_window_size = 2**31 - 1
self.local_settings.max_frame_size = 2**17
self.max_inbound_frame_size = 2**17
# hyper-h2 pitfall: we need to acknowledge here, otherwise its sends out the old settings.
self.local_settings.acknowledge()
self.stream_buffers = collections.defaultdict(collections.deque)
self.stream_trailers = {}
def initiate_connection(self):
super().initiate_connection()
# We increase the flow-control window for new streams with a setting,
# but we need to increase the overall connection flow-control window as well.
self.increment_flow_control_window(
2**31 - 1 - self.inbound_flow_control_window
) # maximum - default
def send_data(
self,
stream_id: int,
data: bytes,
end_stream: bool = False,
pad_length: None = None,
) -> None:
"""
Send data on a given stream.
In contrast to plain hyper-h2, this method will not raise if the data cannot be sent immediately.
Data is split up and buffered internally.
"""
frame_size = len(data)
assert pad_length is None
if frame_size > self.max_outbound_frame_size:
for start in range(0, frame_size, self.max_outbound_frame_size):
chunk = data[start : start + self.max_outbound_frame_size]
self.send_data(stream_id, chunk, end_stream=False)
return
if self.stream_buffers.get(stream_id, None):
# We already have some data buffered, let's append.
self.stream_buffers[stream_id].append(SendH2Data(data, end_stream))
else:
available_window = self.local_flow_control_window(stream_id)
if frame_size <= available_window:
super().send_data(stream_id, data, end_stream)
else:
if available_window:
can_send_now = data[:available_window]
super().send_data(stream_id, can_send_now, end_stream=False)
data = data[available_window:]
# We can't send right now, so we buffer.
self.stream_buffers[stream_id].append(SendH2Data(data, end_stream))
def send_trailers(self, stream_id: int, trailers: list[tuple[bytes, bytes]]):
if self.stream_buffers.get(stream_id, None):
# Though trailers are not subject to flow control, we need to queue them and send strictly after data frames
self.stream_trailers[stream_id] = trailers
else:
self.send_headers(stream_id, trailers, end_stream=True)
def end_stream(self, stream_id: int) -> None:
if stream_id in self.stream_trailers:
return # we already have trailers queued up that will end the stream.
self.send_data(stream_id, b"", end_stream=True)
def reset_stream(self, stream_id: int, error_code: int = 0) -> None:
self.stream_buffers.pop(stream_id, None)
super().reset_stream(stream_id, error_code)
def receive_data(self, data: bytes):
events = super().receive_data(data)
ret = []
for event in events:
if isinstance(event, h2.events.WindowUpdated):
if event.stream_id == 0:
self.connection_window_updated()
else:
self.stream_window_updated(event.stream_id)
continue
elif isinstance(event, h2.events.RemoteSettingsChanged):
if (
h2.settings.SettingCodes.INITIAL_WINDOW_SIZE
in event.changed_settings
):
self.connection_window_updated()
elif isinstance(event, h2.events.StreamReset):
self.stream_buffers.pop(event.stream_id, None)
elif isinstance(event, h2.events.ConnectionTerminated):
self.stream_buffers.clear()
ret.append(event)
return ret
def stream_window_updated(self, stream_id: int) -> bool:
"""
The window for a specific stream has updated. Send as much buffered data as possible.
"""
# If the stream has been reset in the meantime, we just clear the buffer.
try:
stream: h2.stream.H2Stream = self.streams[stream_id]
except KeyError:
stream_was_reset = True
else:
stream_was_reset = stream.state_machine.state not in (
h2.stream.StreamState.OPEN,
h2.stream.StreamState.HALF_CLOSED_REMOTE,
)
if stream_was_reset:
self.stream_buffers.pop(stream_id, None)
return False
available_window = self.local_flow_control_window(stream_id)
sent_any_data = False
while available_window > 0 and stream_id in self.stream_buffers:
chunk: SendH2Data = self.stream_buffers[stream_id].popleft()
if len(chunk.data) > available_window:
# We can't send the entire chunk, so we have to put some bytes back into the buffer.
self.stream_buffers[stream_id].appendleft(
SendH2Data(
data=chunk.data[available_window:],
end_stream=chunk.end_stream,
)
)
chunk = SendH2Data(
data=chunk.data[:available_window],
end_stream=False,
)
super().send_data(stream_id, data=chunk.data, end_stream=chunk.end_stream)
available_window -= len(chunk.data)
if not self.stream_buffers[stream_id]:
del self.stream_buffers[stream_id]
if stream_id in self.stream_trailers:
self.send_headers(
stream_id, self.stream_trailers.pop(stream_id), end_stream=True
)
sent_any_data = True
return sent_any_data
def connection_window_updated(self) -> None:
"""
The connection window has updated. Send data from buffers in a round-robin fashion.
"""
sent_any_data = True
while sent_any_data:
sent_any_data = False
for stream_id in list(self.stream_buffers):
self.stream_buffers[stream_id] = self.stream_buffers.pop(
stream_id
) # move to end of dict
if self.stream_window_updated(stream_id):
sent_any_data = True
if self.outbound_flow_control_window == 0:
return

View File

@@ -0,0 +1,321 @@
from collections.abc import Iterable
from dataclasses import dataclass
from aioquic.h3.connection import FrameUnexpected
from aioquic.h3.connection import H3Connection
from aioquic.h3.connection import H3Event
from aioquic.h3.connection import H3Stream
from aioquic.h3.connection import Headers
from aioquic.h3.connection import HeadersState
from aioquic.h3.events import HeadersReceived
from aioquic.quic.configuration import QuicConfiguration
from aioquic.quic.events import StreamDataReceived
from aioquic.quic.packet import QuicErrorCode
from mitmproxy import connection
from mitmproxy.proxy import commands
from mitmproxy.proxy import layer
from mitmproxy.proxy.layers.quic import CloseQuicConnection
from mitmproxy.proxy.layers.quic import QuicConnectionClosed
from mitmproxy.proxy.layers.quic import QuicStreamDataReceived
from mitmproxy.proxy.layers.quic import QuicStreamEvent
from mitmproxy.proxy.layers.quic import QuicStreamReset
from mitmproxy.proxy.layers.quic import QuicStreamStopSending
from mitmproxy.proxy.layers.quic import ResetQuicStream
from mitmproxy.proxy.layers.quic import SendQuicStreamData
from mitmproxy.proxy.layers.quic import StopSendingQuicStream
@dataclass
class TrailersReceived(H3Event):
"""
The TrailersReceived event is fired whenever trailers are received.
"""
trailers: Headers
"The trailers."
stream_id: int
"The ID of the stream the trailers were received for."
stream_ended: bool
"Whether the STREAM frame had the FIN bit set."
@dataclass
class StreamClosed(H3Event):
"""
The StreamReset event is fired when the peer is sending a CLOSE_STREAM
or a STOP_SENDING frame. For HTTP/3, we don't differentiate between the two.
"""
stream_id: int
"The ID of the stream that was reset."
error_code: int
"""The error code indicating why the stream was closed."""
class MockQuic:
"""
aioquic intermingles QUIC and HTTP/3. This is something we don't want to do because that makes testing much harder.
Instead, we mock our QUIC connection object here and then take out the wire data to be sent.
"""
def __init__(self, conn: connection.Connection, is_client: bool) -> None:
self.conn = conn
self.pending_commands: list[commands.Command] = []
self._next_stream_id: list[int] = [0, 1, 2, 3]
self._is_client = is_client
# the following fields are accessed by H3Connection
self.configuration = QuicConfiguration(is_client=is_client)
self._quic_logger = None
self._remote_max_datagram_frame_size = 0
def close(
self,
error_code: int = QuicErrorCode.NO_ERROR,
frame_type: int | None = None,
reason_phrase: str = "",
) -> None:
# we'll get closed if a protocol error occurs in `H3Connection.handle_event`
# we note the error on the connection and yield a CloseConnection
# this will then call `QuicConnection.close` with the proper values
# once the `Http3Connection` receives `ConnectionClosed`, it will send out `ProtocolError`
self.pending_commands.append(
CloseQuicConnection(self.conn, error_code, frame_type, reason_phrase)
)
def get_next_available_stream_id(self, is_unidirectional: bool = False) -> int:
# since we always reserve the ID, we have to "find" the next ID like `QuicConnection` does
index = (int(is_unidirectional) << 1) | int(not self._is_client)
stream_id = self._next_stream_id[index]
self._next_stream_id[index] = stream_id + 4
return stream_id
def reset_stream(self, stream_id: int, error_code: int) -> None:
self.pending_commands.append(ResetQuicStream(self.conn, stream_id, error_code))
def stop_send(self, stream_id: int, error_code: int) -> None:
self.pending_commands.append(
StopSendingQuicStream(self.conn, stream_id, error_code)
)
def send_stream_data(
self, stream_id: int, data: bytes, end_stream: bool = False
) -> None:
self.pending_commands.append(
SendQuicStreamData(self.conn, stream_id, data, end_stream)
)
class LayeredH3Connection(H3Connection):
"""
Creates a H3 connection using a fake QUIC connection, which allows layer separation.
Also ensures that headers, data and trailers are sent in that order.
"""
def __init__(
self,
conn: connection.Connection,
is_client: bool,
enable_webtransport: bool = False,
) -> None:
self._mock = MockQuic(conn, is_client)
self._closed_streams: set[int] = set()
"""
We keep track of all stream IDs for which we have requested
STOP_SENDING to silently discard incoming data.
"""
super().__init__(self._mock, enable_webtransport) # type: ignore
# aioquic's constructor sets and then uses _max_push_id.
# This is a hack to forcibly disable it.
@property
def _max_push_id(self) -> int | None:
return None
@_max_push_id.setter
def _max_push_id(self, value):
pass
def _after_send(self, stream_id: int, end_stream: bool) -> None:
# if the stream ended, `QuicConnection` has an assert that no further data is being sent
# to catch this more early on, we set the header state on the `H3Stream`
if end_stream:
self._stream[stream_id].headers_send_state = HeadersState.AFTER_TRAILERS
def _handle_request_or_push_frame(
self,
frame_type: int,
frame_data: bytes | None,
stream: H3Stream,
stream_ended: bool,
) -> list[H3Event]:
# turn HeadersReceived into TrailersReceived for trailers
events = super()._handle_request_or_push_frame(
frame_type, frame_data, stream, stream_ended
)
for index, event in enumerate(events):
if (
isinstance(event, HeadersReceived)
and self._stream[event.stream_id].headers_recv_state
== HeadersState.AFTER_TRAILERS
):
events[index] = TrailersReceived(
event.headers, event.stream_id, event.stream_ended
)
return events
def close_connection(
self,
error_code: int = QuicErrorCode.NO_ERROR,
frame_type: int | None = None,
reason_phrase: str = "",
) -> None:
"""Closes the underlying QUIC connection and ignores any incoming events."""
self._is_done = True
self._quic.close(error_code, frame_type, reason_phrase)
def end_stream(self, stream_id: int) -> None:
"""Ends the given stream if not already done so."""
stream = self._get_or_create_stream(stream_id)
if stream.headers_send_state != HeadersState.AFTER_TRAILERS:
super().send_data(stream_id, b"", end_stream=True)
stream.headers_send_state = HeadersState.AFTER_TRAILERS
def get_next_available_stream_id(self, is_unidirectional: bool = False):
"""Reserves and returns the next available stream ID."""
return self._quic.get_next_available_stream_id(is_unidirectional)
def get_open_stream_ids(self) -> Iterable[int]:
"""Iterates over all non-special open streams"""
return (
stream.stream_id
for stream in self._stream.values()
if (
stream.stream_type is None
and not (
stream.headers_recv_state == HeadersState.AFTER_TRAILERS
and stream.headers_send_state == HeadersState.AFTER_TRAILERS
)
)
)
def handle_connection_closed(self, event: QuicConnectionClosed) -> None:
self._is_done = True
def handle_stream_event(self, event: QuicStreamEvent) -> list[H3Event]:
# don't do anything if we're done
if self._is_done:
return []
elif isinstance(event, (QuicStreamReset, QuicStreamStopSending)):
self.close_stream(
event.stream_id,
event.error_code,
stop_send=isinstance(event, QuicStreamStopSending),
)
stream = self._get_or_create_stream(event.stream_id)
stream.ended = True
stream.headers_recv_state = HeadersState.AFTER_TRAILERS
return [StreamClosed(event.stream_id, event.error_code)]
# convert data events from the QUIC layer back to aioquic events
elif isinstance(event, QuicStreamDataReceived):
# Discard contents if we have already sent STOP_SENDING on this stream.
if event.stream_id in self._closed_streams:
return []
elif self._get_or_create_stream(event.stream_id).ended:
# aioquic will not send us any data events once a stream has ended.
# Instead, it will close the connection. We simulate this here for H3 tests.
self.close_connection(
error_code=QuicErrorCode.PROTOCOL_VIOLATION,
reason_phrase="stream already ended",
)
return []
else:
return self.handle_event(
StreamDataReceived(event.data, event.end_stream, event.stream_id)
)
# should never happen
else: # pragma: no cover
raise AssertionError(f"Unexpected event: {event!r}")
def has_sent_headers(self, stream_id: int) -> bool:
"""Indicates whether headers have been sent over the given stream."""
try:
return self._stream[stream_id].headers_send_state != HeadersState.INITIAL
except KeyError:
return False
def close_stream(
self, stream_id: int, error_code: int, stop_send: bool = True
) -> None:
"""Close a stream that hasn't been closed locally yet."""
if stream_id not in self._closed_streams:
self._closed_streams.add(stream_id)
stream = self._get_or_create_stream(stream_id)
stream.headers_send_state = HeadersState.AFTER_TRAILERS
# https://www.rfc-editor.org/rfc/rfc9000.html#section-3.5-8
# An endpoint that wishes to terminate both directions of
# a bidirectional stream can terminate one direction by
# sending a RESET_STREAM frame, and it can encourage prompt
# termination in the opposite direction by sending a
# STOP_SENDING frame.
self._mock.reset_stream(stream_id=stream_id, error_code=error_code)
if stop_send:
self._mock.stop_send(stream_id=stream_id, error_code=error_code)
def send_data(self, stream_id: int, data: bytes, end_stream: bool = False) -> None:
"""Sends data over the given stream."""
super().send_data(stream_id, data, end_stream)
self._after_send(stream_id, end_stream)
def send_datagram(self, flow_id: int, data: bytes) -> None:
# supporting datagrams would require additional information from the underlying QUIC connection
raise NotImplementedError() # pragma: no cover
def send_headers(
self, stream_id: int, headers: Headers, end_stream: bool = False
) -> None:
"""Sends headers over the given stream."""
# ensure we haven't sent something before
stream = self._get_or_create_stream(stream_id)
if stream.headers_send_state != HeadersState.INITIAL:
raise FrameUnexpected("initial HEADERS frame is not allowed in this state")
super().send_headers(stream_id, headers, end_stream)
self._after_send(stream_id, end_stream)
def send_trailers(self, stream_id: int, trailers: Headers) -> None:
"""Sends trailers over the given stream and ends it."""
# ensure we got some headers first
stream = self._get_or_create_stream(stream_id)
if stream.headers_send_state != HeadersState.AFTER_HEADERS:
raise FrameUnexpected("trailing HEADERS frame is not allowed in this state")
super().send_headers(stream_id, trailers, end_stream=True)
self._after_send(stream_id, end_stream=True)
def transmit(self) -> layer.CommandGenerator[None]:
"""Yields all pending commands for the upper QUIC layer."""
while self._mock.pending_commands:
yield self._mock.pending_commands.pop(0)
__all__ = [
"LayeredH3Connection",
"StreamClosed",
"TrailersReceived",
]

View File

@@ -0,0 +1,105 @@
import time
from logging import DEBUG
from h11._receivebuffer import ReceiveBuffer
from mitmproxy import connection
from mitmproxy import http
from mitmproxy.net.http import http1
from mitmproxy.proxy import commands
from mitmproxy.proxy import context
from mitmproxy.proxy import layer
from mitmproxy.proxy import tunnel
from mitmproxy.proxy.layers import tls
from mitmproxy.proxy.layers.http._hooks import HttpConnectUpstreamHook
from mitmproxy.utils import human
class HttpUpstreamProxy(tunnel.TunnelLayer):
buf: ReceiveBuffer
send_connect: bool
conn: connection.Server
tunnel_connection: connection.Server
def __init__(
self, ctx: context.Context, tunnel_conn: connection.Server, send_connect: bool
):
super().__init__(ctx, tunnel_connection=tunnel_conn, conn=ctx.server)
self.buf = ReceiveBuffer()
self.send_connect = send_connect
@classmethod
def make(cls, ctx: context.Context, send_connect: bool) -> tunnel.LayerStack:
assert ctx.server.via
scheme, address = ctx.server.via
assert scheme in ("http", "https")
http_proxy = connection.Server(address=address)
stack = tunnel.LayerStack()
if scheme == "https":
http_proxy.alpn_offers = tls.HTTP1_ALPNS
http_proxy.sni = address[0]
stack /= tls.ServerTLSLayer(ctx, http_proxy)
stack /= cls(ctx, http_proxy, send_connect)
return stack
def start_handshake(self) -> layer.CommandGenerator[None]:
if not self.send_connect:
return (yield from super().start_handshake())
assert self.conn.address
flow = http.HTTPFlow(self.context.client, self.tunnel_connection)
authority = (
self.conn.address[0].encode("idna") + f":{self.conn.address[1]}".encode()
)
headers = http.Headers()
if self.context.options.http_connect_send_host_header:
headers.insert(0, b"Host", authority)
flow.request = http.Request(
host=self.conn.address[0],
port=self.conn.address[1],
method=b"CONNECT",
scheme=b"",
authority=authority,
path=b"",
http_version=b"HTTP/1.1",
headers=headers,
content=b"",
trailers=None,
timestamp_start=time.time(),
timestamp_end=time.time(),
)
yield HttpConnectUpstreamHook(flow)
raw = http1.assemble_request(flow.request)
yield commands.SendData(self.tunnel_connection, raw)
def receive_handshake_data(
self, data: bytes
) -> layer.CommandGenerator[tuple[bool, str | None]]:
if not self.send_connect:
return (yield from super().receive_handshake_data(data))
self.buf += data
response_head = self.buf.maybe_extract_lines()
if response_head:
try:
response = http1.read_response_head([bytes(x) for x in response_head])
except ValueError as e:
proxyaddr = human.format_address(self.tunnel_connection.address)
yield commands.Log(f"{proxyaddr}: {e}")
return False, f"Error connecting to {proxyaddr}: {e}"
if 200 <= response.status_code < 300:
if self.buf:
yield from self.receive_data(bytes(self.buf))
del self.buf
return True, None
else:
proxyaddr = human.format_address(self.tunnel_connection.address)
raw_resp = b"\n".join(response_head)
yield commands.Log(f"{proxyaddr}: {raw_resp!r}", DEBUG)
return (
False,
f"Upstream proxy {proxyaddr} refused HTTP CONNECT request: {response.status_code} {response.reason}",
)
else:
return False, None

View File

@@ -0,0 +1,303 @@
from __future__ import annotations
import socket
import struct
import sys
from abc import ABCMeta
from collections.abc import Callable
from dataclasses import dataclass
from mitmproxy import connection
from mitmproxy.proxy import commands
from mitmproxy.proxy import events
from mitmproxy.proxy import layer
from mitmproxy.proxy.commands import StartHook
from mitmproxy.proxy.mode_specs import ReverseMode
from mitmproxy.proxy.utils import expect
if sys.version_info < (3, 11):
from typing_extensions import assert_never
else:
from typing import assert_never
class HttpProxy(layer.Layer):
@expect(events.Start)
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
child_layer = layer.NextLayer(self.context)
self._handle_event = child_layer.handle_event
yield from child_layer.handle_event(event)
class HttpUpstreamProxy(layer.Layer):
@expect(events.Start)
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
child_layer = layer.NextLayer(self.context)
self._handle_event = child_layer.handle_event
yield from child_layer.handle_event(event)
class DestinationKnown(layer.Layer, metaclass=ABCMeta):
"""Base layer for layers that gather connection destination info and then delegate."""
child_layer: layer.Layer
def finish_start(self) -> layer.CommandGenerator[str | None]:
if (
self.context.options.connection_strategy == "eager"
and self.context.server.address
and self.context.server.transport_protocol == "tcp"
):
err = yield commands.OpenConnection(self.context.server)
if err:
self._handle_event = self.done # type: ignore
return err
self._handle_event = self.child_layer.handle_event # type: ignore
yield from self.child_layer.handle_event(events.Start())
return None
@expect(events.DataReceived, events.ConnectionClosed)
def done(self, _) -> layer.CommandGenerator[None]:
yield from ()
class ReverseProxy(DestinationKnown):
@expect(events.Start)
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
spec = self.context.client.proxy_mode
assert isinstance(spec, ReverseMode)
self.context.server.address = spec.address
self.child_layer = layer.NextLayer(self.context)
# For secure protocols, set SNI if keep_host_header is false
match spec.scheme:
case "http3" | "quic" | "https" | "tls" | "dtls":
if not self.context.options.keep_host_header:
self.context.server.sni = spec.address[0]
case "tcp" | "http" | "udp" | "dns":
pass
case _: # pragma: no cover
assert_never(spec.scheme)
err = yield from self.finish_start()
if err:
yield commands.CloseConnection(self.context.client)
class TransparentProxy(DestinationKnown):
@expect(events.Start)
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
assert self.context.server.address, "No server address set."
self.child_layer = layer.NextLayer(self.context)
err = yield from self.finish_start()
if err:
yield commands.CloseConnection(self.context.client)
SOCKS5_VERSION = 0x05
SOCKS5_METHOD_NO_AUTHENTICATION_REQUIRED = 0x00
SOCKS5_METHOD_USER_PASSWORD_AUTHENTICATION = 0x02
SOCKS5_METHOD_NO_ACCEPTABLE_METHODS = 0xFF
SOCKS5_ATYP_IPV4_ADDRESS = 0x01
SOCKS5_ATYP_DOMAINNAME = 0x03
SOCKS5_ATYP_IPV6_ADDRESS = 0x04
SOCKS5_REP_HOST_UNREACHABLE = 0x04
SOCKS5_REP_COMMAND_NOT_SUPPORTED = 0x07
SOCKS5_REP_ADDRESS_TYPE_NOT_SUPPORTED = 0x08
@dataclass
class Socks5AuthData:
client_conn: connection.Client
username: str
password: str
valid: bool = False
@dataclass
class Socks5AuthHook(StartHook):
"""
Mitmproxy has received username/password SOCKS5 credentials.
This hook decides whether they are valid by setting `data.valid`.
"""
data: Socks5AuthData
class Socks5Proxy(DestinationKnown):
buf: bytes = b""
def socks_err(
self,
message: str,
reply_code: int | None = None,
) -> layer.CommandGenerator[None]:
if reply_code is not None:
yield commands.SendData(
self.context.client,
bytes([SOCKS5_VERSION, reply_code])
+ b"\x00\x01\x00\x00\x00\x00\x00\x00",
)
yield commands.CloseConnection(self.context.client)
yield commands.Log(message)
self._handle_event = self.done
@expect(events.Start, events.DataReceived, events.ConnectionClosed)
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, events.Start):
pass
elif isinstance(event, events.DataReceived):
self.buf += event.data
yield from self.state()
elif isinstance(event, events.ConnectionClosed):
if self.buf:
yield commands.Log(
f"Client closed connection before completing SOCKS5 handshake: {self.buf!r}"
)
yield commands.CloseConnection(event.connection)
else:
raise AssertionError(f"Unknown event: {event}")
def state_greet(self) -> layer.CommandGenerator[None]:
if len(self.buf) < 2:
return
if self.buf[0] != SOCKS5_VERSION:
if self.buf[:3].isupper():
guess = "Probably not a SOCKS request but a regular HTTP request. "
else:
guess = ""
yield from self.socks_err(
guess + "Invalid SOCKS version. Expected 0x05, got 0x%x" % self.buf[0]
)
return
n_methods = self.buf[1]
if len(self.buf) < 2 + n_methods:
return
if "proxyauth" in self.context.options and self.context.options.proxyauth:
method = SOCKS5_METHOD_USER_PASSWORD_AUTHENTICATION
self.state = self.state_auth
else:
method = SOCKS5_METHOD_NO_AUTHENTICATION_REQUIRED
self.state = self.state_connect
if method not in self.buf[2 : 2 + n_methods]:
method_str = (
"user/password"
if method == SOCKS5_METHOD_USER_PASSWORD_AUTHENTICATION
else "no"
)
yield from self.socks_err(
f"Client does not support SOCKS5 with {method_str} authentication.",
SOCKS5_METHOD_NO_ACCEPTABLE_METHODS,
)
return
yield commands.SendData(self.context.client, bytes([SOCKS5_VERSION, method]))
self.buf = self.buf[2 + n_methods :]
yield from self.state()
state: Callable[..., layer.CommandGenerator[None]] = state_greet
def state_auth(self) -> layer.CommandGenerator[None]:
if len(self.buf) < 3:
return
# Parsing username and password, which is somewhat atrocious
user_len = self.buf[1]
if len(self.buf) < 3 + user_len:
return
pass_len = self.buf[2 + user_len]
if len(self.buf) < 3 + user_len + pass_len:
return
user = self.buf[2 : (2 + user_len)].decode("utf-8", "backslashreplace")
password = self.buf[(3 + user_len) : (3 + user_len + pass_len)].decode(
"utf-8", "backslashreplace"
)
data = Socks5AuthData(self.context.client, user, password)
yield Socks5AuthHook(data)
if not data.valid:
# The VER field contains the current **version of the subnegotiation**, which is X'01'.
yield commands.SendData(self.context.client, b"\x01\x01")
yield from self.socks_err("authentication failed")
return
yield commands.SendData(self.context.client, b"\x01\x00")
self.buf = self.buf[3 + user_len + pass_len :]
self.state = self.state_connect
yield from self.state()
def state_connect(self) -> layer.CommandGenerator[None]:
# Parse Connect Request
if len(self.buf) < 5:
return
if self.buf[:3] != b"\x05\x01\x00":
yield from self.socks_err(
f"Unsupported SOCKS5 request: {self.buf!r}",
SOCKS5_REP_COMMAND_NOT_SUPPORTED,
)
return
# Determine message length
atyp = self.buf[3]
message_len: int
if atyp == SOCKS5_ATYP_IPV4_ADDRESS:
message_len = 4 + 4 + 2
elif atyp == SOCKS5_ATYP_IPV6_ADDRESS:
message_len = 4 + 16 + 2
elif atyp == SOCKS5_ATYP_DOMAINNAME:
message_len = 4 + 1 + self.buf[4] + 2
else:
yield from self.socks_err(
f"Unknown address type: {atyp}", SOCKS5_REP_ADDRESS_TYPE_NOT_SUPPORTED
)
return
# Do we have enough bytes yet?
if len(self.buf) < message_len:
return
# Parse host and port
msg, self.buf = self.buf[:message_len], self.buf[message_len:]
host: str
if atyp == SOCKS5_ATYP_IPV4_ADDRESS:
host = socket.inet_ntop(socket.AF_INET, msg[4:-2])
elif atyp == SOCKS5_ATYP_IPV6_ADDRESS:
host = socket.inet_ntop(socket.AF_INET6, msg[4:-2])
else:
host_bytes = msg[5:-2]
host = host_bytes.decode("ascii", "replace")
(port,) = struct.unpack("!H", msg[-2:])
# We now have all we need, let's get going.
self.context.server.address = (host, port)
self.child_layer = layer.NextLayer(self.context)
# this already triggers the child layer's Start event,
# but that's not a problem in practice...
err = yield from self.finish_start()
if err:
yield commands.SendData(
self.context.client, b"\x05\x04\x00\x01\x00\x00\x00\x00\x00\x00"
)
yield commands.CloseConnection(self.context.client)
else:
yield commands.SendData(
self.context.client, b"\x05\x00\x00\x01\x00\x00\x00\x00\x00\x00"
)
if self.buf:
yield from self.child_layer.handle_event(
events.DataReceived(self.context.client, self.buf)
)
del self.buf

View File

@@ -0,0 +1,41 @@
from ._client_hello_parser import quic_parse_client_hello_from_datagrams
from ._commands import CloseQuicConnection
from ._commands import ResetQuicStream
from ._commands import SendQuicStreamData
from ._commands import StopSendingQuicStream
from ._events import QuicConnectionClosed
from ._events import QuicStreamDataReceived
from ._events import QuicStreamEvent
from ._events import QuicStreamReset
from ._events import QuicStreamStopSending
from ._hooks import QuicStartClientHook
from ._hooks import QuicStartServerHook
from ._hooks import QuicTlsData
from ._hooks import QuicTlsSettings
from ._raw_layers import QuicStreamLayer
from ._raw_layers import RawQuicLayer
from ._stream_layers import ClientQuicLayer
from ._stream_layers import error_code_to_str
from ._stream_layers import ServerQuicLayer
__all__ = [
"quic_parse_client_hello_from_datagrams",
"CloseQuicConnection",
"ResetQuicStream",
"SendQuicStreamData",
"StopSendingQuicStream",
"QuicConnectionClosed",
"QuicStreamDataReceived",
"QuicStreamEvent",
"QuicStreamReset",
"QuicStreamStopSending",
"QuicStartClientHook",
"QuicStartServerHook",
"QuicTlsData",
"QuicTlsSettings",
"QuicStreamLayer",
"RawQuicLayer",
"ClientQuicLayer",
"error_code_to_str",
"ServerQuicLayer",
]

View File

@@ -0,0 +1,111 @@
"""
This module contains a very terrible QUIC client hello parser.
Nothing is more permanent than a temporary solution!
"""
from __future__ import annotations
import time
from dataclasses import dataclass
from typing import Optional
from aioquic.buffer import Buffer as QuicBuffer
from aioquic.quic.configuration import QuicConfiguration
from aioquic.quic.connection import QuicConnection
from aioquic.quic.connection import QuicConnectionError
from aioquic.quic.logger import QuicLogger
from aioquic.quic.packet import PACKET_TYPE_INITIAL
from aioquic.quic.packet import pull_quic_header
from aioquic.tls import HandshakeType
from mitmproxy.tls import ClientHello
@dataclass
class QuicClientHello(Exception):
"""Helper error only used in `quic_parse_client_hello_from_datagrams`."""
data: bytes
def quic_parse_client_hello_from_datagrams(
datagrams: list[bytes],
) -> Optional[ClientHello]:
"""
Check if the supplied bytes contain a full ClientHello message,
and if so, parse it.
Args:
- msgs: list of ClientHello fragments received from client
Returns:
- A ClientHello object on success
- None, if the QUIC record is incomplete
Raises:
- A ValueError, if the passed ClientHello is invalid
"""
# ensure the first packet is indeed the initial one
buffer = QuicBuffer(data=datagrams[0])
header = pull_quic_header(buffer, 8)
if header.packet_type != PACKET_TYPE_INITIAL:
raise ValueError("Packet is not initial one.")
# patch aioquic to intercept the client hello
quic = QuicConnection(
configuration=QuicConfiguration(
is_client=False,
certificate="",
private_key="",
quic_logger=QuicLogger(),
),
original_destination_connection_id=header.destination_cid,
)
_initialize = quic._initialize
def server_handle_hello_replacement(
input_buf: QuicBuffer,
initial_buf: QuicBuffer,
handshake_buf: QuicBuffer,
onertt_buf: QuicBuffer,
) -> None:
assert input_buf.pull_uint8() == HandshakeType.CLIENT_HELLO
length = 0
for b in input_buf.pull_bytes(3):
length = (length << 8) | b
offset = input_buf.tell()
raise QuicClientHello(input_buf.data_slice(offset, offset + length))
def initialize_replacement(peer_cid: bytes) -> None:
try:
return _initialize(peer_cid)
finally:
quic.tls._server_handle_hello = server_handle_hello_replacement # type: ignore
quic._initialize = initialize_replacement # type: ignore
try:
for dgm in datagrams:
quic.receive_datagram(dgm, ("0.0.0.0", 0), now=time.time())
except QuicClientHello as hello:
try:
return ClientHello(hello.data)
except EOFError as e:
raise ValueError("Invalid ClientHello data.") from e
except QuicConnectionError as e:
raise ValueError(e.reason_phrase) from e
quic_logger = quic._configuration.quic_logger
assert isinstance(quic_logger, QuicLogger)
traces = quic_logger.to_dict().get("traces")
assert isinstance(traces, list)
for trace in traces:
quic_events = trace.get("events")
for event in quic_events:
if event["name"] == "transport:packet_dropped":
raise ValueError(
f"Invalid ClientHello packet: {event['data']['trigger']}"
)
return None # pragma: no cover # FIXME: this should have test coverage

View File

@@ -0,0 +1,92 @@
from __future__ import annotations
from mitmproxy import connection
from mitmproxy.proxy import commands
class QuicStreamCommand(commands.ConnectionCommand):
"""Base class for all QUIC stream commands."""
stream_id: int
"""The ID of the stream the command was issued for."""
def __init__(self, connection: connection.Connection, stream_id: int) -> None:
super().__init__(connection)
self.stream_id = stream_id
class SendQuicStreamData(QuicStreamCommand):
"""Command that sends data on a stream."""
data: bytes
"""The data which should be sent."""
end_stream: bool
"""Whether the FIN bit should be set in the STREAM frame."""
def __init__(
self,
connection: connection.Connection,
stream_id: int,
data: bytes,
end_stream: bool = False,
) -> None:
super().__init__(connection, stream_id)
self.data = data
self.end_stream = end_stream
def __repr__(self):
target = repr(self.connection).partition("(")[0].lower()
end_stream = "[end_stream] " if self.end_stream else ""
return f"SendQuicStreamData({target} on {self.stream_id}, {end_stream}{self.data!r})"
class ResetQuicStream(QuicStreamCommand):
"""Abruptly terminate the sending part of a stream."""
error_code: int
"""An error code indicating why the stream is being reset."""
def __init__(
self, connection: connection.Connection, stream_id: int, error_code: int
) -> None:
super().__init__(connection, stream_id)
self.error_code = error_code
class StopSendingQuicStream(QuicStreamCommand):
"""Request termination of the receiving part of a stream."""
error_code: int
"""An error code indicating why the stream is being stopped."""
def __init__(
self, connection: connection.Connection, stream_id: int, error_code: int
) -> None:
super().__init__(connection, stream_id)
self.error_code = error_code
class CloseQuicConnection(commands.CloseConnection):
"""Close a QUIC connection."""
error_code: int
"The error code which was specified when closing the connection."
frame_type: int | None
"The frame type which caused the connection to be closed, or `None`."
reason_phrase: str
"The human-readable reason for which the connection was closed."
# XXX: A bit much boilerplate right now. Should switch to dataclasses.
def __init__(
self,
conn: connection.Connection,
error_code: int,
frame_type: int | None,
reason_phrase: str,
) -> None:
super().__init__(conn)
self.error_code = error_code
self.frame_type = frame_type
self.reason_phrase = reason_phrase

View File

@@ -0,0 +1,70 @@
from __future__ import annotations
from dataclasses import dataclass
from mitmproxy import connection
from mitmproxy.proxy import events
@dataclass
class QuicStreamEvent(events.ConnectionEvent):
"""Base class for all QUIC stream events."""
stream_id: int
"""The ID of the stream the event was fired for."""
@dataclass
class QuicStreamDataReceived(QuicStreamEvent):
"""Event that is fired whenever data is received on a stream."""
data: bytes
"""The data which was received."""
end_stream: bool
"""Whether the STREAM frame had the FIN bit set."""
def __repr__(self):
target = repr(self.connection).partition("(")[0].lower()
end_stream = "[end_stream] " if self.end_stream else ""
return f"QuicStreamDataReceived({target} on {self.stream_id}, {end_stream}{self.data!r})"
@dataclass
class QuicStreamReset(QuicStreamEvent):
"""Event that is fired when the remote peer resets a stream."""
error_code: int
"""The error code that triggered the reset."""
@dataclass
class QuicStreamStopSending(QuicStreamEvent):
"""Event that is fired when the remote peer sends a STOP_SENDING frame."""
error_code: int
"""The application protocol error code."""
class QuicConnectionClosed(events.ConnectionClosed):
"""QUIC connection has been closed."""
error_code: int
"The error code which was specified when closing the connection."
frame_type: int | None
"The frame type which caused the connection to be closed, or `None`."
reason_phrase: str
"The human-readable reason for which the connection was closed."
def __init__(
self,
conn: connection.Connection,
error_code: int,
frame_type: int | None,
reason_phrase: str,
) -> None:
super().__init__(conn)
self.error_code = error_code
self.frame_type = frame_type
self.reason_phrase = reason_phrase

View File

@@ -0,0 +1,77 @@
from __future__ import annotations
from dataclasses import dataclass
from dataclasses import field
from ssl import VerifyMode
from aioquic.tls import CipherSuite
from cryptography import x509
from cryptography.hazmat.primitives.asymmetric import dsa
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.asymmetric import rsa
from mitmproxy.proxy import commands
from mitmproxy.tls import TlsData
@dataclass
class QuicTlsSettings:
"""
Settings necessary to establish QUIC's TLS context.
"""
alpn_protocols: list[str] | None = None
"""A list of supported ALPN protocols."""
certificate: x509.Certificate | None = None
"""The certificate to use for the connection."""
certificate_chain: list[x509.Certificate] = field(default_factory=list)
"""A list of additional certificates to send to the peer."""
certificate_private_key: (
dsa.DSAPrivateKey | ec.EllipticCurvePrivateKey | rsa.RSAPrivateKey | None
) = None
"""The certificate's private key."""
cipher_suites: list[CipherSuite] | None = None
"""An optional list of allowed/advertised cipher suites."""
ca_path: str | None = None
"""An optional path to a directory that contains the necessary information to verify the peer certificate."""
ca_file: str | None = None
"""An optional path to a PEM file that will be used to verify the peer certificate."""
verify_mode: VerifyMode | None = None
"""An optional flag that specifies how/if the peer's certificate should be validated."""
@dataclass
class QuicTlsData(TlsData):
"""
Event data for `quic_start_client` and `quic_start_server` event hooks.
"""
settings: QuicTlsSettings | None = None
"""
The associated `QuicTlsSettings` object.
This will be set by an addon in the `quic_start_*` event hooks.
"""
@dataclass
class QuicStartClientHook(commands.StartHook):
"""
TLS negotiation between mitmproxy and a client over QUIC is about to start.
An addon is expected to initialize data.settings.
(by default, this is done by `mitmproxy.addons.tlsconfig`)
"""
data: QuicTlsData
@dataclass
class QuicStartServerHook(commands.StartHook):
"""
TLS negotiation between mitmproxy and a server over QUIC is about to start.
An addon is expected to initialize data.settings.
(by default, this is done by `mitmproxy.addons.tlsconfig`)
"""
data: QuicTlsData

View File

@@ -0,0 +1,433 @@
"""
This module contains the proxy layers for raw QUIC proxying.
This is used if we want to speak QUIC, but we do not want to do HTTP.
"""
from __future__ import annotations
import time
from aioquic.quic.connection import QuicErrorCode
from aioquic.quic.connection import stream_is_client_initiated
from aioquic.quic.connection import stream_is_unidirectional
from ._commands import CloseQuicConnection
from ._commands import ResetQuicStream
from ._commands import SendQuicStreamData
from ._commands import StopSendingQuicStream
from ._events import QuicConnectionClosed
from ._events import QuicStreamDataReceived
from ._events import QuicStreamEvent
from ._events import QuicStreamReset
from mitmproxy import connection
from mitmproxy.connection import Connection
from mitmproxy.proxy import commands
from mitmproxy.proxy import context
from mitmproxy.proxy import events
from mitmproxy.proxy import layer
from mitmproxy.proxy import tunnel
from mitmproxy.proxy.layers.tcp import TCPLayer
from mitmproxy.proxy.layers.udp import UDPLayer
class QuicStreamNextLayer(layer.NextLayer):
"""`NextLayer` variant that callbacks `QuicStreamLayer` after layer decision."""
def __init__(
self,
context: context.Context,
stream: QuicStreamLayer,
ask_on_start: bool = False,
) -> None:
super().__init__(context, ask_on_start)
self._stream = stream
self._layer: layer.Layer | None = None
@property # type: ignore
def layer(self) -> layer.Layer | None: # type: ignore
return self._layer
@layer.setter
def layer(self, value: layer.Layer | None) -> None:
self._layer = value
if self._layer:
self._stream.refresh_metadata()
class QuicStreamLayer(layer.Layer):
"""
Layer for QUIC streams.
Serves as a marker for NextLayer and keeps track of the connection states.
"""
client: connection.Client
"""Virtual client connection for this stream. Use this in QuicRawLayer instead of `context.client`."""
server: connection.Server
"""Virtual server connection for this stream. Use this in QuicRawLayer instead of `context.server`."""
child_layer: layer.Layer
"""The stream's child layer."""
def __init__(
self, context: context.Context, force_raw: bool, stream_id: int
) -> None:
# we mustn't reuse the client from the QUIC connection, as the state and protocol differs
self.client = context.client = context.client.copy()
self.client.transport_protocol = "tcp"
self.client.state = connection.ConnectionState.OPEN
# unidirectional client streams are not fully open, set the appropriate state
if stream_is_unidirectional(stream_id):
self.client.state = (
connection.ConnectionState.CAN_READ
if stream_is_client_initiated(stream_id)
else connection.ConnectionState.CAN_WRITE
)
self._client_stream_id = stream_id
# start with a closed server
self.server = context.server = connection.Server(
address=context.server.address,
transport_protocol="tcp",
)
self._server_stream_id: int | None = None
super().__init__(context)
self.child_layer = (
TCPLayer(context) if force_raw else QuicStreamNextLayer(context, self)
)
self.refresh_metadata()
# we don't handle any events, pass everything to the child layer
self.handle_event = self.child_layer.handle_event # type: ignore
self._handle_event = self.child_layer._handle_event # type: ignore
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
raise AssertionError # pragma: no cover
def open_server_stream(self, server_stream_id) -> None:
assert self._server_stream_id is None
self._server_stream_id = server_stream_id
self.server.timestamp_start = time.time()
self.server.state = (
(
connection.ConnectionState.CAN_WRITE
if stream_is_client_initiated(server_stream_id)
else connection.ConnectionState.CAN_READ
)
if stream_is_unidirectional(server_stream_id)
else connection.ConnectionState.OPEN
)
self.refresh_metadata()
def refresh_metadata(self) -> None:
# find the first transport layer
child_layer: layer.Layer | None = self.child_layer
while True:
if isinstance(child_layer, layer.NextLayer):
child_layer = child_layer.layer
elif isinstance(child_layer, tunnel.TunnelLayer):
child_layer = child_layer.child_layer
else:
break # pragma: no cover
if isinstance(child_layer, (UDPLayer, TCPLayer)) and child_layer.flow:
child_layer.flow.metadata["quic_is_unidirectional"] = (
stream_is_unidirectional(self._client_stream_id)
)
child_layer.flow.metadata["quic_initiator"] = (
"client"
if stream_is_client_initiated(self._client_stream_id)
else "server"
)
child_layer.flow.metadata["quic_stream_id_client"] = self._client_stream_id
child_layer.flow.metadata["quic_stream_id_server"] = self._server_stream_id
def stream_id(self, client: bool) -> int | None:
return self._client_stream_id if client else self._server_stream_id
class RawQuicLayer(layer.Layer):
"""
This layer is responsible for de-multiplexing QUIC streams into an individual layer stack per stream.
"""
force_raw: bool
"""Indicates whether traffic should be treated as raw TCP/UDP without further protocol detection."""
datagram_layer: layer.Layer
"""
The layer that is handling datagrams over QUIC. It's like a child_layer, but with a forked context.
Instead of having a datagram-equivalent for all `QuicStream*` classes, we use `SendData` and `DataReceived` instead.
There is also no need for another `NextLayer` marker, as a missing `QuicStreamLayer` implies UDP,
and the connection state is the same as the one of the underlying QUIC connection.
"""
client_stream_ids: dict[int, QuicStreamLayer]
"""Maps stream IDs from the client connection to stream layers."""
server_stream_ids: dict[int, QuicStreamLayer]
"""Maps stream IDs from the server connection to stream layers."""
connections: dict[connection.Connection, layer.Layer]
"""Maps connections to layers."""
command_sources: dict[commands.Command, layer.Layer]
"""Keeps track of blocking commands and wakeup requests."""
next_stream_id: list[int]
"""List containing the next stream ID for all four is_unidirectional/is_client combinations."""
def __init__(self, context: context.Context, force_raw: bool = False) -> None:
super().__init__(context)
self.force_raw = force_raw
self.datagram_layer = (
UDPLayer(self.context.fork())
if force_raw
else layer.NextLayer(self.context.fork())
)
self.client_stream_ids = {}
self.server_stream_ids = {}
self.connections = {
context.client: self.datagram_layer,
context.server: self.datagram_layer,
}
self.command_sources = {}
self.next_stream_id = [0, 1, 2, 3]
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
# we treat the datagram layer as child layer, so forward Start
if isinstance(event, events.Start):
if self.context.server.timestamp_start is None:
err = yield commands.OpenConnection(self.context.server)
if err:
yield commands.CloseConnection(self.context.client)
self._handle_event = self.done # type: ignore
return
yield from self.event_to_child(self.datagram_layer, event)
# properly forward completion events based on their command
elif isinstance(event, events.CommandCompleted):
yield from self.event_to_child(
self.command_sources.pop(event.command), event
)
# route injected messages based on their connections (prefer client, fallback to server)
elif isinstance(event, events.MessageInjected):
if event.flow.client_conn in self.connections:
yield from self.event_to_child(
self.connections[event.flow.client_conn], event
)
elif event.flow.server_conn in self.connections:
yield from self.event_to_child(
self.connections[event.flow.server_conn], event
)
else:
raise AssertionError(f"Flow not associated: {event.flow!r}")
# handle stream events targeting this context
elif isinstance(event, QuicStreamEvent) and (
event.connection is self.context.client
or event.connection is self.context.server
):
from_client = event.connection is self.context.client
# fetch or create the layer
stream_ids = (
self.client_stream_ids if from_client else self.server_stream_ids
)
if event.stream_id in stream_ids:
stream_layer = stream_ids[event.stream_id]
else:
# ensure we haven't just forgotten to register the ID
assert stream_is_client_initiated(event.stream_id) == from_client
# for server-initiated streams we need to open the client as well
if from_client:
client_stream_id = event.stream_id
server_stream_id = None
else:
client_stream_id = self.get_next_available_stream_id(
is_client=False,
is_unidirectional=stream_is_unidirectional(event.stream_id),
)
server_stream_id = event.stream_id
# create, register and start the layer
stream_layer = QuicStreamLayer(
self.context.fork(),
force_raw=self.force_raw,
stream_id=client_stream_id,
)
self.client_stream_ids[client_stream_id] = stream_layer
if server_stream_id is not None:
stream_layer.open_server_stream(server_stream_id)
self.server_stream_ids[server_stream_id] = stream_layer
self.connections[stream_layer.client] = stream_layer
self.connections[stream_layer.server] = stream_layer
yield from self.event_to_child(stream_layer, events.Start())
# forward data and close events
conn: Connection = (
stream_layer.client if from_client else stream_layer.server
)
if isinstance(event, QuicStreamDataReceived):
if event.data:
yield from self.event_to_child(
stream_layer, events.DataReceived(conn, event.data)
)
if event.end_stream:
yield from self.close_stream_layer(stream_layer, from_client)
elif isinstance(event, QuicStreamReset):
# preserve stream resets
for command in self.close_stream_layer(stream_layer, from_client):
if (
isinstance(command, SendQuicStreamData)
and command.stream_id == stream_layer.stream_id(not from_client)
and command.end_stream
and not command.data
):
yield ResetQuicStream(
command.connection, command.stream_id, event.error_code
)
else:
yield command
else:
raise AssertionError(f"Unexpected stream event: {event!r}")
# handle close events that target this context
elif isinstance(event, QuicConnectionClosed) and (
event.connection is self.context.client
or event.connection is self.context.server
):
from_client = event.connection is self.context.client
other_conn = self.context.server if from_client else self.context.client
# be done if both connections are closed
if other_conn.connected:
yield CloseQuicConnection(
other_conn, event.error_code, event.frame_type, event.reason_phrase
)
else:
self._handle_event = self.done # type: ignore
# always forward to the datagram layer and swallow `CloseConnection` commands
for command in self.event_to_child(self.datagram_layer, event):
if (
not isinstance(command, commands.CloseConnection)
or command.connection is not other_conn
):
yield command
# forward to either the client or server connection of stream layers and swallow empty stream end
for conn, child_layer in self.connections.items():
if isinstance(child_layer, QuicStreamLayer) and (
(conn is child_layer.client)
if from_client
else (conn is child_layer.server)
):
conn.state &= ~connection.ConnectionState.CAN_WRITE
for command in self.close_stream_layer(child_layer, from_client):
if not isinstance(command, SendQuicStreamData) or command.data:
yield command
# all other connection events are routed to their corresponding layer
elif isinstance(event, events.ConnectionEvent):
yield from self.event_to_child(self.connections[event.connection], event)
else:
raise AssertionError(f"Unexpected event: {event!r}")
def close_stream_layer(
self, stream_layer: QuicStreamLayer, client: bool
) -> layer.CommandGenerator[None]:
"""Closes the incoming part of a connection."""
conn = stream_layer.client if client else stream_layer.server
conn.state &= ~connection.ConnectionState.CAN_READ
assert conn.timestamp_start is not None
if conn.timestamp_end is None:
conn.timestamp_end = time.time()
yield from self.event_to_child(stream_layer, events.ConnectionClosed(conn))
def event_to_child(
self, child_layer: layer.Layer, event: events.Event
) -> layer.CommandGenerator[None]:
"""Forwards events to child layers and translates commands."""
for command in child_layer.handle_event(event):
# intercept commands for streams connections
if (
isinstance(child_layer, QuicStreamLayer)
and isinstance(command, commands.ConnectionCommand)
and (
command.connection is child_layer.client
or command.connection is child_layer.server
)
):
# get the target connection and stream ID
to_client = command.connection is child_layer.client
quic_conn = self.context.client if to_client else self.context.server
stream_id = child_layer.stream_id(to_client)
# write data and check CloseConnection wasn't called before
if isinstance(command, commands.SendData):
assert stream_id is not None
if command.connection.state & connection.ConnectionState.CAN_WRITE:
yield SendQuicStreamData(quic_conn, stream_id, command.data)
# send a FIN and optionally also a STOP frame
elif isinstance(command, commands.CloseConnection):
assert stream_id is not None
if command.connection.state & connection.ConnectionState.CAN_WRITE:
command.connection.state &= (
~connection.ConnectionState.CAN_WRITE
)
yield SendQuicStreamData(
quic_conn, stream_id, b"", end_stream=True
)
# XXX: Use `command.connection.state & connection.ConnectionState.CAN_READ` instead?
only_close_our_half = (
isinstance(command, commands.CloseTcpConnection)
and command.half_close
)
if not only_close_our_half:
if stream_is_client_initiated(
stream_id
) == to_client or not stream_is_unidirectional(stream_id):
yield StopSendingQuicStream(
quic_conn, stream_id, QuicErrorCode.NO_ERROR
)
yield from self.close_stream_layer(child_layer, to_client)
# open server connections by reserving the next stream ID
elif isinstance(command, commands.OpenConnection):
assert not to_client
assert stream_id is None
client_stream_id = child_layer.stream_id(client=True)
assert client_stream_id is not None
stream_id = self.get_next_available_stream_id(
is_client=True,
is_unidirectional=stream_is_unidirectional(client_stream_id),
)
child_layer.open_server_stream(stream_id)
self.server_stream_ids[stream_id] = child_layer
yield from self.event_to_child(
child_layer, events.OpenConnectionCompleted(command, None)
)
else:
raise AssertionError(
f"Unexpected stream connection command: {command!r}"
)
# remember blocking and wakeup commands
else:
if command.blocking or isinstance(command, commands.RequestWakeup):
self.command_sources[command] = child_layer
if isinstance(command, commands.OpenConnection):
self.connections[command.connection] = child_layer
yield command
def get_next_available_stream_id(
self, is_client: bool, is_unidirectional: bool = False
) -> int:
index = (int(is_unidirectional) << 1) | int(not is_client)
stream_id = self.next_stream_id[index]
self.next_stream_id[index] = stream_id + 4
return stream_id
def done(self, _) -> layer.CommandGenerator[None]: # pragma: no cover
yield from ()

View File

@@ -0,0 +1,638 @@
"""
This module contains the client and server proxy layers for QUIC streams
which decrypt and encrypt traffic. Decrypted stream data is then forwarded
to either the raw layers, or the HTTP/3 client in ../http/_http3.py.
"""
from __future__ import annotations
import time
from collections.abc import Callable
from logging import DEBUG
from logging import ERROR
from logging import WARNING
from aioquic.buffer import Buffer as QuicBuffer
from aioquic.h3.connection import ErrorCode as H3ErrorCode
from aioquic.quic import events as quic_events
from aioquic.quic.configuration import QuicConfiguration
from aioquic.quic.connection import QuicConnection
from aioquic.quic.connection import QuicConnectionState
from aioquic.quic.connection import QuicErrorCode
from aioquic.quic.packet import encode_quic_version_negotiation
from aioquic.quic.packet import PACKET_TYPE_INITIAL
from aioquic.quic.packet import pull_quic_header
from cryptography import x509
from ._client_hello_parser import quic_parse_client_hello_from_datagrams
from ._commands import CloseQuicConnection
from ._commands import QuicStreamCommand
from ._commands import ResetQuicStream
from ._commands import SendQuicStreamData
from ._commands import StopSendingQuicStream
from ._events import QuicConnectionClosed
from ._events import QuicStreamDataReceived
from ._events import QuicStreamReset
from ._events import QuicStreamStopSending
from ._hooks import QuicStartClientHook
from ._hooks import QuicStartServerHook
from ._hooks import QuicTlsData
from ._hooks import QuicTlsSettings
from mitmproxy import certs
from mitmproxy import connection
from mitmproxy import ctx
from mitmproxy.net import tls
from mitmproxy.proxy import commands
from mitmproxy.proxy import context
from mitmproxy.proxy import events
from mitmproxy.proxy import layer
from mitmproxy.proxy import tunnel
from mitmproxy.proxy.layers.tls import TlsClienthelloHook
from mitmproxy.proxy.layers.tls import TlsEstablishedClientHook
from mitmproxy.proxy.layers.tls import TlsEstablishedServerHook
from mitmproxy.proxy.layers.tls import TlsFailedClientHook
from mitmproxy.proxy.layers.tls import TlsFailedServerHook
from mitmproxy.proxy.layers.udp import UDPLayer
from mitmproxy.tls import ClientHelloData
SUPPORTED_QUIC_VERSIONS_SERVER = QuicConfiguration(is_client=False).supported_versions
class QuicLayer(tunnel.TunnelLayer):
quic: QuicConnection | None = None
tls: QuicTlsSettings | None = None
def __init__(
self,
context: context.Context,
conn: connection.Connection,
time: Callable[[], float] | None,
) -> None:
super().__init__(context, tunnel_connection=conn, conn=conn)
self.child_layer = layer.NextLayer(self.context, ask_on_start=True)
self._time = time or ctx.master.event_loop.time
self._wakeup_commands: dict[commands.RequestWakeup, float] = dict()
conn.tls = True
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, events.Wakeup) and event.command in self._wakeup_commands:
# TunnelLayer has no understanding of wakeups, so we turn this into an empty DataReceived event
# which TunnelLayer recognizes as belonging to our connection.
assert self.quic
scheduled_time = self._wakeup_commands.pop(event.command)
if self.quic._state is not QuicConnectionState.TERMINATED:
# weird quirk: asyncio sometimes returns a bit ahead of time.
now = max(scheduled_time, self._time())
self.quic.handle_timer(now)
yield from super()._handle_event(
events.DataReceived(self.tunnel_connection, b"")
)
else:
yield from super()._handle_event(event)
def event_to_child(self, event: events.Event) -> layer.CommandGenerator[None]:
# the parent will call _handle_command multiple times, we transmit cumulative afterwards
# this will reduce the number of sends, especially if data=b"" and end_stream=True
yield from super().event_to_child(event)
if self.quic:
yield from self.tls_interact()
def _handle_command(
self, command: commands.Command
) -> layer.CommandGenerator[None]:
"""Turns stream commands into aioquic connection invocations."""
if isinstance(command, QuicStreamCommand) and command.connection is self.conn:
assert self.quic
if isinstance(command, SendQuicStreamData):
self.quic.send_stream_data(
command.stream_id, command.data, command.end_stream
)
elif isinstance(command, ResetQuicStream):
stream = self.quic._get_or_create_stream_for_send(command.stream_id)
existing_reset_error_code = stream.sender._reset_error_code
if existing_reset_error_code is None:
self.quic.reset_stream(command.stream_id, command.error_code)
elif self.debug: # pragma: no cover
yield commands.Log(
f"{self.debug}[quic] stream {stream.stream_id} already reset ({existing_reset_error_code=}, {command.error_code=})",
DEBUG,
)
elif isinstance(command, StopSendingQuicStream):
# the stream might have already been closed, check before stopping
if command.stream_id in self.quic._streams:
self.quic.stop_stream(command.stream_id, command.error_code)
else:
raise AssertionError(f"Unexpected stream command: {command!r}")
else:
yield from super()._handle_command(command)
def start_tls(
self, original_destination_connection_id: bytes | None
) -> layer.CommandGenerator[None]:
"""Initiates the aioquic connection."""
# must only be called if QUIC is uninitialized
assert not self.quic
assert not self.tls
# query addons to provide the necessary TLS settings
tls_data = QuicTlsData(self.conn, self.context)
if self.conn is self.context.client:
yield QuicStartClientHook(tls_data)
else:
yield QuicStartServerHook(tls_data)
if not tls_data.settings:
yield commands.Log(
f"No QUIC context was provided, failing connection.", ERROR
)
yield commands.CloseConnection(self.conn)
return
# build the aioquic connection
configuration = tls_settings_to_configuration(
settings=tls_data.settings,
is_client=self.conn is self.context.server,
server_name=self.conn.sni,
)
self.quic = QuicConnection(
configuration=configuration,
original_destination_connection_id=original_destination_connection_id,
)
self.tls = tls_data.settings
# if we act as client, connect to upstream
if original_destination_connection_id is None:
self.quic.connect(self.conn.peername, now=self._time())
yield from self.tls_interact()
def tls_interact(self) -> layer.CommandGenerator[None]:
"""Retrieves all pending outgoing packets from aioquic and sends the data."""
# send all queued datagrams
assert self.quic
now = self._time()
for data, addr in self.quic.datagrams_to_send(now=now):
assert addr == self.conn.peername
yield commands.SendData(self.tunnel_connection, data)
timer = self.quic.get_timer()
if timer is not None:
# smooth wakeups a bit.
smoothed = timer + 0.002
# request a new wakeup if all pending requests trigger at a later time
if not any(
existing <= smoothed for existing in self._wakeup_commands.values()
):
command = commands.RequestWakeup(timer - now)
self._wakeup_commands[command] = timer
yield command
def receive_handshake_data(
self, data: bytes
) -> layer.CommandGenerator[tuple[bool, str | None]]:
assert self.quic
# forward incoming data to aioquic
if data:
self.quic.receive_datagram(data, self.conn.peername, now=self._time())
# handle pre-handshake events
while event := self.quic.next_event():
if isinstance(event, quic_events.ConnectionTerminated):
err = event.reason_phrase or error_code_to_str(event.error_code)
return False, err
elif isinstance(event, quic_events.HandshakeCompleted):
# concatenate all peer certificates
all_certs: list[x509.Certificate] = []
if self.quic.tls._peer_certificate:
all_certs.append(self.quic.tls._peer_certificate)
all_certs.extend(self.quic.tls._peer_certificate_chain)
# set the connection's TLS properties
self.conn.timestamp_tls_setup = time.time()
if event.alpn_protocol:
self.conn.alpn = event.alpn_protocol.encode("ascii")
self.conn.certificate_list = [certs.Cert(cert) for cert in all_certs]
assert self.quic.tls.key_schedule
self.conn.cipher = self.quic.tls.key_schedule.cipher_suite.name
self.conn.tls_version = "QUICv1"
# log the result and report the success to addons
if self.debug:
yield commands.Log(
f"{self.debug}[quic] tls established: {self.conn}", DEBUG
)
if self.conn is self.context.client:
yield TlsEstablishedClientHook(
QuicTlsData(self.conn, self.context, settings=self.tls)
)
else:
yield TlsEstablishedServerHook(
QuicTlsData(self.conn, self.context, settings=self.tls)
)
yield from self.tls_interact()
return True, None
elif isinstance(
event,
(
quic_events.ConnectionIdIssued,
quic_events.ConnectionIdRetired,
quic_events.PingAcknowledged,
quic_events.ProtocolNegotiated,
),
):
pass
else:
raise AssertionError(f"Unexpected event: {event!r}")
# transmit buffered data and re-arm timer
yield from self.tls_interact()
return False, None
def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]:
self.conn.error = err
if self.conn is self.context.client:
yield TlsFailedClientHook(
QuicTlsData(self.conn, self.context, settings=self.tls)
)
else:
yield TlsFailedServerHook(
QuicTlsData(self.conn, self.context, settings=self.tls)
)
yield from super().on_handshake_error(err)
def receive_data(self, data: bytes) -> layer.CommandGenerator[None]:
assert self.quic
# forward incoming data to aioquic
if data:
self.quic.receive_datagram(data, self.conn.peername, now=self._time())
# handle post-handshake events
while event := self.quic.next_event():
if isinstance(event, quic_events.ConnectionTerminated):
if self.debug:
reason = event.reason_phrase or error_code_to_str(event.error_code)
yield commands.Log(
f"{self.debug}[quic] close_notify {self.conn} ({reason=!s})",
DEBUG,
)
# We don't rely on `ConnectionTerminated` to dispatch `QuicConnectionClosed`, because
# after aioquic receives a termination frame, it still waits for the next `handle_timer`
# before returning `ConnectionTerminated` in `next_event`. In the meantime, the underlying
# connection could be closed. Therefore, we instead dispatch on `ConnectionClosed` and simply
# close the connection here.
yield commands.CloseConnection(self.tunnel_connection)
return # we don't handle any further events, nor do/can we transmit data, so exit
elif isinstance(event, quic_events.DatagramFrameReceived):
yield from self.event_to_child(
events.DataReceived(self.conn, event.data)
)
elif isinstance(event, quic_events.StreamDataReceived):
yield from self.event_to_child(
QuicStreamDataReceived(
self.conn, event.stream_id, event.data, event.end_stream
)
)
elif isinstance(event, quic_events.StreamReset):
yield from self.event_to_child(
QuicStreamReset(self.conn, event.stream_id, event.error_code)
)
elif isinstance(event, quic_events.StopSendingReceived):
yield from self.event_to_child(
QuicStreamStopSending(self.conn, event.stream_id, event.error_code)
)
elif isinstance(
event,
(
quic_events.ConnectionIdIssued,
quic_events.ConnectionIdRetired,
quic_events.PingAcknowledged,
quic_events.ProtocolNegotiated,
),
):
pass
else:
raise AssertionError(f"Unexpected event: {event!r}")
# transmit buffered data and re-arm timer
yield from self.tls_interact()
def receive_close(self) -> layer.CommandGenerator[None]:
assert self.quic
# if `_close_event` is not set, the underlying connection has been closed
# we turn this into a QUIC close event as well
close_event = self.quic._close_event or quic_events.ConnectionTerminated(
QuicErrorCode.NO_ERROR, None, "Connection closed."
)
yield from self.event_to_child(
QuicConnectionClosed(
self.conn,
close_event.error_code,
close_event.frame_type,
close_event.reason_phrase,
)
)
def send_data(self, data: bytes) -> layer.CommandGenerator[None]:
# non-stream data uses datagram frames
assert self.quic
if data:
self.quic.send_datagram_frame(data)
yield from self.tls_interact()
def send_close(
self, command: commands.CloseConnection
) -> layer.CommandGenerator[None]:
# properly close the QUIC connection
if self.quic:
if isinstance(command, CloseQuicConnection):
self.quic.close(
command.error_code, command.frame_type, command.reason_phrase
)
else:
self.quic.close()
yield from self.tls_interact()
yield from super().send_close(command)
class ServerQuicLayer(QuicLayer):
"""
This layer establishes QUIC for a single server connection.
"""
wait_for_clienthello: bool = False
def __init__(
self,
context: context.Context,
conn: connection.Server | None = None,
time: Callable[[], float] | None = None,
):
super().__init__(context, conn or context.server, time)
def start_handshake(self) -> layer.CommandGenerator[None]:
wait_for_clienthello = not self.command_to_reply_to and isinstance(
self.child_layer, ClientQuicLayer
)
if wait_for_clienthello:
self.wait_for_clienthello = True
self.tunnel_state = tunnel.TunnelState.CLOSED
else:
yield from self.start_tls(None)
def event_to_child(self, event: events.Event) -> layer.CommandGenerator[None]:
if self.wait_for_clienthello:
for command in super().event_to_child(event):
if (
isinstance(command, commands.OpenConnection)
and command.connection == self.conn
):
self.wait_for_clienthello = False
else:
yield command
else:
yield from super().event_to_child(event)
def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]:
yield commands.Log(f"Server QUIC handshake failed. {err}", level=WARNING)
yield from super().on_handshake_error(err)
class ClientQuicLayer(QuicLayer):
"""
This layer establishes QUIC on a single client connection.
"""
server_tls_available: bool
"""Indicates whether the parent layer is a ServerQuicLayer."""
handshake_datagram_buf: list[bytes]
def __init__(
self, context: context.Context, time: Callable[[], float] | None = None
) -> None:
# same as ClientTLSLayer, we might be nested in some other transport
if context.client.tls:
context.client.alpn = None
context.client.cipher = None
context.client.sni = None
context.client.timestamp_tls_setup = None
context.client.tls_version = None
context.client.certificate_list = []
context.client.mitmcert = None
context.client.alpn_offers = []
context.client.cipher_list = []
super().__init__(context, context.client, time)
self.server_tls_available = len(self.context.layers) >= 2 and isinstance(
self.context.layers[-2], ServerQuicLayer
)
self.handshake_datagram_buf = []
def start_handshake(self) -> layer.CommandGenerator[None]:
yield from ()
def receive_handshake_data(
self, data: bytes
) -> layer.CommandGenerator[tuple[bool, str | None]]:
if not self.context.options.http3:
yield commands.Log(
f"Swallowing QUIC handshake because HTTP/3 is disabled.", DEBUG
)
return False, None
# if we already had a valid client hello, don't process further packets
if self.tls:
return (yield from super().receive_handshake_data(data))
# fail if the received data is not a QUIC packet
buffer = QuicBuffer(data=data)
try:
header = pull_quic_header(buffer)
except TypeError:
return False, f"Cannot parse QUIC header: Malformed head ({data.hex()})"
except ValueError as e:
return False, f"Cannot parse QUIC header: {e} ({data.hex()})"
# negotiate version, support all versions known to aioquic
if (
header.version is not None
and header.version not in SUPPORTED_QUIC_VERSIONS_SERVER
):
yield commands.SendData(
self.tunnel_connection,
encode_quic_version_negotiation(
source_cid=header.destination_cid,
destination_cid=header.source_cid,
supported_versions=SUPPORTED_QUIC_VERSIONS_SERVER,
),
)
return False, None
# ensure it's (likely) a client handshake packet
if len(data) < 1200 or header.packet_type != PACKET_TYPE_INITIAL:
return (
False,
f"Invalid handshake received, roaming not supported. ({data.hex()})",
)
self.handshake_datagram_buf.append(data)
# extract the client hello
try:
client_hello = quic_parse_client_hello_from_datagrams(
self.handshake_datagram_buf
)
except ValueError as e:
msgs = b"\n".join(self.handshake_datagram_buf)
dbg = f"Cannot parse ClientHello: {e} ({msgs.hex()})"
self.handshake_datagram_buf.clear()
return False, dbg
if not client_hello:
return False, None
# copy the client hello information
self.conn.sni = client_hello.sni
self.conn.alpn_offers = client_hello.alpn_protocols
# check with addons what we shall do
tls_clienthello = ClientHelloData(self.context, client_hello)
yield TlsClienthelloHook(tls_clienthello)
# replace the QUIC layer with an UDP layer if requested
if tls_clienthello.ignore_connection:
self.conn = self.tunnel_connection = connection.Client(
peername=("ignore-conn", 0),
sockname=("ignore-conn", 0),
transport_protocol="udp",
state=connection.ConnectionState.OPEN,
)
# we need to replace the server layer as well, if there is one
parent_layer = self.context.layers[self.context.layers.index(self) - 1]
if isinstance(parent_layer, ServerQuicLayer):
parent_layer.conn = parent_layer.tunnel_connection = connection.Server(
address=None
)
replacement_layer = UDPLayer(self.context, ignore=True)
parent_layer.handle_event = replacement_layer.handle_event # type: ignore
parent_layer._handle_event = replacement_layer._handle_event # type: ignore
yield from parent_layer.handle_event(events.Start())
for dgm in self.handshake_datagram_buf:
yield from parent_layer.handle_event(
events.DataReceived(self.context.client, dgm)
)
self.handshake_datagram_buf.clear()
return True, None
# start the server QUIC connection if demanded and available
if (
tls_clienthello.establish_server_tls_first
and not self.context.server.tls_established
):
err = yield from self.start_server_tls()
if err:
yield commands.Log(
f"Unable to establish QUIC connection with server ({err}). "
f"Trying to establish QUIC with client anyway. "
f"If you plan to redirect requests away from this server, "
f"consider setting `connection_strategy` to `lazy` to suppress early connections."
)
# start the client QUIC connection
yield from self.start_tls(header.destination_cid)
# XXX copied from TLS, we assume that `CloseConnection` in `start_tls` takes effect immediately
if not self.conn.connected:
return False, "connection closed early"
# send the client hello to aioquic
assert self.quic
for dgm in self.handshake_datagram_buf:
self.quic.receive_datagram(dgm, self.conn.peername, now=self._time())
self.handshake_datagram_buf.clear()
# handle events emanating from `self.quic`
return (yield from super().receive_handshake_data(b""))
def start_server_tls(self) -> layer.CommandGenerator[str | None]:
if not self.server_tls_available:
return f"No server QUIC available."
err = yield commands.OpenConnection(self.context.server)
return err
def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]:
yield commands.Log(f"Client QUIC handshake failed. {err}", level=WARNING)
yield from super().on_handshake_error(err)
self.event_to_child = self.errored # type: ignore
def errored(self, event: events.Event) -> layer.CommandGenerator[None]:
if self.debug is not None:
yield commands.Log(
f"{self.debug}[quic] Swallowing {event} as handshake failed.", DEBUG
)
class QuicSecretsLogger:
logger: tls.MasterSecretLogger
def __init__(self, logger: tls.MasterSecretLogger) -> None:
super().__init__()
self.logger = logger
def write(self, s: str) -> int:
if s[-1:] == "\n":
s = s[:-1]
data = s.encode("ascii")
self.logger(None, data) # type: ignore
return len(data) + 1
def flush(self) -> None:
# done by the logger during write
pass
def error_code_to_str(error_code: int) -> str:
"""Returns the corresponding name of the given error code or a string containing its numeric value."""
try:
return H3ErrorCode(error_code).name
except ValueError:
try:
return QuicErrorCode(error_code).name
except ValueError:
return f"unknown error (0x{error_code:x})"
def is_success_error_code(error_code: int) -> bool:
"""Returns whether the given error code actually indicates no error."""
return error_code in (QuicErrorCode.NO_ERROR, H3ErrorCode.H3_NO_ERROR)
def tls_settings_to_configuration(
settings: QuicTlsSettings,
is_client: bool,
server_name: str | None = None,
) -> QuicConfiguration:
"""Converts `QuicTlsSettings` to `QuicConfiguration`."""
return QuicConfiguration(
alpn_protocols=settings.alpn_protocols,
is_client=is_client,
secrets_log_file=(
QuicSecretsLogger(tls.log_master_secret) # type: ignore
if tls.log_master_secret is not None
else None
),
server_name=server_name,
cafile=settings.ca_file,
capath=settings.ca_path,
certificate=settings.certificate,
certificate_chain=settings.certificate_chain,
cipher_suites=settings.cipher_suites,
private_key=settings.certificate_private_key,
verify_mode=settings.verify_mode,
max_datagram_frame_size=65536,
)

View File

@@ -0,0 +1,143 @@
from dataclasses import dataclass
from mitmproxy import flow
from mitmproxy import tcp
from mitmproxy.connection import Connection
from mitmproxy.connection import ConnectionState
from mitmproxy.proxy import commands
from mitmproxy.proxy import events
from mitmproxy.proxy import layer
from mitmproxy.proxy.commands import StartHook
from mitmproxy.proxy.context import Context
from mitmproxy.proxy.events import MessageInjected
from mitmproxy.proxy.utils import expect
@dataclass
class TcpStartHook(StartHook):
"""
A TCP connection has started.
"""
flow: tcp.TCPFlow
@dataclass
class TcpMessageHook(StartHook):
"""
A TCP connection has received a message. The most recent message
will be flow.messages[-1]. The message is user-modifiable.
"""
flow: tcp.TCPFlow
@dataclass
class TcpEndHook(StartHook):
"""
A TCP connection has ended.
"""
flow: tcp.TCPFlow
@dataclass
class TcpErrorHook(StartHook):
"""
A TCP error has occurred.
Every TCP flow will receive either a tcp_error or a tcp_end event, but not both.
"""
flow: tcp.TCPFlow
class TcpMessageInjected(MessageInjected[tcp.TCPMessage]):
"""
The user has injected a custom TCP message.
"""
class TCPLayer(layer.Layer):
"""
Simple TCP layer that just relays messages right now.
"""
flow: tcp.TCPFlow | None
def __init__(self, context: Context, ignore: bool = False):
super().__init__(context)
if ignore:
self.flow = None
else:
self.flow = tcp.TCPFlow(self.context.client, self.context.server, True)
@expect(events.Start)
def start(self, _) -> layer.CommandGenerator[None]:
if self.flow:
yield TcpStartHook(self.flow)
if self.context.server.timestamp_start is None:
err = yield commands.OpenConnection(self.context.server)
if err:
if self.flow:
self.flow.error = flow.Error(str(err))
yield TcpErrorHook(self.flow)
yield commands.CloseConnection(self.context.client)
self._handle_event = self.done
return
self._handle_event = self.relay_messages
_handle_event = start
@expect(events.DataReceived, events.ConnectionClosed, TcpMessageInjected)
def relay_messages(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, TcpMessageInjected):
# we just spoof that we received data here and then process that regularly.
event = events.DataReceived(
self.context.client
if event.message.from_client
else self.context.server,
event.message.content,
)
assert isinstance(event, events.ConnectionEvent)
from_client = event.connection == self.context.client
send_to: Connection
if from_client:
send_to = self.context.server
else:
send_to = self.context.client
if isinstance(event, events.DataReceived):
if self.flow:
tcp_message = tcp.TCPMessage(from_client, event.data)
self.flow.messages.append(tcp_message)
yield TcpMessageHook(self.flow)
yield commands.SendData(send_to, tcp_message.content)
else:
yield commands.SendData(send_to, event.data)
elif isinstance(event, events.ConnectionClosed):
all_done = not (
(self.context.client.state & ConnectionState.CAN_READ)
or (self.context.server.state & ConnectionState.CAN_READ)
)
if all_done:
self._handle_event = self.done
if self.context.server.state is not ConnectionState.CLOSED:
yield commands.CloseConnection(self.context.server)
if self.context.client.state is not ConnectionState.CLOSED:
yield commands.CloseConnection(self.context.client)
if self.flow:
yield TcpEndHook(self.flow)
self.flow.live = False
else:
yield commands.CloseTcpConnection(send_to, half_close=True)
else:
raise AssertionError(f"Unexpected event: {event}")
@expect(events.DataReceived, events.ConnectionClosed, TcpMessageInjected)
def done(self, _) -> layer.CommandGenerator[None]:
yield from ()

View File

@@ -0,0 +1,692 @@
import struct
import time
import typing
from collections.abc import Iterator
from dataclasses import dataclass
from logging import DEBUG
from logging import ERROR
from logging import INFO
from logging import WARNING
from OpenSSL import SSL
from mitmproxy import certs
from mitmproxy import connection
from mitmproxy.connection import TlsVersion
from mitmproxy.net.tls import starts_like_dtls_record
from mitmproxy.net.tls import starts_like_tls_record
from mitmproxy.proxy import commands
from mitmproxy.proxy import context
from mitmproxy.proxy import events
from mitmproxy.proxy import layer
from mitmproxy.proxy import tunnel
from mitmproxy.proxy.commands import StartHook
from mitmproxy.proxy.layers import tcp
from mitmproxy.proxy.layers import udp
from mitmproxy.tls import ClientHello
from mitmproxy.tls import ClientHelloData
from mitmproxy.tls import TlsData
from mitmproxy.utils import human
def handshake_record_contents(data: bytes) -> Iterator[bytes]:
"""
Returns a generator that yields the bytes contained in each handshake record.
This will raise an error on the first non-handshake record, so fully exhausting this
generator is a bad idea.
"""
offset = 0
while True:
if len(data) < offset + 5:
return
record_header = data[offset : offset + 5]
if not starts_like_tls_record(record_header):
raise ValueError(f"Expected TLS record, got {record_header!r} instead.")
record_size = struct.unpack("!H", record_header[3:])[0]
if record_size == 0:
raise ValueError("Record must not be empty.")
offset += 5
if len(data) < offset + record_size:
return
record_body = data[offset : offset + record_size]
yield record_body
offset += record_size
def get_client_hello(data: bytes) -> bytes | None:
"""
Read all TLS records that contain the initial ClientHello.
Returns the raw handshake packet bytes, without TLS record headers.
"""
client_hello = b""
for d in handshake_record_contents(data):
client_hello += d
if len(client_hello) >= 4:
client_hello_size = struct.unpack("!I", b"\x00" + client_hello[1:4])[0] + 4
if len(client_hello) >= client_hello_size:
return client_hello[:client_hello_size]
return None
def parse_client_hello(data: bytes) -> ClientHello | None:
"""
Check if the supplied bytes contain a full ClientHello message,
and if so, parse it.
Returns:
- A ClientHello object on success
- None, if the TLS record is not complete
Raises:
- A ValueError, if the passed ClientHello is invalid
"""
# Check if ClientHello is complete
client_hello = get_client_hello(data)
if client_hello:
try:
return ClientHello(client_hello[4:])
except EOFError as e:
raise ValueError("Invalid ClientHello") from e
return None
def dtls_handshake_record_contents(data: bytes) -> Iterator[bytes]:
"""
Returns a generator that yields the bytes contained in each handshake record.
This will raise an error on the first non-handshake record, so fully exhausting this
generator is a bad idea.
"""
offset = 0
while True:
# DTLS includes two new fields, totaling 8 bytes, between Version and Length
if len(data) < offset + 13:
return
record_header = data[offset : offset + 13]
if not starts_like_dtls_record(record_header):
raise ValueError(f"Expected DTLS record, got {record_header!r} instead.")
# Length fields starts at 11
record_size = struct.unpack("!H", record_header[11:])[0]
if record_size == 0:
raise ValueError("Record must not be empty.")
offset += 13
if len(data) < offset + record_size:
return
record_body = data[offset : offset + record_size]
yield record_body
offset += record_size
def get_dtls_client_hello(data: bytes) -> bytes | None:
"""
Read all DTLS records that contain the initial ClientHello.
Returns the raw handshake packet bytes, without TLS record headers.
"""
client_hello = b""
for d in dtls_handshake_record_contents(data):
client_hello += d
if len(client_hello) >= 13:
# comment about slicing: we skip the epoch and sequence number
client_hello_size = (
struct.unpack("!I", b"\x00" + client_hello[9:12])[0] + 12
)
if len(client_hello) >= client_hello_size:
return client_hello[:client_hello_size]
return None
def dtls_parse_client_hello(data: bytes) -> ClientHello | None:
"""
Check if the supplied bytes contain a full ClientHello message,
and if so, parse it.
Returns:
- A ClientHello object on success
- None, if the TLS record is not complete
Raises:
- A ValueError, if the passed ClientHello is invalid
"""
# Check if ClientHello is complete
client_hello = get_dtls_client_hello(data)
if client_hello:
try:
return ClientHello(client_hello[12:], dtls=True)
except EOFError as e:
raise ValueError("Invalid ClientHello") from e
return None
HTTP1_ALPNS = (b"http/1.1", b"http/1.0", b"http/0.9")
HTTP2_ALPN = b"h2"
HTTP3_ALPN = b"h3"
HTTP_ALPNS = (HTTP3_ALPN, HTTP2_ALPN, *HTTP1_ALPNS)
# We need these classes as hooks can only have one argument at the moment.
@dataclass
class TlsClienthelloHook(StartHook):
"""
Mitmproxy has received a TLS ClientHello message.
This hook decides whether a server connection is needed
to negotiate TLS with the client (data.establish_server_tls_first)
"""
data: ClientHelloData
@dataclass
class TlsStartClientHook(StartHook):
"""
TLS negotation between mitmproxy and a client is about to start.
An addon is expected to initialize data.ssl_conn.
(by default, this is done by `mitmproxy.addons.tlsconfig`)
"""
data: TlsData
@dataclass
class TlsStartServerHook(StartHook):
"""
TLS negotation between mitmproxy and a server is about to start.
An addon is expected to initialize data.ssl_conn.
(by default, this is done by `mitmproxy.addons.tlsconfig`)
"""
data: TlsData
@dataclass
class TlsEstablishedClientHook(StartHook):
"""
The TLS handshake with the client has been completed successfully.
"""
data: TlsData
@dataclass
class TlsEstablishedServerHook(StartHook):
"""
The TLS handshake with the server has been completed successfully.
"""
data: TlsData
@dataclass
class TlsFailedClientHook(StartHook):
"""
The TLS handshake with the client has failed.
"""
data: TlsData
@dataclass
class TlsFailedServerHook(StartHook):
"""
The TLS handshake with the server has failed.
"""
data: TlsData
class TLSLayer(tunnel.TunnelLayer):
tls: SSL.Connection = None # type: ignore
"""The OpenSSL connection object"""
def __init__(self, context: context.Context, conn: connection.Connection):
super().__init__(
context,
tunnel_connection=conn,
conn=conn,
)
conn.tls = True
def __repr__(self):
return (
super().__repr__().replace(")", f" {self.conn.sni!r} {self.conn.alpn!r})")
)
@property
def is_dtls(self):
return self.conn.transport_protocol == "udp"
@property
def proto_name(self):
return "DTLS" if self.is_dtls else "TLS"
def start_tls(self) -> layer.CommandGenerator[None]:
assert not self.tls
tls_start = TlsData(self.conn, self.context, is_dtls=self.is_dtls)
if self.conn == self.context.client:
yield TlsStartClientHook(tls_start)
else:
yield TlsStartServerHook(tls_start)
if not tls_start.ssl_conn:
yield commands.Log(
f"No {self.proto_name} context was provided, failing connection.", ERROR
)
yield commands.CloseConnection(self.conn)
return
assert tls_start.ssl_conn
self.tls = tls_start.ssl_conn
def tls_interact(self) -> layer.CommandGenerator[None]:
while True:
try:
data = self.tls.bio_read(65535)
except SSL.WantReadError:
return # Okay, nothing more waiting to be sent.
else:
yield commands.SendData(self.conn, data)
def receive_handshake_data(
self, data: bytes
) -> layer.CommandGenerator[tuple[bool, str | None]]:
# bio_write errors for b"", so we need to check first if we actually received something.
if data:
self.tls.bio_write(data)
try:
self.tls.do_handshake()
except SSL.WantReadError:
yield from self.tls_interact()
return False, None
except SSL.Error as e:
# provide more detailed information for some errors.
last_err = (
e.args and isinstance(e.args[0], list) and e.args[0] and e.args[0][-1]
)
if last_err in [
(
"SSL routines",
"tls_process_server_certificate",
"certificate verify failed",
),
("SSL routines", "", "certificate verify failed"), # OpenSSL 3+
]:
verify_result = SSL._lib.SSL_get_verify_result(self.tls._ssl) # type: ignore
error = SSL._ffi.string( # type: ignore
SSL._lib.X509_verify_cert_error_string(verify_result) # type: ignore
).decode()
err = f"Certificate verify failed: {error}"
elif last_err in [
("SSL routines", "ssl3_read_bytes", "tlsv1 alert unknown ca"),
("SSL routines", "ssl3_read_bytes", "sslv3 alert bad certificate"),
("SSL routines", "ssl3_read_bytes", "ssl/tls alert bad certificate"),
("SSL routines", "", "tlsv1 alert unknown ca"), # OpenSSL 3+
("SSL routines", "", "sslv3 alert bad certificate"), # OpenSSL 3+
("SSL routines", "", "ssl/tls alert bad certificate"), # OpenSSL 3.2+
]:
assert isinstance(last_err, tuple)
err = last_err[2]
elif (
last_err
in [
("SSL routines", "ssl3_get_record", "wrong version number"),
("SSL routines", "", "wrong version number"), # OpenSSL 3+
("SSL routines", "", "packet length too long"), # OpenSSL 3+
("SSL routines", "", "record layer failure"), # OpenSSL 3+
]
and data[:4].isascii()
):
err = f"The remote server does not speak TLS."
elif last_err in [
("SSL routines", "ssl3_read_bytes", "tlsv1 alert protocol version"),
("SSL routines", "", "tlsv1 alert protocol version"), # OpenSSL 3+
]:
err = (
f"The remote server and mitmproxy cannot agree on a TLS version to use. "
f"You may need to adjust mitmproxy's tls_version_server_min option."
)
else:
err = f"OpenSSL {e!r}"
return False, err
else:
# Here we set all attributes that are only known *after* the handshake.
# Get all peer certificates.
# https://www.openssl.org/docs/man1.1.1/man3/SSL_get_peer_cert_chain.html
# If called on the client side, the stack also contains the peer's certificate; if called on the server
# side, the peer's certificate must be obtained separately using SSL_get_peer_certificate(3).
all_certs = self.tls.get_peer_cert_chain() or []
if self.conn == self.context.client:
cert = self.tls.get_peer_certificate()
if cert:
all_certs.insert(0, cert)
self.conn.certificate_list = []
for cert in all_certs:
try:
# This may fail for weird certs, https://github.com/mitmproxy/mitmproxy/issues/6968.
parsed_cert = certs.Cert.from_pyopenssl(cert)
except ValueError as e:
yield commands.Log(
f"{self.debug}[tls] failed to parse certificate: {e}", WARNING
)
else:
self.conn.certificate_list.append(parsed_cert)
self.conn.timestamp_tls_setup = time.time()
self.conn.alpn = self.tls.get_alpn_proto_negotiated()
self.conn.cipher = self.tls.get_cipher_name()
self.conn.tls_version = typing.cast(
TlsVersion, self.tls.get_protocol_version_name()
)
if self.debug:
yield commands.Log(
f"{self.debug}[tls] tls established: {self.conn}", DEBUG
)
if self.conn == self.context.client:
yield TlsEstablishedClientHook(
TlsData(self.conn, self.context, self.tls)
)
else:
yield TlsEstablishedServerHook(
TlsData(self.conn, self.context, self.tls)
)
yield from self.receive_data(b"")
return True, None
def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]:
self.conn.error = err
if self.conn == self.context.client:
yield TlsFailedClientHook(TlsData(self.conn, self.context, self.tls))
else:
yield TlsFailedServerHook(TlsData(self.conn, self.context, self.tls))
yield from super().on_handshake_error(err)
def receive_data(self, data: bytes) -> layer.CommandGenerator[None]:
if data:
self.tls.bio_write(data)
plaintext = bytearray()
close = False
while True:
try:
plaintext.extend(self.tls.recv(65535))
except SSL.WantReadError:
break
except SSL.ZeroReturnError:
close = True
break
except SSL.Error as e:
# This may be happening because the other side send an alert.
# There's somewhat ugly behavior with Firefox on Android here,
# which upon mistrusting a certificate still completes the handshake
# and then sends an alert in the next packet. At this point we have unfortunately
# already fired out `tls_established_client` hook.
yield commands.Log(f"TLS Error: {e}", WARNING)
break
# Can we send something?
# Note that this must happen after `recv()`, which may have advanced the state machine.
# https://github.com/mitmproxy/mitmproxy/discussions/7550
yield from self.tls_interact()
if plaintext:
yield from self.event_to_child(
events.DataReceived(self.conn, bytes(plaintext))
)
if close:
self.conn.state &= ~connection.ConnectionState.CAN_READ
if self.debug:
yield commands.Log(f"{self.debug}[tls] close_notify {self.conn}", DEBUG)
yield from self.event_to_child(events.ConnectionClosed(self.conn))
def receive_close(self) -> layer.CommandGenerator[None]:
if self.tls.get_shutdown() & SSL.RECEIVED_SHUTDOWN:
pass # We have already dispatched a ConnectionClosed to the child layer.
else:
yield from super().receive_close()
def send_data(self, data: bytes) -> layer.CommandGenerator[None]:
try:
self.tls.sendall(data)
except (SSL.ZeroReturnError, SSL.SysCallError):
# The other peer may still be trying to send data over, which we discard here.
pass
yield from self.tls_interact()
def send_close(
self, command: commands.CloseConnection
) -> layer.CommandGenerator[None]:
# We should probably shutdown the TLS connection properly here.
yield from super().send_close(command)
class ServerTLSLayer(TLSLayer):
"""
This layer establishes TLS for a single server connection.
"""
wait_for_clienthello: bool = False
def __init__(self, context: context.Context, conn: connection.Server | None = None):
super().__init__(context, conn or context.server)
def start_handshake(self) -> layer.CommandGenerator[None]:
wait_for_clienthello = (
# if command_to_reply_to is set, we've been instructed to open the connection from the child layer.
# in that case any potential ClientHello is already parsed (by the ClientTLS child layer).
not self.command_to_reply_to
# if command_to_reply_to is not set, the connection was already open when this layer received its Start
# event (eager connection strategy). We now want to establish TLS right away, _unless_ we already know
# that there's TLS on the client side as well (we check if our immediate child layer is set to be ClientTLS)
# In this case want to wait for ClientHello to be parsed, so that we can incorporate SNI/ALPN from there.
and isinstance(self.child_layer, ClientTLSLayer)
)
if wait_for_clienthello:
self.wait_for_clienthello = True
self.tunnel_state = tunnel.TunnelState.CLOSED
else:
yield from self.start_tls()
if self.tls:
yield from self.receive_handshake_data(b"")
def event_to_child(self, event: events.Event) -> layer.CommandGenerator[None]:
if self.wait_for_clienthello:
for command in super().event_to_child(event):
if (
isinstance(command, commands.OpenConnection)
and command.connection == self.conn
):
self.wait_for_clienthello = False
# swallow OpenConnection here by not re-yielding it.
else:
yield command
else:
yield from super().event_to_child(event)
def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]:
yield commands.Log(f"Server TLS handshake failed. {err}", level=WARNING)
yield from super().on_handshake_error(err)
class ClientTLSLayer(TLSLayer):
"""
This layer establishes TLS on a single client connection.
┌─────┐
│Start│
└┬────┘
┌────────────────────┐
│Wait for ClientHello│
└┬───────────────────┘
┌────────────────┐
│Process messages│
└────────────────┘
"""
recv_buffer: bytearray
server_tls_available: bool
client_hello_parsed: bool = False
def __init__(self, context: context.Context):
if context.client.tls:
# In the case of TLS-over-TLS, we already have client TLS. As the outer TLS connection between client
# and proxy isn't that interesting to us, we just unset the attributes here and keep the inner TLS
# session's attributes.
# Alternatively we could create a new Client instance,
# but for now we keep it simple. There is a proof-of-concept at
# https://github.com/mitmproxy/mitmproxy/commit/9b6e2a716888b7787514733b76a5936afa485352.
context.client.alpn = None
context.client.cipher = None
context.client.sni = None
context.client.timestamp_tls_setup = None
context.client.tls_version = None
context.client.certificate_list = []
context.client.mitmcert = None
context.client.alpn_offers = []
context.client.cipher_list = []
super().__init__(context, context.client)
self.server_tls_available = isinstance(self.context.layers[-2], ServerTLSLayer)
self.recv_buffer = bytearray()
def start_handshake(self) -> layer.CommandGenerator[None]:
yield from ()
def receive_handshake_data(
self, data: bytes
) -> layer.CommandGenerator[tuple[bool, str | None]]:
if self.client_hello_parsed:
return (yield from super().receive_handshake_data(data))
self.recv_buffer.extend(data)
try:
if self.is_dtls:
client_hello = dtls_parse_client_hello(self.recv_buffer)
else:
client_hello = parse_client_hello(self.recv_buffer)
except ValueError:
return False, f"Cannot parse ClientHello: {self.recv_buffer.hex()}"
if client_hello:
self.client_hello_parsed = True
else:
return False, None
self.conn.sni = client_hello.sni
self.conn.alpn_offers = client_hello.alpn_protocols
tls_clienthello = ClientHelloData(self.context, client_hello)
yield TlsClienthelloHook(tls_clienthello)
if tls_clienthello.ignore_connection:
# we've figured out that we don't want to intercept this connection, so we assign fake connection objects
# to all TLS layers. This makes the real connection contents just go through.
self.conn = self.tunnel_connection = connection.Client(
peername=("ignore-conn", 0), sockname=("ignore-conn", 0)
)
parent_layer = self.context.layers[self.context.layers.index(self) - 1]
if isinstance(parent_layer, ServerTLSLayer):
parent_layer.conn = parent_layer.tunnel_connection = connection.Server(
address=None
)
if self.is_dtls:
self.child_layer = udp.UDPLayer(self.context, ignore=True)
else:
self.child_layer = tcp.TCPLayer(self.context, ignore=True)
yield from self.event_to_child(
events.DataReceived(self.context.client, bytes(self.recv_buffer))
)
self.recv_buffer.clear()
return True, None
if (
tls_clienthello.establish_server_tls_first
and not self.context.server.tls_established
):
err = yield from self.start_server_tls()
if err:
yield commands.Log(
f"Unable to establish {self.proto_name} connection with server ({err}). "
f"Trying to establish {self.proto_name} with client anyway. "
f"If you plan to redirect requests away from this server, "
f"consider setting `connection_strategy` to `lazy` to suppress early connections."
)
yield from self.start_tls()
if not self.conn.connected:
return False, "connection closed early"
ret = yield from super().receive_handshake_data(bytes(self.recv_buffer))
self.recv_buffer.clear()
return ret
def start_server_tls(self) -> layer.CommandGenerator[str | None]:
"""
We often need information from the upstream connection to establish TLS with the client.
For example, we need to check if the client does ALPN or not.
"""
if not self.server_tls_available:
return f"No server {self.proto_name} available."
err = yield commands.OpenConnection(self.context.server)
return err
def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]:
if self.conn.sni:
dest = self.conn.sni
else:
dest = human.format_address(self.context.server.address)
level: int = WARNING
if err.startswith("Cannot parse ClientHello"):
pass
elif (
"('SSL routines', 'tls_early_post_process_client_hello', 'unsupported protocol')"
in err
or "('SSL routines', '', 'unsupported protocol')" in err # OpenSSL 3+
):
err = (
f"Client and mitmproxy cannot agree on a TLS version to use. "
f"You may need to adjust mitmproxy's tls_version_client_min option."
)
elif (
"unknown ca" in err
or "bad certificate" in err
or "certificate unknown" in err
):
err = (
f"The client does not trust the proxy's certificate for {dest} ({err})"
)
elif err == "connection closed":
err = (
f"The client disconnected during the handshake. If this happens consistently for {dest}, "
f"this may indicate that the client does not trust the proxy's certificate."
)
level = INFO
elif err == "connection closed early":
pass
else:
err = f"The client may not trust the proxy's certificate for {dest} ({err})"
if err != "connection closed early":
yield commands.Log(f"Client TLS handshake failed. {err}", level=level)
yield from super().on_handshake_error(err)
self.event_to_child = self.errored # type: ignore
def errored(self, event: events.Event) -> layer.CommandGenerator[None]:
if self.debug is not None:
yield commands.Log(
f"{self.debug}[tls] Swallowing {event} as handshake failed.", DEBUG
)
class MockTLSLayer(TLSLayer):
"""Mock layer to disable actual TLS and use cleartext in tests.
Use like so:
monkeypatch.setattr(tls, "ServerTLSLayer", tls.MockTLSLayer)
"""
def __init__(self, ctx: context.Context):
super().__init__(ctx, connection.Server(address=None))

View File

@@ -0,0 +1,132 @@
from dataclasses import dataclass
from mitmproxy import flow
from mitmproxy import udp
from mitmproxy.connection import Connection
from mitmproxy.proxy import commands
from mitmproxy.proxy import events
from mitmproxy.proxy import layer
from mitmproxy.proxy.commands import StartHook
from mitmproxy.proxy.context import Context
from mitmproxy.proxy.events import MessageInjected
from mitmproxy.proxy.utils import expect
@dataclass
class UdpStartHook(StartHook):
"""
A UDP connection has started.
"""
flow: udp.UDPFlow
@dataclass
class UdpMessageHook(StartHook):
"""
A UDP connection has received a message. The most recent message
will be flow.messages[-1]. The message is user-modifiable.
"""
flow: udp.UDPFlow
@dataclass
class UdpEndHook(StartHook):
"""
A UDP connection has ended.
"""
flow: udp.UDPFlow
@dataclass
class UdpErrorHook(StartHook):
"""
A UDP error has occurred.
Every UDP flow will receive either a udp_error or a udp_end event, but not both.
"""
flow: udp.UDPFlow
class UdpMessageInjected(MessageInjected[udp.UDPMessage]):
"""
The user has injected a custom UDP message.
"""
class UDPLayer(layer.Layer):
"""
Simple UDP layer that just relays messages right now.
"""
flow: udp.UDPFlow | None
def __init__(self, context: Context, ignore: bool = False):
super().__init__(context)
if ignore:
self.flow = None
else:
self.flow = udp.UDPFlow(self.context.client, self.context.server, True)
@expect(events.Start)
def start(self, _) -> layer.CommandGenerator[None]:
if self.flow:
yield UdpStartHook(self.flow)
if self.context.server.timestamp_start is None:
err = yield commands.OpenConnection(self.context.server)
if err:
if self.flow:
self.flow.error = flow.Error(str(err))
yield UdpErrorHook(self.flow)
yield commands.CloseConnection(self.context.client)
self._handle_event = self.done
return
self._handle_event = self.relay_messages
_handle_event = start
@expect(events.DataReceived, events.ConnectionClosed, UdpMessageInjected)
def relay_messages(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, UdpMessageInjected):
# we just spoof that we received data here and then process that regularly.
event = events.DataReceived(
self.context.client
if event.message.from_client
else self.context.server,
event.message.content,
)
assert isinstance(event, events.ConnectionEvent)
from_client = event.connection == self.context.client
send_to: Connection
if from_client:
send_to = self.context.server
else:
send_to = self.context.client
if isinstance(event, events.DataReceived):
if self.flow:
udp_message = udp.UDPMessage(from_client, event.data)
self.flow.messages.append(udp_message)
yield UdpMessageHook(self.flow)
yield commands.SendData(send_to, udp_message.content)
else:
yield commands.SendData(send_to, event.data)
elif isinstance(event, events.ConnectionClosed):
self._handle_event = self.done
yield commands.CloseConnection(send_to)
if self.flow:
yield UdpEndHook(self.flow)
self.flow.live = False
else:
raise AssertionError(f"Unexpected event: {event}")
@expect(events.DataReceived, events.ConnectionClosed, UdpMessageInjected)
def done(self, _) -> layer.CommandGenerator[None]:
yield from ()

View File

@@ -0,0 +1,272 @@
import time
from collections.abc import Iterator
from dataclasses import dataclass
import wsproto.extensions
import wsproto.frame_protocol
import wsproto.utilities
from wsproto import ConnectionState
from wsproto.frame_protocol import Opcode
from mitmproxy import connection
from mitmproxy import http
from mitmproxy import websocket
from mitmproxy.proxy import commands
from mitmproxy.proxy import events
from mitmproxy.proxy import layer
from mitmproxy.proxy.commands import StartHook
from mitmproxy.proxy.context import Context
from mitmproxy.proxy.events import MessageInjected
from mitmproxy.proxy.utils import expect
@dataclass
class WebsocketStartHook(StartHook):
"""
A WebSocket connection has commenced.
"""
flow: http.HTTPFlow
@dataclass
class WebsocketMessageHook(StartHook):
"""
Called when a WebSocket message is received from the client or
server. The most recent message will be flow.messages[-1]. The
message is user-modifiable. Currently there are two types of
messages, corresponding to the BINARY and TEXT frame types.
"""
flow: http.HTTPFlow
@dataclass
class WebsocketEndHook(StartHook):
"""
A WebSocket connection has ended.
You can check `flow.websocket.close_code` to determine why it ended.
"""
flow: http.HTTPFlow
class WebSocketMessageInjected(MessageInjected[websocket.WebSocketMessage]):
"""
The user has injected a custom WebSocket message.
"""
class WebsocketConnection(wsproto.Connection):
"""
A very thin wrapper around wsproto.Connection:
- we keep the underlying connection as an attribute for easy access.
- we add a framebuffer for incomplete messages
- we wrap .send() so that we can directly yield it.
"""
conn: connection.Connection
frame_buf: list[bytes]
def __init__(self, *args, conn: connection.Connection, **kwargs):
super().__init__(*args, **kwargs)
self.conn = conn
self.frame_buf = [b""]
def send2(self, event: wsproto.events.Event) -> commands.SendData:
data = self.send(event)
return commands.SendData(self.conn, data)
def __repr__(self):
return f"WebsocketConnection<{self.state.name}, {self.conn}>"
class WebsocketLayer(layer.Layer):
"""
WebSocket layer that intercepts and relays messages.
"""
flow: http.HTTPFlow
client_ws: WebsocketConnection
server_ws: WebsocketConnection
def __init__(self, context: Context, flow: http.HTTPFlow):
super().__init__(context)
self.flow = flow
@expect(events.Start)
def start(self, _) -> layer.CommandGenerator[None]:
client_extensions = []
server_extensions = []
# Parse extension headers. We only support deflate at the moment and ignore everything else.
assert self.flow.response # satisfy type checker
ext_header = self.flow.response.headers.get("Sec-WebSocket-Extensions", "")
if ext_header:
for ext in wsproto.utilities.split_comma_header(
ext_header.encode("ascii", "replace")
):
ext_name = ext.split(";", 1)[0].strip()
if ext_name == wsproto.extensions.PerMessageDeflate.name:
client_deflate = wsproto.extensions.PerMessageDeflate()
client_deflate.finalize(ext)
client_extensions.append(client_deflate)
server_deflate = wsproto.extensions.PerMessageDeflate()
server_deflate.finalize(ext)
server_extensions.append(server_deflate)
else:
yield commands.Log(
f"Ignoring unknown WebSocket extension {ext_name!r}."
)
self.client_ws = WebsocketConnection(
wsproto.ConnectionType.SERVER, client_extensions, conn=self.context.client
)
self.server_ws = WebsocketConnection(
wsproto.ConnectionType.CLIENT, server_extensions, conn=self.context.server
)
yield WebsocketStartHook(self.flow)
self._handle_event = self.relay_messages
_handle_event = start
@expect(events.DataReceived, events.ConnectionClosed, WebSocketMessageInjected)
def relay_messages(self, event: events.Event) -> layer.CommandGenerator[None]:
assert self.flow.websocket # satisfy type checker
if isinstance(event, events.ConnectionEvent):
from_client = event.connection == self.context.client
injected = False
elif isinstance(event, WebSocketMessageInjected):
from_client = event.message.from_client
injected = True
else:
raise AssertionError(f"Unexpected event: {event}")
from_str = "client" if from_client else "server"
if from_client:
src_ws = self.client_ws
dst_ws = self.server_ws
else:
src_ws = self.server_ws
dst_ws = self.client_ws
if isinstance(event, events.DataReceived):
src_ws.receive_data(event.data)
elif isinstance(event, events.ConnectionClosed):
src_ws.receive_data(None)
elif isinstance(event, WebSocketMessageInjected):
fragmentizer = Fragmentizer([], event.message.type == Opcode.TEXT)
src_ws._events.extend(fragmentizer(event.message.content))
else: # pragma: no cover
raise AssertionError(f"Unexpected event: {event}")
for ws_event in src_ws.events():
if isinstance(ws_event, wsproto.events.Message):
is_text = isinstance(ws_event.data, str)
if is_text:
typ = Opcode.TEXT
src_ws.frame_buf[-1] += ws_event.data.encode()
else:
typ = Opcode.BINARY
src_ws.frame_buf[-1] += ws_event.data
if ws_event.message_finished:
content = b"".join(src_ws.frame_buf)
fragmentizer = Fragmentizer(src_ws.frame_buf, is_text)
src_ws.frame_buf = [b""]
message = websocket.WebSocketMessage(
typ, from_client, content, injected=injected
)
self.flow.websocket.messages.append(message)
yield WebsocketMessageHook(self.flow)
if not message.dropped:
for msg in fragmentizer(message.content):
yield dst_ws.send2(msg)
elif ws_event.frame_finished:
src_ws.frame_buf.append(b"")
elif isinstance(ws_event, (wsproto.events.Ping, wsproto.events.Pong)):
yield commands.Log(
f"Received WebSocket {ws_event.__class__.__name__.lower()} from {from_str} "
f"(payload: {bytes(ws_event.payload)!r})"
)
yield dst_ws.send2(ws_event)
elif isinstance(ws_event, wsproto.events.CloseConnection):
self.flow.websocket.timestamp_end = time.time()
self.flow.websocket.closed_by_client = from_client
self.flow.websocket.close_code = ws_event.code
self.flow.websocket.close_reason = ws_event.reason
for ws in [self.server_ws, self.client_ws]:
if ws.state in {
ConnectionState.OPEN,
ConnectionState.REMOTE_CLOSING,
}:
# response == original event, so no need to differentiate here.
yield ws.send2(ws_event)
yield commands.CloseConnection(ws.conn)
yield WebsocketEndHook(self.flow)
self.flow.live = False
self._handle_event = self.done
else: # pragma: no cover
raise AssertionError(f"Unexpected WebSocket event: {ws_event}")
@expect(events.DataReceived, events.ConnectionClosed, WebSocketMessageInjected)
def done(self, _) -> layer.CommandGenerator[None]:
yield from ()
class Fragmentizer:
"""
Theory (RFC 6455):
Unless specified otherwise by an extension, frames have no semantic
meaning. An intermediary might coalesce and/or split frames, [...]
Practice:
Some WebSocket servers reject large payload sizes.
Other WebSocket servers reject CONTINUATION frames.
As a workaround, we either retain the original chunking or, if the payload has been modified, use ~4kB chunks.
If one deals with web servers that do not support CONTINUATION frames, addons need to monkeypatch FRAGMENT_SIZE
if they need to modify the message.
"""
# A bit less than 4kb to accommodate for headers.
FRAGMENT_SIZE = 4000
def __init__(self, fragments: list[bytes], is_text: bool):
self.fragment_lengths = [len(x) for x in fragments]
self.is_text = is_text
def msg(self, data: bytes, message_finished: bool):
if self.is_text:
data_str = data.decode(errors="replace")
return wsproto.events.TextMessage(
data_str, message_finished=message_finished
)
else:
return wsproto.events.BytesMessage(data, message_finished=message_finished)
def __call__(self, content: bytes) -> Iterator[wsproto.events.Message]:
if len(content) == sum(self.fragment_lengths):
# message has the same length, we can reuse the same sizes
offset = 0
for fl in self.fragment_lengths[:-1]:
yield self.msg(content[offset : offset + fl], False)
offset += fl
yield self.msg(content[offset:], True)
else:
offset = 0
total = len(content) - self.FRAGMENT_SIZE
while offset < total:
yield self.msg(content[offset : offset + self.FRAGMENT_SIZE], False)
offset += self.FRAGMENT_SIZE
yield self.msg(content[offset:], True)