Files
“shengyudong” 322ac74336 2025-12-25 upload
2025-12-25 11:16:59 +08:00

191 lines
6.3 KiB
Python

import struct
import time
from dataclasses import dataclass
from typing import List
from typing import Literal
from mitmproxy import dns
from mitmproxy import flow as mflow
from mitmproxy.net.dns import response_codes
from mitmproxy.proxy import commands
from mitmproxy.proxy import events
from mitmproxy.proxy import layer
from mitmproxy.proxy.context import Context
from mitmproxy.proxy.utils import expect
_LENGTH_LABEL = struct.Struct("!H")
@dataclass
class DnsRequestHook(commands.StartHook):
"""
A DNS query has been received.
"""
flow: dns.DNSFlow
@dataclass
class DnsResponseHook(commands.StartHook):
"""
A DNS response has been received or set.
"""
flow: dns.DNSFlow
@dataclass
class DnsErrorHook(commands.StartHook):
"""
A DNS error has occurred.
"""
flow: dns.DNSFlow
def pack_message(
message: dns.DNSMessage, transport_protocol: Literal["tcp", "udp"]
) -> bytes:
packed = message.packed
if transport_protocol == "tcp":
return struct.pack("!H", len(packed)) + packed
else:
return packed
class DNSLayer(layer.Layer):
"""
Layer that handles resolving DNS queries.
"""
flows: dict[int, dns.DNSFlow]
req_buf: bytearray
resp_buf: bytearray
def __init__(self, context: Context):
super().__init__(context)
self.flows = {}
self.req_buf = bytearray()
self.resp_buf = bytearray()
def handle_request(
self, flow: dns.DNSFlow, msg: dns.DNSMessage
) -> layer.CommandGenerator[None]:
flow.request = msg # if already set, continue and query upstream again
yield DnsRequestHook(flow)
if flow.response:
yield from self.handle_response(flow, flow.response)
elif flow.error:
yield from self.handle_error(flow, flow.error.msg)
elif not self.context.server.address:
yield from self.handle_error(
flow, "No hook has set a response and there is no upstream server."
)
else:
if not self.context.server.connected:
err = yield commands.OpenConnection(self.context.server)
if err:
yield from self.handle_error(flow, str(err))
# cannot recover from this
return
packed = pack_message(flow.request, flow.server_conn.transport_protocol)
yield commands.SendData(self.context.server, packed)
def handle_response(
self, flow: dns.DNSFlow, msg: dns.DNSMessage
) -> layer.CommandGenerator[None]:
flow.response = msg
yield DnsResponseHook(flow)
if flow.response:
packed = pack_message(flow.response, flow.client_conn.transport_protocol)
yield commands.SendData(self.context.client, packed)
def handle_error(self, flow: dns.DNSFlow, err: str) -> layer.CommandGenerator[None]:
flow.error = mflow.Error(err)
yield DnsErrorHook(flow)
servfail = flow.request.fail(response_codes.SERVFAIL)
yield commands.SendData(
self.context.client,
pack_message(servfail, flow.client_conn.transport_protocol),
)
def unpack_message(self, data: bytes, from_client: bool) -> List[dns.DNSMessage]:
msgs: List[dns.DNSMessage] = []
buf = self.req_buf if from_client else self.resp_buf
if self.context.client.transport_protocol == "udp":
msgs.append(dns.DNSMessage.unpack(data, timestamp=time.time()))
elif self.context.client.transport_protocol == "tcp":
buf.extend(data)
size = len(buf)
offset = 0
while True:
if size - offset < _LENGTH_LABEL.size:
break
(expected_size,) = _LENGTH_LABEL.unpack_from(buf, offset)
offset += _LENGTH_LABEL.size
if expected_size == 0:
raise struct.error("Message length field cannot be zero")
if size - offset < expected_size:
offset -= _LENGTH_LABEL.size
break
data = bytes(buf[offset : expected_size + offset])
offset += expected_size
msgs.append(dns.DNSMessage.unpack(data, timestamp=time.time()))
del buf[:offset]
return msgs
@expect(events.Start)
def state_start(self, _) -> layer.CommandGenerator[None]:
self._handle_event = self.state_query
yield from ()
@expect(events.DataReceived, events.ConnectionClosed)
def state_query(self, event: events.Event) -> layer.CommandGenerator[None]:
assert isinstance(event, events.ConnectionEvent)
from_client = event.connection is self.context.client
if isinstance(event, events.DataReceived):
msgs: List[dns.DNSMessage] = []
try:
msgs = self.unpack_message(event.data, from_client)
except struct.error as e:
yield commands.Log(f"{event.connection} sent an invalid message: {e}")
yield commands.CloseConnection(event.connection)
self._handle_event = self.state_done
else:
for msg in msgs:
try:
flow = self.flows[msg.id]
except KeyError:
flow = dns.DNSFlow(
self.context.client, self.context.server, live=True
)
self.flows[msg.id] = flow
if from_client:
yield from self.handle_request(flow, msg)
else:
yield from self.handle_response(flow, msg)
elif isinstance(event, events.ConnectionClosed):
other_conn = self.context.server if from_client else self.context.client
if other_conn.connected:
yield commands.CloseConnection(other_conn)
self._handle_event = self.state_done
for flow in self.flows.values():
flow.live = False
else:
raise AssertionError(f"Unexpected event: {event}")
@expect(events.DataReceived, events.ConnectionClosed)
def state_done(self, _) -> layer.CommandGenerator[None]:
yield from ()
_handle_event = state_start