Files
“shengyudong” 322ac74336 2025-12-25 upload
2025-12-25 11:16:59 +08:00

101 lines
2.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import os
import sys
import time
from collections.abc import Coroutine
from collections.abc import Iterator
from contextlib import contextmanager
from mitmproxy.utils import human
_KEEP_ALIVE = set()
def create_task(
coro: Coroutine,
*,
name: str,
keep_ref: bool,
client: tuple | None = None,
) -> asyncio.Task:
"""
Wrapper around `asyncio.create_task`.
- Use `keep_ref` to keep an internal reference.
This ensures that the task is not garbage collected mid-execution if no other reference is kept.
- Use `client` to pass the client address as additional debug info on the task.
"""
t = asyncio.create_task(coro) # noqa: TID251
set_task_debug_info(t, name=name, client=client)
if keep_ref and not t.done():
# The event loop only keeps weak references to tasks.
# A task that isnt referenced elsewhere may get garbage collected at any time, even before its done.
_KEEP_ALIVE.add(t)
t.add_done_callback(_KEEP_ALIVE.discard)
return t
def set_task_debug_info(
task: asyncio.Task,
*,
name: str,
client: tuple | None = None,
) -> None:
"""Set debug info for an externally-spawned task."""
task.created = time.time() # type: ignore
if __debug__ is True and (test := os.environ.get("PYTEST_CURRENT_TEST", None)):
name = f"{name} [created in {test}]"
task.set_name(name)
if client:
task.client = client # type: ignore
def set_current_task_debug_info(
*,
name: str,
client: tuple | None = None,
) -> None:
"""Set debug info for the current task."""
task = asyncio.current_task()
assert task
set_task_debug_info(task, name=name, client=client)
def task_repr(task: asyncio.Task) -> str:
"""Get a task representation with debug info."""
name = task.get_name()
a: float = getattr(task, "created", 0)
if a:
age = f" (age: {time.time() - a:.0f}s)"
else:
age = ""
client = getattr(task, "client", "")
if client:
client = f"{human.format_address(client)}: "
return f"{client}{name}{age}"
@contextmanager
def install_exception_handler(handler) -> Iterator[None]:
loop = asyncio.get_running_loop()
existing = loop.get_exception_handler()
loop.set_exception_handler(handler)
try:
yield
finally:
loop.set_exception_handler(existing)
@contextmanager
def set_eager_task_factory() -> Iterator[None]:
loop = asyncio.get_running_loop()
if sys.version_info < (3, 12): # pragma: no cover
yield
else:
existing = loop.get_task_factory()
loop.set_task_factory(asyncio.eager_task_factory) # type: ignore
try:
yield
finally:
loop.set_task_factory(existing)