2025-12-25 upload

This commit is contained in:
“shengyudong”
2025-12-25 11:16:59 +08:00
commit 322ac74336
2241 changed files with 639966 additions and 0 deletions

View File

@@ -0,0 +1,6 @@
from .io import FilteredFlowWriter
from .io import FlowReader
from .io import FlowWriter
from .io import read_flows_from_paths
__all__ = ["FlowWriter", "FlowReader", "FilteredFlowWriter", "read_flows_from_paths"]

View File

@@ -0,0 +1,528 @@
"""
This module handles the import of mitmproxy flows generated by old versions.
The flow file version is decoupled from the mitmproxy release cycle (since
v3.0.0dev) and versioning. Every change or migration gets a new flow file
version number, this prevents issues with developer builds and snapshots.
"""
import copy
import uuid
from typing import Any
from mitmproxy import version
from mitmproxy.utils import strutils
def convert_011_012(data):
data[b"version"] = (0, 12)
return data
def convert_012_013(data):
data[b"version"] = (0, 13)
return data
def convert_013_014(data):
data[b"request"][b"first_line_format"] = data[b"request"].pop(b"form_in")
data[b"request"][b"http_version"] = (
b"HTTP/"
+ ".".join(str(x) for x in data[b"request"].pop(b"httpversion")).encode()
)
data[b"response"][b"http_version"] = (
b"HTTP/"
+ ".".join(str(x) for x in data[b"response"].pop(b"httpversion")).encode()
)
data[b"response"][b"status_code"] = data[b"response"].pop(b"code")
data[b"response"][b"body"] = data[b"response"].pop(b"content")
data[b"server_conn"].pop(b"state")
data[b"server_conn"][b"via"] = None
data[b"version"] = (0, 14)
return data
def convert_014_015(data):
data[b"version"] = (0, 15)
return data
def convert_015_016(data):
for m in (b"request", b"response"):
if b"body" in data[m]:
data[m][b"content"] = data[m].pop(b"body")
if b"msg" in data[b"response"]:
data[b"response"][b"reason"] = data[b"response"].pop(b"msg")
data[b"request"].pop(b"form_out", None)
data[b"version"] = (0, 16)
return data
def convert_016_017(data):
data[b"server_conn"][b"peer_address"] = None
data[b"version"] = (0, 17)
return data
def convert_017_018(data):
# convert_unicode needs to be called for every dual release and the first py3-only release
data = convert_unicode(data)
data["server_conn"]["ip_address"] = data["server_conn"].pop("peer_address", None)
data["marked"] = False
data["version"] = (0, 18)
return data
def convert_018_019(data):
# convert_unicode needs to be called for every dual release and the first py3-only release
data = convert_unicode(data)
data["request"].pop("stickyauth", None)
data["request"].pop("stickycookie", None)
data["client_conn"]["sni"] = None
data["client_conn"]["alpn_proto_negotiated"] = None
data["client_conn"]["cipher_name"] = None
data["client_conn"]["tls_version"] = None
data["server_conn"]["alpn_proto_negotiated"] = None
if data["server_conn"]["via"]:
data["server_conn"]["via"]["alpn_proto_negotiated"] = None
data["mode"] = "regular"
data["metadata"] = dict()
data["version"] = (0, 19)
return data
def convert_019_100(data):
# convert_unicode needs to be called for every dual release and the first py3-only release
data = convert_unicode(data)
data["version"] = (1, 0, 0)
return data
def convert_100_200(data):
data["version"] = (2, 0, 0)
data["client_conn"]["address"] = data["client_conn"]["address"]["address"]
data["server_conn"]["address"] = data["server_conn"]["address"]["address"]
data["server_conn"]["source_address"] = data["server_conn"]["source_address"][
"address"
]
if data["server_conn"]["ip_address"]:
data["server_conn"]["ip_address"] = data["server_conn"]["ip_address"]["address"]
if data["server_conn"]["via"]:
data["server_conn"]["via"]["address"] = data["server_conn"]["via"]["address"][
"address"
]
data["server_conn"]["via"]["source_address"] = data["server_conn"]["via"][
"source_address"
]["address"]
if data["server_conn"]["via"]["ip_address"]:
data["server_conn"]["via"]["ip_address"] = data["server_conn"]["via"][
"ip_address"
]["address"]
return data
def convert_200_300(data):
data["version"] = (3, 0, 0)
data["client_conn"]["mitmcert"] = None
data["server_conn"]["tls_version"] = None
if data["server_conn"]["via"]:
data["server_conn"]["via"]["tls_version"] = None
return data
def convert_300_4(data):
data["version"] = 4
# This is an empty migration to transition to the new versioning scheme.
return data
client_connections: dict[tuple[str, ...], str] = {}
server_connections: dict[tuple[str, ...], str] = {}
def convert_4_5(data):
data["version"] = 5
client_conn_key = (
data["client_conn"]["timestamp_start"],
*data["client_conn"]["address"],
)
server_conn_key = (
data["server_conn"]["timestamp_start"],
*data["server_conn"]["source_address"],
)
data["client_conn"]["id"] = client_connections.setdefault(
client_conn_key, str(uuid.uuid4())
)
data["server_conn"]["id"] = server_connections.setdefault(
server_conn_key, str(uuid.uuid4())
)
if data["server_conn"]["via"]:
server_conn_key = (
data["server_conn"]["via"]["timestamp_start"],
*data["server_conn"]["via"]["source_address"],
)
data["server_conn"]["via"]["id"] = server_connections.setdefault(
server_conn_key, str(uuid.uuid4())
)
return data
def convert_5_6(data):
data["version"] = 6
data["client_conn"]["tls_established"] = data["client_conn"].pop("ssl_established")
data["client_conn"]["timestamp_tls_setup"] = data["client_conn"].pop(
"timestamp_ssl_setup"
)
data["server_conn"]["tls_established"] = data["server_conn"].pop("ssl_established")
data["server_conn"]["timestamp_tls_setup"] = data["server_conn"].pop(
"timestamp_ssl_setup"
)
if data["server_conn"]["via"]:
data["server_conn"]["via"]["tls_established"] = data["server_conn"]["via"].pop(
"ssl_established"
)
data["server_conn"]["via"]["timestamp_tls_setup"] = data["server_conn"][
"via"
].pop("timestamp_ssl_setup")
return data
def convert_6_7(data):
data["version"] = 7
data["client_conn"]["tls_extensions"] = None
return data
def convert_7_8(data):
data["version"] = 8
if "request" in data and data["request"] is not None:
data["request"]["trailers"] = None
if "response" in data and data["response"] is not None:
data["response"]["trailers"] = None
return data
def convert_8_9(data):
data["version"] = 9
is_request_replay = False
if "request" in data:
data["request"].pop("first_line_format")
data["request"]["authority"] = b""
is_request_replay = data["request"].pop("is_replay", False)
is_response_replay = False
if "response" in data and data["response"] is not None:
is_response_replay = data["response"].pop("is_replay", False)
if is_request_replay: # pragma: no cover
data["is_replay"] = "request"
elif is_response_replay: # pragma: no cover
data["is_replay"] = "response"
else:
data["is_replay"] = None
return data
def convert_9_10(data):
data["version"] = 10
def conv_conn(conn):
conn["state"] = 0
conn["error"] = None
conn["tls"] = conn["tls_established"]
alpn = conn["alpn_proto_negotiated"]
conn["alpn_offers"] = [alpn] if alpn else None
cipher = conn["cipher_name"]
conn["cipher_list"] = [cipher] if cipher else None
def conv_cconn(conn):
conn["sockname"] = ("", 0)
cc = conn.pop("clientcert", None)
conn["certificate_list"] = [cc] if cc else []
conv_conn(conn)
def conv_sconn(conn):
crt = conn.pop("cert", None)
conn["certificate_list"] = [crt] if crt else []
conn["cipher_name"] = None
conn["via2"] = None
conv_conn(conn)
conv_cconn(data["client_conn"])
conv_sconn(data["server_conn"])
if data["server_conn"]["via"]:
conv_sconn(data["server_conn"]["via"])
return data
def convert_10_11(data):
data["version"] = 11
def conv_conn(conn):
conn["sni"] = strutils.always_str(conn["sni"], "ascii", "backslashreplace")
conn["alpn"] = conn.pop("alpn_proto_negotiated")
conn["alpn_offers"] = conn["alpn_offers"] or []
conn["cipher_list"] = conn["cipher_list"] or []
conv_conn(data["client_conn"])
conv_conn(data["server_conn"])
if data["server_conn"]["via"]:
conv_conn(data["server_conn"]["via"])
return data
_websocket_handshakes = {}
def convert_11_12(data):
data["version"] = 12
if "websocket" in data["metadata"]:
_websocket_handshakes[data["id"]] = copy.deepcopy(data)
if "websocket_handshake" in data["metadata"]:
ws_flow = data
try:
data = _websocket_handshakes.pop(data["metadata"]["websocket_handshake"])
except KeyError:
# The handshake flow is missing, which should never really happen. We make up a dummy.
data = {
"client_conn": data["client_conn"],
"error": data["error"],
"id": data["id"],
"intercepted": data["intercepted"],
"is_replay": data["is_replay"],
"marked": data["marked"],
"metadata": {},
"mode": "transparent",
"request": {
"authority": b"",
"content": None,
"headers": [],
"host": b"unknown",
"http_version": b"HTTP/1.1",
"method": b"GET",
"path": b"/",
"port": 80,
"scheme": b"http",
"timestamp_end": 0,
"timestamp_start": 0,
"trailers": None,
},
"response": None,
"server_conn": data["server_conn"],
"type": "http",
"version": 12,
}
data["metadata"]["duplicated"] = (
"This WebSocket flow has been migrated from an old file format version "
"and may appear duplicated."
)
data["websocket"] = {
"messages": ws_flow["messages"],
"closed_by_client": ws_flow["close_sender"] == "client",
"close_code": ws_flow["close_code"],
"close_reason": ws_flow["close_reason"],
"timestamp_end": data.get("server_conn", {}).get("timestamp_end", None),
}
else:
data["websocket"] = None
return data
def convert_12_13(data):
data["version"] = 13
if data["marked"]:
data["marked"] = ":default:"
else:
data["marked"] = ""
return data
def convert_13_14(data):
data["version"] = 14
data["comment"] = ""
# bugfix for https://github.com/mitmproxy/mitmproxy/issues/4576
if data.get("response", None) and data["response"]["timestamp_start"] is None:
data["response"]["timestamp_start"] = data["request"]["timestamp_end"]
data["response"]["timestamp_end"] = data["request"]["timestamp_end"] + 1
return data
def convert_14_15(data):
data["version"] = 15
if data.get("websocket", None):
# Add "injected" attribute.
data["websocket"]["messages"] = [
msg + [False] for msg in data["websocket"]["messages"]
]
return data
def convert_15_16(data):
data["version"] = 16
data["timestamp_created"] = data.get("request", data["client_conn"])[
"timestamp_start"
]
return data
def convert_16_17(data):
data["version"] = 17
data.pop("mode", None)
return data
def convert_17_18(data):
data["version"] = 18
data["client_conn"]["proxy_mode"] = "regular"
return data
def convert_18_19(data):
data["version"] = 19
data["client_conn"]["peername"] = data["client_conn"].pop("address", None)
if data["client_conn"].get("timestamp_start") is None:
data["client_conn"]["timestamp_start"] = 0.0
data["client_conn"].pop("tls_extensions")
data["server_conn"]["peername"] = data["server_conn"].pop("ip_address", None)
data["server_conn"]["sockname"] = data["server_conn"].pop("source_address", None)
data["server_conn"]["via"] = data["server_conn"].pop("via2", None)
for conn in ["client_conn", "server_conn"]:
data[conn].pop("tls_established")
data[conn]["cipher"] = data[conn].pop("cipher_name", None)
data[conn].setdefault("transport_protocol", "tcp")
for name in ["peername", "sockname", "address"]:
if data[conn].get(name) and isinstance(data[conn][name][0], bytes):
data[conn][name][0] = data[conn][name][0].decode(
errors="backslashreplace"
)
if data["server_conn"]["sni"] is True:
data["server_conn"]["sni"] = data["server_conn"]["address"][0]
return data
def convert_19_20(data):
data["version"] = 20
data["client_conn"].pop("state", None)
data["server_conn"].pop("state", None)
return data
def convert_20_21(data):
data["version"] = 21
if data["client_conn"]["tls_version"] == "QUIC":
data["client_conn"]["tls_version"] = "QUICv1"
if data["server_conn"]["tls_version"] == "QUIC":
data["server_conn"]["tls_version"] = "QUICv1"
return data
def _convert_dict_keys(o: Any) -> Any:
if isinstance(o, dict):
return {strutils.always_str(k): _convert_dict_keys(v) for k, v in o.items()}
else:
return o
def _convert_dict_vals(o: dict, values_to_convert: dict) -> dict:
for k, v in values_to_convert.items():
if not o or k not in o:
continue # pragma: no cover
if v is True:
o[k] = strutils.always_str(o[k])
else:
_convert_dict_vals(o[k], v)
return o
def convert_unicode(data: dict) -> dict:
"""
This method converts between Python 3 and Python 2 dumpfiles.
"""
data = _convert_dict_keys(data)
data = _convert_dict_vals(
data,
{
"type": True,
"id": True,
"request": {"first_line_format": True},
"error": {"msg": True},
},
)
return data
converters = {
(0, 11): convert_011_012,
(0, 12): convert_012_013,
(0, 13): convert_013_014,
(0, 14): convert_014_015,
(0, 15): convert_015_016,
(0, 16): convert_016_017,
(0, 17): convert_017_018,
(0, 18): convert_018_019,
(0, 19): convert_019_100,
(1, 0): convert_100_200,
(2, 0): convert_200_300,
(3, 0): convert_300_4,
4: convert_4_5,
5: convert_5_6,
6: convert_6_7,
7: convert_7_8,
8: convert_8_9,
9: convert_9_10,
10: convert_10_11,
11: convert_11_12,
12: convert_12_13,
13: convert_13_14,
14: convert_14_15,
15: convert_15_16,
16: convert_16_17,
17: convert_17_18,
18: convert_18_19,
19: convert_19_20,
20: convert_20_21,
}
def migrate_flow(flow_data: dict[bytes | str, Any]) -> dict[bytes | str, Any]:
while True:
flow_version = flow_data.get(b"version", flow_data.get("version"))
# Historically, we used the mitmproxy minor version tuple as the flow format version.
if not isinstance(flow_version, int):
flow_version = tuple(flow_version)[:2] # type: ignore
if flow_version == version.FLOW_FORMAT_VERSION:
break
elif flow_version in converters:
flow_data = converters[flow_version](flow_data)
else:
should_upgrade = (
isinstance(flow_version, int)
and flow_version > version.FLOW_FORMAT_VERSION
)
raise ValueError(
"{} cannot read files with flow format version {}{}.".format(
version.MITMPROXY,
flow_version,
", please update mitmproxy" if should_upgrade else "",
)
)
return flow_data

View File

@@ -0,0 +1,159 @@
"""Reads HAR files into flow objects"""
import base64
import logging
import time
from datetime import datetime
from mitmproxy import connection
from mitmproxy import exceptions
from mitmproxy import http
from mitmproxy.net.http.headers import infer_content_encoding
logger = logging.getLogger(__name__)
def fix_headers(
request_headers: list[dict[str, str]] | list[tuple[str, str]],
) -> http.Headers:
"""Converts provided headers into (b"header-name", b"header-value") tuples"""
flow_headers: list[tuple[bytes, bytes]] = []
for header in request_headers:
# Applications that use the {"name":item,"value":item} notation are Brave,Chrome,Edge,Firefox,Charles,Fiddler,Insomnia,Safari
if isinstance(header, dict):
key = header["name"]
value = header["value"]
# Application that uses the [name, value] notation is Slack
else:
try:
key = header[0]
value = header[1]
except IndexError as e:
raise exceptions.OptionsError(str(e)) from e
flow_headers.append((key.encode(), value.encode()))
return http.Headers(flow_headers)
def request_to_flow(request_json: dict) -> http.HTTPFlow:
"""
Creates a HTTPFlow object from a given entry in HAR file
"""
timestamp_start = datetime.fromisoformat(
request_json["startedDateTime"].replace("Z", "+00:00")
).timestamp()
timestamp_end = timestamp_start + request_json["time"] / 1000.0
request_method = request_json["request"]["method"]
request_url = request_json["request"]["url"]
server_address = request_json.get("serverIPAddress", None)
request_headers = fix_headers(request_json["request"]["headers"])
http_version_req = request_json["request"]["httpVersion"]
http_version_resp = request_json["response"]["httpVersion"]
request_content = ""
# List contains all the representations of an http request across different HAR files
if request_url.startswith("http://"):
port = 80
else:
port = 443
client_conn = connection.Client(
peername=("127.0.0.1", 0),
sockname=("127.0.0.1", 0),
# TODO Get time info from HAR File
timestamp_start=time.time(),
)
if server_address:
server_conn = connection.Server(address=(server_address, port))
else:
server_conn = connection.Server(address=None)
new_flow = http.HTTPFlow(client_conn, server_conn)
if "postData" in request_json["request"]:
request_content = request_json["request"]["postData"]["text"]
new_flow.request = http.Request.make(
request_method, request_url, request_content, request_headers
)
response_code = request_json["response"]["status"]
# In Firefox HAR files images don't include response bodies
response_content = request_json["response"]["content"].get("text", "")
content_encoding = request_json["response"]["content"].get("encoding", None)
response_headers = fix_headers(request_json["response"]["headers"])
if content_encoding == "base64":
response_content = base64.b64decode(response_content)
elif isinstance(response_content, str):
# Convert text to bytes, as in `Response.set_text`
try:
response_content = http.encoding.encode(
response_content,
(
content_encoding
or infer_content_encoding(response_headers.get("content-type", ""))
),
)
except ValueError:
# Fallback to UTF-8
response_content = response_content.encode(
"utf-8", errors="surrogateescape"
)
# Then encode the content, as in `Response.set_content`
response_content = http.encoding.encode(
response_content, response_headers.get("content-encoding") or "identity"
)
new_flow.response = http.Response(
b"HTTP/1.1",
response_code,
http.status_codes.RESPONSES.get(response_code, "").encode(),
response_headers,
response_content,
None,
timestamp_start,
timestamp_end,
)
# Update timestamps
new_flow.request.timestamp_start = timestamp_start
new_flow.request.timestamp_end = timestamp_end
new_flow.client_conn.timestamp_start = timestamp_start
new_flow.client_conn.timestamp_end = timestamp_end
# Update HTTP version
match http_version_req:
case "http/2.0":
new_flow.request.http_version = "HTTP/2"
case "HTTP/2":
new_flow.request.http_version = "HTTP/2"
case "HTTP/3":
new_flow.request.http_version = "HTTP/3"
case _:
new_flow.request.http_version = "HTTP/1.1"
match http_version_resp:
case "http/2.0":
new_flow.response.http_version = "HTTP/2"
case "HTTP/2":
new_flow.response.http_version = "HTTP/2"
case "HTTP/3":
new_flow.response.http_version = "HTTP/3"
case _:
new_flow.response.http_version = "HTTP/1.1"
# Remove compression because that may generate different sizes between versions
new_flow.request.decode()
new_flow.response.decode()
return new_flow

View File

@@ -0,0 +1,114 @@
import json
import os
from collections.abc import Iterable
from io import BufferedReader
from typing import Any
from typing import BinaryIO
from typing import cast
from typing import Union
from mitmproxy import exceptions
from mitmproxy import flow
from mitmproxy import flowfilter
from mitmproxy.io import compat
from mitmproxy.io import tnetstring
from mitmproxy.io.har import request_to_flow
class FlowWriter:
def __init__(self, fo):
self.fo = fo
def add(self, f: flow.Flow) -> None:
d = f.get_state()
tnetstring.dump(d, self.fo)
class FlowReader:
fo: BinaryIO
def __init__(self, fo: BinaryIO):
self.fo = fo
def peek(self, n: int) -> bytes:
try:
return cast(BufferedReader, self.fo).peek(n)
except AttributeError:
# https://github.com/python/cpython/issues/90533: io.BytesIO does not have peek()
pos = self.fo.tell()
ret = self.fo.read(n)
self.fo.seek(pos)
return ret
def stream(self) -> Iterable[flow.Flow]:
"""
Yields Flow objects from the dump.
"""
if self.peek(4).startswith(
b"\xef\xbb\xbf{"
): # skip BOM, usually added by Fiddler
self.fo.read(3)
if self.peek(1).startswith(b"{"):
try:
har_file = json.loads(self.fo.read().decode("utf-8"))
for request_json in har_file["log"]["entries"]:
yield request_to_flow(request_json)
except Exception:
raise exceptions.FlowReadException(
"Unable to read HAR file. Please provide a valid HAR file"
)
else:
try:
while True:
# FIXME: This cast hides a lack of dynamic type checking
loaded = cast(
dict[Union[bytes, str], Any],
tnetstring.load(self.fo),
)
try:
if not isinstance(loaded, dict):
raise ValueError(f"Invalid flow: {loaded=}")
yield flow.Flow.from_state(compat.migrate_flow(loaded))
except ValueError as e:
raise exceptions.FlowReadException(e) from e
except (ValueError, TypeError, IndexError) as e:
if str(e) == "not a tnetstring: empty file":
return # Error is due to EOF
raise exceptions.FlowReadException("Invalid data format.") from e
class FilteredFlowWriter:
def __init__(self, fo: BinaryIO, flt: flowfilter.TFilter | None):
self.fo = fo
self.flt = flt
def add(self, f: flow.Flow) -> None:
if self.flt and not flowfilter.match(self.flt, f):
return
d = f.get_state()
tnetstring.dump(d, self.fo)
self.fo.flush()
def read_flows_from_paths(paths) -> list[flow.Flow]:
"""
Given a list of filepaths, read all flows and return a list of them.
From a performance perspective, streaming would be advisable -
however, if there's an error with one of the files, we want it to be raised immediately.
Raises:
FlowReadException, if any error occurs.
"""
try:
flows: list[flow.Flow] = []
for path in paths:
path = os.path.expanduser(path)
with open(path, "rb") as f:
flows.extend(FlowReader(f).stream())
except OSError as e:
raise exceptions.FlowReadException(e.strerror)
return flows

View File

@@ -0,0 +1,261 @@
"""
tnetstring: data serialization using typed netstrings
======================================================
This is a custom Python 3 implementation of tnetstrings.
Compared to other implementations, the main difference
is that this implementation supports a custom unicode datatype.
An ordinary tnetstring is a blob of data prefixed with its length and postfixed
with its type. Here are some examples:
>>> tnetstring.dumps("hello world")
11:hello world,
>>> tnetstring.dumps(12345)
5:12345#
>>> tnetstring.dumps([12345, True, 0])
19:5:12345#4:true!1:0#]
This module gives you the following functions:
:dump: dump an object as a tnetstring to a file
:dumps: dump an object as a tnetstring to a string
:load: load a tnetstring-encoded object from a file
:loads: load a tnetstring-encoded object from a string
Note that since parsing a tnetstring requires reading all the data into memory
at once, there's no efficiency gain from using the file-based versions of these
functions. They're only here so you can use load() to read precisely one
item from a file or socket without consuming any extra data.
The tnetstrings specification explicitly states that strings are binary blobs
and forbids the use of unicode at the protocol level.
**This implementation decodes dictionary keys as surrogate-escaped ASCII**,
all other strings are returned as plain bytes.
:Copyright: (c) 2012-2013 by Ryan Kelly <ryan@rfk.id.au>.
:Copyright: (c) 2014 by Carlo Pires <carlopires@gmail.com>.
:Copyright: (c) 2016 by Maximilian Hils <tnetstring3@maximilianhils.com>.
:License: MIT
"""
import collections
from typing import BinaryIO
from typing import Union
TSerializable = Union[None, str, bool, int, float, bytes, list, tuple, dict]
def dumps(value: TSerializable) -> bytes:
"""
This function dumps a python object as a tnetstring.
"""
# This uses a deque to collect output fragments in reverse order,
# then joins them together at the end. It's measurably faster
# than creating all the intermediate strings.
q: collections.deque = collections.deque()
_rdumpq(q, 0, value)
return b"".join(q)
def dump(value: TSerializable, file_handle: BinaryIO) -> None:
"""
This function dumps a python object as a tnetstring and
writes it to the given file.
"""
file_handle.write(dumps(value))
def _rdumpq(q: collections.deque, size: int, value: TSerializable) -> int:
"""
Dump value as a tnetstring, to a deque instance, last chunks first.
This function generates the tnetstring representation of the given value,
pushing chunks of the output onto the given deque instance. It pushes
the last chunk first, then recursively generates more chunks.
When passed in the current size of the string in the queue, it will return
the new size of the string in the queue.
Operating last-chunk-first makes it easy to calculate the size written
for recursive structures without having to build their representation as
a string. This is measurably faster than generating the intermediate
strings, especially on deeply nested structures.
"""
write = q.appendleft
if value is None:
write(b"0:~")
return size + 3
elif value is True:
write(b"4:true!")
return size + 7
elif value is False:
write(b"5:false!")
return size + 8
elif isinstance(value, int):
data = str(value).encode()
ldata = len(data)
span = str(ldata).encode()
write(b"%s:%s#" % (span, data))
return size + 2 + len(span) + ldata
elif isinstance(value, float):
# Use repr() for float rather than str().
# It round-trips more accurately.
# Probably unnecessary in later python versions that
# use David Gay's ftoa routines.
data = repr(value).encode()
ldata = len(data)
span = str(ldata).encode()
write(b"%s:%s^" % (span, data))
return size + 2 + len(span) + ldata
elif isinstance(value, bytes):
data = value
ldata = len(data)
span = str(ldata).encode()
write(b",")
write(data)
write(b":")
write(span)
return size + 2 + len(span) + ldata
elif isinstance(value, str):
data = value.encode("utf8")
ldata = len(data)
span = str(ldata).encode()
write(b";")
write(data)
write(b":")
write(span)
return size + 2 + len(span) + ldata
elif isinstance(value, (list, tuple)):
write(b"]")
init_size = size = size + 1
for item in reversed(value):
size = _rdumpq(q, size, item)
span = str(size - init_size).encode()
write(b":")
write(span)
return size + 1 + len(span)
elif isinstance(value, dict):
write(b"}")
init_size = size = size + 1
for k, v in value.items():
size = _rdumpq(q, size, v)
size = _rdumpq(q, size, k)
span = str(size - init_size).encode()
write(b":")
write(span)
return size + 1 + len(span)
else:
raise ValueError(f"unserializable object: {value} ({type(value)})")
def loads(string: bytes) -> TSerializable:
"""
This function parses a tnetstring into a python object.
"""
return pop(memoryview(string))[0]
def load(file_handle: BinaryIO) -> TSerializable:
"""load(file) -> object
This function reads a tnetstring from a file and parses it into a
python object. The file must support the read() method, and this
function promises not to read more data than necessary.
"""
# Read the length prefix one char at a time.
# Note that the netstring spec explicitly forbids padding zeros.
c = file_handle.read(1)
if c == b"": # we want to detect this special case.
raise ValueError("not a tnetstring: empty file")
data_length = b""
while c.isdigit():
data_length += c
if len(data_length) > 12:
raise ValueError("not a tnetstring: absurdly large length prefix")
c = file_handle.read(1)
if c != b":":
raise ValueError("not a tnetstring: missing or invalid length prefix")
data = memoryview(file_handle.read(int(data_length)))
data_type = file_handle.read(1)[0]
return parse(data_type, data)
def parse(data_type: int, data: memoryview) -> TSerializable:
if data_type == ord(b","):
return data.tobytes()
if data_type == ord(b";"):
return str(data, "utf8")
if data_type == ord(b"#"):
try:
return int(data)
except ValueError:
raise ValueError(f"not a tnetstring: invalid integer literal: {data!r}")
if data_type == ord(b"^"):
try:
return float(data)
except ValueError:
raise ValueError(f"not a tnetstring: invalid float literal: {data!r}")
if data_type == ord(b"!"):
if data == b"true":
return True
elif data == b"false":
return False
else:
raise ValueError(f"not a tnetstring: invalid boolean literal: {data!r}")
if data_type == ord(b"~"):
if data:
raise ValueError(f"not a tnetstring: invalid null literal: {data!r}")
return None
if data_type == ord(b"]"):
lst = []
while data:
item, data = pop(data)
lst.append(item) # type: ignore
return lst
if data_type == ord(b"}"):
d = {}
while data:
key, data = pop(data)
val, data = pop(data)
d[key] = val # type: ignore
return d
raise ValueError(f"unknown type tag: {data_type}")
def split(data: memoryview, sep: bytes) -> tuple[int, memoryview]:
i = 0
try:
ord_sep = ord(sep)
while data[i] != ord_sep:
i += 1
# here i is the position of b":" in the memoryview
return int(data[:i]), data[i + 1 :]
except (IndexError, ValueError):
raise ValueError(
f"not a tnetstring: missing or invalid length prefix: {data.tobytes()!r}"
)
def pop(data: memoryview) -> tuple[TSerializable, memoryview]:
"""
This function parses a tnetstring into a python object.
It returns a tuple giving the parsed object and a string
containing any unparsed data from the end of the string.
"""
# Parse out data length, type and remaining string.
length, data = split(data, b":")
try:
data, data_type, remain = data[:length], data[length], data[length + 1 :]
except IndexError:
# This fires if len(data) < dlen, meaning we don't need
# to further validate that data is the right length.
raise ValueError(f"not a tnetstring: invalid length prefix: {length}")
# Parse the data based on the type tag.
return parse(data_type, data), remain
__all__ = ["dump", "dumps", "load", "loads", "pop"]