183 lines
6.0 KiB
Python
183 lines
6.0 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import ipaddress
|
|
import logging
|
|
import socket
|
|
from collections.abc import Sequence
|
|
from functools import cache
|
|
from typing import Protocol
|
|
|
|
import mitmproxy_rs
|
|
from mitmproxy import ctx
|
|
from mitmproxy import dns
|
|
from mitmproxy.flow import Error
|
|
from mitmproxy.proxy import mode_specs
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class DnsResolver:
|
|
def load(self, loader):
|
|
loader.add_option(
|
|
"dns_use_hosts_file",
|
|
bool,
|
|
True,
|
|
"Use the hosts file for DNS lookups in regular DNS mode/wireguard mode.",
|
|
)
|
|
|
|
loader.add_option(
|
|
"dns_name_servers",
|
|
Sequence[str],
|
|
[],
|
|
"Name servers to use for lookups in regular DNS mode/wireguard mode. Default: operating system's name servers",
|
|
)
|
|
|
|
def configure(self, updated):
|
|
if "dns_use_hosts_file" in updated or "dns_name_servers" in updated:
|
|
self.resolver.cache_clear()
|
|
self.name_servers.cache_clear()
|
|
|
|
@cache
|
|
def name_servers(self) -> list[str]:
|
|
"""
|
|
Returns the operating system's name servers unless custom name servers are set.
|
|
On error, an empty list is returned.
|
|
"""
|
|
try:
|
|
return (
|
|
ctx.options.dns_name_servers
|
|
or mitmproxy_rs.dns.get_system_dns_servers()
|
|
)
|
|
except RuntimeError as e:
|
|
logger.warning(
|
|
f"Failed to get system dns servers: {e}\n"
|
|
f"The dns_name_servers option needs to be set manually."
|
|
)
|
|
return []
|
|
|
|
@cache
|
|
def resolver(self) -> Resolver:
|
|
"""
|
|
Returns:
|
|
The DNS resolver to use.
|
|
Raises:
|
|
MissingNameServers, if name servers are unknown and `dns_use_hosts_file` is disabled.
|
|
"""
|
|
if ns := self.name_servers():
|
|
# We always want to use our own resolver if name server info is available.
|
|
return mitmproxy_rs.dns.DnsResolver(
|
|
name_servers=ns,
|
|
use_hosts_file=ctx.options.dns_use_hosts_file,
|
|
)
|
|
elif ctx.options.dns_use_hosts_file:
|
|
# Fallback to getaddrinfo as hickory's resolver isn't as reliable
|
|
# as we would like it to be (https://github.com/mitmproxy/mitmproxy/issues/7064).
|
|
return GetaddrinfoFallbackResolver()
|
|
else:
|
|
raise MissingNameServers()
|
|
|
|
async def dns_request(self, flow: dns.DNSFlow) -> None:
|
|
if self._should_resolve(flow):
|
|
all_ip_lookups = (
|
|
flow.request.query
|
|
and flow.request.op_code == dns.op_codes.QUERY
|
|
and flow.request.question
|
|
and flow.request.question.class_ == dns.classes.IN
|
|
and flow.request.question.type in (dns.types.A, dns.types.AAAA)
|
|
)
|
|
if all_ip_lookups:
|
|
try:
|
|
flow.response = await self.resolve(flow.request)
|
|
except MissingNameServers:
|
|
flow.error = Error("Cannot resolve, dns_name_servers unknown.")
|
|
elif name_servers := self.name_servers():
|
|
# For other records, the best we can do is to forward the query
|
|
# to an upstream server.
|
|
flow.server_conn.address = (name_servers[0], 53)
|
|
else:
|
|
flow.error = Error("Cannot resolve, dns_name_servers unknown.")
|
|
|
|
@staticmethod
|
|
def _should_resolve(flow: dns.DNSFlow) -> bool:
|
|
return (
|
|
(
|
|
isinstance(flow.client_conn.proxy_mode, mode_specs.DnsMode)
|
|
or (
|
|
isinstance(flow.client_conn.proxy_mode, mode_specs.WireGuardMode)
|
|
and flow.server_conn.address == ("10.0.0.53", 53)
|
|
)
|
|
)
|
|
and flow.live
|
|
and not flow.response
|
|
and not flow.error
|
|
)
|
|
|
|
async def resolve(
|
|
self,
|
|
message: dns.DNSMessage,
|
|
) -> dns.DNSMessage:
|
|
q = message.question
|
|
assert q
|
|
try:
|
|
if q.type == dns.types.A:
|
|
ip_addrs = await self.resolver().lookup_ipv4(q.name)
|
|
else:
|
|
ip_addrs = await self.resolver().lookup_ipv6(q.name)
|
|
except socket.gaierror as e:
|
|
match e.args[0]:
|
|
case socket.EAI_NONAME:
|
|
return message.fail(dns.response_codes.NXDOMAIN)
|
|
case socket.EAI_NODATA:
|
|
ip_addrs = []
|
|
case _:
|
|
return message.fail(dns.response_codes.SERVFAIL)
|
|
|
|
return message.succeed(
|
|
[
|
|
dns.ResourceRecord(
|
|
name=q.name,
|
|
type=q.type,
|
|
class_=q.class_,
|
|
ttl=dns.ResourceRecord.DEFAULT_TTL,
|
|
data=ipaddress.ip_address(ip).packed,
|
|
)
|
|
for ip in ip_addrs
|
|
]
|
|
)
|
|
|
|
|
|
class Resolver(Protocol):
|
|
async def lookup_ip(self, domain: str) -> list[str]: # pragma: no cover
|
|
...
|
|
|
|
async def lookup_ipv4(self, domain: str) -> list[str]: # pragma: no cover
|
|
...
|
|
|
|
async def lookup_ipv6(self, domain: str) -> list[str]: # pragma: no cover
|
|
...
|
|
|
|
|
|
class GetaddrinfoFallbackResolver(Resolver):
|
|
async def lookup_ip(self, domain: str) -> list[str]:
|
|
return await self._lookup(domain, socket.AF_UNSPEC)
|
|
|
|
async def lookup_ipv4(self, domain: str) -> list[str]:
|
|
return await self._lookup(domain, socket.AF_INET)
|
|
|
|
async def lookup_ipv6(self, domain: str) -> list[str]:
|
|
return await self._lookup(domain, socket.AF_INET6)
|
|
|
|
async def _lookup(self, domain: str, family: socket.AddressFamily) -> list[str]:
|
|
addrinfos = await asyncio.get_running_loop().getaddrinfo(
|
|
host=domain,
|
|
port=None,
|
|
family=family,
|
|
type=socket.SOCK_STREAM,
|
|
)
|
|
return [addrinfo[4][0] for addrinfo in addrinfos]
|
|
|
|
|
|
class MissingNameServers(RuntimeError):
|
|
pass
|