2025-12-25 upload

This commit is contained in:
“shengyudong”
2025-12-25 11:16:59 +08:00
commit 322ac74336
2241 changed files with 639966 additions and 0 deletions

View File

@@ -0,0 +1,43 @@
import ipaddress
import re
from typing import AnyStr
# Allow underscore in host name
# Note: This could be a DNS label, a hostname, a FQDN, or an IP
_label_valid = re.compile(rb"[A-Z\d\-_]{1,63}$", re.IGNORECASE)
def is_valid_host(host: AnyStr) -> bool:
"""
Checks if the passed bytes are a valid DNS hostname or an IPv4/IPv6 address.
"""
if isinstance(host, str):
try:
host_bytes = host.encode("idna")
except UnicodeError:
return False
else:
host_bytes = host
try:
host_bytes.decode("idna")
except ValueError:
return False
# RFC1035: 255 bytes or less.
if len(host_bytes) > 255:
return False
if host_bytes and host_bytes.endswith(b"."):
host_bytes = host_bytes[:-1]
# DNS hostname
if all(_label_valid.match(x) for x in host_bytes.split(b".")):
return True
# IPv4/IPv6 address
try:
ipaddress.ip_address(host_bytes.decode("idna"))
return True
except ValueError:
return False
def is_valid_port(port: int) -> bool:
return 0 <= port <= 65535

View File

@@ -0,0 +1,19 @@
IN = 1
CH = 3
HS = 4
NONE = 254
ANY = 255
_STRINGS = {IN: "IN", CH: "CH", HS: "HS", NONE: "NONE", ANY: "ANY"}
_INTS = {v: k for k, v in _STRINGS.items()}
def to_str(class_: int) -> str:
return _STRINGS.get(class_, f"CLASS({class_})")
def from_str(class_: str) -> int:
try:
return _INTS[class_]
except KeyError:
return int(class_.removeprefix("CLASS(").removesuffix(")"))

View File

@@ -0,0 +1,169 @@
import struct
from typing import Optional
from . import types
_LABEL_SIZE = struct.Struct("!B")
_POINTER_OFFSET = struct.Struct("!H")
_POINTER_INDICATOR = 0b11000000
Cache = dict[int, Optional[tuple[str, int]]]
def cache() -> Cache:
return dict()
def _unpack_label_into(labels: list[str], buffer: bytes, offset: int) -> int:
(size,) = _LABEL_SIZE.unpack_from(buffer, offset)
if size >= 64:
raise struct.error(f"unpack encountered a label of length {size}")
elif size == 0:
return _LABEL_SIZE.size
else:
offset += _LABEL_SIZE.size
end_label = offset + size
if len(buffer) < end_label:
raise struct.error(f"unpack requires a label buffer of {size} bytes")
try:
labels.append(buffer[offset:end_label].decode("idna"))
except UnicodeDecodeError:
raise struct.error(
f"unpack encountered an illegal characters at offset {offset}"
)
return _LABEL_SIZE.size + size
def unpack_from_with_compression(
buffer: bytes, offset: int, cache: Cache
) -> tuple[str, int]:
if offset in cache:
result = cache[offset]
if result is None:
raise struct.error(f"unpack encountered domain name loop")
else:
cache[offset] = None # this will indicate that the offset is being unpacked
start_offset = offset
labels = []
while True:
(size,) = _LABEL_SIZE.unpack_from(buffer, offset)
if size & _POINTER_INDICATOR == _POINTER_INDICATOR:
(pointer,) = _POINTER_OFFSET.unpack_from(buffer, offset)
offset += _POINTER_OFFSET.size
label, _ = unpack_from_with_compression(
buffer, pointer & ~(_POINTER_INDICATOR << 8), cache
)
labels.append(label)
break
else:
offset += _unpack_label_into(labels, buffer, offset)
if size == 0:
break
result = ".".join(labels), (offset - start_offset)
cache[start_offset] = result
return result
def unpack_from(buffer: bytes, offset: int) -> tuple[str, int]:
"""Converts RDATA into a domain name without pointer compression from a given offset and also returns the binary size."""
labels: list[str] = []
while True:
(size,) = _LABEL_SIZE.unpack_from(buffer, offset)
if size & _POINTER_INDICATOR == _POINTER_INDICATOR:
raise struct.error(
f"unpack encountered a pointer which is not supported in RDATA"
)
else:
offset += _unpack_label_into(labels, buffer, offset)
if size == 0:
break
return ".".join(labels), offset
def unpack(buffer: bytes) -> str:
"""Converts RDATA into a domain name without pointer compression."""
name, length = unpack_from(buffer, 0)
if length != len(buffer):
raise struct.error(f"unpack requires a buffer of {length} bytes")
return name
def pack(name: str) -> bytes:
"""Converts a domain name into RDATA without pointer compression."""
buffer = bytearray()
if len(name) > 0:
for part in name.split("."):
label = part.encode("idna")
size = len(label)
if size == 0:
raise ValueError(f"domain name '{name}' contains empty labels")
if size >= 64: # pragma: no cover
# encoding with 'idna' will already have raised an exception earlier
raise ValueError(
f"encoded label '{part}' of domain name '{name}' is too long ({size} bytes)"
)
buffer.extend(_LABEL_SIZE.pack(size))
buffer.extend(label)
buffer.extend(_LABEL_SIZE.pack(0))
return bytes(buffer)
def record_data_can_have_compression(record_type: int) -> bool:
if record_type in (
types.CNAME,
types.HINFO,
types.MB,
types.MD,
types.MF,
types.MG,
types.MINFO,
types.MR,
types.MX,
types.NS,
types.PTR,
types.SOA,
types.TXT,
types.RP,
types.AFSDB,
types.RT,
types.SIG,
types.PX,
types.NXT,
types.NAPTR,
types.SRV,
):
return True
return False
def decompress_from_record_data(
buffer: bytes, offset: int, end_data: int, cached_names: Cache
) -> bytes:
# we decompress compression pointers in RDATA by iterating through each byte and checking
# if it has a leading 0b11, if so we try to decompress it and update it in the data variable.
data = bytearray(buffer[offset:end_data])
data_offset = 0
decompress_size = 0
while data_offset < end_data - offset:
if buffer[offset + data_offset] & _POINTER_INDICATOR == _POINTER_INDICATOR:
try:
(
rr_name,
rr_name_len,
) = unpack_from_with_compression(
buffer, offset + data_offset, cached_names
)
data[
data_offset + decompress_size : data_offset
+ decompress_size
+ rr_name_len
] = pack(rr_name)
decompress_size += len(rr_name)
data_offset += rr_name_len
continue
except struct.error:
# the byte isn't actually a domain name compression pointer but some other data type
pass
data_offset += 1
return bytes(data)

View File

@@ -0,0 +1,134 @@
import enum
import struct
from dataclasses import dataclass
from typing import Self
from ...utils import strutils
from . import domain_names
"""
HTTPS records are formatted as follows (as per RFC9460):
- a 2-octet field for SvcPriority as an integer in network byte order.
- the uncompressed, fully qualified TargetName, represented as a sequence of length-prefixed labels per Section 3.1 of [RFC1035].
- the SvcParams, consuming the remainder of the record (so smaller than 65535 octets and constrained by the RDATA and DNS message sizes).
When the list of SvcParams is non-empty, it contains a series of SvcParamKey=SvcParamValue pairs, represented as:
- a 2-octet field containing the SvcParamKey as an integer in network byte order. (See Section 14.3.2 for the defined values.)
- a 2-octet field containing the length of the SvcParamValue as an integer between 0 and 65535 in network byte order.
- an octet string of this length whose contents are the SvcParamValue in a format determined by the SvcParamKey.
https://datatracker.ietf.org/doc/rfc9460/
https://datatracker.ietf.org/doc/rfc1035/
"""
class SVCParamKeys(enum.Enum):
MANDATORY = 0
ALPN = 1
NO_DEFAULT_ALPN = 2
PORT = 3
IPV4HINT = 4
ECH = 5
IPV6HINT = 6
type HTTPSRecordJSON = dict[str | int, str | int]
@dataclass
class HTTPSRecord:
priority: int
target_name: str
params: dict[int, bytes]
def __repr__(self):
return str(self.to_json())
def to_json(self) -> HTTPSRecordJSON:
ret: HTTPSRecordJSON = {
"target_name": self.target_name,
"priority": self.priority,
}
typ: str | int
for typ, val in self.params.items():
try:
typ = SVCParamKeys(typ).name.lower()
except ValueError:
pass
ret[typ] = strutils.bytes_to_escaped_str(val)
return ret
@classmethod
def from_json(cls, data: HTTPSRecordJSON) -> Self:
target_name = data.pop("target_name")
assert isinstance(target_name, str)
priority = data.pop("priority")
assert isinstance(priority, int)
params: dict[int, bytes] = {}
for k, v in data.items():
if isinstance(k, str):
k = SVCParamKeys[k.upper()].value
assert isinstance(v, str)
params[k] = strutils.escaped_str_to_bytes(v)
return cls(target_name=target_name, priority=priority, params=params)
def _unpack_params(data: bytes, offset: int) -> dict[int, bytes]:
"""Unpacks the service parameters from the given offset."""
params = {}
while offset < len(data):
param_type = struct.unpack("!H", data[offset : offset + 2])[0]
offset += 2
param_length = struct.unpack("!H", data[offset : offset + 2])[0]
offset += 2
if offset + param_length > len(data):
raise struct.error(
"unpack requires a buffer of %i bytes" % (offset + param_length)
)
param_value = data[offset : offset + param_length]
offset += param_length
params[param_type] = param_value
return params
def unpack(data: bytes) -> HTTPSRecord:
"""
Unpacks HTTPS RDATA from byte data.
Raises:
struct.error if the record is malformed.
"""
offset = 0
# Priority (2 bytes)
priority = struct.unpack("!h", data[offset : offset + 2])[0]
offset += 2
# TargetName (variable length)
target_name, offset = domain_names.unpack_from(data, offset)
# Service Parameters (remaining bytes)
params = _unpack_params(data, offset)
return HTTPSRecord(priority=priority, target_name=target_name, params=params)
def _pack_params(params: dict[int, bytes]) -> bytes:
"""Converts the service parameters into the raw byte format"""
buffer = bytearray()
for k, v in params.items():
buffer.extend(struct.pack("!H", k))
buffer.extend(struct.pack("!H", len(v)))
buffer.extend(v)
return bytes(buffer)
def pack(record: HTTPSRecord) -> bytes:
"""Packs the HTTPS record into its bytes form."""
buffer = bytearray()
buffer.extend(struct.pack("!h", record.priority))
buffer.extend(domain_names.pack(record.target_name))
buffer.extend(_pack_params(record.params))
return bytes(buffer)

View File

@@ -0,0 +1,27 @@
QUERY = 0
IQUERY = 1
STATUS = 2
NOTIFY = 4
UPDATE = 5
DSO = 6
_STRINGS = {
QUERY: "QUERY",
IQUERY: "IQUERY",
STATUS: "STATUS",
NOTIFY: "NOTIFY",
UPDATE: "UPDATE",
DSO: "DSO",
}
_INTS = {v: k for k, v in _STRINGS.items()}
def to_str(op_code: int) -> str:
return _STRINGS.get(op_code, f"OPCODE({op_code})")
def from_str(op_code: str) -> int:
try:
return _INTS[op_code]
except KeyError:
return int(op_code.removeprefix("OPCODE(").removesuffix(")"))

View File

@@ -0,0 +1,58 @@
NOERROR = 0
FORMERR = 1
SERVFAIL = 2
NXDOMAIN = 3
NOTIMP = 4
REFUSED = 5
YXDOMAIN = 6
YXRRSET = 7
NXRRSET = 8
NOTAUTH = 9
NOTZONE = 10
DSOTYPENI = 11
_CODES = {
NOERROR: 200,
FORMERR: 400,
SERVFAIL: 500,
NXDOMAIN: 404,
NOTIMP: 501,
REFUSED: 403,
YXDOMAIN: 409,
YXRRSET: 409,
NXRRSET: 410,
NOTAUTH: 401,
NOTZONE: 404,
DSOTYPENI: 501,
}
_STRINGS = {
NOERROR: "NOERROR",
FORMERR: "FORMERR",
SERVFAIL: "SERVFAIL",
NXDOMAIN: "NXDOMAIN",
NOTIMP: "NOTIMP",
REFUSED: "REFUSED",
YXDOMAIN: "YXDOMAIN",
YXRRSET: "YXRRSET",
NXRRSET: "NXRRSET",
NOTAUTH: "NOTAUTH",
NOTZONE: "NOTZONE",
DSOTYPENI: "DSOTYPENI",
}
_INTS = {v: k for k, v in _STRINGS.items()}
def http_equiv_status_code(response_code: int) -> int:
return _CODES.get(response_code, 500)
def to_str(response_code: int) -> str:
return _STRINGS.get(response_code, f"RCODE({response_code})")
def from_str(response_code: str) -> int:
try:
return _INTS[response_code]
except KeyError:
return int(response_code.removeprefix("RCODE(").removesuffix(")"))

View File

@@ -0,0 +1,193 @@
A = 1
NS = 2
MD = 3
MF = 4
CNAME = 5
SOA = 6
MB = 7
MG = 8
MR = 9
NULL = 10
WKS = 11
PTR = 12
HINFO = 13
MINFO = 14
MX = 15
TXT = 16
RP = 17
AFSDB = 18
X25 = 19
ISDN = 20
RT = 21
NSAP = 22
NSAP_PTR = 23
SIG = 24
KEY = 25
PX = 26
GPOS = 27
AAAA = 28
LOC = 29
NXT = 30
EID = 31
NIMLOC = 32
SRV = 33
ATMA = 34
NAPTR = 35
KX = 36
CERT = 37
A6 = 38
DNAME = 39
SINK = 40
OPT = 41
APL = 42
DS = 43
SSHFP = 44
IPSECKEY = 45
RRSIG = 46
NSEC = 47
DNSKEY = 48
DHCID = 49
NSEC3 = 50
NSEC3PARAM = 51
TLSA = 52
SMIMEA = 53
HIP = 55
NINFO = 56
RKEY = 57
TALINK = 58
CDS = 59
CDNSKEY = 60
OPENPGPKEY = 61
CSYNC = 62
ZONEMD = 63
SVCB = 64
HTTPS = 65
SPF = 99
UINFO = 100
UID = 101
GID = 102
UNSPEC = 103
NID = 104
L32 = 105
L64 = 106
LP = 107
EUI48 = 108
EUI64 = 109
TKEY = 249
TSIG = 250
IXFR = 251
AXFR = 252
MAILB = 253
MAILA = 254
ANY = 255
URI = 256
CAA = 257
AVC = 258
DOA = 259
AMTRELAY = 260
TA = 32768
DLV = 32769
_STRINGS = {
A: "A",
NS: "NS",
MD: "MD",
MF: "MF",
CNAME: "CNAME",
SOA: "SOA",
MB: "MB",
MG: "MG",
MR: "MR",
NULL: "NULL",
WKS: "WKS",
PTR: "PTR",
HINFO: "HINFO",
MINFO: "MINFO",
MX: "MX",
TXT: "TXT",
RP: "RP",
AFSDB: "AFSDB",
X25: "X25",
ISDN: "ISDN",
RT: "RT",
NSAP: "NSAP",
NSAP_PTR: "NSAP_PTR",
SIG: "SIG",
KEY: "KEY",
PX: "PX",
GPOS: "GPOS",
AAAA: "AAAA",
LOC: "LOC",
NXT: "NXT",
EID: "EID",
NIMLOC: "NIMLOC",
SRV: "SRV",
ATMA: "ATMA",
NAPTR: "NAPTR",
KX: "KX",
CERT: "CERT",
A6: "A6",
DNAME: "DNAME",
SINK: "SINK",
OPT: "OPT",
APL: "APL",
DS: "DS",
SSHFP: "SSHFP",
IPSECKEY: "IPSECKEY",
RRSIG: "RRSIG",
NSEC: "NSEC",
DNSKEY: "DNSKEY",
DHCID: "DHCID",
NSEC3: "NSEC3",
NSEC3PARAM: "NSEC3PARAM",
TLSA: "TLSA",
SMIMEA: "SMIMEA",
HIP: "HIP",
NINFO: "NINFO",
RKEY: "RKEY",
TALINK: "TALINK",
CDS: "CDS",
CDNSKEY: "CDNSKEY",
OPENPGPKEY: "OPENPGPKEY",
CSYNC: "CSYNC",
ZONEMD: "ZONEMD",
SVCB: "SVCB",
HTTPS: "HTTPS",
SPF: "SPF",
UINFO: "UINFO",
UID: "UID",
GID: "GID",
UNSPEC: "UNSPEC",
NID: "NID",
L32: "L32",
L64: "L64",
LP: "LP",
EUI48: "EUI48",
EUI64: "EUI64",
TKEY: "TKEY",
TSIG: "TSIG",
IXFR: "IXFR",
AXFR: "AXFR",
MAILB: "MAILB",
MAILA: "MAILA",
ANY: "ANY",
URI: "URI",
CAA: "CAA",
AVC: "AVC",
DOA: "DOA",
AMTRELAY: "AMTRELAY",
TA: "TA",
DLV: "DLV",
}
_INTS = {v: k for k, v in _STRINGS.items()}
def to_str(type_: int) -> str:
return _STRINGS.get(type_, f"TYPE({type_})")
def from_str(type_: str) -> int:
try:
return _INTS[type_]
except KeyError:
return int(type_.removeprefix("TYPE(").removesuffix(")"))

View File

@@ -0,0 +1,235 @@
"""
Utility functions for decoding response bodies.
"""
import codecs
import collections
import gzip
import zlib
from io import BytesIO
from typing import overload
import brotli
import zstandard as zstd
# We have a shared single-element cache for encoding and decoding.
# This is quite useful in practice, e.g.
# flow.request.content = flow.request.content.replace(b"foo", b"bar")
# does not require an .encode() call if content does not contain b"foo"
CachedDecode = collections.namedtuple("CachedDecode", "encoded encoding errors decoded")
_cache = CachedDecode(None, None, None, None)
@overload
def decode(encoded: None, encoding: str, errors: str = "strict") -> None: ...
@overload
def decode(encoded: str, encoding: str, errors: str = "strict") -> str: ...
@overload
def decode(encoded: bytes, encoding: str, errors: str = "strict") -> str | bytes: ...
def decode(
encoded: None | str | bytes, encoding: str, errors: str = "strict"
) -> None | str | bytes:
"""
Decode the given input object
Returns:
The decoded value
Raises:
ValueError, if decoding fails.
"""
if encoded is None:
return None
encoding = encoding.lower()
global _cache
cached = (
isinstance(encoded, bytes)
and _cache.encoded == encoded
and _cache.encoding == encoding
and _cache.errors == errors
)
if cached:
return _cache.decoded
try:
try:
decoded = custom_decode[encoding](encoded)
except KeyError:
decoded = codecs.decode(encoded, encoding, errors) # type: ignore
if encoding in ("gzip", "deflate", "deflateraw", "br", "zstd"):
_cache = CachedDecode(encoded, encoding, errors, decoded)
return decoded
except TypeError:
raise
except Exception as e:
raise ValueError(
"{} when decoding {} with {}: {}".format(
type(e).__name__,
repr(encoded)[:10],
repr(encoding),
repr(e),
)
)
@overload
def encode(decoded: None, encoding: str, errors: str = "strict") -> None: ...
@overload
def encode(decoded: str, encoding: str, errors: str = "strict") -> str | bytes: ...
@overload
def encode(decoded: bytes, encoding: str, errors: str = "strict") -> bytes: ...
def encode(
decoded: None | str | bytes, encoding, errors="strict"
) -> None | str | bytes:
"""
Encode the given input object
Returns:
The encoded value
Raises:
ValueError, if encoding fails.
"""
if decoded is None:
return None
encoding = encoding.lower()
global _cache
cached = (
isinstance(decoded, bytes)
and _cache.decoded == decoded
and _cache.encoding == encoding
and _cache.errors == errors
)
if cached:
return _cache.encoded
try:
try:
encoded = custom_encode[encoding](decoded)
except KeyError:
encoded = codecs.encode(decoded, encoding, errors) # type: ignore
if encoding in ("gzip", "deflate", "deflateraw", "br", "zstd"):
_cache = CachedDecode(encoded, encoding, errors, decoded)
return encoded
except TypeError:
raise
except Exception as e:
raise ValueError(
"{} when encoding {} with {}: {}".format(
type(e).__name__,
repr(decoded)[:10],
repr(encoding),
repr(e),
)
)
def identity(content):
"""
Returns content unchanged. Identity is the default value of
Accept-Encoding headers.
"""
return content
def decode_gzip(content: bytes) -> bytes:
"""Decode gzip or zlib-compressed data using zlib's auto-detection."""
if not content:
return b""
try:
# Using wbits=47 (32 + 15) tells zlib to automatically detect both gzip and zlib headers.
# This simplifies decoding and avoids the need for a separate gzip.GzipFile fallback.
# Reference: https://docs.python.org/3/library/zlib.html#zlib.decompress
decompressor = zlib.decompressobj(47)
return decompressor.decompress(content) + decompressor.flush()
except zlib.error as e:
raise ValueError(f"Decompression failed: {e}")
def encode_gzip(content: bytes) -> bytes:
s = BytesIO()
# set mtime to 0 so that gzip encoding is deterministic.
with gzip.GzipFile(fileobj=s, mode="wb", mtime=0) as f:
f.write(content)
return s.getvalue()
def decode_brotli(content: bytes) -> bytes:
if not content:
return b""
return brotli.decompress(content)
def encode_brotli(content: bytes) -> bytes:
return brotli.compress(content)
def decode_zstd(content: bytes) -> bytes:
if not content:
return b""
zstd_ctx = zstd.ZstdDecompressor()
return zstd_ctx.stream_reader(BytesIO(content), read_across_frames=True).read()
def encode_zstd(content: bytes) -> bytes:
zstd_ctx = zstd.ZstdCompressor()
return zstd_ctx.compress(content)
def decode_deflate(content: bytes) -> bytes:
"""
Returns decompressed data for DEFLATE. Some servers may respond with
compressed data without a zlib header or checksum. An undocumented
feature of zlib permits the lenient decompression of data missing both
values.
http://bugs.python.org/issue5784
"""
if not content:
return b""
try:
return zlib.decompress(content)
except zlib.error:
return zlib.decompress(content, -15)
def encode_deflate(content: bytes) -> bytes:
"""
Returns compressed content, always including zlib header and checksum.
"""
return zlib.compress(content)
custom_decode = {
"none": identity,
"identity": identity,
"gzip": decode_gzip,
"deflate": decode_deflate,
"deflateraw": decode_deflate,
"br": decode_brotli,
"zstd": decode_zstd,
}
custom_encode = {
"none": identity,
"identity": identity,
"gzip": encode_gzip,
"deflate": encode_deflate,
"deflateraw": encode_deflate,
"br": encode_brotli,
"zstd": encode_zstd,
}
__all__ = ["encode", "decode"]

View File

@@ -0,0 +1,25 @@
import socket
def get_free_port() -> int:
"""
Get a port that's free for both TCP and UDP.
This method never raises. If no free port can be found, 0 is returned.
"""
for _ in range(10):
tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
tcp.bind(("", 0))
port: int = tcp.getsockname()[1]
udp.bind(("", port))
udp.close()
return port
except OSError:
pass
finally:
tcp.close()
return 0

View File

@@ -0,0 +1,387 @@
import email.utils
import re
import time
from collections.abc import Iterable
from mitmproxy.coretypes import multidict
"""
A flexible module for cookie parsing and manipulation.
This module differs from usual standards-compliant cookie modules in a number
of ways. We try to be as permissive as possible, and to retain even mal-formed
information. Duplicate cookies are preserved in parsing, and can be set in
formatting. We do attempt to escape and quote values where needed, but will not
reject data that violate the specs.
Parsing accepts the formats in RFC6265 and partially RFC2109 and RFC2965. We
also parse the comma-separated variant of Set-Cookie that allows multiple
cookies to be set in a single header. Serialization follows RFC6265.
http://tools.ietf.org/html/rfc6265
http://tools.ietf.org/html/rfc2109
http://tools.ietf.org/html/rfc2965
"""
_cookie_params = {
"expires",
"path",
"comment",
"max-age",
"secure",
"httponly",
"version",
}
ESCAPE = re.compile(r"([\"\\])")
class CookieAttrs(multidict.MultiDict):
@staticmethod
def _kconv(key):
return key.lower()
@staticmethod
def _reduce_values(values):
# See the StickyCookieTest for a weird cookie that only makes sense
# if we take the last part.
return values[-1]
TSetCookie = tuple[str, str | None, CookieAttrs]
TPairs = list[tuple[str, str | None]]
def _read_until(s, start, term):
"""
Read until one of the characters in term is reached.
"""
if start == len(s):
return "", start + 1
for i in range(start, len(s)):
if s[i] in term:
return s[start:i], i
return s[start : i + 1], i + 1
def _read_quoted_string(s, start):
"""
start: offset to the first quote of the string to be read
A sort of loose super-set of the various quoted string specifications.
RFC6265 disallows backslashes or double quotes within quoted strings.
Prior RFCs use backslashes to escape. This leaves us free to apply
backslash escaping by default and be compatible with everything.
"""
escaping = False
ret = []
# Skip the first quote
i = start # initialize in case the loop doesn't run.
for i in range(start + 1, len(s)):
if escaping:
ret.append(s[i])
escaping = False
elif s[i] == '"':
break
elif s[i] == "\\":
escaping = True
else:
ret.append(s[i])
return "".join(ret), i + 1
def _read_key(s, start, delims=";="):
"""
Read a key - the LHS of a token/value pair in a cookie.
"""
return _read_until(s, start, delims)
def _read_value(s, start, delims):
"""
Reads a value - the RHS of a token/value pair in a cookie.
"""
if start >= len(s):
return "", start
elif s[start] == '"':
return _read_quoted_string(s, start)
else:
return _read_until(s, start, delims)
def _read_cookie_pairs(s, off=0):
"""
Read pairs of lhs=rhs values from Cookie headers.
off: start offset
"""
pairs = []
while True:
lhs, off = _read_key(s, off)
lhs = lhs.lstrip()
rhs = ""
if off < len(s) and s[off] == "=":
rhs, off = _read_value(s, off + 1, ";")
if rhs or lhs:
pairs.append([lhs, rhs])
off += 1
if not off < len(s):
break
return pairs, off
def _read_set_cookie_pairs(s: str, off=0) -> tuple[list[TPairs], int]:
"""
Read pairs of lhs=rhs values from SetCookie headers while handling multiple cookies.
off: start offset
specials: attributes that are treated specially
"""
cookies: list[TPairs] = []
pairs: TPairs = []
while True:
lhs, off = _read_key(s, off, ";=,")
lhs = lhs.lstrip()
rhs = ""
if off < len(s) and s[off] == "=":
rhs, off = _read_value(s, off + 1, ";,")
# Special handling of attributes
if lhs.lower() == "expires":
# 'expires' values can contain commas in them so they need to
# be handled separately.
# We actually bank on the fact that the expires value WILL
# contain a comma. Things will fail, if they don't.
# '3' is just a heuristic we use to determine whether we've
# only read a part of the expires value and we should read more.
if len(rhs) <= 3:
trail, off = _read_value(s, off + 1, ";,")
rhs = rhs + "," + trail
# as long as there's a "=", we consider it a pair
pairs.append((lhs, rhs))
elif lhs:
pairs.append((lhs, None))
# comma marks the beginning of a new cookie
if off < len(s) and s[off] == ",":
cookies.append(pairs)
pairs = []
off += 1
if not off < len(s):
break
if pairs or not cookies:
cookies.append(pairs)
return cookies, off
def _has_special(s: str) -> bool:
for i in s:
if i in '",;\\':
return True
o = ord(i)
if o < 0x21 or o > 0x7E:
return True
return False
def _format_pairs(pairs, specials=(), sep="; "):
"""
specials: A lower-cased list of keys that will not be quoted.
"""
vals = []
for k, v in pairs:
if v is None:
val = k
elif k.lower() not in specials and _has_special(v):
v = ESCAPE.sub(r"\\\1", v)
v = '"%s"' % v
val = f"{k}={v}"
else:
val = f"{k}={v}"
vals.append(val)
return sep.join(vals)
def _format_set_cookie_pairs(lst):
return _format_pairs(lst, specials=("expires", "path"))
def parse_cookie_header(line):
"""
Parse a Cookie header value.
Returns a list of (lhs, rhs) tuples.
"""
pairs, off_ = _read_cookie_pairs(line)
return pairs
def parse_cookie_headers(cookie_headers):
cookie_list = []
for header in cookie_headers:
cookie_list.extend(parse_cookie_header(header))
return cookie_list
def format_cookie_header(lst):
"""
Formats a Cookie header value.
"""
return _format_pairs(lst)
def parse_set_cookie_header(line: str) -> list[TSetCookie]:
"""
Parse a Set-Cookie header value
Returns:
A list of (name, value, attrs) tuples, where attrs is a
CookieAttrs dict of attributes. No attempt is made to parse attribute
values - they are treated purely as strings.
"""
cookie_pairs, off = _read_set_cookie_pairs(line)
cookies = []
for pairs in cookie_pairs:
if pairs:
cookie, *attrs = pairs
cookies.append((cookie[0], cookie[1], CookieAttrs(attrs)))
return cookies
def parse_set_cookie_headers(headers: Iterable[str]) -> list[TSetCookie]:
rv = []
for header in headers:
cookies = parse_set_cookie_header(header)
rv.extend(cookies)
return rv
def format_set_cookie_header(set_cookies: list[TSetCookie]) -> str:
"""
Formats a Set-Cookie header value.
"""
rv = []
for name, value, attrs in set_cookies:
pairs = [(name, value)]
pairs.extend(attrs.fields if hasattr(attrs, "fields") else attrs)
rv.append(_format_set_cookie_pairs(pairs))
return ", ".join(rv)
def refresh_set_cookie_header(c: str, delta: int) -> str:
"""
Args:
c: A Set-Cookie string
delta: Time delta in seconds
Returns:
A refreshed Set-Cookie string
Raises:
ValueError, if the cookie is invalid.
"""
cookies = parse_set_cookie_header(c)
for cookie in cookies:
name, value, attrs = cookie
if not name or not value:
raise ValueError("Invalid Cookie")
if "expires" in attrs:
e = email.utils.parsedate_tz(attrs["expires"])
if e:
f = email.utils.mktime_tz(e) + delta
attrs.set_all("expires", [email.utils.formatdate(f, usegmt=True)])
else:
# This can happen when the expires tag is invalid.
# reddit.com sends a an expires tag like this: "Thu, 31 Dec
# 2037 23:59:59 GMT", which is valid RFC 1123, but not
# strictly correct according to the cookie spec. Browsers
# appear to parse this tolerantly - maybe we should too.
# For now, we just ignore this.
del attrs["expires"]
return format_set_cookie_header(cookies)
def get_expiration_ts(cookie_attrs):
"""
Determines the time when the cookie will be expired.
Considering both 'expires' and 'max-age' parameters.
Returns: timestamp of when the cookie will expire.
None, if no expiration time is set.
"""
if "expires" in cookie_attrs:
e = email.utils.parsedate_tz(cookie_attrs["expires"])
if e:
return email.utils.mktime_tz(e)
elif "max-age" in cookie_attrs:
try:
max_age = int(cookie_attrs["Max-Age"])
except ValueError:
pass
else:
now_ts = time.time()
return now_ts + max_age
return None
def is_expired(cookie_attrs):
"""
Determines whether a cookie has expired.
Returns: boolean
"""
exp_ts = get_expiration_ts(cookie_attrs)
now_ts = time.time()
# If no expiration information was provided with the cookie
if exp_ts is None:
return False
else:
return exp_ts <= now_ts
def group_cookies(pairs):
"""
Converts a list of pairs to a (name, value, attrs) for each cookie.
"""
if not pairs:
return []
cookie_list = []
# First pair is always a new cookie
name, value = pairs[0]
attrs = []
for k, v in pairs[1:]:
if k.lower() in _cookie_params:
attrs.append((k, v))
else:
cookie_list.append((name, value, CookieAttrs(attrs)))
name, value, attrs = k, v, []
cookie_list.append((name, value, CookieAttrs(attrs)))
return cookie_list

View File

@@ -0,0 +1,113 @@
import collections
import re
def parse_content_type(c: str) -> tuple[str, str, dict[str, str]] | None:
"""
A simple parser for content-type values. Returns a (type, subtype,
parameters) tuple, where type and subtype are strings, and parameters
is a dict. If the string could not be parsed, return None.
E.g. the following string:
text/html; charset=UTF-8
Returns:
("text", "html", {"charset": "UTF-8"})
"""
parts = c.split(";", 1)
ts = parts[0].split("/", 1)
if len(ts) != 2:
return None
d = collections.OrderedDict()
if len(parts) == 2:
for i in parts[1].split(";"):
clause = i.split("=", 1)
if len(clause) == 2:
d[clause[0].strip()] = clause[1].strip()
return ts[0].lower(), ts[1].lower(), d
def assemble_content_type(type, subtype, parameters):
if not parameters:
return f"{type}/{subtype}"
params = "; ".join(f"{k}={v}" for k, v in parameters.items())
return f"{type}/{subtype}; {params}"
def infer_content_encoding(content_type: str, content: bytes = b"") -> str:
"""
Infer the encoding of content from the content-type header.
"""
enc = None
# BOM has the highest priority
if content.startswith(b"\x00\x00\xfe\xff"):
enc = "utf-32be"
elif content.startswith(b"\xff\xfe\x00\x00"):
enc = "utf-32le"
elif content.startswith(b"\xfe\xff"):
enc = "utf-16be"
elif content.startswith(b"\xff\xfe"):
enc = "utf-16le"
elif content.startswith(b"\xef\xbb\xbf"):
# 'utf-8-sig' will strip the BOM on decode
enc = "utf-8-sig"
elif parsed_content_type := parse_content_type(content_type):
# Use the charset from the header if possible
enc = parsed_content_type[2].get("charset")
# Otherwise, infer the encoding
if not enc and "json" in content_type:
enc = "utf8"
if not enc and "html" in content_type:
meta_charset = re.search(
rb"""<meta[^>]+charset=['"]?([^'">]+)""", content, re.IGNORECASE
)
if meta_charset:
enc = meta_charset.group(1).decode("ascii", "ignore")
else:
# Fallback to utf8 for html
# Ref: https://html.spec.whatwg.org/multipage/parsing.html#determining-the-character-encoding
# > 9. [snip] the comprehensive UTF-8 encoding is suggested.
enc = "utf8"
if not enc and "xml" in content_type:
if xml_encoding := re.search(
rb"""<\?xml[^\?>]+encoding=['"]([^'"\?>]+)""", content, re.IGNORECASE
):
enc = xml_encoding.group(1).decode("ascii", "ignore")
else:
# Fallback to utf8 for xml
# Ref: https://datatracker.ietf.org/doc/html/rfc7303#section-8.5
# > the XML processor [snip] to determine an encoding of UTF-8.
enc = "utf8"
if not enc and ("javascript" in content_type or "ecmascript" in content_type):
# Fallback to utf8 for javascript
# Ref: https://datatracker.ietf.org/doc/html/rfc9239#section-4.2
# > 3. Else, the character encoding scheme is assumed to be UTF-8
enc = "utf8"
if not enc and "text/css" in content_type:
# @charset rule must be the very first thing.
css_charset = re.match(rb"""@charset "([^"]+)";""", content, re.IGNORECASE)
if css_charset:
enc = css_charset.group(1).decode("ascii", "ignore")
else:
# Fallback to utf8 for css
# Ref: https://drafts.csswg.org/css-syntax/#determine-the-fallback-encoding
# > 4. Otherwise, return utf-8
enc = "utf8"
# Fallback to latin-1
if not enc:
enc = "latin-1"
# Use GB 18030 as the superset of GB2312 and GBK to fix common encoding problems on Chinese websites.
if enc.lower() in ("gb2312", "gbk"):
enc = "gb18030"
return enc

View File

@@ -0,0 +1,21 @@
from .assemble import assemble_body
from .assemble import assemble_request
from .assemble import assemble_request_head
from .assemble import assemble_response
from .assemble import assemble_response_head
from .read import connection_close
from .read import expected_http_body_size
from .read import read_request_head
from .read import read_response_head
__all__ = [
"read_request_head",
"read_response_head",
"connection_close",
"expected_http_body_size",
"assemble_request",
"assemble_request_head",
"assemble_response",
"assemble_response_head",
"assemble_body",
]

View File

@@ -0,0 +1,99 @@
def assemble_request(request):
if request.data.content is None:
raise ValueError("Cannot assemble flow with missing content")
head = assemble_request_head(request)
body = b"".join(
assemble_body(
request.data.headers, [request.data.content], request.data.trailers
)
)
return head + body
def assemble_request_head(request):
first_line = _assemble_request_line(request.data)
headers = _assemble_request_headers(request.data)
return b"%s\r\n%s\r\n" % (first_line, headers)
def assemble_response(response):
if response.data.content is None:
raise ValueError("Cannot assemble flow with missing content")
head = assemble_response_head(response)
body = b"".join(
assemble_body(
response.data.headers, [response.data.content], response.data.trailers
)
)
return head + body
def assemble_response_head(response):
first_line = _assemble_response_line(response.data)
headers = _assemble_response_headers(response.data)
return b"%s\r\n%s\r\n" % (first_line, headers)
def assemble_body(headers, body_chunks, trailers):
if "chunked" in headers.get("transfer-encoding", "").lower():
for chunk in body_chunks:
if chunk:
yield b"%x\r\n%s\r\n" % (len(chunk), chunk)
if trailers:
yield b"0\r\n%s\r\n" % trailers
else:
yield b"0\r\n\r\n"
else:
if trailers:
raise ValueError(
"Sending HTTP/1.1 trailer headers requires transfer-encoding: chunked"
)
for chunk in body_chunks:
yield chunk
def _assemble_request_line(request_data):
"""
Args:
request_data (mitmproxy.net.http.request.RequestData)
"""
if request_data.method.upper() == b"CONNECT":
return b"%s %s %s" % (
request_data.method,
request_data.authority,
request_data.http_version,
)
elif request_data.authority:
return b"%s %s://%s%s %s" % (
request_data.method,
request_data.scheme,
request_data.authority,
request_data.path,
request_data.http_version,
)
else:
return b"%s %s %s" % (
request_data.method,
request_data.path,
request_data.http_version,
)
def _assemble_request_headers(request_data):
"""
Args:
request_data (mitmproxy.net.http.request.RequestData)
"""
return bytes(request_data.headers)
def _assemble_response_line(response_data):
return b"%s %d %s" % (
response_data.http_version,
response_data.status_code,
response_data.reason,
)
def _assemble_response_headers(response):
return bytes(response.headers)

View File

@@ -0,0 +1,303 @@
import re
import time
import typing
from collections.abc import Iterable
from mitmproxy.http import Headers
from mitmproxy.http import Request
from mitmproxy.http import Response
from mitmproxy.net.http import url
from mitmproxy.net.http import validate
def get_header_tokens(headers, key):
"""
Retrieve all tokens for a header key. A number of different headers
follow a pattern where each header line can containe comma-separated
tokens, and headers can be set multiple times.
"""
if key not in headers:
return []
tokens = headers[key].split(",")
return [token.strip() for token in tokens]
def connection_close(http_version, headers):
"""
Checks the message to see if the client connection should be closed
according to RFC 2616 Section 8.1.
If we don't have a Connection header, HTTP 1.1 connections are assumed
to be persistent.
"""
if "connection" in headers:
tokens = get_header_tokens(headers, "connection")
if "close" in tokens:
return True
elif "keep-alive" in tokens:
return False
return http_version not in (
"HTTP/1.1",
b"HTTP/1.1",
"HTTP/2.0",
b"HTTP/2.0",
)
def expected_http_body_size(
request: Request, response: Response | None = None
) -> int | None:
"""
Returns:
The expected body length:
- a positive integer, if the size is known in advance
- None, if the size in unknown in advance (chunked encoding)
- -1, if all data should be read until end of stream.
Raises:
ValueError, if the content-length or transfer-encoding header is invalid
"""
# Determine response size according to http://tools.ietf.org/html/rfc7230#section-3.3, which is inlined below.
if not response:
headers = request.headers
else:
headers = response.headers
# 1. Any response to a HEAD request and any response with a 1xx
# (Informational), 204 (No Content), or 304 (Not Modified) status
# code is always terminated by the first empty line after the
# header fields, regardless of the header fields present in the
# message, and thus cannot contain a message body.
if request.method.upper() == "HEAD":
return 0
if 100 <= response.status_code <= 199:
return 0
if response.status_code in (204, 304):
return 0
# 2. Any 2xx (Successful) response to a CONNECT request implies that
# the connection will become a tunnel immediately after the empty
# line that concludes the header fields. A client MUST ignore any
# Content-Length or Transfer-Encoding header fields received in
# such a message.
if 200 <= response.status_code <= 299 and request.method.upper() == "CONNECT":
return 0
# 3. If a Transfer-Encoding header field is present and the chunked
# transfer coding (Section 4.1) is the final encoding, the message
# body length is determined by reading and decoding the chunked
# data until the transfer coding indicates the data is complete.
#
# If a Transfer-Encoding header field is present in a response and
# the chunked transfer coding is not the final encoding, the
# message body length is determined by reading the connection until
# it is closed by the server. If a Transfer-Encoding header field
# is present in a request and the chunked transfer coding is not
# the final encoding, the message body length cannot be determined
# reliably; the server MUST respond with the 400 (Bad Request)
# status code and then close the connection.
#
# If a message is received with both a Transfer-Encoding and a
# Content-Length header field, the Transfer-Encoding overrides the
# Content-Length. Such a message might indicate an attempt to
# perform request smuggling (Section 9.5) or response splitting
# (Section 9.4) and ought to be handled as an error. A sender MUST
# remove the received Content-Length field prior to forwarding such
# a message downstream.
#
if te_str := headers.get("transfer-encoding"):
te = validate.parse_transfer_encoding(te_str)
match te:
case "chunked" | "compress,chunked" | "deflate,chunked" | "gzip,chunked":
return None
case "compress" | "deflate" | "gzip" | "identity":
if response:
return -1
# These values are valid for responses only (not requests), which is ensured in
# mitmproxy.net.http.validate. If users have explicitly disabled header validation,
# we strive for maximum compatibility with weird clients.
if te == "identity" or "content-length" in headers:
pass # Content-Length or 0
else:
return (
-1
) # compress/deflate/gzip with no content-length -> read until eof
case other: # pragma: no cover
typing.assert_never(other)
# 4. If a message is received without Transfer-Encoding and with
# either multiple Content-Length header fields having differing
# field-values or a single Content-Length header field having an
# invalid value, then the message framing is invalid and the
# recipient MUST treat it as an unrecoverable error. If this is a
# request message, the server MUST respond with a 400 (Bad Request)
# status code and then close the connection. If this is a response
# message received by a proxy, the proxy MUST close the connection
# to the server, discard the received response, and send a 502 (Bad
# Gateway) response to the client. If this is a response message
# received by a user agent, the user agent MUST close the
# connection to the server and discard the received response.
#
# 5. If a valid Content-Length header field is present without
# Transfer-Encoding, its decimal value defines the expected message
# body length in octets. If the sender closes the connection or
# the recipient times out before the indicated number of octets are
# received, the recipient MUST consider the message to be
# incomplete and close the connection.
if cl := headers.get("content-length"):
return validate.parse_content_length(cl)
# 6. If this is a request message and none of the above are true, then
# the message body length is zero (no message body is present).
if not response:
return 0
# 7. Otherwise, this is a response message without a declared message
# body length, so the message body length is determined by the
# number of octets received prior to the server closing the
# connection.
return -1
def raise_if_http_version_unknown(http_version: bytes) -> None:
if not re.match(rb"^HTTP/\d\.\d$", http_version):
raise ValueError(f"Unknown HTTP version: {http_version!r}")
def _read_request_line(
line: bytes,
) -> tuple[str, int, bytes, bytes, bytes, bytes, bytes]:
try:
method, target, http_version = line.split()
port: int | None
if target == b"*" or target.startswith(b"/"):
scheme, authority, path = b"", b"", target
host, port = "", 0
elif method == b"CONNECT":
scheme, authority, path = b"", target, b""
host, port = url.parse_authority(authority, check=True)
if not port:
raise ValueError
else:
scheme, rest = target.split(b"://", maxsplit=1)
authority, _, path_ = rest.partition(b"/")
path = b"/" + path_
host, port = url.parse_authority(authority, check=True)
port = port or url.default_port(scheme)
if not port:
raise ValueError
# TODO: we can probably get rid of this check?
url.parse(target)
raise_if_http_version_unknown(http_version)
except ValueError as e:
raise ValueError(f"Bad HTTP request line: {line!r}") from e
return host, port, method, scheme, authority, path, http_version
def _read_response_line(line: bytes) -> tuple[bytes, int, bytes]:
try:
parts = line.split(None, 2)
if len(parts) == 2: # handle missing message gracefully
parts.append(b"")
http_version, status_code_str, reason = parts
status_code = int(status_code_str)
raise_if_http_version_unknown(http_version)
except ValueError as e:
raise ValueError(f"Bad HTTP response line: {line!r}") from e
return http_version, status_code, reason
def _read_headers(lines: Iterable[bytes]) -> Headers:
"""
Read a set of headers.
Stop once a blank line is reached.
Returns:
A headers object
Raises:
exceptions.HttpSyntaxException
"""
ret: list[tuple[bytes, bytes]] = []
for line in lines:
if line[0] in b" \t":
if not ret:
raise ValueError("Invalid headers")
# continued header
ret[-1] = (ret[-1][0], ret[-1][1] + b"\r\n " + line.strip())
else:
try:
name, value = line.split(b":", 1)
value = value.strip()
if not name:
raise ValueError()
ret.append((name, value))
except ValueError:
raise ValueError(f"Invalid header line: {line!r}")
return Headers(ret)
def read_request_head(lines: list[bytes]) -> Request:
"""
Parse an HTTP request head (request line + headers) from an iterable of lines
Args:
lines: The input lines
Returns:
The HTTP request object (without body)
Raises:
ValueError: The input is malformed.
"""
host, port, method, scheme, authority, path, http_version = _read_request_line(
lines[0]
)
headers = _read_headers(lines[1:])
return Request(
host=host,
port=port,
method=method,
scheme=scheme,
authority=authority,
path=path,
http_version=http_version,
headers=headers,
content=None,
trailers=None,
timestamp_start=time.time(),
timestamp_end=None,
)
def read_response_head(lines: list[bytes]) -> Response:
"""
Parse an HTTP response head (response line + headers) from an iterable of lines
Args:
lines: The input lines
Returns:
The HTTP response object (without body)
Raises:
ValueError: The input is malformed.
"""
http_version, status_code, reason = _read_response_line(lines[0])
headers = _read_headers(lines[1:])
return Response(
http_version=http_version,
status_code=status_code,
reason=reason,
headers=headers,
content=None,
trailers=None,
timestamp_start=time.time(),
timestamp_end=None,
)

View File

@@ -0,0 +1,95 @@
from __future__ import annotations
import mimetypes
import re
import warnings
from urllib.parse import quote
from mitmproxy.net.http import headers
def encode_multipart(content_type: str, parts: list[tuple[bytes, bytes]]) -> bytes:
if content_type:
ct = headers.parse_content_type(content_type)
if ct is not None:
try:
raw_boundary = ct[2]["boundary"].encode("ascii")
boundary = quote(raw_boundary)
except (KeyError, UnicodeError):
return b""
hdrs = []
for key, value in parts:
file_type = (
mimetypes.guess_type(str(key))[0] or "text/plain; charset=utf-8"
)
if key:
hdrs.append(b"--%b" % boundary.encode("utf-8"))
disposition = b'form-data; name="%b"' % key
hdrs.append(b"Content-Disposition: %b" % disposition)
hdrs.append(b"Content-Type: %b" % file_type.encode("utf-8"))
hdrs.append(b"")
hdrs.append(value)
hdrs.append(b"")
if value is not None:
# If boundary is found in value then raise ValueError
if re.search(
rb"^--%b$" % re.escape(boundary.encode("utf-8")), value
):
raise ValueError(b"boundary found in encoded string")
hdrs.append(b"--%b--\r\n" % boundary.encode("utf-8"))
temp = b"\r\n".join(hdrs)
return temp
return b""
def decode_multipart(
content_type: str | None, content: bytes
) -> list[tuple[bytes, bytes]]:
"""
Takes a multipart boundary encoded string and returns list of (key, value) tuples.
"""
if content_type:
ct = headers.parse_content_type(content_type)
if not ct:
return []
try:
boundary = ct[2]["boundary"].encode("ascii")
except (KeyError, UnicodeError):
return []
rx = re.compile(rb'\bname="([^"]+)"')
r = []
if content is not None:
for i in content.split(b"--" + boundary):
parts = i.splitlines()
if len(parts) > 1 and parts[0][0:2] != b"--":
match = rx.search(parts[1])
if match:
key = match.group(1)
value = b"".join(parts[3 + parts[2:].index(b"") :])
r.append((key, value))
return r
return []
def encode(ct, parts): # pragma: no cover
# 2023-02
warnings.warn(
"multipart.encode is deprecated, use multipart.encode_multipart instead.",
DeprecationWarning,
stacklevel=2,
)
return encode_multipart(ct, parts)
def decode(ct, content): # pragma: no cover
# 2023-02
warnings.warn(
"multipart.decode is deprecated, use multipart.decode_multipart instead.",
DeprecationWarning,
stacklevel=2,
)
return encode_multipart(ct, content)

View File

@@ -0,0 +1,146 @@
# Covered status codes:
# - official HTTP status codes: https://developer.mozilla.org/en-US/docs/Web/HTTP/Status
# - custom codes:
# - 444 No Response
# - 499 Client Closed Request
CONTINUE = 100
SWITCHING = 101
PROCESSING = 102
EARLY_HINTS = 103
OK = 200
CREATED = 201
ACCEPTED = 202
NON_AUTHORITATIVE_INFORMATION = 203
NO_CONTENT = 204
RESET_CONTENT = 205
PARTIAL_CONTENT = 206
MULTI_STATUS = 207
ALREADY_REPORTED = 208
IM_USED = 226
MULTIPLE_CHOICE = 300
MOVED_PERMANENTLY = 301
FOUND = 302
SEE_OTHER = 303
NOT_MODIFIED = 304
USE_PROXY = 305
TEMPORARY_REDIRECT = 307
PERMANENT_REDIRECT = 308
BAD_REQUEST = 400
UNAUTHORIZED = 401
PAYMENT_REQUIRED = 402
FORBIDDEN = 403
NOT_FOUND = 404
NOT_ALLOWED = 405
NOT_ACCEPTABLE = 406
PROXY_AUTH_REQUIRED = 407
REQUEST_TIMEOUT = 408
CONFLICT = 409
GONE = 410
LENGTH_REQUIRED = 411
PRECONDITION_FAILED = 412
PAYLOAD_TOO_LARGE = 413
REQUEST_URI_TOO_LONG = 414
UNSUPPORTED_MEDIA_TYPE = 415
REQUESTED_RANGE_NOT_SATISFIABLE = 416
EXPECTATION_FAILED = 417
IM_A_TEAPOT = 418
MISDIRECTED_REQUEST = 421
UNPROCESSABLE_CONTENT = 422
LOCKED = 423
FAILED_DEPENDENCY = 424
TOO_EARLY = 425
UPGRADE_REQUIRED = 426
PRECONDITION_REQUIRED = 428
TOO_MANY_REQUESTS = 429
REQUEST_HEADER_FIELDS_TOO_LARGE = 431
UNAVAILABLE_FOR_LEGAL_REASONS = 451
NO_RESPONSE = 444
CLIENT_CLOSED_REQUEST = 499
INTERNAL_SERVER_ERROR = 500
NOT_IMPLEMENTED = 501
BAD_GATEWAY = 502
SERVICE_UNAVAILABLE = 503
GATEWAY_TIMEOUT = 504
HTTP_VERSION_NOT_SUPPORTED = 505
VARIANT_ALSO_NEGOTIATES = 506
INSUFFICIENT_STORAGE_SPACE = 507
LOOP_DETECTED = 508
NOT_EXTENDED = 510
NETWORK_AUTHENTICATION_REQUIRED = 511
RESPONSES = {
# 100
CONTINUE: "Continue",
SWITCHING: "Switching Protocols",
PROCESSING: "Processing",
EARLY_HINTS: "Early Hints",
# 200
OK: "OK",
CREATED: "Created",
ACCEPTED: "Accepted",
NON_AUTHORITATIVE_INFORMATION: "Non-Authoritative Information",
NO_CONTENT: "No Content",
RESET_CONTENT: "Reset Content",
PARTIAL_CONTENT: "Partial Content",
MULTI_STATUS: "Multi-Status",
ALREADY_REPORTED: "Already Reported",
IM_USED: "IM Used",
# 300
MULTIPLE_CHOICE: "Multiple Choices",
MOVED_PERMANENTLY: "Moved Permanently",
FOUND: "Found",
SEE_OTHER: "See Other",
NOT_MODIFIED: "Not Modified",
USE_PROXY: "Use Proxy",
# 306 not defined??
TEMPORARY_REDIRECT: "Temporary Redirect",
PERMANENT_REDIRECT: "Permanent Redirect",
# 400
BAD_REQUEST: "Bad Request",
UNAUTHORIZED: "Unauthorized",
PAYMENT_REQUIRED: "Payment Required",
FORBIDDEN: "Forbidden",
NOT_FOUND: "Not Found",
NOT_ALLOWED: "Method Not Allowed",
NOT_ACCEPTABLE: "Not Acceptable",
PROXY_AUTH_REQUIRED: "Proxy Authentication Required",
REQUEST_TIMEOUT: "Request Time-out",
CONFLICT: "Conflict",
GONE: "Gone",
LENGTH_REQUIRED: "Length Required",
PRECONDITION_FAILED: "Precondition Failed",
PAYLOAD_TOO_LARGE: "Payload Too Large",
REQUEST_URI_TOO_LONG: "Request-URI Too Long",
UNSUPPORTED_MEDIA_TYPE: "Unsupported Media Type",
REQUESTED_RANGE_NOT_SATISFIABLE: "Requested Range not satisfiable",
EXPECTATION_FAILED: "Expectation Failed",
IM_A_TEAPOT: "I'm a teapot",
MISDIRECTED_REQUEST: "Misdirected Request",
UNPROCESSABLE_CONTENT: "Unprocessable Content",
LOCKED: "Locked",
FAILED_DEPENDENCY: "Failed Dependency",
TOO_EARLY: "Too Early",
UPGRADE_REQUIRED: "Upgrade Required",
PRECONDITION_REQUIRED: "Precondition Required",
TOO_MANY_REQUESTS: "Too Many Requests",
REQUEST_HEADER_FIELDS_TOO_LARGE: "Request Header Fields Too Large",
UNAVAILABLE_FOR_LEGAL_REASONS: "Unavailable For Legal Reasons",
NO_RESPONSE: "No Response",
CLIENT_CLOSED_REQUEST: "Client Closed Request",
# 500
INTERNAL_SERVER_ERROR: "Internal Server Error",
NOT_IMPLEMENTED: "Not Implemented",
BAD_GATEWAY: "Bad Gateway",
SERVICE_UNAVAILABLE: "Service Unavailable",
GATEWAY_TIMEOUT: "Gateway Time-out",
HTTP_VERSION_NOT_SUPPORTED: "HTTP Version not supported",
VARIANT_ALSO_NEGOTIATES: "Variant Also Negotiates",
INSUFFICIENT_STORAGE_SPACE: "Insufficient Storage Space",
LOOP_DETECTED: "Loop Detected",
NOT_EXTENDED: "Not Extended",
NETWORK_AUTHENTICATION_REQUIRED: "Network Authentication Required",
}

View File

@@ -0,0 +1,200 @@
from __future__ import annotations
import re
import urllib.parse
from collections.abc import Sequence
from typing import AnyStr
from typing import overload
from mitmproxy.net import check
from mitmproxy.net.check import is_valid_host
from mitmproxy.net.check import is_valid_port
from mitmproxy.utils.strutils import always_str
# This regex extracts & splits the host header into host and port.
# Handles the edge case of IPv6 addresses containing colons.
# https://bugzilla.mozilla.org/show_bug.cgi?id=45891
_authority_re = re.compile(r"^(?P<host>[^:]+|\[.+\])(?::(?P<port>\d+))?$")
def parse(url: str | bytes) -> tuple[bytes, bytes, int, bytes]:
"""
URL-parsing function that checks that
- port is an integer 0-65535
- host is a valid IDNA-encoded hostname with no null-bytes
- path is valid ASCII
Args:
A URL (as bytes or as unicode)
Returns:
A (scheme, host, port, path) tuple
Raises:
ValueError, if the URL is not properly formatted.
"""
# FIXME: We shouldn't rely on urllib here.
# Size of Ascii character after encoding is 1 byte which is same as its size
# But non-Ascii character's size after encoding will be more than its size
def ascii_check(x):
if len(x) == len(str(x).encode()):
return True
return False
if isinstance(url, bytes):
url = url.decode()
if not ascii_check(url):
url = urllib.parse.urlsplit(url) # type: ignore
url = list(url) # type: ignore
url[3] = urllib.parse.quote(url[3]) # type: ignore
url = urllib.parse.urlunsplit(url) # type: ignore
parsed: urllib.parse.ParseResult = urllib.parse.urlparse(url)
if not parsed.hostname:
raise ValueError("No hostname given")
else:
host = parsed.hostname.encode("idna")
parsed_b: urllib.parse.ParseResultBytes = parsed.encode("ascii") # type: ignore
port = parsed_b.port
if not port:
port = 443 if parsed_b.scheme == b"https" else 80
full_path: bytes = urllib.parse.urlunparse(
(b"", b"", parsed_b.path, parsed_b.params, parsed_b.query, parsed_b.fragment) # type: ignore
)
if not full_path.startswith(b"/"):
full_path = b"/" + full_path # type: ignore
if not check.is_valid_host(host):
raise ValueError("Invalid Host")
return parsed_b.scheme, host, port, full_path
@overload
def unparse(scheme: str, host: str, port: int, path) -> str: ...
@overload
def unparse(scheme: bytes, host: bytes, port: int, path) -> bytes: ...
def unparse(scheme, host, port, path):
"""
Returns a URL string, constructed from the specified components.
"""
authority = hostport(scheme, host, port)
if isinstance(scheme, str):
return f"{scheme}://{authority}{path}"
else:
return b"%s://%s%s" % (scheme, authority, path)
def encode(s: Sequence[tuple[str, str]], similar_to: str | None = None) -> str:
"""
Takes a list of (key, value) tuples and returns a urlencoded string.
If similar_to is passed, the output is formatted similar to the provided urlencoded string.
"""
remove_trailing_equal = False
if similar_to:
remove_trailing_equal = any("=" not in param for param in similar_to.split("&"))
encoded = urllib.parse.urlencode(s, False, errors="surrogateescape")
if encoded and remove_trailing_equal:
encoded = encoded.replace("=&", "&")
if encoded[-1] == "=":
encoded = encoded[:-1]
return encoded
def decode(s):
"""
Takes a urlencoded string and returns a list of surrogate-escaped (key, value) tuples.
"""
return urllib.parse.parse_qsl(s, keep_blank_values=True, errors="surrogateescape")
def quote(b: str, safe: str = "/") -> str:
"""
Returns:
An ascii-encodable str.
"""
return urllib.parse.quote(b, safe=safe, errors="surrogateescape")
def unquote(s: str) -> str:
"""
Args:
s: A surrogate-escaped str
Returns:
A surrogate-escaped str
"""
return urllib.parse.unquote(s, errors="surrogateescape")
def hostport(scheme: AnyStr, host: AnyStr, port: int) -> AnyStr:
"""
Returns the host component, with a port specification if needed.
"""
if default_port(scheme) == port:
return host
else:
if isinstance(host, bytes):
return b"%s:%d" % (host, port)
else:
return "%s:%d" % (host, port)
def default_port(scheme: AnyStr) -> int | None:
return {
"http": 80,
b"http": 80,
"https": 443,
b"https": 443,
}.get(scheme, None)
def parse_authority(authority: AnyStr, check: bool) -> tuple[str, int | None]:
"""Extract the host and port from host header/authority information
Raises:
ValueError, if check is True and the authority information is malformed.
"""
try:
if isinstance(authority, bytes):
m = _authority_re.match(authority.decode("utf-8"))
if not m:
raise ValueError
host = m["host"].encode("utf-8").decode("idna")
else:
m = _authority_re.match(authority)
if not m:
raise ValueError
host = m.group("host")
if host.startswith("[") and host.endswith("]"):
host = host[1:-1]
if not is_valid_host(host):
raise ValueError
if m.group("port"):
port = int(m.group("port"))
if not is_valid_port(port):
raise ValueError
return host, port
else:
return host, None
except ValueError:
if check:
raise
else:
return always_str(authority, "utf-8", "surrogateescape"), None

View File

@@ -0,0 +1,60 @@
"""
A small collection of useful user-agent header strings. These should be
kept reasonably current to reflect common usage.
"""
# pylint: line-too-long
# A collection of (name, shortcut, string) tuples.
UASTRINGS = [
(
"android",
"a",
"Mozilla/5.0 (Linux; U; Android 4.1.1; en-gb; Nexus 7 Build/JRO03D) AFL/01.04.02",
),
(
"blackberry",
"l",
"Mozilla/5.0 (BlackBerry; U; BlackBerry 9900; en) AppleWebKit/534.11+ (KHTML, like Gecko) Version/7.1.0.346 Mobile Safari/534.11+",
),
(
"bingbot",
"b",
"Mozilla/5.0 (compatible; bingbot/2.0; +http://www.bing.com/bingbot.htm)",
),
(
"chrome",
"c",
"Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.1 (KHTML, like Gecko) Chrome/22.0.1207.1 Safari/537.1",
),
(
"firefox",
"f",
"Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:14.0) Gecko/20120405 Firefox/14.0a1",
),
("googlebot", "g", "Googlebot/2.1 (+http://www.googlebot.com/bot.html)"),
("ie9", "i", "Mozilla/5.0 (Windows; U; MSIE 9.0; WIndows NT 9.0; en-US)"),
(
"ipad",
"p",
"Mozilla/5.0 (iPad; CPU OS 5_1 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko) Version/5.1 Mobile/9B176 Safari/7534.48.3",
),
(
"iphone",
"h",
"Mozilla/5.0 (iPhone; CPU iPhone OS 4_2_1 like Mac OS X) AppleWebKit/533.17.9 (KHTML, like Gecko) Version/5.0.2 Mobile/8C148a Safari/6533.18.5", # noqa
),
(
"safari",
"s",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_7_3) AppleWebKit/534.55.3 (KHTML, like Gecko) Version/5.1.3 Safari/534.53.10",
),
]
def get_by_shortcut(s):
"""
Retrieve a user agent entry by shortcut.
"""
for i in UASTRINGS:
if s == i[1]:
return i

View File

@@ -0,0 +1,141 @@
import logging
import re
import typing
from mitmproxy.http import Message
from mitmproxy.http import Request
from mitmproxy.http import Response
logger = logging.getLogger(__name__)
# https://datatracker.ietf.org/doc/html/rfc7230#section-3.2: Header fields are tokens.
# "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~" / DIGIT / ALPHA
_valid_header_name = re.compile(rb"^[!#$%&'*+\-.^_`|~0-9a-zA-Z]+$")
_valid_content_length = re.compile(rb"^(?:0|[1-9][0-9]*)$")
_valid_content_length_str = re.compile(r"^(?:0|[1-9][0-9]*)$")
# https://datatracker.ietf.org/doc/html/rfc9112#section-6.1:
# > A sender MUST NOT apply the chunked transfer coding more than once to a message body (i.e., chunking an already
# > chunked message is not allowed). If any transfer coding other than chunked is applied to a request's content, the
# > sender MUST apply chunked as the final transfer coding to ensure that the message is properly framed. If any
# > transfer coding other than chunked is applied to a response's content, the sender MUST either apply chunked as the
# > final transfer coding or terminate the message by closing the connection.
#
# The RFC technically still allows for fun encodings, we are a bit stricter and only accept a known subset by default.
TransferEncoding = typing.Literal[
"chunked",
"compress,chunked",
"deflate,chunked",
"gzip,chunked",
"compress",
"deflate",
"gzip",
"identity",
]
_HTTP_1_1_TRANSFER_ENCODINGS = frozenset(typing.get_args(TransferEncoding))
def parse_content_length(value: str | bytes) -> int:
"""Parse a content-length header value, or raise a ValueError if it is invalid."""
if isinstance(value, str):
valid = bool(_valid_content_length_str.match(value))
else:
valid = bool(_valid_content_length.match(value))
if not valid:
raise ValueError(f"invalid content-length header: {value!r}")
return int(value)
def parse_transfer_encoding(value: str | bytes) -> TransferEncoding:
"""Parse a transfer-encoding header value, or raise a ValueError if it is invalid or unknown."""
# guard against .lower() transforming non-ascii to ascii
if not value.isascii():
raise ValueError(f"invalid transfer-encoding header: {value!r}")
if isinstance(value, str):
te = value
else:
te = value.decode()
te = te.lower()
te = re.sub(r"[\t ]*,[\t ]*", ",", te)
if te not in _HTTP_1_1_TRANSFER_ENCODINGS:
raise ValueError(f"unknown transfer-encoding header: {value!r}")
return typing.cast(TransferEncoding, te)
def validate_headers(message: Message) -> None:
"""
Validate HTTP message headers to avoid request smuggling attacks.
Raises a ValueError if they are malformed.
"""
te = []
cl = []
for name, value in message.headers.fields:
if not _valid_header_name.match(name):
raise ValueError(f"invalid header name: {name!r}")
match name.lower():
case b"transfer-encoding":
te.append(value)
case b"content-length":
cl.append(value)
if te and cl:
# > A server MAY reject a request that contains both Content-Length and Transfer-Encoding or process such a
# > request in accordance with the Transfer-Encoding alone.
# > A sender MUST NOT send a Content-Length header field in any message that contains a Transfer-Encoding header
# > field.
raise ValueError(
"message with both transfer-encoding and content-length headers"
)
elif te:
if len(te) > 1:
raise ValueError(f"multiple transfer-encoding headers: {te!r}")
# > Transfer-Encoding was added in HTTP/1.1. It is generally assumed that implementations advertising only
# > HTTP/1.0 support will not understand how to process transfer-encoded content, and that an HTTP/1.0 message
# > received with a Transfer-Encoding is likely to have been forwarded without proper handling of the chunked
# > transfer coding in transit.
#
# > A client MUST NOT send a request containing Transfer-Encoding unless it knows the server will handle
# > HTTP/1.1 requests (or later minor revisions); such knowledge might be in the form of specific user
# > configuration or by remembering the version of a prior received response. A server MUST NOT send a response
# > containing Transfer-Encoding unless the corresponding request indicates HTTP/1.1 (or later minor revisions).
if not message.is_http11:
raise ValueError(
f"unexpected HTTP transfer-encoding {te[0]!r} for {message.http_version}"
)
# > A server MUST NOT send a Transfer-Encoding header field in any response with a status code of 1xx
# > (Informational) or 204 (No Content).
if isinstance(message, Response) and (
100 <= message.status_code <= 199 or message.status_code == 204
):
raise ValueError(
f"unexpected HTTP transfer-encoding {te[0]!r} for response with status code {message.status_code}"
)
# > If a Transfer-Encoding header field is present in a request and the chunked transfer coding is not the final
# > encoding, the message body length cannot be determined reliably; the server MUST respond with the 400 (Bad
# > Request) status code and then close the connection.
te_parsed = parse_transfer_encoding(te[0])
match te_parsed:
case "chunked" | "compress,chunked" | "deflate,chunked" | "gzip,chunked":
pass
case "compress" | "deflate" | "gzip" | "identity":
if isinstance(message, Request):
raise ValueError(
f"unexpected HTTP transfer-encoding {te_parsed!r} for request"
)
case other: # pragma: no cover
typing.assert_never(other)
elif cl:
# > If a message is received without Transfer-Encoding and with an invalid Content-Length header field, then the
# > message framing is invalid and the recipient MUST treat it as an unrecoverable error, unless the field value
# > can be successfully parsed as a comma-separated list (Section 5.6.1 of [HTTP]), all values in the list are
# > valid, and all values in the list are the same (in which case, the message is processed with that single
# > value used as the Content-Length field value).
# We are stricter here and reject comma-separated lists.
if len(cl) > 1:
raise ValueError(f"multiple content-length headers: {cl!r}")
parse_content_length(cl[0])

View File

@@ -0,0 +1,40 @@
from __future__ import annotations
import socket
def get_local_ip(reachable: str = "8.8.8.8") -> str | None:
"""
Get the default local outgoing IPv4 address without sending any packets.
This will fail if the target address is known to be unreachable.
We use Google DNS's IPv4 address as the default.
"""
# https://stackoverflow.com/questions/166506/finding-local-ip-addresses-using-pythons-stdlib
s = None
try:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect((reachable, 80))
return s.getsockname()[0] # pragma: no cover
except OSError:
return None # pragma: no cover
finally:
if s is not None:
s.close()
def get_local_ip6(reachable: str = "2001:4860:4860::8888") -> str | None:
"""
Get the default local outgoing IPv6 address without sending any packets.
This will fail if the target address is known to be unreachable.
We use Google DNS's IPv6 address as the default.
"""
s = None
try:
s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
s.connect((reachable, 80))
return s.getsockname()[0] # pragma: no cover
except OSError: # pragma: no cover
return None
finally:
if s is not None:
s.close()

View File

@@ -0,0 +1,85 @@
"""
Server specs are used to describe an upstream proxy or server.
"""
import re
from functools import cache
from typing import Literal
from mitmproxy.net import check
ServerSpec = tuple[
Literal["http", "https", "http3", "tls", "dtls", "tcp", "udp", "dns", "quic"],
tuple[str, int],
]
server_spec_re = re.compile(
r"""
^
(?:(?P<scheme>\w+)://)? # scheme is optional
(?P<host>[^:/]+|\[.+\]) # hostname can be DNS name, IPv4, or IPv6 address.
(?::(?P<port>\d+))? # port is optional
/? # we allow a trailing backslash, but no path
$
""",
re.VERBOSE,
)
@cache
def parse(server_spec: str, default_scheme: str) -> ServerSpec:
"""
Parses a server mode specification, e.g.:
- http://example.com/
- example.org
- example.com:443
*Raises:*
- ValueError, if the server specification is invalid.
"""
m = server_spec_re.match(server_spec)
if not m:
raise ValueError(f"Invalid server specification: {server_spec}")
if m.group("scheme"):
scheme = m.group("scheme")
else:
scheme = default_scheme
if scheme not in (
"http",
"https",
"http3",
"tls",
"dtls",
"tcp",
"udp",
"dns",
"quic",
):
raise ValueError(f"Invalid server scheme: {scheme}")
host = m.group("host")
# IPv6 brackets
if host.startswith("[") and host.endswith("]"):
host = host[1:-1]
if not check.is_valid_host(host):
raise ValueError(f"Invalid hostname: {host}")
if m.group("port"):
port = int(m.group("port"))
else:
try:
port = {
"http": 80,
"https": 443,
"quic": 443,
"http3": 443,
"dns": 53,
}[scheme]
except KeyError:
raise ValueError(f"Port specification missing.")
if not check.is_valid_port(port):
raise ValueError(f"Invalid port: {port}")
return scheme, (host, port) # type: ignore

View File

@@ -0,0 +1,330 @@
import os
import threading
import typing
from collections.abc import Callable
from collections.abc import Iterable
from enum import Enum
from functools import cache
from functools import lru_cache
from pathlib import Path
from typing import Any
from typing import BinaryIO
import certifi
import OpenSSL
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurve
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurveOID
from cryptography.hazmat.primitives.asymmetric.ec import get_curve_for_oid
from cryptography.x509 import ObjectIdentifier
from OpenSSL import SSL
from mitmproxy import certs
# Remove once pyOpenSSL 23.3.0 is released and bump version in pyproject.toml.
try: # pragma: no cover
from OpenSSL.SSL import OP_LEGACY_SERVER_CONNECT # type: ignore
except ImportError:
OP_LEGACY_SERVER_CONNECT = 0x4
# redeclared here for strict type checking
class Method(Enum):
TLS_SERVER_METHOD = SSL.TLS_SERVER_METHOD
TLS_CLIENT_METHOD = SSL.TLS_CLIENT_METHOD
# Type-pyopenssl does not know about these DTLS constants.
DTLS_SERVER_METHOD = SSL.DTLS_SERVER_METHOD # type: ignore
DTLS_CLIENT_METHOD = SSL.DTLS_CLIENT_METHOD # type: ignore
try:
SSL._lib.TLS_server_method # type: ignore
except AttributeError as e: # pragma: no cover
raise RuntimeError(
"Your installation of the cryptography Python package is outdated."
) from e
class Version(Enum):
UNBOUNDED = 0
SSL3 = SSL.SSL3_VERSION
TLS1 = SSL.TLS1_VERSION
TLS1_1 = SSL.TLS1_1_VERSION
TLS1_2 = SSL.TLS1_2_VERSION
TLS1_3 = SSL.TLS1_3_VERSION
INSECURE_TLS_MIN_VERSIONS: tuple[Version, ...] = (
Version.UNBOUNDED,
Version.SSL3,
Version.TLS1,
Version.TLS1_1,
)
class Verify(Enum):
VERIFY_NONE = SSL.VERIFY_NONE
VERIFY_PEER = SSL.VERIFY_PEER
DEFAULT_MIN_VERSION = Version.TLS1_2
DEFAULT_MAX_VERSION = Version.UNBOUNDED
DEFAULT_OPTIONS = SSL.OP_CIPHER_SERVER_PREFERENCE | SSL.OP_NO_COMPRESSION
@cache
def is_supported_version(version: Version):
client_ctx = SSL.Context(SSL.TLS_CLIENT_METHOD)
# Without SECLEVEL, recent OpenSSL versions forbid old TLS versions.
# https://github.com/pyca/cryptography/issues/9523
client_ctx.set_cipher_list(b"@SECLEVEL=0:ALL")
client_ctx.set_min_proto_version(version.value)
client_ctx.set_max_proto_version(version.value)
client_conn = SSL.Connection(client_ctx)
client_conn.set_connect_state()
try:
client_conn.recv(4096)
except SSL.WantReadError:
return True
except SSL.Error:
return False
EC_CURVES: dict[str, EllipticCurve] = {}
for oid in EllipticCurveOID.__dict__.values():
if isinstance(oid, ObjectIdentifier):
curve = get_curve_for_oid(oid)()
EC_CURVES[curve.name] = curve
@typing.overload
def get_curve(name: str) -> EllipticCurve: ...
@typing.overload
def get_curve(name: None) -> None: ...
def get_curve(name: str | None) -> EllipticCurve | None:
if name is None:
return None
return EC_CURVES[name]
class MasterSecretLogger:
def __init__(self, filename: Path):
self.filename = filename.expanduser()
self.f: BinaryIO | None = None
self.lock = threading.Lock()
# required for functools.wraps, which pyOpenSSL uses.
__name__ = "MasterSecretLogger"
def __call__(self, connection: SSL.Connection, keymaterial: bytes) -> None:
with self.lock:
if self.f is None:
self.filename.parent.mkdir(parents=True, exist_ok=True)
self.f = self.filename.open("ab")
self.f.write(b"\n")
self.f.write(keymaterial + b"\n")
self.f.flush()
def close(self):
with self.lock:
if self.f is not None:
self.f.close()
def make_master_secret_logger(filename: str | None) -> MasterSecretLogger | None:
if filename:
return MasterSecretLogger(Path(filename))
return None
log_master_secret = make_master_secret_logger(
os.getenv("MITMPROXY_SSLKEYLOGFILE") or os.getenv("SSLKEYLOGFILE")
)
def _create_ssl_context(
*,
method: Method,
min_version: Version,
max_version: Version,
cipher_list: Iterable[str] | None,
ecdh_curve: EllipticCurve | None,
) -> SSL.Context:
context = SSL.Context(method.value)
ok = SSL._lib.SSL_CTX_set_min_proto_version(context._context, min_version.value) # type: ignore
ok += SSL._lib.SSL_CTX_set_max_proto_version(context._context, max_version.value) # type: ignore
if ok != 2:
raise RuntimeError(
f"Error setting TLS versions ({min_version=}, {max_version=}). "
"The version you specified may be unavailable in your libssl."
)
# Options
context.set_options(DEFAULT_OPTIONS)
# ECDHE for Key exchange
if ecdh_curve is not None:
try:
context.set_tmp_ecdh(ecdh_curve)
except ValueError as e:
raise RuntimeError(f"Elliptic curve specification error: {e}") from e
# Cipher List
if cipher_list is not None:
try:
context.set_cipher_list(b":".join(x.encode() for x in cipher_list))
except SSL.Error as e:
raise RuntimeError(f"SSL cipher specification error: {e}") from e
# SSLKEYLOGFILE
if log_master_secret:
context.set_keylog_callback(log_master_secret)
return context
@lru_cache(256)
def create_proxy_server_context(
*,
method: Method,
min_version: Version,
max_version: Version,
cipher_list: tuple[str, ...] | None,
ecdh_curve: EllipticCurve | None,
verify: Verify,
ca_path: str | None,
ca_pemfile: str | None,
client_cert: str | None,
legacy_server_connect: bool,
) -> SSL.Context:
context: SSL.Context = _create_ssl_context(
method=method,
min_version=min_version,
max_version=max_version,
cipher_list=cipher_list,
ecdh_curve=ecdh_curve,
)
context.set_verify(verify.value, None)
if ca_path is None and ca_pemfile is None:
ca_pemfile = certifi.where()
try:
context.load_verify_locations(ca_pemfile, ca_path)
except SSL.Error as e:
raise RuntimeError(
f"Cannot load trusted certificates ({ca_pemfile=}, {ca_path=})."
) from e
# Client Certs
if client_cert:
try:
context.use_privatekey_file(client_cert)
context.use_certificate_chain_file(client_cert)
except SSL.Error as e:
raise RuntimeError(f"Cannot load TLS client certificate: {e}") from e
# https://github.com/mitmproxy/mitmproxy/discussions/7550
SSL._lib.SSL_CTX_set_post_handshake_auth(context._context, 1) # type: ignore
if legacy_server_connect:
context.set_options(OP_LEGACY_SERVER_CONNECT)
return context
@lru_cache(256)
def create_client_proxy_context(
*,
method: Method,
min_version: Version,
max_version: Version,
cipher_list: tuple[str, ...] | None,
ecdh_curve: EllipticCurve | None,
chain_file: Path | None,
alpn_select_callback: Callable[[SSL.Connection, list[bytes]], Any] | None,
request_client_cert: bool,
extra_chain_certs: tuple[certs.Cert, ...],
dhparams: certs.DHParams,
) -> SSL.Context:
context: SSL.Context = _create_ssl_context(
method=method,
min_version=min_version,
max_version=max_version,
cipher_list=cipher_list,
ecdh_curve=ecdh_curve,
)
if chain_file is not None:
try:
context.load_verify_locations(str(chain_file), None)
except SSL.Error as e:
raise RuntimeError(f"Cannot load certificate chain ({chain_file}).") from e
if alpn_select_callback is not None:
assert callable(alpn_select_callback)
context.set_alpn_select_callback(alpn_select_callback)
if request_client_cert:
# The request_client_cert argument requires some explanation. We're
# supposed to be able to do this with no negative effects - if the
# client has no cert to present, we're notified and proceed as usual.
# Unfortunately, Android seems to have a bug (tested on 4.2.2) - when
# an Android client is asked to present a certificate it does not
# have, it hangs up, which is frankly bogus. Some time down the track
# we may be able to make the proper behaviour the default again, but
# until then we're conservative.
context.set_verify(Verify.VERIFY_PEER.value, accept_all)
else:
context.set_verify(Verify.VERIFY_NONE.value, None)
for i in extra_chain_certs:
context.add_extra_chain_cert(i.to_cryptography())
if dhparams:
res = SSL._lib.SSL_CTX_set_tmp_dh(context._context, dhparams) # type: ignore
SSL._openssl_assert(res == 1) # type: ignore
return context
def accept_all(
conn_: SSL.Connection,
x509: OpenSSL.crypto.X509,
errno: int,
err_depth: int,
is_cert_verified: int,
) -> bool:
# Return true to prevent cert verification error
return True
def starts_like_tls_record(d: bytes) -> bool:
"""
Returns:
True, if the passed bytes could be the start of a TLS record
False, otherwise.
"""
# TLS ClientHello magic, works for SSLv3, TLSv1.0, TLSv1.1, TLSv1.2, and TLSv1.3
# http://www.moserware.com/2009/06/first-few-milliseconds-of-https.html#client-hello
# https://tls13.ulfheim.net/
# We assume that a client sending less than 3 bytes initially is not a TLS client.
return len(d) > 2 and d[0] == 0x16 and d[1] == 0x03 and 0x00 <= d[2] <= 0x03
def starts_like_dtls_record(d: bytes) -> bool:
"""
Returns:
True, if the passed bytes could be the start of a DTLS record
False, otherwise.
"""
# TLS ClientHello magic, works for DTLS 1.1, DTLS 1.2, and DTLS 1.3.
# https://www.rfc-editor.org/rfc/rfc4347#section-4.1
# https://www.rfc-editor.org/rfc/rfc6347#section-4.1
# https://www.rfc-editor.org/rfc/rfc9147#section-4-6.2
# We assume that a client sending less than 3 bytes initially is not a DTLS client.
return len(d) > 2 and d[0] == 0x16 and d[1] == 0xFE and 0xFD <= d[2] <= 0xFE