Files
baijiahao_data_crawl/venv/Lib/site-packages/mitmproxy/addons/clientplayback.py

299 lines
10 KiB
Python
Raw Normal View History

2025-12-25 11:16:59 +08:00
from __future__ import annotations
import asyncio
import logging
import time
from collections.abc import Sequence
from types import TracebackType
from typing import cast
from typing import Literal
import mitmproxy.types
from mitmproxy import command
from mitmproxy import ctx
from mitmproxy import exceptions
from mitmproxy import flow
from mitmproxy import http
from mitmproxy import io
from mitmproxy.connection import ConnectionState
from mitmproxy.connection import Server
from mitmproxy.hooks import UpdateHook
from mitmproxy.log import ALERT
from mitmproxy.options import Options
from mitmproxy.proxy import commands
from mitmproxy.proxy import events
from mitmproxy.proxy import layers
from mitmproxy.proxy import server
from mitmproxy.proxy.context import Context
from mitmproxy.proxy.layer import CommandGenerator
from mitmproxy.proxy.layers.http import HTTPMode
from mitmproxy.proxy.mode_specs import UpstreamMode
from mitmproxy.utils import asyncio_utils
logger = logging.getLogger(__name__)
class MockServer(layers.http.HttpConnection):
"""
A mock HTTP "server" that just pretends it received a full HTTP request,
which is then processed by the proxy core.
"""
flow: http.HTTPFlow
def __init__(self, flow: http.HTTPFlow, context: Context):
super().__init__(context, context.client)
self.flow = flow
def _handle_event(self, event: events.Event) -> CommandGenerator[None]:
if isinstance(event, events.Start):
content = self.flow.request.raw_content
self.flow.request.timestamp_start = self.flow.request.timestamp_end = (
time.time()
)
yield layers.http.ReceiveHttp(
layers.http.RequestHeaders(
1,
self.flow.request,
end_stream=not (content or self.flow.request.trailers),
replay_flow=self.flow,
)
)
if content:
yield layers.http.ReceiveHttp(layers.http.RequestData(1, content))
if self.flow.request.trailers: # pragma: no cover
# TODO: Cover this once we support HTTP/1 trailers.
yield layers.http.ReceiveHttp(
layers.http.RequestTrailers(1, self.flow.request.trailers)
)
yield layers.http.ReceiveHttp(layers.http.RequestEndOfMessage(1))
elif isinstance(
event,
(
layers.http.ResponseHeaders,
layers.http.ResponseData,
layers.http.ResponseTrailers,
layers.http.ResponseEndOfMessage,
layers.http.ResponseProtocolError,
),
):
pass
else: # pragma: no cover
logger.warning(f"Unexpected event during replay: {event}")
class ReplayHandler(server.ConnectionHandler):
layer: layers.HttpLayer
def __init__(self, flow: http.HTTPFlow, options: Options) -> None:
client = flow.client_conn.copy()
client.state = ConnectionState.OPEN
context = Context(client, options)
context.server = Server(address=(flow.request.host, flow.request.port))
if flow.request.scheme == "https":
context.server.tls = True
context.server.sni = flow.request.pretty_host
if options.mode and options.mode[0].startswith("upstream:"):
mode = UpstreamMode.parse(options.mode[0])
assert isinstance(mode, UpstreamMode) # remove once mypy supports Self.
context.server.via = flow.server_conn.via = (mode.scheme, mode.address)
super().__init__(context)
if options.mode and options.mode[0].startswith("upstream:"):
self.layer = layers.HttpLayer(context, HTTPMode.upstream)
else:
self.layer = layers.HttpLayer(context, HTTPMode.transparent)
self.layer.connections[client] = MockServer(flow, context.fork())
self.flow = flow
self.done = asyncio.Event()
async def replay(self) -> None:
await self.server_event(events.Start())
await self.done.wait()
def log(
self,
message: str,
level: int = logging.INFO,
exc_info: Literal[True]
| tuple[type[BaseException] | None, BaseException | None, TracebackType | None]
| None = None,
) -> None:
assert isinstance(level, int)
logger.log(level=level, msg=f"[replay] {message}")
async def handle_hook(self, hook: commands.StartHook) -> None:
(data,) = hook.args()
await ctx.master.addons.handle_lifecycle(hook)
if isinstance(data, flow.Flow):
await data.wait_for_resume()
if isinstance(hook, (layers.http.HttpResponseHook, layers.http.HttpErrorHook)):
if self.transports:
# close server connections
for x in self.transports.values():
if x.handler:
x.handler.cancel()
await asyncio.wait(
[x.handler for x in self.transports.values() if x.handler]
)
# signal completion
self.done.set()
class ClientPlayback:
playback_task: asyncio.Task | None = None
inflight: http.HTTPFlow | None
queue: asyncio.Queue
options: Options
replay_tasks: set[asyncio.Task]
def __init__(self):
self.queue = asyncio.Queue()
self.inflight = None
self.task = None
self.replay_tasks = set()
def running(self):
self.options = ctx.options
self.playback_task = asyncio_utils.create_task(
self.playback(),
name="client playback",
keep_ref=False,
)
async def done(self):
if self.playback_task:
self.playback_task.cancel()
try:
await self.playback_task
except asyncio.CancelledError:
pass
async def playback(self):
while True:
self.inflight = await self.queue.get()
try:
assert self.inflight
h = ReplayHandler(self.inflight, self.options)
if ctx.options.client_replay_concurrency == -1:
t = asyncio_utils.create_task(
h.replay(),
name="client playback awaiting response",
keep_ref=False,
)
# keep a reference so this is not garbage collected
self.replay_tasks.add(t)
t.add_done_callback(self.replay_tasks.remove)
else:
await h.replay()
except Exception:
logger.exception(f"Client replay has crashed!")
self.queue.task_done()
self.inflight = None
def check(self, f: flow.Flow) -> str | None:
if f.live or f == self.inflight:
return "Can't replay live flow."
if f.intercepted:
return "Can't replay intercepted flow."
if isinstance(f, http.HTTPFlow):
if not f.request:
return "Can't replay flow with missing request."
if f.request.raw_content is None:
return "Can't replay flow with missing content."
if f.websocket is not None:
return "Can't replay WebSocket flows."
else:
return "Can only replay HTTP flows."
return None
def load(self, loader):
loader.add_option(
"client_replay",
Sequence[str],
[],
"Replay client requests from a saved file.",
)
loader.add_option(
"client_replay_concurrency",
int,
1,
"Concurrency limit on in-flight client replay requests. Currently the only valid values are 1 and -1 (no limit).",
)
def configure(self, updated):
if "client_replay" in updated and ctx.options.client_replay:
try:
flows = io.read_flows_from_paths(ctx.options.client_replay)
except exceptions.FlowReadException as e:
raise exceptions.OptionsError(str(e))
self.start_replay(flows)
if "client_replay_concurrency" in updated:
if ctx.options.client_replay_concurrency not in [-1, 1]:
raise exceptions.OptionsError(
"Currently the only valid client_replay_concurrency values are -1 and 1."
)
@command.command("replay.client.count")
def count(self) -> int:
"""
Approximate number of flows queued for replay.
"""
return self.queue.qsize() + int(bool(self.inflight))
@command.command("replay.client.stop")
def stop_replay(self) -> None:
"""
Clear the replay queue.
"""
updated = []
while True:
try:
f = self.queue.get_nowait()
except asyncio.QueueEmpty:
break
else:
self.queue.task_done()
f.revert()
updated.append(f)
ctx.master.addons.trigger(UpdateHook(updated))
logger.log(ALERT, "Client replay queue cleared.")
@command.command("replay.client")
def start_replay(self, flows: Sequence[flow.Flow]) -> None:
"""
Add flows to the replay queue, skipping flows that can't be replayed.
"""
updated: list[http.HTTPFlow] = []
for f in flows:
err = self.check(f)
if err:
logger.warning(err)
continue
http_flow = cast(http.HTTPFlow, f)
# Prepare the flow for replay
http_flow.backup()
http_flow.is_replay = "request"
http_flow.response = None
http_flow.error = None
self.queue.put_nowait(http_flow)
updated.append(http_flow)
ctx.master.addons.trigger(UpdateHook(updated))
@command.command("replay.client.file")
def load_file(self, path: mitmproxy.types.Path) -> None:
"""
Load flows from file, and add them to the replay queue.
"""
try:
flows = io.read_flows_from_paths([path])
except exceptions.FlowReadException as e:
raise exceptions.CommandError(str(e))
self.start_replay(flows)