101 lines
2.7 KiB
Python
101 lines
2.7 KiB
Python
import asyncio
|
||
import os
|
||
import sys
|
||
import time
|
||
from collections.abc import Coroutine
|
||
from collections.abc import Iterator
|
||
from contextlib import contextmanager
|
||
|
||
from mitmproxy.utils import human
|
||
|
||
_KEEP_ALIVE = set()
|
||
|
||
|
||
def create_task(
|
||
coro: Coroutine,
|
||
*,
|
||
name: str,
|
||
keep_ref: bool,
|
||
client: tuple | None = None,
|
||
) -> asyncio.Task:
|
||
"""
|
||
Wrapper around `asyncio.create_task`.
|
||
|
||
- Use `keep_ref` to keep an internal reference.
|
||
This ensures that the task is not garbage collected mid-execution if no other reference is kept.
|
||
- Use `client` to pass the client address as additional debug info on the task.
|
||
"""
|
||
t = asyncio.create_task(coro) # noqa: TID251
|
||
set_task_debug_info(t, name=name, client=client)
|
||
if keep_ref and not t.done():
|
||
# The event loop only keeps weak references to tasks.
|
||
# A task that isn’t referenced elsewhere may get garbage collected at any time, even before it’s done.
|
||
_KEEP_ALIVE.add(t)
|
||
t.add_done_callback(_KEEP_ALIVE.discard)
|
||
return t
|
||
|
||
|
||
def set_task_debug_info(
|
||
task: asyncio.Task,
|
||
*,
|
||
name: str,
|
||
client: tuple | None = None,
|
||
) -> None:
|
||
"""Set debug info for an externally-spawned task."""
|
||
task.created = time.time() # type: ignore
|
||
if __debug__ is True and (test := os.environ.get("PYTEST_CURRENT_TEST", None)):
|
||
name = f"{name} [created in {test}]"
|
||
task.set_name(name)
|
||
if client:
|
||
task.client = client # type: ignore
|
||
|
||
|
||
def set_current_task_debug_info(
|
||
*,
|
||
name: str,
|
||
client: tuple | None = None,
|
||
) -> None:
|
||
"""Set debug info for the current task."""
|
||
task = asyncio.current_task()
|
||
assert task
|
||
set_task_debug_info(task, name=name, client=client)
|
||
|
||
|
||
def task_repr(task: asyncio.Task) -> str:
|
||
"""Get a task representation with debug info."""
|
||
name = task.get_name()
|
||
a: float = getattr(task, "created", 0)
|
||
if a:
|
||
age = f" (age: {time.time() - a:.0f}s)"
|
||
else:
|
||
age = ""
|
||
client = getattr(task, "client", "")
|
||
if client:
|
||
client = f"{human.format_address(client)}: "
|
||
return f"{client}{name}{age}"
|
||
|
||
|
||
@contextmanager
|
||
def install_exception_handler(handler) -> Iterator[None]:
|
||
loop = asyncio.get_running_loop()
|
||
existing = loop.get_exception_handler()
|
||
loop.set_exception_handler(handler)
|
||
try:
|
||
yield
|
||
finally:
|
||
loop.set_exception_handler(existing)
|
||
|
||
|
||
@contextmanager
|
||
def set_eager_task_factory() -> Iterator[None]:
|
||
loop = asyncio.get_running_loop()
|
||
if sys.version_info < (3, 12): # pragma: no cover
|
||
yield
|
||
else:
|
||
existing = loop.get_task_factory()
|
||
loop.set_task_factory(asyncio.eager_task_factory) # type: ignore
|
||
try:
|
||
yield
|
||
finally:
|
||
loop.set_task_factory(existing)
|