2025-12-25 upload
This commit is contained in:
588
venv/Lib/site-packages/mitmproxy/dns.py
Normal file
588
venv/Lib/site-packages/mitmproxy/dns.py
Normal file
@@ -0,0 +1,588 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import itertools
|
||||
import random
|
||||
import struct
|
||||
import time
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from ipaddress import IPv4Address
|
||||
from ipaddress import IPv6Address
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import ClassVar
|
||||
from typing import Self
|
||||
|
||||
from mitmproxy import flow
|
||||
from mitmproxy.coretypes import serializable
|
||||
from mitmproxy.net.dns import classes
|
||||
from mitmproxy.net.dns import domain_names
|
||||
from mitmproxy.net.dns import https_records
|
||||
from mitmproxy.net.dns import op_codes
|
||||
from mitmproxy.net.dns import response_codes
|
||||
from mitmproxy.net.dns import types
|
||||
from mitmproxy.net.dns.https_records import HTTPSRecord
|
||||
from mitmproxy.net.dns.https_records import HTTPSRecordJSON
|
||||
from mitmproxy.net.dns.https_records import SVCParamKeys
|
||||
|
||||
# DNS parameters taken from https://www.iana.org/assignments/dns-parameters/dns-parameters.xml
|
||||
|
||||
|
||||
@dataclass
|
||||
class Question(serializable.SerializableDataclass):
|
||||
HEADER: ClassVar[struct.Struct] = struct.Struct("!HH")
|
||||
|
||||
name: str
|
||||
type: int
|
||||
class_: int
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
def to_json(self) -> dict:
|
||||
"""
|
||||
Converts the question into json for mitmweb.
|
||||
Sync with web/src/flow.ts.
|
||||
"""
|
||||
return {
|
||||
"name": self.name,
|
||||
"type": types.to_str(self.type),
|
||||
"class": classes.to_str(self.class_),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, data: dict[str, str]) -> Self:
|
||||
return cls(
|
||||
name=data["name"],
|
||||
type=types.from_str(data["type"]),
|
||||
class_=classes.from_str(data["class"]),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResourceRecord(serializable.SerializableDataclass):
|
||||
DEFAULT_TTL: ClassVar[int] = 60
|
||||
HEADER: ClassVar[struct.Struct] = struct.Struct("!HHIH")
|
||||
|
||||
name: str
|
||||
type: int
|
||||
class_: int
|
||||
ttl: int
|
||||
data: bytes
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self._data_json())
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
return self.data.decode("utf-8")
|
||||
|
||||
@text.setter
|
||||
def text(self, value: str) -> None:
|
||||
self.data = value.encode("utf-8")
|
||||
|
||||
@property
|
||||
def ipv4_address(self) -> IPv4Address:
|
||||
return IPv4Address(self.data)
|
||||
|
||||
@ipv4_address.setter
|
||||
def ipv4_address(self, ip: IPv4Address) -> None:
|
||||
self.data = ip.packed
|
||||
|
||||
@property
|
||||
def ipv6_address(self) -> IPv6Address:
|
||||
return IPv6Address(self.data)
|
||||
|
||||
@ipv6_address.setter
|
||||
def ipv6_address(self, ip: IPv6Address) -> None:
|
||||
self.data = ip.packed
|
||||
|
||||
@property
|
||||
def domain_name(self) -> str:
|
||||
return domain_names.unpack(self.data)
|
||||
|
||||
@domain_name.setter
|
||||
def domain_name(self, name: str) -> None:
|
||||
self.data = domain_names.pack(name)
|
||||
|
||||
@property
|
||||
def https_alpn(self) -> tuple[bytes, ...] | None:
|
||||
record = https_records.unpack(self.data)
|
||||
alpn_bytes = record.params.get(SVCParamKeys.ALPN.value, None)
|
||||
if alpn_bytes is not None:
|
||||
i = 0
|
||||
ret = []
|
||||
while i < len(alpn_bytes):
|
||||
token_len = alpn_bytes[i]
|
||||
ret.append(alpn_bytes[i + 1 : i + 1 + token_len])
|
||||
i += token_len + 1
|
||||
return tuple(ret)
|
||||
else:
|
||||
return None
|
||||
|
||||
@https_alpn.setter
|
||||
def https_alpn(self, alpn: Iterable[bytes] | None) -> None:
|
||||
record = https_records.unpack(self.data)
|
||||
if alpn is None:
|
||||
record.params.pop(SVCParamKeys.ALPN.value, None)
|
||||
else:
|
||||
alpn_bytes = b"".join(bytes([len(a)]) + a for a in alpn)
|
||||
record.params[SVCParamKeys.ALPN.value] = alpn_bytes
|
||||
self.data = https_records.pack(record)
|
||||
|
||||
@property
|
||||
def https_ech(self) -> str | None:
|
||||
record = https_records.unpack(self.data)
|
||||
ech_bytes = record.params.get(SVCParamKeys.ECH.value, None)
|
||||
if ech_bytes is not None:
|
||||
return base64.b64encode(ech_bytes).decode("utf-8")
|
||||
else:
|
||||
return None
|
||||
|
||||
@https_ech.setter
|
||||
def https_ech(self, ech: str | None) -> None:
|
||||
record = https_records.unpack(self.data)
|
||||
if ech is None:
|
||||
record.params.pop(SVCParamKeys.ECH.value, None)
|
||||
else:
|
||||
ech_bytes = base64.b64decode(ech.encode("utf-8"))
|
||||
record.params[SVCParamKeys.ECH.value] = ech_bytes
|
||||
self.data = https_records.pack(record)
|
||||
|
||||
def _data_json(self) -> str | HTTPSRecordJSON:
|
||||
try:
|
||||
match self.type:
|
||||
case types.A:
|
||||
return str(self.ipv4_address)
|
||||
case types.AAAA:
|
||||
return str(self.ipv6_address)
|
||||
case types.NS | types.CNAME | types.PTR:
|
||||
return self.domain_name
|
||||
case types.TXT:
|
||||
return self.text
|
||||
case types.HTTPS:
|
||||
return https_records.unpack(self.data).to_json()
|
||||
case _:
|
||||
return f"0x{self.data.hex()}"
|
||||
except Exception:
|
||||
return f"0x{self.data.hex()} (invalid {types.to_str(self.type)} data)"
|
||||
|
||||
def to_json(self) -> dict[str, str | int | HTTPSRecordJSON]:
|
||||
"""
|
||||
Converts the resource record into json for mitmweb.
|
||||
Sync with web/src/flow.ts.
|
||||
"""
|
||||
return {
|
||||
"name": self.name,
|
||||
"type": types.to_str(self.type),
|
||||
"class": classes.to_str(self.class_),
|
||||
"ttl": self.ttl,
|
||||
"data": self._data_json(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, data: dict[str, Any]) -> Self:
|
||||
inst = cls(
|
||||
name=data["name"],
|
||||
type=types.from_str(data["type"]),
|
||||
class_=classes.from_str(data["class"]),
|
||||
ttl=data["ttl"],
|
||||
data=b"",
|
||||
)
|
||||
|
||||
d: str = data["data"]
|
||||
try:
|
||||
match inst.type:
|
||||
case types.A:
|
||||
inst.ipv4_address = IPv4Address(d)
|
||||
case types.AAAA:
|
||||
inst.ipv6_address = IPv6Address(d)
|
||||
case types.NS | types.CNAME | types.PTR:
|
||||
inst.domain_name = d
|
||||
case types.TXT:
|
||||
inst.text = d
|
||||
case types.HTTPS:
|
||||
record = HTTPSRecord.from_json(cast(HTTPSRecordJSON, d))
|
||||
inst.data = https_records.pack(record)
|
||||
case _:
|
||||
raise ValueError
|
||||
except Exception:
|
||||
inst.data = bytes.fromhex(d.removeprefix("0x").partition(" (")[0])
|
||||
|
||||
return inst
|
||||
|
||||
@classmethod
|
||||
def A(cls, name: str, ip: IPv4Address, *, ttl: int = DEFAULT_TTL) -> ResourceRecord:
|
||||
"""Create an IPv4 resource record."""
|
||||
return cls(name, types.A, classes.IN, ttl, ip.packed)
|
||||
|
||||
@classmethod
|
||||
def AAAA(
|
||||
cls, name: str, ip: IPv6Address, *, ttl: int = DEFAULT_TTL
|
||||
) -> ResourceRecord:
|
||||
"""Create an IPv6 resource record."""
|
||||
return cls(name, types.AAAA, classes.IN, ttl, ip.packed)
|
||||
|
||||
@classmethod
|
||||
def CNAME(
|
||||
cls, alias: str, canonical: str, *, ttl: int = DEFAULT_TTL
|
||||
) -> ResourceRecord:
|
||||
"""Create a canonical internet name resource record."""
|
||||
return cls(alias, types.CNAME, classes.IN, ttl, domain_names.pack(canonical))
|
||||
|
||||
@classmethod
|
||||
def PTR(cls, inaddr: str, ptr: str, *, ttl: int = DEFAULT_TTL) -> ResourceRecord:
|
||||
"""Create a canonical internet name resource record."""
|
||||
return cls(inaddr, types.PTR, classes.IN, ttl, domain_names.pack(ptr))
|
||||
|
||||
@classmethod
|
||||
def TXT(cls, name: str, text: str, *, ttl: int = DEFAULT_TTL) -> ResourceRecord:
|
||||
"""Create a textual resource record."""
|
||||
return cls(name, types.TXT, classes.IN, ttl, text.encode("utf-8"))
|
||||
|
||||
@classmethod
|
||||
def HTTPS(
|
||||
cls, name: str, record: HTTPSRecord, ttl: int = DEFAULT_TTL
|
||||
) -> ResourceRecord:
|
||||
"""Create a HTTPS resource record"""
|
||||
return cls(name, types.HTTPS, classes.IN, ttl, https_records.pack(record))
|
||||
|
||||
|
||||
# comments are taken from rfc1035
|
||||
@dataclass
|
||||
class DNSMessage(serializable.SerializableDataclass):
|
||||
HEADER: ClassVar[struct.Struct] = struct.Struct("!HHHHHH")
|
||||
|
||||
id: int
|
||||
"""An identifier assigned by the program that generates any kind of query."""
|
||||
query: bool
|
||||
"""A field that specifies whether this message is a query."""
|
||||
op_code: int
|
||||
"""
|
||||
A field that specifies kind of query in this message.
|
||||
This value is set by the originator of a request and copied into the response.
|
||||
"""
|
||||
authoritative_answer: bool
|
||||
"""
|
||||
This field is valid in responses, and specifies that the responding name server
|
||||
is an authority for the domain name in question section.
|
||||
"""
|
||||
truncation: bool
|
||||
"""Specifies that this message was truncated due to length greater than that permitted on the transmission channel."""
|
||||
recursion_desired: bool
|
||||
"""
|
||||
This field may be set in a query and is copied into the response.
|
||||
If set, it directs the name server to pursue the query recursively.
|
||||
"""
|
||||
recursion_available: bool
|
||||
"""This field is set or cleared in a response, and denotes whether recursive query support is available in the name server."""
|
||||
reserved: int
|
||||
"""Reserved for future use. Must be zero in all queries and responses."""
|
||||
response_code: int
|
||||
"""This field is set as part of responses."""
|
||||
questions: list[Question]
|
||||
"""
|
||||
The question section is used to carry the "question" in most queries, i.e.
|
||||
the parameters that define what is being asked.
|
||||
"""
|
||||
answers: list[ResourceRecord]
|
||||
"""First resource record section."""
|
||||
authorities: list[ResourceRecord]
|
||||
"""Second resource record section."""
|
||||
additionals: list[ResourceRecord]
|
||||
"""Third resource record section."""
|
||||
|
||||
timestamp: float | None = None
|
||||
"""The time at which the message was sent or received."""
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "\r\n".join(
|
||||
map(
|
||||
str,
|
||||
itertools.chain(
|
||||
self.questions, self.answers, self.authorities, self.additionals
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def content(self) -> bytes:
|
||||
return self.packed
|
||||
|
||||
@property
|
||||
def question(self) -> Question | None:
|
||||
"""DNS practically only supports a single question at the
|
||||
same time, so this is a shorthand for this."""
|
||||
if len(self.questions) == 1:
|
||||
return self.questions[0]
|
||||
return None
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
"""Returns the cumulative data size of all resource record sections."""
|
||||
return sum(
|
||||
len(x.data)
|
||||
for x in itertools.chain.from_iterable(
|
||||
[self.answers, self.authorities, self.additionals]
|
||||
)
|
||||
)
|
||||
|
||||
def fail(self, response_code: int) -> DNSMessage:
|
||||
if response_code == response_codes.NOERROR:
|
||||
raise ValueError("response_code must be an error code.")
|
||||
return DNSMessage(
|
||||
timestamp=time.time(),
|
||||
id=self.id,
|
||||
query=False,
|
||||
op_code=self.op_code,
|
||||
authoritative_answer=False,
|
||||
truncation=False,
|
||||
recursion_desired=self.recursion_desired,
|
||||
recursion_available=False,
|
||||
reserved=0,
|
||||
response_code=response_code,
|
||||
questions=self.questions,
|
||||
answers=[],
|
||||
authorities=[],
|
||||
additionals=[],
|
||||
)
|
||||
|
||||
def succeed(self, answers: list[ResourceRecord]) -> DNSMessage:
|
||||
return DNSMessage(
|
||||
timestamp=time.time(),
|
||||
id=self.id,
|
||||
query=False,
|
||||
op_code=self.op_code,
|
||||
authoritative_answer=False,
|
||||
truncation=False,
|
||||
recursion_desired=self.recursion_desired,
|
||||
recursion_available=True,
|
||||
reserved=0,
|
||||
response_code=response_codes.NOERROR,
|
||||
questions=self.questions,
|
||||
answers=answers,
|
||||
authorities=[],
|
||||
additionals=[],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def unpack(cls, buffer: bytes, timestamp: float | None = None) -> DNSMessage:
|
||||
"""Converts the entire given buffer into a DNS message."""
|
||||
length, msg = cls.unpack_from(buffer, 0, timestamp)
|
||||
if length != len(buffer):
|
||||
raise struct.error(f"unpack requires a buffer of {length} bytes")
|
||||
return msg
|
||||
|
||||
@classmethod
|
||||
def unpack_from(
|
||||
cls, buffer: bytes | bytearray, offset: int, timestamp: float | None = None
|
||||
) -> tuple[int, DNSMessage]:
|
||||
"""Converts the buffer from a given offset into a DNS message and also returns its length."""
|
||||
(
|
||||
id,
|
||||
flags,
|
||||
len_questions,
|
||||
len_answers,
|
||||
len_authorities,
|
||||
len_additionals,
|
||||
) = DNSMessage.HEADER.unpack_from(buffer, offset)
|
||||
msg = DNSMessage(
|
||||
timestamp=timestamp,
|
||||
id=id,
|
||||
query=(flags & (1 << 15)) == 0,
|
||||
op_code=(flags >> 11) & 0b1111,
|
||||
authoritative_answer=(flags & (1 << 10)) != 0,
|
||||
truncation=(flags & (1 << 9)) != 0,
|
||||
recursion_desired=(flags & (1 << 8)) != 0,
|
||||
recursion_available=(flags & (1 << 7)) != 0,
|
||||
reserved=(flags >> 4) & 0b111,
|
||||
response_code=flags & 0b1111,
|
||||
questions=[],
|
||||
answers=[],
|
||||
authorities=[],
|
||||
additionals=[],
|
||||
)
|
||||
offset += DNSMessage.HEADER.size
|
||||
cached_names = domain_names.cache()
|
||||
|
||||
def unpack_domain_name() -> str:
|
||||
nonlocal buffer, offset, cached_names
|
||||
name, length = domain_names.unpack_from_with_compression(
|
||||
buffer, offset, cached_names
|
||||
)
|
||||
offset += length
|
||||
return name
|
||||
|
||||
for i in range(0, len_questions):
|
||||
try:
|
||||
name = unpack_domain_name()
|
||||
type, class_ = Question.HEADER.unpack_from(buffer, offset)
|
||||
offset += Question.HEADER.size
|
||||
msg.questions.append(Question(name=name, type=type, class_=class_))
|
||||
except struct.error as e:
|
||||
raise struct.error(f"question #{i}: {e}")
|
||||
|
||||
def unpack_rrs(
|
||||
section: list[ResourceRecord], section_name: str, count: int
|
||||
) -> None:
|
||||
nonlocal buffer, offset
|
||||
for i in range(0, count):
|
||||
try:
|
||||
name = unpack_domain_name()
|
||||
type, class_, ttl, len_data = ResourceRecord.HEADER.unpack_from(
|
||||
buffer, offset
|
||||
)
|
||||
offset += ResourceRecord.HEADER.size
|
||||
end_data = offset + len_data
|
||||
if len(buffer) < end_data:
|
||||
raise struct.error(
|
||||
f"unpack requires a data buffer of {len_data} bytes"
|
||||
)
|
||||
data = buffer[offset:end_data]
|
||||
|
||||
if domain_names.record_data_can_have_compression(type):
|
||||
data = domain_names.decompress_from_record_data(
|
||||
buffer, offset, end_data, cached_names
|
||||
)
|
||||
|
||||
section.append(ResourceRecord(name, type, class_, ttl, data))
|
||||
offset += len_data
|
||||
except struct.error as e:
|
||||
raise struct.error(f"{section_name} #{i}: {e}")
|
||||
|
||||
unpack_rrs(msg.answers, "answer", len_answers)
|
||||
unpack_rrs(msg.authorities, "authority", len_authorities)
|
||||
unpack_rrs(msg.additionals, "additional", len_additionals)
|
||||
return (offset, msg)
|
||||
|
||||
@property
|
||||
def packed(self) -> bytes:
|
||||
"""Converts the message into network bytes."""
|
||||
if self.id < 0 or self.id > 65535:
|
||||
raise ValueError(f"DNS message's id {self.id} is out of bounds.")
|
||||
flags = 0
|
||||
if not self.query:
|
||||
flags |= 1 << 15
|
||||
if self.op_code < 0 or self.op_code > 0b1111:
|
||||
raise ValueError(f"DNS message's op_code {self.op_code} is out of bounds.")
|
||||
flags |= self.op_code << 11
|
||||
if self.authoritative_answer:
|
||||
flags |= 1 << 10
|
||||
if self.truncation:
|
||||
flags |= 1 << 9
|
||||
if self.recursion_desired:
|
||||
flags |= 1 << 8
|
||||
if self.recursion_available:
|
||||
flags |= 1 << 7
|
||||
if self.reserved < 0 or self.reserved > 0b111:
|
||||
raise ValueError(
|
||||
f"DNS message's reserved value of {self.reserved} is out of bounds."
|
||||
)
|
||||
flags |= self.reserved << 4
|
||||
if self.response_code < 0 or self.response_code > 0b1111:
|
||||
raise ValueError(
|
||||
f"DNS message's response_code {self.response_code} is out of bounds."
|
||||
)
|
||||
flags |= self.response_code
|
||||
data = bytearray()
|
||||
data.extend(
|
||||
DNSMessage.HEADER.pack(
|
||||
self.id,
|
||||
flags,
|
||||
len(self.questions),
|
||||
len(self.answers),
|
||||
len(self.authorities),
|
||||
len(self.additionals),
|
||||
)
|
||||
)
|
||||
# TODO implement compression
|
||||
for question in self.questions:
|
||||
data.extend(domain_names.pack(question.name))
|
||||
data.extend(Question.HEADER.pack(question.type, question.class_))
|
||||
for rr in (*self.answers, *self.authorities, *self.additionals):
|
||||
data.extend(domain_names.pack(rr.name))
|
||||
data.extend(
|
||||
ResourceRecord.HEADER.pack(rr.type, rr.class_, rr.ttl, len(rr.data))
|
||||
)
|
||||
data.extend(rr.data)
|
||||
return bytes(data)
|
||||
|
||||
def to_json(self) -> dict:
|
||||
"""
|
||||
Converts the message into json for mitmweb.
|
||||
Sync with web/src/flow.ts.
|
||||
"""
|
||||
ret = {
|
||||
"id": self.id,
|
||||
"query": self.query,
|
||||
"op_code": op_codes.to_str(self.op_code),
|
||||
"authoritative_answer": self.authoritative_answer,
|
||||
"truncation": self.truncation,
|
||||
"recursion_desired": self.recursion_desired,
|
||||
"recursion_available": self.recursion_available,
|
||||
"response_code": response_codes.to_str(self.response_code),
|
||||
"status_code": response_codes.http_equiv_status_code(self.response_code),
|
||||
"questions": [question.to_json() for question in self.questions],
|
||||
"answers": [rr.to_json() for rr in self.answers],
|
||||
"authorities": [rr.to_json() for rr in self.authorities],
|
||||
"additionals": [rr.to_json() for rr in self.additionals],
|
||||
"size": self.size,
|
||||
}
|
||||
if self.timestamp:
|
||||
ret["timestamp"] = self.timestamp
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, data: Any) -> DNSMessage:
|
||||
"""Reconstruct a DNS message from JSON."""
|
||||
inst = cls(
|
||||
id=data["id"],
|
||||
query=data["query"],
|
||||
op_code=op_codes.from_str(data["op_code"]),
|
||||
authoritative_answer=data["authoritative_answer"],
|
||||
truncation=data["truncation"],
|
||||
recursion_desired=data["recursion_desired"],
|
||||
recursion_available=data["recursion_available"],
|
||||
reserved=0,
|
||||
response_code=response_codes.from_str(data["response_code"]),
|
||||
questions=[Question.from_json(x) for x in data["questions"]],
|
||||
answers=[ResourceRecord.from_json(x) for x in data["answers"]],
|
||||
authorities=[ResourceRecord.from_json(x) for x in data["authorities"]],
|
||||
additionals=[ResourceRecord.from_json(x) for x in data["additionals"]],
|
||||
)
|
||||
if ts := data.get("timestamp"):
|
||||
inst.timestamp = ts
|
||||
return inst
|
||||
|
||||
def copy(self) -> DNSMessage:
|
||||
# we keep the copy semantics but change the ID generation
|
||||
state = self.get_state()
|
||||
state["id"] = random.randint(0, 65535)
|
||||
return DNSMessage.from_state(state)
|
||||
|
||||
|
||||
class DNSFlow(flow.Flow):
|
||||
"""A DNSFlow is a collection of DNS messages representing a single DNS query."""
|
||||
|
||||
request: DNSMessage
|
||||
"""The DNS request."""
|
||||
response: DNSMessage | None = None
|
||||
"""The DNS response."""
|
||||
|
||||
def get_state(self) -> serializable.State:
|
||||
return {
|
||||
**super().get_state(),
|
||||
"request": self.request.get_state(),
|
||||
"response": self.response.get_state() if self.response else None,
|
||||
}
|
||||
|
||||
def set_state(self, state: serializable.State) -> None:
|
||||
self.request = DNSMessage.from_state(state.pop("request"))
|
||||
self.response = (
|
||||
DNSMessage.from_state(r) if (r := state.pop("response")) else None
|
||||
)
|
||||
super().set_state(state)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<DNSFlow\r\n request={self.request!r}\r\n response={self.response!r}\r\n>"
|
||||
Reference in New Issue
Block a user