191 lines
6.3 KiB
Python
191 lines
6.3 KiB
Python
|
|
import struct
|
||
|
|
import time
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from typing import List
|
||
|
|
from typing import Literal
|
||
|
|
|
||
|
|
from mitmproxy import dns
|
||
|
|
from mitmproxy import flow as mflow
|
||
|
|
from mitmproxy.net.dns import response_codes
|
||
|
|
from mitmproxy.proxy import commands
|
||
|
|
from mitmproxy.proxy import events
|
||
|
|
from mitmproxy.proxy import layer
|
||
|
|
from mitmproxy.proxy.context import Context
|
||
|
|
from mitmproxy.proxy.utils import expect
|
||
|
|
|
||
|
|
_LENGTH_LABEL = struct.Struct("!H")
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class DnsRequestHook(commands.StartHook):
|
||
|
|
"""
|
||
|
|
A DNS query has been received.
|
||
|
|
"""
|
||
|
|
|
||
|
|
flow: dns.DNSFlow
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class DnsResponseHook(commands.StartHook):
|
||
|
|
"""
|
||
|
|
A DNS response has been received or set.
|
||
|
|
"""
|
||
|
|
|
||
|
|
flow: dns.DNSFlow
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class DnsErrorHook(commands.StartHook):
|
||
|
|
"""
|
||
|
|
A DNS error has occurred.
|
||
|
|
"""
|
||
|
|
|
||
|
|
flow: dns.DNSFlow
|
||
|
|
|
||
|
|
|
||
|
|
def pack_message(
|
||
|
|
message: dns.DNSMessage, transport_protocol: Literal["tcp", "udp"]
|
||
|
|
) -> bytes:
|
||
|
|
packed = message.packed
|
||
|
|
if transport_protocol == "tcp":
|
||
|
|
return struct.pack("!H", len(packed)) + packed
|
||
|
|
else:
|
||
|
|
return packed
|
||
|
|
|
||
|
|
|
||
|
|
class DNSLayer(layer.Layer):
|
||
|
|
"""
|
||
|
|
Layer that handles resolving DNS queries.
|
||
|
|
"""
|
||
|
|
|
||
|
|
flows: dict[int, dns.DNSFlow]
|
||
|
|
req_buf: bytearray
|
||
|
|
resp_buf: bytearray
|
||
|
|
|
||
|
|
def __init__(self, context: Context):
|
||
|
|
super().__init__(context)
|
||
|
|
self.flows = {}
|
||
|
|
self.req_buf = bytearray()
|
||
|
|
self.resp_buf = bytearray()
|
||
|
|
|
||
|
|
def handle_request(
|
||
|
|
self, flow: dns.DNSFlow, msg: dns.DNSMessage
|
||
|
|
) -> layer.CommandGenerator[None]:
|
||
|
|
flow.request = msg # if already set, continue and query upstream again
|
||
|
|
yield DnsRequestHook(flow)
|
||
|
|
if flow.response:
|
||
|
|
yield from self.handle_response(flow, flow.response)
|
||
|
|
elif flow.error:
|
||
|
|
yield from self.handle_error(flow, flow.error.msg)
|
||
|
|
elif not self.context.server.address:
|
||
|
|
yield from self.handle_error(
|
||
|
|
flow, "No hook has set a response and there is no upstream server."
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
if not self.context.server.connected:
|
||
|
|
err = yield commands.OpenConnection(self.context.server)
|
||
|
|
if err:
|
||
|
|
yield from self.handle_error(flow, str(err))
|
||
|
|
# cannot recover from this
|
||
|
|
return
|
||
|
|
packed = pack_message(flow.request, flow.server_conn.transport_protocol)
|
||
|
|
yield commands.SendData(self.context.server, packed)
|
||
|
|
|
||
|
|
def handle_response(
|
||
|
|
self, flow: dns.DNSFlow, msg: dns.DNSMessage
|
||
|
|
) -> layer.CommandGenerator[None]:
|
||
|
|
flow.response = msg
|
||
|
|
yield DnsResponseHook(flow)
|
||
|
|
if flow.response:
|
||
|
|
packed = pack_message(flow.response, flow.client_conn.transport_protocol)
|
||
|
|
yield commands.SendData(self.context.client, packed)
|
||
|
|
|
||
|
|
def handle_error(self, flow: dns.DNSFlow, err: str) -> layer.CommandGenerator[None]:
|
||
|
|
flow.error = mflow.Error(err)
|
||
|
|
yield DnsErrorHook(flow)
|
||
|
|
servfail = flow.request.fail(response_codes.SERVFAIL)
|
||
|
|
yield commands.SendData(
|
||
|
|
self.context.client,
|
||
|
|
pack_message(servfail, flow.client_conn.transport_protocol),
|
||
|
|
)
|
||
|
|
|
||
|
|
def unpack_message(self, data: bytes, from_client: bool) -> List[dns.DNSMessage]:
|
||
|
|
msgs: List[dns.DNSMessage] = []
|
||
|
|
|
||
|
|
buf = self.req_buf if from_client else self.resp_buf
|
||
|
|
|
||
|
|
if self.context.client.transport_protocol == "udp":
|
||
|
|
msgs.append(dns.DNSMessage.unpack(data, timestamp=time.time()))
|
||
|
|
elif self.context.client.transport_protocol == "tcp":
|
||
|
|
buf.extend(data)
|
||
|
|
size = len(buf)
|
||
|
|
offset = 0
|
||
|
|
|
||
|
|
while True:
|
||
|
|
if size - offset < _LENGTH_LABEL.size:
|
||
|
|
break
|
||
|
|
(expected_size,) = _LENGTH_LABEL.unpack_from(buf, offset)
|
||
|
|
offset += _LENGTH_LABEL.size
|
||
|
|
if expected_size == 0:
|
||
|
|
raise struct.error("Message length field cannot be zero")
|
||
|
|
|
||
|
|
if size - offset < expected_size:
|
||
|
|
offset -= _LENGTH_LABEL.size
|
||
|
|
break
|
||
|
|
|
||
|
|
data = bytes(buf[offset : expected_size + offset])
|
||
|
|
offset += expected_size
|
||
|
|
msgs.append(dns.DNSMessage.unpack(data, timestamp=time.time()))
|
||
|
|
|
||
|
|
del buf[:offset]
|
||
|
|
return msgs
|
||
|
|
|
||
|
|
@expect(events.Start)
|
||
|
|
def state_start(self, _) -> layer.CommandGenerator[None]:
|
||
|
|
self._handle_event = self.state_query
|
||
|
|
yield from ()
|
||
|
|
|
||
|
|
@expect(events.DataReceived, events.ConnectionClosed)
|
||
|
|
def state_query(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||
|
|
assert isinstance(event, events.ConnectionEvent)
|
||
|
|
from_client = event.connection is self.context.client
|
||
|
|
|
||
|
|
if isinstance(event, events.DataReceived):
|
||
|
|
msgs: List[dns.DNSMessage] = []
|
||
|
|
try:
|
||
|
|
msgs = self.unpack_message(event.data, from_client)
|
||
|
|
except struct.error as e:
|
||
|
|
yield commands.Log(f"{event.connection} sent an invalid message: {e}")
|
||
|
|
yield commands.CloseConnection(event.connection)
|
||
|
|
self._handle_event = self.state_done
|
||
|
|
else:
|
||
|
|
for msg in msgs:
|
||
|
|
try:
|
||
|
|
flow = self.flows[msg.id]
|
||
|
|
except KeyError:
|
||
|
|
flow = dns.DNSFlow(
|
||
|
|
self.context.client, self.context.server, live=True
|
||
|
|
)
|
||
|
|
self.flows[msg.id] = flow
|
||
|
|
if from_client:
|
||
|
|
yield from self.handle_request(flow, msg)
|
||
|
|
else:
|
||
|
|
yield from self.handle_response(flow, msg)
|
||
|
|
|
||
|
|
elif isinstance(event, events.ConnectionClosed):
|
||
|
|
other_conn = self.context.server if from_client else self.context.client
|
||
|
|
if other_conn.connected:
|
||
|
|
yield commands.CloseConnection(other_conn)
|
||
|
|
self._handle_event = self.state_done
|
||
|
|
for flow in self.flows.values():
|
||
|
|
flow.live = False
|
||
|
|
|
||
|
|
else:
|
||
|
|
raise AssertionError(f"Unexpected event: {event}")
|
||
|
|
|
||
|
|
@expect(events.DataReceived, events.ConnectionClosed)
|
||
|
|
def state_done(self, _) -> layer.CommandGenerator[None]:
|
||
|
|
yield from ()
|
||
|
|
|
||
|
|
_handle_event = state_start
|