2025-12-25 upload
This commit is contained in:
272
venv/Lib/site-packages/mitmproxy/proxy/layers/websocket.py
Normal file
272
venv/Lib/site-packages/mitmproxy/proxy/layers/websocket.py
Normal file
@@ -0,0 +1,272 @@
|
||||
import time
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
|
||||
import wsproto.extensions
|
||||
import wsproto.frame_protocol
|
||||
import wsproto.utilities
|
||||
from wsproto import ConnectionState
|
||||
from wsproto.frame_protocol import Opcode
|
||||
|
||||
from mitmproxy import connection
|
||||
from mitmproxy import http
|
||||
from mitmproxy import websocket
|
||||
from mitmproxy.proxy import commands
|
||||
from mitmproxy.proxy import events
|
||||
from mitmproxy.proxy import layer
|
||||
from mitmproxy.proxy.commands import StartHook
|
||||
from mitmproxy.proxy.context import Context
|
||||
from mitmproxy.proxy.events import MessageInjected
|
||||
from mitmproxy.proxy.utils import expect
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebsocketStartHook(StartHook):
|
||||
"""
|
||||
A WebSocket connection has commenced.
|
||||
"""
|
||||
|
||||
flow: http.HTTPFlow
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebsocketMessageHook(StartHook):
|
||||
"""
|
||||
Called when a WebSocket message is received from the client or
|
||||
server. The most recent message will be flow.messages[-1]. The
|
||||
message is user-modifiable. Currently there are two types of
|
||||
messages, corresponding to the BINARY and TEXT frame types.
|
||||
"""
|
||||
|
||||
flow: http.HTTPFlow
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebsocketEndHook(StartHook):
|
||||
"""
|
||||
A WebSocket connection has ended.
|
||||
You can check `flow.websocket.close_code` to determine why it ended.
|
||||
"""
|
||||
|
||||
flow: http.HTTPFlow
|
||||
|
||||
|
||||
class WebSocketMessageInjected(MessageInjected[websocket.WebSocketMessage]):
|
||||
"""
|
||||
The user has injected a custom WebSocket message.
|
||||
"""
|
||||
|
||||
|
||||
class WebsocketConnection(wsproto.Connection):
|
||||
"""
|
||||
A very thin wrapper around wsproto.Connection:
|
||||
|
||||
- we keep the underlying connection as an attribute for easy access.
|
||||
- we add a framebuffer for incomplete messages
|
||||
- we wrap .send() so that we can directly yield it.
|
||||
"""
|
||||
|
||||
conn: connection.Connection
|
||||
frame_buf: list[bytes]
|
||||
|
||||
def __init__(self, *args, conn: connection.Connection, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.conn = conn
|
||||
self.frame_buf = [b""]
|
||||
|
||||
def send2(self, event: wsproto.events.Event) -> commands.SendData:
|
||||
data = self.send(event)
|
||||
return commands.SendData(self.conn, data)
|
||||
|
||||
def __repr__(self):
|
||||
return f"WebsocketConnection<{self.state.name}, {self.conn}>"
|
||||
|
||||
|
||||
class WebsocketLayer(layer.Layer):
|
||||
"""
|
||||
WebSocket layer that intercepts and relays messages.
|
||||
"""
|
||||
|
||||
flow: http.HTTPFlow
|
||||
client_ws: WebsocketConnection
|
||||
server_ws: WebsocketConnection
|
||||
|
||||
def __init__(self, context: Context, flow: http.HTTPFlow):
|
||||
super().__init__(context)
|
||||
self.flow = flow
|
||||
|
||||
@expect(events.Start)
|
||||
def start(self, _) -> layer.CommandGenerator[None]:
|
||||
client_extensions = []
|
||||
server_extensions = []
|
||||
|
||||
# Parse extension headers. We only support deflate at the moment and ignore everything else.
|
||||
assert self.flow.response # satisfy type checker
|
||||
ext_header = self.flow.response.headers.get("Sec-WebSocket-Extensions", "")
|
||||
if ext_header:
|
||||
for ext in wsproto.utilities.split_comma_header(
|
||||
ext_header.encode("ascii", "replace")
|
||||
):
|
||||
ext_name = ext.split(";", 1)[0].strip()
|
||||
if ext_name == wsproto.extensions.PerMessageDeflate.name:
|
||||
client_deflate = wsproto.extensions.PerMessageDeflate()
|
||||
client_deflate.finalize(ext)
|
||||
client_extensions.append(client_deflate)
|
||||
server_deflate = wsproto.extensions.PerMessageDeflate()
|
||||
server_deflate.finalize(ext)
|
||||
server_extensions.append(server_deflate)
|
||||
else:
|
||||
yield commands.Log(
|
||||
f"Ignoring unknown WebSocket extension {ext_name!r}."
|
||||
)
|
||||
|
||||
self.client_ws = WebsocketConnection(
|
||||
wsproto.ConnectionType.SERVER, client_extensions, conn=self.context.client
|
||||
)
|
||||
self.server_ws = WebsocketConnection(
|
||||
wsproto.ConnectionType.CLIENT, server_extensions, conn=self.context.server
|
||||
)
|
||||
|
||||
yield WebsocketStartHook(self.flow)
|
||||
|
||||
self._handle_event = self.relay_messages
|
||||
|
||||
_handle_event = start
|
||||
|
||||
@expect(events.DataReceived, events.ConnectionClosed, WebSocketMessageInjected)
|
||||
def relay_messages(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||
assert self.flow.websocket # satisfy type checker
|
||||
|
||||
if isinstance(event, events.ConnectionEvent):
|
||||
from_client = event.connection == self.context.client
|
||||
injected = False
|
||||
elif isinstance(event, WebSocketMessageInjected):
|
||||
from_client = event.message.from_client
|
||||
injected = True
|
||||
else:
|
||||
raise AssertionError(f"Unexpected event: {event}")
|
||||
|
||||
from_str = "client" if from_client else "server"
|
||||
if from_client:
|
||||
src_ws = self.client_ws
|
||||
dst_ws = self.server_ws
|
||||
else:
|
||||
src_ws = self.server_ws
|
||||
dst_ws = self.client_ws
|
||||
|
||||
if isinstance(event, events.DataReceived):
|
||||
src_ws.receive_data(event.data)
|
||||
elif isinstance(event, events.ConnectionClosed):
|
||||
src_ws.receive_data(None)
|
||||
elif isinstance(event, WebSocketMessageInjected):
|
||||
fragmentizer = Fragmentizer([], event.message.type == Opcode.TEXT)
|
||||
src_ws._events.extend(fragmentizer(event.message.content))
|
||||
else: # pragma: no cover
|
||||
raise AssertionError(f"Unexpected event: {event}")
|
||||
|
||||
for ws_event in src_ws.events():
|
||||
if isinstance(ws_event, wsproto.events.Message):
|
||||
is_text = isinstance(ws_event.data, str)
|
||||
if is_text:
|
||||
typ = Opcode.TEXT
|
||||
src_ws.frame_buf[-1] += ws_event.data.encode()
|
||||
else:
|
||||
typ = Opcode.BINARY
|
||||
src_ws.frame_buf[-1] += ws_event.data
|
||||
|
||||
if ws_event.message_finished:
|
||||
content = b"".join(src_ws.frame_buf)
|
||||
|
||||
fragmentizer = Fragmentizer(src_ws.frame_buf, is_text)
|
||||
src_ws.frame_buf = [b""]
|
||||
|
||||
message = websocket.WebSocketMessage(
|
||||
typ, from_client, content, injected=injected
|
||||
)
|
||||
self.flow.websocket.messages.append(message)
|
||||
yield WebsocketMessageHook(self.flow)
|
||||
|
||||
if not message.dropped:
|
||||
for msg in fragmentizer(message.content):
|
||||
yield dst_ws.send2(msg)
|
||||
|
||||
elif ws_event.frame_finished:
|
||||
src_ws.frame_buf.append(b"")
|
||||
|
||||
elif isinstance(ws_event, (wsproto.events.Ping, wsproto.events.Pong)):
|
||||
yield commands.Log(
|
||||
f"Received WebSocket {ws_event.__class__.__name__.lower()} from {from_str} "
|
||||
f"(payload: {bytes(ws_event.payload)!r})"
|
||||
)
|
||||
yield dst_ws.send2(ws_event)
|
||||
elif isinstance(ws_event, wsproto.events.CloseConnection):
|
||||
self.flow.websocket.timestamp_end = time.time()
|
||||
self.flow.websocket.closed_by_client = from_client
|
||||
self.flow.websocket.close_code = ws_event.code
|
||||
self.flow.websocket.close_reason = ws_event.reason
|
||||
|
||||
for ws in [self.server_ws, self.client_ws]:
|
||||
if ws.state in {
|
||||
ConnectionState.OPEN,
|
||||
ConnectionState.REMOTE_CLOSING,
|
||||
}:
|
||||
# response == original event, so no need to differentiate here.
|
||||
yield ws.send2(ws_event)
|
||||
yield commands.CloseConnection(ws.conn)
|
||||
yield WebsocketEndHook(self.flow)
|
||||
self.flow.live = False
|
||||
self._handle_event = self.done
|
||||
else: # pragma: no cover
|
||||
raise AssertionError(f"Unexpected WebSocket event: {ws_event}")
|
||||
|
||||
@expect(events.DataReceived, events.ConnectionClosed, WebSocketMessageInjected)
|
||||
def done(self, _) -> layer.CommandGenerator[None]:
|
||||
yield from ()
|
||||
|
||||
|
||||
class Fragmentizer:
|
||||
"""
|
||||
Theory (RFC 6455):
|
||||
Unless specified otherwise by an extension, frames have no semantic
|
||||
meaning. An intermediary might coalesce and/or split frames, [...]
|
||||
|
||||
Practice:
|
||||
Some WebSocket servers reject large payload sizes.
|
||||
Other WebSocket servers reject CONTINUATION frames.
|
||||
|
||||
As a workaround, we either retain the original chunking or, if the payload has been modified, use ~4kB chunks.
|
||||
If one deals with web servers that do not support CONTINUATION frames, addons need to monkeypatch FRAGMENT_SIZE
|
||||
if they need to modify the message.
|
||||
"""
|
||||
|
||||
# A bit less than 4kb to accommodate for headers.
|
||||
FRAGMENT_SIZE = 4000
|
||||
|
||||
def __init__(self, fragments: list[bytes], is_text: bool):
|
||||
self.fragment_lengths = [len(x) for x in fragments]
|
||||
self.is_text = is_text
|
||||
|
||||
def msg(self, data: bytes, message_finished: bool):
|
||||
if self.is_text:
|
||||
data_str = data.decode(errors="replace")
|
||||
return wsproto.events.TextMessage(
|
||||
data_str, message_finished=message_finished
|
||||
)
|
||||
else:
|
||||
return wsproto.events.BytesMessage(data, message_finished=message_finished)
|
||||
|
||||
def __call__(self, content: bytes) -> Iterator[wsproto.events.Message]:
|
||||
if len(content) == sum(self.fragment_lengths):
|
||||
# message has the same length, we can reuse the same sizes
|
||||
offset = 0
|
||||
for fl in self.fragment_lengths[:-1]:
|
||||
yield self.msg(content[offset : offset + fl], False)
|
||||
offset += fl
|
||||
yield self.msg(content[offset:], True)
|
||||
else:
|
||||
offset = 0
|
||||
total = len(content) - self.FRAGMENT_SIZE
|
||||
while offset < total:
|
||||
yield self.msg(content[offset : offset + self.FRAGMENT_SIZE], False)
|
||||
offset += self.FRAGMENT_SIZE
|
||||
yield self.msg(content[offset:], True)
|
||||
Reference in New Issue
Block a user