2025-12-25 upload
This commit is contained in:
27
venv/Lib/site-packages/mitmproxy/proxy/layers/__init__.py
Normal file
27
venv/Lib/site-packages/mitmproxy/proxy/layers/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from . import modes
|
||||
from .dns import DNSLayer
|
||||
from .http import HttpLayer
|
||||
from .quic import ClientQuicLayer
|
||||
from .quic import QuicStreamLayer
|
||||
from .quic import RawQuicLayer
|
||||
from .quic import ServerQuicLayer
|
||||
from .tcp import TCPLayer
|
||||
from .tls import ClientTLSLayer
|
||||
from .tls import ServerTLSLayer
|
||||
from .udp import UDPLayer
|
||||
from .websocket import WebsocketLayer
|
||||
|
||||
__all__ = [
|
||||
"modes",
|
||||
"DNSLayer",
|
||||
"HttpLayer",
|
||||
"QuicStreamLayer",
|
||||
"RawQuicLayer",
|
||||
"TCPLayer",
|
||||
"UDPLayer",
|
||||
"ClientQuicLayer",
|
||||
"ClientTLSLayer",
|
||||
"ServerQuicLayer",
|
||||
"ServerTLSLayer",
|
||||
"WebsocketLayer",
|
||||
]
|
||||
190
venv/Lib/site-packages/mitmproxy/proxy/layers/dns.py
Normal file
190
venv/Lib/site-packages/mitmproxy/proxy/layers/dns.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import struct
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
from typing import Literal
|
||||
|
||||
from mitmproxy import dns
|
||||
from mitmproxy import flow as mflow
|
||||
from mitmproxy.net.dns import response_codes
|
||||
from mitmproxy.proxy import commands
|
||||
from mitmproxy.proxy import events
|
||||
from mitmproxy.proxy import layer
|
||||
from mitmproxy.proxy.context import Context
|
||||
from mitmproxy.proxy.utils import expect
|
||||
|
||||
_LENGTH_LABEL = struct.Struct("!H")
|
||||
|
||||
|
||||
@dataclass
|
||||
class DnsRequestHook(commands.StartHook):
|
||||
"""
|
||||
A DNS query has been received.
|
||||
"""
|
||||
|
||||
flow: dns.DNSFlow
|
||||
|
||||
|
||||
@dataclass
|
||||
class DnsResponseHook(commands.StartHook):
|
||||
"""
|
||||
A DNS response has been received or set.
|
||||
"""
|
||||
|
||||
flow: dns.DNSFlow
|
||||
|
||||
|
||||
@dataclass
|
||||
class DnsErrorHook(commands.StartHook):
|
||||
"""
|
||||
A DNS error has occurred.
|
||||
"""
|
||||
|
||||
flow: dns.DNSFlow
|
||||
|
||||
|
||||
def pack_message(
|
||||
message: dns.DNSMessage, transport_protocol: Literal["tcp", "udp"]
|
||||
) -> bytes:
|
||||
packed = message.packed
|
||||
if transport_protocol == "tcp":
|
||||
return struct.pack("!H", len(packed)) + packed
|
||||
else:
|
||||
return packed
|
||||
|
||||
|
||||
class DNSLayer(layer.Layer):
|
||||
"""
|
||||
Layer that handles resolving DNS queries.
|
||||
"""
|
||||
|
||||
flows: dict[int, dns.DNSFlow]
|
||||
req_buf: bytearray
|
||||
resp_buf: bytearray
|
||||
|
||||
def __init__(self, context: Context):
|
||||
super().__init__(context)
|
||||
self.flows = {}
|
||||
self.req_buf = bytearray()
|
||||
self.resp_buf = bytearray()
|
||||
|
||||
def handle_request(
|
||||
self, flow: dns.DNSFlow, msg: dns.DNSMessage
|
||||
) -> layer.CommandGenerator[None]:
|
||||
flow.request = msg # if already set, continue and query upstream again
|
||||
yield DnsRequestHook(flow)
|
||||
if flow.response:
|
||||
yield from self.handle_response(flow, flow.response)
|
||||
elif flow.error:
|
||||
yield from self.handle_error(flow, flow.error.msg)
|
||||
elif not self.context.server.address:
|
||||
yield from self.handle_error(
|
||||
flow, "No hook has set a response and there is no upstream server."
|
||||
)
|
||||
else:
|
||||
if not self.context.server.connected:
|
||||
err = yield commands.OpenConnection(self.context.server)
|
||||
if err:
|
||||
yield from self.handle_error(flow, str(err))
|
||||
# cannot recover from this
|
||||
return
|
||||
packed = pack_message(flow.request, flow.server_conn.transport_protocol)
|
||||
yield commands.SendData(self.context.server, packed)
|
||||
|
||||
def handle_response(
|
||||
self, flow: dns.DNSFlow, msg: dns.DNSMessage
|
||||
) -> layer.CommandGenerator[None]:
|
||||
flow.response = msg
|
||||
yield DnsResponseHook(flow)
|
||||
if flow.response:
|
||||
packed = pack_message(flow.response, flow.client_conn.transport_protocol)
|
||||
yield commands.SendData(self.context.client, packed)
|
||||
|
||||
def handle_error(self, flow: dns.DNSFlow, err: str) -> layer.CommandGenerator[None]:
|
||||
flow.error = mflow.Error(err)
|
||||
yield DnsErrorHook(flow)
|
||||
servfail = flow.request.fail(response_codes.SERVFAIL)
|
||||
yield commands.SendData(
|
||||
self.context.client,
|
||||
pack_message(servfail, flow.client_conn.transport_protocol),
|
||||
)
|
||||
|
||||
def unpack_message(self, data: bytes, from_client: bool) -> List[dns.DNSMessage]:
|
||||
msgs: List[dns.DNSMessage] = []
|
||||
|
||||
buf = self.req_buf if from_client else self.resp_buf
|
||||
|
||||
if self.context.client.transport_protocol == "udp":
|
||||
msgs.append(dns.DNSMessage.unpack(data, timestamp=time.time()))
|
||||
elif self.context.client.transport_protocol == "tcp":
|
||||
buf.extend(data)
|
||||
size = len(buf)
|
||||
offset = 0
|
||||
|
||||
while True:
|
||||
if size - offset < _LENGTH_LABEL.size:
|
||||
break
|
||||
(expected_size,) = _LENGTH_LABEL.unpack_from(buf, offset)
|
||||
offset += _LENGTH_LABEL.size
|
||||
if expected_size == 0:
|
||||
raise struct.error("Message length field cannot be zero")
|
||||
|
||||
if size - offset < expected_size:
|
||||
offset -= _LENGTH_LABEL.size
|
||||
break
|
||||
|
||||
data = bytes(buf[offset : expected_size + offset])
|
||||
offset += expected_size
|
||||
msgs.append(dns.DNSMessage.unpack(data, timestamp=time.time()))
|
||||
|
||||
del buf[:offset]
|
||||
return msgs
|
||||
|
||||
@expect(events.Start)
|
||||
def state_start(self, _) -> layer.CommandGenerator[None]:
|
||||
self._handle_event = self.state_query
|
||||
yield from ()
|
||||
|
||||
@expect(events.DataReceived, events.ConnectionClosed)
|
||||
def state_query(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
assert isinstance(event, events.ConnectionEvent)
|
||||
from_client = event.connection is self.context.client
|
||||
|
||||
if isinstance(event, events.DataReceived):
|
||||
msgs: List[dns.DNSMessage] = []
|
||||
try:
|
||||
msgs = self.unpack_message(event.data, from_client)
|
||||
except struct.error as e:
|
||||
yield commands.Log(f"{event.connection} sent an invalid message: {e}")
|
||||
yield commands.CloseConnection(event.connection)
|
||||
self._handle_event = self.state_done
|
||||
else:
|
||||
for msg in msgs:
|
||||
try:
|
||||
flow = self.flows[msg.id]
|
||||
except KeyError:
|
||||
flow = dns.DNSFlow(
|
||||
self.context.client, self.context.server, live=True
|
||||
)
|
||||
self.flows[msg.id] = flow
|
||||
if from_client:
|
||||
yield from self.handle_request(flow, msg)
|
||||
else:
|
||||
yield from self.handle_response(flow, msg)
|
||||
|
||||
elif isinstance(event, events.ConnectionClosed):
|
||||
other_conn = self.context.server if from_client else self.context.client
|
||||
if other_conn.connected:
|
||||
yield commands.CloseConnection(other_conn)
|
||||
self._handle_event = self.state_done
|
||||
for flow in self.flows.values():
|
||||
flow.live = False
|
||||
|
||||
else:
|
||||
raise AssertionError(f"Unexpected event: {event}")
|
||||
|
||||
@expect(events.DataReceived, events.ConnectionClosed)
|
||||
def state_done(self, _) -> layer.CommandGenerator[None]:
|
||||
yield from ()
|
||||
|
||||
_handle_event = state_start
|
||||
1210
venv/Lib/site-packages/mitmproxy/proxy/layers/http/__init__.py
Normal file
1210
venv/Lib/site-packages/mitmproxy/proxy/layers/http/__init__.py
Normal file
File diff suppressed because it is too large
Load Diff
61
venv/Lib/site-packages/mitmproxy/proxy/layers/http/_base.py
Normal file
61
venv/Lib/site-packages/mitmproxy/proxy/layers/http/_base.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import html
|
||||
import textwrap
|
||||
from dataclasses import dataclass
|
||||
|
||||
from mitmproxy import http
|
||||
from mitmproxy.connection import Connection
|
||||
from mitmproxy.proxy import commands
|
||||
from mitmproxy.proxy import events
|
||||
from mitmproxy.proxy import layer
|
||||
from mitmproxy.proxy.context import Context
|
||||
|
||||
StreamId = int
|
||||
|
||||
|
||||
@dataclass
|
||||
class HttpEvent(events.Event):
|
||||
# we need stream ids on every event to avoid race conditions
|
||||
stream_id: StreamId
|
||||
|
||||
|
||||
class HttpConnection(layer.Layer):
|
||||
conn: Connection
|
||||
|
||||
def __init__(self, context: Context, conn: Connection):
|
||||
super().__init__(context)
|
||||
self.conn = conn
|
||||
|
||||
|
||||
class HttpCommand(commands.Command):
|
||||
pass
|
||||
|
||||
|
||||
class ReceiveHttp(HttpCommand):
|
||||
event: HttpEvent
|
||||
|
||||
def __init__(self, event: HttpEvent):
|
||||
self.event = event
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Receive({self.event})"
|
||||
|
||||
|
||||
def format_error(status_code: int, message: str) -> bytes:
|
||||
reason = http.status_codes.RESPONSES.get(status_code, "Unknown")
|
||||
return (
|
||||
textwrap.dedent(
|
||||
f"""
|
||||
<html>
|
||||
<head>
|
||||
<title>{status_code} {reason}</title>
|
||||
</head>
|
||||
<body>
|
||||
<h1>{status_code} {reason}</h1>
|
||||
<p>{html.escape(message)}</p>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
)
|
||||
.strip()
|
||||
.encode("utf8", "replace")
|
||||
)
|
||||
167
venv/Lib/site-packages/mitmproxy/proxy/layers/http/_events.py
Normal file
167
venv/Lib/site-packages/mitmproxy/proxy/layers/http/_events.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import enum
|
||||
import typing
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ._base import HttpEvent
|
||||
from mitmproxy import http
|
||||
from mitmproxy.http import HTTPFlow
|
||||
from mitmproxy.net.http import status_codes
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestHeaders(HttpEvent):
|
||||
request: http.Request
|
||||
end_stream: bool
|
||||
"""
|
||||
If True, we already know at this point that there is no message body. This is useful for HTTP/2, where it allows
|
||||
us to set END_STREAM on headers already (and some servers - Akamai - implicitly expect that).
|
||||
In either case, this event will nonetheless be followed by RequestEndOfMessage.
|
||||
"""
|
||||
replay_flow: HTTPFlow | None = None
|
||||
"""If set, the current request headers belong to a replayed flow, which should be reused."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResponseHeaders(HttpEvent):
|
||||
response: http.Response
|
||||
end_stream: bool = False
|
||||
|
||||
|
||||
# explicit constructors below to facilitate type checking in _http1/_http2
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestData(HttpEvent):
|
||||
data: bytes
|
||||
|
||||
def __init__(self, stream_id: int, data: bytes):
|
||||
self.stream_id = stream_id
|
||||
self.data = data
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResponseData(HttpEvent):
|
||||
data: bytes
|
||||
|
||||
def __init__(self, stream_id: int, data: bytes):
|
||||
self.stream_id = stream_id
|
||||
self.data = data
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestTrailers(HttpEvent):
|
||||
trailers: http.Headers
|
||||
|
||||
def __init__(self, stream_id: int, trailers: http.Headers):
|
||||
self.stream_id = stream_id
|
||||
self.trailers = trailers
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResponseTrailers(HttpEvent):
|
||||
trailers: http.Headers
|
||||
|
||||
def __init__(self, stream_id: int, trailers: http.Headers):
|
||||
self.stream_id = stream_id
|
||||
self.trailers = trailers
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestEndOfMessage(HttpEvent):
|
||||
def __init__(self, stream_id: int):
|
||||
self.stream_id = stream_id
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResponseEndOfMessage(HttpEvent):
|
||||
def __init__(self, stream_id: int):
|
||||
self.stream_id = stream_id
|
||||
|
||||
|
||||
class ErrorCode(enum.Enum):
|
||||
GENERIC_CLIENT_ERROR = 1
|
||||
GENERIC_SERVER_ERROR = 2
|
||||
REQUEST_TOO_LARGE = 3
|
||||
RESPONSE_TOO_LARGE = 4
|
||||
CONNECT_FAILED = 5
|
||||
PASSTHROUGH_CLOSE = 6
|
||||
KILL = 7
|
||||
HTTP_1_1_REQUIRED = 8
|
||||
"""Client should fall back to HTTP/1.1 to perform request."""
|
||||
DESTINATION_UNKNOWN = 9
|
||||
"""Proxy does not know where to send request to."""
|
||||
CLIENT_DISCONNECTED = 10
|
||||
"""Client disconnected before receiving entire response."""
|
||||
CANCEL = 11
|
||||
"""Client or server cancelled h2/h3 stream."""
|
||||
REQUEST_VALIDATION_FAILED = 12
|
||||
RESPONSE_VALIDATION_FAILED = 13
|
||||
|
||||
def http_status_code(self) -> int | None:
|
||||
match self:
|
||||
# Client Errors
|
||||
case (
|
||||
ErrorCode.GENERIC_CLIENT_ERROR
|
||||
| ErrorCode.REQUEST_VALIDATION_FAILED
|
||||
| ErrorCode.DESTINATION_UNKNOWN
|
||||
):
|
||||
return status_codes.BAD_REQUEST
|
||||
case ErrorCode.REQUEST_TOO_LARGE:
|
||||
return status_codes.PAYLOAD_TOO_LARGE
|
||||
case (
|
||||
ErrorCode.CONNECT_FAILED
|
||||
| ErrorCode.GENERIC_SERVER_ERROR
|
||||
| ErrorCode.RESPONSE_VALIDATION_FAILED
|
||||
| ErrorCode.RESPONSE_TOO_LARGE
|
||||
):
|
||||
return status_codes.BAD_GATEWAY
|
||||
case (
|
||||
ErrorCode.PASSTHROUGH_CLOSE
|
||||
| ErrorCode.KILL
|
||||
| ErrorCode.HTTP_1_1_REQUIRED
|
||||
| ErrorCode.CLIENT_DISCONNECTED
|
||||
| ErrorCode.CANCEL
|
||||
):
|
||||
return None
|
||||
case other: # pragma: no cover
|
||||
typing.assert_never(other)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestProtocolError(HttpEvent):
|
||||
message: str
|
||||
code: ErrorCode = ErrorCode.GENERIC_CLIENT_ERROR
|
||||
|
||||
def __init__(self, stream_id: int, message: str, code: ErrorCode):
|
||||
assert isinstance(code, ErrorCode)
|
||||
self.stream_id = stream_id
|
||||
self.message = message
|
||||
self.code = code
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResponseProtocolError(HttpEvent):
|
||||
message: str
|
||||
code: ErrorCode = ErrorCode.GENERIC_SERVER_ERROR
|
||||
|
||||
def __init__(self, stream_id: int, message: str, code: ErrorCode):
|
||||
assert isinstance(code, ErrorCode)
|
||||
self.stream_id = stream_id
|
||||
self.message = message
|
||||
self.code = code
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ErrorCode",
|
||||
"HttpEvent",
|
||||
"RequestHeaders",
|
||||
"RequestData",
|
||||
"RequestEndOfMessage",
|
||||
"ResponseHeaders",
|
||||
"ResponseData",
|
||||
"RequestTrailers",
|
||||
"ResponseTrailers",
|
||||
"ResponseEndOfMessage",
|
||||
"RequestProtocolError",
|
||||
"ResponseProtocolError",
|
||||
]
|
||||
122
venv/Lib/site-packages/mitmproxy/proxy/layers/http/_hooks.py
Normal file
122
venv/Lib/site-packages/mitmproxy/proxy/layers/http/_hooks.py
Normal file
@@ -0,0 +1,122 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from mitmproxy import http
|
||||
from mitmproxy.proxy import commands
|
||||
|
||||
|
||||
@dataclass
|
||||
class HttpRequestHeadersHook(commands.StartHook):
|
||||
"""
|
||||
HTTP request headers were successfully read. At this point, the body is empty.
|
||||
"""
|
||||
|
||||
name = "requestheaders"
|
||||
flow: http.HTTPFlow
|
||||
|
||||
|
||||
@dataclass
|
||||
class HttpRequestHook(commands.StartHook):
|
||||
"""
|
||||
The full HTTP request has been read.
|
||||
|
||||
Note: If request streaming is active, this event fires after the entire body has been streamed.
|
||||
HTTP trailers, if present, have not been transmitted to the server yet and can still be modified.
|
||||
Enabling streaming may cause unexpected event sequences: For example, `response` may now occur
|
||||
before `request` because the server replied with "413 Payload Too Large" during upload.
|
||||
"""
|
||||
|
||||
name = "request"
|
||||
flow: http.HTTPFlow
|
||||
|
||||
|
||||
@dataclass
|
||||
class HttpResponseHeadersHook(commands.StartHook):
|
||||
"""
|
||||
HTTP response headers were successfully read. At this point, the body is empty.
|
||||
"""
|
||||
|
||||
name = "responseheaders"
|
||||
flow: http.HTTPFlow
|
||||
|
||||
|
||||
@dataclass
|
||||
class HttpResponseHook(commands.StartHook):
|
||||
"""
|
||||
The full HTTP response has been read.
|
||||
|
||||
Note: If response streaming is active, this event fires after the entire body has been streamed.
|
||||
HTTP trailers, if present, have not been transmitted to the client yet and can still be modified.
|
||||
"""
|
||||
|
||||
name = "response"
|
||||
flow: http.HTTPFlow
|
||||
|
||||
|
||||
@dataclass
|
||||
class HttpErrorHook(commands.StartHook):
|
||||
"""
|
||||
An HTTP error has occurred, e.g. invalid server responses, or
|
||||
interrupted connections. This is distinct from a valid server HTTP
|
||||
error response, which is simply a response with an HTTP error code.
|
||||
|
||||
Every flow will receive either an error or an response event, but not both.
|
||||
"""
|
||||
|
||||
name = "error"
|
||||
flow: http.HTTPFlow
|
||||
|
||||
|
||||
@dataclass
|
||||
class HttpConnectHook(commands.StartHook):
|
||||
"""
|
||||
An HTTP CONNECT request was received. This event can be ignored for most practical purposes.
|
||||
|
||||
This event only occurs in regular and upstream proxy modes
|
||||
when the client instructs mitmproxy to open a connection to an upstream host.
|
||||
Setting a non 2xx response on the flow will return the response to the client and abort the connection.
|
||||
|
||||
CONNECT requests are HTTP proxy instructions for mitmproxy itself
|
||||
and not forwarded. They do not generate the usual HTTP handler events,
|
||||
but all requests going over the newly opened connection will.
|
||||
"""
|
||||
|
||||
flow: http.HTTPFlow
|
||||
|
||||
|
||||
@dataclass
|
||||
class HttpConnectUpstreamHook(commands.StartHook):
|
||||
"""
|
||||
An HTTP CONNECT request is about to be sent to an upstream proxy.
|
||||
This event can be ignored for most practical purposes.
|
||||
|
||||
This event can be used to set custom authentication headers for upstream proxies.
|
||||
|
||||
CONNECT requests do not generate the usual HTTP handler events,
|
||||
but all requests going over the newly opened connection will.
|
||||
"""
|
||||
|
||||
flow: http.HTTPFlow
|
||||
|
||||
|
||||
@dataclass
|
||||
class HttpConnectedHook(commands.StartHook):
|
||||
"""
|
||||
HTTP CONNECT was successful
|
||||
|
||||
> [!WARNING]
|
||||
> This may fire before an upstream connection has been established
|
||||
> if `connection_strategy` is set to `lazy` (default)
|
||||
"""
|
||||
|
||||
flow: http.HTTPFlow
|
||||
|
||||
|
||||
@dataclass
|
||||
class HttpConnectErrorHook(commands.StartHook):
|
||||
"""
|
||||
HTTP CONNECT has failed.
|
||||
This can happen when the upstream server is unreachable or proxy authentication is required.
|
||||
In contrast to the `error` hook, `flow.error` is not guaranteed to be set.
|
||||
"""
|
||||
|
||||
flow: http.HTTPFlow
|
||||
502
venv/Lib/site-packages/mitmproxy/proxy/layers/http/_http1.py
Normal file
502
venv/Lib/site-packages/mitmproxy/proxy/layers/http/_http1.py
Normal file
@@ -0,0 +1,502 @@
|
||||
import abc
|
||||
from collections.abc import Callable
|
||||
from typing import Union
|
||||
|
||||
import h11
|
||||
from h11._readers import ChunkedReader
|
||||
from h11._readers import ContentLengthReader
|
||||
from h11._readers import Http10Reader
|
||||
from h11._receivebuffer import ReceiveBuffer
|
||||
|
||||
from ...context import Context
|
||||
from ._base import format_error
|
||||
from ._base import HttpConnection
|
||||
from ._events import ErrorCode
|
||||
from ._events import HttpEvent
|
||||
from ._events import RequestData
|
||||
from ._events import RequestEndOfMessage
|
||||
from ._events import RequestHeaders
|
||||
from ._events import RequestProtocolError
|
||||
from ._events import ResponseData
|
||||
from ._events import ResponseEndOfMessage
|
||||
from ._events import ResponseHeaders
|
||||
from ._events import ResponseProtocolError
|
||||
from mitmproxy import http
|
||||
from mitmproxy import version
|
||||
from mitmproxy.connection import Connection
|
||||
from mitmproxy.connection import ConnectionState
|
||||
from mitmproxy.net.http import http1
|
||||
from mitmproxy.net.http import status_codes
|
||||
from mitmproxy.proxy import commands
|
||||
from mitmproxy.proxy import events
|
||||
from mitmproxy.proxy import layer
|
||||
from mitmproxy.proxy.layers.http._base import ReceiveHttp
|
||||
from mitmproxy.proxy.layers.http._base import StreamId
|
||||
from mitmproxy.proxy.utils import expect
|
||||
from mitmproxy.utils import human
|
||||
|
||||
TBodyReader = Union[ChunkedReader, Http10Reader, ContentLengthReader]
|
||||
|
||||
|
||||
class Http1Connection(HttpConnection, metaclass=abc.ABCMeta):
|
||||
stream_id: StreamId | None = None
|
||||
request: http.Request | None = None
|
||||
response: http.Response | None = None
|
||||
request_done: bool = False
|
||||
response_done: bool = False
|
||||
# this is a bit of a hack to make both mypy and PyCharm happy.
|
||||
state: Callable[[events.Event], layer.CommandGenerator[None]] | Callable
|
||||
body_reader: TBodyReader
|
||||
buf: ReceiveBuffer
|
||||
|
||||
ReceiveProtocolError: type[RequestProtocolError | ResponseProtocolError]
|
||||
ReceiveData: type[RequestData | ResponseData]
|
||||
ReceiveEndOfMessage: type[RequestEndOfMessage | ResponseEndOfMessage]
|
||||
|
||||
def __init__(self, context: Context, conn: Connection):
|
||||
super().__init__(context, conn)
|
||||
self.buf = ReceiveBuffer()
|
||||
|
||||
@abc.abstractmethod
|
||||
def send(self, event: HttpEvent) -> layer.CommandGenerator[None]:
|
||||
yield from () # pragma: no cover
|
||||
|
||||
@abc.abstractmethod
|
||||
def read_headers(
|
||||
self, event: events.ConnectionEvent
|
||||
) -> layer.CommandGenerator[None]:
|
||||
yield from () # pragma: no cover
|
||||
|
||||
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
if isinstance(event, HttpEvent):
|
||||
yield from self.send(event)
|
||||
else:
|
||||
if (
|
||||
isinstance(event, events.DataReceived)
|
||||
and self.state != self.passthrough
|
||||
):
|
||||
self.buf += event.data
|
||||
yield from self.state(event)
|
||||
|
||||
@expect(events.Start)
|
||||
def start(self, _) -> layer.CommandGenerator[None]:
|
||||
self.state = self.read_headers
|
||||
yield from ()
|
||||
|
||||
state = start
|
||||
|
||||
def read_body(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
assert self.stream_id is not None
|
||||
while True:
|
||||
try:
|
||||
if isinstance(event, events.DataReceived):
|
||||
h11_event = self.body_reader(self.buf)
|
||||
elif isinstance(event, events.ConnectionClosed):
|
||||
h11_event = self.body_reader.read_eof()
|
||||
else:
|
||||
raise AssertionError(f"Unexpected event: {event}")
|
||||
except h11.ProtocolError as e:
|
||||
yield commands.CloseConnection(self.conn)
|
||||
yield ReceiveHttp(
|
||||
self.ReceiveProtocolError(
|
||||
self.stream_id,
|
||||
f"HTTP/1 protocol error: {e}",
|
||||
code=self.ReceiveProtocolError.code,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
if h11_event is None:
|
||||
return
|
||||
elif isinstance(h11_event, h11.Data):
|
||||
data: bytes = bytes(h11_event.data)
|
||||
if data:
|
||||
yield ReceiveHttp(self.ReceiveData(self.stream_id, data))
|
||||
elif isinstance(h11_event, h11.EndOfMessage):
|
||||
assert self.request
|
||||
if h11_event.headers:
|
||||
raise NotImplementedError(f"HTTP trailers are not implemented yet.")
|
||||
if self.request.data.method.upper() != b"CONNECT":
|
||||
yield ReceiveHttp(self.ReceiveEndOfMessage(self.stream_id))
|
||||
is_request = isinstance(self, Http1Server)
|
||||
yield from self.mark_done(request=is_request, response=not is_request)
|
||||
return
|
||||
|
||||
def wait(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
"""
|
||||
We wait for the current flow to be finished before parsing the next message,
|
||||
as we may want to upgrade to WebSocket or plain TCP before that.
|
||||
"""
|
||||
assert self.stream_id
|
||||
if isinstance(event, events.DataReceived):
|
||||
return
|
||||
elif isinstance(event, events.ConnectionClosed):
|
||||
# for practical purposes, we assume that a peer which sent at least a FIN
|
||||
# is not interested in any more data from us, see
|
||||
# see https://github.com/httpwg/http-core/issues/22
|
||||
if event.connection.state is not ConnectionState.CLOSED:
|
||||
yield commands.CloseConnection(event.connection)
|
||||
yield ReceiveHttp(
|
||||
self.ReceiveProtocolError(
|
||||
self.stream_id,
|
||||
f"Client disconnected.",
|
||||
code=ErrorCode.CLIENT_DISCONNECTED,
|
||||
)
|
||||
)
|
||||
else: # pragma: no cover
|
||||
raise AssertionError(f"Unexpected event: {event}")
|
||||
|
||||
def done(self, event: events.ConnectionEvent) -> layer.CommandGenerator[None]:
|
||||
yield from () # pragma: no cover
|
||||
|
||||
def make_pipe(self) -> layer.CommandGenerator[None]:
|
||||
self.state = self.passthrough
|
||||
if self.buf:
|
||||
already_received = self.buf.maybe_extract_at_most(len(self.buf)) or b""
|
||||
# Some clients send superfluous newlines after CONNECT, we want to eat those.
|
||||
already_received = already_received.lstrip(b"\r\n")
|
||||
if already_received:
|
||||
yield from self.state(events.DataReceived(self.conn, already_received))
|
||||
|
||||
def passthrough(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
assert self.stream_id
|
||||
if isinstance(event, events.DataReceived):
|
||||
yield ReceiveHttp(self.ReceiveData(self.stream_id, event.data))
|
||||
elif isinstance(event, events.ConnectionClosed):
|
||||
if isinstance(self, Http1Server):
|
||||
yield ReceiveHttp(RequestEndOfMessage(self.stream_id))
|
||||
else:
|
||||
yield ReceiveHttp(ResponseEndOfMessage(self.stream_id))
|
||||
|
||||
def mark_done(
|
||||
self, *, request: bool = False, response: bool = False
|
||||
) -> layer.CommandGenerator[None]:
|
||||
if request:
|
||||
self.request_done = True
|
||||
if response:
|
||||
self.response_done = True
|
||||
if self.request_done and self.response_done:
|
||||
assert self.request
|
||||
assert self.response
|
||||
if should_make_pipe(self.request, self.response):
|
||||
yield from self.make_pipe()
|
||||
return
|
||||
try:
|
||||
read_until_eof_semantics = (
|
||||
http1.expected_http_body_size(self.request, self.response) == -1
|
||||
)
|
||||
except ValueError:
|
||||
# this may raise only now (and not earlier) because an addon set invalid headers,
|
||||
# in which case it's not really clear what we are supposed to do.
|
||||
read_until_eof_semantics = False
|
||||
connection_done = (
|
||||
read_until_eof_semantics
|
||||
or http1.connection_close(
|
||||
self.request.http_version, self.request.headers
|
||||
)
|
||||
or http1.connection_close(
|
||||
self.response.http_version, self.response.headers
|
||||
)
|
||||
# If we proxy HTTP/2 to HTTP/1, we only use upstream connections for one request.
|
||||
# This simplifies our connection management quite a bit as we can rely on
|
||||
# the proxyserver's max-connection-per-server throttling.
|
||||
or (
|
||||
(self.request.is_http2 or self.request.is_http3)
|
||||
and isinstance(self, Http1Client)
|
||||
)
|
||||
)
|
||||
if connection_done:
|
||||
yield commands.CloseConnection(self.conn)
|
||||
self.state = self.done
|
||||
return
|
||||
self.request_done = self.response_done = False
|
||||
self.request = self.response = None
|
||||
if isinstance(self, Http1Server):
|
||||
self.stream_id += 2
|
||||
else:
|
||||
self.stream_id = None
|
||||
self.state = self.read_headers
|
||||
if self.buf:
|
||||
yield from self.state(events.DataReceived(self.conn, b""))
|
||||
|
||||
|
||||
class Http1Server(Http1Connection):
|
||||
"""A simple HTTP/1 server with no pipelining support."""
|
||||
|
||||
ReceiveProtocolError = RequestProtocolError
|
||||
ReceiveData = RequestData
|
||||
ReceiveEndOfMessage = RequestEndOfMessage
|
||||
stream_id: int
|
||||
|
||||
def __init__(self, context: Context):
|
||||
super().__init__(context, context.client)
|
||||
self.stream_id = 1
|
||||
|
||||
def send(self, event: HttpEvent) -> layer.CommandGenerator[None]:
|
||||
assert event.stream_id == self.stream_id
|
||||
if isinstance(event, ResponseHeaders):
|
||||
self.response = response = event.response
|
||||
|
||||
if response.is_http2 or response.is_http3:
|
||||
response = response.copy()
|
||||
# Convert to an HTTP/1 response.
|
||||
response.http_version = "HTTP/1.1"
|
||||
# not everyone supports empty reason phrases, so we better make up one.
|
||||
response.reason = status_codes.RESPONSES.get(response.status_code, "")
|
||||
# Shall we set a Content-Length header here if there is none?
|
||||
# For now, let's try to modify as little as possible.
|
||||
|
||||
raw = http1.assemble_response_head(response)
|
||||
yield commands.SendData(self.conn, raw)
|
||||
elif isinstance(event, ResponseData):
|
||||
assert self.response
|
||||
if "chunked" in self.response.headers.get("transfer-encoding", "").lower():
|
||||
raw = b"%x\r\n%s\r\n" % (len(event.data), event.data)
|
||||
else:
|
||||
raw = event.data
|
||||
if raw:
|
||||
yield commands.SendData(self.conn, raw)
|
||||
elif isinstance(event, ResponseEndOfMessage):
|
||||
assert self.request
|
||||
assert self.response
|
||||
if (
|
||||
self.request.method.upper() != "HEAD"
|
||||
and "chunked"
|
||||
in self.response.headers.get("transfer-encoding", "").lower()
|
||||
):
|
||||
yield commands.SendData(self.conn, b"0\r\n\r\n")
|
||||
yield from self.mark_done(response=True)
|
||||
elif isinstance(event, ResponseProtocolError):
|
||||
if not (self.conn.state & ConnectionState.CAN_WRITE):
|
||||
return
|
||||
status = event.code.http_status_code()
|
||||
if not self.response and status is not None:
|
||||
yield commands.SendData(
|
||||
self.conn, make_error_response(status, event.message)
|
||||
)
|
||||
yield commands.CloseConnection(self.conn)
|
||||
else:
|
||||
raise AssertionError(f"Unexpected event: {event}")
|
||||
|
||||
def read_headers(
|
||||
self, event: events.ConnectionEvent
|
||||
) -> layer.CommandGenerator[None]:
|
||||
if isinstance(event, events.DataReceived):
|
||||
request_head = self.buf.maybe_extract_lines()
|
||||
if request_head:
|
||||
try:
|
||||
self.request = http1.read_request_head(
|
||||
[bytes(x) for x in request_head]
|
||||
)
|
||||
expected_body_size = http1.expected_http_body_size(self.request)
|
||||
except ValueError as e:
|
||||
yield commands.SendData(self.conn, make_error_response(400, str(e)))
|
||||
yield commands.CloseConnection(self.conn)
|
||||
if self.request:
|
||||
# we have headers that we can show in the ui
|
||||
yield ReceiveHttp(
|
||||
RequestHeaders(self.stream_id, self.request, False)
|
||||
)
|
||||
yield ReceiveHttp(
|
||||
RequestProtocolError(
|
||||
self.stream_id, str(e), ErrorCode.GENERIC_CLIENT_ERROR
|
||||
)
|
||||
)
|
||||
else:
|
||||
yield commands.Log(
|
||||
f"{human.format_address(self.conn.peername)}: {e}"
|
||||
)
|
||||
self.state = self.done
|
||||
return
|
||||
yield ReceiveHttp(
|
||||
RequestHeaders(
|
||||
self.stream_id, self.request, expected_body_size == 0
|
||||
)
|
||||
)
|
||||
self.body_reader = make_body_reader(expected_body_size)
|
||||
self.state = self.read_body
|
||||
yield from self.state(event)
|
||||
else:
|
||||
pass # FIXME: protect against header size DoS
|
||||
elif isinstance(event, events.ConnectionClosed):
|
||||
buf = bytes(self.buf)
|
||||
if buf.strip():
|
||||
yield commands.Log(
|
||||
f"Client closed connection before completing request headers: {buf!r}"
|
||||
)
|
||||
yield commands.CloseConnection(self.conn)
|
||||
else:
|
||||
raise AssertionError(f"Unexpected event: {event}")
|
||||
|
||||
def mark_done(
|
||||
self, *, request: bool = False, response: bool = False
|
||||
) -> layer.CommandGenerator[None]:
|
||||
yield from super().mark_done(request=request, response=response)
|
||||
if self.request_done and not self.response_done:
|
||||
self.state = self.wait
|
||||
|
||||
|
||||
class Http1Client(Http1Connection):
|
||||
"""A simple HTTP/1 client with no pipelining support."""
|
||||
|
||||
ReceiveProtocolError = ResponseProtocolError
|
||||
ReceiveData = ResponseData
|
||||
ReceiveEndOfMessage = ResponseEndOfMessage
|
||||
|
||||
def __init__(self, context: Context):
|
||||
super().__init__(context, context.server)
|
||||
|
||||
def send(self, event: HttpEvent) -> layer.CommandGenerator[None]:
|
||||
if isinstance(event, RequestProtocolError):
|
||||
yield commands.CloseConnection(self.conn)
|
||||
return
|
||||
|
||||
if self.stream_id is None:
|
||||
assert isinstance(event, RequestHeaders)
|
||||
self.stream_id = event.stream_id
|
||||
self.request = event.request
|
||||
assert self.stream_id == event.stream_id
|
||||
|
||||
if isinstance(event, RequestHeaders):
|
||||
request = event.request
|
||||
if request.is_http2 or request.is_http3:
|
||||
# Convert to an HTTP/1 request.
|
||||
request = (
|
||||
request.copy()
|
||||
) # (we could probably be a bit more efficient here.)
|
||||
request.http_version = "HTTP/1.1"
|
||||
if "Host" not in request.headers and request.authority:
|
||||
request.headers.insert(0, "Host", request.authority)
|
||||
request.authority = ""
|
||||
cookie_headers = request.headers.get_all("Cookie")
|
||||
if len(cookie_headers) > 1:
|
||||
# Only HTTP/2 supports multiple cookie headers, HTTP/1.x does not.
|
||||
# see: https://www.rfc-editor.org/rfc/rfc6265#section-5.4
|
||||
# https://www.rfc-editor.org/rfc/rfc7540#section-8.1.2.5
|
||||
request.headers["Cookie"] = "; ".join(cookie_headers)
|
||||
raw = http1.assemble_request_head(request)
|
||||
yield commands.SendData(self.conn, raw)
|
||||
elif isinstance(event, RequestData):
|
||||
assert self.request
|
||||
if "chunked" in self.request.headers.get("transfer-encoding", "").lower():
|
||||
raw = b"%x\r\n%s\r\n" % (len(event.data), event.data)
|
||||
else:
|
||||
raw = event.data
|
||||
if raw:
|
||||
yield commands.SendData(self.conn, raw)
|
||||
elif isinstance(event, RequestEndOfMessage):
|
||||
assert self.request
|
||||
if "chunked" in self.request.headers.get("transfer-encoding", "").lower():
|
||||
yield commands.SendData(self.conn, b"0\r\n\r\n")
|
||||
elif http1.expected_http_body_size(self.request, self.response) == -1:
|
||||
yield commands.CloseTcpConnection(self.conn, half_close=True)
|
||||
yield from self.mark_done(request=True)
|
||||
else:
|
||||
raise AssertionError(f"Unexpected event: {event}")
|
||||
|
||||
def read_headers(
|
||||
self, event: events.ConnectionEvent
|
||||
) -> layer.CommandGenerator[None]:
|
||||
if isinstance(event, events.DataReceived):
|
||||
if not self.request:
|
||||
# we just received some data for an unknown request.
|
||||
yield commands.Log(f"Unexpected data from server: {bytes(self.buf)!r}")
|
||||
yield commands.CloseConnection(self.conn)
|
||||
return
|
||||
assert self.stream_id is not None
|
||||
|
||||
response_head = self.buf.maybe_extract_lines()
|
||||
if response_head:
|
||||
try:
|
||||
self.response = http1.read_response_head(
|
||||
[bytes(x) for x in response_head]
|
||||
)
|
||||
expected_size = http1.expected_http_body_size(
|
||||
self.request, self.response
|
||||
)
|
||||
except ValueError as e:
|
||||
yield commands.CloseConnection(self.conn)
|
||||
yield ReceiveHttp(
|
||||
ResponseProtocolError(
|
||||
self.stream_id,
|
||||
f"Cannot parse HTTP response: {e}",
|
||||
ErrorCode.GENERIC_SERVER_ERROR,
|
||||
)
|
||||
)
|
||||
return
|
||||
yield ReceiveHttp(
|
||||
ResponseHeaders(self.stream_id, self.response, expected_size == 0)
|
||||
)
|
||||
self.body_reader = make_body_reader(expected_size)
|
||||
|
||||
self.state = self.read_body
|
||||
yield from self.state(event)
|
||||
else:
|
||||
pass # FIXME: protect against header size DoS
|
||||
elif isinstance(event, events.ConnectionClosed):
|
||||
if self.conn.state & ConnectionState.CAN_WRITE:
|
||||
yield commands.CloseConnection(self.conn)
|
||||
if self.stream_id:
|
||||
if self.buf:
|
||||
yield ReceiveHttp(
|
||||
ResponseProtocolError(
|
||||
self.stream_id,
|
||||
f"unexpected server response: {bytes(self.buf)!r}",
|
||||
ErrorCode.GENERIC_SERVER_ERROR,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# The server has closed the connection to prevent us from continuing.
|
||||
# We need to signal that to the stream.
|
||||
# https://tools.ietf.org/html/rfc7231#section-6.5.11
|
||||
yield ReceiveHttp(
|
||||
ResponseProtocolError(
|
||||
self.stream_id,
|
||||
"server closed connection",
|
||||
ErrorCode.GENERIC_SERVER_ERROR,
|
||||
)
|
||||
)
|
||||
else:
|
||||
return
|
||||
else:
|
||||
raise AssertionError(f"Unexpected event: {event}")
|
||||
|
||||
|
||||
def should_make_pipe(request: http.Request, response: http.Response) -> bool:
|
||||
if response.status_code == 101:
|
||||
return True
|
||||
elif response.status_code == 200 and request.method.upper() == "CONNECT":
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def make_body_reader(expected_size: int | None) -> TBodyReader:
|
||||
if expected_size is None:
|
||||
return ChunkedReader()
|
||||
elif expected_size == -1:
|
||||
return Http10Reader()
|
||||
else:
|
||||
return ContentLengthReader(expected_size)
|
||||
|
||||
|
||||
def make_error_response(
|
||||
status_code: int,
|
||||
message: str = "",
|
||||
) -> bytes:
|
||||
resp = http.Response.make(
|
||||
status_code,
|
||||
format_error(status_code, message),
|
||||
http.Headers(
|
||||
Server=version.MITMPROXY,
|
||||
Connection="close",
|
||||
Content_Type="text/html",
|
||||
),
|
||||
)
|
||||
return http1.assemble_response(resp)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Http1Client",
|
||||
"Http1Server",
|
||||
]
|
||||
714
venv/Lib/site-packages/mitmproxy/proxy/layers/http/_http2.py
Normal file
714
venv/Lib/site-packages/mitmproxy/proxy/layers/http/_http2.py
Normal file
@@ -0,0 +1,714 @@
|
||||
import collections
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from logging import DEBUG
|
||||
from logging import ERROR
|
||||
from typing import Any
|
||||
from typing import assert_never
|
||||
from typing import ClassVar
|
||||
|
||||
import h2.config
|
||||
import h2.connection
|
||||
import h2.errors
|
||||
import h2.events
|
||||
import h2.exceptions
|
||||
import h2.settings
|
||||
import h2.stream
|
||||
import h2.utilities
|
||||
|
||||
from ...commands import CloseConnection
|
||||
from ...commands import Log
|
||||
from ...commands import RequestWakeup
|
||||
from ...commands import SendData
|
||||
from ...context import Context
|
||||
from ...events import ConnectionClosed
|
||||
from ...events import DataReceived
|
||||
from ...events import Event
|
||||
from ...events import Start
|
||||
from ...events import Wakeup
|
||||
from ...layer import CommandGenerator
|
||||
from ...utils import expect
|
||||
from . import ErrorCode
|
||||
from . import RequestData
|
||||
from . import RequestEndOfMessage
|
||||
from . import RequestHeaders
|
||||
from . import RequestProtocolError
|
||||
from . import RequestTrailers
|
||||
from . import ResponseData
|
||||
from . import ResponseEndOfMessage
|
||||
from . import ResponseHeaders
|
||||
from . import ResponseProtocolError
|
||||
from . import ResponseTrailers
|
||||
from ._base import format_error
|
||||
from ._base import HttpConnection
|
||||
from ._base import HttpEvent
|
||||
from ._base import ReceiveHttp
|
||||
from ._http_h2 import BufferedH2Connection
|
||||
from ._http_h2 import H2ConnectionLogger
|
||||
from mitmproxy import http
|
||||
from mitmproxy import version
|
||||
from mitmproxy.connection import Connection
|
||||
from mitmproxy.net.http import status_codes
|
||||
from mitmproxy.net.http import url
|
||||
from mitmproxy.utils import human
|
||||
|
||||
|
||||
class StreamState(Enum):
|
||||
EXPECTING_HEADERS = 1
|
||||
HEADERS_RECEIVED = 2
|
||||
|
||||
|
||||
CATCH_HYPER_H2_ERRORS = (ValueError, IndexError)
|
||||
|
||||
|
||||
class Http2Connection(HttpConnection):
|
||||
h2_conf: ClassVar[h2.config.H2Configuration]
|
||||
h2_conf_defaults: dict[str, Any] = dict(
|
||||
header_encoding=False,
|
||||
validate_outbound_headers=False,
|
||||
# validate_inbound_headers is controlled by the validate_inbound_headers option.
|
||||
normalize_inbound_headers=False, # changing this to True is required to pass h2spec
|
||||
normalize_outbound_headers=False,
|
||||
)
|
||||
h2_conn: BufferedH2Connection
|
||||
streams: dict[int, StreamState]
|
||||
"""keep track of all active stream ids to send protocol errors on teardown"""
|
||||
|
||||
ReceiveProtocolError: type[RequestProtocolError | ResponseProtocolError]
|
||||
ReceiveData: type[RequestData | ResponseData]
|
||||
ReceiveTrailers: type[RequestTrailers | ResponseTrailers]
|
||||
ReceiveEndOfMessage: type[RequestEndOfMessage | ResponseEndOfMessage]
|
||||
|
||||
def __init__(self, context: Context, conn: Connection):
|
||||
super().__init__(context, conn)
|
||||
if self.debug:
|
||||
self.h2_conf.logger = H2ConnectionLogger(
|
||||
self.context.client.peername, self.__class__.__name__
|
||||
)
|
||||
self.h2_conf.validate_inbound_headers = (
|
||||
self.context.options.validate_inbound_headers
|
||||
)
|
||||
self.h2_conn = BufferedH2Connection(self.h2_conf)
|
||||
self.streams = {}
|
||||
|
||||
def is_closed(self, stream_id: int) -> bool:
|
||||
"""Check if a non-idle stream is closed"""
|
||||
stream = self.h2_conn.streams.get(stream_id, None)
|
||||
if (
|
||||
stream is not None
|
||||
and stream.state_machine.state is not h2.stream.StreamState.CLOSED
|
||||
and self.h2_conn.state_machine.state
|
||||
is not h2.connection.ConnectionState.CLOSED
|
||||
):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def is_open_for_us(self, stream_id: int) -> bool:
|
||||
"""Check if we can write to a non-idle stream."""
|
||||
stream = self.h2_conn.streams.get(stream_id, None)
|
||||
if (
|
||||
stream is not None
|
||||
and stream.state_machine.state
|
||||
is not h2.stream.StreamState.HALF_CLOSED_LOCAL
|
||||
and stream.state_machine.state is not h2.stream.StreamState.CLOSED
|
||||
and self.h2_conn.state_machine.state
|
||||
is not h2.connection.ConnectionState.CLOSED
|
||||
):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def _handle_event(self, event: Event) -> CommandGenerator[None]:
|
||||
if isinstance(event, Start):
|
||||
self.h2_conn.initiate_connection()
|
||||
yield SendData(self.conn, self.h2_conn.data_to_send())
|
||||
|
||||
elif isinstance(event, HttpEvent):
|
||||
if isinstance(event, (RequestData, ResponseData)):
|
||||
if self.is_open_for_us(event.stream_id):
|
||||
self.h2_conn.send_data(event.stream_id, event.data)
|
||||
elif isinstance(event, (RequestTrailers, ResponseTrailers)):
|
||||
if self.is_open_for_us(event.stream_id):
|
||||
trailers = [*event.trailers.fields]
|
||||
self.h2_conn.send_trailers(event.stream_id, trailers)
|
||||
elif isinstance(event, (RequestEndOfMessage, ResponseEndOfMessage)):
|
||||
if self.is_open_for_us(event.stream_id):
|
||||
self.h2_conn.end_stream(event.stream_id)
|
||||
elif isinstance(event, (RequestProtocolError, ResponseProtocolError)):
|
||||
if not self.is_closed(event.stream_id):
|
||||
stream: h2.stream.H2Stream = self.h2_conn.streams[event.stream_id]
|
||||
status = event.code.http_status_code()
|
||||
if (
|
||||
isinstance(event, ResponseProtocolError)
|
||||
and self.is_open_for_us(event.stream_id)
|
||||
and not stream.state_machine.headers_sent
|
||||
and status is not None
|
||||
):
|
||||
self.h2_conn.send_headers(
|
||||
event.stream_id,
|
||||
[
|
||||
(b":status", b"%d" % status),
|
||||
(b"server", version.MITMPROXY.encode()),
|
||||
(b"content-type", b"text/html"),
|
||||
],
|
||||
)
|
||||
self.h2_conn.send_data(
|
||||
event.stream_id,
|
||||
format_error(status, event.message),
|
||||
end_stream=True,
|
||||
)
|
||||
else:
|
||||
match event.code:
|
||||
case ErrorCode.CANCEL | ErrorCode.CLIENT_DISCONNECTED:
|
||||
error_code = h2.errors.ErrorCodes.CANCEL
|
||||
case ErrorCode.KILL:
|
||||
# XXX: Debateable whether this is the best error code.
|
||||
error_code = h2.errors.ErrorCodes.INTERNAL_ERROR
|
||||
case ErrorCode.HTTP_1_1_REQUIRED:
|
||||
error_code = h2.errors.ErrorCodes.HTTP_1_1_REQUIRED
|
||||
case ErrorCode.PASSTHROUGH_CLOSE:
|
||||
# FIXME: This probably shouldn't be a protocol error, but an EOM event.
|
||||
error_code = h2.errors.ErrorCodes.CANCEL
|
||||
case (
|
||||
ErrorCode.GENERIC_CLIENT_ERROR
|
||||
| ErrorCode.GENERIC_SERVER_ERROR
|
||||
| ErrorCode.REQUEST_TOO_LARGE
|
||||
| ErrorCode.RESPONSE_TOO_LARGE
|
||||
| ErrorCode.CONNECT_FAILED
|
||||
| ErrorCode.DESTINATION_UNKNOWN
|
||||
| ErrorCode.REQUEST_VALIDATION_FAILED
|
||||
| ErrorCode.RESPONSE_VALIDATION_FAILED
|
||||
):
|
||||
error_code = h2.errors.ErrorCodes.INTERNAL_ERROR
|
||||
case other: # pragma: no cover
|
||||
assert_never(other)
|
||||
self.h2_conn.reset_stream(event.stream_id, error_code.value)
|
||||
else:
|
||||
raise AssertionError(f"Unexpected event: {event}")
|
||||
data_to_send = self.h2_conn.data_to_send()
|
||||
if data_to_send:
|
||||
yield SendData(self.conn, data_to_send)
|
||||
|
||||
elif isinstance(event, DataReceived):
|
||||
try:
|
||||
try:
|
||||
events = self.h2_conn.receive_data(event.data)
|
||||
except CATCH_HYPER_H2_ERRORS as e: # pragma: no cover
|
||||
# this should never raise a ValueError, but we triggered one while fuzzing:
|
||||
# https://github.com/python-hyper/hyper-h2/issues/1231
|
||||
# this stays here as defense-in-depth.
|
||||
raise h2.exceptions.ProtocolError(
|
||||
f"uncaught hyper-h2 error: {e}"
|
||||
) from e
|
||||
except h2.exceptions.ProtocolError as e:
|
||||
events = [e]
|
||||
|
||||
for h2_event in events:
|
||||
if self.debug:
|
||||
yield Log(f"{self.debug}[h2] {h2_event}", DEBUG)
|
||||
if (yield from self.handle_h2_event(h2_event)):
|
||||
if self.debug:
|
||||
yield Log(f"{self.debug}[h2] done", DEBUG)
|
||||
return
|
||||
|
||||
data_to_send = self.h2_conn.data_to_send()
|
||||
if data_to_send:
|
||||
yield SendData(self.conn, data_to_send)
|
||||
|
||||
elif isinstance(event, ConnectionClosed):
|
||||
yield from self.close_connection("peer closed connection")
|
||||
else:
|
||||
raise AssertionError(f"Unexpected event: {event!r}")
|
||||
|
||||
def handle_h2_event(self, event: h2.events.Event) -> CommandGenerator[bool]:
|
||||
"""returns true if further processing should be stopped."""
|
||||
if isinstance(event, h2.events.DataReceived):
|
||||
state = self.streams.get(event.stream_id, None)
|
||||
if state is StreamState.HEADERS_RECEIVED:
|
||||
is_empty_eos_data_frame = event.stream_ended and not event.data
|
||||
if not is_empty_eos_data_frame:
|
||||
yield ReceiveHttp(self.ReceiveData(event.stream_id, event.data))
|
||||
elif state is StreamState.EXPECTING_HEADERS:
|
||||
yield from self.protocol_error(
|
||||
f"Received HTTP/2 data frame, expected headers."
|
||||
)
|
||||
return True
|
||||
self.h2_conn.acknowledge_received_data(
|
||||
event.flow_controlled_length, event.stream_id
|
||||
)
|
||||
elif isinstance(event, h2.events.TrailersReceived):
|
||||
trailers = http.Headers(event.headers)
|
||||
yield ReceiveHttp(self.ReceiveTrailers(event.stream_id, trailers))
|
||||
elif isinstance(event, h2.events.StreamEnded):
|
||||
state = self.streams.get(event.stream_id, None)
|
||||
if state is StreamState.HEADERS_RECEIVED:
|
||||
yield ReceiveHttp(self.ReceiveEndOfMessage(event.stream_id))
|
||||
elif state is StreamState.EXPECTING_HEADERS:
|
||||
raise AssertionError("unreachable")
|
||||
if self.is_closed(event.stream_id):
|
||||
self.streams.pop(event.stream_id, None)
|
||||
elif isinstance(event, h2.events.StreamReset):
|
||||
if event.stream_id in self.streams:
|
||||
try:
|
||||
err_str = h2.errors.ErrorCodes(event.error_code).name
|
||||
except ValueError:
|
||||
err_str = str(event.error_code)
|
||||
match event.error_code:
|
||||
case h2.errors.ErrorCodes.CANCEL:
|
||||
err_code = ErrorCode.CANCEL
|
||||
case h2.errors.ErrorCodes.HTTP_1_1_REQUIRED:
|
||||
err_code = ErrorCode.HTTP_1_1_REQUIRED
|
||||
case _:
|
||||
err_code = self.ReceiveProtocolError.code
|
||||
yield ReceiveHttp(
|
||||
self.ReceiveProtocolError(
|
||||
event.stream_id,
|
||||
f"stream reset by client ({err_str})",
|
||||
code=err_code,
|
||||
)
|
||||
)
|
||||
self.streams.pop(event.stream_id)
|
||||
else:
|
||||
pass # We don't track priority frames which could be followed by a stream reset here.
|
||||
elif isinstance(event, h2.exceptions.ProtocolError):
|
||||
yield from self.protocol_error(f"HTTP/2 protocol error: {event}")
|
||||
return True
|
||||
elif isinstance(event, h2.events.ConnectionTerminated):
|
||||
yield from self.close_connection(f"HTTP/2 connection closed: {event!r}")
|
||||
return True
|
||||
# The implementation above isn't really ideal, we should probably only terminate streams > last_stream_id?
|
||||
# We currently lack a mechanism to signal that connections are still active but cannot be reused.
|
||||
# for stream_id in self.streams:
|
||||
# if stream_id > event.last_stream_id:
|
||||
# yield ReceiveHttp(self.ReceiveProtocolError(stream_id, f"HTTP/2 connection closed: {event!r}"))
|
||||
# self.streams.pop(stream_id)
|
||||
elif isinstance(event, h2.events.RemoteSettingsChanged):
|
||||
pass
|
||||
elif isinstance(event, h2.events.SettingsAcknowledged):
|
||||
pass
|
||||
elif isinstance(event, h2.events.PriorityUpdated):
|
||||
pass
|
||||
elif isinstance(event, h2.events.PingReceived):
|
||||
pass
|
||||
elif isinstance(event, h2.events.PingAckReceived):
|
||||
pass
|
||||
elif isinstance(event, h2.events.PushedStreamReceived):
|
||||
yield Log(
|
||||
"Received HTTP/2 push promise, even though we signalled no support.",
|
||||
ERROR,
|
||||
)
|
||||
elif isinstance(event, h2.events.UnknownFrameReceived):
|
||||
# https://http2.github.io/http2-spec/#rfc.section.4.1
|
||||
# Implementations MUST ignore and discard any frame that has a type that is unknown.
|
||||
yield Log(f"Ignoring unknown HTTP/2 frame type: {event.frame.type}")
|
||||
elif isinstance(event, h2.events.AlternativeServiceAvailable):
|
||||
yield Log(
|
||||
"Received HTTP/2 Alt-Svc frame, which will not be forwarded.", DEBUG
|
||||
)
|
||||
else:
|
||||
raise AssertionError(f"Unexpected event: {event!r}")
|
||||
return False
|
||||
|
||||
def protocol_error(
|
||||
self,
|
||||
message: str,
|
||||
error_code: int = h2.errors.ErrorCodes.PROTOCOL_ERROR,
|
||||
) -> CommandGenerator[None]:
|
||||
yield Log(f"{human.format_address(self.conn.peername)}: {message}")
|
||||
self.h2_conn.close_connection(error_code, message.encode())
|
||||
yield SendData(self.conn, self.h2_conn.data_to_send())
|
||||
yield from self.close_connection(message)
|
||||
|
||||
def close_connection(self, msg: str) -> CommandGenerator[None]:
|
||||
yield CloseConnection(self.conn)
|
||||
for stream_id in self.streams:
|
||||
yield ReceiveHttp(
|
||||
self.ReceiveProtocolError(
|
||||
stream_id, msg, self.ReceiveProtocolError.code
|
||||
)
|
||||
)
|
||||
self.streams.clear()
|
||||
self._handle_event = self.done # type: ignore
|
||||
|
||||
@expect(DataReceived, HttpEvent, ConnectionClosed, Wakeup)
|
||||
def done(self, _) -> CommandGenerator[None]:
|
||||
yield from ()
|
||||
|
||||
|
||||
def normalize_h1_headers(
|
||||
headers: list[tuple[bytes, bytes]], is_client: bool
|
||||
) -> list[tuple[bytes, bytes]]:
|
||||
# HTTP/1 servers commonly send capitalized headers (Content-Length vs content-length),
|
||||
# which isn't valid HTTP/2. As such we normalize.
|
||||
# Make sure that this is not just an iterator but an iterable,
|
||||
# otherwise hyper-h2 will silently drop headers.
|
||||
return list(
|
||||
h2.utilities.normalize_outbound_headers(
|
||||
headers,
|
||||
h2.utilities.HeaderValidationFlags(is_client, False, not is_client, False),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def normalize_h2_headers(headers: list[tuple[bytes, bytes]]) -> CommandGenerator[None]:
|
||||
for i in range(len(headers)):
|
||||
if not headers[i][0].islower():
|
||||
yield Log(
|
||||
f"Lowercased {repr(headers[i][0]).lstrip('b')} header as uppercase is not allowed with HTTP/2."
|
||||
)
|
||||
headers[i] = (headers[i][0].lower(), headers[i][1])
|
||||
|
||||
|
||||
def format_h2_request_headers(
|
||||
context: Context,
|
||||
event: RequestHeaders,
|
||||
) -> CommandGenerator[list[tuple[bytes, bytes]]]:
|
||||
pseudo_headers = [
|
||||
(b":method", event.request.data.method),
|
||||
(b":scheme", event.request.data.scheme),
|
||||
(b":path", event.request.data.path),
|
||||
]
|
||||
if event.request.authority:
|
||||
pseudo_headers.append((b":authority", event.request.data.authority))
|
||||
|
||||
if event.request.is_http2 or event.request.is_http3:
|
||||
hdrs = list(event.request.headers.fields)
|
||||
if context.options.normalize_outbound_headers:
|
||||
yield from normalize_h2_headers(hdrs)
|
||||
else:
|
||||
headers = event.request.headers
|
||||
if not event.request.authority and "host" in headers:
|
||||
headers = headers.copy()
|
||||
pseudo_headers.append((b":authority", headers.pop(b"host")))
|
||||
hdrs = normalize_h1_headers(list(headers.fields), True)
|
||||
|
||||
return pseudo_headers + hdrs
|
||||
|
||||
|
||||
def format_h2_response_headers(
|
||||
context: Context,
|
||||
event: ResponseHeaders,
|
||||
) -> CommandGenerator[list[tuple[bytes, bytes]]]:
|
||||
headers = [
|
||||
(b":status", b"%d" % event.response.status_code),
|
||||
*event.response.headers.fields,
|
||||
]
|
||||
if event.response.is_http2 or event.response.is_http3:
|
||||
if context.options.normalize_outbound_headers:
|
||||
yield from normalize_h2_headers(headers)
|
||||
else:
|
||||
headers = normalize_h1_headers(headers, False)
|
||||
return headers
|
||||
|
||||
|
||||
class Http2Server(Http2Connection):
|
||||
h2_conf = h2.config.H2Configuration(
|
||||
**Http2Connection.h2_conf_defaults,
|
||||
client_side=False,
|
||||
)
|
||||
|
||||
ReceiveProtocolError = RequestProtocolError
|
||||
ReceiveData = RequestData
|
||||
ReceiveTrailers = RequestTrailers
|
||||
ReceiveEndOfMessage = RequestEndOfMessage
|
||||
|
||||
def __init__(self, context: Context):
|
||||
super().__init__(context, context.client)
|
||||
|
||||
def _handle_event(self, event: Event) -> CommandGenerator[None]:
|
||||
if isinstance(event, ResponseHeaders):
|
||||
if self.is_open_for_us(event.stream_id):
|
||||
self.h2_conn.send_headers(
|
||||
event.stream_id,
|
||||
headers=(
|
||||
yield from format_h2_response_headers(self.context, event)
|
||||
),
|
||||
end_stream=event.end_stream,
|
||||
)
|
||||
yield SendData(self.conn, self.h2_conn.data_to_send())
|
||||
else:
|
||||
yield from super()._handle_event(event)
|
||||
|
||||
def handle_h2_event(self, event: h2.events.Event) -> CommandGenerator[bool]:
|
||||
if isinstance(event, h2.events.RequestReceived):
|
||||
try:
|
||||
(
|
||||
host,
|
||||
port,
|
||||
method,
|
||||
scheme,
|
||||
authority,
|
||||
path,
|
||||
headers,
|
||||
) = parse_h2_request_headers(event.headers)
|
||||
except ValueError as e:
|
||||
yield from self.protocol_error(f"Invalid HTTP/2 request headers: {e}")
|
||||
return True
|
||||
request = http.Request(
|
||||
host=host,
|
||||
port=port,
|
||||
method=method,
|
||||
scheme=scheme,
|
||||
authority=authority,
|
||||
path=path,
|
||||
http_version=b"HTTP/2.0",
|
||||
headers=headers,
|
||||
content=None,
|
||||
trailers=None,
|
||||
timestamp_start=time.time(),
|
||||
timestamp_end=None,
|
||||
)
|
||||
self.streams[event.stream_id] = StreamState.HEADERS_RECEIVED
|
||||
yield ReceiveHttp(
|
||||
RequestHeaders(
|
||||
event.stream_id, request, end_stream=bool(event.stream_ended)
|
||||
)
|
||||
)
|
||||
return False
|
||||
else:
|
||||
return (yield from super().handle_h2_event(event))
|
||||
|
||||
|
||||
class Http2Client(Http2Connection):
|
||||
h2_conf = h2.config.H2Configuration(
|
||||
**Http2Connection.h2_conf_defaults,
|
||||
client_side=True,
|
||||
)
|
||||
|
||||
ReceiveProtocolError = ResponseProtocolError
|
||||
ReceiveData = ResponseData
|
||||
ReceiveTrailers = ResponseTrailers
|
||||
ReceiveEndOfMessage = ResponseEndOfMessage
|
||||
|
||||
our_stream_id: dict[int, int]
|
||||
their_stream_id: dict[int, int]
|
||||
stream_queue: collections.defaultdict[int, list[Event]]
|
||||
"""Queue of streams that we haven't sent yet because we have reached MAX_CONCURRENT_STREAMS"""
|
||||
provisional_max_concurrency: int | None = 10
|
||||
"""A provisional currency limit before we get the server's first settings frame."""
|
||||
last_activity: float
|
||||
"""Timestamp of when we've last seen network activity on this connection."""
|
||||
|
||||
def __init__(self, context: Context):
|
||||
super().__init__(context, context.server)
|
||||
# Disable HTTP/2 push for now to keep things simple.
|
||||
# don't send here, that is done as part of initiate_connection().
|
||||
self.h2_conn.local_settings.enable_push = 0
|
||||
# hyper-h2 pitfall: we need to acknowledge here, otherwise its sends out the old settings.
|
||||
self.h2_conn.local_settings.acknowledge()
|
||||
self.our_stream_id = {}
|
||||
self.their_stream_id = {}
|
||||
self.stream_queue = collections.defaultdict(list)
|
||||
|
||||
def _handle_event(self, event: Event) -> CommandGenerator[None]:
|
||||
# We can't reuse stream ids from the client because they may arrived reordered here
|
||||
# and HTTP/2 forbids opening a stream on a lower id than what was previously sent (see test_stream_concurrency).
|
||||
# To mitigate this, we transparently map the outside's stream id to our stream id.
|
||||
if isinstance(event, HttpEvent):
|
||||
ours = self.our_stream_id.get(event.stream_id, None)
|
||||
if ours is None:
|
||||
no_free_streams = self.h2_conn.open_outbound_streams >= (
|
||||
self.provisional_max_concurrency
|
||||
or self.h2_conn.remote_settings.max_concurrent_streams
|
||||
)
|
||||
if no_free_streams:
|
||||
self.stream_queue[event.stream_id].append(event)
|
||||
return
|
||||
ours = self.h2_conn.get_next_available_stream_id()
|
||||
self.our_stream_id[event.stream_id] = ours
|
||||
self.their_stream_id[ours] = event.stream_id
|
||||
event.stream_id = ours
|
||||
|
||||
for cmd in self._handle_event2(event):
|
||||
if isinstance(cmd, ReceiveHttp):
|
||||
cmd.event.stream_id = self.their_stream_id[cmd.event.stream_id]
|
||||
yield cmd
|
||||
|
||||
can_resume_queue = self.stream_queue and self.h2_conn.open_outbound_streams < (
|
||||
self.provisional_max_concurrency
|
||||
or self.h2_conn.remote_settings.max_concurrent_streams
|
||||
)
|
||||
if can_resume_queue:
|
||||
# popitem would be LIFO, but we want FIFO.
|
||||
events = self.stream_queue.pop(next(iter(self.stream_queue)))
|
||||
for event in events:
|
||||
yield from self._handle_event(event)
|
||||
|
||||
def _handle_event2(self, event: Event) -> CommandGenerator[None]:
|
||||
if isinstance(event, Wakeup):
|
||||
send_ping_now = (
|
||||
# add one second to avoid unnecessary roundtrip, we don't need to be super correct here.
|
||||
time.time() - self.last_activity + 1
|
||||
> self.context.options.http2_ping_keepalive
|
||||
)
|
||||
if send_ping_now:
|
||||
# PING frames MUST contain 8 octets of opaque data in the payload.
|
||||
# A sender can include any value it chooses and use those octets in any fashion.
|
||||
self.last_activity = time.time()
|
||||
self.h2_conn.ping(b"0" * 8)
|
||||
data = self.h2_conn.data_to_send()
|
||||
if data is not None:
|
||||
yield Log(
|
||||
f"Send HTTP/2 keep-alive PING to {human.format_address(self.conn.peername)}",
|
||||
DEBUG,
|
||||
)
|
||||
yield SendData(self.conn, data)
|
||||
time_until_next_ping = self.context.options.http2_ping_keepalive - (
|
||||
time.time() - self.last_activity
|
||||
)
|
||||
yield RequestWakeup(time_until_next_ping)
|
||||
return
|
||||
|
||||
self.last_activity = time.time()
|
||||
if isinstance(event, Start):
|
||||
if self.context.options.http2_ping_keepalive > 0:
|
||||
yield RequestWakeup(self.context.options.http2_ping_keepalive)
|
||||
yield from super()._handle_event(event)
|
||||
elif isinstance(event, RequestHeaders):
|
||||
self.h2_conn.send_headers(
|
||||
event.stream_id,
|
||||
headers=(yield from format_h2_request_headers(self.context, event)),
|
||||
end_stream=event.end_stream,
|
||||
)
|
||||
self.streams[event.stream_id] = StreamState.EXPECTING_HEADERS
|
||||
yield SendData(self.conn, self.h2_conn.data_to_send())
|
||||
else:
|
||||
yield from super()._handle_event(event)
|
||||
|
||||
def handle_h2_event(self, event: h2.events.Event) -> CommandGenerator[bool]:
|
||||
if isinstance(event, h2.events.ResponseReceived):
|
||||
if (
|
||||
self.streams.get(event.stream_id, None)
|
||||
is not StreamState.EXPECTING_HEADERS
|
||||
):
|
||||
yield from self.protocol_error(f"Received unexpected HTTP/2 response.")
|
||||
return True
|
||||
|
||||
try:
|
||||
status_code, headers = parse_h2_response_headers(event.headers)
|
||||
except ValueError as e:
|
||||
yield from self.protocol_error(f"Invalid HTTP/2 response headers: {e}")
|
||||
return True
|
||||
|
||||
response = http.Response(
|
||||
http_version=b"HTTP/2.0",
|
||||
status_code=status_code,
|
||||
reason=b"",
|
||||
headers=headers,
|
||||
content=None,
|
||||
trailers=None,
|
||||
timestamp_start=time.time(),
|
||||
timestamp_end=None,
|
||||
)
|
||||
self.streams[event.stream_id] = StreamState.HEADERS_RECEIVED
|
||||
yield ReceiveHttp(
|
||||
ResponseHeaders(event.stream_id, response, bool(event.stream_ended))
|
||||
)
|
||||
return False
|
||||
elif isinstance(event, h2.events.InformationalResponseReceived):
|
||||
# We violate the spec here ("A proxy MUST forward 1xx responses", RFC 7231),
|
||||
# but that's probably fine:
|
||||
# - 100 Continue is sent by mitmproxy to clients (irrespective of what the server does).
|
||||
# - 101 Switching Protocols is not allowed for HTTP/2.
|
||||
# - 102 Processing is WebDAV only and also ignorable.
|
||||
# - 103 Early Hints is not mission-critical.
|
||||
headers = http.Headers(event.headers)
|
||||
status: str | int = "<unknown status>"
|
||||
try:
|
||||
status = int(headers[":status"])
|
||||
reason = status_codes.RESPONSES.get(status, "")
|
||||
except (KeyError, ValueError):
|
||||
reason = ""
|
||||
yield Log(f"Swallowing HTTP/2 informational response: {status} {reason}")
|
||||
return False
|
||||
elif isinstance(event, h2.events.RequestReceived):
|
||||
yield from self.protocol_error(
|
||||
f"HTTP/2 protocol error: received request from server"
|
||||
)
|
||||
return True
|
||||
elif isinstance(event, h2.events.RemoteSettingsChanged):
|
||||
# We have received at least one settings from now,
|
||||
# which means we can rely on the max concurrency in remote_settings
|
||||
self.provisional_max_concurrency = None
|
||||
return (yield from super().handle_h2_event(event))
|
||||
else:
|
||||
return (yield from super().handle_h2_event(event))
|
||||
|
||||
|
||||
def split_pseudo_headers(
|
||||
h2_headers: Sequence[tuple[bytes, bytes]],
|
||||
) -> tuple[dict[bytes, bytes], http.Headers]:
|
||||
pseudo_headers: dict[bytes, bytes] = {}
|
||||
i = 0
|
||||
for header, value in h2_headers:
|
||||
if header.startswith(b":"):
|
||||
if header in pseudo_headers:
|
||||
raise ValueError(f"Duplicate HTTP/2 pseudo header: {header!r}")
|
||||
pseudo_headers[header] = value
|
||||
i += 1
|
||||
else:
|
||||
# Pseudo-headers must be at the start, we are done here.
|
||||
break
|
||||
|
||||
headers = http.Headers(h2_headers[i:])
|
||||
|
||||
return pseudo_headers, headers
|
||||
|
||||
|
||||
def parse_h2_request_headers(
|
||||
h2_headers: Sequence[tuple[bytes, bytes]],
|
||||
) -> tuple[str, int, bytes, bytes, bytes, bytes, http.Headers]:
|
||||
"""Split HTTP/2 pseudo-headers from the actual headers and parse them."""
|
||||
pseudo_headers, headers = split_pseudo_headers(h2_headers)
|
||||
|
||||
try:
|
||||
method: bytes = pseudo_headers.pop(b":method")
|
||||
scheme: bytes = pseudo_headers.pop(
|
||||
b":scheme"
|
||||
) # this raises for HTTP/2 CONNECT requests
|
||||
path: bytes = pseudo_headers.pop(b":path")
|
||||
authority: bytes = pseudo_headers.pop(b":authority", b"")
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Required pseudo header is missing: {e}")
|
||||
|
||||
if pseudo_headers:
|
||||
raise ValueError(f"Unknown pseudo headers: {pseudo_headers}")
|
||||
|
||||
if authority:
|
||||
host, port = url.parse_authority(authority, check=True)
|
||||
if port is None:
|
||||
port = 80 if scheme == b"http" else 443
|
||||
else:
|
||||
host = ""
|
||||
port = 0
|
||||
|
||||
return host, port, method, scheme, authority, path, headers
|
||||
|
||||
|
||||
def parse_h2_response_headers(
|
||||
h2_headers: Sequence[tuple[bytes, bytes]],
|
||||
) -> tuple[int, http.Headers]:
|
||||
"""Split HTTP/2 pseudo-headers from the actual headers and parse them."""
|
||||
pseudo_headers, headers = split_pseudo_headers(h2_headers)
|
||||
|
||||
try:
|
||||
status_code: int = int(pseudo_headers.pop(b":status"))
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Required pseudo header is missing: {e}")
|
||||
|
||||
if pseudo_headers:
|
||||
raise ValueError(f"Unknown pseudo headers: {pseudo_headers}")
|
||||
|
||||
return status_code, headers
|
||||
|
||||
|
||||
__all__ = [
|
||||
"format_h2_request_headers",
|
||||
"format_h2_response_headers",
|
||||
"parse_h2_request_headers",
|
||||
"parse_h2_response_headers",
|
||||
"Http2Client",
|
||||
"Http2Server",
|
||||
]
|
||||
309
venv/Lib/site-packages/mitmproxy/proxy/layers/http/_http3.py
Normal file
309
venv/Lib/site-packages/mitmproxy/proxy/layers/http/_http3.py
Normal file
@@ -0,0 +1,309 @@
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from typing import assert_never
|
||||
|
||||
from aioquic.h3.connection import ErrorCode as H3ErrorCode
|
||||
from aioquic.h3.connection import FrameUnexpected as H3FrameUnexpected
|
||||
from aioquic.h3.events import DataReceived
|
||||
from aioquic.h3.events import HeadersReceived
|
||||
from aioquic.h3.events import PushPromiseReceived
|
||||
|
||||
from . import ErrorCode
|
||||
from . import RequestData
|
||||
from . import RequestEndOfMessage
|
||||
from . import RequestHeaders
|
||||
from . import RequestProtocolError
|
||||
from . import RequestTrailers
|
||||
from . import ResponseData
|
||||
from . import ResponseEndOfMessage
|
||||
from . import ResponseHeaders
|
||||
from . import ResponseProtocolError
|
||||
from . import ResponseTrailers
|
||||
from ._base import format_error
|
||||
from ._base import HttpConnection
|
||||
from ._base import HttpEvent
|
||||
from ._base import ReceiveHttp
|
||||
from ._http2 import format_h2_request_headers
|
||||
from ._http2 import format_h2_response_headers
|
||||
from ._http2 import parse_h2_request_headers
|
||||
from ._http2 import parse_h2_response_headers
|
||||
from ._http_h3 import LayeredH3Connection
|
||||
from ._http_h3 import StreamClosed
|
||||
from ._http_h3 import TrailersReceived
|
||||
from mitmproxy import connection
|
||||
from mitmproxy import http
|
||||
from mitmproxy import version
|
||||
from mitmproxy.proxy import commands
|
||||
from mitmproxy.proxy import context
|
||||
from mitmproxy.proxy import events
|
||||
from mitmproxy.proxy import layer
|
||||
from mitmproxy.proxy.layers.quic import error_code_to_str
|
||||
from mitmproxy.proxy.layers.quic import QuicConnectionClosed
|
||||
from mitmproxy.proxy.layers.quic import QuicStreamEvent
|
||||
from mitmproxy.proxy.utils import expect
|
||||
|
||||
|
||||
class Http3Connection(HttpConnection):
|
||||
h3_conn: LayeredH3Connection
|
||||
|
||||
ReceiveData: type[RequestData | ResponseData]
|
||||
ReceiveEndOfMessage: type[RequestEndOfMessage | ResponseEndOfMessage]
|
||||
ReceiveProtocolError: type[RequestProtocolError | ResponseProtocolError]
|
||||
ReceiveTrailers: type[RequestTrailers | ResponseTrailers]
|
||||
|
||||
def __init__(self, context: context.Context, conn: connection.Connection):
|
||||
super().__init__(context, conn)
|
||||
self.h3_conn = LayeredH3Connection(
|
||||
self.conn, is_client=self.conn is self.context.server
|
||||
)
|
||||
|
||||
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
if isinstance(event, events.Start):
|
||||
yield from self.h3_conn.transmit()
|
||||
|
||||
# send mitmproxy HTTP events over the H3 connection
|
||||
elif isinstance(event, HttpEvent):
|
||||
try:
|
||||
if isinstance(event, (RequestData, ResponseData)):
|
||||
self.h3_conn.send_data(event.stream_id, event.data)
|
||||
elif isinstance(event, (RequestHeaders, ResponseHeaders)):
|
||||
headers = yield from (
|
||||
format_h2_request_headers(self.context, event)
|
||||
if isinstance(event, RequestHeaders)
|
||||
else format_h2_response_headers(self.context, event)
|
||||
)
|
||||
self.h3_conn.send_headers(
|
||||
event.stream_id, headers, end_stream=event.end_stream
|
||||
)
|
||||
elif isinstance(event, (RequestTrailers, ResponseTrailers)):
|
||||
self.h3_conn.send_trailers(
|
||||
event.stream_id, [*event.trailers.fields]
|
||||
)
|
||||
elif isinstance(event, (RequestEndOfMessage, ResponseEndOfMessage)):
|
||||
self.h3_conn.end_stream(event.stream_id)
|
||||
elif isinstance(event, (RequestProtocolError, ResponseProtocolError)):
|
||||
status = event.code.http_status_code()
|
||||
if (
|
||||
isinstance(event, ResponseProtocolError)
|
||||
and not self.h3_conn.has_sent_headers(event.stream_id)
|
||||
and status is not None
|
||||
):
|
||||
self.h3_conn.send_headers(
|
||||
event.stream_id,
|
||||
[
|
||||
(b":status", b"%d" % status),
|
||||
(b"server", version.MITMPROXY.encode()),
|
||||
(b"content-type", b"text/html"),
|
||||
],
|
||||
)
|
||||
self.h3_conn.send_data(
|
||||
event.stream_id,
|
||||
format_error(status, event.message),
|
||||
end_stream=True,
|
||||
)
|
||||
else:
|
||||
match event.code:
|
||||
case ErrorCode.CANCEL | ErrorCode.CLIENT_DISCONNECTED:
|
||||
error_code = H3ErrorCode.H3_REQUEST_CANCELLED
|
||||
case ErrorCode.KILL:
|
||||
error_code = H3ErrorCode.H3_INTERNAL_ERROR
|
||||
case ErrorCode.HTTP_1_1_REQUIRED:
|
||||
error_code = H3ErrorCode.H3_VERSION_FALLBACK
|
||||
case ErrorCode.PASSTHROUGH_CLOSE:
|
||||
# FIXME: This probably shouldn't be a protocol error, but an EOM event.
|
||||
error_code = H3ErrorCode.H3_REQUEST_CANCELLED
|
||||
case (
|
||||
ErrorCode.GENERIC_CLIENT_ERROR
|
||||
| ErrorCode.GENERIC_SERVER_ERROR
|
||||
| ErrorCode.REQUEST_TOO_LARGE
|
||||
| ErrorCode.RESPONSE_TOO_LARGE
|
||||
| ErrorCode.CONNECT_FAILED
|
||||
| ErrorCode.DESTINATION_UNKNOWN
|
||||
| ErrorCode.REQUEST_VALIDATION_FAILED
|
||||
| ErrorCode.RESPONSE_VALIDATION_FAILED
|
||||
):
|
||||
error_code = H3ErrorCode.H3_INTERNAL_ERROR
|
||||
case other: # pragma: no cover
|
||||
assert_never(other)
|
||||
self.h3_conn.close_stream(event.stream_id, error_code.value)
|
||||
else: # pragma: no cover
|
||||
raise AssertionError(f"Unexpected event: {event!r}")
|
||||
|
||||
except H3FrameUnexpected as e:
|
||||
# Http2Connection also ignores HttpEvents that violate the current stream state
|
||||
yield commands.Log(f"Received {event!r} unexpectedly: {e}")
|
||||
|
||||
else:
|
||||
# transmit buffered data
|
||||
yield from self.h3_conn.transmit()
|
||||
|
||||
# forward stream messages from the QUIC layer to the H3 connection
|
||||
elif isinstance(event, QuicStreamEvent):
|
||||
h3_events = self.h3_conn.handle_stream_event(event)
|
||||
for h3_event in h3_events:
|
||||
if isinstance(h3_event, StreamClosed):
|
||||
err_str = error_code_to_str(h3_event.error_code)
|
||||
match h3_event.error_code:
|
||||
case H3ErrorCode.H3_REQUEST_CANCELLED:
|
||||
err_code = ErrorCode.CANCEL
|
||||
case H3ErrorCode.H3_VERSION_FALLBACK:
|
||||
err_code = ErrorCode.HTTP_1_1_REQUIRED
|
||||
case _:
|
||||
err_code = self.ReceiveProtocolError.code
|
||||
yield ReceiveHttp(
|
||||
self.ReceiveProtocolError(
|
||||
h3_event.stream_id,
|
||||
f"stream closed by client ({err_str})",
|
||||
code=err_code,
|
||||
)
|
||||
)
|
||||
elif isinstance(h3_event, DataReceived):
|
||||
if h3_event.data:
|
||||
yield ReceiveHttp(
|
||||
self.ReceiveData(h3_event.stream_id, h3_event.data)
|
||||
)
|
||||
if h3_event.stream_ended:
|
||||
yield ReceiveHttp(self.ReceiveEndOfMessage(h3_event.stream_id))
|
||||
elif isinstance(h3_event, HeadersReceived):
|
||||
try:
|
||||
receive_event = self.parse_headers(h3_event)
|
||||
except ValueError as e:
|
||||
self.h3_conn.close_connection(
|
||||
error_code=H3ErrorCode.H3_GENERAL_PROTOCOL_ERROR,
|
||||
reason_phrase=f"Invalid HTTP/3 request headers: {e}",
|
||||
)
|
||||
else:
|
||||
yield ReceiveHttp(receive_event)
|
||||
if h3_event.stream_ended:
|
||||
yield ReceiveHttp(
|
||||
self.ReceiveEndOfMessage(h3_event.stream_id)
|
||||
)
|
||||
elif isinstance(h3_event, TrailersReceived):
|
||||
yield ReceiveHttp(
|
||||
self.ReceiveTrailers(
|
||||
h3_event.stream_id, http.Headers(h3_event.trailers)
|
||||
)
|
||||
)
|
||||
if h3_event.stream_ended:
|
||||
yield ReceiveHttp(self.ReceiveEndOfMessage(h3_event.stream_id))
|
||||
elif isinstance(h3_event, PushPromiseReceived): # pragma: no cover
|
||||
self.h3_conn.close_connection(
|
||||
error_code=H3ErrorCode.H3_GENERAL_PROTOCOL_ERROR,
|
||||
reason_phrase=f"Received HTTP/3 push promise, even though we signalled no support.",
|
||||
)
|
||||
else: # pragma: no cover
|
||||
raise AssertionError(f"Unexpected event: {event!r}")
|
||||
yield from self.h3_conn.transmit()
|
||||
|
||||
# report a protocol error for all remaining open streams when a connection is closed
|
||||
elif isinstance(event, QuicConnectionClosed):
|
||||
self._handle_event = self.done # type: ignore
|
||||
self.h3_conn.handle_connection_closed(event)
|
||||
msg = event.reason_phrase or error_code_to_str(event.error_code)
|
||||
for stream_id in self.h3_conn.get_open_stream_ids():
|
||||
yield ReceiveHttp(
|
||||
self.ReceiveProtocolError(
|
||||
stream_id, msg, self.ReceiveProtocolError.code
|
||||
)
|
||||
)
|
||||
|
||||
else: # pragma: no cover
|
||||
raise AssertionError(f"Unexpected event: {event!r}")
|
||||
|
||||
@expect(HttpEvent, QuicStreamEvent, QuicConnectionClosed)
|
||||
def done(self, _) -> layer.CommandGenerator[None]:
|
||||
yield from ()
|
||||
|
||||
@abstractmethod
|
||||
def parse_headers(self, event: HeadersReceived) -> RequestHeaders | ResponseHeaders:
|
||||
pass # pragma: no cover
|
||||
|
||||
|
||||
class Http3Server(Http3Connection):
|
||||
ReceiveData = RequestData
|
||||
ReceiveEndOfMessage = RequestEndOfMessage
|
||||
ReceiveProtocolError = RequestProtocolError
|
||||
ReceiveTrailers = RequestTrailers
|
||||
|
||||
def __init__(self, context: context.Context):
|
||||
super().__init__(context, context.client)
|
||||
|
||||
def parse_headers(self, event: HeadersReceived) -> RequestHeaders | ResponseHeaders:
|
||||
# same as HTTP/2
|
||||
(
|
||||
host,
|
||||
port,
|
||||
method,
|
||||
scheme,
|
||||
authority,
|
||||
path,
|
||||
headers,
|
||||
) = parse_h2_request_headers(event.headers)
|
||||
request = http.Request(
|
||||
host=host,
|
||||
port=port,
|
||||
method=method,
|
||||
scheme=scheme,
|
||||
authority=authority,
|
||||
path=path,
|
||||
http_version=b"HTTP/3",
|
||||
headers=headers,
|
||||
content=None,
|
||||
trailers=None,
|
||||
timestamp_start=time.time(),
|
||||
timestamp_end=None,
|
||||
)
|
||||
return RequestHeaders(event.stream_id, request, end_stream=event.stream_ended)
|
||||
|
||||
|
||||
class Http3Client(Http3Connection):
|
||||
ReceiveData = ResponseData
|
||||
ReceiveEndOfMessage = ResponseEndOfMessage
|
||||
ReceiveProtocolError = ResponseProtocolError
|
||||
ReceiveTrailers = ResponseTrailers
|
||||
|
||||
our_stream_id: dict[int, int]
|
||||
their_stream_id: dict[int, int]
|
||||
|
||||
def __init__(self, context: context.Context):
|
||||
super().__init__(context, context.server)
|
||||
self.our_stream_id = {}
|
||||
self.their_stream_id = {}
|
||||
|
||||
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
# QUIC and HTTP/3 would actually allow for direct stream ID mapping, but since we want
|
||||
# to support H2<->H3, we need to translate IDs.
|
||||
# NOTE: We always create bidirectional streams, as we can't safely infer unidirectionality.
|
||||
if isinstance(event, HttpEvent):
|
||||
ours = self.our_stream_id.get(event.stream_id, None)
|
||||
if ours is None:
|
||||
ours = self.h3_conn.get_next_available_stream_id()
|
||||
self.our_stream_id[event.stream_id] = ours
|
||||
self.their_stream_id[ours] = event.stream_id
|
||||
event.stream_id = ours
|
||||
|
||||
for cmd in super()._handle_event(event):
|
||||
if isinstance(cmd, ReceiveHttp):
|
||||
cmd.event.stream_id = self.their_stream_id[cmd.event.stream_id]
|
||||
yield cmd
|
||||
|
||||
def parse_headers(self, event: HeadersReceived) -> RequestHeaders | ResponseHeaders:
|
||||
# same as HTTP/2
|
||||
status_code, headers = parse_h2_response_headers(event.headers)
|
||||
response = http.Response(
|
||||
http_version=b"HTTP/3",
|
||||
status_code=status_code,
|
||||
reason=b"",
|
||||
headers=headers,
|
||||
content=None,
|
||||
trailers=None,
|
||||
timestamp_start=time.time(),
|
||||
timestamp_end=None,
|
||||
)
|
||||
return ResponseHeaders(event.stream_id, response, event.stream_ended)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Http3Client",
|
||||
"Http3Server",
|
||||
]
|
||||
207
venv/Lib/site-packages/mitmproxy/proxy/layers/http/_http_h2.py
Normal file
207
venv/Lib/site-packages/mitmproxy/proxy/layers/http/_http_h2.py
Normal file
@@ -0,0 +1,207 @@
|
||||
import collections
|
||||
import logging
|
||||
from typing import NamedTuple
|
||||
|
||||
import h2.config
|
||||
import h2.connection
|
||||
import h2.events
|
||||
import h2.exceptions
|
||||
import h2.settings
|
||||
import h2.stream
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class H2ConnectionLogger(h2.config.DummyLogger):
|
||||
def __init__(self, peername: tuple, conn_type: str):
|
||||
super().__init__()
|
||||
self.peername = peername
|
||||
self.conn_type = conn_type
|
||||
|
||||
def debug(self, fmtstr, *args):
|
||||
logger.debug(
|
||||
f"{self.conn_type} {fmtstr}", *args, extra={"client": self.peername}
|
||||
)
|
||||
|
||||
def trace(self, fmtstr, *args):
|
||||
logger.log(
|
||||
logging.DEBUG - 1,
|
||||
f"{self.conn_type} {fmtstr}",
|
||||
*args,
|
||||
extra={"client": self.peername},
|
||||
)
|
||||
|
||||
|
||||
class SendH2Data(NamedTuple):
|
||||
data: bytes
|
||||
end_stream: bool
|
||||
|
||||
|
||||
class BufferedH2Connection(h2.connection.H2Connection):
|
||||
"""
|
||||
This class wrap's hyper-h2's H2Connection and adds internal send buffers.
|
||||
|
||||
To simplify implementation, padding is unsupported.
|
||||
"""
|
||||
|
||||
stream_buffers: collections.defaultdict[int, collections.deque[SendH2Data]]
|
||||
stream_trailers: dict[int, list[tuple[bytes, bytes]]]
|
||||
|
||||
def __init__(self, config: h2.config.H2Configuration):
|
||||
super().__init__(config)
|
||||
self.local_settings.initial_window_size = 2**31 - 1
|
||||
self.local_settings.max_frame_size = 2**17
|
||||
self.max_inbound_frame_size = 2**17
|
||||
# hyper-h2 pitfall: we need to acknowledge here, otherwise its sends out the old settings.
|
||||
self.local_settings.acknowledge()
|
||||
self.stream_buffers = collections.defaultdict(collections.deque)
|
||||
self.stream_trailers = {}
|
||||
|
||||
def initiate_connection(self):
|
||||
super().initiate_connection()
|
||||
# We increase the flow-control window for new streams with a setting,
|
||||
# but we need to increase the overall connection flow-control window as well.
|
||||
self.increment_flow_control_window(
|
||||
2**31 - 1 - self.inbound_flow_control_window
|
||||
) # maximum - default
|
||||
|
||||
def send_data(
|
||||
self,
|
||||
stream_id: int,
|
||||
data: bytes,
|
||||
end_stream: bool = False,
|
||||
pad_length: None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Send data on a given stream.
|
||||
|
||||
In contrast to plain hyper-h2, this method will not raise if the data cannot be sent immediately.
|
||||
Data is split up and buffered internally.
|
||||
"""
|
||||
frame_size = len(data)
|
||||
assert pad_length is None
|
||||
|
||||
if frame_size > self.max_outbound_frame_size:
|
||||
for start in range(0, frame_size, self.max_outbound_frame_size):
|
||||
chunk = data[start : start + self.max_outbound_frame_size]
|
||||
self.send_data(stream_id, chunk, end_stream=False)
|
||||
|
||||
return
|
||||
|
||||
if self.stream_buffers.get(stream_id, None):
|
||||
# We already have some data buffered, let's append.
|
||||
self.stream_buffers[stream_id].append(SendH2Data(data, end_stream))
|
||||
else:
|
||||
available_window = self.local_flow_control_window(stream_id)
|
||||
if frame_size <= available_window:
|
||||
super().send_data(stream_id, data, end_stream)
|
||||
else:
|
||||
if available_window:
|
||||
can_send_now = data[:available_window]
|
||||
super().send_data(stream_id, can_send_now, end_stream=False)
|
||||
data = data[available_window:]
|
||||
# We can't send right now, so we buffer.
|
||||
self.stream_buffers[stream_id].append(SendH2Data(data, end_stream))
|
||||
|
||||
def send_trailers(self, stream_id: int, trailers: list[tuple[bytes, bytes]]):
|
||||
if self.stream_buffers.get(stream_id, None):
|
||||
# Though trailers are not subject to flow control, we need to queue them and send strictly after data frames
|
||||
self.stream_trailers[stream_id] = trailers
|
||||
else:
|
||||
self.send_headers(stream_id, trailers, end_stream=True)
|
||||
|
||||
def end_stream(self, stream_id: int) -> None:
|
||||
if stream_id in self.stream_trailers:
|
||||
return # we already have trailers queued up that will end the stream.
|
||||
self.send_data(stream_id, b"", end_stream=True)
|
||||
|
||||
def reset_stream(self, stream_id: int, error_code: int = 0) -> None:
|
||||
self.stream_buffers.pop(stream_id, None)
|
||||
super().reset_stream(stream_id, error_code)
|
||||
|
||||
def receive_data(self, data: bytes):
|
||||
events = super().receive_data(data)
|
||||
ret = []
|
||||
for event in events:
|
||||
if isinstance(event, h2.events.WindowUpdated):
|
||||
if event.stream_id == 0:
|
||||
self.connection_window_updated()
|
||||
else:
|
||||
self.stream_window_updated(event.stream_id)
|
||||
continue
|
||||
elif isinstance(event, h2.events.RemoteSettingsChanged):
|
||||
if (
|
||||
h2.settings.SettingCodes.INITIAL_WINDOW_SIZE
|
||||
in event.changed_settings
|
||||
):
|
||||
self.connection_window_updated()
|
||||
elif isinstance(event, h2.events.StreamReset):
|
||||
self.stream_buffers.pop(event.stream_id, None)
|
||||
elif isinstance(event, h2.events.ConnectionTerminated):
|
||||
self.stream_buffers.clear()
|
||||
ret.append(event)
|
||||
return ret
|
||||
|
||||
def stream_window_updated(self, stream_id: int) -> bool:
|
||||
"""
|
||||
The window for a specific stream has updated. Send as much buffered data as possible.
|
||||
"""
|
||||
# If the stream has been reset in the meantime, we just clear the buffer.
|
||||
try:
|
||||
stream: h2.stream.H2Stream = self.streams[stream_id]
|
||||
except KeyError:
|
||||
stream_was_reset = True
|
||||
else:
|
||||
stream_was_reset = stream.state_machine.state not in (
|
||||
h2.stream.StreamState.OPEN,
|
||||
h2.stream.StreamState.HALF_CLOSED_REMOTE,
|
||||
)
|
||||
if stream_was_reset:
|
||||
self.stream_buffers.pop(stream_id, None)
|
||||
return False
|
||||
|
||||
available_window = self.local_flow_control_window(stream_id)
|
||||
sent_any_data = False
|
||||
while available_window > 0 and stream_id in self.stream_buffers:
|
||||
chunk: SendH2Data = self.stream_buffers[stream_id].popleft()
|
||||
if len(chunk.data) > available_window:
|
||||
# We can't send the entire chunk, so we have to put some bytes back into the buffer.
|
||||
self.stream_buffers[stream_id].appendleft(
|
||||
SendH2Data(
|
||||
data=chunk.data[available_window:],
|
||||
end_stream=chunk.end_stream,
|
||||
)
|
||||
)
|
||||
chunk = SendH2Data(
|
||||
data=chunk.data[:available_window],
|
||||
end_stream=False,
|
||||
)
|
||||
|
||||
super().send_data(stream_id, data=chunk.data, end_stream=chunk.end_stream)
|
||||
|
||||
available_window -= len(chunk.data)
|
||||
if not self.stream_buffers[stream_id]:
|
||||
del self.stream_buffers[stream_id]
|
||||
if stream_id in self.stream_trailers:
|
||||
self.send_headers(
|
||||
stream_id, self.stream_trailers.pop(stream_id), end_stream=True
|
||||
)
|
||||
sent_any_data = True
|
||||
|
||||
return sent_any_data
|
||||
|
||||
def connection_window_updated(self) -> None:
|
||||
"""
|
||||
The connection window has updated. Send data from buffers in a round-robin fashion.
|
||||
"""
|
||||
sent_any_data = True
|
||||
while sent_any_data:
|
||||
sent_any_data = False
|
||||
for stream_id in list(self.stream_buffers):
|
||||
self.stream_buffers[stream_id] = self.stream_buffers.pop(
|
||||
stream_id
|
||||
) # move to end of dict
|
||||
if self.stream_window_updated(stream_id):
|
||||
sent_any_data = True
|
||||
if self.outbound_flow_control_window == 0:
|
||||
return
|
||||
321
venv/Lib/site-packages/mitmproxy/proxy/layers/http/_http_h3.py
Normal file
321
venv/Lib/site-packages/mitmproxy/proxy/layers/http/_http_h3.py
Normal file
@@ -0,0 +1,321 @@
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
|
||||
from aioquic.h3.connection import FrameUnexpected
|
||||
from aioquic.h3.connection import H3Connection
|
||||
from aioquic.h3.connection import H3Event
|
||||
from aioquic.h3.connection import H3Stream
|
||||
from aioquic.h3.connection import Headers
|
||||
from aioquic.h3.connection import HeadersState
|
||||
from aioquic.h3.events import HeadersReceived
|
||||
from aioquic.quic.configuration import QuicConfiguration
|
||||
from aioquic.quic.events import StreamDataReceived
|
||||
from aioquic.quic.packet import QuicErrorCode
|
||||
|
||||
from mitmproxy import connection
|
||||
from mitmproxy.proxy import commands
|
||||
from mitmproxy.proxy import layer
|
||||
from mitmproxy.proxy.layers.quic import CloseQuicConnection
|
||||
from mitmproxy.proxy.layers.quic import QuicConnectionClosed
|
||||
from mitmproxy.proxy.layers.quic import QuicStreamDataReceived
|
||||
from mitmproxy.proxy.layers.quic import QuicStreamEvent
|
||||
from mitmproxy.proxy.layers.quic import QuicStreamReset
|
||||
from mitmproxy.proxy.layers.quic import QuicStreamStopSending
|
||||
from mitmproxy.proxy.layers.quic import ResetQuicStream
|
||||
from mitmproxy.proxy.layers.quic import SendQuicStreamData
|
||||
from mitmproxy.proxy.layers.quic import StopSendingQuicStream
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrailersReceived(H3Event):
|
||||
"""
|
||||
The TrailersReceived event is fired whenever trailers are received.
|
||||
"""
|
||||
|
||||
trailers: Headers
|
||||
"The trailers."
|
||||
|
||||
stream_id: int
|
||||
"The ID of the stream the trailers were received for."
|
||||
|
||||
stream_ended: bool
|
||||
"Whether the STREAM frame had the FIN bit set."
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamClosed(H3Event):
|
||||
"""
|
||||
The StreamReset event is fired when the peer is sending a CLOSE_STREAM
|
||||
or a STOP_SENDING frame. For HTTP/3, we don't differentiate between the two.
|
||||
"""
|
||||
|
||||
stream_id: int
|
||||
"The ID of the stream that was reset."
|
||||
|
||||
error_code: int
|
||||
"""The error code indicating why the stream was closed."""
|
||||
|
||||
|
||||
class MockQuic:
|
||||
"""
|
||||
aioquic intermingles QUIC and HTTP/3. This is something we don't want to do because that makes testing much harder.
|
||||
Instead, we mock our QUIC connection object here and then take out the wire data to be sent.
|
||||
"""
|
||||
|
||||
def __init__(self, conn: connection.Connection, is_client: bool) -> None:
|
||||
self.conn = conn
|
||||
self.pending_commands: list[commands.Command] = []
|
||||
self._next_stream_id: list[int] = [0, 1, 2, 3]
|
||||
self._is_client = is_client
|
||||
|
||||
# the following fields are accessed by H3Connection
|
||||
self.configuration = QuicConfiguration(is_client=is_client)
|
||||
self._quic_logger = None
|
||||
self._remote_max_datagram_frame_size = 0
|
||||
|
||||
def close(
|
||||
self,
|
||||
error_code: int = QuicErrorCode.NO_ERROR,
|
||||
frame_type: int | None = None,
|
||||
reason_phrase: str = "",
|
||||
) -> None:
|
||||
# we'll get closed if a protocol error occurs in `H3Connection.handle_event`
|
||||
# we note the error on the connection and yield a CloseConnection
|
||||
# this will then call `QuicConnection.close` with the proper values
|
||||
# once the `Http3Connection` receives `ConnectionClosed`, it will send out `ProtocolError`
|
||||
self.pending_commands.append(
|
||||
CloseQuicConnection(self.conn, error_code, frame_type, reason_phrase)
|
||||
)
|
||||
|
||||
def get_next_available_stream_id(self, is_unidirectional: bool = False) -> int:
|
||||
# since we always reserve the ID, we have to "find" the next ID like `QuicConnection` does
|
||||
index = (int(is_unidirectional) << 1) | int(not self._is_client)
|
||||
stream_id = self._next_stream_id[index]
|
||||
self._next_stream_id[index] = stream_id + 4
|
||||
return stream_id
|
||||
|
||||
def reset_stream(self, stream_id: int, error_code: int) -> None:
|
||||
self.pending_commands.append(ResetQuicStream(self.conn, stream_id, error_code))
|
||||
|
||||
def stop_send(self, stream_id: int, error_code: int) -> None:
|
||||
self.pending_commands.append(
|
||||
StopSendingQuicStream(self.conn, stream_id, error_code)
|
||||
)
|
||||
|
||||
def send_stream_data(
|
||||
self, stream_id: int, data: bytes, end_stream: bool = False
|
||||
) -> None:
|
||||
self.pending_commands.append(
|
||||
SendQuicStreamData(self.conn, stream_id, data, end_stream)
|
||||
)
|
||||
|
||||
|
||||
class LayeredH3Connection(H3Connection):
|
||||
"""
|
||||
Creates a H3 connection using a fake QUIC connection, which allows layer separation.
|
||||
Also ensures that headers, data and trailers are sent in that order.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conn: connection.Connection,
|
||||
is_client: bool,
|
||||
enable_webtransport: bool = False,
|
||||
) -> None:
|
||||
self._mock = MockQuic(conn, is_client)
|
||||
self._closed_streams: set[int] = set()
|
||||
"""
|
||||
We keep track of all stream IDs for which we have requested
|
||||
STOP_SENDING to silently discard incoming data.
|
||||
"""
|
||||
super().__init__(self._mock, enable_webtransport) # type: ignore
|
||||
|
||||
# aioquic's constructor sets and then uses _max_push_id.
|
||||
# This is a hack to forcibly disable it.
|
||||
@property
|
||||
def _max_push_id(self) -> int | None:
|
||||
return None
|
||||
|
||||
@_max_push_id.setter
|
||||
def _max_push_id(self, value):
|
||||
pass
|
||||
|
||||
def _after_send(self, stream_id: int, end_stream: bool) -> None:
|
||||
# if the stream ended, `QuicConnection` has an assert that no further data is being sent
|
||||
# to catch this more early on, we set the header state on the `H3Stream`
|
||||
if end_stream:
|
||||
self._stream[stream_id].headers_send_state = HeadersState.AFTER_TRAILERS
|
||||
|
||||
def _handle_request_or_push_frame(
|
||||
self,
|
||||
frame_type: int,
|
||||
frame_data: bytes | None,
|
||||
stream: H3Stream,
|
||||
stream_ended: bool,
|
||||
) -> list[H3Event]:
|
||||
# turn HeadersReceived into TrailersReceived for trailers
|
||||
events = super()._handle_request_or_push_frame(
|
||||
frame_type, frame_data, stream, stream_ended
|
||||
)
|
||||
for index, event in enumerate(events):
|
||||
if (
|
||||
isinstance(event, HeadersReceived)
|
||||
and self._stream[event.stream_id].headers_recv_state
|
||||
== HeadersState.AFTER_TRAILERS
|
||||
):
|
||||
events[index] = TrailersReceived(
|
||||
event.headers, event.stream_id, event.stream_ended
|
||||
)
|
||||
return events
|
||||
|
||||
def close_connection(
|
||||
self,
|
||||
error_code: int = QuicErrorCode.NO_ERROR,
|
||||
frame_type: int | None = None,
|
||||
reason_phrase: str = "",
|
||||
) -> None:
|
||||
"""Closes the underlying QUIC connection and ignores any incoming events."""
|
||||
|
||||
self._is_done = True
|
||||
self._quic.close(error_code, frame_type, reason_phrase)
|
||||
|
||||
def end_stream(self, stream_id: int) -> None:
|
||||
"""Ends the given stream if not already done so."""
|
||||
|
||||
stream = self._get_or_create_stream(stream_id)
|
||||
if stream.headers_send_state != HeadersState.AFTER_TRAILERS:
|
||||
super().send_data(stream_id, b"", end_stream=True)
|
||||
stream.headers_send_state = HeadersState.AFTER_TRAILERS
|
||||
|
||||
def get_next_available_stream_id(self, is_unidirectional: bool = False):
|
||||
"""Reserves and returns the next available stream ID."""
|
||||
|
||||
return self._quic.get_next_available_stream_id(is_unidirectional)
|
||||
|
||||
def get_open_stream_ids(self) -> Iterable[int]:
|
||||
"""Iterates over all non-special open streams"""
|
||||
|
||||
return (
|
||||
stream.stream_id
|
||||
for stream in self._stream.values()
|
||||
if (
|
||||
stream.stream_type is None
|
||||
and not (
|
||||
stream.headers_recv_state == HeadersState.AFTER_TRAILERS
|
||||
and stream.headers_send_state == HeadersState.AFTER_TRAILERS
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
def handle_connection_closed(self, event: QuicConnectionClosed) -> None:
|
||||
self._is_done = True
|
||||
|
||||
def handle_stream_event(self, event: QuicStreamEvent) -> list[H3Event]:
|
||||
# don't do anything if we're done
|
||||
if self._is_done:
|
||||
return []
|
||||
|
||||
elif isinstance(event, (QuicStreamReset, QuicStreamStopSending)):
|
||||
self.close_stream(
|
||||
event.stream_id,
|
||||
event.error_code,
|
||||
stop_send=isinstance(event, QuicStreamStopSending),
|
||||
)
|
||||
stream = self._get_or_create_stream(event.stream_id)
|
||||
stream.ended = True
|
||||
stream.headers_recv_state = HeadersState.AFTER_TRAILERS
|
||||
return [StreamClosed(event.stream_id, event.error_code)]
|
||||
|
||||
# convert data events from the QUIC layer back to aioquic events
|
||||
elif isinstance(event, QuicStreamDataReceived):
|
||||
# Discard contents if we have already sent STOP_SENDING on this stream.
|
||||
if event.stream_id in self._closed_streams:
|
||||
return []
|
||||
elif self._get_or_create_stream(event.stream_id).ended:
|
||||
# aioquic will not send us any data events once a stream has ended.
|
||||
# Instead, it will close the connection. We simulate this here for H3 tests.
|
||||
self.close_connection(
|
||||
error_code=QuicErrorCode.PROTOCOL_VIOLATION,
|
||||
reason_phrase="stream already ended",
|
||||
)
|
||||
return []
|
||||
else:
|
||||
return self.handle_event(
|
||||
StreamDataReceived(event.data, event.end_stream, event.stream_id)
|
||||
)
|
||||
|
||||
# should never happen
|
||||
else: # pragma: no cover
|
||||
raise AssertionError(f"Unexpected event: {event!r}")
|
||||
|
||||
def has_sent_headers(self, stream_id: int) -> bool:
|
||||
"""Indicates whether headers have been sent over the given stream."""
|
||||
|
||||
try:
|
||||
return self._stream[stream_id].headers_send_state != HeadersState.INITIAL
|
||||
except KeyError:
|
||||
return False
|
||||
|
||||
def close_stream(
|
||||
self, stream_id: int, error_code: int, stop_send: bool = True
|
||||
) -> None:
|
||||
"""Close a stream that hasn't been closed locally yet."""
|
||||
if stream_id not in self._closed_streams:
|
||||
self._closed_streams.add(stream_id)
|
||||
|
||||
stream = self._get_or_create_stream(stream_id)
|
||||
stream.headers_send_state = HeadersState.AFTER_TRAILERS
|
||||
# https://www.rfc-editor.org/rfc/rfc9000.html#section-3.5-8
|
||||
# An endpoint that wishes to terminate both directions of
|
||||
# a bidirectional stream can terminate one direction by
|
||||
# sending a RESET_STREAM frame, and it can encourage prompt
|
||||
# termination in the opposite direction by sending a
|
||||
# STOP_SENDING frame.
|
||||
self._mock.reset_stream(stream_id=stream_id, error_code=error_code)
|
||||
if stop_send:
|
||||
self._mock.stop_send(stream_id=stream_id, error_code=error_code)
|
||||
|
||||
def send_data(self, stream_id: int, data: bytes, end_stream: bool = False) -> None:
|
||||
"""Sends data over the given stream."""
|
||||
|
||||
super().send_data(stream_id, data, end_stream)
|
||||
self._after_send(stream_id, end_stream)
|
||||
|
||||
def send_datagram(self, flow_id: int, data: bytes) -> None:
|
||||
# supporting datagrams would require additional information from the underlying QUIC connection
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
def send_headers(
|
||||
self, stream_id: int, headers: Headers, end_stream: bool = False
|
||||
) -> None:
|
||||
"""Sends headers over the given stream."""
|
||||
|
||||
# ensure we haven't sent something before
|
||||
stream = self._get_or_create_stream(stream_id)
|
||||
if stream.headers_send_state != HeadersState.INITIAL:
|
||||
raise FrameUnexpected("initial HEADERS frame is not allowed in this state")
|
||||
super().send_headers(stream_id, headers, end_stream)
|
||||
self._after_send(stream_id, end_stream)
|
||||
|
||||
def send_trailers(self, stream_id: int, trailers: Headers) -> None:
|
||||
"""Sends trailers over the given stream and ends it."""
|
||||
|
||||
# ensure we got some headers first
|
||||
stream = self._get_or_create_stream(stream_id)
|
||||
if stream.headers_send_state != HeadersState.AFTER_HEADERS:
|
||||
raise FrameUnexpected("trailing HEADERS frame is not allowed in this state")
|
||||
super().send_headers(stream_id, trailers, end_stream=True)
|
||||
self._after_send(stream_id, end_stream=True)
|
||||
|
||||
def transmit(self) -> layer.CommandGenerator[None]:
|
||||
"""Yields all pending commands for the upper QUIC layer."""
|
||||
|
||||
while self._mock.pending_commands:
|
||||
yield self._mock.pending_commands.pop(0)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"LayeredH3Connection",
|
||||
"StreamClosed",
|
||||
"TrailersReceived",
|
||||
]
|
||||
@@ -0,0 +1,105 @@
|
||||
import time
|
||||
from logging import DEBUG
|
||||
|
||||
from h11._receivebuffer import ReceiveBuffer
|
||||
|
||||
from mitmproxy import connection
|
||||
from mitmproxy import http
|
||||
from mitmproxy.net.http import http1
|
||||
from mitmproxy.proxy import commands
|
||||
from mitmproxy.proxy import context
|
||||
from mitmproxy.proxy import layer
|
||||
from mitmproxy.proxy import tunnel
|
||||
from mitmproxy.proxy.layers import tls
|
||||
from mitmproxy.proxy.layers.http._hooks import HttpConnectUpstreamHook
|
||||
from mitmproxy.utils import human
|
||||
|
||||
|
||||
class HttpUpstreamProxy(tunnel.TunnelLayer):
|
||||
buf: ReceiveBuffer
|
||||
send_connect: bool
|
||||
conn: connection.Server
|
||||
tunnel_connection: connection.Server
|
||||
|
||||
def __init__(
|
||||
self, ctx: context.Context, tunnel_conn: connection.Server, send_connect: bool
|
||||
):
|
||||
super().__init__(ctx, tunnel_connection=tunnel_conn, conn=ctx.server)
|
||||
self.buf = ReceiveBuffer()
|
||||
self.send_connect = send_connect
|
||||
|
||||
@classmethod
|
||||
def make(cls, ctx: context.Context, send_connect: bool) -> tunnel.LayerStack:
|
||||
assert ctx.server.via
|
||||
scheme, address = ctx.server.via
|
||||
assert scheme in ("http", "https")
|
||||
|
||||
http_proxy = connection.Server(address=address)
|
||||
|
||||
stack = tunnel.LayerStack()
|
||||
if scheme == "https":
|
||||
http_proxy.alpn_offers = tls.HTTP1_ALPNS
|
||||
http_proxy.sni = address[0]
|
||||
stack /= tls.ServerTLSLayer(ctx, http_proxy)
|
||||
stack /= cls(ctx, http_proxy, send_connect)
|
||||
|
||||
return stack
|
||||
|
||||
def start_handshake(self) -> layer.CommandGenerator[None]:
|
||||
if not self.send_connect:
|
||||
return (yield from super().start_handshake())
|
||||
assert self.conn.address
|
||||
flow = http.HTTPFlow(self.context.client, self.tunnel_connection)
|
||||
authority = (
|
||||
self.conn.address[0].encode("idna") + f":{self.conn.address[1]}".encode()
|
||||
)
|
||||
headers = http.Headers()
|
||||
if self.context.options.http_connect_send_host_header:
|
||||
headers.insert(0, b"Host", authority)
|
||||
flow.request = http.Request(
|
||||
host=self.conn.address[0],
|
||||
port=self.conn.address[1],
|
||||
method=b"CONNECT",
|
||||
scheme=b"",
|
||||
authority=authority,
|
||||
path=b"",
|
||||
http_version=b"HTTP/1.1",
|
||||
headers=headers,
|
||||
content=b"",
|
||||
trailers=None,
|
||||
timestamp_start=time.time(),
|
||||
timestamp_end=time.time(),
|
||||
)
|
||||
yield HttpConnectUpstreamHook(flow)
|
||||
raw = http1.assemble_request(flow.request)
|
||||
yield commands.SendData(self.tunnel_connection, raw)
|
||||
|
||||
def receive_handshake_data(
|
||||
self, data: bytes
|
||||
) -> layer.CommandGenerator[tuple[bool, str | None]]:
|
||||
if not self.send_connect:
|
||||
return (yield from super().receive_handshake_data(data))
|
||||
self.buf += data
|
||||
response_head = self.buf.maybe_extract_lines()
|
||||
if response_head:
|
||||
try:
|
||||
response = http1.read_response_head([bytes(x) for x in response_head])
|
||||
except ValueError as e:
|
||||
proxyaddr = human.format_address(self.tunnel_connection.address)
|
||||
yield commands.Log(f"{proxyaddr}: {e}")
|
||||
return False, f"Error connecting to {proxyaddr}: {e}"
|
||||
if 200 <= response.status_code < 300:
|
||||
if self.buf:
|
||||
yield from self.receive_data(bytes(self.buf))
|
||||
del self.buf
|
||||
return True, None
|
||||
else:
|
||||
proxyaddr = human.format_address(self.tunnel_connection.address)
|
||||
raw_resp = b"\n".join(response_head)
|
||||
yield commands.Log(f"{proxyaddr}: {raw_resp!r}", DEBUG)
|
||||
return (
|
||||
False,
|
||||
f"Upstream proxy {proxyaddr} refused HTTP CONNECT request: {response.status_code} {response.reason}",
|
||||
)
|
||||
else:
|
||||
return False, None
|
||||
303
venv/Lib/site-packages/mitmproxy/proxy/layers/modes.py
Normal file
303
venv/Lib/site-packages/mitmproxy/proxy/layers/modes.py
Normal file
@@ -0,0 +1,303 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import struct
|
||||
import sys
|
||||
from abc import ABCMeta
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
from mitmproxy import connection
|
||||
from mitmproxy.proxy import commands
|
||||
from mitmproxy.proxy import events
|
||||
from mitmproxy.proxy import layer
|
||||
from mitmproxy.proxy.commands import StartHook
|
||||
from mitmproxy.proxy.mode_specs import ReverseMode
|
||||
from mitmproxy.proxy.utils import expect
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
from typing_extensions import assert_never
|
||||
else:
|
||||
from typing import assert_never
|
||||
|
||||
|
||||
class HttpProxy(layer.Layer):
|
||||
@expect(events.Start)
|
||||
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
child_layer = layer.NextLayer(self.context)
|
||||
self._handle_event = child_layer.handle_event
|
||||
yield from child_layer.handle_event(event)
|
||||
|
||||
|
||||
class HttpUpstreamProxy(layer.Layer):
|
||||
@expect(events.Start)
|
||||
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
child_layer = layer.NextLayer(self.context)
|
||||
self._handle_event = child_layer.handle_event
|
||||
yield from child_layer.handle_event(event)
|
||||
|
||||
|
||||
class DestinationKnown(layer.Layer, metaclass=ABCMeta):
|
||||
"""Base layer for layers that gather connection destination info and then delegate."""
|
||||
|
||||
child_layer: layer.Layer
|
||||
|
||||
def finish_start(self) -> layer.CommandGenerator[str | None]:
|
||||
if (
|
||||
self.context.options.connection_strategy == "eager"
|
||||
and self.context.server.address
|
||||
and self.context.server.transport_protocol == "tcp"
|
||||
):
|
||||
err = yield commands.OpenConnection(self.context.server)
|
||||
if err:
|
||||
self._handle_event = self.done # type: ignore
|
||||
return err
|
||||
|
||||
self._handle_event = self.child_layer.handle_event # type: ignore
|
||||
yield from self.child_layer.handle_event(events.Start())
|
||||
return None
|
||||
|
||||
@expect(events.DataReceived, events.ConnectionClosed)
|
||||
def done(self, _) -> layer.CommandGenerator[None]:
|
||||
yield from ()
|
||||
|
||||
|
||||
class ReverseProxy(DestinationKnown):
|
||||
@expect(events.Start)
|
||||
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
spec = self.context.client.proxy_mode
|
||||
assert isinstance(spec, ReverseMode)
|
||||
self.context.server.address = spec.address
|
||||
|
||||
self.child_layer = layer.NextLayer(self.context)
|
||||
|
||||
# For secure protocols, set SNI if keep_host_header is false
|
||||
match spec.scheme:
|
||||
case "http3" | "quic" | "https" | "tls" | "dtls":
|
||||
if not self.context.options.keep_host_header:
|
||||
self.context.server.sni = spec.address[0]
|
||||
case "tcp" | "http" | "udp" | "dns":
|
||||
pass
|
||||
case _: # pragma: no cover
|
||||
assert_never(spec.scheme)
|
||||
|
||||
err = yield from self.finish_start()
|
||||
if err:
|
||||
yield commands.CloseConnection(self.context.client)
|
||||
|
||||
|
||||
class TransparentProxy(DestinationKnown):
|
||||
@expect(events.Start)
|
||||
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
assert self.context.server.address, "No server address set."
|
||||
self.child_layer = layer.NextLayer(self.context)
|
||||
err = yield from self.finish_start()
|
||||
if err:
|
||||
yield commands.CloseConnection(self.context.client)
|
||||
|
||||
|
||||
SOCKS5_VERSION = 0x05
|
||||
|
||||
SOCKS5_METHOD_NO_AUTHENTICATION_REQUIRED = 0x00
|
||||
SOCKS5_METHOD_USER_PASSWORD_AUTHENTICATION = 0x02
|
||||
SOCKS5_METHOD_NO_ACCEPTABLE_METHODS = 0xFF
|
||||
|
||||
SOCKS5_ATYP_IPV4_ADDRESS = 0x01
|
||||
SOCKS5_ATYP_DOMAINNAME = 0x03
|
||||
SOCKS5_ATYP_IPV6_ADDRESS = 0x04
|
||||
|
||||
SOCKS5_REP_HOST_UNREACHABLE = 0x04
|
||||
SOCKS5_REP_COMMAND_NOT_SUPPORTED = 0x07
|
||||
SOCKS5_REP_ADDRESS_TYPE_NOT_SUPPORTED = 0x08
|
||||
|
||||
|
||||
@dataclass
|
||||
class Socks5AuthData:
|
||||
client_conn: connection.Client
|
||||
username: str
|
||||
password: str
|
||||
valid: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class Socks5AuthHook(StartHook):
|
||||
"""
|
||||
Mitmproxy has received username/password SOCKS5 credentials.
|
||||
|
||||
This hook decides whether they are valid by setting `data.valid`.
|
||||
"""
|
||||
|
||||
data: Socks5AuthData
|
||||
|
||||
|
||||
class Socks5Proxy(DestinationKnown):
|
||||
buf: bytes = b""
|
||||
|
||||
def socks_err(
|
||||
self,
|
||||
message: str,
|
||||
reply_code: int | None = None,
|
||||
) -> layer.CommandGenerator[None]:
|
||||
if reply_code is not None:
|
||||
yield commands.SendData(
|
||||
self.context.client,
|
||||
bytes([SOCKS5_VERSION, reply_code])
|
||||
+ b"\x00\x01\x00\x00\x00\x00\x00\x00",
|
||||
)
|
||||
yield commands.CloseConnection(self.context.client)
|
||||
yield commands.Log(message)
|
||||
self._handle_event = self.done
|
||||
|
||||
@expect(events.Start, events.DataReceived, events.ConnectionClosed)
|
||||
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
if isinstance(event, events.Start):
|
||||
pass
|
||||
elif isinstance(event, events.DataReceived):
|
||||
self.buf += event.data
|
||||
yield from self.state()
|
||||
elif isinstance(event, events.ConnectionClosed):
|
||||
if self.buf:
|
||||
yield commands.Log(
|
||||
f"Client closed connection before completing SOCKS5 handshake: {self.buf!r}"
|
||||
)
|
||||
yield commands.CloseConnection(event.connection)
|
||||
else:
|
||||
raise AssertionError(f"Unknown event: {event}")
|
||||
|
||||
def state_greet(self) -> layer.CommandGenerator[None]:
|
||||
if len(self.buf) < 2:
|
||||
return
|
||||
|
||||
if self.buf[0] != SOCKS5_VERSION:
|
||||
if self.buf[:3].isupper():
|
||||
guess = "Probably not a SOCKS request but a regular HTTP request. "
|
||||
else:
|
||||
guess = ""
|
||||
yield from self.socks_err(
|
||||
guess + "Invalid SOCKS version. Expected 0x05, got 0x%x" % self.buf[0]
|
||||
)
|
||||
return
|
||||
|
||||
n_methods = self.buf[1]
|
||||
if len(self.buf) < 2 + n_methods:
|
||||
return
|
||||
|
||||
if "proxyauth" in self.context.options and self.context.options.proxyauth:
|
||||
method = SOCKS5_METHOD_USER_PASSWORD_AUTHENTICATION
|
||||
self.state = self.state_auth
|
||||
else:
|
||||
method = SOCKS5_METHOD_NO_AUTHENTICATION_REQUIRED
|
||||
self.state = self.state_connect
|
||||
|
||||
if method not in self.buf[2 : 2 + n_methods]:
|
||||
method_str = (
|
||||
"user/password"
|
||||
if method == SOCKS5_METHOD_USER_PASSWORD_AUTHENTICATION
|
||||
else "no"
|
||||
)
|
||||
yield from self.socks_err(
|
||||
f"Client does not support SOCKS5 with {method_str} authentication.",
|
||||
SOCKS5_METHOD_NO_ACCEPTABLE_METHODS,
|
||||
)
|
||||
return
|
||||
yield commands.SendData(self.context.client, bytes([SOCKS5_VERSION, method]))
|
||||
self.buf = self.buf[2 + n_methods :]
|
||||
yield from self.state()
|
||||
|
||||
state: Callable[..., layer.CommandGenerator[None]] = state_greet
|
||||
|
||||
def state_auth(self) -> layer.CommandGenerator[None]:
|
||||
if len(self.buf) < 3:
|
||||
return
|
||||
|
||||
# Parsing username and password, which is somewhat atrocious
|
||||
user_len = self.buf[1]
|
||||
if len(self.buf) < 3 + user_len:
|
||||
return
|
||||
pass_len = self.buf[2 + user_len]
|
||||
if len(self.buf) < 3 + user_len + pass_len:
|
||||
return
|
||||
user = self.buf[2 : (2 + user_len)].decode("utf-8", "backslashreplace")
|
||||
password = self.buf[(3 + user_len) : (3 + user_len + pass_len)].decode(
|
||||
"utf-8", "backslashreplace"
|
||||
)
|
||||
|
||||
data = Socks5AuthData(self.context.client, user, password)
|
||||
yield Socks5AuthHook(data)
|
||||
if not data.valid:
|
||||
# The VER field contains the current **version of the subnegotiation**, which is X'01'.
|
||||
yield commands.SendData(self.context.client, b"\x01\x01")
|
||||
yield from self.socks_err("authentication failed")
|
||||
return
|
||||
|
||||
yield commands.SendData(self.context.client, b"\x01\x00")
|
||||
self.buf = self.buf[3 + user_len + pass_len :]
|
||||
self.state = self.state_connect
|
||||
yield from self.state()
|
||||
|
||||
def state_connect(self) -> layer.CommandGenerator[None]:
|
||||
# Parse Connect Request
|
||||
if len(self.buf) < 5:
|
||||
return
|
||||
|
||||
if self.buf[:3] != b"\x05\x01\x00":
|
||||
yield from self.socks_err(
|
||||
f"Unsupported SOCKS5 request: {self.buf!r}",
|
||||
SOCKS5_REP_COMMAND_NOT_SUPPORTED,
|
||||
)
|
||||
return
|
||||
|
||||
# Determine message length
|
||||
atyp = self.buf[3]
|
||||
message_len: int
|
||||
if atyp == SOCKS5_ATYP_IPV4_ADDRESS:
|
||||
message_len = 4 + 4 + 2
|
||||
elif atyp == SOCKS5_ATYP_IPV6_ADDRESS:
|
||||
message_len = 4 + 16 + 2
|
||||
elif atyp == SOCKS5_ATYP_DOMAINNAME:
|
||||
message_len = 4 + 1 + self.buf[4] + 2
|
||||
else:
|
||||
yield from self.socks_err(
|
||||
f"Unknown address type: {atyp}", SOCKS5_REP_ADDRESS_TYPE_NOT_SUPPORTED
|
||||
)
|
||||
return
|
||||
|
||||
# Do we have enough bytes yet?
|
||||
if len(self.buf) < message_len:
|
||||
return
|
||||
|
||||
# Parse host and port
|
||||
msg, self.buf = self.buf[:message_len], self.buf[message_len:]
|
||||
|
||||
host: str
|
||||
if atyp == SOCKS5_ATYP_IPV4_ADDRESS:
|
||||
host = socket.inet_ntop(socket.AF_INET, msg[4:-2])
|
||||
elif atyp == SOCKS5_ATYP_IPV6_ADDRESS:
|
||||
host = socket.inet_ntop(socket.AF_INET6, msg[4:-2])
|
||||
else:
|
||||
host_bytes = msg[5:-2]
|
||||
host = host_bytes.decode("ascii", "replace")
|
||||
|
||||
(port,) = struct.unpack("!H", msg[-2:])
|
||||
|
||||
# We now have all we need, let's get going.
|
||||
self.context.server.address = (host, port)
|
||||
self.child_layer = layer.NextLayer(self.context)
|
||||
|
||||
# this already triggers the child layer's Start event,
|
||||
# but that's not a problem in practice...
|
||||
err = yield from self.finish_start()
|
||||
if err:
|
||||
yield commands.SendData(
|
||||
self.context.client, b"\x05\x04\x00\x01\x00\x00\x00\x00\x00\x00"
|
||||
)
|
||||
yield commands.CloseConnection(self.context.client)
|
||||
else:
|
||||
yield commands.SendData(
|
||||
self.context.client, b"\x05\x00\x00\x01\x00\x00\x00\x00\x00\x00"
|
||||
)
|
||||
if self.buf:
|
||||
yield from self.child_layer.handle_event(
|
||||
events.DataReceived(self.context.client, self.buf)
|
||||
)
|
||||
del self.buf
|
||||
@@ -0,0 +1,41 @@
|
||||
from ._client_hello_parser import quic_parse_client_hello_from_datagrams
|
||||
from ._commands import CloseQuicConnection
|
||||
from ._commands import ResetQuicStream
|
||||
from ._commands import SendQuicStreamData
|
||||
from ._commands import StopSendingQuicStream
|
||||
from ._events import QuicConnectionClosed
|
||||
from ._events import QuicStreamDataReceived
|
||||
from ._events import QuicStreamEvent
|
||||
from ._events import QuicStreamReset
|
||||
from ._events import QuicStreamStopSending
|
||||
from ._hooks import QuicStartClientHook
|
||||
from ._hooks import QuicStartServerHook
|
||||
from ._hooks import QuicTlsData
|
||||
from ._hooks import QuicTlsSettings
|
||||
from ._raw_layers import QuicStreamLayer
|
||||
from ._raw_layers import RawQuicLayer
|
||||
from ._stream_layers import ClientQuicLayer
|
||||
from ._stream_layers import error_code_to_str
|
||||
from ._stream_layers import ServerQuicLayer
|
||||
|
||||
__all__ = [
|
||||
"quic_parse_client_hello_from_datagrams",
|
||||
"CloseQuicConnection",
|
||||
"ResetQuicStream",
|
||||
"SendQuicStreamData",
|
||||
"StopSendingQuicStream",
|
||||
"QuicConnectionClosed",
|
||||
"QuicStreamDataReceived",
|
||||
"QuicStreamEvent",
|
||||
"QuicStreamReset",
|
||||
"QuicStreamStopSending",
|
||||
"QuicStartClientHook",
|
||||
"QuicStartServerHook",
|
||||
"QuicTlsData",
|
||||
"QuicTlsSettings",
|
||||
"QuicStreamLayer",
|
||||
"RawQuicLayer",
|
||||
"ClientQuicLayer",
|
||||
"error_code_to_str",
|
||||
"ServerQuicLayer",
|
||||
]
|
||||
@@ -0,0 +1,111 @@
|
||||
"""
|
||||
This module contains a very terrible QUIC client hello parser.
|
||||
|
||||
Nothing is more permanent than a temporary solution!
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from aioquic.buffer import Buffer as QuicBuffer
|
||||
from aioquic.quic.configuration import QuicConfiguration
|
||||
from aioquic.quic.connection import QuicConnection
|
||||
from aioquic.quic.connection import QuicConnectionError
|
||||
from aioquic.quic.logger import QuicLogger
|
||||
from aioquic.quic.packet import PACKET_TYPE_INITIAL
|
||||
from aioquic.quic.packet import pull_quic_header
|
||||
from aioquic.tls import HandshakeType
|
||||
|
||||
from mitmproxy.tls import ClientHello
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuicClientHello(Exception):
|
||||
"""Helper error only used in `quic_parse_client_hello_from_datagrams`."""
|
||||
|
||||
data: bytes
|
||||
|
||||
|
||||
def quic_parse_client_hello_from_datagrams(
|
||||
datagrams: list[bytes],
|
||||
) -> Optional[ClientHello]:
|
||||
"""
|
||||
Check if the supplied bytes contain a full ClientHello message,
|
||||
and if so, parse it.
|
||||
|
||||
Args:
|
||||
- msgs: list of ClientHello fragments received from client
|
||||
|
||||
Returns:
|
||||
- A ClientHello object on success
|
||||
- None, if the QUIC record is incomplete
|
||||
|
||||
Raises:
|
||||
- A ValueError, if the passed ClientHello is invalid
|
||||
"""
|
||||
|
||||
# ensure the first packet is indeed the initial one
|
||||
buffer = QuicBuffer(data=datagrams[0])
|
||||
header = pull_quic_header(buffer, 8)
|
||||
if header.packet_type != PACKET_TYPE_INITIAL:
|
||||
raise ValueError("Packet is not initial one.")
|
||||
|
||||
# patch aioquic to intercept the client hello
|
||||
quic = QuicConnection(
|
||||
configuration=QuicConfiguration(
|
||||
is_client=False,
|
||||
certificate="",
|
||||
private_key="",
|
||||
quic_logger=QuicLogger(),
|
||||
),
|
||||
original_destination_connection_id=header.destination_cid,
|
||||
)
|
||||
_initialize = quic._initialize
|
||||
|
||||
def server_handle_hello_replacement(
|
||||
input_buf: QuicBuffer,
|
||||
initial_buf: QuicBuffer,
|
||||
handshake_buf: QuicBuffer,
|
||||
onertt_buf: QuicBuffer,
|
||||
) -> None:
|
||||
assert input_buf.pull_uint8() == HandshakeType.CLIENT_HELLO
|
||||
length = 0
|
||||
for b in input_buf.pull_bytes(3):
|
||||
length = (length << 8) | b
|
||||
offset = input_buf.tell()
|
||||
raise QuicClientHello(input_buf.data_slice(offset, offset + length))
|
||||
|
||||
def initialize_replacement(peer_cid: bytes) -> None:
|
||||
try:
|
||||
return _initialize(peer_cid)
|
||||
finally:
|
||||
quic.tls._server_handle_hello = server_handle_hello_replacement # type: ignore
|
||||
|
||||
quic._initialize = initialize_replacement # type: ignore
|
||||
try:
|
||||
for dgm in datagrams:
|
||||
quic.receive_datagram(dgm, ("0.0.0.0", 0), now=time.time())
|
||||
except QuicClientHello as hello:
|
||||
try:
|
||||
return ClientHello(hello.data)
|
||||
except EOFError as e:
|
||||
raise ValueError("Invalid ClientHello data.") from e
|
||||
except QuicConnectionError as e:
|
||||
raise ValueError(e.reason_phrase) from e
|
||||
|
||||
quic_logger = quic._configuration.quic_logger
|
||||
assert isinstance(quic_logger, QuicLogger)
|
||||
traces = quic_logger.to_dict().get("traces")
|
||||
assert isinstance(traces, list)
|
||||
for trace in traces:
|
||||
quic_events = trace.get("events")
|
||||
for event in quic_events:
|
||||
if event["name"] == "transport:packet_dropped":
|
||||
raise ValueError(
|
||||
f"Invalid ClientHello packet: {event['data']['trigger']}"
|
||||
)
|
||||
|
||||
return None # pragma: no cover # FIXME: this should have test coverage
|
||||
@@ -0,0 +1,92 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from mitmproxy import connection
|
||||
from mitmproxy.proxy import commands
|
||||
|
||||
|
||||
class QuicStreamCommand(commands.ConnectionCommand):
|
||||
"""Base class for all QUIC stream commands."""
|
||||
|
||||
stream_id: int
|
||||
"""The ID of the stream the command was issued for."""
|
||||
|
||||
def __init__(self, connection: connection.Connection, stream_id: int) -> None:
|
||||
super().__init__(connection)
|
||||
self.stream_id = stream_id
|
||||
|
||||
|
||||
class SendQuicStreamData(QuicStreamCommand):
|
||||
"""Command that sends data on a stream."""
|
||||
|
||||
data: bytes
|
||||
"""The data which should be sent."""
|
||||
end_stream: bool
|
||||
"""Whether the FIN bit should be set in the STREAM frame."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection: connection.Connection,
|
||||
stream_id: int,
|
||||
data: bytes,
|
||||
end_stream: bool = False,
|
||||
) -> None:
|
||||
super().__init__(connection, stream_id)
|
||||
self.data = data
|
||||
self.end_stream = end_stream
|
||||
|
||||
def __repr__(self):
|
||||
target = repr(self.connection).partition("(")[0].lower()
|
||||
end_stream = "[end_stream] " if self.end_stream else ""
|
||||
return f"SendQuicStreamData({target} on {self.stream_id}, {end_stream}{self.data!r})"
|
||||
|
||||
|
||||
class ResetQuicStream(QuicStreamCommand):
|
||||
"""Abruptly terminate the sending part of a stream."""
|
||||
|
||||
error_code: int
|
||||
"""An error code indicating why the stream is being reset."""
|
||||
|
||||
def __init__(
|
||||
self, connection: connection.Connection, stream_id: int, error_code: int
|
||||
) -> None:
|
||||
super().__init__(connection, stream_id)
|
||||
self.error_code = error_code
|
||||
|
||||
|
||||
class StopSendingQuicStream(QuicStreamCommand):
|
||||
"""Request termination of the receiving part of a stream."""
|
||||
|
||||
error_code: int
|
||||
"""An error code indicating why the stream is being stopped."""
|
||||
|
||||
def __init__(
|
||||
self, connection: connection.Connection, stream_id: int, error_code: int
|
||||
) -> None:
|
||||
super().__init__(connection, stream_id)
|
||||
self.error_code = error_code
|
||||
|
||||
|
||||
class CloseQuicConnection(commands.CloseConnection):
|
||||
"""Close a QUIC connection."""
|
||||
|
||||
error_code: int
|
||||
"The error code which was specified when closing the connection."
|
||||
|
||||
frame_type: int | None
|
||||
"The frame type which caused the connection to be closed, or `None`."
|
||||
|
||||
reason_phrase: str
|
||||
"The human-readable reason for which the connection was closed."
|
||||
|
||||
# XXX: A bit much boilerplate right now. Should switch to dataclasses.
|
||||
def __init__(
|
||||
self,
|
||||
conn: connection.Connection,
|
||||
error_code: int,
|
||||
frame_type: int | None,
|
||||
reason_phrase: str,
|
||||
) -> None:
|
||||
super().__init__(conn)
|
||||
self.error_code = error_code
|
||||
self.frame_type = frame_type
|
||||
self.reason_phrase = reason_phrase
|
||||
@@ -0,0 +1,70 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from mitmproxy import connection
|
||||
from mitmproxy.proxy import events
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuicStreamEvent(events.ConnectionEvent):
|
||||
"""Base class for all QUIC stream events."""
|
||||
|
||||
stream_id: int
|
||||
"""The ID of the stream the event was fired for."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuicStreamDataReceived(QuicStreamEvent):
|
||||
"""Event that is fired whenever data is received on a stream."""
|
||||
|
||||
data: bytes
|
||||
"""The data which was received."""
|
||||
end_stream: bool
|
||||
"""Whether the STREAM frame had the FIN bit set."""
|
||||
|
||||
def __repr__(self):
|
||||
target = repr(self.connection).partition("(")[0].lower()
|
||||
end_stream = "[end_stream] " if self.end_stream else ""
|
||||
return f"QuicStreamDataReceived({target} on {self.stream_id}, {end_stream}{self.data!r})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuicStreamReset(QuicStreamEvent):
|
||||
"""Event that is fired when the remote peer resets a stream."""
|
||||
|
||||
error_code: int
|
||||
"""The error code that triggered the reset."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuicStreamStopSending(QuicStreamEvent):
|
||||
"""Event that is fired when the remote peer sends a STOP_SENDING frame."""
|
||||
|
||||
error_code: int
|
||||
"""The application protocol error code."""
|
||||
|
||||
|
||||
class QuicConnectionClosed(events.ConnectionClosed):
|
||||
"""QUIC connection has been closed."""
|
||||
|
||||
error_code: int
|
||||
"The error code which was specified when closing the connection."
|
||||
|
||||
frame_type: int | None
|
||||
"The frame type which caused the connection to be closed, or `None`."
|
||||
|
||||
reason_phrase: str
|
||||
"The human-readable reason for which the connection was closed."
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conn: connection.Connection,
|
||||
error_code: int,
|
||||
frame_type: int | None,
|
||||
reason_phrase: str,
|
||||
) -> None:
|
||||
super().__init__(conn)
|
||||
self.error_code = error_code
|
||||
self.frame_type = frame_type
|
||||
self.reason_phrase = reason_phrase
|
||||
77
venv/Lib/site-packages/mitmproxy/proxy/layers/quic/_hooks.py
Normal file
77
venv/Lib/site-packages/mitmproxy/proxy/layers/quic/_hooks.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import field
|
||||
from ssl import VerifyMode
|
||||
|
||||
from aioquic.tls import CipherSuite
|
||||
from cryptography import x509
|
||||
from cryptography.hazmat.primitives.asymmetric import dsa
|
||||
from cryptography.hazmat.primitives.asymmetric import ec
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
|
||||
from mitmproxy.proxy import commands
|
||||
from mitmproxy.tls import TlsData
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuicTlsSettings:
|
||||
"""
|
||||
Settings necessary to establish QUIC's TLS context.
|
||||
"""
|
||||
|
||||
alpn_protocols: list[str] | None = None
|
||||
"""A list of supported ALPN protocols."""
|
||||
certificate: x509.Certificate | None = None
|
||||
"""The certificate to use for the connection."""
|
||||
certificate_chain: list[x509.Certificate] = field(default_factory=list)
|
||||
"""A list of additional certificates to send to the peer."""
|
||||
certificate_private_key: (
|
||||
dsa.DSAPrivateKey | ec.EllipticCurvePrivateKey | rsa.RSAPrivateKey | None
|
||||
) = None
|
||||
"""The certificate's private key."""
|
||||
cipher_suites: list[CipherSuite] | None = None
|
||||
"""An optional list of allowed/advertised cipher suites."""
|
||||
ca_path: str | None = None
|
||||
"""An optional path to a directory that contains the necessary information to verify the peer certificate."""
|
||||
ca_file: str | None = None
|
||||
"""An optional path to a PEM file that will be used to verify the peer certificate."""
|
||||
verify_mode: VerifyMode | None = None
|
||||
"""An optional flag that specifies how/if the peer's certificate should be validated."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuicTlsData(TlsData):
|
||||
"""
|
||||
Event data for `quic_start_client` and `quic_start_server` event hooks.
|
||||
"""
|
||||
|
||||
settings: QuicTlsSettings | None = None
|
||||
"""
|
||||
The associated `QuicTlsSettings` object.
|
||||
This will be set by an addon in the `quic_start_*` event hooks.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuicStartClientHook(commands.StartHook):
|
||||
"""
|
||||
TLS negotiation between mitmproxy and a client over QUIC is about to start.
|
||||
|
||||
An addon is expected to initialize data.settings.
|
||||
(by default, this is done by `mitmproxy.addons.tlsconfig`)
|
||||
"""
|
||||
|
||||
data: QuicTlsData
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuicStartServerHook(commands.StartHook):
|
||||
"""
|
||||
TLS negotiation between mitmproxy and a server over QUIC is about to start.
|
||||
|
||||
An addon is expected to initialize data.settings.
|
||||
(by default, this is done by `mitmproxy.addons.tlsconfig`)
|
||||
"""
|
||||
|
||||
data: QuicTlsData
|
||||
@@ -0,0 +1,433 @@
|
||||
"""
|
||||
This module contains the proxy layers for raw QUIC proxying.
|
||||
This is used if we want to speak QUIC, but we do not want to do HTTP.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
from aioquic.quic.connection import QuicErrorCode
|
||||
from aioquic.quic.connection import stream_is_client_initiated
|
||||
from aioquic.quic.connection import stream_is_unidirectional
|
||||
|
||||
from ._commands import CloseQuicConnection
|
||||
from ._commands import ResetQuicStream
|
||||
from ._commands import SendQuicStreamData
|
||||
from ._commands import StopSendingQuicStream
|
||||
from ._events import QuicConnectionClosed
|
||||
from ._events import QuicStreamDataReceived
|
||||
from ._events import QuicStreamEvent
|
||||
from ._events import QuicStreamReset
|
||||
from mitmproxy import connection
|
||||
from mitmproxy.connection import Connection
|
||||
from mitmproxy.proxy import commands
|
||||
from mitmproxy.proxy import context
|
||||
from mitmproxy.proxy import events
|
||||
from mitmproxy.proxy import layer
|
||||
from mitmproxy.proxy import tunnel
|
||||
from mitmproxy.proxy.layers.tcp import TCPLayer
|
||||
from mitmproxy.proxy.layers.udp import UDPLayer
|
||||
|
||||
|
||||
class QuicStreamNextLayer(layer.NextLayer):
|
||||
"""`NextLayer` variant that callbacks `QuicStreamLayer` after layer decision."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: context.Context,
|
||||
stream: QuicStreamLayer,
|
||||
ask_on_start: bool = False,
|
||||
) -> None:
|
||||
super().__init__(context, ask_on_start)
|
||||
self._stream = stream
|
||||
self._layer: layer.Layer | None = None
|
||||
|
||||
@property # type: ignore
|
||||
def layer(self) -> layer.Layer | None: # type: ignore
|
||||
return self._layer
|
||||
|
||||
@layer.setter
|
||||
def layer(self, value: layer.Layer | None) -> None:
|
||||
self._layer = value
|
||||
if self._layer:
|
||||
self._stream.refresh_metadata()
|
||||
|
||||
|
||||
class QuicStreamLayer(layer.Layer):
|
||||
"""
|
||||
Layer for QUIC streams.
|
||||
Serves as a marker for NextLayer and keeps track of the connection states.
|
||||
"""
|
||||
|
||||
client: connection.Client
|
||||
"""Virtual client connection for this stream. Use this in QuicRawLayer instead of `context.client`."""
|
||||
server: connection.Server
|
||||
"""Virtual server connection for this stream. Use this in QuicRawLayer instead of `context.server`."""
|
||||
child_layer: layer.Layer
|
||||
"""The stream's child layer."""
|
||||
|
||||
def __init__(
|
||||
self, context: context.Context, force_raw: bool, stream_id: int
|
||||
) -> None:
|
||||
# we mustn't reuse the client from the QUIC connection, as the state and protocol differs
|
||||
self.client = context.client = context.client.copy()
|
||||
self.client.transport_protocol = "tcp"
|
||||
self.client.state = connection.ConnectionState.OPEN
|
||||
|
||||
# unidirectional client streams are not fully open, set the appropriate state
|
||||
if stream_is_unidirectional(stream_id):
|
||||
self.client.state = (
|
||||
connection.ConnectionState.CAN_READ
|
||||
if stream_is_client_initiated(stream_id)
|
||||
else connection.ConnectionState.CAN_WRITE
|
||||
)
|
||||
self._client_stream_id = stream_id
|
||||
|
||||
# start with a closed server
|
||||
self.server = context.server = connection.Server(
|
||||
address=context.server.address,
|
||||
transport_protocol="tcp",
|
||||
)
|
||||
self._server_stream_id: int | None = None
|
||||
|
||||
super().__init__(context)
|
||||
self.child_layer = (
|
||||
TCPLayer(context) if force_raw else QuicStreamNextLayer(context, self)
|
||||
)
|
||||
self.refresh_metadata()
|
||||
|
||||
# we don't handle any events, pass everything to the child layer
|
||||
self.handle_event = self.child_layer.handle_event # type: ignore
|
||||
self._handle_event = self.child_layer._handle_event # type: ignore
|
||||
|
||||
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
raise AssertionError # pragma: no cover
|
||||
|
||||
def open_server_stream(self, server_stream_id) -> None:
|
||||
assert self._server_stream_id is None
|
||||
self._server_stream_id = server_stream_id
|
||||
self.server.timestamp_start = time.time()
|
||||
self.server.state = (
|
||||
(
|
||||
connection.ConnectionState.CAN_WRITE
|
||||
if stream_is_client_initiated(server_stream_id)
|
||||
else connection.ConnectionState.CAN_READ
|
||||
)
|
||||
if stream_is_unidirectional(server_stream_id)
|
||||
else connection.ConnectionState.OPEN
|
||||
)
|
||||
self.refresh_metadata()
|
||||
|
||||
def refresh_metadata(self) -> None:
|
||||
# find the first transport layer
|
||||
child_layer: layer.Layer | None = self.child_layer
|
||||
while True:
|
||||
if isinstance(child_layer, layer.NextLayer):
|
||||
child_layer = child_layer.layer
|
||||
elif isinstance(child_layer, tunnel.TunnelLayer):
|
||||
child_layer = child_layer.child_layer
|
||||
else:
|
||||
break # pragma: no cover
|
||||
if isinstance(child_layer, (UDPLayer, TCPLayer)) and child_layer.flow:
|
||||
child_layer.flow.metadata["quic_is_unidirectional"] = (
|
||||
stream_is_unidirectional(self._client_stream_id)
|
||||
)
|
||||
child_layer.flow.metadata["quic_initiator"] = (
|
||||
"client"
|
||||
if stream_is_client_initiated(self._client_stream_id)
|
||||
else "server"
|
||||
)
|
||||
child_layer.flow.metadata["quic_stream_id_client"] = self._client_stream_id
|
||||
child_layer.flow.metadata["quic_stream_id_server"] = self._server_stream_id
|
||||
|
||||
def stream_id(self, client: bool) -> int | None:
|
||||
return self._client_stream_id if client else self._server_stream_id
|
||||
|
||||
|
||||
class RawQuicLayer(layer.Layer):
|
||||
"""
|
||||
This layer is responsible for de-multiplexing QUIC streams into an individual layer stack per stream.
|
||||
"""
|
||||
|
||||
force_raw: bool
|
||||
"""Indicates whether traffic should be treated as raw TCP/UDP without further protocol detection."""
|
||||
datagram_layer: layer.Layer
|
||||
"""
|
||||
The layer that is handling datagrams over QUIC. It's like a child_layer, but with a forked context.
|
||||
Instead of having a datagram-equivalent for all `QuicStream*` classes, we use `SendData` and `DataReceived` instead.
|
||||
There is also no need for another `NextLayer` marker, as a missing `QuicStreamLayer` implies UDP,
|
||||
and the connection state is the same as the one of the underlying QUIC connection.
|
||||
"""
|
||||
client_stream_ids: dict[int, QuicStreamLayer]
|
||||
"""Maps stream IDs from the client connection to stream layers."""
|
||||
server_stream_ids: dict[int, QuicStreamLayer]
|
||||
"""Maps stream IDs from the server connection to stream layers."""
|
||||
connections: dict[connection.Connection, layer.Layer]
|
||||
"""Maps connections to layers."""
|
||||
command_sources: dict[commands.Command, layer.Layer]
|
||||
"""Keeps track of blocking commands and wakeup requests."""
|
||||
next_stream_id: list[int]
|
||||
"""List containing the next stream ID for all four is_unidirectional/is_client combinations."""
|
||||
|
||||
def __init__(self, context: context.Context, force_raw: bool = False) -> None:
|
||||
super().__init__(context)
|
||||
self.force_raw = force_raw
|
||||
self.datagram_layer = (
|
||||
UDPLayer(self.context.fork())
|
||||
if force_raw
|
||||
else layer.NextLayer(self.context.fork())
|
||||
)
|
||||
self.client_stream_ids = {}
|
||||
self.server_stream_ids = {}
|
||||
self.connections = {
|
||||
context.client: self.datagram_layer,
|
||||
context.server: self.datagram_layer,
|
||||
}
|
||||
self.command_sources = {}
|
||||
self.next_stream_id = [0, 1, 2, 3]
|
||||
|
||||
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
# we treat the datagram layer as child layer, so forward Start
|
||||
if isinstance(event, events.Start):
|
||||
if self.context.server.timestamp_start is None:
|
||||
err = yield commands.OpenConnection(self.context.server)
|
||||
if err:
|
||||
yield commands.CloseConnection(self.context.client)
|
||||
self._handle_event = self.done # type: ignore
|
||||
return
|
||||
yield from self.event_to_child(self.datagram_layer, event)
|
||||
|
||||
# properly forward completion events based on their command
|
||||
elif isinstance(event, events.CommandCompleted):
|
||||
yield from self.event_to_child(
|
||||
self.command_sources.pop(event.command), event
|
||||
)
|
||||
|
||||
# route injected messages based on their connections (prefer client, fallback to server)
|
||||
elif isinstance(event, events.MessageInjected):
|
||||
if event.flow.client_conn in self.connections:
|
||||
yield from self.event_to_child(
|
||||
self.connections[event.flow.client_conn], event
|
||||
)
|
||||
elif event.flow.server_conn in self.connections:
|
||||
yield from self.event_to_child(
|
||||
self.connections[event.flow.server_conn], event
|
||||
)
|
||||
else:
|
||||
raise AssertionError(f"Flow not associated: {event.flow!r}")
|
||||
|
||||
# handle stream events targeting this context
|
||||
elif isinstance(event, QuicStreamEvent) and (
|
||||
event.connection is self.context.client
|
||||
or event.connection is self.context.server
|
||||
):
|
||||
from_client = event.connection is self.context.client
|
||||
|
||||
# fetch or create the layer
|
||||
stream_ids = (
|
||||
self.client_stream_ids if from_client else self.server_stream_ids
|
||||
)
|
||||
if event.stream_id in stream_ids:
|
||||
stream_layer = stream_ids[event.stream_id]
|
||||
else:
|
||||
# ensure we haven't just forgotten to register the ID
|
||||
assert stream_is_client_initiated(event.stream_id) == from_client
|
||||
|
||||
# for server-initiated streams we need to open the client as well
|
||||
if from_client:
|
||||
client_stream_id = event.stream_id
|
||||
server_stream_id = None
|
||||
else:
|
||||
client_stream_id = self.get_next_available_stream_id(
|
||||
is_client=False,
|
||||
is_unidirectional=stream_is_unidirectional(event.stream_id),
|
||||
)
|
||||
server_stream_id = event.stream_id
|
||||
|
||||
# create, register and start the layer
|
||||
stream_layer = QuicStreamLayer(
|
||||
self.context.fork(),
|
||||
force_raw=self.force_raw,
|
||||
stream_id=client_stream_id,
|
||||
)
|
||||
self.client_stream_ids[client_stream_id] = stream_layer
|
||||
if server_stream_id is not None:
|
||||
stream_layer.open_server_stream(server_stream_id)
|
||||
self.server_stream_ids[server_stream_id] = stream_layer
|
||||
self.connections[stream_layer.client] = stream_layer
|
||||
self.connections[stream_layer.server] = stream_layer
|
||||
yield from self.event_to_child(stream_layer, events.Start())
|
||||
|
||||
# forward data and close events
|
||||
conn: Connection = (
|
||||
stream_layer.client if from_client else stream_layer.server
|
||||
)
|
||||
if isinstance(event, QuicStreamDataReceived):
|
||||
if event.data:
|
||||
yield from self.event_to_child(
|
||||
stream_layer, events.DataReceived(conn, event.data)
|
||||
)
|
||||
if event.end_stream:
|
||||
yield from self.close_stream_layer(stream_layer, from_client)
|
||||
elif isinstance(event, QuicStreamReset):
|
||||
# preserve stream resets
|
||||
for command in self.close_stream_layer(stream_layer, from_client):
|
||||
if (
|
||||
isinstance(command, SendQuicStreamData)
|
||||
and command.stream_id == stream_layer.stream_id(not from_client)
|
||||
and command.end_stream
|
||||
and not command.data
|
||||
):
|
||||
yield ResetQuicStream(
|
||||
command.connection, command.stream_id, event.error_code
|
||||
)
|
||||
else:
|
||||
yield command
|
||||
else:
|
||||
raise AssertionError(f"Unexpected stream event: {event!r}")
|
||||
|
||||
# handle close events that target this context
|
||||
elif isinstance(event, QuicConnectionClosed) and (
|
||||
event.connection is self.context.client
|
||||
or event.connection is self.context.server
|
||||
):
|
||||
from_client = event.connection is self.context.client
|
||||
other_conn = self.context.server if from_client else self.context.client
|
||||
|
||||
# be done if both connections are closed
|
||||
if other_conn.connected:
|
||||
yield CloseQuicConnection(
|
||||
other_conn, event.error_code, event.frame_type, event.reason_phrase
|
||||
)
|
||||
else:
|
||||
self._handle_event = self.done # type: ignore
|
||||
|
||||
# always forward to the datagram layer and swallow `CloseConnection` commands
|
||||
for command in self.event_to_child(self.datagram_layer, event):
|
||||
if (
|
||||
not isinstance(command, commands.CloseConnection)
|
||||
or command.connection is not other_conn
|
||||
):
|
||||
yield command
|
||||
|
||||
# forward to either the client or server connection of stream layers and swallow empty stream end
|
||||
for conn, child_layer in self.connections.items():
|
||||
if isinstance(child_layer, QuicStreamLayer) and (
|
||||
(conn is child_layer.client)
|
||||
if from_client
|
||||
else (conn is child_layer.server)
|
||||
):
|
||||
conn.state &= ~connection.ConnectionState.CAN_WRITE
|
||||
for command in self.close_stream_layer(child_layer, from_client):
|
||||
if not isinstance(command, SendQuicStreamData) or command.data:
|
||||
yield command
|
||||
|
||||
# all other connection events are routed to their corresponding layer
|
||||
elif isinstance(event, events.ConnectionEvent):
|
||||
yield from self.event_to_child(self.connections[event.connection], event)
|
||||
|
||||
else:
|
||||
raise AssertionError(f"Unexpected event: {event!r}")
|
||||
|
||||
def close_stream_layer(
|
||||
self, stream_layer: QuicStreamLayer, client: bool
|
||||
) -> layer.CommandGenerator[None]:
|
||||
"""Closes the incoming part of a connection."""
|
||||
|
||||
conn = stream_layer.client if client else stream_layer.server
|
||||
conn.state &= ~connection.ConnectionState.CAN_READ
|
||||
assert conn.timestamp_start is not None
|
||||
if conn.timestamp_end is None:
|
||||
conn.timestamp_end = time.time()
|
||||
yield from self.event_to_child(stream_layer, events.ConnectionClosed(conn))
|
||||
|
||||
def event_to_child(
|
||||
self, child_layer: layer.Layer, event: events.Event
|
||||
) -> layer.CommandGenerator[None]:
|
||||
"""Forwards events to child layers and translates commands."""
|
||||
|
||||
for command in child_layer.handle_event(event):
|
||||
# intercept commands for streams connections
|
||||
if (
|
||||
isinstance(child_layer, QuicStreamLayer)
|
||||
and isinstance(command, commands.ConnectionCommand)
|
||||
and (
|
||||
command.connection is child_layer.client
|
||||
or command.connection is child_layer.server
|
||||
)
|
||||
):
|
||||
# get the target connection and stream ID
|
||||
to_client = command.connection is child_layer.client
|
||||
quic_conn = self.context.client if to_client else self.context.server
|
||||
stream_id = child_layer.stream_id(to_client)
|
||||
|
||||
# write data and check CloseConnection wasn't called before
|
||||
if isinstance(command, commands.SendData):
|
||||
assert stream_id is not None
|
||||
if command.connection.state & connection.ConnectionState.CAN_WRITE:
|
||||
yield SendQuicStreamData(quic_conn, stream_id, command.data)
|
||||
|
||||
# send a FIN and optionally also a STOP frame
|
||||
elif isinstance(command, commands.CloseConnection):
|
||||
assert stream_id is not None
|
||||
if command.connection.state & connection.ConnectionState.CAN_WRITE:
|
||||
command.connection.state &= (
|
||||
~connection.ConnectionState.CAN_WRITE
|
||||
)
|
||||
yield SendQuicStreamData(
|
||||
quic_conn, stream_id, b"", end_stream=True
|
||||
)
|
||||
# XXX: Use `command.connection.state & connection.ConnectionState.CAN_READ` instead?
|
||||
only_close_our_half = (
|
||||
isinstance(command, commands.CloseTcpConnection)
|
||||
and command.half_close
|
||||
)
|
||||
if not only_close_our_half:
|
||||
if stream_is_client_initiated(
|
||||
stream_id
|
||||
) == to_client or not stream_is_unidirectional(stream_id):
|
||||
yield StopSendingQuicStream(
|
||||
quic_conn, stream_id, QuicErrorCode.NO_ERROR
|
||||
)
|
||||
yield from self.close_stream_layer(child_layer, to_client)
|
||||
|
||||
# open server connections by reserving the next stream ID
|
||||
elif isinstance(command, commands.OpenConnection):
|
||||
assert not to_client
|
||||
assert stream_id is None
|
||||
client_stream_id = child_layer.stream_id(client=True)
|
||||
assert client_stream_id is not None
|
||||
stream_id = self.get_next_available_stream_id(
|
||||
is_client=True,
|
||||
is_unidirectional=stream_is_unidirectional(client_stream_id),
|
||||
)
|
||||
child_layer.open_server_stream(stream_id)
|
||||
self.server_stream_ids[stream_id] = child_layer
|
||||
yield from self.event_to_child(
|
||||
child_layer, events.OpenConnectionCompleted(command, None)
|
||||
)
|
||||
|
||||
else:
|
||||
raise AssertionError(
|
||||
f"Unexpected stream connection command: {command!r}"
|
||||
)
|
||||
|
||||
# remember blocking and wakeup commands
|
||||
else:
|
||||
if command.blocking or isinstance(command, commands.RequestWakeup):
|
||||
self.command_sources[command] = child_layer
|
||||
if isinstance(command, commands.OpenConnection):
|
||||
self.connections[command.connection] = child_layer
|
||||
yield command
|
||||
|
||||
def get_next_available_stream_id(
|
||||
self, is_client: bool, is_unidirectional: bool = False
|
||||
) -> int:
|
||||
index = (int(is_unidirectional) << 1) | int(not is_client)
|
||||
stream_id = self.next_stream_id[index]
|
||||
self.next_stream_id[index] = stream_id + 4
|
||||
return stream_id
|
||||
|
||||
def done(self, _) -> layer.CommandGenerator[None]: # pragma: no cover
|
||||
yield from ()
|
||||
@@ -0,0 +1,638 @@
|
||||
"""
|
||||
This module contains the client and server proxy layers for QUIC streams
|
||||
which decrypt and encrypt traffic. Decrypted stream data is then forwarded
|
||||
to either the raw layers, or the HTTP/3 client in ../http/_http3.py.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from logging import DEBUG
|
||||
from logging import ERROR
|
||||
from logging import WARNING
|
||||
|
||||
from aioquic.buffer import Buffer as QuicBuffer
|
||||
from aioquic.h3.connection import ErrorCode as H3ErrorCode
|
||||
from aioquic.quic import events as quic_events
|
||||
from aioquic.quic.configuration import QuicConfiguration
|
||||
from aioquic.quic.connection import QuicConnection
|
||||
from aioquic.quic.connection import QuicConnectionState
|
||||
from aioquic.quic.connection import QuicErrorCode
|
||||
from aioquic.quic.packet import encode_quic_version_negotiation
|
||||
from aioquic.quic.packet import PACKET_TYPE_INITIAL
|
||||
from aioquic.quic.packet import pull_quic_header
|
||||
from cryptography import x509
|
||||
|
||||
from ._client_hello_parser import quic_parse_client_hello_from_datagrams
|
||||
from ._commands import CloseQuicConnection
|
||||
from ._commands import QuicStreamCommand
|
||||
from ._commands import ResetQuicStream
|
||||
from ._commands import SendQuicStreamData
|
||||
from ._commands import StopSendingQuicStream
|
||||
from ._events import QuicConnectionClosed
|
||||
from ._events import QuicStreamDataReceived
|
||||
from ._events import QuicStreamReset
|
||||
from ._events import QuicStreamStopSending
|
||||
from ._hooks import QuicStartClientHook
|
||||
from ._hooks import QuicStartServerHook
|
||||
from ._hooks import QuicTlsData
|
||||
from ._hooks import QuicTlsSettings
|
||||
from mitmproxy import certs
|
||||
from mitmproxy import connection
|
||||
from mitmproxy import ctx
|
||||
from mitmproxy.net import tls
|
||||
from mitmproxy.proxy import commands
|
||||
from mitmproxy.proxy import context
|
||||
from mitmproxy.proxy import events
|
||||
from mitmproxy.proxy import layer
|
||||
from mitmproxy.proxy import tunnel
|
||||
from mitmproxy.proxy.layers.tls import TlsClienthelloHook
|
||||
from mitmproxy.proxy.layers.tls import TlsEstablishedClientHook
|
||||
from mitmproxy.proxy.layers.tls import TlsEstablishedServerHook
|
||||
from mitmproxy.proxy.layers.tls import TlsFailedClientHook
|
||||
from mitmproxy.proxy.layers.tls import TlsFailedServerHook
|
||||
from mitmproxy.proxy.layers.udp import UDPLayer
|
||||
from mitmproxy.tls import ClientHelloData
|
||||
|
||||
SUPPORTED_QUIC_VERSIONS_SERVER = QuicConfiguration(is_client=False).supported_versions
|
||||
|
||||
|
||||
class QuicLayer(tunnel.TunnelLayer):
|
||||
quic: QuicConnection | None = None
|
||||
tls: QuicTlsSettings | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: context.Context,
|
||||
conn: connection.Connection,
|
||||
time: Callable[[], float] | None,
|
||||
) -> None:
|
||||
super().__init__(context, tunnel_connection=conn, conn=conn)
|
||||
self.child_layer = layer.NextLayer(self.context, ask_on_start=True)
|
||||
self._time = time or ctx.master.event_loop.time
|
||||
self._wakeup_commands: dict[commands.RequestWakeup, float] = dict()
|
||||
conn.tls = True
|
||||
|
||||
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
if isinstance(event, events.Wakeup) and event.command in self._wakeup_commands:
|
||||
# TunnelLayer has no understanding of wakeups, so we turn this into an empty DataReceived event
|
||||
# which TunnelLayer recognizes as belonging to our connection.
|
||||
assert self.quic
|
||||
scheduled_time = self._wakeup_commands.pop(event.command)
|
||||
if self.quic._state is not QuicConnectionState.TERMINATED:
|
||||
# weird quirk: asyncio sometimes returns a bit ahead of time.
|
||||
now = max(scheduled_time, self._time())
|
||||
self.quic.handle_timer(now)
|
||||
yield from super()._handle_event(
|
||||
events.DataReceived(self.tunnel_connection, b"")
|
||||
)
|
||||
else:
|
||||
yield from super()._handle_event(event)
|
||||
|
||||
def event_to_child(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
# the parent will call _handle_command multiple times, we transmit cumulative afterwards
|
||||
# this will reduce the number of sends, especially if data=b"" and end_stream=True
|
||||
yield from super().event_to_child(event)
|
||||
if self.quic:
|
||||
yield from self.tls_interact()
|
||||
|
||||
def _handle_command(
|
||||
self, command: commands.Command
|
||||
) -> layer.CommandGenerator[None]:
|
||||
"""Turns stream commands into aioquic connection invocations."""
|
||||
if isinstance(command, QuicStreamCommand) and command.connection is self.conn:
|
||||
assert self.quic
|
||||
if isinstance(command, SendQuicStreamData):
|
||||
self.quic.send_stream_data(
|
||||
command.stream_id, command.data, command.end_stream
|
||||
)
|
||||
elif isinstance(command, ResetQuicStream):
|
||||
stream = self.quic._get_or_create_stream_for_send(command.stream_id)
|
||||
existing_reset_error_code = stream.sender._reset_error_code
|
||||
if existing_reset_error_code is None:
|
||||
self.quic.reset_stream(command.stream_id, command.error_code)
|
||||
elif self.debug: # pragma: no cover
|
||||
yield commands.Log(
|
||||
f"{self.debug}[quic] stream {stream.stream_id} already reset ({existing_reset_error_code=}, {command.error_code=})",
|
||||
DEBUG,
|
||||
)
|
||||
elif isinstance(command, StopSendingQuicStream):
|
||||
# the stream might have already been closed, check before stopping
|
||||
if command.stream_id in self.quic._streams:
|
||||
self.quic.stop_stream(command.stream_id, command.error_code)
|
||||
else:
|
||||
raise AssertionError(f"Unexpected stream command: {command!r}")
|
||||
else:
|
||||
yield from super()._handle_command(command)
|
||||
|
||||
def start_tls(
|
||||
self, original_destination_connection_id: bytes | None
|
||||
) -> layer.CommandGenerator[None]:
|
||||
"""Initiates the aioquic connection."""
|
||||
|
||||
# must only be called if QUIC is uninitialized
|
||||
assert not self.quic
|
||||
assert not self.tls
|
||||
|
||||
# query addons to provide the necessary TLS settings
|
||||
tls_data = QuicTlsData(self.conn, self.context)
|
||||
if self.conn is self.context.client:
|
||||
yield QuicStartClientHook(tls_data)
|
||||
else:
|
||||
yield QuicStartServerHook(tls_data)
|
||||
if not tls_data.settings:
|
||||
yield commands.Log(
|
||||
f"No QUIC context was provided, failing connection.", ERROR
|
||||
)
|
||||
yield commands.CloseConnection(self.conn)
|
||||
return
|
||||
|
||||
# build the aioquic connection
|
||||
configuration = tls_settings_to_configuration(
|
||||
settings=tls_data.settings,
|
||||
is_client=self.conn is self.context.server,
|
||||
server_name=self.conn.sni,
|
||||
)
|
||||
self.quic = QuicConnection(
|
||||
configuration=configuration,
|
||||
original_destination_connection_id=original_destination_connection_id,
|
||||
)
|
||||
self.tls = tls_data.settings
|
||||
|
||||
# if we act as client, connect to upstream
|
||||
if original_destination_connection_id is None:
|
||||
self.quic.connect(self.conn.peername, now=self._time())
|
||||
yield from self.tls_interact()
|
||||
|
||||
def tls_interact(self) -> layer.CommandGenerator[None]:
|
||||
"""Retrieves all pending outgoing packets from aioquic and sends the data."""
|
||||
|
||||
# send all queued datagrams
|
||||
assert self.quic
|
||||
now = self._time()
|
||||
|
||||
for data, addr in self.quic.datagrams_to_send(now=now):
|
||||
assert addr == self.conn.peername
|
||||
yield commands.SendData(self.tunnel_connection, data)
|
||||
|
||||
timer = self.quic.get_timer()
|
||||
if timer is not None:
|
||||
# smooth wakeups a bit.
|
||||
smoothed = timer + 0.002
|
||||
# request a new wakeup if all pending requests trigger at a later time
|
||||
if not any(
|
||||
existing <= smoothed for existing in self._wakeup_commands.values()
|
||||
):
|
||||
command = commands.RequestWakeup(timer - now)
|
||||
self._wakeup_commands[command] = timer
|
||||
yield command
|
||||
|
||||
def receive_handshake_data(
|
||||
self, data: bytes
|
||||
) -> layer.CommandGenerator[tuple[bool, str | None]]:
|
||||
assert self.quic
|
||||
|
||||
# forward incoming data to aioquic
|
||||
if data:
|
||||
self.quic.receive_datagram(data, self.conn.peername, now=self._time())
|
||||
|
||||
# handle pre-handshake events
|
||||
while event := self.quic.next_event():
|
||||
if isinstance(event, quic_events.ConnectionTerminated):
|
||||
err = event.reason_phrase or error_code_to_str(event.error_code)
|
||||
return False, err
|
||||
elif isinstance(event, quic_events.HandshakeCompleted):
|
||||
# concatenate all peer certificates
|
||||
all_certs: list[x509.Certificate] = []
|
||||
if self.quic.tls._peer_certificate:
|
||||
all_certs.append(self.quic.tls._peer_certificate)
|
||||
all_certs.extend(self.quic.tls._peer_certificate_chain)
|
||||
|
||||
# set the connection's TLS properties
|
||||
self.conn.timestamp_tls_setup = time.time()
|
||||
if event.alpn_protocol:
|
||||
self.conn.alpn = event.alpn_protocol.encode("ascii")
|
||||
self.conn.certificate_list = [certs.Cert(cert) for cert in all_certs]
|
||||
assert self.quic.tls.key_schedule
|
||||
self.conn.cipher = self.quic.tls.key_schedule.cipher_suite.name
|
||||
self.conn.tls_version = "QUICv1"
|
||||
|
||||
# log the result and report the success to addons
|
||||
if self.debug:
|
||||
yield commands.Log(
|
||||
f"{self.debug}[quic] tls established: {self.conn}", DEBUG
|
||||
)
|
||||
if self.conn is self.context.client:
|
||||
yield TlsEstablishedClientHook(
|
||||
QuicTlsData(self.conn, self.context, settings=self.tls)
|
||||
)
|
||||
else:
|
||||
yield TlsEstablishedServerHook(
|
||||
QuicTlsData(self.conn, self.context, settings=self.tls)
|
||||
)
|
||||
|
||||
yield from self.tls_interact()
|
||||
return True, None
|
||||
elif isinstance(
|
||||
event,
|
||||
(
|
||||
quic_events.ConnectionIdIssued,
|
||||
quic_events.ConnectionIdRetired,
|
||||
quic_events.PingAcknowledged,
|
||||
quic_events.ProtocolNegotiated,
|
||||
),
|
||||
):
|
||||
pass
|
||||
else:
|
||||
raise AssertionError(f"Unexpected event: {event!r}")
|
||||
|
||||
# transmit buffered data and re-arm timer
|
||||
yield from self.tls_interact()
|
||||
return False, None
|
||||
|
||||
def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]:
|
||||
self.conn.error = err
|
||||
if self.conn is self.context.client:
|
||||
yield TlsFailedClientHook(
|
||||
QuicTlsData(self.conn, self.context, settings=self.tls)
|
||||
)
|
||||
else:
|
||||
yield TlsFailedServerHook(
|
||||
QuicTlsData(self.conn, self.context, settings=self.tls)
|
||||
)
|
||||
yield from super().on_handshake_error(err)
|
||||
|
||||
def receive_data(self, data: bytes) -> layer.CommandGenerator[None]:
|
||||
assert self.quic
|
||||
|
||||
# forward incoming data to aioquic
|
||||
if data:
|
||||
self.quic.receive_datagram(data, self.conn.peername, now=self._time())
|
||||
|
||||
# handle post-handshake events
|
||||
while event := self.quic.next_event():
|
||||
if isinstance(event, quic_events.ConnectionTerminated):
|
||||
if self.debug:
|
||||
reason = event.reason_phrase or error_code_to_str(event.error_code)
|
||||
yield commands.Log(
|
||||
f"{self.debug}[quic] close_notify {self.conn} ({reason=!s})",
|
||||
DEBUG,
|
||||
)
|
||||
# We don't rely on `ConnectionTerminated` to dispatch `QuicConnectionClosed`, because
|
||||
# after aioquic receives a termination frame, it still waits for the next `handle_timer`
|
||||
# before returning `ConnectionTerminated` in `next_event`. In the meantime, the underlying
|
||||
# connection could be closed. Therefore, we instead dispatch on `ConnectionClosed` and simply
|
||||
# close the connection here.
|
||||
yield commands.CloseConnection(self.tunnel_connection)
|
||||
return # we don't handle any further events, nor do/can we transmit data, so exit
|
||||
elif isinstance(event, quic_events.DatagramFrameReceived):
|
||||
yield from self.event_to_child(
|
||||
events.DataReceived(self.conn, event.data)
|
||||
)
|
||||
elif isinstance(event, quic_events.StreamDataReceived):
|
||||
yield from self.event_to_child(
|
||||
QuicStreamDataReceived(
|
||||
self.conn, event.stream_id, event.data, event.end_stream
|
||||
)
|
||||
)
|
||||
elif isinstance(event, quic_events.StreamReset):
|
||||
yield from self.event_to_child(
|
||||
QuicStreamReset(self.conn, event.stream_id, event.error_code)
|
||||
)
|
||||
elif isinstance(event, quic_events.StopSendingReceived):
|
||||
yield from self.event_to_child(
|
||||
QuicStreamStopSending(self.conn, event.stream_id, event.error_code)
|
||||
)
|
||||
elif isinstance(
|
||||
event,
|
||||
(
|
||||
quic_events.ConnectionIdIssued,
|
||||
quic_events.ConnectionIdRetired,
|
||||
quic_events.PingAcknowledged,
|
||||
quic_events.ProtocolNegotiated,
|
||||
),
|
||||
):
|
||||
pass
|
||||
else:
|
||||
raise AssertionError(f"Unexpected event: {event!r}")
|
||||
|
||||
# transmit buffered data and re-arm timer
|
||||
yield from self.tls_interact()
|
||||
|
||||
def receive_close(self) -> layer.CommandGenerator[None]:
|
||||
assert self.quic
|
||||
# if `_close_event` is not set, the underlying connection has been closed
|
||||
# we turn this into a QUIC close event as well
|
||||
close_event = self.quic._close_event or quic_events.ConnectionTerminated(
|
||||
QuicErrorCode.NO_ERROR, None, "Connection closed."
|
||||
)
|
||||
yield from self.event_to_child(
|
||||
QuicConnectionClosed(
|
||||
self.conn,
|
||||
close_event.error_code,
|
||||
close_event.frame_type,
|
||||
close_event.reason_phrase,
|
||||
)
|
||||
)
|
||||
|
||||
def send_data(self, data: bytes) -> layer.CommandGenerator[None]:
|
||||
# non-stream data uses datagram frames
|
||||
assert self.quic
|
||||
if data:
|
||||
self.quic.send_datagram_frame(data)
|
||||
yield from self.tls_interact()
|
||||
|
||||
def send_close(
|
||||
self, command: commands.CloseConnection
|
||||
) -> layer.CommandGenerator[None]:
|
||||
# properly close the QUIC connection
|
||||
if self.quic:
|
||||
if isinstance(command, CloseQuicConnection):
|
||||
self.quic.close(
|
||||
command.error_code, command.frame_type, command.reason_phrase
|
||||
)
|
||||
else:
|
||||
self.quic.close()
|
||||
yield from self.tls_interact()
|
||||
yield from super().send_close(command)
|
||||
|
||||
|
||||
class ServerQuicLayer(QuicLayer):
|
||||
"""
|
||||
This layer establishes QUIC for a single server connection.
|
||||
"""
|
||||
|
||||
wait_for_clienthello: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: context.Context,
|
||||
conn: connection.Server | None = None,
|
||||
time: Callable[[], float] | None = None,
|
||||
):
|
||||
super().__init__(context, conn or context.server, time)
|
||||
|
||||
def start_handshake(self) -> layer.CommandGenerator[None]:
|
||||
wait_for_clienthello = not self.command_to_reply_to and isinstance(
|
||||
self.child_layer, ClientQuicLayer
|
||||
)
|
||||
if wait_for_clienthello:
|
||||
self.wait_for_clienthello = True
|
||||
self.tunnel_state = tunnel.TunnelState.CLOSED
|
||||
else:
|
||||
yield from self.start_tls(None)
|
||||
|
||||
def event_to_child(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
if self.wait_for_clienthello:
|
||||
for command in super().event_to_child(event):
|
||||
if (
|
||||
isinstance(command, commands.OpenConnection)
|
||||
and command.connection == self.conn
|
||||
):
|
||||
self.wait_for_clienthello = False
|
||||
else:
|
||||
yield command
|
||||
else:
|
||||
yield from super().event_to_child(event)
|
||||
|
||||
def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]:
|
||||
yield commands.Log(f"Server QUIC handshake failed. {err}", level=WARNING)
|
||||
yield from super().on_handshake_error(err)
|
||||
|
||||
|
||||
class ClientQuicLayer(QuicLayer):
|
||||
"""
|
||||
This layer establishes QUIC on a single client connection.
|
||||
"""
|
||||
|
||||
server_tls_available: bool
|
||||
"""Indicates whether the parent layer is a ServerQuicLayer."""
|
||||
handshake_datagram_buf: list[bytes]
|
||||
|
||||
def __init__(
|
||||
self, context: context.Context, time: Callable[[], float] | None = None
|
||||
) -> None:
|
||||
# same as ClientTLSLayer, we might be nested in some other transport
|
||||
if context.client.tls:
|
||||
context.client.alpn = None
|
||||
context.client.cipher = None
|
||||
context.client.sni = None
|
||||
context.client.timestamp_tls_setup = None
|
||||
context.client.tls_version = None
|
||||
context.client.certificate_list = []
|
||||
context.client.mitmcert = None
|
||||
context.client.alpn_offers = []
|
||||
context.client.cipher_list = []
|
||||
|
||||
super().__init__(context, context.client, time)
|
||||
self.server_tls_available = len(self.context.layers) >= 2 and isinstance(
|
||||
self.context.layers[-2], ServerQuicLayer
|
||||
)
|
||||
self.handshake_datagram_buf = []
|
||||
|
||||
def start_handshake(self) -> layer.CommandGenerator[None]:
|
||||
yield from ()
|
||||
|
||||
def receive_handshake_data(
|
||||
self, data: bytes
|
||||
) -> layer.CommandGenerator[tuple[bool, str | None]]:
|
||||
if not self.context.options.http3:
|
||||
yield commands.Log(
|
||||
f"Swallowing QUIC handshake because HTTP/3 is disabled.", DEBUG
|
||||
)
|
||||
return False, None
|
||||
|
||||
# if we already had a valid client hello, don't process further packets
|
||||
if self.tls:
|
||||
return (yield from super().receive_handshake_data(data))
|
||||
|
||||
# fail if the received data is not a QUIC packet
|
||||
buffer = QuicBuffer(data=data)
|
||||
try:
|
||||
header = pull_quic_header(buffer)
|
||||
except TypeError:
|
||||
return False, f"Cannot parse QUIC header: Malformed head ({data.hex()})"
|
||||
except ValueError as e:
|
||||
return False, f"Cannot parse QUIC header: {e} ({data.hex()})"
|
||||
|
||||
# negotiate version, support all versions known to aioquic
|
||||
if (
|
||||
header.version is not None
|
||||
and header.version not in SUPPORTED_QUIC_VERSIONS_SERVER
|
||||
):
|
||||
yield commands.SendData(
|
||||
self.tunnel_connection,
|
||||
encode_quic_version_negotiation(
|
||||
source_cid=header.destination_cid,
|
||||
destination_cid=header.source_cid,
|
||||
supported_versions=SUPPORTED_QUIC_VERSIONS_SERVER,
|
||||
),
|
||||
)
|
||||
return False, None
|
||||
|
||||
# ensure it's (likely) a client handshake packet
|
||||
if len(data) < 1200 or header.packet_type != PACKET_TYPE_INITIAL:
|
||||
return (
|
||||
False,
|
||||
f"Invalid handshake received, roaming not supported. ({data.hex()})",
|
||||
)
|
||||
|
||||
self.handshake_datagram_buf.append(data)
|
||||
# extract the client hello
|
||||
try:
|
||||
client_hello = quic_parse_client_hello_from_datagrams(
|
||||
self.handshake_datagram_buf
|
||||
)
|
||||
except ValueError as e:
|
||||
msgs = b"\n".join(self.handshake_datagram_buf)
|
||||
dbg = f"Cannot parse ClientHello: {e} ({msgs.hex()})"
|
||||
self.handshake_datagram_buf.clear()
|
||||
return False, dbg
|
||||
|
||||
if not client_hello:
|
||||
return False, None
|
||||
|
||||
# copy the client hello information
|
||||
self.conn.sni = client_hello.sni
|
||||
self.conn.alpn_offers = client_hello.alpn_protocols
|
||||
|
||||
# check with addons what we shall do
|
||||
tls_clienthello = ClientHelloData(self.context, client_hello)
|
||||
yield TlsClienthelloHook(tls_clienthello)
|
||||
|
||||
# replace the QUIC layer with an UDP layer if requested
|
||||
if tls_clienthello.ignore_connection:
|
||||
self.conn = self.tunnel_connection = connection.Client(
|
||||
peername=("ignore-conn", 0),
|
||||
sockname=("ignore-conn", 0),
|
||||
transport_protocol="udp",
|
||||
state=connection.ConnectionState.OPEN,
|
||||
)
|
||||
|
||||
# we need to replace the server layer as well, if there is one
|
||||
parent_layer = self.context.layers[self.context.layers.index(self) - 1]
|
||||
if isinstance(parent_layer, ServerQuicLayer):
|
||||
parent_layer.conn = parent_layer.tunnel_connection = connection.Server(
|
||||
address=None
|
||||
)
|
||||
replacement_layer = UDPLayer(self.context, ignore=True)
|
||||
parent_layer.handle_event = replacement_layer.handle_event # type: ignore
|
||||
parent_layer._handle_event = replacement_layer._handle_event # type: ignore
|
||||
yield from parent_layer.handle_event(events.Start())
|
||||
for dgm in self.handshake_datagram_buf:
|
||||
yield from parent_layer.handle_event(
|
||||
events.DataReceived(self.context.client, dgm)
|
||||
)
|
||||
self.handshake_datagram_buf.clear()
|
||||
return True, None
|
||||
|
||||
# start the server QUIC connection if demanded and available
|
||||
if (
|
||||
tls_clienthello.establish_server_tls_first
|
||||
and not self.context.server.tls_established
|
||||
):
|
||||
err = yield from self.start_server_tls()
|
||||
if err:
|
||||
yield commands.Log(
|
||||
f"Unable to establish QUIC connection with server ({err}). "
|
||||
f"Trying to establish QUIC with client anyway. "
|
||||
f"If you plan to redirect requests away from this server, "
|
||||
f"consider setting `connection_strategy` to `lazy` to suppress early connections."
|
||||
)
|
||||
|
||||
# start the client QUIC connection
|
||||
yield from self.start_tls(header.destination_cid)
|
||||
# XXX copied from TLS, we assume that `CloseConnection` in `start_tls` takes effect immediately
|
||||
if not self.conn.connected:
|
||||
return False, "connection closed early"
|
||||
|
||||
# send the client hello to aioquic
|
||||
assert self.quic
|
||||
for dgm in self.handshake_datagram_buf:
|
||||
self.quic.receive_datagram(dgm, self.conn.peername, now=self._time())
|
||||
self.handshake_datagram_buf.clear()
|
||||
|
||||
# handle events emanating from `self.quic`
|
||||
return (yield from super().receive_handshake_data(b""))
|
||||
|
||||
def start_server_tls(self) -> layer.CommandGenerator[str | None]:
|
||||
if not self.server_tls_available:
|
||||
return f"No server QUIC available."
|
||||
err = yield commands.OpenConnection(self.context.server)
|
||||
return err
|
||||
|
||||
def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]:
|
||||
yield commands.Log(f"Client QUIC handshake failed. {err}", level=WARNING)
|
||||
yield from super().on_handshake_error(err)
|
||||
self.event_to_child = self.errored # type: ignore
|
||||
|
||||
def errored(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
if self.debug is not None:
|
||||
yield commands.Log(
|
||||
f"{self.debug}[quic] Swallowing {event} as handshake failed.", DEBUG
|
||||
)
|
||||
|
||||
|
||||
class QuicSecretsLogger:
|
||||
logger: tls.MasterSecretLogger
|
||||
|
||||
def __init__(self, logger: tls.MasterSecretLogger) -> None:
|
||||
super().__init__()
|
||||
self.logger = logger
|
||||
|
||||
def write(self, s: str) -> int:
|
||||
if s[-1:] == "\n":
|
||||
s = s[:-1]
|
||||
data = s.encode("ascii")
|
||||
self.logger(None, data) # type: ignore
|
||||
return len(data) + 1
|
||||
|
||||
def flush(self) -> None:
|
||||
# done by the logger during write
|
||||
pass
|
||||
|
||||
|
||||
def error_code_to_str(error_code: int) -> str:
|
||||
"""Returns the corresponding name of the given error code or a string containing its numeric value."""
|
||||
|
||||
try:
|
||||
return H3ErrorCode(error_code).name
|
||||
except ValueError:
|
||||
try:
|
||||
return QuicErrorCode(error_code).name
|
||||
except ValueError:
|
||||
return f"unknown error (0x{error_code:x})"
|
||||
|
||||
|
||||
def is_success_error_code(error_code: int) -> bool:
|
||||
"""Returns whether the given error code actually indicates no error."""
|
||||
|
||||
return error_code in (QuicErrorCode.NO_ERROR, H3ErrorCode.H3_NO_ERROR)
|
||||
|
||||
|
||||
def tls_settings_to_configuration(
|
||||
settings: QuicTlsSettings,
|
||||
is_client: bool,
|
||||
server_name: str | None = None,
|
||||
) -> QuicConfiguration:
|
||||
"""Converts `QuicTlsSettings` to `QuicConfiguration`."""
|
||||
|
||||
return QuicConfiguration(
|
||||
alpn_protocols=settings.alpn_protocols,
|
||||
is_client=is_client,
|
||||
secrets_log_file=(
|
||||
QuicSecretsLogger(tls.log_master_secret) # type: ignore
|
||||
if tls.log_master_secret is not None
|
||||
else None
|
||||
),
|
||||
server_name=server_name,
|
||||
cafile=settings.ca_file,
|
||||
capath=settings.ca_path,
|
||||
certificate=settings.certificate,
|
||||
certificate_chain=settings.certificate_chain,
|
||||
cipher_suites=settings.cipher_suites,
|
||||
private_key=settings.certificate_private_key,
|
||||
verify_mode=settings.verify_mode,
|
||||
max_datagram_frame_size=65536,
|
||||
)
|
||||
143
venv/Lib/site-packages/mitmproxy/proxy/layers/tcp.py
Normal file
143
venv/Lib/site-packages/mitmproxy/proxy/layers/tcp.py
Normal file
@@ -0,0 +1,143 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from mitmproxy import flow
|
||||
from mitmproxy import tcp
|
||||
from mitmproxy.connection import Connection
|
||||
from mitmproxy.connection import ConnectionState
|
||||
from mitmproxy.proxy import commands
|
||||
from mitmproxy.proxy import events
|
||||
from mitmproxy.proxy import layer
|
||||
from mitmproxy.proxy.commands import StartHook
|
||||
from mitmproxy.proxy.context import Context
|
||||
from mitmproxy.proxy.events import MessageInjected
|
||||
from mitmproxy.proxy.utils import expect
|
||||
|
||||
|
||||
@dataclass
|
||||
class TcpStartHook(StartHook):
|
||||
"""
|
||||
A TCP connection has started.
|
||||
"""
|
||||
|
||||
flow: tcp.TCPFlow
|
||||
|
||||
|
||||
@dataclass
|
||||
class TcpMessageHook(StartHook):
|
||||
"""
|
||||
A TCP connection has received a message. The most recent message
|
||||
will be flow.messages[-1]. The message is user-modifiable.
|
||||
"""
|
||||
|
||||
flow: tcp.TCPFlow
|
||||
|
||||
|
||||
@dataclass
|
||||
class TcpEndHook(StartHook):
|
||||
"""
|
||||
A TCP connection has ended.
|
||||
"""
|
||||
|
||||
flow: tcp.TCPFlow
|
||||
|
||||
|
||||
@dataclass
|
||||
class TcpErrorHook(StartHook):
|
||||
"""
|
||||
A TCP error has occurred.
|
||||
|
||||
Every TCP flow will receive either a tcp_error or a tcp_end event, but not both.
|
||||
"""
|
||||
|
||||
flow: tcp.TCPFlow
|
||||
|
||||
|
||||
class TcpMessageInjected(MessageInjected[tcp.TCPMessage]):
|
||||
"""
|
||||
The user has injected a custom TCP message.
|
||||
"""
|
||||
|
||||
|
||||
class TCPLayer(layer.Layer):
|
||||
"""
|
||||
Simple TCP layer that just relays messages right now.
|
||||
"""
|
||||
|
||||
flow: tcp.TCPFlow | None
|
||||
|
||||
def __init__(self, context: Context, ignore: bool = False):
|
||||
super().__init__(context)
|
||||
if ignore:
|
||||
self.flow = None
|
||||
else:
|
||||
self.flow = tcp.TCPFlow(self.context.client, self.context.server, True)
|
||||
|
||||
@expect(events.Start)
|
||||
def start(self, _) -> layer.CommandGenerator[None]:
|
||||
if self.flow:
|
||||
yield TcpStartHook(self.flow)
|
||||
|
||||
if self.context.server.timestamp_start is None:
|
||||
err = yield commands.OpenConnection(self.context.server)
|
||||
if err:
|
||||
if self.flow:
|
||||
self.flow.error = flow.Error(str(err))
|
||||
yield TcpErrorHook(self.flow)
|
||||
yield commands.CloseConnection(self.context.client)
|
||||
self._handle_event = self.done
|
||||
return
|
||||
self._handle_event = self.relay_messages
|
||||
|
||||
_handle_event = start
|
||||
|
||||
@expect(events.DataReceived, events.ConnectionClosed, TcpMessageInjected)
|
||||
def relay_messages(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
if isinstance(event, TcpMessageInjected):
|
||||
# we just spoof that we received data here and then process that regularly.
|
||||
event = events.DataReceived(
|
||||
self.context.client
|
||||
if event.message.from_client
|
||||
else self.context.server,
|
||||
event.message.content,
|
||||
)
|
||||
|
||||
assert isinstance(event, events.ConnectionEvent)
|
||||
|
||||
from_client = event.connection == self.context.client
|
||||
send_to: Connection
|
||||
if from_client:
|
||||
send_to = self.context.server
|
||||
else:
|
||||
send_to = self.context.client
|
||||
|
||||
if isinstance(event, events.DataReceived):
|
||||
if self.flow:
|
||||
tcp_message = tcp.TCPMessage(from_client, event.data)
|
||||
self.flow.messages.append(tcp_message)
|
||||
yield TcpMessageHook(self.flow)
|
||||
yield commands.SendData(send_to, tcp_message.content)
|
||||
else:
|
||||
yield commands.SendData(send_to, event.data)
|
||||
|
||||
elif isinstance(event, events.ConnectionClosed):
|
||||
all_done = not (
|
||||
(self.context.client.state & ConnectionState.CAN_READ)
|
||||
or (self.context.server.state & ConnectionState.CAN_READ)
|
||||
)
|
||||
if all_done:
|
||||
self._handle_event = self.done
|
||||
if self.context.server.state is not ConnectionState.CLOSED:
|
||||
yield commands.CloseConnection(self.context.server)
|
||||
if self.context.client.state is not ConnectionState.CLOSED:
|
||||
yield commands.CloseConnection(self.context.client)
|
||||
if self.flow:
|
||||
yield TcpEndHook(self.flow)
|
||||
self.flow.live = False
|
||||
else:
|
||||
yield commands.CloseTcpConnection(send_to, half_close=True)
|
||||
else:
|
||||
raise AssertionError(f"Unexpected event: {event}")
|
||||
|
||||
@expect(events.DataReceived, events.ConnectionClosed, TcpMessageInjected)
|
||||
def done(self, _) -> layer.CommandGenerator[None]:
|
||||
yield from ()
|
||||
692
venv/Lib/site-packages/mitmproxy/proxy/layers/tls.py
Normal file
692
venv/Lib/site-packages/mitmproxy/proxy/layers/tls.py
Normal file
@@ -0,0 +1,692 @@
|
||||
import struct
|
||||
import time
|
||||
import typing
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from logging import DEBUG
|
||||
from logging import ERROR
|
||||
from logging import INFO
|
||||
from logging import WARNING
|
||||
|
||||
from OpenSSL import SSL
|
||||
|
||||
from mitmproxy import certs
|
||||
from mitmproxy import connection
|
||||
from mitmproxy.connection import TlsVersion
|
||||
from mitmproxy.net.tls import starts_like_dtls_record
|
||||
from mitmproxy.net.tls import starts_like_tls_record
|
||||
from mitmproxy.proxy import commands
|
||||
from mitmproxy.proxy import context
|
||||
from mitmproxy.proxy import events
|
||||
from mitmproxy.proxy import layer
|
||||
from mitmproxy.proxy import tunnel
|
||||
from mitmproxy.proxy.commands import StartHook
|
||||
from mitmproxy.proxy.layers import tcp
|
||||
from mitmproxy.proxy.layers import udp
|
||||
from mitmproxy.tls import ClientHello
|
||||
from mitmproxy.tls import ClientHelloData
|
||||
from mitmproxy.tls import TlsData
|
||||
from mitmproxy.utils import human
|
||||
|
||||
|
||||
def handshake_record_contents(data: bytes) -> Iterator[bytes]:
|
||||
"""
|
||||
Returns a generator that yields the bytes contained in each handshake record.
|
||||
This will raise an error on the first non-handshake record, so fully exhausting this
|
||||
generator is a bad idea.
|
||||
"""
|
||||
offset = 0
|
||||
while True:
|
||||
if len(data) < offset + 5:
|
||||
return
|
||||
record_header = data[offset : offset + 5]
|
||||
if not starts_like_tls_record(record_header):
|
||||
raise ValueError(f"Expected TLS record, got {record_header!r} instead.")
|
||||
record_size = struct.unpack("!H", record_header[3:])[0]
|
||||
if record_size == 0:
|
||||
raise ValueError("Record must not be empty.")
|
||||
offset += 5
|
||||
|
||||
if len(data) < offset + record_size:
|
||||
return
|
||||
record_body = data[offset : offset + record_size]
|
||||
yield record_body
|
||||
offset += record_size
|
||||
|
||||
|
||||
def get_client_hello(data: bytes) -> bytes | None:
|
||||
"""
|
||||
Read all TLS records that contain the initial ClientHello.
|
||||
Returns the raw handshake packet bytes, without TLS record headers.
|
||||
"""
|
||||
client_hello = b""
|
||||
for d in handshake_record_contents(data):
|
||||
client_hello += d
|
||||
if len(client_hello) >= 4:
|
||||
client_hello_size = struct.unpack("!I", b"\x00" + client_hello[1:4])[0] + 4
|
||||
if len(client_hello) >= client_hello_size:
|
||||
return client_hello[:client_hello_size]
|
||||
return None
|
||||
|
||||
|
||||
def parse_client_hello(data: bytes) -> ClientHello | None:
|
||||
"""
|
||||
Check if the supplied bytes contain a full ClientHello message,
|
||||
and if so, parse it.
|
||||
|
||||
Returns:
|
||||
- A ClientHello object on success
|
||||
- None, if the TLS record is not complete
|
||||
|
||||
Raises:
|
||||
- A ValueError, if the passed ClientHello is invalid
|
||||
"""
|
||||
# Check if ClientHello is complete
|
||||
client_hello = get_client_hello(data)
|
||||
if client_hello:
|
||||
try:
|
||||
return ClientHello(client_hello[4:])
|
||||
except EOFError as e:
|
||||
raise ValueError("Invalid ClientHello") from e
|
||||
return None
|
||||
|
||||
|
||||
def dtls_handshake_record_contents(data: bytes) -> Iterator[bytes]:
|
||||
"""
|
||||
Returns a generator that yields the bytes contained in each handshake record.
|
||||
This will raise an error on the first non-handshake record, so fully exhausting this
|
||||
generator is a bad idea.
|
||||
"""
|
||||
offset = 0
|
||||
while True:
|
||||
# DTLS includes two new fields, totaling 8 bytes, between Version and Length
|
||||
if len(data) < offset + 13:
|
||||
return
|
||||
record_header = data[offset : offset + 13]
|
||||
if not starts_like_dtls_record(record_header):
|
||||
raise ValueError(f"Expected DTLS record, got {record_header!r} instead.")
|
||||
# Length fields starts at 11
|
||||
record_size = struct.unpack("!H", record_header[11:])[0]
|
||||
if record_size == 0:
|
||||
raise ValueError("Record must not be empty.")
|
||||
offset += 13
|
||||
|
||||
if len(data) < offset + record_size:
|
||||
return
|
||||
record_body = data[offset : offset + record_size]
|
||||
yield record_body
|
||||
offset += record_size
|
||||
|
||||
|
||||
def get_dtls_client_hello(data: bytes) -> bytes | None:
|
||||
"""
|
||||
Read all DTLS records that contain the initial ClientHello.
|
||||
Returns the raw handshake packet bytes, without TLS record headers.
|
||||
"""
|
||||
client_hello = b""
|
||||
for d in dtls_handshake_record_contents(data):
|
||||
client_hello += d
|
||||
if len(client_hello) >= 13:
|
||||
# comment about slicing: we skip the epoch and sequence number
|
||||
client_hello_size = (
|
||||
struct.unpack("!I", b"\x00" + client_hello[9:12])[0] + 12
|
||||
)
|
||||
if len(client_hello) >= client_hello_size:
|
||||
return client_hello[:client_hello_size]
|
||||
return None
|
||||
|
||||
|
||||
def dtls_parse_client_hello(data: bytes) -> ClientHello | None:
|
||||
"""
|
||||
Check if the supplied bytes contain a full ClientHello message,
|
||||
and if so, parse it.
|
||||
|
||||
Returns:
|
||||
- A ClientHello object on success
|
||||
- None, if the TLS record is not complete
|
||||
|
||||
Raises:
|
||||
- A ValueError, if the passed ClientHello is invalid
|
||||
"""
|
||||
# Check if ClientHello is complete
|
||||
client_hello = get_dtls_client_hello(data)
|
||||
if client_hello:
|
||||
try:
|
||||
return ClientHello(client_hello[12:], dtls=True)
|
||||
except EOFError as e:
|
||||
raise ValueError("Invalid ClientHello") from e
|
||||
return None
|
||||
|
||||
|
||||
HTTP1_ALPNS = (b"http/1.1", b"http/1.0", b"http/0.9")
|
||||
HTTP2_ALPN = b"h2"
|
||||
HTTP3_ALPN = b"h3"
|
||||
HTTP_ALPNS = (HTTP3_ALPN, HTTP2_ALPN, *HTTP1_ALPNS)
|
||||
|
||||
|
||||
# We need these classes as hooks can only have one argument at the moment.
|
||||
|
||||
|
||||
@dataclass
|
||||
class TlsClienthelloHook(StartHook):
|
||||
"""
|
||||
Mitmproxy has received a TLS ClientHello message.
|
||||
|
||||
This hook decides whether a server connection is needed
|
||||
to negotiate TLS with the client (data.establish_server_tls_first)
|
||||
"""
|
||||
|
||||
data: ClientHelloData
|
||||
|
||||
|
||||
@dataclass
|
||||
class TlsStartClientHook(StartHook):
|
||||
"""
|
||||
TLS negotation between mitmproxy and a client is about to start.
|
||||
|
||||
An addon is expected to initialize data.ssl_conn.
|
||||
(by default, this is done by `mitmproxy.addons.tlsconfig`)
|
||||
"""
|
||||
|
||||
data: TlsData
|
||||
|
||||
|
||||
@dataclass
|
||||
class TlsStartServerHook(StartHook):
|
||||
"""
|
||||
TLS negotation between mitmproxy and a server is about to start.
|
||||
|
||||
An addon is expected to initialize data.ssl_conn.
|
||||
(by default, this is done by `mitmproxy.addons.tlsconfig`)
|
||||
"""
|
||||
|
||||
data: TlsData
|
||||
|
||||
|
||||
@dataclass
|
||||
class TlsEstablishedClientHook(StartHook):
|
||||
"""
|
||||
The TLS handshake with the client has been completed successfully.
|
||||
"""
|
||||
|
||||
data: TlsData
|
||||
|
||||
|
||||
@dataclass
|
||||
class TlsEstablishedServerHook(StartHook):
|
||||
"""
|
||||
The TLS handshake with the server has been completed successfully.
|
||||
"""
|
||||
|
||||
data: TlsData
|
||||
|
||||
|
||||
@dataclass
|
||||
class TlsFailedClientHook(StartHook):
|
||||
"""
|
||||
The TLS handshake with the client has failed.
|
||||
"""
|
||||
|
||||
data: TlsData
|
||||
|
||||
|
||||
@dataclass
|
||||
class TlsFailedServerHook(StartHook):
|
||||
"""
|
||||
The TLS handshake with the server has failed.
|
||||
"""
|
||||
|
||||
data: TlsData
|
||||
|
||||
|
||||
class TLSLayer(tunnel.TunnelLayer):
|
||||
tls: SSL.Connection = None # type: ignore
|
||||
"""The OpenSSL connection object"""
|
||||
|
||||
def __init__(self, context: context.Context, conn: connection.Connection):
|
||||
super().__init__(
|
||||
context,
|
||||
tunnel_connection=conn,
|
||||
conn=conn,
|
||||
)
|
||||
|
||||
conn.tls = True
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
super().__repr__().replace(")", f" {self.conn.sni!r} {self.conn.alpn!r})")
|
||||
)
|
||||
|
||||
@property
|
||||
def is_dtls(self):
|
||||
return self.conn.transport_protocol == "udp"
|
||||
|
||||
@property
|
||||
def proto_name(self):
|
||||
return "DTLS" if self.is_dtls else "TLS"
|
||||
|
||||
def start_tls(self) -> layer.CommandGenerator[None]:
|
||||
assert not self.tls
|
||||
|
||||
tls_start = TlsData(self.conn, self.context, is_dtls=self.is_dtls)
|
||||
if self.conn == self.context.client:
|
||||
yield TlsStartClientHook(tls_start)
|
||||
else:
|
||||
yield TlsStartServerHook(tls_start)
|
||||
if not tls_start.ssl_conn:
|
||||
yield commands.Log(
|
||||
f"No {self.proto_name} context was provided, failing connection.", ERROR
|
||||
)
|
||||
yield commands.CloseConnection(self.conn)
|
||||
return
|
||||
assert tls_start.ssl_conn
|
||||
self.tls = tls_start.ssl_conn
|
||||
|
||||
def tls_interact(self) -> layer.CommandGenerator[None]:
|
||||
while True:
|
||||
try:
|
||||
data = self.tls.bio_read(65535)
|
||||
except SSL.WantReadError:
|
||||
return # Okay, nothing more waiting to be sent.
|
||||
else:
|
||||
yield commands.SendData(self.conn, data)
|
||||
|
||||
def receive_handshake_data(
|
||||
self, data: bytes
|
||||
) -> layer.CommandGenerator[tuple[bool, str | None]]:
|
||||
# bio_write errors for b"", so we need to check first if we actually received something.
|
||||
if data:
|
||||
self.tls.bio_write(data)
|
||||
try:
|
||||
self.tls.do_handshake()
|
||||
except SSL.WantReadError:
|
||||
yield from self.tls_interact()
|
||||
return False, None
|
||||
except SSL.Error as e:
|
||||
# provide more detailed information for some errors.
|
||||
last_err = (
|
||||
e.args and isinstance(e.args[0], list) and e.args[0] and e.args[0][-1]
|
||||
)
|
||||
if last_err in [
|
||||
(
|
||||
"SSL routines",
|
||||
"tls_process_server_certificate",
|
||||
"certificate verify failed",
|
||||
),
|
||||
("SSL routines", "", "certificate verify failed"), # OpenSSL 3+
|
||||
]:
|
||||
verify_result = SSL._lib.SSL_get_verify_result(self.tls._ssl) # type: ignore
|
||||
error = SSL._ffi.string( # type: ignore
|
||||
SSL._lib.X509_verify_cert_error_string(verify_result) # type: ignore
|
||||
).decode()
|
||||
err = f"Certificate verify failed: {error}"
|
||||
elif last_err in [
|
||||
("SSL routines", "ssl3_read_bytes", "tlsv1 alert unknown ca"),
|
||||
("SSL routines", "ssl3_read_bytes", "sslv3 alert bad certificate"),
|
||||
("SSL routines", "ssl3_read_bytes", "ssl/tls alert bad certificate"),
|
||||
("SSL routines", "", "tlsv1 alert unknown ca"), # OpenSSL 3+
|
||||
("SSL routines", "", "sslv3 alert bad certificate"), # OpenSSL 3+
|
||||
("SSL routines", "", "ssl/tls alert bad certificate"), # OpenSSL 3.2+
|
||||
]:
|
||||
assert isinstance(last_err, tuple)
|
||||
err = last_err[2]
|
||||
elif (
|
||||
last_err
|
||||
in [
|
||||
("SSL routines", "ssl3_get_record", "wrong version number"),
|
||||
("SSL routines", "", "wrong version number"), # OpenSSL 3+
|
||||
("SSL routines", "", "packet length too long"), # OpenSSL 3+
|
||||
("SSL routines", "", "record layer failure"), # OpenSSL 3+
|
||||
]
|
||||
and data[:4].isascii()
|
||||
):
|
||||
err = f"The remote server does not speak TLS."
|
||||
elif last_err in [
|
||||
("SSL routines", "ssl3_read_bytes", "tlsv1 alert protocol version"),
|
||||
("SSL routines", "", "tlsv1 alert protocol version"), # OpenSSL 3+
|
||||
]:
|
||||
err = (
|
||||
f"The remote server and mitmproxy cannot agree on a TLS version to use. "
|
||||
f"You may need to adjust mitmproxy's tls_version_server_min option."
|
||||
)
|
||||
else:
|
||||
err = f"OpenSSL {e!r}"
|
||||
return False, err
|
||||
else:
|
||||
# Here we set all attributes that are only known *after* the handshake.
|
||||
|
||||
# Get all peer certificates.
|
||||
# https://www.openssl.org/docs/man1.1.1/man3/SSL_get_peer_cert_chain.html
|
||||
# If called on the client side, the stack also contains the peer's certificate; if called on the server
|
||||
# side, the peer's certificate must be obtained separately using SSL_get_peer_certificate(3).
|
||||
all_certs = self.tls.get_peer_cert_chain() or []
|
||||
if self.conn == self.context.client:
|
||||
cert = self.tls.get_peer_certificate()
|
||||
if cert:
|
||||
all_certs.insert(0, cert)
|
||||
self.conn.certificate_list = []
|
||||
for cert in all_certs:
|
||||
try:
|
||||
# This may fail for weird certs, https://github.com/mitmproxy/mitmproxy/issues/6968.
|
||||
parsed_cert = certs.Cert.from_pyopenssl(cert)
|
||||
except ValueError as e:
|
||||
yield commands.Log(
|
||||
f"{self.debug}[tls] failed to parse certificate: {e}", WARNING
|
||||
)
|
||||
else:
|
||||
self.conn.certificate_list.append(parsed_cert)
|
||||
|
||||
self.conn.timestamp_tls_setup = time.time()
|
||||
self.conn.alpn = self.tls.get_alpn_proto_negotiated()
|
||||
self.conn.cipher = self.tls.get_cipher_name()
|
||||
self.conn.tls_version = typing.cast(
|
||||
TlsVersion, self.tls.get_protocol_version_name()
|
||||
)
|
||||
if self.debug:
|
||||
yield commands.Log(
|
||||
f"{self.debug}[tls] tls established: {self.conn}", DEBUG
|
||||
)
|
||||
if self.conn == self.context.client:
|
||||
yield TlsEstablishedClientHook(
|
||||
TlsData(self.conn, self.context, self.tls)
|
||||
)
|
||||
else:
|
||||
yield TlsEstablishedServerHook(
|
||||
TlsData(self.conn, self.context, self.tls)
|
||||
)
|
||||
yield from self.receive_data(b"")
|
||||
return True, None
|
||||
|
||||
def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]:
|
||||
self.conn.error = err
|
||||
if self.conn == self.context.client:
|
||||
yield TlsFailedClientHook(TlsData(self.conn, self.context, self.tls))
|
||||
else:
|
||||
yield TlsFailedServerHook(TlsData(self.conn, self.context, self.tls))
|
||||
yield from super().on_handshake_error(err)
|
||||
|
||||
def receive_data(self, data: bytes) -> layer.CommandGenerator[None]:
|
||||
if data:
|
||||
self.tls.bio_write(data)
|
||||
|
||||
plaintext = bytearray()
|
||||
close = False
|
||||
while True:
|
||||
try:
|
||||
plaintext.extend(self.tls.recv(65535))
|
||||
except SSL.WantReadError:
|
||||
break
|
||||
except SSL.ZeroReturnError:
|
||||
close = True
|
||||
break
|
||||
except SSL.Error as e:
|
||||
# This may be happening because the other side send an alert.
|
||||
# There's somewhat ugly behavior with Firefox on Android here,
|
||||
# which upon mistrusting a certificate still completes the handshake
|
||||
# and then sends an alert in the next packet. At this point we have unfortunately
|
||||
# already fired out `tls_established_client` hook.
|
||||
yield commands.Log(f"TLS Error: {e}", WARNING)
|
||||
break
|
||||
|
||||
# Can we send something?
|
||||
# Note that this must happen after `recv()`, which may have advanced the state machine.
|
||||
# https://github.com/mitmproxy/mitmproxy/discussions/7550
|
||||
yield from self.tls_interact()
|
||||
|
||||
if plaintext:
|
||||
yield from self.event_to_child(
|
||||
events.DataReceived(self.conn, bytes(plaintext))
|
||||
)
|
||||
if close:
|
||||
self.conn.state &= ~connection.ConnectionState.CAN_READ
|
||||
if self.debug:
|
||||
yield commands.Log(f"{self.debug}[tls] close_notify {self.conn}", DEBUG)
|
||||
yield from self.event_to_child(events.ConnectionClosed(self.conn))
|
||||
|
||||
def receive_close(self) -> layer.CommandGenerator[None]:
|
||||
if self.tls.get_shutdown() & SSL.RECEIVED_SHUTDOWN:
|
||||
pass # We have already dispatched a ConnectionClosed to the child layer.
|
||||
else:
|
||||
yield from super().receive_close()
|
||||
|
||||
def send_data(self, data: bytes) -> layer.CommandGenerator[None]:
|
||||
try:
|
||||
self.tls.sendall(data)
|
||||
except (SSL.ZeroReturnError, SSL.SysCallError):
|
||||
# The other peer may still be trying to send data over, which we discard here.
|
||||
pass
|
||||
yield from self.tls_interact()
|
||||
|
||||
def send_close(
|
||||
self, command: commands.CloseConnection
|
||||
) -> layer.CommandGenerator[None]:
|
||||
# We should probably shutdown the TLS connection properly here.
|
||||
yield from super().send_close(command)
|
||||
|
||||
|
||||
class ServerTLSLayer(TLSLayer):
|
||||
"""
|
||||
This layer establishes TLS for a single server connection.
|
||||
"""
|
||||
|
||||
wait_for_clienthello: bool = False
|
||||
|
||||
def __init__(self, context: context.Context, conn: connection.Server | None = None):
|
||||
super().__init__(context, conn or context.server)
|
||||
|
||||
def start_handshake(self) -> layer.CommandGenerator[None]:
|
||||
wait_for_clienthello = (
|
||||
# if command_to_reply_to is set, we've been instructed to open the connection from the child layer.
|
||||
# in that case any potential ClientHello is already parsed (by the ClientTLS child layer).
|
||||
not self.command_to_reply_to
|
||||
# if command_to_reply_to is not set, the connection was already open when this layer received its Start
|
||||
# event (eager connection strategy). We now want to establish TLS right away, _unless_ we already know
|
||||
# that there's TLS on the client side as well (we check if our immediate child layer is set to be ClientTLS)
|
||||
# In this case want to wait for ClientHello to be parsed, so that we can incorporate SNI/ALPN from there.
|
||||
and isinstance(self.child_layer, ClientTLSLayer)
|
||||
)
|
||||
if wait_for_clienthello:
|
||||
self.wait_for_clienthello = True
|
||||
self.tunnel_state = tunnel.TunnelState.CLOSED
|
||||
else:
|
||||
yield from self.start_tls()
|
||||
if self.tls:
|
||||
yield from self.receive_handshake_data(b"")
|
||||
|
||||
def event_to_child(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
if self.wait_for_clienthello:
|
||||
for command in super().event_to_child(event):
|
||||
if (
|
||||
isinstance(command, commands.OpenConnection)
|
||||
and command.connection == self.conn
|
||||
):
|
||||
self.wait_for_clienthello = False
|
||||
# swallow OpenConnection here by not re-yielding it.
|
||||
else:
|
||||
yield command
|
||||
else:
|
||||
yield from super().event_to_child(event)
|
||||
|
||||
def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]:
|
||||
yield commands.Log(f"Server TLS handshake failed. {err}", level=WARNING)
|
||||
yield from super().on_handshake_error(err)
|
||||
|
||||
|
||||
class ClientTLSLayer(TLSLayer):
|
||||
"""
|
||||
This layer establishes TLS on a single client connection.
|
||||
|
||||
┌─────┐
|
||||
│Start│
|
||||
└┬────┘
|
||||
↓
|
||||
┌────────────────────┐
|
||||
│Wait for ClientHello│
|
||||
└┬───────────────────┘
|
||||
↓
|
||||
┌────────────────┐
|
||||
│Process messages│
|
||||
└────────────────┘
|
||||
|
||||
"""
|
||||
|
||||
recv_buffer: bytearray
|
||||
server_tls_available: bool
|
||||
client_hello_parsed: bool = False
|
||||
|
||||
def __init__(self, context: context.Context):
|
||||
if context.client.tls:
|
||||
# In the case of TLS-over-TLS, we already have client TLS. As the outer TLS connection between client
|
||||
# and proxy isn't that interesting to us, we just unset the attributes here and keep the inner TLS
|
||||
# session's attributes.
|
||||
# Alternatively we could create a new Client instance,
|
||||
# but for now we keep it simple. There is a proof-of-concept at
|
||||
# https://github.com/mitmproxy/mitmproxy/commit/9b6e2a716888b7787514733b76a5936afa485352.
|
||||
context.client.alpn = None
|
||||
context.client.cipher = None
|
||||
context.client.sni = None
|
||||
context.client.timestamp_tls_setup = None
|
||||
context.client.tls_version = None
|
||||
context.client.certificate_list = []
|
||||
context.client.mitmcert = None
|
||||
context.client.alpn_offers = []
|
||||
context.client.cipher_list = []
|
||||
|
||||
super().__init__(context, context.client)
|
||||
self.server_tls_available = isinstance(self.context.layers[-2], ServerTLSLayer)
|
||||
self.recv_buffer = bytearray()
|
||||
|
||||
def start_handshake(self) -> layer.CommandGenerator[None]:
|
||||
yield from ()
|
||||
|
||||
def receive_handshake_data(
|
||||
self, data: bytes
|
||||
) -> layer.CommandGenerator[tuple[bool, str | None]]:
|
||||
if self.client_hello_parsed:
|
||||
return (yield from super().receive_handshake_data(data))
|
||||
self.recv_buffer.extend(data)
|
||||
try:
|
||||
if self.is_dtls:
|
||||
client_hello = dtls_parse_client_hello(self.recv_buffer)
|
||||
else:
|
||||
client_hello = parse_client_hello(self.recv_buffer)
|
||||
except ValueError:
|
||||
return False, f"Cannot parse ClientHello: {self.recv_buffer.hex()}"
|
||||
|
||||
if client_hello:
|
||||
self.client_hello_parsed = True
|
||||
else:
|
||||
return False, None
|
||||
|
||||
self.conn.sni = client_hello.sni
|
||||
self.conn.alpn_offers = client_hello.alpn_protocols
|
||||
tls_clienthello = ClientHelloData(self.context, client_hello)
|
||||
yield TlsClienthelloHook(tls_clienthello)
|
||||
|
||||
if tls_clienthello.ignore_connection:
|
||||
# we've figured out that we don't want to intercept this connection, so we assign fake connection objects
|
||||
# to all TLS layers. This makes the real connection contents just go through.
|
||||
self.conn = self.tunnel_connection = connection.Client(
|
||||
peername=("ignore-conn", 0), sockname=("ignore-conn", 0)
|
||||
)
|
||||
parent_layer = self.context.layers[self.context.layers.index(self) - 1]
|
||||
if isinstance(parent_layer, ServerTLSLayer):
|
||||
parent_layer.conn = parent_layer.tunnel_connection = connection.Server(
|
||||
address=None
|
||||
)
|
||||
if self.is_dtls:
|
||||
self.child_layer = udp.UDPLayer(self.context, ignore=True)
|
||||
else:
|
||||
self.child_layer = tcp.TCPLayer(self.context, ignore=True)
|
||||
yield from self.event_to_child(
|
||||
events.DataReceived(self.context.client, bytes(self.recv_buffer))
|
||||
)
|
||||
self.recv_buffer.clear()
|
||||
return True, None
|
||||
if (
|
||||
tls_clienthello.establish_server_tls_first
|
||||
and not self.context.server.tls_established
|
||||
):
|
||||
err = yield from self.start_server_tls()
|
||||
if err:
|
||||
yield commands.Log(
|
||||
f"Unable to establish {self.proto_name} connection with server ({err}). "
|
||||
f"Trying to establish {self.proto_name} with client anyway. "
|
||||
f"If you plan to redirect requests away from this server, "
|
||||
f"consider setting `connection_strategy` to `lazy` to suppress early connections."
|
||||
)
|
||||
|
||||
yield from self.start_tls()
|
||||
if not self.conn.connected:
|
||||
return False, "connection closed early"
|
||||
|
||||
ret = yield from super().receive_handshake_data(bytes(self.recv_buffer))
|
||||
self.recv_buffer.clear()
|
||||
return ret
|
||||
|
||||
def start_server_tls(self) -> layer.CommandGenerator[str | None]:
|
||||
"""
|
||||
We often need information from the upstream connection to establish TLS with the client.
|
||||
For example, we need to check if the client does ALPN or not.
|
||||
"""
|
||||
if not self.server_tls_available:
|
||||
return f"No server {self.proto_name} available."
|
||||
err = yield commands.OpenConnection(self.context.server)
|
||||
return err
|
||||
|
||||
def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]:
|
||||
if self.conn.sni:
|
||||
dest = self.conn.sni
|
||||
else:
|
||||
dest = human.format_address(self.context.server.address)
|
||||
level: int = WARNING
|
||||
if err.startswith("Cannot parse ClientHello"):
|
||||
pass
|
||||
elif (
|
||||
"('SSL routines', 'tls_early_post_process_client_hello', 'unsupported protocol')"
|
||||
in err
|
||||
or "('SSL routines', '', 'unsupported protocol')" in err # OpenSSL 3+
|
||||
):
|
||||
err = (
|
||||
f"Client and mitmproxy cannot agree on a TLS version to use. "
|
||||
f"You may need to adjust mitmproxy's tls_version_client_min option."
|
||||
)
|
||||
elif (
|
||||
"unknown ca" in err
|
||||
or "bad certificate" in err
|
||||
or "certificate unknown" in err
|
||||
):
|
||||
err = (
|
||||
f"The client does not trust the proxy's certificate for {dest} ({err})"
|
||||
)
|
||||
elif err == "connection closed":
|
||||
err = (
|
||||
f"The client disconnected during the handshake. If this happens consistently for {dest}, "
|
||||
f"this may indicate that the client does not trust the proxy's certificate."
|
||||
)
|
||||
level = INFO
|
||||
elif err == "connection closed early":
|
||||
pass
|
||||
else:
|
||||
err = f"The client may not trust the proxy's certificate for {dest} ({err})"
|
||||
if err != "connection closed early":
|
||||
yield commands.Log(f"Client TLS handshake failed. {err}", level=level)
|
||||
yield from super().on_handshake_error(err)
|
||||
self.event_to_child = self.errored # type: ignore
|
||||
|
||||
def errored(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
if self.debug is not None:
|
||||
yield commands.Log(
|
||||
f"{self.debug}[tls] Swallowing {event} as handshake failed.", DEBUG
|
||||
)
|
||||
|
||||
|
||||
class MockTLSLayer(TLSLayer):
|
||||
"""Mock layer to disable actual TLS and use cleartext in tests.
|
||||
|
||||
Use like so:
|
||||
monkeypatch.setattr(tls, "ServerTLSLayer", tls.MockTLSLayer)
|
||||
"""
|
||||
|
||||
def __init__(self, ctx: context.Context):
|
||||
super().__init__(ctx, connection.Server(address=None))
|
||||
132
venv/Lib/site-packages/mitmproxy/proxy/layers/udp.py
Normal file
132
venv/Lib/site-packages/mitmproxy/proxy/layers/udp.py
Normal file
@@ -0,0 +1,132 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from mitmproxy import flow
|
||||
from mitmproxy import udp
|
||||
from mitmproxy.connection import Connection
|
||||
from mitmproxy.proxy import commands
|
||||
from mitmproxy.proxy import events
|
||||
from mitmproxy.proxy import layer
|
||||
from mitmproxy.proxy.commands import StartHook
|
||||
from mitmproxy.proxy.context import Context
|
||||
from mitmproxy.proxy.events import MessageInjected
|
||||
from mitmproxy.proxy.utils import expect
|
||||
|
||||
|
||||
@dataclass
|
||||
class UdpStartHook(StartHook):
|
||||
"""
|
||||
A UDP connection has started.
|
||||
"""
|
||||
|
||||
flow: udp.UDPFlow
|
||||
|
||||
|
||||
@dataclass
|
||||
class UdpMessageHook(StartHook):
|
||||
"""
|
||||
A UDP connection has received a message. The most recent message
|
||||
will be flow.messages[-1]. The message is user-modifiable.
|
||||
"""
|
||||
|
||||
flow: udp.UDPFlow
|
||||
|
||||
|
||||
@dataclass
|
||||
class UdpEndHook(StartHook):
|
||||
"""
|
||||
A UDP connection has ended.
|
||||
"""
|
||||
|
||||
flow: udp.UDPFlow
|
||||
|
||||
|
||||
@dataclass
|
||||
class UdpErrorHook(StartHook):
|
||||
"""
|
||||
A UDP error has occurred.
|
||||
|
||||
Every UDP flow will receive either a udp_error or a udp_end event, but not both.
|
||||
"""
|
||||
|
||||
flow: udp.UDPFlow
|
||||
|
||||
|
||||
class UdpMessageInjected(MessageInjected[udp.UDPMessage]):
|
||||
"""
|
||||
The user has injected a custom UDP message.
|
||||
"""
|
||||
|
||||
|
||||
class UDPLayer(layer.Layer):
|
||||
"""
|
||||
Simple UDP layer that just relays messages right now.
|
||||
"""
|
||||
|
||||
flow: udp.UDPFlow | None
|
||||
|
||||
def __init__(self, context: Context, ignore: bool = False):
|
||||
super().__init__(context)
|
||||
if ignore:
|
||||
self.flow = None
|
||||
else:
|
||||
self.flow = udp.UDPFlow(self.context.client, self.context.server, True)
|
||||
|
||||
@expect(events.Start)
|
||||
def start(self, _) -> layer.CommandGenerator[None]:
|
||||
if self.flow:
|
||||
yield UdpStartHook(self.flow)
|
||||
|
||||
if self.context.server.timestamp_start is None:
|
||||
err = yield commands.OpenConnection(self.context.server)
|
||||
if err:
|
||||
if self.flow:
|
||||
self.flow.error = flow.Error(str(err))
|
||||
yield UdpErrorHook(self.flow)
|
||||
yield commands.CloseConnection(self.context.client)
|
||||
self._handle_event = self.done
|
||||
return
|
||||
self._handle_event = self.relay_messages
|
||||
|
||||
_handle_event = start
|
||||
|
||||
@expect(events.DataReceived, events.ConnectionClosed, UdpMessageInjected)
|
||||
def relay_messages(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
if isinstance(event, UdpMessageInjected):
|
||||
# we just spoof that we received data here and then process that regularly.
|
||||
event = events.DataReceived(
|
||||
self.context.client
|
||||
if event.message.from_client
|
||||
else self.context.server,
|
||||
event.message.content,
|
||||
)
|
||||
|
||||
assert isinstance(event, events.ConnectionEvent)
|
||||
|
||||
from_client = event.connection == self.context.client
|
||||
send_to: Connection
|
||||
if from_client:
|
||||
send_to = self.context.server
|
||||
else:
|
||||
send_to = self.context.client
|
||||
|
||||
if isinstance(event, events.DataReceived):
|
||||
if self.flow:
|
||||
udp_message = udp.UDPMessage(from_client, event.data)
|
||||
self.flow.messages.append(udp_message)
|
||||
yield UdpMessageHook(self.flow)
|
||||
yield commands.SendData(send_to, udp_message.content)
|
||||
else:
|
||||
yield commands.SendData(send_to, event.data)
|
||||
|
||||
elif isinstance(event, events.ConnectionClosed):
|
||||
self._handle_event = self.done
|
||||
yield commands.CloseConnection(send_to)
|
||||
if self.flow:
|
||||
yield UdpEndHook(self.flow)
|
||||
self.flow.live = False
|
||||
else:
|
||||
raise AssertionError(f"Unexpected event: {event}")
|
||||
|
||||
@expect(events.DataReceived, events.ConnectionClosed, UdpMessageInjected)
|
||||
def done(self, _) -> layer.CommandGenerator[None]:
|
||||
yield from ()
|
||||
272
venv/Lib/site-packages/mitmproxy/proxy/layers/websocket.py
Normal file
272
venv/Lib/site-packages/mitmproxy/proxy/layers/websocket.py
Normal file
@@ -0,0 +1,272 @@
|
||||
import time
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
|
||||
import wsproto.extensions
|
||||
import wsproto.frame_protocol
|
||||
import wsproto.utilities
|
||||
from wsproto import ConnectionState
|
||||
from wsproto.frame_protocol import Opcode
|
||||
|
||||
from mitmproxy import connection
|
||||
from mitmproxy import http
|
||||
from mitmproxy import websocket
|
||||
from mitmproxy.proxy import commands
|
||||
from mitmproxy.proxy import events
|
||||
from mitmproxy.proxy import layer
|
||||
from mitmproxy.proxy.commands import StartHook
|
||||
from mitmproxy.proxy.context import Context
|
||||
from mitmproxy.proxy.events import MessageInjected
|
||||
from mitmproxy.proxy.utils import expect
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebsocketStartHook(StartHook):
|
||||
"""
|
||||
A WebSocket connection has commenced.
|
||||
"""
|
||||
|
||||
flow: http.HTTPFlow
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebsocketMessageHook(StartHook):
|
||||
"""
|
||||
Called when a WebSocket message is received from the client or
|
||||
server. The most recent message will be flow.messages[-1]. The
|
||||
message is user-modifiable. Currently there are two types of
|
||||
messages, corresponding to the BINARY and TEXT frame types.
|
||||
"""
|
||||
|
||||
flow: http.HTTPFlow
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebsocketEndHook(StartHook):
|
||||
"""
|
||||
A WebSocket connection has ended.
|
||||
You can check `flow.websocket.close_code` to determine why it ended.
|
||||
"""
|
||||
|
||||
flow: http.HTTPFlow
|
||||
|
||||
|
||||
class WebSocketMessageInjected(MessageInjected[websocket.WebSocketMessage]):
|
||||
"""
|
||||
The user has injected a custom WebSocket message.
|
||||
"""
|
||||
|
||||
|
||||
class WebsocketConnection(wsproto.Connection):
|
||||
"""
|
||||
A very thin wrapper around wsproto.Connection:
|
||||
|
||||
- we keep the underlying connection as an attribute for easy access.
|
||||
- we add a framebuffer for incomplete messages
|
||||
- we wrap .send() so that we can directly yield it.
|
||||
"""
|
||||
|
||||
conn: connection.Connection
|
||||
frame_buf: list[bytes]
|
||||
|
||||
def __init__(self, *args, conn: connection.Connection, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.conn = conn
|
||||
self.frame_buf = [b""]
|
||||
|
||||
def send2(self, event: wsproto.events.Event) -> commands.SendData:
|
||||
data = self.send(event)
|
||||
return commands.SendData(self.conn, data)
|
||||
|
||||
def __repr__(self):
|
||||
return f"WebsocketConnection<{self.state.name}, {self.conn}>"
|
||||
|
||||
|
||||
class WebsocketLayer(layer.Layer):
|
||||
"""
|
||||
WebSocket layer that intercepts and relays messages.
|
||||
"""
|
||||
|
||||
flow: http.HTTPFlow
|
||||
client_ws: WebsocketConnection
|
||||
server_ws: WebsocketConnection
|
||||
|
||||
def __init__(self, context: Context, flow: http.HTTPFlow):
|
||||
super().__init__(context)
|
||||
self.flow = flow
|
||||
|
||||
@expect(events.Start)
|
||||
def start(self, _) -> layer.CommandGenerator[None]:
|
||||
client_extensions = []
|
||||
server_extensions = []
|
||||
|
||||
# Parse extension headers. We only support deflate at the moment and ignore everything else.
|
||||
assert self.flow.response # satisfy type checker
|
||||
ext_header = self.flow.response.headers.get("Sec-WebSocket-Extensions", "")
|
||||
if ext_header:
|
||||
for ext in wsproto.utilities.split_comma_header(
|
||||
ext_header.encode("ascii", "replace")
|
||||
):
|
||||
ext_name = ext.split(";", 1)[0].strip()
|
||||
if ext_name == wsproto.extensions.PerMessageDeflate.name:
|
||||
client_deflate = wsproto.extensions.PerMessageDeflate()
|
||||
client_deflate.finalize(ext)
|
||||
client_extensions.append(client_deflate)
|
||||
server_deflate = wsproto.extensions.PerMessageDeflate()
|
||||
server_deflate.finalize(ext)
|
||||
server_extensions.append(server_deflate)
|
||||
else:
|
||||
yield commands.Log(
|
||||
f"Ignoring unknown WebSocket extension {ext_name!r}."
|
||||
)
|
||||
|
||||
self.client_ws = WebsocketConnection(
|
||||
wsproto.ConnectionType.SERVER, client_extensions, conn=self.context.client
|
||||
)
|
||||
self.server_ws = WebsocketConnection(
|
||||
wsproto.ConnectionType.CLIENT, server_extensions, conn=self.context.server
|
||||
)
|
||||
|
||||
yield WebsocketStartHook(self.flow)
|
||||
|
||||
self._handle_event = self.relay_messages
|
||||
|
||||
_handle_event = start
|
||||
|
||||
@expect(events.DataReceived, events.ConnectionClosed, WebSocketMessageInjected)
|
||||
def relay_messages(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
assert self.flow.websocket # satisfy type checker
|
||||
|
||||
if isinstance(event, events.ConnectionEvent):
|
||||
from_client = event.connection == self.context.client
|
||||
injected = False
|
||||
elif isinstance(event, WebSocketMessageInjected):
|
||||
from_client = event.message.from_client
|
||||
injected = True
|
||||
else:
|
||||
raise AssertionError(f"Unexpected event: {event}")
|
||||
|
||||
from_str = "client" if from_client else "server"
|
||||
if from_client:
|
||||
src_ws = self.client_ws
|
||||
dst_ws = self.server_ws
|
||||
else:
|
||||
src_ws = self.server_ws
|
||||
dst_ws = self.client_ws
|
||||
|
||||
if isinstance(event, events.DataReceived):
|
||||
src_ws.receive_data(event.data)
|
||||
elif isinstance(event, events.ConnectionClosed):
|
||||
src_ws.receive_data(None)
|
||||
elif isinstance(event, WebSocketMessageInjected):
|
||||
fragmentizer = Fragmentizer([], event.message.type == Opcode.TEXT)
|
||||
src_ws._events.extend(fragmentizer(event.message.content))
|
||||
else: # pragma: no cover
|
||||
raise AssertionError(f"Unexpected event: {event}")
|
||||
|
||||
for ws_event in src_ws.events():
|
||||
if isinstance(ws_event, wsproto.events.Message):
|
||||
is_text = isinstance(ws_event.data, str)
|
||||
if is_text:
|
||||
typ = Opcode.TEXT
|
||||
src_ws.frame_buf[-1] += ws_event.data.encode()
|
||||
else:
|
||||
typ = Opcode.BINARY
|
||||
src_ws.frame_buf[-1] += ws_event.data
|
||||
|
||||
if ws_event.message_finished:
|
||||
content = b"".join(src_ws.frame_buf)
|
||||
|
||||
fragmentizer = Fragmentizer(src_ws.frame_buf, is_text)
|
||||
src_ws.frame_buf = [b""]
|
||||
|
||||
message = websocket.WebSocketMessage(
|
||||
typ, from_client, content, injected=injected
|
||||
)
|
||||
self.flow.websocket.messages.append(message)
|
||||
yield WebsocketMessageHook(self.flow)
|
||||
|
||||
if not message.dropped:
|
||||
for msg in fragmentizer(message.content):
|
||||
yield dst_ws.send2(msg)
|
||||
|
||||
elif ws_event.frame_finished:
|
||||
src_ws.frame_buf.append(b"")
|
||||
|
||||
elif isinstance(ws_event, (wsproto.events.Ping, wsproto.events.Pong)):
|
||||
yield commands.Log(
|
||||
f"Received WebSocket {ws_event.__class__.__name__.lower()} from {from_str} "
|
||||
f"(payload: {bytes(ws_event.payload)!r})"
|
||||
)
|
||||
yield dst_ws.send2(ws_event)
|
||||
elif isinstance(ws_event, wsproto.events.CloseConnection):
|
||||
self.flow.websocket.timestamp_end = time.time()
|
||||
self.flow.websocket.closed_by_client = from_client
|
||||
self.flow.websocket.close_code = ws_event.code
|
||||
self.flow.websocket.close_reason = ws_event.reason
|
||||
|
||||
for ws in [self.server_ws, self.client_ws]:
|
||||
if ws.state in {
|
||||
ConnectionState.OPEN,
|
||||
ConnectionState.REMOTE_CLOSING,
|
||||
}:
|
||||
# response == original event, so no need to differentiate here.
|
||||
yield ws.send2(ws_event)
|
||||
yield commands.CloseConnection(ws.conn)
|
||||
yield WebsocketEndHook(self.flow)
|
||||
self.flow.live = False
|
||||
self._handle_event = self.done
|
||||
else: # pragma: no cover
|
||||
raise AssertionError(f"Unexpected WebSocket event: {ws_event}")
|
||||
|
||||
@expect(events.DataReceived, events.ConnectionClosed, WebSocketMessageInjected)
|
||||
def done(self, _) -> layer.CommandGenerator[None]:
|
||||
yield from ()
|
||||
|
||||
|
||||
class Fragmentizer:
|
||||
"""
|
||||
Theory (RFC 6455):
|
||||
Unless specified otherwise by an extension, frames have no semantic
|
||||
meaning. An intermediary might coalesce and/or split frames, [...]
|
||||
|
||||
Practice:
|
||||
Some WebSocket servers reject large payload sizes.
|
||||
Other WebSocket servers reject CONTINUATION frames.
|
||||
|
||||
As a workaround, we either retain the original chunking or, if the payload has been modified, use ~4kB chunks.
|
||||
If one deals with web servers that do not support CONTINUATION frames, addons need to monkeypatch FRAGMENT_SIZE
|
||||
if they need to modify the message.
|
||||
"""
|
||||
|
||||
# A bit less than 4kb to accommodate for headers.
|
||||
FRAGMENT_SIZE = 4000
|
||||
|
||||
def __init__(self, fragments: list[bytes], is_text: bool):
|
||||
self.fragment_lengths = [len(x) for x in fragments]
|
||||
self.is_text = is_text
|
||||
|
||||
def msg(self, data: bytes, message_finished: bool):
|
||||
if self.is_text:
|
||||
data_str = data.decode(errors="replace")
|
||||
return wsproto.events.TextMessage(
|
||||
data_str, message_finished=message_finished
|
||||
)
|
||||
else:
|
||||
return wsproto.events.BytesMessage(data, message_finished=message_finished)
|
||||
|
||||
def __call__(self, content: bytes) -> Iterator[wsproto.events.Message]:
|
||||
if len(content) == sum(self.fragment_lengths):
|
||||
# message has the same length, we can reuse the same sizes
|
||||
offset = 0
|
||||
for fl in self.fragment_lengths[:-1]:
|
||||
yield self.msg(content[offset : offset + fl], False)
|
||||
offset += fl
|
||||
yield self.msg(content[offset:], True)
|
||||
else:
|
||||
offset = 0
|
||||
total = len(content) - self.FRAGMENT_SIZE
|
||||
while offset < total:
|
||||
yield self.msg(content[offset : offset + self.FRAGMENT_SIZE], False)
|
||||
offset += self.FRAGMENT_SIZE
|
||||
yield self.msg(content[offset:], True)
|
||||
Reference in New Issue
Block a user