Files
baijiahao_data_crawl/venv/Lib/site-packages/mitmproxy/dns.py
“shengyudong” 322ac74336 2025-12-25 upload
2025-12-25 11:16:59 +08:00

589 lines
20 KiB
Python

from __future__ import annotations
import base64
import itertools
import random
import struct
import time
from collections.abc import Iterable
from dataclasses import dataclass
from ipaddress import IPv4Address
from ipaddress import IPv6Address
from typing import Any
from typing import cast
from typing import ClassVar
from typing import Self
from mitmproxy import flow
from mitmproxy.coretypes import serializable
from mitmproxy.net.dns import classes
from mitmproxy.net.dns import domain_names
from mitmproxy.net.dns import https_records
from mitmproxy.net.dns import op_codes
from mitmproxy.net.dns import response_codes
from mitmproxy.net.dns import types
from mitmproxy.net.dns.https_records import HTTPSRecord
from mitmproxy.net.dns.https_records import HTTPSRecordJSON
from mitmproxy.net.dns.https_records import SVCParamKeys
# DNS parameters taken from https://www.iana.org/assignments/dns-parameters/dns-parameters.xml
@dataclass
class Question(serializable.SerializableDataclass):
HEADER: ClassVar[struct.Struct] = struct.Struct("!HH")
name: str
type: int
class_: int
def __str__(self) -> str:
return self.name
def to_json(self) -> dict:
"""
Converts the question into json for mitmweb.
Sync with web/src/flow.ts.
"""
return {
"name": self.name,
"type": types.to_str(self.type),
"class": classes.to_str(self.class_),
}
@classmethod
def from_json(cls, data: dict[str, str]) -> Self:
return cls(
name=data["name"],
type=types.from_str(data["type"]),
class_=classes.from_str(data["class"]),
)
@dataclass
class ResourceRecord(serializable.SerializableDataclass):
DEFAULT_TTL: ClassVar[int] = 60
HEADER: ClassVar[struct.Struct] = struct.Struct("!HHIH")
name: str
type: int
class_: int
ttl: int
data: bytes
def __str__(self) -> str:
return str(self._data_json())
@property
def text(self) -> str:
return self.data.decode("utf-8")
@text.setter
def text(self, value: str) -> None:
self.data = value.encode("utf-8")
@property
def ipv4_address(self) -> IPv4Address:
return IPv4Address(self.data)
@ipv4_address.setter
def ipv4_address(self, ip: IPv4Address) -> None:
self.data = ip.packed
@property
def ipv6_address(self) -> IPv6Address:
return IPv6Address(self.data)
@ipv6_address.setter
def ipv6_address(self, ip: IPv6Address) -> None:
self.data = ip.packed
@property
def domain_name(self) -> str:
return domain_names.unpack(self.data)
@domain_name.setter
def domain_name(self, name: str) -> None:
self.data = domain_names.pack(name)
@property
def https_alpn(self) -> tuple[bytes, ...] | None:
record = https_records.unpack(self.data)
alpn_bytes = record.params.get(SVCParamKeys.ALPN.value, None)
if alpn_bytes is not None:
i = 0
ret = []
while i < len(alpn_bytes):
token_len = alpn_bytes[i]
ret.append(alpn_bytes[i + 1 : i + 1 + token_len])
i += token_len + 1
return tuple(ret)
else:
return None
@https_alpn.setter
def https_alpn(self, alpn: Iterable[bytes] | None) -> None:
record = https_records.unpack(self.data)
if alpn is None:
record.params.pop(SVCParamKeys.ALPN.value, None)
else:
alpn_bytes = b"".join(bytes([len(a)]) + a for a in alpn)
record.params[SVCParamKeys.ALPN.value] = alpn_bytes
self.data = https_records.pack(record)
@property
def https_ech(self) -> str | None:
record = https_records.unpack(self.data)
ech_bytes = record.params.get(SVCParamKeys.ECH.value, None)
if ech_bytes is not None:
return base64.b64encode(ech_bytes).decode("utf-8")
else:
return None
@https_ech.setter
def https_ech(self, ech: str | None) -> None:
record = https_records.unpack(self.data)
if ech is None:
record.params.pop(SVCParamKeys.ECH.value, None)
else:
ech_bytes = base64.b64decode(ech.encode("utf-8"))
record.params[SVCParamKeys.ECH.value] = ech_bytes
self.data = https_records.pack(record)
def _data_json(self) -> str | HTTPSRecordJSON:
try:
match self.type:
case types.A:
return str(self.ipv4_address)
case types.AAAA:
return str(self.ipv6_address)
case types.NS | types.CNAME | types.PTR:
return self.domain_name
case types.TXT:
return self.text
case types.HTTPS:
return https_records.unpack(self.data).to_json()
case _:
return f"0x{self.data.hex()}"
except Exception:
return f"0x{self.data.hex()} (invalid {types.to_str(self.type)} data)"
def to_json(self) -> dict[str, str | int | HTTPSRecordJSON]:
"""
Converts the resource record into json for mitmweb.
Sync with web/src/flow.ts.
"""
return {
"name": self.name,
"type": types.to_str(self.type),
"class": classes.to_str(self.class_),
"ttl": self.ttl,
"data": self._data_json(),
}
@classmethod
def from_json(cls, data: dict[str, Any]) -> Self:
inst = cls(
name=data["name"],
type=types.from_str(data["type"]),
class_=classes.from_str(data["class"]),
ttl=data["ttl"],
data=b"",
)
d: str = data["data"]
try:
match inst.type:
case types.A:
inst.ipv4_address = IPv4Address(d)
case types.AAAA:
inst.ipv6_address = IPv6Address(d)
case types.NS | types.CNAME | types.PTR:
inst.domain_name = d
case types.TXT:
inst.text = d
case types.HTTPS:
record = HTTPSRecord.from_json(cast(HTTPSRecordJSON, d))
inst.data = https_records.pack(record)
case _:
raise ValueError
except Exception:
inst.data = bytes.fromhex(d.removeprefix("0x").partition(" (")[0])
return inst
@classmethod
def A(cls, name: str, ip: IPv4Address, *, ttl: int = DEFAULT_TTL) -> ResourceRecord:
"""Create an IPv4 resource record."""
return cls(name, types.A, classes.IN, ttl, ip.packed)
@classmethod
def AAAA(
cls, name: str, ip: IPv6Address, *, ttl: int = DEFAULT_TTL
) -> ResourceRecord:
"""Create an IPv6 resource record."""
return cls(name, types.AAAA, classes.IN, ttl, ip.packed)
@classmethod
def CNAME(
cls, alias: str, canonical: str, *, ttl: int = DEFAULT_TTL
) -> ResourceRecord:
"""Create a canonical internet name resource record."""
return cls(alias, types.CNAME, classes.IN, ttl, domain_names.pack(canonical))
@classmethod
def PTR(cls, inaddr: str, ptr: str, *, ttl: int = DEFAULT_TTL) -> ResourceRecord:
"""Create a canonical internet name resource record."""
return cls(inaddr, types.PTR, classes.IN, ttl, domain_names.pack(ptr))
@classmethod
def TXT(cls, name: str, text: str, *, ttl: int = DEFAULT_TTL) -> ResourceRecord:
"""Create a textual resource record."""
return cls(name, types.TXT, classes.IN, ttl, text.encode("utf-8"))
@classmethod
def HTTPS(
cls, name: str, record: HTTPSRecord, ttl: int = DEFAULT_TTL
) -> ResourceRecord:
"""Create a HTTPS resource record"""
return cls(name, types.HTTPS, classes.IN, ttl, https_records.pack(record))
# comments are taken from rfc1035
@dataclass
class DNSMessage(serializable.SerializableDataclass):
HEADER: ClassVar[struct.Struct] = struct.Struct("!HHHHHH")
id: int
"""An identifier assigned by the program that generates any kind of query."""
query: bool
"""A field that specifies whether this message is a query."""
op_code: int
"""
A field that specifies kind of query in this message.
This value is set by the originator of a request and copied into the response.
"""
authoritative_answer: bool
"""
This field is valid in responses, and specifies that the responding name server
is an authority for the domain name in question section.
"""
truncation: bool
"""Specifies that this message was truncated due to length greater than that permitted on the transmission channel."""
recursion_desired: bool
"""
This field may be set in a query and is copied into the response.
If set, it directs the name server to pursue the query recursively.
"""
recursion_available: bool
"""This field is set or cleared in a response, and denotes whether recursive query support is available in the name server."""
reserved: int
"""Reserved for future use. Must be zero in all queries and responses."""
response_code: int
"""This field is set as part of responses."""
questions: list[Question]
"""
The question section is used to carry the "question" in most queries, i.e.
the parameters that define what is being asked.
"""
answers: list[ResourceRecord]
"""First resource record section."""
authorities: list[ResourceRecord]
"""Second resource record section."""
additionals: list[ResourceRecord]
"""Third resource record section."""
timestamp: float | None = None
"""The time at which the message was sent or received."""
def __str__(self) -> str:
return "\r\n".join(
map(
str,
itertools.chain(
self.questions, self.answers, self.authorities, self.additionals
),
)
)
@property
def content(self) -> bytes:
return self.packed
@property
def question(self) -> Question | None:
"""DNS practically only supports a single question at the
same time, so this is a shorthand for this."""
if len(self.questions) == 1:
return self.questions[0]
return None
@property
def size(self) -> int:
"""Returns the cumulative data size of all resource record sections."""
return sum(
len(x.data)
for x in itertools.chain.from_iterable(
[self.answers, self.authorities, self.additionals]
)
)
def fail(self, response_code: int) -> DNSMessage:
if response_code == response_codes.NOERROR:
raise ValueError("response_code must be an error code.")
return DNSMessage(
timestamp=time.time(),
id=self.id,
query=False,
op_code=self.op_code,
authoritative_answer=False,
truncation=False,
recursion_desired=self.recursion_desired,
recursion_available=False,
reserved=0,
response_code=response_code,
questions=self.questions,
answers=[],
authorities=[],
additionals=[],
)
def succeed(self, answers: list[ResourceRecord]) -> DNSMessage:
return DNSMessage(
timestamp=time.time(),
id=self.id,
query=False,
op_code=self.op_code,
authoritative_answer=False,
truncation=False,
recursion_desired=self.recursion_desired,
recursion_available=True,
reserved=0,
response_code=response_codes.NOERROR,
questions=self.questions,
answers=answers,
authorities=[],
additionals=[],
)
@classmethod
def unpack(cls, buffer: bytes, timestamp: float | None = None) -> DNSMessage:
"""Converts the entire given buffer into a DNS message."""
length, msg = cls.unpack_from(buffer, 0, timestamp)
if length != len(buffer):
raise struct.error(f"unpack requires a buffer of {length} bytes")
return msg
@classmethod
def unpack_from(
cls, buffer: bytes | bytearray, offset: int, timestamp: float | None = None
) -> tuple[int, DNSMessage]:
"""Converts the buffer from a given offset into a DNS message and also returns its length."""
(
id,
flags,
len_questions,
len_answers,
len_authorities,
len_additionals,
) = DNSMessage.HEADER.unpack_from(buffer, offset)
msg = DNSMessage(
timestamp=timestamp,
id=id,
query=(flags & (1 << 15)) == 0,
op_code=(flags >> 11) & 0b1111,
authoritative_answer=(flags & (1 << 10)) != 0,
truncation=(flags & (1 << 9)) != 0,
recursion_desired=(flags & (1 << 8)) != 0,
recursion_available=(flags & (1 << 7)) != 0,
reserved=(flags >> 4) & 0b111,
response_code=flags & 0b1111,
questions=[],
answers=[],
authorities=[],
additionals=[],
)
offset += DNSMessage.HEADER.size
cached_names = domain_names.cache()
def unpack_domain_name() -> str:
nonlocal buffer, offset, cached_names
name, length = domain_names.unpack_from_with_compression(
buffer, offset, cached_names
)
offset += length
return name
for i in range(0, len_questions):
try:
name = unpack_domain_name()
type, class_ = Question.HEADER.unpack_from(buffer, offset)
offset += Question.HEADER.size
msg.questions.append(Question(name=name, type=type, class_=class_))
except struct.error as e:
raise struct.error(f"question #{i}: {e}")
def unpack_rrs(
section: list[ResourceRecord], section_name: str, count: int
) -> None:
nonlocal buffer, offset
for i in range(0, count):
try:
name = unpack_domain_name()
type, class_, ttl, len_data = ResourceRecord.HEADER.unpack_from(
buffer, offset
)
offset += ResourceRecord.HEADER.size
end_data = offset + len_data
if len(buffer) < end_data:
raise struct.error(
f"unpack requires a data buffer of {len_data} bytes"
)
data = buffer[offset:end_data]
if domain_names.record_data_can_have_compression(type):
data = domain_names.decompress_from_record_data(
buffer, offset, end_data, cached_names
)
section.append(ResourceRecord(name, type, class_, ttl, data))
offset += len_data
except struct.error as e:
raise struct.error(f"{section_name} #{i}: {e}")
unpack_rrs(msg.answers, "answer", len_answers)
unpack_rrs(msg.authorities, "authority", len_authorities)
unpack_rrs(msg.additionals, "additional", len_additionals)
return (offset, msg)
@property
def packed(self) -> bytes:
"""Converts the message into network bytes."""
if self.id < 0 or self.id > 65535:
raise ValueError(f"DNS message's id {self.id} is out of bounds.")
flags = 0
if not self.query:
flags |= 1 << 15
if self.op_code < 0 or self.op_code > 0b1111:
raise ValueError(f"DNS message's op_code {self.op_code} is out of bounds.")
flags |= self.op_code << 11
if self.authoritative_answer:
flags |= 1 << 10
if self.truncation:
flags |= 1 << 9
if self.recursion_desired:
flags |= 1 << 8
if self.recursion_available:
flags |= 1 << 7
if self.reserved < 0 or self.reserved > 0b111:
raise ValueError(
f"DNS message's reserved value of {self.reserved} is out of bounds."
)
flags |= self.reserved << 4
if self.response_code < 0 or self.response_code > 0b1111:
raise ValueError(
f"DNS message's response_code {self.response_code} is out of bounds."
)
flags |= self.response_code
data = bytearray()
data.extend(
DNSMessage.HEADER.pack(
self.id,
flags,
len(self.questions),
len(self.answers),
len(self.authorities),
len(self.additionals),
)
)
# TODO implement compression
for question in self.questions:
data.extend(domain_names.pack(question.name))
data.extend(Question.HEADER.pack(question.type, question.class_))
for rr in (*self.answers, *self.authorities, *self.additionals):
data.extend(domain_names.pack(rr.name))
data.extend(
ResourceRecord.HEADER.pack(rr.type, rr.class_, rr.ttl, len(rr.data))
)
data.extend(rr.data)
return bytes(data)
def to_json(self) -> dict:
"""
Converts the message into json for mitmweb.
Sync with web/src/flow.ts.
"""
ret = {
"id": self.id,
"query": self.query,
"op_code": op_codes.to_str(self.op_code),
"authoritative_answer": self.authoritative_answer,
"truncation": self.truncation,
"recursion_desired": self.recursion_desired,
"recursion_available": self.recursion_available,
"response_code": response_codes.to_str(self.response_code),
"status_code": response_codes.http_equiv_status_code(self.response_code),
"questions": [question.to_json() for question in self.questions],
"answers": [rr.to_json() for rr in self.answers],
"authorities": [rr.to_json() for rr in self.authorities],
"additionals": [rr.to_json() for rr in self.additionals],
"size": self.size,
}
if self.timestamp:
ret["timestamp"] = self.timestamp
return ret
@classmethod
def from_json(cls, data: Any) -> DNSMessage:
"""Reconstruct a DNS message from JSON."""
inst = cls(
id=data["id"],
query=data["query"],
op_code=op_codes.from_str(data["op_code"]),
authoritative_answer=data["authoritative_answer"],
truncation=data["truncation"],
recursion_desired=data["recursion_desired"],
recursion_available=data["recursion_available"],
reserved=0,
response_code=response_codes.from_str(data["response_code"]),
questions=[Question.from_json(x) for x in data["questions"]],
answers=[ResourceRecord.from_json(x) for x in data["answers"]],
authorities=[ResourceRecord.from_json(x) for x in data["authorities"]],
additionals=[ResourceRecord.from_json(x) for x in data["additionals"]],
)
if ts := data.get("timestamp"):
inst.timestamp = ts
return inst
def copy(self) -> DNSMessage:
# we keep the copy semantics but change the ID generation
state = self.get_state()
state["id"] = random.randint(0, 65535)
return DNSMessage.from_state(state)
class DNSFlow(flow.Flow):
"""A DNSFlow is a collection of DNS messages representing a single DNS query."""
request: DNSMessage
"""The DNS request."""
response: DNSMessage | None = None
"""The DNS response."""
def get_state(self) -> serializable.State:
return {
**super().get_state(),
"request": self.request.get_state(),
"response": self.response.get_state() if self.response else None,
}
def set_state(self, state: serializable.State) -> None:
self.request = DNSMessage.from_state(state.pop("request"))
self.response = (
DNSMessage.from_state(r) if (r := state.pop("response")) else None
)
super().set_state(state)
def __repr__(self) -> str:
return f"<DNSFlow\r\n request={self.request!r}\r\n response={self.response!r}\r\n>"