import contextlib import datetime import ipaddress import logging import os import sys import warnings from collections.abc import Iterable from dataclasses import dataclass from pathlib import Path from typing import cast from typing import NewType from typing import Optional from typing import Union import OpenSSL from cryptography import x509 from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import dsa from cryptography.hazmat.primitives.asymmetric import ec from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.asymmetric.types import CertificatePublicKeyTypes from cryptography.hazmat.primitives.serialization import pkcs12 from cryptography.x509 import ExtendedKeyUsageOID from cryptography.x509 import NameOID from mitmproxy.coretypes import serializable if sys.version_info < (3, 13): # pragma: no cover from typing_extensions import deprecated else: from warnings import deprecated logger = logging.getLogger(__name__) # Default expiry must not be too long: https://github.com/mitmproxy/mitmproxy/issues/815 CA_EXPIRY = datetime.timedelta(days=10 * 365) CERT_EXPIRY = datetime.timedelta(days=365) CRL_EXPIRY = datetime.timedelta(days=7) # Generated with "openssl dhparam". It's too slow to generate this on startup. DEFAULT_DHPARAM = b""" -----BEGIN DH PARAMETERS----- MIICCAKCAgEAyT6LzpwVFS3gryIo29J5icvgxCnCebcdSe/NHMkD8dKJf8suFCg3 O2+dguLakSVif/t6dhImxInJk230HmfC8q93hdcg/j8rLGJYDKu3ik6H//BAHKIv j5O9yjU3rXCfmVJQic2Nne39sg3CreAepEts2TvYHhVv3TEAzEqCtOuTjgDv0ntJ Gwpj+BJBRQGG9NvprX1YGJ7WOFBP/hWU7d6tgvE6Xa7T/u9QIKpYHMIkcN/l3ZFB chZEqVlyrcngtSXCROTPcDOQ6Q8QzhaBJS+Z6rcsd7X+haiQqvoFcmaJ08Ks6LQC ZIL2EtYJw8V8z7C0igVEBIADZBI6OTbuuhDwRw//zU1uq52Oc48CIZlGxTYG/Evq o9EWAXUYVzWkDSTeBH1r4z/qLPE2cnhtMxbFxuvK53jGB0emy2y1Ei6IhKshJ5qX IB/aE7SSHyQ3MDHHkCmQJCsOd4Mo26YX61NZ+n501XjqpCBQ2+DfZCBh8Va2wDyv A2Ryg9SUz8j0AXViRNMJgJrr446yro/FuJZwnQcO3WQnXeqSBnURqKjmqkeFP+d8 6mk2tqJaY507lRNqtGlLnj7f5RNoBFJDCLBNurVgfvq9TCVWKDIFD4vZRjCrnl6I rD693XKIHUCWOjMh1if6omGXKHH40QuME2gNa50+YPn1iYDl88uDbbMCAQI= -----END DH PARAMETERS----- """ class Cert(serializable.Serializable): """Representation of a (TLS) certificate.""" _cert: x509.Certificate def __init__(self, cert: x509.Certificate): assert isinstance(cert, x509.Certificate) self._cert = cert def __eq__(self, other): return self.fingerprint() == other.fingerprint() def __repr__(self): altnames = [str(x.value) for x in self.altnames] return f"" def __hash__(self): return self._cert.__hash__() @classmethod def from_state(cls, state): return cls.from_pem(state) def get_state(self): return self.to_pem() def set_state(self, state): self._cert = x509.load_pem_x509_certificate(state) @classmethod def from_pem(cls, data: bytes) -> "Cert": cert = x509.load_pem_x509_certificate(data) # type: ignore return cls(cert) def to_pem(self) -> bytes: return self._cert.public_bytes(serialization.Encoding.PEM) @classmethod def from_pyopenssl(self, x509: OpenSSL.crypto.X509) -> "Cert": return Cert(x509.to_cryptography()) @deprecated("Use `to_cryptography` instead.") def to_pyopenssl(self) -> OpenSSL.crypto.X509: # pragma: no cover return OpenSSL.crypto.X509.from_cryptography(self._cert) def to_cryptography(self) -> x509.Certificate: return self._cert def public_key(self) -> CertificatePublicKeyTypes: return self._cert.public_key() def fingerprint(self) -> bytes: return self._cert.fingerprint(hashes.SHA256()) @property def issuer(self) -> list[tuple[str, str]]: return _name_to_keyval(self._cert.issuer) @property def notbefore(self) -> datetime.datetime: try: # type definitions haven't caught up with new API yet. return self._cert.not_valid_before_utc # type: ignore except AttributeError: # pragma: no cover # cryptography < 42.0 return self._cert.not_valid_before.replace(tzinfo=datetime.UTC) @property def notafter(self) -> datetime.datetime: try: return self._cert.not_valid_after_utc # type: ignore except AttributeError: # pragma: no cover return self._cert.not_valid_after.replace(tzinfo=datetime.UTC) def has_expired(self) -> bool: if sys.version_info < (3, 11): # pragma: no cover return datetime.datetime.now(datetime.UTC) > self.notafter return datetime.datetime.now(datetime.UTC) > self.notafter @property def subject(self) -> list[tuple[str, str]]: return _name_to_keyval(self._cert.subject) @property def serial(self) -> int: return self._cert.serial_number @property def is_ca(self) -> bool: constraints: x509.BasicConstraints try: constraints = self._cert.extensions.get_extension_for_class( x509.BasicConstraints ).value return constraints.ca except x509.ExtensionNotFound: return False @property def keyinfo(self) -> tuple[str, int]: public_key = self._cert.public_key() if isinstance(public_key, rsa.RSAPublicKey): return "RSA", public_key.key_size if isinstance(public_key, dsa.DSAPublicKey): return "DSA", public_key.key_size if isinstance(public_key, ec.EllipticCurvePublicKey): return f"EC ({public_key.curve.name})", public_key.key_size return ( public_key.__class__.__name__.replace("PublicKey", "").replace("_", ""), getattr(public_key, "key_size", -1), ) # pragma: no cover @property def cn(self) -> str | None: attrs = self._cert.subject.get_attributes_for_oid(x509.NameOID.COMMON_NAME) if attrs: return cast(str, attrs[0].value) return None @property def organization(self) -> str | None: attrs = self._cert.subject.get_attributes_for_oid( x509.NameOID.ORGANIZATION_NAME ) if attrs: return cast(str, attrs[0].value) return None @property def altnames(self) -> x509.GeneralNames: """ Get all SubjectAlternativeName DNS altnames. """ try: sans = self._cert.extensions.get_extension_for_class( x509.SubjectAlternativeName ).value except x509.ExtensionNotFound: return x509.GeneralNames([]) else: return x509.GeneralNames(sans) @property def crl_distribution_points(self) -> list[str]: try: ext = self._cert.extensions.get_extension_for_class( x509.CRLDistributionPoints ).value except x509.ExtensionNotFound: return [] else: return [ dist_point.full_name[0].value for dist_point in ext if dist_point.full_name and isinstance(dist_point.full_name[0], x509.UniformResourceIdentifier) ] def _name_to_keyval(name: x509.Name) -> list[tuple[str, str]]: parts = [] for attr in name: k = attr.rfc4514_string().partition("=")[0] v = cast(str, attr.value) parts.append((k, v)) return parts def create_ca( organization: str, cn: str, key_size: int, ) -> tuple[rsa.RSAPrivateKeyWithSerialization, x509.Certificate]: now = datetime.datetime.now() private_key = rsa.generate_private_key( public_exponent=65537, key_size=key_size, ) # type: ignore name = x509.Name( [ x509.NameAttribute(NameOID.COMMON_NAME, cn), x509.NameAttribute(NameOID.ORGANIZATION_NAME, organization), ] ) builder = x509.CertificateBuilder() builder = builder.serial_number(x509.random_serial_number()) builder = builder.subject_name(name) builder = builder.not_valid_before(now - datetime.timedelta(days=2)) builder = builder.not_valid_after(now + CA_EXPIRY) builder = builder.issuer_name(name) builder = builder.public_key(private_key.public_key()) builder = builder.add_extension( x509.BasicConstraints(ca=True, path_length=None), critical=True ) builder = builder.add_extension( x509.ExtendedKeyUsage([ExtendedKeyUsageOID.SERVER_AUTH]), critical=False ) builder = builder.add_extension( x509.KeyUsage( digital_signature=False, content_commitment=False, key_encipherment=False, data_encipherment=False, key_agreement=False, key_cert_sign=True, crl_sign=True, encipher_only=False, decipher_only=False, ), critical=True, ) builder = builder.add_extension( x509.SubjectKeyIdentifier.from_public_key(private_key.public_key()), critical=False, ) cert = builder.sign(private_key=private_key, algorithm=hashes.SHA256()) # type: ignore return private_key, cert def _fix_legacy_sans(sans: Iterable[x509.GeneralName] | list[str]) -> x509.GeneralNames: """ SANs used to be a list of strings in mitmproxy 10.1 and below, but now they're a list of GeneralNames. This function converts the old format to the new one. """ if isinstance(sans, x509.GeneralNames): return sans elif ( isinstance(sans, list) and len(sans) > 0 and isinstance(sans[0], str) ): # pragma: no cover warnings.warn( "Passing SANs as a list of strings is deprecated.", DeprecationWarning, stacklevel=2, ) ss: list[x509.GeneralName] = [] for x in cast(list[str], sans): try: ip = ipaddress.ip_address(x) except ValueError: x = x.encode("idna").decode() ss.append(x509.DNSName(x)) else: ss.append(x509.IPAddress(ip)) return x509.GeneralNames(ss) else: return x509.GeneralNames(cast(Iterable[x509.GeneralName], sans)) def dummy_cert( privkey: rsa.RSAPrivateKey, cacert: x509.Certificate, commonname: str | None, sans: Iterable[x509.GeneralName], organization: str | None = None, crl_url: str | None = None, ) -> Cert: """ Generates a dummy certificate. privkey: CA private key cacert: CA certificate commonname: Common name for the generated certificate. sans: A list of Subject Alternate Names. organization: Organization name for the generated certificate. crl_url: URL of CRL distribution point Returns cert if operation succeeded, None if not. """ builder = x509.CertificateBuilder() builder = builder.issuer_name(cacert.subject) builder = builder.add_extension( x509.ExtendedKeyUsage([ExtendedKeyUsageOID.SERVER_AUTH]), critical=False ) builder = builder.public_key(cacert.public_key()) now = datetime.datetime.now() builder = builder.not_valid_before(now - datetime.timedelta(days=2)) builder = builder.not_valid_after(now + CERT_EXPIRY) subject = [] is_valid_commonname = commonname is not None and len(commonname) < 64 if is_valid_commonname: assert commonname is not None subject.append(x509.NameAttribute(NameOID.COMMON_NAME, commonname)) if organization is not None: assert organization is not None subject.append(x509.NameAttribute(NameOID.ORGANIZATION_NAME, organization)) builder = builder.subject_name(x509.Name(subject)) builder = builder.serial_number(x509.random_serial_number()) # RFC 5280 ยง4.2.1.6: subjectAltName is critical if subject is empty. builder = builder.add_extension( x509.SubjectAlternativeName(_fix_legacy_sans(sans)), critical=not is_valid_commonname, ) # https://datatracker.ietf.org/doc/html/rfc5280#section-4.2.1.1 builder = builder.add_extension( x509.AuthorityKeyIdentifier.from_issuer_public_key(cacert.public_key()), # type: ignore critical=False, ) # If CA and leaf cert have the same Subject Key Identifier, SChannel breaks in funny ways, # see https://github.com/mitmproxy/mitmproxy/issues/6494. # https://datatracker.ietf.org/doc/html/rfc5280#section-4.2.1.2 states # that SKI is optional for the leaf cert, so we skip that. if crl_url: builder = builder.add_extension( x509.CRLDistributionPoints( [ x509.DistributionPoint( [x509.UniformResourceIdentifier(crl_url)], relative_name=None, crl_issuer=None, reasons=None, ) ] ), critical=False, ) cert = builder.sign(private_key=privkey, algorithm=hashes.SHA256()) # type: ignore return Cert(cert) def dummy_crl( privkey: rsa.RSAPrivateKey, cacert: x509.Certificate, ) -> bytes: """ Generates an empty CRL signed with the CA key privkey: CA private key cacert: CA certificate Returns a CRL DER encoded """ builder = x509.CertificateRevocationListBuilder() builder = builder.issuer_name(cacert.issuer) now = datetime.datetime.now() builder = builder.last_update(now - datetime.timedelta(days=2)) builder = builder.next_update(now + CRL_EXPIRY) builder = builder.add_extension(x509.CRLNumber(1000), False) # meaningless number crl = builder.sign(private_key=privkey, algorithm=hashes.SHA256()) return crl.public_bytes(serialization.Encoding.DER) @dataclass(frozen=True) class CertStoreEntry: cert: Cert privatekey: rsa.RSAPrivateKey chain_file: Path | None chain_certs: list[Cert] TCustomCertId = str # manually provided certs (e.g. mitmproxy's --certs) TGeneratedCertId = tuple[Optional[str], x509.GeneralNames] # (common_name, sans) TCertId = Union[TCustomCertId, TGeneratedCertId] DHParams = NewType("DHParams", bytes) class CertStore: """ Implements an in-memory certificate store. """ STORE_CAP = 100 default_privatekey: rsa.RSAPrivateKey default_ca: Cert default_chain_file: Path | None default_chain_certs: list[Cert] dhparams: DHParams certs: dict[TCertId, CertStoreEntry] expire_queue: list[CertStoreEntry] def __init__( self, default_privatekey: rsa.RSAPrivateKey, default_ca: Cert, default_chain_file: Path | None, default_crl: bytes, dhparams: DHParams, ): self.default_privatekey = default_privatekey self.default_ca = default_ca self.default_chain_file = default_chain_file self.default_crl = default_crl self.default_chain_certs = ( [ Cert(c) for c in x509.load_pem_x509_certificates( self.default_chain_file.read_bytes() ) ] if self.default_chain_file else [default_ca] ) self.dhparams = dhparams self.certs = {} self.expire_queue = [] def expire(self, entry: CertStoreEntry) -> None: self.expire_queue.append(entry) if len(self.expire_queue) > self.STORE_CAP: d = self.expire_queue.pop(0) self.certs = {k: v for k, v in self.certs.items() if v != d} @staticmethod def load_dhparam(path: Path) -> DHParams: # mitmproxy<=0.10 doesn't generate a dhparam file. # Create it now if necessary. if not path.exists(): path.write_bytes(DEFAULT_DHPARAM) # we could use cryptography for this, but it's unclear how to convert cryptography's object to pyOpenSSL's # expected format. bio = OpenSSL.SSL._lib.BIO_new_file( # type: ignore str(path).encode(sys.getfilesystemencoding()), b"r" ) if bio != OpenSSL.SSL._ffi.NULL: # type: ignore bio = OpenSSL.SSL._ffi.gc(bio, OpenSSL.SSL._lib.BIO_free) # type: ignore dh = OpenSSL.SSL._lib.PEM_read_bio_DHparams( # type: ignore bio, OpenSSL.SSL._ffi.NULL, # type: ignore OpenSSL.SSL._ffi.NULL, # type: ignore OpenSSL.SSL._ffi.NULL, # type: ignore ) dh = OpenSSL.SSL._ffi.gc(dh, OpenSSL.SSL._lib.DH_free) # type: ignore return dh raise RuntimeError("Error loading DH Params.") # pragma: no cover @classmethod def from_store( cls, path: Path | str, basename: str, key_size: int, passphrase: bytes | None = None, ) -> "CertStore": path = Path(path) ca_file = path / f"{basename}-ca.pem" dhparam_file = path / f"{basename}-dhparam.pem" if not ca_file.exists(): cls.create_store(path, basename, key_size) return cls.from_files(ca_file, dhparam_file, passphrase) @classmethod def from_files( cls, ca_file: Path, dhparam_file: Path, passphrase: bytes | None = None ) -> "CertStore": raw = ca_file.read_bytes() key = load_pem_private_key(raw, passphrase) dh = cls.load_dhparam(dhparam_file) certs = x509.load_pem_x509_certificates(raw) ca = Cert(certs[0]) crl = dummy_crl(key, ca._cert) if len(certs) > 1: chain_file: Path | None = ca_file else: chain_file = None return cls(key, ca, chain_file, crl, dh) @staticmethod @contextlib.contextmanager def umask_secret(): """ Context to temporarily set umask to its original value bitor 0o77. Useful when writing private keys to disk so that only the owner will be able to read them. """ original_umask = os.umask(0) os.umask(original_umask | 0o77) try: yield finally: os.umask(original_umask) @staticmethod def create_store( path: Path, basename: str, key_size: int, organization=None, cn=None ) -> None: path.mkdir(parents=True, exist_ok=True) organization = organization or basename cn = cn or basename key: rsa.RSAPrivateKeyWithSerialization ca: x509.Certificate key, ca = create_ca(organization=organization, cn=cn, key_size=key_size) # Dump the CA plus private key. with CertStore.umask_secret(): # PEM format (path / f"{basename}-ca.pem").write_bytes( key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, encryption_algorithm=serialization.NoEncryption(), ) + ca.public_bytes(serialization.Encoding.PEM) ) # PKCS12 format for Windows devices (path / f"{basename}-ca.p12").write_bytes( pkcs12.serialize_key_and_certificates( # type: ignore name=basename.encode(), key=key, cert=ca, cas=None, encryption_algorithm=serialization.NoEncryption(), ) ) # Dump the certificate in PEM format pem_cert = ca.public_bytes(serialization.Encoding.PEM) (path / f"{basename}-ca-cert.pem").write_bytes(pem_cert) # Create a .cer file with the same contents for Android (path / f"{basename}-ca-cert.cer").write_bytes(pem_cert) # Dump the certificate in PKCS12 format for Windows devices (path / f"{basename}-ca-cert.p12").write_bytes( pkcs12.serialize_key_and_certificates( name=basename.encode(), key=None, # type: ignore cert=ca, cas=None, encryption_algorithm=serialization.NoEncryption(), ) ) (path / f"{basename}-dhparam.pem").write_bytes(DEFAULT_DHPARAM) def add_cert_file( self, spec: str, path: Path, passphrase: bytes | None = None ) -> None: raw = path.read_bytes() cert = Cert.from_pem(raw) try: private_key = load_pem_private_key(raw, password=passphrase) except ValueError as e: private_key = self.default_privatekey if cert.public_key() != private_key.public_key(): raise ValueError( f'Unable to find private key in "{path.absolute()}": {e}' ) from e else: if cert.public_key() != private_key.public_key(): raise ValueError( f'Private and public keys in "{path.absolute()}" do not match:\n' f"{cert.public_key()=}\n" f"{private_key.public_key()=}" ) try: chain = [Cert(x) for x in x509.load_pem_x509_certificates(raw)] except ValueError as e: logger.warning(f"Failed to read certificate chain: {e}") chain = [cert] if cert.is_ca: logger.warning( f'"{path.absolute()}" is a certificate authority and not a leaf certificate. ' f"This indicates a misconfiguration, see https://docs.mitmproxy.org/stable/concepts-certificates/." ) self.add_cert(CertStoreEntry(cert, private_key, path, chain), spec) def add_cert(self, entry: CertStoreEntry, *names: str) -> None: """ Adds a cert to the certstore. We register the CN in the cert plus any SANs, and also the list of names provided as an argument. """ if entry.cert.cn: self.certs[entry.cert.cn] = entry for i in entry.cert.altnames: self.certs[str(i.value)] = entry for i in names: self.certs[i] = entry @staticmethod def asterisk_forms(dn: str | x509.GeneralName) -> list[str]: """ Return all asterisk forms for a domain. For example, for www.example.com this will return [b"www.example.com", b"*.example.com", b"*.com"]. The single wildcard "*" is omitted. """ if isinstance(dn, str): parts = dn.split(".") ret = [dn] for i in range(1, len(parts)): ret.append("*." + ".".join(parts[i:])) return ret elif isinstance(dn, x509.DNSName): return CertStore.asterisk_forms(dn.value) else: return [str(dn.value)] def get_cert( self, commonname: str | None, sans: Iterable[x509.GeneralName], organization: str | None = None, crl_url: str | None = None, ) -> CertStoreEntry: """ commonname: Common name for the generated certificate. Must be a valid, plain-ASCII, IDNA-encoded domain name. sans: A list of Subject Alternate Names. organization: Organization name for the generated certificate. crl_url: URL of CRL distribution point """ sans = _fix_legacy_sans(sans) potential_keys: list[TCertId] = [] if commonname: potential_keys.extend(self.asterisk_forms(commonname)) for s in sans: potential_keys.extend(self.asterisk_forms(s)) potential_keys.append("*") potential_keys.append((commonname, sans)) name = next(filter(lambda key: key in self.certs, potential_keys), None) if name: entry = self.certs[name] else: entry = CertStoreEntry( cert=dummy_cert( self.default_privatekey, self.default_ca._cert, commonname, sans, organization, crl_url, ), privatekey=self.default_privatekey, chain_file=self.default_chain_file, chain_certs=self.default_chain_certs, ) self.certs[(commonname, sans)] = entry self.expire(entry) return entry def load_pem_private_key(data: bytes, password: bytes | None) -> rsa.RSAPrivateKey: """ like cryptography's load_pem_private_key, but silently falls back to not using a password if the private key is unencrypted. """ try: return serialization.load_pem_private_key(data, password) # type: ignore except TypeError: if password is not None: return load_pem_private_key(data, None) raise