2025-12-25 upload
This commit is contained in:
1
venv/Lib/site-packages/aioquic/__init__.py
Normal file
1
venv/Lib/site-packages/aioquic/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
__version__ = "1.2.0"
|
||||
422
venv/Lib/site-packages/aioquic/_buffer.c
Normal file
422
venv/Lib/site-packages/aioquic/_buffer.c
Normal file
@@ -0,0 +1,422 @@
|
||||
#define PY_SSIZE_T_CLEAN
|
||||
|
||||
#include <Python.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#define MODULE_NAME "aioquic._buffer"
|
||||
|
||||
static PyObject *BufferReadError;
|
||||
static PyObject *BufferWriteError;
|
||||
|
||||
typedef struct {
|
||||
PyObject_HEAD
|
||||
uint8_t *base;
|
||||
uint8_t *end;
|
||||
uint8_t *pos;
|
||||
} BufferObject;
|
||||
|
||||
static PyObject *BufferType;
|
||||
|
||||
#define CHECK_READ_BOUNDS(self, len) \
|
||||
if (len < 0 || self->pos + len > self->end) { \
|
||||
PyErr_SetString(BufferReadError, "Read out of bounds"); \
|
||||
return NULL; \
|
||||
}
|
||||
|
||||
#define CHECK_WRITE_BOUNDS(self, len) \
|
||||
if (self->pos + len > self->end) { \
|
||||
PyErr_SetString(BufferWriteError, "Write out of bounds"); \
|
||||
return NULL; \
|
||||
}
|
||||
|
||||
static int
|
||||
Buffer_init(BufferObject *self, PyObject *args, PyObject *kwargs)
|
||||
{
|
||||
const char *kwlist[] = {"capacity", "data", NULL};
|
||||
Py_ssize_t capacity = 0;
|
||||
const unsigned char *data = NULL;
|
||||
Py_ssize_t data_len = 0;
|
||||
|
||||
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|ny#", (char**)kwlist, &capacity, &data, &data_len))
|
||||
return -1;
|
||||
|
||||
if (data != NULL) {
|
||||
self->base = malloc(data_len);
|
||||
self->end = self->base + data_len;
|
||||
memcpy(self->base, data, data_len);
|
||||
} else {
|
||||
self->base = malloc(capacity);
|
||||
self->end = self->base + capacity;
|
||||
}
|
||||
self->pos = self->base;
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void
|
||||
Buffer_dealloc(BufferObject *self)
|
||||
{
|
||||
free(self->base);
|
||||
PyTypeObject *tp = Py_TYPE(self);
|
||||
freefunc free = PyType_GetSlot(tp, Py_tp_free);
|
||||
free(self);
|
||||
Py_DECREF(tp);
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
Buffer_data_slice(BufferObject *self, PyObject *args)
|
||||
{
|
||||
Py_ssize_t start, stop;
|
||||
if (!PyArg_ParseTuple(args, "nn", &start, &stop))
|
||||
return NULL;
|
||||
|
||||
if (start < 0 || self->base + start > self->end ||
|
||||
stop < 0 || self->base + stop > self->end ||
|
||||
stop < start) {
|
||||
PyErr_SetString(BufferReadError, "Read out of bounds");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return PyBytes_FromStringAndSize((const char*)(self->base + start), (stop - start));
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
Buffer_eof(BufferObject *self, PyObject *args)
|
||||
{
|
||||
if (self->pos == self->end)
|
||||
Py_RETURN_TRUE;
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
Buffer_pull_bytes(BufferObject *self, PyObject *args)
|
||||
{
|
||||
Py_ssize_t len;
|
||||
if (!PyArg_ParseTuple(args, "n", &len))
|
||||
return NULL;
|
||||
|
||||
CHECK_READ_BOUNDS(self, len);
|
||||
|
||||
PyObject *o = PyBytes_FromStringAndSize((const char*)self->pos, len);
|
||||
self->pos += len;
|
||||
return o;
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
Buffer_pull_uint8(BufferObject *self, PyObject *args)
|
||||
{
|
||||
CHECK_READ_BOUNDS(self, 1)
|
||||
|
||||
return PyLong_FromUnsignedLong(
|
||||
(uint8_t)(*(self->pos++))
|
||||
);
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
Buffer_pull_uint16(BufferObject *self, PyObject *args)
|
||||
{
|
||||
CHECK_READ_BOUNDS(self, 2)
|
||||
|
||||
uint16_t value = (uint16_t)(*(self->pos)) << 8 |
|
||||
(uint16_t)(*(self->pos + 1));
|
||||
self->pos += 2;
|
||||
return PyLong_FromUnsignedLong(value);
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
Buffer_pull_uint32(BufferObject *self, PyObject *args)
|
||||
{
|
||||
CHECK_READ_BOUNDS(self, 4)
|
||||
|
||||
uint32_t value = (uint32_t)(*(self->pos)) << 24 |
|
||||
(uint32_t)(*(self->pos + 1)) << 16 |
|
||||
(uint32_t)(*(self->pos + 2)) << 8 |
|
||||
(uint32_t)(*(self->pos + 3));
|
||||
self->pos += 4;
|
||||
return PyLong_FromUnsignedLong(value);
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
Buffer_pull_uint64(BufferObject *self, PyObject *args)
|
||||
{
|
||||
CHECK_READ_BOUNDS(self, 8)
|
||||
|
||||
uint64_t value = (uint64_t)(*(self->pos)) << 56 |
|
||||
(uint64_t)(*(self->pos + 1)) << 48 |
|
||||
(uint64_t)(*(self->pos + 2)) << 40 |
|
||||
(uint64_t)(*(self->pos + 3)) << 32 |
|
||||
(uint64_t)(*(self->pos + 4)) << 24 |
|
||||
(uint64_t)(*(self->pos + 5)) << 16 |
|
||||
(uint64_t)(*(self->pos + 6)) << 8 |
|
||||
(uint64_t)(*(self->pos + 7));
|
||||
self->pos += 8;
|
||||
return PyLong_FromUnsignedLongLong(value);
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
Buffer_pull_uint_var(BufferObject *self, PyObject *args)
|
||||
{
|
||||
uint64_t value;
|
||||
CHECK_READ_BOUNDS(self, 1)
|
||||
switch (*(self->pos) >> 6) {
|
||||
case 0:
|
||||
value = *(self->pos++) & 0x3F;
|
||||
break;
|
||||
case 1:
|
||||
CHECK_READ_BOUNDS(self, 2)
|
||||
value = (uint16_t)(*(self->pos) & 0x3F) << 8 |
|
||||
(uint16_t)(*(self->pos + 1));
|
||||
self->pos += 2;
|
||||
break;
|
||||
case 2:
|
||||
CHECK_READ_BOUNDS(self, 4)
|
||||
value = (uint32_t)(*(self->pos) & 0x3F) << 24 |
|
||||
(uint32_t)(*(self->pos + 1)) << 16 |
|
||||
(uint32_t)(*(self->pos + 2)) << 8 |
|
||||
(uint32_t)(*(self->pos + 3));
|
||||
self->pos += 4;
|
||||
break;
|
||||
default:
|
||||
CHECK_READ_BOUNDS(self, 8)
|
||||
value = (uint64_t)(*(self->pos) & 0x3F) << 56 |
|
||||
(uint64_t)(*(self->pos + 1)) << 48 |
|
||||
(uint64_t)(*(self->pos + 2)) << 40 |
|
||||
(uint64_t)(*(self->pos + 3)) << 32 |
|
||||
(uint64_t)(*(self->pos + 4)) << 24 |
|
||||
(uint64_t)(*(self->pos + 5)) << 16 |
|
||||
(uint64_t)(*(self->pos + 6)) << 8 |
|
||||
(uint64_t)(*(self->pos + 7));
|
||||
self->pos += 8;
|
||||
break;
|
||||
}
|
||||
return PyLong_FromUnsignedLongLong(value);
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
Buffer_push_bytes(BufferObject *self, PyObject *args)
|
||||
{
|
||||
const unsigned char *data;
|
||||
Py_ssize_t data_len;
|
||||
if (!PyArg_ParseTuple(args, "y#", &data, &data_len))
|
||||
return NULL;
|
||||
|
||||
CHECK_WRITE_BOUNDS(self, data_len)
|
||||
|
||||
memcpy(self->pos, data, data_len);
|
||||
self->pos += data_len;
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
Buffer_push_uint8(BufferObject *self, PyObject *args)
|
||||
{
|
||||
uint8_t value;
|
||||
if (!PyArg_ParseTuple(args, "B", &value))
|
||||
return NULL;
|
||||
|
||||
CHECK_WRITE_BOUNDS(self, 1)
|
||||
|
||||
*(self->pos++) = value;
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
Buffer_push_uint16(BufferObject *self, PyObject *args)
|
||||
{
|
||||
uint16_t value;
|
||||
if (!PyArg_ParseTuple(args, "H", &value))
|
||||
return NULL;
|
||||
|
||||
CHECK_WRITE_BOUNDS(self, 2)
|
||||
|
||||
*(self->pos++) = (value >> 8);
|
||||
*(self->pos++) = value;
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
Buffer_push_uint32(BufferObject *self, PyObject *args)
|
||||
{
|
||||
uint32_t value;
|
||||
if (!PyArg_ParseTuple(args, "I", &value))
|
||||
return NULL;
|
||||
|
||||
CHECK_WRITE_BOUNDS(self, 4)
|
||||
*(self->pos++) = (value >> 24);
|
||||
*(self->pos++) = (value >> 16);
|
||||
*(self->pos++) = (value >> 8);
|
||||
*(self->pos++) = value;
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
Buffer_push_uint64(BufferObject *self, PyObject *args)
|
||||
{
|
||||
uint64_t value;
|
||||
if (!PyArg_ParseTuple(args, "K", &value))
|
||||
return NULL;
|
||||
|
||||
CHECK_WRITE_BOUNDS(self, 8)
|
||||
*(self->pos++) = (value >> 56);
|
||||
*(self->pos++) = (value >> 48);
|
||||
*(self->pos++) = (value >> 40);
|
||||
*(self->pos++) = (value >> 32);
|
||||
*(self->pos++) = (value >> 24);
|
||||
*(self->pos++) = (value >> 16);
|
||||
*(self->pos++) = (value >> 8);
|
||||
*(self->pos++) = value;
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
Buffer_push_uint_var(BufferObject *self, PyObject *args)
|
||||
{
|
||||
uint64_t value;
|
||||
if (!PyArg_ParseTuple(args, "K", &value))
|
||||
return NULL;
|
||||
|
||||
if (value <= 0x3F) {
|
||||
CHECK_WRITE_BOUNDS(self, 1)
|
||||
*(self->pos++) = value;
|
||||
Py_RETURN_NONE;
|
||||
} else if (value <= 0x3FFF) {
|
||||
CHECK_WRITE_BOUNDS(self, 2)
|
||||
*(self->pos++) = (value >> 8) | 0x40;
|
||||
*(self->pos++) = value;
|
||||
Py_RETURN_NONE;
|
||||
} else if (value <= 0x3FFFFFFF) {
|
||||
CHECK_WRITE_BOUNDS(self, 4)
|
||||
*(self->pos++) = (value >> 24) | 0x80;
|
||||
*(self->pos++) = (value >> 16);
|
||||
*(self->pos++) = (value >> 8);
|
||||
*(self->pos++) = value;
|
||||
Py_RETURN_NONE;
|
||||
} else if (value <= 0x3FFFFFFFFFFFFFFF) {
|
||||
CHECK_WRITE_BOUNDS(self, 8)
|
||||
*(self->pos++) = (value >> 56) | 0xC0;
|
||||
*(self->pos++) = (value >> 48);
|
||||
*(self->pos++) = (value >> 40);
|
||||
*(self->pos++) = (value >> 32);
|
||||
*(self->pos++) = (value >> 24);
|
||||
*(self->pos++) = (value >> 16);
|
||||
*(self->pos++) = (value >> 8);
|
||||
*(self->pos++) = value;
|
||||
Py_RETURN_NONE;
|
||||
} else {
|
||||
PyErr_SetString(PyExc_ValueError, "Integer is too big for a variable-length integer");
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
Buffer_seek(BufferObject *self, PyObject *args)
|
||||
{
|
||||
Py_ssize_t pos;
|
||||
if (!PyArg_ParseTuple(args, "n", &pos))
|
||||
return NULL;
|
||||
|
||||
if (pos < 0 || self->base + pos > self->end) {
|
||||
PyErr_SetString(BufferReadError, "Seek out of bounds");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
self->pos = self->base + pos;
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
Buffer_tell(BufferObject *self, PyObject *args)
|
||||
{
|
||||
return PyLong_FromSsize_t(self->pos - self->base);
|
||||
}
|
||||
|
||||
static PyMethodDef Buffer_methods[] = {
|
||||
{"data_slice", (PyCFunction)Buffer_data_slice, METH_VARARGS, ""},
|
||||
{"eof", (PyCFunction)Buffer_eof, METH_VARARGS, ""},
|
||||
{"pull_bytes", (PyCFunction)Buffer_pull_bytes, METH_VARARGS, "Pull bytes."},
|
||||
{"pull_uint8", (PyCFunction)Buffer_pull_uint8, METH_VARARGS, "Pull an 8-bit unsigned integer."},
|
||||
{"pull_uint16", (PyCFunction)Buffer_pull_uint16, METH_VARARGS, "Pull a 16-bit unsigned integer."},
|
||||
{"pull_uint32", (PyCFunction)Buffer_pull_uint32, METH_VARARGS, "Pull a 32-bit unsigned integer."},
|
||||
{"pull_uint64", (PyCFunction)Buffer_pull_uint64, METH_VARARGS, "Pull a 64-bit unsigned integer."},
|
||||
{"pull_uint_var", (PyCFunction)Buffer_pull_uint_var, METH_VARARGS, "Pull a QUIC variable-length unsigned integer."},
|
||||
{"push_bytes", (PyCFunction)Buffer_push_bytes, METH_VARARGS, "Push bytes."},
|
||||
{"push_uint8", (PyCFunction)Buffer_push_uint8, METH_VARARGS, "Push an 8-bit unsigned integer."},
|
||||
{"push_uint16", (PyCFunction)Buffer_push_uint16, METH_VARARGS, "Push a 16-bit unsigned integer."},
|
||||
{"push_uint32", (PyCFunction)Buffer_push_uint32, METH_VARARGS, "Push a 32-bit unsigned integer."},
|
||||
{"push_uint64", (PyCFunction)Buffer_push_uint64, METH_VARARGS, "Push a 64-bit unsigned integer."},
|
||||
{"push_uint_var", (PyCFunction)Buffer_push_uint_var, METH_VARARGS, "Push a QUIC variable-length unsigned integer."},
|
||||
{"seek", (PyCFunction)Buffer_seek, METH_VARARGS, ""},
|
||||
{"tell", (PyCFunction)Buffer_tell, METH_VARARGS, ""},
|
||||
{NULL}
|
||||
};
|
||||
|
||||
static PyObject*
|
||||
Buffer_capacity_getter(BufferObject* self, void *closure) {
|
||||
return PyLong_FromSsize_t(self->end - self->base);
|
||||
}
|
||||
|
||||
static PyObject*
|
||||
Buffer_data_getter(BufferObject* self, void *closure) {
|
||||
return PyBytes_FromStringAndSize((const char*)self->base, self->pos - self->base);
|
||||
}
|
||||
|
||||
static PyGetSetDef Buffer_getset[] = {
|
||||
{"capacity", (getter) Buffer_capacity_getter, NULL, "", NULL },
|
||||
{"data", (getter) Buffer_data_getter, NULL, "", NULL },
|
||||
{NULL}
|
||||
};
|
||||
|
||||
static PyType_Slot BufferType_slots[] = {
|
||||
{Py_tp_dealloc, Buffer_dealloc},
|
||||
{Py_tp_methods, Buffer_methods},
|
||||
{Py_tp_doc, "Buffer objects"},
|
||||
{Py_tp_getset, Buffer_getset},
|
||||
{Py_tp_init, Buffer_init},
|
||||
{0, 0},
|
||||
};
|
||||
|
||||
static PyType_Spec BufferType_spec = {
|
||||
MODULE_NAME ".Buffer",
|
||||
sizeof(BufferObject),
|
||||
0,
|
||||
Py_TPFLAGS_DEFAULT,
|
||||
BufferType_slots
|
||||
};
|
||||
|
||||
|
||||
static struct PyModuleDef moduledef = {
|
||||
PyModuleDef_HEAD_INIT,
|
||||
MODULE_NAME, /* m_name */
|
||||
"Serialization utilities.", /* m_doc */
|
||||
-1, /* m_size */
|
||||
NULL, /* m_methods */
|
||||
NULL, /* m_reload */
|
||||
NULL, /* m_traverse */
|
||||
NULL, /* m_clear */
|
||||
NULL, /* m_free */
|
||||
};
|
||||
|
||||
|
||||
PyMODINIT_FUNC
|
||||
PyInit__buffer(void)
|
||||
{
|
||||
PyObject* m;
|
||||
|
||||
m = PyModule_Create(&moduledef);
|
||||
if (m == NULL)
|
||||
return NULL;
|
||||
|
||||
BufferReadError = PyErr_NewException(MODULE_NAME ".BufferReadError", PyExc_ValueError, NULL);
|
||||
Py_INCREF(BufferReadError);
|
||||
PyModule_AddObject(m, "BufferReadError", BufferReadError);
|
||||
|
||||
BufferWriteError = PyErr_NewException(MODULE_NAME ".BufferWriteError", PyExc_ValueError, NULL);
|
||||
Py_INCREF(BufferWriteError);
|
||||
PyModule_AddObject(m, "BufferWriteError", BufferWriteError);
|
||||
|
||||
BufferType = PyType_FromSpec(&BufferType_spec);
|
||||
if (BufferType == NULL)
|
||||
return NULL;
|
||||
PyModule_AddObject(m, "Buffer", BufferType);
|
||||
|
||||
return m;
|
||||
}
|
||||
27
venv/Lib/site-packages/aioquic/_buffer.pyi
Normal file
27
venv/Lib/site-packages/aioquic/_buffer.pyi
Normal file
@@ -0,0 +1,27 @@
|
||||
from typing import Optional
|
||||
|
||||
class BufferReadError(ValueError): ...
|
||||
class BufferWriteError(ValueError): ...
|
||||
|
||||
class Buffer:
|
||||
def __init__(self, capacity: Optional[int] = 0, data: Optional[bytes] = None): ...
|
||||
@property
|
||||
def capacity(self) -> int: ...
|
||||
@property
|
||||
def data(self) -> bytes: ...
|
||||
def data_slice(self, start: int, end: int) -> bytes: ...
|
||||
def eof(self) -> bool: ...
|
||||
def seek(self, pos: int) -> None: ...
|
||||
def tell(self) -> int: ...
|
||||
def pull_bytes(self, length: int) -> bytes: ...
|
||||
def pull_uint8(self) -> int: ...
|
||||
def pull_uint16(self) -> int: ...
|
||||
def pull_uint32(self) -> int: ...
|
||||
def pull_uint64(self) -> int: ...
|
||||
def pull_uint_var(self) -> int: ...
|
||||
def push_bytes(self, value: bytes) -> None: ...
|
||||
def push_uint8(self, value: int) -> None: ...
|
||||
def push_uint16(self, value: int) -> None: ...
|
||||
def push_uint32(self, v: int) -> None: ...
|
||||
def push_uint64(self, v: int) -> None: ...
|
||||
def push_uint_var(self, value: int) -> None: ...
|
||||
416
venv/Lib/site-packages/aioquic/_crypto.c
Normal file
416
venv/Lib/site-packages/aioquic/_crypto.c
Normal file
@@ -0,0 +1,416 @@
|
||||
#define PY_SSIZE_T_CLEAN
|
||||
|
||||
#include <Python.h>
|
||||
#include <openssl/err.h>
|
||||
#include <openssl/evp.h>
|
||||
|
||||
#define MODULE_NAME "aioquic._crypto"
|
||||
|
||||
#define AEAD_KEY_LENGTH_MAX 32
|
||||
#define AEAD_NONCE_LENGTH 12
|
||||
#define AEAD_TAG_LENGTH 16
|
||||
|
||||
#define PACKET_LENGTH_MAX 1500
|
||||
#define PACKET_NUMBER_LENGTH_MAX 4
|
||||
#define SAMPLE_LENGTH 16
|
||||
|
||||
#define CHECK_RESULT(expr) \
|
||||
if (!(expr)) { \
|
||||
ERR_clear_error(); \
|
||||
PyErr_SetString(CryptoError, "OpenSSL call failed"); \
|
||||
return NULL; \
|
||||
}
|
||||
|
||||
#define CHECK_RESULT_CTOR(expr) \
|
||||
if (!(expr)) { \
|
||||
ERR_clear_error(); \
|
||||
PyErr_SetString(CryptoError, "OpenSSL call failed"); \
|
||||
return -1; \
|
||||
}
|
||||
|
||||
static PyObject *CryptoError;
|
||||
|
||||
/* AEAD */
|
||||
|
||||
typedef struct {
|
||||
PyObject_HEAD
|
||||
EVP_CIPHER_CTX *decrypt_ctx;
|
||||
EVP_CIPHER_CTX *encrypt_ctx;
|
||||
unsigned char buffer[PACKET_LENGTH_MAX];
|
||||
unsigned char key[AEAD_KEY_LENGTH_MAX];
|
||||
unsigned char iv[AEAD_NONCE_LENGTH];
|
||||
unsigned char nonce[AEAD_NONCE_LENGTH];
|
||||
} AEADObject;
|
||||
|
||||
static PyObject *AEADType;
|
||||
|
||||
static EVP_CIPHER_CTX *
|
||||
create_ctx(const EVP_CIPHER *cipher, int key_length, int operation)
|
||||
{
|
||||
EVP_CIPHER_CTX *ctx;
|
||||
int res;
|
||||
|
||||
ctx = EVP_CIPHER_CTX_new();
|
||||
CHECK_RESULT(ctx != 0);
|
||||
|
||||
res = EVP_CipherInit_ex(ctx, cipher, NULL, NULL, NULL, operation);
|
||||
CHECK_RESULT(res != 0);
|
||||
|
||||
res = EVP_CIPHER_CTX_set_key_length(ctx, key_length);
|
||||
CHECK_RESULT(res != 0);
|
||||
|
||||
res = EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_CCM_SET_IVLEN, AEAD_NONCE_LENGTH, NULL);
|
||||
CHECK_RESULT(res != 0);
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
static int
|
||||
AEAD_init(AEADObject *self, PyObject *args, PyObject *kwargs)
|
||||
{
|
||||
const char *cipher_name;
|
||||
const unsigned char *key, *iv;
|
||||
Py_ssize_t cipher_name_len, key_len, iv_len;
|
||||
|
||||
if (!PyArg_ParseTuple(args, "y#y#y#", &cipher_name, &cipher_name_len, &key, &key_len, &iv, &iv_len))
|
||||
return -1;
|
||||
|
||||
const EVP_CIPHER *evp_cipher = EVP_get_cipherbyname(cipher_name);
|
||||
if (evp_cipher == 0) {
|
||||
PyErr_Format(CryptoError, "Invalid cipher name: %s", cipher_name);
|
||||
return -1;
|
||||
}
|
||||
if (key_len > AEAD_KEY_LENGTH_MAX) {
|
||||
PyErr_SetString(CryptoError, "Invalid key length");
|
||||
return -1;
|
||||
}
|
||||
if (iv_len > AEAD_NONCE_LENGTH) {
|
||||
PyErr_SetString(CryptoError, "Invalid iv length");
|
||||
return -1;
|
||||
}
|
||||
|
||||
memcpy(self->key, key, key_len);
|
||||
memcpy(self->iv, iv, iv_len);
|
||||
|
||||
self->decrypt_ctx = create_ctx(evp_cipher, key_len, 0);
|
||||
CHECK_RESULT_CTOR(self->decrypt_ctx != 0);
|
||||
|
||||
self->encrypt_ctx = create_ctx(evp_cipher, key_len, 1);
|
||||
CHECK_RESULT_CTOR(self->encrypt_ctx != 0);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void
|
||||
AEAD_dealloc(AEADObject *self)
|
||||
{
|
||||
EVP_CIPHER_CTX_free(self->decrypt_ctx);
|
||||
EVP_CIPHER_CTX_free(self->encrypt_ctx);
|
||||
PyTypeObject *tp = Py_TYPE(self);
|
||||
freefunc free = PyType_GetSlot(tp, Py_tp_free);
|
||||
free(self);
|
||||
Py_DECREF(tp);
|
||||
}
|
||||
|
||||
static PyObject*
|
||||
AEAD_decrypt(AEADObject *self, PyObject *args)
|
||||
{
|
||||
const unsigned char *data, *associated;
|
||||
Py_ssize_t data_len, associated_len;
|
||||
int outlen, outlen2, res;
|
||||
uint64_t pn;
|
||||
|
||||
if (!PyArg_ParseTuple(args, "y#y#K", &data, &data_len, &associated, &associated_len, &pn))
|
||||
return NULL;
|
||||
|
||||
if (data_len < AEAD_TAG_LENGTH || data_len > PACKET_LENGTH_MAX) {
|
||||
PyErr_SetString(CryptoError, "Invalid payload length");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
memcpy(self->nonce, self->iv, AEAD_NONCE_LENGTH);
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
self->nonce[AEAD_NONCE_LENGTH - 1 - i] ^= (uint8_t)(pn >> 8 * i);
|
||||
}
|
||||
|
||||
res = EVP_CIPHER_CTX_ctrl(self->decrypt_ctx, EVP_CTRL_CCM_SET_TAG, AEAD_TAG_LENGTH, (void*)(data + (data_len - AEAD_TAG_LENGTH)));
|
||||
CHECK_RESULT(res != 0);
|
||||
|
||||
res = EVP_CipherInit_ex(self->decrypt_ctx, NULL, NULL, self->key, self->nonce, 0);
|
||||
CHECK_RESULT(res != 0);
|
||||
|
||||
res = EVP_CipherUpdate(self->decrypt_ctx, NULL, &outlen, associated, associated_len);
|
||||
CHECK_RESULT(res != 0);
|
||||
|
||||
res = EVP_CipherUpdate(self->decrypt_ctx, self->buffer, &outlen, data, data_len - AEAD_TAG_LENGTH);
|
||||
CHECK_RESULT(res != 0);
|
||||
|
||||
res = EVP_CipherFinal_ex(self->decrypt_ctx, NULL, &outlen2);
|
||||
if (res == 0) {
|
||||
PyErr_SetString(CryptoError, "Payload decryption failed");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return PyBytes_FromStringAndSize((const char*)self->buffer, outlen);
|
||||
}
|
||||
|
||||
static PyObject*
|
||||
AEAD_encrypt(AEADObject *self, PyObject *args)
|
||||
{
|
||||
const unsigned char *data, *associated;
|
||||
Py_ssize_t data_len, associated_len;
|
||||
int outlen, outlen2, res;
|
||||
uint64_t pn;
|
||||
|
||||
if (!PyArg_ParseTuple(args, "y#y#K", &data, &data_len, &associated, &associated_len, &pn))
|
||||
return NULL;
|
||||
|
||||
if (data_len > PACKET_LENGTH_MAX) {
|
||||
PyErr_SetString(CryptoError, "Invalid payload length");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
memcpy(self->nonce, self->iv, AEAD_NONCE_LENGTH);
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
self->nonce[AEAD_NONCE_LENGTH - 1 - i] ^= (uint8_t)(pn >> 8 * i);
|
||||
}
|
||||
|
||||
res = EVP_CipherInit_ex(self->encrypt_ctx, NULL, NULL, self->key, self->nonce, 1);
|
||||
CHECK_RESULT(res != 0);
|
||||
|
||||
res = EVP_CipherUpdate(self->encrypt_ctx, NULL, &outlen, associated, associated_len);
|
||||
CHECK_RESULT(res != 0);
|
||||
|
||||
res = EVP_CipherUpdate(self->encrypt_ctx, self->buffer, &outlen, data, data_len);
|
||||
CHECK_RESULT(res != 0);
|
||||
|
||||
res = EVP_CipherFinal_ex(self->encrypt_ctx, NULL, &outlen2);
|
||||
CHECK_RESULT(res != 0 && outlen2 == 0);
|
||||
|
||||
res = EVP_CIPHER_CTX_ctrl(self->encrypt_ctx, EVP_CTRL_CCM_GET_TAG, AEAD_TAG_LENGTH, self->buffer + outlen);
|
||||
CHECK_RESULT(res != 0);
|
||||
|
||||
return PyBytes_FromStringAndSize((const char*)self->buffer, outlen + AEAD_TAG_LENGTH);
|
||||
}
|
||||
|
||||
static PyMethodDef AEAD_methods[] = {
|
||||
{"decrypt", (PyCFunction)AEAD_decrypt, METH_VARARGS, ""},
|
||||
{"encrypt", (PyCFunction)AEAD_encrypt, METH_VARARGS, ""},
|
||||
|
||||
{NULL}
|
||||
};
|
||||
|
||||
static PyType_Slot AEADType_slots[] = {
|
||||
{Py_tp_dealloc, AEAD_dealloc},
|
||||
{Py_tp_methods, AEAD_methods},
|
||||
{Py_tp_doc, "AEAD objects"},
|
||||
{Py_tp_init, AEAD_init},
|
||||
{0, 0},
|
||||
};
|
||||
|
||||
static PyType_Spec AEADType_spec = {
|
||||
MODULE_NAME ".AEADType",
|
||||
sizeof(AEADObject),
|
||||
0,
|
||||
Py_TPFLAGS_DEFAULT,
|
||||
AEADType_slots
|
||||
};
|
||||
|
||||
/* HeaderProtection */
|
||||
|
||||
typedef struct {
|
||||
PyObject_HEAD
|
||||
EVP_CIPHER_CTX *ctx;
|
||||
int is_chacha20;
|
||||
unsigned char buffer[PACKET_LENGTH_MAX];
|
||||
unsigned char mask[31];
|
||||
unsigned char zero[5];
|
||||
} HeaderProtectionObject;
|
||||
|
||||
static PyObject *HeaderProtectionType;
|
||||
|
||||
static int
|
||||
HeaderProtection_init(HeaderProtectionObject *self, PyObject *args, PyObject *kwargs)
|
||||
{
|
||||
const char *cipher_name;
|
||||
const unsigned char *key;
|
||||
Py_ssize_t cipher_name_len, key_len;
|
||||
int res;
|
||||
|
||||
if (!PyArg_ParseTuple(args, "y#y#", &cipher_name, &cipher_name_len, &key, &key_len))
|
||||
return -1;
|
||||
|
||||
const EVP_CIPHER *evp_cipher = EVP_get_cipherbyname(cipher_name);
|
||||
if (evp_cipher == 0) {
|
||||
PyErr_Format(CryptoError, "Invalid cipher name: %s", cipher_name);
|
||||
return -1;
|
||||
}
|
||||
|
||||
memset(self->mask, 0, sizeof(self->mask));
|
||||
memset(self->zero, 0, sizeof(self->zero));
|
||||
self->is_chacha20 = cipher_name_len == 8 && memcmp(cipher_name, "chacha20", 8) == 0;
|
||||
|
||||
self->ctx = EVP_CIPHER_CTX_new();
|
||||
CHECK_RESULT_CTOR(self->ctx != 0);
|
||||
|
||||
res = EVP_CipherInit_ex(self->ctx, evp_cipher, NULL, NULL, NULL, 1);
|
||||
CHECK_RESULT_CTOR(res != 0);
|
||||
|
||||
res = EVP_CIPHER_CTX_set_key_length(self->ctx, key_len);
|
||||
CHECK_RESULT_CTOR(res != 0);
|
||||
|
||||
res = EVP_CipherInit_ex(self->ctx, NULL, NULL, key, NULL, 1);
|
||||
CHECK_RESULT_CTOR(res != 0);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void
|
||||
HeaderProtection_dealloc(HeaderProtectionObject *self)
|
||||
{
|
||||
EVP_CIPHER_CTX_free(self->ctx);
|
||||
PyTypeObject *tp = Py_TYPE(self);
|
||||
freefunc free = PyType_GetSlot(tp, Py_tp_free);
|
||||
free(self);
|
||||
Py_DECREF(tp);
|
||||
}
|
||||
|
||||
static int HeaderProtection_mask(HeaderProtectionObject *self, const unsigned char* sample)
|
||||
{
|
||||
int outlen;
|
||||
if (self->is_chacha20) {
|
||||
return EVP_CipherInit_ex(self->ctx, NULL, NULL, NULL, sample, 1) &&
|
||||
EVP_CipherUpdate(self->ctx, self->mask, &outlen, self->zero, sizeof(self->zero));
|
||||
} else {
|
||||
return EVP_CipherUpdate(self->ctx, self->mask, &outlen, sample, SAMPLE_LENGTH);
|
||||
}
|
||||
}
|
||||
|
||||
static PyObject*
|
||||
HeaderProtection_apply(HeaderProtectionObject *self, PyObject *args)
|
||||
{
|
||||
const unsigned char *header, *payload;
|
||||
Py_ssize_t header_len, payload_len;
|
||||
int res;
|
||||
|
||||
if (!PyArg_ParseTuple(args, "y#y#", &header, &header_len, &payload, &payload_len))
|
||||
return NULL;
|
||||
|
||||
int pn_length = (header[0] & 0x03) + 1;
|
||||
int pn_offset = header_len - pn_length;
|
||||
|
||||
res = HeaderProtection_mask(self, payload + PACKET_NUMBER_LENGTH_MAX - pn_length);
|
||||
CHECK_RESULT(res != 0);
|
||||
|
||||
memcpy(self->buffer, header, header_len);
|
||||
memcpy(self->buffer + header_len, payload, payload_len);
|
||||
|
||||
if (self->buffer[0] & 0x80) {
|
||||
self->buffer[0] ^= self->mask[0] & 0x0F;
|
||||
} else {
|
||||
self->buffer[0] ^= self->mask[0] & 0x1F;
|
||||
}
|
||||
|
||||
for (int i = 0; i < pn_length; ++i) {
|
||||
self->buffer[pn_offset + i] ^= self->mask[1 + i];
|
||||
}
|
||||
|
||||
return PyBytes_FromStringAndSize((const char*)self->buffer, header_len + payload_len);
|
||||
}
|
||||
|
||||
static PyObject*
|
||||
HeaderProtection_remove(HeaderProtectionObject *self, PyObject *args)
|
||||
{
|
||||
const unsigned char *packet;
|
||||
Py_ssize_t packet_len;
|
||||
int pn_offset, res;
|
||||
|
||||
if (!PyArg_ParseTuple(args, "y#I", &packet, &packet_len, &pn_offset))
|
||||
return NULL;
|
||||
|
||||
res = HeaderProtection_mask(self, packet + pn_offset + PACKET_NUMBER_LENGTH_MAX);
|
||||
CHECK_RESULT(res != 0);
|
||||
|
||||
memcpy(self->buffer, packet, pn_offset + PACKET_NUMBER_LENGTH_MAX);
|
||||
|
||||
if (self->buffer[0] & 0x80) {
|
||||
self->buffer[0] ^= self->mask[0] & 0x0F;
|
||||
} else {
|
||||
self->buffer[0] ^= self->mask[0] & 0x1F;
|
||||
}
|
||||
|
||||
int pn_length = (self->buffer[0] & 0x03) + 1;
|
||||
uint32_t pn_truncated = 0;
|
||||
for (int i = 0; i < pn_length; ++i) {
|
||||
self->buffer[pn_offset + i] ^= self->mask[1 + i];
|
||||
pn_truncated = self->buffer[pn_offset + i] | (pn_truncated << 8);
|
||||
}
|
||||
|
||||
return Py_BuildValue("y#i", self->buffer, pn_offset + pn_length, pn_truncated);
|
||||
}
|
||||
|
||||
static PyMethodDef HeaderProtection_methods[] = {
|
||||
{"apply", (PyCFunction)HeaderProtection_apply, METH_VARARGS, ""},
|
||||
{"remove", (PyCFunction)HeaderProtection_remove, METH_VARARGS, ""},
|
||||
{NULL}
|
||||
};
|
||||
|
||||
static PyType_Slot HeaderProtectionType_slots[] = {
|
||||
{Py_tp_dealloc, HeaderProtection_dealloc},
|
||||
{Py_tp_methods, HeaderProtection_methods},
|
||||
{Py_tp_doc, "HeaderProtection objects"},
|
||||
{Py_tp_init, HeaderProtection_init},
|
||||
{0, 0},
|
||||
};
|
||||
|
||||
static PyType_Spec HeaderProtectionType_spec = {
|
||||
MODULE_NAME ".HeaderProtectionType",
|
||||
sizeof(HeaderProtectionObject),
|
||||
0,
|
||||
Py_TPFLAGS_DEFAULT,
|
||||
HeaderProtectionType_slots
|
||||
};
|
||||
|
||||
static struct PyModuleDef moduledef = {
|
||||
PyModuleDef_HEAD_INIT,
|
||||
MODULE_NAME, /* m_name */
|
||||
"Cryptography utilities.", /* m_doc */
|
||||
-1, /* m_size */
|
||||
NULL, /* m_methods */
|
||||
NULL, /* m_reload */
|
||||
NULL, /* m_traverse */
|
||||
NULL, /* m_clear */
|
||||
NULL, /* m_free */
|
||||
};
|
||||
|
||||
PyMODINIT_FUNC
|
||||
PyInit__crypto(void)
|
||||
{
|
||||
PyObject* m;
|
||||
|
||||
m = PyModule_Create(&moduledef);
|
||||
if (m == NULL)
|
||||
return NULL;
|
||||
|
||||
CryptoError = PyErr_NewException(MODULE_NAME ".CryptoError", PyExc_ValueError, NULL);
|
||||
Py_INCREF(CryptoError);
|
||||
PyModule_AddObject(m, "CryptoError", CryptoError);
|
||||
|
||||
AEADType = PyType_FromSpec(&AEADType_spec);
|
||||
if (AEADType == NULL)
|
||||
return NULL;
|
||||
PyModule_AddObject(m, "AEAD", AEADType);
|
||||
|
||||
HeaderProtectionType = PyType_FromSpec(&HeaderProtectionType_spec);
|
||||
if (HeaderProtectionType == NULL)
|
||||
return NULL;
|
||||
PyModule_AddObject(m, "HeaderProtection", HeaderProtectionType);
|
||||
|
||||
// ensure required ciphers are initialised
|
||||
EVP_add_cipher(EVP_aes_128_ecb());
|
||||
EVP_add_cipher(EVP_aes_128_gcm());
|
||||
EVP_add_cipher(EVP_aes_256_ecb());
|
||||
EVP_add_cipher(EVP_aes_256_gcm());
|
||||
|
||||
return m;
|
||||
}
|
||||
17
venv/Lib/site-packages/aioquic/_crypto.pyi
Normal file
17
venv/Lib/site-packages/aioquic/_crypto.pyi
Normal file
@@ -0,0 +1,17 @@
|
||||
from typing import Tuple
|
||||
|
||||
class AEAD:
|
||||
def __init__(self, cipher_name: bytes, key: bytes, iv: bytes): ...
|
||||
def decrypt(
|
||||
self, data: bytes, associated_data: bytes, packet_number: int
|
||||
) -> bytes: ...
|
||||
def encrypt(
|
||||
self, data: bytes, associated_data: bytes, packet_number: int
|
||||
) -> bytes: ...
|
||||
|
||||
class CryptoError(ValueError): ...
|
||||
|
||||
class HeaderProtection:
|
||||
def __init__(self, cipher_name: bytes, key: bytes): ...
|
||||
def apply(self, plain_header: bytes, protected_payload: bytes) -> bytes: ...
|
||||
def remove(self, packet: bytes, encrypted_offset: int) -> Tuple[bytes, int]: ...
|
||||
3
venv/Lib/site-packages/aioquic/asyncio/__init__.py
Normal file
3
venv/Lib/site-packages/aioquic/asyncio/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .client import connect # noqa
|
||||
from .protocol import QuicConnectionProtocol # noqa
|
||||
from .server import serve # noqa
|
||||
98
venv/Lib/site-packages/aioquic/asyncio/client.py
Normal file
98
venv/Lib/site-packages/aioquic/asyncio/client.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import asyncio
|
||||
import socket
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator, Callable, Optional, cast
|
||||
|
||||
from ..quic.configuration import QuicConfiguration
|
||||
from ..quic.connection import QuicConnection, QuicTokenHandler
|
||||
from ..tls import SessionTicketHandler
|
||||
from .protocol import QuicConnectionProtocol, QuicStreamHandler
|
||||
|
||||
__all__ = ["connect"]
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def connect(
|
||||
host: str,
|
||||
port: int,
|
||||
*,
|
||||
configuration: Optional[QuicConfiguration] = None,
|
||||
create_protocol: Optional[Callable] = QuicConnectionProtocol,
|
||||
session_ticket_handler: Optional[SessionTicketHandler] = None,
|
||||
stream_handler: Optional[QuicStreamHandler] = None,
|
||||
token_handler: Optional[QuicTokenHandler] = None,
|
||||
wait_connected: bool = True,
|
||||
local_port: int = 0,
|
||||
) -> AsyncGenerator[QuicConnectionProtocol, None]:
|
||||
"""
|
||||
Connect to a QUIC server at the given `host` and `port`.
|
||||
|
||||
:meth:`connect()` returns an awaitable. Awaiting it yields a
|
||||
:class:`~aioquic.asyncio.QuicConnectionProtocol` which can be used to
|
||||
create streams.
|
||||
|
||||
:func:`connect` also accepts the following optional arguments:
|
||||
|
||||
* ``configuration`` is a :class:`~aioquic.quic.configuration.QuicConfiguration`
|
||||
configuration object.
|
||||
* ``create_protocol`` allows customizing the :class:`~asyncio.Protocol` that
|
||||
manages the connection. It should be a callable or class accepting the same
|
||||
arguments as :class:`~aioquic.asyncio.QuicConnectionProtocol` and returning
|
||||
an instance of :class:`~aioquic.asyncio.QuicConnectionProtocol` or a subclass.
|
||||
* ``session_ticket_handler`` is a callback which is invoked by the TLS
|
||||
engine when a new session ticket is received.
|
||||
* ``stream_handler`` is a callback which is invoked whenever a stream is
|
||||
created. It must accept two arguments: a :class:`asyncio.StreamReader`
|
||||
and a :class:`asyncio.StreamWriter`.
|
||||
* ``wait_connected`` indicates whether the context manager should wait for the
|
||||
connection to be established before yielding the
|
||||
:class:`~aioquic.asyncio.QuicConnectionProtocol`. By default this is `True` but
|
||||
you can set it to `False` if you want to immediately start sending data using
|
||||
0-RTT.
|
||||
* ``local_port`` is the UDP port number that this client wants to bind.
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
local_host = "::"
|
||||
|
||||
# lookup remote address
|
||||
infos = await loop.getaddrinfo(host, port, type=socket.SOCK_DGRAM)
|
||||
addr = infos[0][4]
|
||||
if len(addr) == 2:
|
||||
addr = ("::ffff:" + addr[0], addr[1], 0, 0)
|
||||
|
||||
# prepare QUIC connection
|
||||
if configuration is None:
|
||||
configuration = QuicConfiguration(is_client=True)
|
||||
if configuration.server_name is None:
|
||||
configuration.server_name = host
|
||||
connection = QuicConnection(
|
||||
configuration=configuration,
|
||||
session_ticket_handler=session_ticket_handler,
|
||||
token_handler=token_handler,
|
||||
)
|
||||
|
||||
# explicitly enable IPv4/IPv6 dual stack
|
||||
sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
|
||||
completed = False
|
||||
try:
|
||||
sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
|
||||
sock.bind((local_host, local_port, 0, 0))
|
||||
completed = True
|
||||
finally:
|
||||
if not completed:
|
||||
sock.close()
|
||||
# connect
|
||||
transport, protocol = await loop.create_datagram_endpoint(
|
||||
lambda: create_protocol(connection, stream_handler=stream_handler),
|
||||
sock=sock,
|
||||
)
|
||||
protocol = cast(QuicConnectionProtocol, protocol)
|
||||
try:
|
||||
protocol.connect(addr, transmit=wait_connected)
|
||||
if wait_connected:
|
||||
await protocol.wait_connected()
|
||||
yield protocol
|
||||
finally:
|
||||
protocol.close()
|
||||
await protocol.wait_closed()
|
||||
transport.close()
|
||||
272
venv/Lib/site-packages/aioquic/asyncio/protocol.py
Normal file
272
venv/Lib/site-packages/aioquic/asyncio/protocol.py
Normal file
@@ -0,0 +1,272 @@
|
||||
import asyncio
|
||||
from typing import Any, Callable, Dict, Optional, Text, Tuple, Union, cast
|
||||
|
||||
from ..quic import events
|
||||
from ..quic.connection import NetworkAddress, QuicConnection
|
||||
from ..quic.packet import QuicErrorCode
|
||||
|
||||
QuicConnectionIdHandler = Callable[[bytes], None]
|
||||
QuicStreamHandler = Callable[[asyncio.StreamReader, asyncio.StreamWriter], None]
|
||||
|
||||
|
||||
class QuicConnectionProtocol(asyncio.DatagramProtocol):
|
||||
def __init__(
|
||||
self, quic: QuicConnection, stream_handler: Optional[QuicStreamHandler] = None
|
||||
):
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
self._closed = asyncio.Event()
|
||||
self._connected = False
|
||||
self._connected_waiter: Optional[asyncio.Future[None]] = None
|
||||
self._loop = loop
|
||||
self._ping_waiters: Dict[int, asyncio.Future[None]] = {}
|
||||
self._quic = quic
|
||||
self._stream_readers: Dict[int, asyncio.StreamReader] = {}
|
||||
self._timer: Optional[asyncio.TimerHandle] = None
|
||||
self._timer_at: Optional[float] = None
|
||||
self._transmit_task: Optional[asyncio.Handle] = None
|
||||
self._transport: Optional[asyncio.DatagramTransport] = None
|
||||
|
||||
# callbacks
|
||||
self._connection_id_issued_handler: QuicConnectionIdHandler = lambda c: None
|
||||
self._connection_id_retired_handler: QuicConnectionIdHandler = lambda c: None
|
||||
self._connection_terminated_handler: Callable[[], None] = lambda: None
|
||||
if stream_handler is not None:
|
||||
self._stream_handler = stream_handler
|
||||
else:
|
||||
self._stream_handler = lambda r, w: None
|
||||
|
||||
def change_connection_id(self) -> None:
|
||||
"""
|
||||
Change the connection ID used to communicate with the peer.
|
||||
|
||||
The previous connection ID will be retired.
|
||||
"""
|
||||
self._quic.change_connection_id()
|
||||
self.transmit()
|
||||
|
||||
def close(
|
||||
self,
|
||||
error_code: int = QuicErrorCode.NO_ERROR,
|
||||
reason_phrase: str = "",
|
||||
) -> None:
|
||||
"""
|
||||
Close the connection.
|
||||
|
||||
:param error_code: An error code indicating why the connection is
|
||||
being closed.
|
||||
:param reason_phrase: A human-readable explanation of why the
|
||||
connection is being closed.
|
||||
"""
|
||||
self._quic.close(
|
||||
error_code=error_code,
|
||||
reason_phrase=reason_phrase,
|
||||
)
|
||||
self.transmit()
|
||||
|
||||
def connect(self, addr: NetworkAddress, transmit=True) -> None:
|
||||
"""
|
||||
Initiate the TLS handshake.
|
||||
|
||||
This method can only be called for clients and a single time.
|
||||
"""
|
||||
self._quic.connect(addr, now=self._loop.time())
|
||||
if transmit:
|
||||
self.transmit()
|
||||
|
||||
async def create_stream(
|
||||
self, is_unidirectional: bool = False
|
||||
) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]:
|
||||
"""
|
||||
Create a QUIC stream and return a pair of (reader, writer) objects.
|
||||
|
||||
The returned reader and writer objects are instances of
|
||||
:class:`asyncio.StreamReader` and :class:`asyncio.StreamWriter` classes.
|
||||
"""
|
||||
stream_id = self._quic.get_next_available_stream_id(
|
||||
is_unidirectional=is_unidirectional
|
||||
)
|
||||
return self._create_stream(stream_id)
|
||||
|
||||
def request_key_update(self) -> None:
|
||||
"""
|
||||
Request an update of the encryption keys.
|
||||
"""
|
||||
self._quic.request_key_update()
|
||||
self.transmit()
|
||||
|
||||
async def ping(self) -> None:
|
||||
"""
|
||||
Ping the peer and wait for the response.
|
||||
"""
|
||||
waiter = self._loop.create_future()
|
||||
uid = id(waiter)
|
||||
self._ping_waiters[uid] = waiter
|
||||
self._quic.send_ping(uid)
|
||||
self.transmit()
|
||||
await asyncio.shield(waiter)
|
||||
|
||||
def transmit(self) -> None:
|
||||
"""
|
||||
Send pending datagrams to the peer and arm the timer if needed.
|
||||
|
||||
This method is called automatically when data is received from the peer
|
||||
or when a timer goes off. If you interact directly with the underlying
|
||||
:class:`~aioquic.quic.connection.QuicConnection`, make sure you call this
|
||||
method whenever data needs to be sent out to the network.
|
||||
"""
|
||||
self._transmit_task = None
|
||||
|
||||
# send datagrams
|
||||
for data, addr in self._quic.datagrams_to_send(now=self._loop.time()):
|
||||
self._transport.sendto(data, addr)
|
||||
|
||||
# re-arm timer
|
||||
timer_at = self._quic.get_timer()
|
||||
if self._timer is not None and self._timer_at != timer_at:
|
||||
self._timer.cancel()
|
||||
self._timer = None
|
||||
if self._timer is None and timer_at is not None:
|
||||
self._timer = self._loop.call_at(timer_at, self._handle_timer)
|
||||
self._timer_at = timer_at
|
||||
|
||||
async def wait_closed(self) -> None:
|
||||
"""
|
||||
Wait for the connection to be closed.
|
||||
"""
|
||||
await self._closed.wait()
|
||||
|
||||
async def wait_connected(self) -> None:
|
||||
"""
|
||||
Wait for the TLS handshake to complete.
|
||||
"""
|
||||
assert self._connected_waiter is None, "already awaiting connected"
|
||||
if not self._connected:
|
||||
self._connected_waiter = self._loop.create_future()
|
||||
await asyncio.shield(self._connected_waiter)
|
||||
|
||||
# asyncio.Transport
|
||||
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
""":meta private:"""
|
||||
self._transport = cast(asyncio.DatagramTransport, transport)
|
||||
|
||||
def datagram_received(self, data: Union[bytes, Text], addr: NetworkAddress) -> None:
|
||||
""":meta private:"""
|
||||
self._quic.receive_datagram(cast(bytes, data), addr, now=self._loop.time())
|
||||
self._process_events()
|
||||
self.transmit()
|
||||
|
||||
# overridable
|
||||
|
||||
def quic_event_received(self, event: events.QuicEvent) -> None:
|
||||
"""
|
||||
Called when a QUIC event is received.
|
||||
|
||||
Reimplement this in your subclass to handle the events.
|
||||
"""
|
||||
# FIXME: move this to a subclass
|
||||
if isinstance(event, events.ConnectionTerminated):
|
||||
for reader in self._stream_readers.values():
|
||||
reader.feed_eof()
|
||||
elif isinstance(event, events.StreamDataReceived):
|
||||
reader = self._stream_readers.get(event.stream_id, None)
|
||||
if reader is None:
|
||||
reader, writer = self._create_stream(event.stream_id)
|
||||
self._stream_handler(reader, writer)
|
||||
reader.feed_data(event.data)
|
||||
if event.end_stream:
|
||||
reader.feed_eof()
|
||||
|
||||
# private
|
||||
|
||||
def _create_stream(
|
||||
self, stream_id: int
|
||||
) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]:
|
||||
adapter = QuicStreamAdapter(self, stream_id)
|
||||
reader = asyncio.StreamReader()
|
||||
protocol = asyncio.streams.StreamReaderProtocol(reader)
|
||||
writer = asyncio.StreamWriter(adapter, protocol, reader, self._loop)
|
||||
self._stream_readers[stream_id] = reader
|
||||
return reader, writer
|
||||
|
||||
def _handle_timer(self) -> None:
|
||||
now = max(self._timer_at, self._loop.time())
|
||||
self._timer = None
|
||||
self._timer_at = None
|
||||
self._quic.handle_timer(now=now)
|
||||
self._process_events()
|
||||
self.transmit()
|
||||
|
||||
def _process_events(self) -> None:
|
||||
event = self._quic.next_event()
|
||||
while event is not None:
|
||||
if isinstance(event, events.ConnectionIdIssued):
|
||||
self._connection_id_issued_handler(event.connection_id)
|
||||
elif isinstance(event, events.ConnectionIdRetired):
|
||||
self._connection_id_retired_handler(event.connection_id)
|
||||
elif isinstance(event, events.ConnectionTerminated):
|
||||
self._connection_terminated_handler()
|
||||
|
||||
# abort connection waiter
|
||||
if self._connected_waiter is not None:
|
||||
waiter = self._connected_waiter
|
||||
self._connected_waiter = None
|
||||
waiter.set_exception(ConnectionError)
|
||||
|
||||
# abort ping waiters
|
||||
for waiter in self._ping_waiters.values():
|
||||
waiter.set_exception(ConnectionError)
|
||||
self._ping_waiters.clear()
|
||||
|
||||
self._closed.set()
|
||||
elif isinstance(event, events.HandshakeCompleted):
|
||||
if self._connected_waiter is not None:
|
||||
waiter = self._connected_waiter
|
||||
self._connected = True
|
||||
self._connected_waiter = None
|
||||
waiter.set_result(None)
|
||||
elif isinstance(event, events.PingAcknowledged):
|
||||
waiter = self._ping_waiters.pop(event.uid, None)
|
||||
if waiter is not None:
|
||||
waiter.set_result(None)
|
||||
self.quic_event_received(event)
|
||||
event = self._quic.next_event()
|
||||
|
||||
def _transmit_soon(self) -> None:
|
||||
if self._transmit_task is None:
|
||||
self._transmit_task = self._loop.call_soon(self.transmit)
|
||||
|
||||
|
||||
class QuicStreamAdapter(asyncio.Transport):
|
||||
def __init__(self, protocol: QuicConnectionProtocol, stream_id: int):
|
||||
self.protocol = protocol
|
||||
self.stream_id = stream_id
|
||||
self._closing = False
|
||||
|
||||
def can_write_eof(self) -> bool:
|
||||
return True
|
||||
|
||||
def get_extra_info(self, name: str, default: Any = None) -> Any:
|
||||
"""
|
||||
Get information about the underlying QUIC stream.
|
||||
"""
|
||||
if name == "stream_id":
|
||||
return self.stream_id
|
||||
|
||||
def write(self, data):
|
||||
self.protocol._quic.send_stream_data(self.stream_id, data)
|
||||
self.protocol._transmit_soon()
|
||||
|
||||
def write_eof(self):
|
||||
if self._closing:
|
||||
return
|
||||
self._closing = True
|
||||
self.protocol._quic.send_stream_data(self.stream_id, b"", end_stream=True)
|
||||
self.protocol._transmit_soon()
|
||||
|
||||
def close(self):
|
||||
self.write_eof()
|
||||
|
||||
def is_closing(self) -> bool:
|
||||
return self._closing
|
||||
215
venv/Lib/site-packages/aioquic/asyncio/server.py
Normal file
215
venv/Lib/site-packages/aioquic/asyncio/server.py
Normal file
@@ -0,0 +1,215 @@
|
||||
import asyncio
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, Optional, Text, Union, cast
|
||||
|
||||
from ..buffer import Buffer
|
||||
from ..quic.configuration import SMALLEST_MAX_DATAGRAM_SIZE, QuicConfiguration
|
||||
from ..quic.connection import NetworkAddress, QuicConnection
|
||||
from ..quic.packet import (
|
||||
QuicPacketType,
|
||||
encode_quic_retry,
|
||||
encode_quic_version_negotiation,
|
||||
pull_quic_header,
|
||||
)
|
||||
from ..quic.retry import QuicRetryTokenHandler
|
||||
from ..tls import SessionTicketFetcher, SessionTicketHandler
|
||||
from .protocol import QuicConnectionProtocol, QuicStreamHandler
|
||||
|
||||
__all__ = ["serve"]
|
||||
|
||||
|
||||
class QuicServer(asyncio.DatagramProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
configuration: QuicConfiguration,
|
||||
create_protocol: Callable = QuicConnectionProtocol,
|
||||
session_ticket_fetcher: Optional[SessionTicketFetcher] = None,
|
||||
session_ticket_handler: Optional[SessionTicketHandler] = None,
|
||||
retry: bool = False,
|
||||
stream_handler: Optional[QuicStreamHandler] = None,
|
||||
) -> None:
|
||||
self._configuration = configuration
|
||||
self._create_protocol = create_protocol
|
||||
self._loop = asyncio.get_event_loop()
|
||||
self._protocols: Dict[bytes, QuicConnectionProtocol] = {}
|
||||
self._session_ticket_fetcher = session_ticket_fetcher
|
||||
self._session_ticket_handler = session_ticket_handler
|
||||
self._transport: Optional[asyncio.DatagramTransport] = None
|
||||
|
||||
self._stream_handler = stream_handler
|
||||
|
||||
if retry:
|
||||
self._retry = QuicRetryTokenHandler()
|
||||
else:
|
||||
self._retry = None
|
||||
|
||||
def close(self):
|
||||
for protocol in set(self._protocols.values()):
|
||||
protocol.close()
|
||||
self._protocols.clear()
|
||||
self._transport.close()
|
||||
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
self._transport = cast(asyncio.DatagramTransport, transport)
|
||||
|
||||
def datagram_received(self, data: Union[bytes, Text], addr: NetworkAddress) -> None:
|
||||
data = cast(bytes, data)
|
||||
buf = Buffer(data=data)
|
||||
|
||||
try:
|
||||
header = pull_quic_header(
|
||||
buf, host_cid_length=self._configuration.connection_id_length
|
||||
)
|
||||
except ValueError:
|
||||
return
|
||||
|
||||
# version negotiation
|
||||
if (
|
||||
header.version is not None
|
||||
and header.version not in self._configuration.supported_versions
|
||||
):
|
||||
self._transport.sendto(
|
||||
encode_quic_version_negotiation(
|
||||
source_cid=header.destination_cid,
|
||||
destination_cid=header.source_cid,
|
||||
supported_versions=self._configuration.supported_versions,
|
||||
),
|
||||
addr,
|
||||
)
|
||||
return
|
||||
|
||||
protocol = self._protocols.get(header.destination_cid, None)
|
||||
original_destination_connection_id: Optional[bytes] = None
|
||||
retry_source_connection_id: Optional[bytes] = None
|
||||
if (
|
||||
protocol is None
|
||||
and len(data) >= SMALLEST_MAX_DATAGRAM_SIZE
|
||||
and header.packet_type == QuicPacketType.INITIAL
|
||||
):
|
||||
# retry
|
||||
if self._retry is not None:
|
||||
if not header.token:
|
||||
# create a retry token
|
||||
source_cid = os.urandom(8)
|
||||
self._transport.sendto(
|
||||
encode_quic_retry(
|
||||
version=header.version,
|
||||
source_cid=source_cid,
|
||||
destination_cid=header.source_cid,
|
||||
original_destination_cid=header.destination_cid,
|
||||
retry_token=self._retry.create_token(
|
||||
addr, header.destination_cid, source_cid
|
||||
),
|
||||
),
|
||||
addr,
|
||||
)
|
||||
return
|
||||
else:
|
||||
# validate retry token
|
||||
try:
|
||||
(
|
||||
original_destination_connection_id,
|
||||
retry_source_connection_id,
|
||||
) = self._retry.validate_token(addr, header.token)
|
||||
except ValueError:
|
||||
return
|
||||
else:
|
||||
original_destination_connection_id = header.destination_cid
|
||||
|
||||
# create new connection
|
||||
connection = QuicConnection(
|
||||
configuration=self._configuration,
|
||||
original_destination_connection_id=original_destination_connection_id,
|
||||
retry_source_connection_id=retry_source_connection_id,
|
||||
session_ticket_fetcher=self._session_ticket_fetcher,
|
||||
session_ticket_handler=self._session_ticket_handler,
|
||||
)
|
||||
protocol = self._create_protocol(
|
||||
connection, stream_handler=self._stream_handler
|
||||
)
|
||||
protocol.connection_made(self._transport)
|
||||
|
||||
# register callbacks
|
||||
protocol._connection_id_issued_handler = partial(
|
||||
self._connection_id_issued, protocol=protocol
|
||||
)
|
||||
protocol._connection_id_retired_handler = partial(
|
||||
self._connection_id_retired, protocol=protocol
|
||||
)
|
||||
protocol._connection_terminated_handler = partial(
|
||||
self._connection_terminated, protocol=protocol
|
||||
)
|
||||
|
||||
self._protocols[header.destination_cid] = protocol
|
||||
self._protocols[connection.host_cid] = protocol
|
||||
|
||||
if protocol is not None:
|
||||
protocol.datagram_received(data, addr)
|
||||
|
||||
def _connection_id_issued(self, cid: bytes, protocol: QuicConnectionProtocol):
|
||||
self._protocols[cid] = protocol
|
||||
|
||||
def _connection_id_retired(
|
||||
self, cid: bytes, protocol: QuicConnectionProtocol
|
||||
) -> None:
|
||||
assert self._protocols[cid] == protocol
|
||||
del self._protocols[cid]
|
||||
|
||||
def _connection_terminated(self, protocol: QuicConnectionProtocol):
|
||||
for cid, proto in list(self._protocols.items()):
|
||||
if proto == protocol:
|
||||
del self._protocols[cid]
|
||||
|
||||
|
||||
async def serve(
|
||||
host: str,
|
||||
port: int,
|
||||
*,
|
||||
configuration: QuicConfiguration,
|
||||
create_protocol: Callable = QuicConnectionProtocol,
|
||||
session_ticket_fetcher: Optional[SessionTicketFetcher] = None,
|
||||
session_ticket_handler: Optional[SessionTicketHandler] = None,
|
||||
retry: bool = False,
|
||||
stream_handler: QuicStreamHandler = None,
|
||||
) -> QuicServer:
|
||||
"""
|
||||
Start a QUIC server at the given `host` and `port`.
|
||||
|
||||
:func:`serve` requires a :class:`~aioquic.quic.configuration.QuicConfiguration`
|
||||
containing TLS certificate and private key as the ``configuration`` argument.
|
||||
|
||||
:func:`serve` also accepts the following optional arguments:
|
||||
|
||||
* ``create_protocol`` allows customizing the :class:`~asyncio.Protocol` that
|
||||
manages the connection. It should be a callable or class accepting the same
|
||||
arguments as :class:`~aioquic.asyncio.QuicConnectionProtocol` and returning
|
||||
an instance of :class:`~aioquic.asyncio.QuicConnectionProtocol` or a subclass.
|
||||
* ``session_ticket_fetcher`` is a callback which is invoked by the TLS
|
||||
engine when a session ticket is presented by the peer. It should return
|
||||
the session ticket with the specified ID or `None` if it is not found.
|
||||
* ``session_ticket_handler`` is a callback which is invoked by the TLS
|
||||
engine when a new session ticket is issued. It should store the session
|
||||
ticket for future lookup.
|
||||
* ``retry`` specifies whether client addresses should be validated prior to
|
||||
the cryptographic handshake using a retry packet.
|
||||
* ``stream_handler`` is a callback which is invoked whenever a stream is
|
||||
created. It must accept two arguments: a :class:`asyncio.StreamReader`
|
||||
and a :class:`asyncio.StreamWriter`.
|
||||
"""
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
_, protocol = await loop.create_datagram_endpoint(
|
||||
lambda: QuicServer(
|
||||
configuration=configuration,
|
||||
create_protocol=create_protocol,
|
||||
session_ticket_fetcher=session_ticket_fetcher,
|
||||
session_ticket_handler=session_ticket_handler,
|
||||
retry=retry,
|
||||
stream_handler=stream_handler,
|
||||
),
|
||||
local_addr=(host, port),
|
||||
)
|
||||
return protocol
|
||||
30
venv/Lib/site-packages/aioquic/buffer.py
Normal file
30
venv/Lib/site-packages/aioquic/buffer.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from ._buffer import Buffer, BufferReadError, BufferWriteError # noqa
|
||||
|
||||
UINT_VAR_MAX = 0x3FFFFFFFFFFFFFFF
|
||||
UINT_VAR_MAX_SIZE = 8
|
||||
|
||||
|
||||
def encode_uint_var(value: int) -> bytes:
|
||||
"""
|
||||
Encode a variable-length unsigned integer.
|
||||
"""
|
||||
buf = Buffer(capacity=UINT_VAR_MAX_SIZE)
|
||||
buf.push_uint_var(value)
|
||||
return buf.data
|
||||
|
||||
|
||||
def size_uint_var(value: int) -> int:
|
||||
"""
|
||||
Return the number of bytes required to encode the given value
|
||||
as a QUIC variable-length unsigned integer.
|
||||
"""
|
||||
if value <= 0x3F:
|
||||
return 1
|
||||
elif value <= 0x3FFF:
|
||||
return 2
|
||||
elif value <= 0x3FFFFFFF:
|
||||
return 4
|
||||
elif value <= 0x3FFFFFFFFFFFFFFF:
|
||||
return 8
|
||||
else:
|
||||
raise ValueError("Integer is too big for a variable-length integer")
|
||||
0
venv/Lib/site-packages/aioquic/h0/__init__.py
Normal file
0
venv/Lib/site-packages/aioquic/h0/__init__.py
Normal file
68
venv/Lib/site-packages/aioquic/h0/connection.py
Normal file
68
venv/Lib/site-packages/aioquic/h0/connection.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from typing import Dict, List
|
||||
|
||||
from aioquic.h3.events import DataReceived, H3Event, Headers, HeadersReceived
|
||||
from aioquic.quic.connection import QuicConnection
|
||||
from aioquic.quic.events import QuicEvent, StreamDataReceived
|
||||
|
||||
H0_ALPN = ["hq-interop"]
|
||||
|
||||
|
||||
class H0Connection:
|
||||
"""
|
||||
An HTTP/0.9 connection object.
|
||||
"""
|
||||
|
||||
def __init__(self, quic: QuicConnection):
|
||||
self._buffer: Dict[int, bytes] = {}
|
||||
self._headers_received: Dict[int, bool] = {}
|
||||
self._is_client = quic.configuration.is_client
|
||||
self._quic = quic
|
||||
|
||||
def handle_event(self, event: QuicEvent) -> List[H3Event]:
|
||||
http_events: List[H3Event] = []
|
||||
|
||||
if isinstance(event, StreamDataReceived) and (event.stream_id % 4) == 0:
|
||||
data = self._buffer.pop(event.stream_id, b"") + event.data
|
||||
if not self._headers_received.get(event.stream_id, False):
|
||||
if self._is_client:
|
||||
http_events.append(
|
||||
HeadersReceived(
|
||||
headers=[], stream_ended=False, stream_id=event.stream_id
|
||||
)
|
||||
)
|
||||
elif data.endswith(b"\r\n") or event.end_stream:
|
||||
method, path = data.rstrip().split(b" ", 1)
|
||||
http_events.append(
|
||||
HeadersReceived(
|
||||
headers=[(b":method", method), (b":path", path)],
|
||||
stream_ended=False,
|
||||
stream_id=event.stream_id,
|
||||
)
|
||||
)
|
||||
data = b""
|
||||
else:
|
||||
# incomplete request, stash the data
|
||||
self._buffer[event.stream_id] = data
|
||||
return http_events
|
||||
self._headers_received[event.stream_id] = True
|
||||
|
||||
http_events.append(
|
||||
DataReceived(
|
||||
data=data, stream_ended=event.end_stream, stream_id=event.stream_id
|
||||
)
|
||||
)
|
||||
|
||||
return http_events
|
||||
|
||||
def send_data(self, stream_id: int, data: bytes, end_stream: bool) -> None:
|
||||
self._quic.send_stream_data(stream_id, data, end_stream)
|
||||
|
||||
def send_headers(
|
||||
self, stream_id: int, headers: Headers, end_stream: bool = False
|
||||
) -> None:
|
||||
if self._is_client:
|
||||
headers_dict = dict(headers)
|
||||
data = headers_dict[b":method"] + b" " + headers_dict[b":path"] + b"\r\n"
|
||||
else:
|
||||
data = b""
|
||||
self._quic.send_stream_data(stream_id, data, end_stream)
|
||||
0
venv/Lib/site-packages/aioquic/h3/__init__.py
Normal file
0
venv/Lib/site-packages/aioquic/h3/__init__.py
Normal file
1218
venv/Lib/site-packages/aioquic/h3/connection.py
Normal file
1218
venv/Lib/site-packages/aioquic/h3/connection.py
Normal file
File diff suppressed because it is too large
Load Diff
100
venv/Lib/site-packages/aioquic/h3/events.py
Normal file
100
venv/Lib/site-packages/aioquic/h3/events.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
Headers = List[Tuple[bytes, bytes]]
|
||||
|
||||
|
||||
class H3Event:
|
||||
"""
|
||||
Base class for HTTP/3 events.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataReceived(H3Event):
|
||||
"""
|
||||
The DataReceived event is fired whenever data is received on a stream from
|
||||
the remote peer.
|
||||
"""
|
||||
|
||||
data: bytes
|
||||
"The data which was received."
|
||||
|
||||
stream_id: int
|
||||
"The ID of the stream the data was received for."
|
||||
|
||||
stream_ended: bool
|
||||
"Whether the STREAM frame had the FIN bit set."
|
||||
|
||||
push_id: Optional[int] = None
|
||||
"The Push ID or `None` if this is not a push."
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatagramReceived(H3Event):
|
||||
"""
|
||||
The DatagramReceived is fired whenever a datagram is received from the
|
||||
the remote peer.
|
||||
"""
|
||||
|
||||
data: bytes
|
||||
"The data which was received."
|
||||
|
||||
stream_id: int
|
||||
"The ID of the stream the data was received for."
|
||||
|
||||
|
||||
@dataclass
|
||||
class HeadersReceived(H3Event):
|
||||
"""
|
||||
The HeadersReceived event is fired whenever headers are received.
|
||||
"""
|
||||
|
||||
headers: Headers
|
||||
"The headers."
|
||||
|
||||
stream_id: int
|
||||
"The ID of the stream the headers were received for."
|
||||
|
||||
stream_ended: bool
|
||||
"Whether the STREAM frame had the FIN bit set."
|
||||
|
||||
push_id: Optional[int] = None
|
||||
"The Push ID or `None` if this is not a push."
|
||||
|
||||
|
||||
@dataclass
|
||||
class PushPromiseReceived(H3Event):
|
||||
"""
|
||||
The PushedStreamReceived event is fired whenever a pushed stream has been
|
||||
received from the remote peer.
|
||||
"""
|
||||
|
||||
headers: Headers
|
||||
"The request headers."
|
||||
|
||||
push_id: int
|
||||
"The Push ID of the push promise."
|
||||
|
||||
stream_id: int
|
||||
"The Stream ID of the stream that the push is related to."
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebTransportStreamDataReceived(H3Event):
|
||||
"""
|
||||
The WebTransportStreamDataReceived is fired whenever data is received
|
||||
for a WebTransport stream.
|
||||
"""
|
||||
|
||||
data: bytes
|
||||
"The data which was received."
|
||||
|
||||
stream_id: int
|
||||
"The ID of the stream the data was received for."
|
||||
|
||||
stream_ended: bool
|
||||
"Whether the STREAM frame had the FIN bit set."
|
||||
|
||||
session_id: int
|
||||
"The ID of the session the data was received for."
|
||||
17
venv/Lib/site-packages/aioquic/h3/exceptions.py
Normal file
17
venv/Lib/site-packages/aioquic/h3/exceptions.py
Normal file
@@ -0,0 +1,17 @@
|
||||
class H3Error(Exception):
|
||||
"""
|
||||
Base class for HTTP/3 exceptions.
|
||||
"""
|
||||
|
||||
|
||||
class InvalidStreamTypeError(H3Error):
|
||||
"""
|
||||
An action was attempted on an invalid stream type.
|
||||
"""
|
||||
|
||||
|
||||
class NoAvailablePushIDError(H3Error):
|
||||
"""
|
||||
There are no available push IDs left, or push is not supported
|
||||
by the remote party.
|
||||
"""
|
||||
1
venv/Lib/site-packages/aioquic/py.typed
Normal file
1
venv/Lib/site-packages/aioquic/py.typed
Normal file
@@ -0,0 +1 @@
|
||||
Marker
|
||||
0
venv/Lib/site-packages/aioquic/quic/__init__.py
Normal file
0
venv/Lib/site-packages/aioquic/quic/__init__.py
Normal file
163
venv/Lib/site-packages/aioquic/quic/configuration.py
Normal file
163
venv/Lib/site-packages/aioquic/quic/configuration.py
Normal file
@@ -0,0 +1,163 @@
|
||||
from dataclasses import dataclass, field
|
||||
from os import PathLike
|
||||
from re import split
|
||||
from typing import Any, List, Optional, TextIO, Union
|
||||
|
||||
from ..tls import (
|
||||
CipherSuite,
|
||||
SessionTicket,
|
||||
load_pem_private_key,
|
||||
load_pem_x509_certificates,
|
||||
)
|
||||
from .logger import QuicLogger
|
||||
from .packet import QuicProtocolVersion
|
||||
|
||||
SMALLEST_MAX_DATAGRAM_SIZE = 1200
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuicConfiguration:
|
||||
"""
|
||||
A QUIC configuration.
|
||||
"""
|
||||
|
||||
alpn_protocols: Optional[List[str]] = None
|
||||
"""
|
||||
A list of supported ALPN protocols.
|
||||
"""
|
||||
|
||||
congestion_control_algorithm: str = "reno"
|
||||
"""
|
||||
The name of the congestion control algorithm to use.
|
||||
|
||||
Currently supported algorithms: `"reno", `"cubic"`.
|
||||
"""
|
||||
|
||||
connection_id_length: int = 8
|
||||
"""
|
||||
The length in bytes of local connection IDs.
|
||||
"""
|
||||
|
||||
idle_timeout: float = 60.0
|
||||
"""
|
||||
The idle timeout in seconds.
|
||||
|
||||
The connection is terminated if nothing is received for the given duration.
|
||||
"""
|
||||
|
||||
is_client: bool = True
|
||||
"""
|
||||
Whether this is the client side of the QUIC connection.
|
||||
"""
|
||||
|
||||
max_data: int = 1048576
|
||||
"""
|
||||
Connection-wide flow control limit.
|
||||
"""
|
||||
|
||||
max_datagram_size: int = SMALLEST_MAX_DATAGRAM_SIZE
|
||||
"""
|
||||
The maximum QUIC payload size in bytes to send, excluding UDP or IP overhead.
|
||||
"""
|
||||
|
||||
max_stream_data: int = 1048576
|
||||
"""
|
||||
Per-stream flow control limit.
|
||||
"""
|
||||
|
||||
quic_logger: Optional[QuicLogger] = None
|
||||
"""
|
||||
The :class:`~aioquic.quic.logger.QuicLogger` instance to log events to.
|
||||
"""
|
||||
|
||||
secrets_log_file: TextIO = None
|
||||
"""
|
||||
A file-like object in which to log traffic secrets.
|
||||
|
||||
This is useful to analyze traffic captures with Wireshark.
|
||||
"""
|
||||
|
||||
server_name: Optional[str] = None
|
||||
"""
|
||||
The server name to use when verifying the server's TLS certificate, which
|
||||
can either be a DNS name or an IP address.
|
||||
|
||||
If it is a DNS name, it is also sent during the TLS handshake in the
|
||||
Server Name Indication (SNI) extension.
|
||||
|
||||
.. note:: This is only used by clients.
|
||||
"""
|
||||
|
||||
session_ticket: Optional[SessionTicket] = None
|
||||
"""
|
||||
The TLS session ticket which should be used for session resumption.
|
||||
"""
|
||||
|
||||
token: bytes = b""
|
||||
"""
|
||||
The address validation token that can be used to validate future connections.
|
||||
|
||||
.. note:: This is only used by clients.
|
||||
"""
|
||||
|
||||
# For internal purposes, not guaranteed to be stable.
|
||||
cadata: Optional[bytes] = None
|
||||
cafile: Optional[str] = None
|
||||
capath: Optional[str] = None
|
||||
certificate: Any = None
|
||||
certificate_chain: List[Any] = field(default_factory=list)
|
||||
cipher_suites: Optional[List[CipherSuite]] = None
|
||||
initial_rtt: float = 0.1
|
||||
max_datagram_frame_size: Optional[int] = None
|
||||
original_version: Optional[int] = None
|
||||
private_key: Any = None
|
||||
quantum_readiness_test: bool = False
|
||||
supported_versions: List[int] = field(
|
||||
default_factory=lambda: [
|
||||
QuicProtocolVersion.VERSION_1,
|
||||
QuicProtocolVersion.VERSION_2,
|
||||
]
|
||||
)
|
||||
verify_mode: Optional[int] = None
|
||||
|
||||
def load_cert_chain(
|
||||
self,
|
||||
certfile: PathLike,
|
||||
keyfile: Optional[PathLike] = None,
|
||||
password: Optional[Union[bytes, str]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Load a private key and the corresponding certificate.
|
||||
"""
|
||||
with open(certfile, "rb") as fp:
|
||||
boundary = b"-----BEGIN PRIVATE KEY-----\n"
|
||||
chunks = split(b"\n" + boundary, fp.read())
|
||||
certificates = load_pem_x509_certificates(chunks[0])
|
||||
if len(chunks) == 2:
|
||||
private_key = boundary + chunks[1]
|
||||
self.private_key = load_pem_private_key(private_key)
|
||||
self.certificate = certificates[0]
|
||||
self.certificate_chain = certificates[1:]
|
||||
|
||||
if keyfile is not None:
|
||||
with open(keyfile, "rb") as fp:
|
||||
self.private_key = load_pem_private_key(
|
||||
fp.read(),
|
||||
password=password.encode("utf8")
|
||||
if isinstance(password, str)
|
||||
else password,
|
||||
)
|
||||
|
||||
def load_verify_locations(
|
||||
self,
|
||||
cafile: Optional[str] = None,
|
||||
capath: Optional[str] = None,
|
||||
cadata: Optional[bytes] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Load a set of "certification authority" (CA) certificates used to
|
||||
validate other peers' certificates.
|
||||
"""
|
||||
self.cafile = cafile
|
||||
self.capath = capath
|
||||
self.cadata = cadata
|
||||
128
venv/Lib/site-packages/aioquic/quic/congestion/base.py
Normal file
128
venv/Lib/site-packages/aioquic/quic/congestion/base.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import abc
|
||||
from typing import Any, Dict, Iterable, Optional, Protocol
|
||||
|
||||
from ..packet_builder import QuicSentPacket
|
||||
|
||||
K_GRANULARITY = 0.001 # seconds
|
||||
K_INITIAL_WINDOW = 10
|
||||
K_MINIMUM_WINDOW = 2
|
||||
|
||||
|
||||
class QuicCongestionControl(abc.ABC):
|
||||
"""
|
||||
Base class for congestion control implementations.
|
||||
"""
|
||||
|
||||
bytes_in_flight: int = 0
|
||||
congestion_window: int = 0
|
||||
ssthresh: Optional[int] = None
|
||||
|
||||
def __init__(self, *, max_datagram_size: int) -> None:
|
||||
self.congestion_window = K_INITIAL_WINDOW * max_datagram_size
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_packet_acked(self, *, now: float, packet: QuicSentPacket) -> None: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_packet_sent(self, *, packet: QuicSentPacket) -> None: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_packets_expired(self, *, packets: Iterable[QuicSentPacket]) -> None: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_packets_lost(
|
||||
self, *, now: float, packets: Iterable[QuicSentPacket]
|
||||
) -> None: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_rtt_measurement(self, *, now: float, rtt: float) -> None: ...
|
||||
|
||||
def get_log_data(self) -> Dict[str, Any]:
|
||||
data = {"cwnd": self.congestion_window, "bytes_in_flight": self.bytes_in_flight}
|
||||
if self.ssthresh is not None:
|
||||
data["ssthresh"] = self.ssthresh
|
||||
return data
|
||||
|
||||
|
||||
class QuicCongestionControlFactory(Protocol):
|
||||
def __call__(self, *, max_datagram_size: int) -> QuicCongestionControl: ...
|
||||
|
||||
|
||||
class QuicRttMonitor:
|
||||
"""
|
||||
Roundtrip time monitor for HyStart.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._increases = 0
|
||||
self._last_time = None
|
||||
self._ready = False
|
||||
self._size = 5
|
||||
|
||||
self._filtered_min: Optional[float] = None
|
||||
|
||||
self._sample_idx = 0
|
||||
self._sample_max: Optional[float] = None
|
||||
self._sample_min: Optional[float] = None
|
||||
self._sample_time = 0.0
|
||||
self._samples = [0.0 for i in range(self._size)]
|
||||
|
||||
def add_rtt(self, *, rtt: float) -> None:
|
||||
self._samples[self._sample_idx] = rtt
|
||||
self._sample_idx += 1
|
||||
|
||||
if self._sample_idx >= self._size:
|
||||
self._sample_idx = 0
|
||||
self._ready = True
|
||||
|
||||
if self._ready:
|
||||
self._sample_max = self._samples[0]
|
||||
self._sample_min = self._samples[0]
|
||||
for sample in self._samples[1:]:
|
||||
if sample < self._sample_min:
|
||||
self._sample_min = sample
|
||||
elif sample > self._sample_max:
|
||||
self._sample_max = sample
|
||||
|
||||
def is_rtt_increasing(self, *, now: float, rtt: float) -> bool:
|
||||
if now > self._sample_time + K_GRANULARITY:
|
||||
self.add_rtt(rtt=rtt)
|
||||
self._sample_time = now
|
||||
|
||||
if self._ready:
|
||||
if self._filtered_min is None or self._filtered_min > self._sample_max:
|
||||
self._filtered_min = self._sample_max
|
||||
|
||||
delta = self._sample_min - self._filtered_min
|
||||
if delta * 4 >= self._filtered_min:
|
||||
self._increases += 1
|
||||
if self._increases >= self._size:
|
||||
return True
|
||||
elif delta > 0:
|
||||
self._increases = 0
|
||||
return False
|
||||
|
||||
|
||||
_factories: Dict[str, QuicCongestionControlFactory] = {}
|
||||
|
||||
|
||||
def create_congestion_control(
|
||||
name: str, *, max_datagram_size: int
|
||||
) -> QuicCongestionControl:
|
||||
"""
|
||||
Create an instance of the `name` congestion control algorithm.
|
||||
"""
|
||||
try:
|
||||
factory = _factories[name]
|
||||
except KeyError:
|
||||
raise Exception(f"Unknown congestion control algorithm: {name}")
|
||||
return factory(max_datagram_size=max_datagram_size)
|
||||
|
||||
|
||||
def register_congestion_control(
|
||||
name: str, factory: QuicCongestionControlFactory
|
||||
) -> None:
|
||||
"""
|
||||
Register a congestion control algorithm named `name`.
|
||||
"""
|
||||
_factories[name] = factory
|
||||
212
venv/Lib/site-packages/aioquic/quic/congestion/cubic.py
Normal file
212
venv/Lib/site-packages/aioquic/quic/congestion/cubic.py
Normal file
@@ -0,0 +1,212 @@
|
||||
from typing import Any, Dict, Iterable
|
||||
|
||||
from ..packet_builder import QuicSentPacket
|
||||
from .base import (
|
||||
K_INITIAL_WINDOW,
|
||||
K_MINIMUM_WINDOW,
|
||||
QuicCongestionControl,
|
||||
QuicRttMonitor,
|
||||
register_congestion_control,
|
||||
)
|
||||
|
||||
# cubic specific variables (see https://www.rfc-editor.org/rfc/rfc9438.html#name-definitions)
|
||||
K_CUBIC_C = 0.4
|
||||
K_CUBIC_LOSS_REDUCTION_FACTOR = 0.7
|
||||
K_CUBIC_MAX_IDLE_TIME = 2 # reset the cwnd after 2 seconds of inactivity
|
||||
|
||||
|
||||
def better_cube_root(x: float) -> float:
|
||||
if x < 0:
|
||||
# avoid precision errors that make the cube root returns an imaginary number
|
||||
return -((-x) ** (1.0 / 3.0))
|
||||
else:
|
||||
return (x) ** (1.0 / 3.0)
|
||||
|
||||
|
||||
class CubicCongestionControl(QuicCongestionControl):
|
||||
"""
|
||||
Cubic congestion control implementation for aioquic
|
||||
"""
|
||||
|
||||
def __init__(self, max_datagram_size: int) -> None:
|
||||
super().__init__(max_datagram_size=max_datagram_size)
|
||||
# increase by one segment
|
||||
self.additive_increase_factor: int = max_datagram_size
|
||||
self._max_datagram_size: int = max_datagram_size
|
||||
self._congestion_recovery_start_time = 0.0
|
||||
|
||||
self._rtt_monitor = QuicRttMonitor()
|
||||
|
||||
self.rtt = 0.02 # starting RTT is considered to be 20ms
|
||||
|
||||
self.reset()
|
||||
|
||||
self.last_ack = 0.0
|
||||
|
||||
def W_cubic(self, t) -> int:
|
||||
W_max_segments = self._W_max / self._max_datagram_size
|
||||
target_segments = K_CUBIC_C * (t - self.K) ** 3 + (W_max_segments)
|
||||
return int(target_segments * self._max_datagram_size)
|
||||
|
||||
def is_reno_friendly(self, t) -> bool:
|
||||
return self.W_cubic(t) < self._W_est
|
||||
|
||||
def is_concave(self) -> bool:
|
||||
return self.congestion_window < self._W_max
|
||||
|
||||
def reset(self) -> None:
|
||||
self.congestion_window = K_INITIAL_WINDOW * self._max_datagram_size
|
||||
self.ssthresh = None
|
||||
|
||||
self._first_slow_start = True
|
||||
self._starting_congestion_avoidance = False
|
||||
self.K: float = 0.0
|
||||
self._W_est = 0
|
||||
self._cwnd_epoch = 0
|
||||
self._t_epoch = 0.0
|
||||
self._W_max = self.congestion_window
|
||||
|
||||
def on_packet_acked(self, *, now: float, packet: QuicSentPacket) -> None:
|
||||
self.bytes_in_flight -= packet.sent_bytes
|
||||
self.last_ack = packet.sent_time
|
||||
|
||||
if self.ssthresh is None or self.congestion_window < self.ssthresh:
|
||||
# slow start
|
||||
self.congestion_window += packet.sent_bytes
|
||||
else:
|
||||
# congestion avoidance
|
||||
if self._first_slow_start and not self._starting_congestion_avoidance:
|
||||
# exiting slow start without having a loss
|
||||
self._first_slow_start = False
|
||||
self._W_max = self.congestion_window
|
||||
self._t_epoch = now
|
||||
self._cwnd_epoch = self.congestion_window
|
||||
self._W_est = self._cwnd_epoch
|
||||
# calculate K
|
||||
W_max_segments = self._W_max / self._max_datagram_size
|
||||
cwnd_epoch_segments = self._cwnd_epoch / self._max_datagram_size
|
||||
self.K = better_cube_root(
|
||||
(W_max_segments - cwnd_epoch_segments) / K_CUBIC_C
|
||||
)
|
||||
|
||||
# initialize the variables used at start of congestion avoidance
|
||||
if self._starting_congestion_avoidance:
|
||||
self._starting_congestion_avoidance = False
|
||||
self._first_slow_start = False
|
||||
self._t_epoch = now
|
||||
self._cwnd_epoch = self.congestion_window
|
||||
self._W_est = self._cwnd_epoch
|
||||
# calculate K
|
||||
W_max_segments = self._W_max / self._max_datagram_size
|
||||
cwnd_epoch_segments = self._cwnd_epoch / self._max_datagram_size
|
||||
self.K = better_cube_root(
|
||||
(W_max_segments - cwnd_epoch_segments) / K_CUBIC_C
|
||||
)
|
||||
|
||||
self._W_est = int(
|
||||
self._W_est
|
||||
+ self.additive_increase_factor
|
||||
* (packet.sent_bytes / self.congestion_window)
|
||||
)
|
||||
|
||||
t = now - self._t_epoch
|
||||
|
||||
target: int = 0
|
||||
W_cubic = self.W_cubic(t + self.rtt)
|
||||
if W_cubic < self.congestion_window:
|
||||
target = self.congestion_window
|
||||
elif W_cubic > 1.5 * self.congestion_window:
|
||||
target = int(self.congestion_window * 1.5)
|
||||
else:
|
||||
target = W_cubic
|
||||
|
||||
if self.is_reno_friendly(t):
|
||||
# reno friendly region of cubic
|
||||
# (https://www.rfc-editor.org/rfc/rfc9438.html#name-reno-friendly-region)
|
||||
self.congestion_window = self._W_est
|
||||
elif self.is_concave():
|
||||
# concave region of cubic
|
||||
# (https://www.rfc-editor.org/rfc/rfc9438.html#name-concave-region)
|
||||
self.congestion_window = int(
|
||||
self.congestion_window
|
||||
+ (
|
||||
(target - self.congestion_window)
|
||||
* (self._max_datagram_size / self.congestion_window)
|
||||
)
|
||||
)
|
||||
else:
|
||||
# convex region of cubic
|
||||
# (https://www.rfc-editor.org/rfc/rfc9438.html#name-convex-region)
|
||||
self.congestion_window = int(
|
||||
self.congestion_window
|
||||
+ (
|
||||
(target - self.congestion_window)
|
||||
* (self._max_datagram_size / self.congestion_window)
|
||||
)
|
||||
)
|
||||
|
||||
def on_packet_sent(self, *, packet: QuicSentPacket) -> None:
|
||||
self.bytes_in_flight += packet.sent_bytes
|
||||
if self.last_ack == 0.0:
|
||||
return
|
||||
elapsed_idle = packet.sent_time - self.last_ack
|
||||
if elapsed_idle >= K_CUBIC_MAX_IDLE_TIME:
|
||||
self.reset()
|
||||
|
||||
def on_packets_expired(self, *, packets: Iterable[QuicSentPacket]) -> None:
|
||||
for packet in packets:
|
||||
self.bytes_in_flight -= packet.sent_bytes
|
||||
|
||||
def on_packets_lost(self, *, now: float, packets: Iterable[QuicSentPacket]) -> None:
|
||||
lost_largest_time = 0.0
|
||||
for packet in packets:
|
||||
self.bytes_in_flight -= packet.sent_bytes
|
||||
lost_largest_time = packet.sent_time
|
||||
|
||||
# start a new congestion event if packet was sent after the
|
||||
# start of the previous congestion recovery period.
|
||||
if lost_largest_time > self._congestion_recovery_start_time:
|
||||
self._congestion_recovery_start_time = now
|
||||
|
||||
# Normal congestion handle, can't be used in same time as fast convergence
|
||||
# self._W_max = self.congestion_window
|
||||
|
||||
# fast convergence
|
||||
if self._W_max is not None and self.congestion_window < self._W_max:
|
||||
self._W_max = int(
|
||||
self.congestion_window * (1 + K_CUBIC_LOSS_REDUCTION_FACTOR) / 2
|
||||
)
|
||||
else:
|
||||
self._W_max = self.congestion_window
|
||||
|
||||
# normal congestion MD
|
||||
flight_size = self.bytes_in_flight
|
||||
new_ssthresh = max(
|
||||
int(flight_size * K_CUBIC_LOSS_REDUCTION_FACTOR),
|
||||
K_MINIMUM_WINDOW * self._max_datagram_size,
|
||||
)
|
||||
self.ssthresh = new_ssthresh
|
||||
self.congestion_window = max(
|
||||
self.ssthresh, K_MINIMUM_WINDOW * self._max_datagram_size
|
||||
)
|
||||
|
||||
# restart a new congestion avoidance phase
|
||||
self._starting_congestion_avoidance = True
|
||||
|
||||
def on_rtt_measurement(self, *, now: float, rtt: float) -> None:
|
||||
self.rtt = rtt
|
||||
# check whether we should exit slow start
|
||||
if self.ssthresh is None and self._rtt_monitor.is_rtt_increasing(
|
||||
rtt=rtt, now=now
|
||||
):
|
||||
self.ssthresh = self.congestion_window
|
||||
|
||||
def get_log_data(self) -> Dict[str, Any]:
|
||||
data = super().get_log_data()
|
||||
|
||||
data["cubic-wmax"] = int(self._W_max)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
register_congestion_control("cubic", CubicCongestionControl)
|
||||
77
venv/Lib/site-packages/aioquic/quic/congestion/reno.py
Normal file
77
venv/Lib/site-packages/aioquic/quic/congestion/reno.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from typing import Iterable
|
||||
|
||||
from ..packet_builder import QuicSentPacket
|
||||
from .base import (
|
||||
K_MINIMUM_WINDOW,
|
||||
QuicCongestionControl,
|
||||
QuicRttMonitor,
|
||||
register_congestion_control,
|
||||
)
|
||||
|
||||
K_LOSS_REDUCTION_FACTOR = 0.5
|
||||
|
||||
|
||||
class RenoCongestionControl(QuicCongestionControl):
|
||||
"""
|
||||
New Reno congestion control.
|
||||
"""
|
||||
|
||||
def __init__(self, *, max_datagram_size: int) -> None:
|
||||
super().__init__(max_datagram_size=max_datagram_size)
|
||||
self._max_datagram_size = max_datagram_size
|
||||
self._congestion_recovery_start_time = 0.0
|
||||
self._congestion_stash = 0
|
||||
self._rtt_monitor = QuicRttMonitor()
|
||||
|
||||
def on_packet_acked(self, *, now: float, packet: QuicSentPacket) -> None:
|
||||
self.bytes_in_flight -= packet.sent_bytes
|
||||
|
||||
# don't increase window in congestion recovery
|
||||
if packet.sent_time <= self._congestion_recovery_start_time:
|
||||
return
|
||||
|
||||
if self.ssthresh is None or self.congestion_window < self.ssthresh:
|
||||
# slow start
|
||||
self.congestion_window += packet.sent_bytes
|
||||
else:
|
||||
# congestion avoidance
|
||||
self._congestion_stash += packet.sent_bytes
|
||||
count = self._congestion_stash // self.congestion_window
|
||||
if count:
|
||||
self._congestion_stash -= count * self.congestion_window
|
||||
self.congestion_window += count * self._max_datagram_size
|
||||
|
||||
def on_packet_sent(self, *, packet: QuicSentPacket) -> None:
|
||||
self.bytes_in_flight += packet.sent_bytes
|
||||
|
||||
def on_packets_expired(self, *, packets: Iterable[QuicSentPacket]) -> None:
|
||||
for packet in packets:
|
||||
self.bytes_in_flight -= packet.sent_bytes
|
||||
|
||||
def on_packets_lost(self, *, now: float, packets: Iterable[QuicSentPacket]) -> None:
|
||||
lost_largest_time = 0.0
|
||||
for packet in packets:
|
||||
self.bytes_in_flight -= packet.sent_bytes
|
||||
lost_largest_time = packet.sent_time
|
||||
|
||||
# start a new congestion event if packet was sent after the
|
||||
# start of the previous congestion recovery period.
|
||||
if lost_largest_time > self._congestion_recovery_start_time:
|
||||
self._congestion_recovery_start_time = now
|
||||
self.congestion_window = max(
|
||||
int(self.congestion_window * K_LOSS_REDUCTION_FACTOR),
|
||||
K_MINIMUM_WINDOW * self._max_datagram_size,
|
||||
)
|
||||
self.ssthresh = self.congestion_window
|
||||
|
||||
# TODO : collapse congestion window if persistent congestion
|
||||
|
||||
def on_rtt_measurement(self, *, now: float, rtt: float) -> None:
|
||||
# check whether we should exit slow start
|
||||
if self.ssthresh is None and self._rtt_monitor.is_rtt_increasing(
|
||||
now=now, rtt=rtt
|
||||
):
|
||||
self.ssthresh = self.congestion_window
|
||||
|
||||
|
||||
register_congestion_control("reno", RenoCongestionControl)
|
||||
3623
venv/Lib/site-packages/aioquic/quic/connection.py
Normal file
3623
venv/Lib/site-packages/aioquic/quic/connection.py
Normal file
File diff suppressed because it is too large
Load Diff
246
venv/Lib/site-packages/aioquic/quic/crypto.py
Normal file
246
venv/Lib/site-packages/aioquic/quic/crypto.py
Normal file
@@ -0,0 +1,246 @@
|
||||
import binascii
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
from .._crypto import AEAD, CryptoError, HeaderProtection
|
||||
from ..tls import CipherSuite, cipher_suite_hash, hkdf_expand_label, hkdf_extract
|
||||
from .packet import (
|
||||
QuicProtocolVersion,
|
||||
decode_packet_number,
|
||||
is_long_header,
|
||||
)
|
||||
|
||||
CIPHER_SUITES = {
|
||||
CipherSuite.AES_128_GCM_SHA256: (b"aes-128-ecb", b"aes-128-gcm"),
|
||||
CipherSuite.AES_256_GCM_SHA384: (b"aes-256-ecb", b"aes-256-gcm"),
|
||||
CipherSuite.CHACHA20_POLY1305_SHA256: (b"chacha20", b"chacha20-poly1305"),
|
||||
}
|
||||
INITIAL_CIPHER_SUITE = CipherSuite.AES_128_GCM_SHA256
|
||||
INITIAL_SALT_VERSION_1 = binascii.unhexlify("38762cf7f55934b34d179ae6a4c80cadccbb7f0a")
|
||||
INITIAL_SALT_VERSION_2 = binascii.unhexlify("0dede3def700a6db819381be6e269dcbf9bd2ed9")
|
||||
SAMPLE_SIZE = 16
|
||||
|
||||
|
||||
Callback = Callable[[str], None]
|
||||
|
||||
|
||||
def NoCallback(trigger: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class KeyUnavailableError(CryptoError):
|
||||
pass
|
||||
|
||||
|
||||
def derive_key_iv_hp(
|
||||
*, cipher_suite: CipherSuite, secret: bytes, version: int
|
||||
) -> Tuple[bytes, bytes, bytes]:
|
||||
algorithm = cipher_suite_hash(cipher_suite)
|
||||
if cipher_suite in [
|
||||
CipherSuite.AES_256_GCM_SHA384,
|
||||
CipherSuite.CHACHA20_POLY1305_SHA256,
|
||||
]:
|
||||
key_size = 32
|
||||
else:
|
||||
key_size = 16
|
||||
if version == QuicProtocolVersion.VERSION_2:
|
||||
return (
|
||||
hkdf_expand_label(algorithm, secret, b"quicv2 key", b"", key_size),
|
||||
hkdf_expand_label(algorithm, secret, b"quicv2 iv", b"", 12),
|
||||
hkdf_expand_label(algorithm, secret, b"quicv2 hp", b"", key_size),
|
||||
)
|
||||
else:
|
||||
return (
|
||||
hkdf_expand_label(algorithm, secret, b"quic key", b"", key_size),
|
||||
hkdf_expand_label(algorithm, secret, b"quic iv", b"", 12),
|
||||
hkdf_expand_label(algorithm, secret, b"quic hp", b"", key_size),
|
||||
)
|
||||
|
||||
|
||||
class CryptoContext:
|
||||
def __init__(
|
||||
self,
|
||||
key_phase: int = 0,
|
||||
setup_cb: Callback = NoCallback,
|
||||
teardown_cb: Callback = NoCallback,
|
||||
) -> None:
|
||||
self.aead: Optional[AEAD] = None
|
||||
self.cipher_suite: Optional[CipherSuite] = None
|
||||
self.hp: Optional[HeaderProtection] = None
|
||||
self.key_phase = key_phase
|
||||
self.secret: Optional[bytes] = None
|
||||
self.version: Optional[int] = None
|
||||
self._setup_cb = setup_cb
|
||||
self._teardown_cb = teardown_cb
|
||||
|
||||
def decrypt_packet(
|
||||
self, packet: bytes, encrypted_offset: int, expected_packet_number: int
|
||||
) -> Tuple[bytes, bytes, int, bool]:
|
||||
if self.aead is None:
|
||||
raise KeyUnavailableError("Decryption key is not available")
|
||||
|
||||
# header protection
|
||||
plain_header, packet_number = self.hp.remove(packet, encrypted_offset)
|
||||
first_byte = plain_header[0]
|
||||
|
||||
# packet number
|
||||
pn_length = (first_byte & 0x03) + 1
|
||||
packet_number = decode_packet_number(
|
||||
packet_number, pn_length * 8, expected_packet_number
|
||||
)
|
||||
|
||||
# detect key phase change
|
||||
crypto = self
|
||||
if not is_long_header(first_byte):
|
||||
key_phase = (first_byte & 4) >> 2
|
||||
if key_phase != self.key_phase:
|
||||
crypto = next_key_phase(self)
|
||||
|
||||
# payload protection
|
||||
payload = crypto.aead.decrypt(
|
||||
packet[len(plain_header) :], plain_header, packet_number
|
||||
)
|
||||
|
||||
return plain_header, payload, packet_number, crypto != self
|
||||
|
||||
def encrypt_packet(
|
||||
self, plain_header: bytes, plain_payload: bytes, packet_number: int
|
||||
) -> bytes:
|
||||
assert self.is_valid(), "Encryption key is not available"
|
||||
|
||||
# payload protection
|
||||
protected_payload = self.aead.encrypt(
|
||||
plain_payload, plain_header, packet_number
|
||||
)
|
||||
|
||||
# header protection
|
||||
return self.hp.apply(plain_header, protected_payload)
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
return self.aead is not None
|
||||
|
||||
def setup(self, *, cipher_suite: CipherSuite, secret: bytes, version: int) -> None:
|
||||
hp_cipher_name, aead_cipher_name = CIPHER_SUITES[cipher_suite]
|
||||
|
||||
key, iv, hp = derive_key_iv_hp(
|
||||
cipher_suite=cipher_suite,
|
||||
secret=secret,
|
||||
version=version,
|
||||
)
|
||||
self.aead = AEAD(aead_cipher_name, key, iv)
|
||||
self.cipher_suite = cipher_suite
|
||||
self.hp = HeaderProtection(hp_cipher_name, hp)
|
||||
self.secret = secret
|
||||
self.version = version
|
||||
|
||||
# trigger callback
|
||||
self._setup_cb("tls")
|
||||
|
||||
def teardown(self) -> None:
|
||||
self.aead = None
|
||||
self.cipher_suite = None
|
||||
self.hp = None
|
||||
self.secret = None
|
||||
|
||||
# trigger callback
|
||||
self._teardown_cb("tls")
|
||||
|
||||
|
||||
def apply_key_phase(self: CryptoContext, crypto: CryptoContext, trigger: str) -> None:
|
||||
self.aead = crypto.aead
|
||||
self.key_phase = crypto.key_phase
|
||||
self.secret = crypto.secret
|
||||
|
||||
# trigger callback
|
||||
self._setup_cb(trigger)
|
||||
|
||||
|
||||
def next_key_phase(self: CryptoContext) -> CryptoContext:
|
||||
algorithm = cipher_suite_hash(self.cipher_suite)
|
||||
|
||||
crypto = CryptoContext(key_phase=int(not self.key_phase))
|
||||
crypto.setup(
|
||||
cipher_suite=self.cipher_suite,
|
||||
secret=hkdf_expand_label(
|
||||
algorithm, self.secret, b"quic ku", b"", algorithm.digest_size
|
||||
),
|
||||
version=self.version,
|
||||
)
|
||||
return crypto
|
||||
|
||||
|
||||
class CryptoPair:
|
||||
def __init__(
|
||||
self,
|
||||
recv_setup_cb: Callback = NoCallback,
|
||||
recv_teardown_cb: Callback = NoCallback,
|
||||
send_setup_cb: Callback = NoCallback,
|
||||
send_teardown_cb: Callback = NoCallback,
|
||||
) -> None:
|
||||
self.aead_tag_size = 16
|
||||
self.recv = CryptoContext(setup_cb=recv_setup_cb, teardown_cb=recv_teardown_cb)
|
||||
self.send = CryptoContext(setup_cb=send_setup_cb, teardown_cb=send_teardown_cb)
|
||||
self._update_key_requested = False
|
||||
|
||||
def decrypt_packet(
|
||||
self, packet: bytes, encrypted_offset: int, expected_packet_number: int
|
||||
) -> Tuple[bytes, bytes, int]:
|
||||
plain_header, payload, packet_number, update_key = self.recv.decrypt_packet(
|
||||
packet, encrypted_offset, expected_packet_number
|
||||
)
|
||||
if update_key:
|
||||
self._update_key("remote_update")
|
||||
return plain_header, payload, packet_number
|
||||
|
||||
def encrypt_packet(
|
||||
self, plain_header: bytes, plain_payload: bytes, packet_number: int
|
||||
) -> bytes:
|
||||
if self._update_key_requested:
|
||||
self._update_key("local_update")
|
||||
return self.send.encrypt_packet(plain_header, plain_payload, packet_number)
|
||||
|
||||
def setup_initial(self, cid: bytes, is_client: bool, version: int) -> None:
|
||||
if is_client:
|
||||
recv_label, send_label = b"server in", b"client in"
|
||||
else:
|
||||
recv_label, send_label = b"client in", b"server in"
|
||||
|
||||
if version == QuicProtocolVersion.VERSION_2:
|
||||
initial_salt = INITIAL_SALT_VERSION_2
|
||||
else:
|
||||
initial_salt = INITIAL_SALT_VERSION_1
|
||||
|
||||
algorithm = cipher_suite_hash(INITIAL_CIPHER_SUITE)
|
||||
initial_secret = hkdf_extract(algorithm, initial_salt, cid)
|
||||
self.recv.setup(
|
||||
cipher_suite=INITIAL_CIPHER_SUITE,
|
||||
secret=hkdf_expand_label(
|
||||
algorithm, initial_secret, recv_label, b"", algorithm.digest_size
|
||||
),
|
||||
version=version,
|
||||
)
|
||||
self.send.setup(
|
||||
cipher_suite=INITIAL_CIPHER_SUITE,
|
||||
secret=hkdf_expand_label(
|
||||
algorithm, initial_secret, send_label, b"", algorithm.digest_size
|
||||
),
|
||||
version=version,
|
||||
)
|
||||
|
||||
def teardown(self) -> None:
|
||||
self.recv.teardown()
|
||||
self.send.teardown()
|
||||
|
||||
def update_key(self) -> None:
|
||||
self._update_key_requested = True
|
||||
|
||||
@property
|
||||
def key_phase(self) -> int:
|
||||
if self._update_key_requested:
|
||||
return int(not self.recv.key_phase)
|
||||
else:
|
||||
return self.recv.key_phase
|
||||
|
||||
def _update_key(self, trigger: str) -> None:
|
||||
apply_key_phase(self.recv, next_key_phase(self.recv), trigger=trigger)
|
||||
apply_key_phase(self.send, next_key_phase(self.send), trigger=trigger)
|
||||
self._update_key_requested = False
|
||||
126
venv/Lib/site-packages/aioquic/quic/events.py
Normal file
126
venv/Lib/site-packages/aioquic/quic/events.py
Normal file
@@ -0,0 +1,126 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class QuicEvent:
|
||||
"""
|
||||
Base class for QUIC events.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionIdIssued(QuicEvent):
|
||||
connection_id: bytes
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionIdRetired(QuicEvent):
|
||||
connection_id: bytes
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionTerminated(QuicEvent):
|
||||
"""
|
||||
The ConnectionTerminated event is fired when the QUIC connection is terminated.
|
||||
"""
|
||||
|
||||
error_code: int
|
||||
"The error code which was specified when closing the connection."
|
||||
|
||||
frame_type: Optional[int]
|
||||
"The frame type which caused the connection to be closed, or `None`."
|
||||
|
||||
reason_phrase: str
|
||||
"The human-readable reason for which the connection was closed."
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatagramFrameReceived(QuicEvent):
|
||||
"""
|
||||
The DatagramFrameReceived event is fired when a DATAGRAM frame is received.
|
||||
"""
|
||||
|
||||
data: bytes
|
||||
"The data which was received."
|
||||
|
||||
|
||||
@dataclass
|
||||
class HandshakeCompleted(QuicEvent):
|
||||
"""
|
||||
The HandshakeCompleted event is fired when the TLS handshake completes.
|
||||
"""
|
||||
|
||||
alpn_protocol: Optional[str]
|
||||
"The protocol which was negotiated using ALPN, or `None`."
|
||||
|
||||
early_data_accepted: bool
|
||||
"Whether early (0-RTT) data was accepted by the remote peer."
|
||||
|
||||
session_resumed: bool
|
||||
"Whether a TLS session was resumed."
|
||||
|
||||
|
||||
@dataclass
|
||||
class PingAcknowledged(QuicEvent):
|
||||
"""
|
||||
The PingAcknowledged event is fired when a PING frame is acknowledged.
|
||||
"""
|
||||
|
||||
uid: int
|
||||
"The unique ID of the PING."
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProtocolNegotiated(QuicEvent):
|
||||
"""
|
||||
The ProtocolNegotiated event is fired when ALPN negotiation completes.
|
||||
"""
|
||||
|
||||
alpn_protocol: Optional[str]
|
||||
"The protocol which was negotiated using ALPN, or `None`."
|
||||
|
||||
|
||||
@dataclass
|
||||
class StopSendingReceived(QuicEvent):
|
||||
"""
|
||||
The StopSendingReceived event is fired when the remote peer requests
|
||||
stopping data transmission on a stream.
|
||||
"""
|
||||
|
||||
error_code: int
|
||||
"The error code that was sent from the peer."
|
||||
|
||||
stream_id: int
|
||||
"The ID of the stream that the peer requested stopping data transmission."
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamDataReceived(QuicEvent):
|
||||
"""
|
||||
The StreamDataReceived event is fired whenever data is received on a
|
||||
stream.
|
||||
"""
|
||||
|
||||
data: bytes
|
||||
"The data which was received."
|
||||
|
||||
end_stream: bool
|
||||
"Whether the STREAM frame had the FIN bit set."
|
||||
|
||||
stream_id: int
|
||||
"The ID of the stream the data was received for."
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamReset(QuicEvent):
|
||||
"""
|
||||
The StreamReset event is fired when the remote peer resets a stream.
|
||||
"""
|
||||
|
||||
error_code: int
|
||||
"The error code that triggered the reset."
|
||||
|
||||
stream_id: int
|
||||
"The ID of the stream that was reset."
|
||||
329
venv/Lib/site-packages/aioquic/quic/logger.py
Normal file
329
venv/Lib/site-packages/aioquic/quic/logger.py
Normal file
@@ -0,0 +1,329 @@
|
||||
import binascii
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import Any, Deque, Dict, List, Optional
|
||||
|
||||
from ..h3.events import Headers
|
||||
from .packet import (
|
||||
QuicFrameType,
|
||||
QuicPacketType,
|
||||
QuicStreamFrame,
|
||||
QuicTransportParameters,
|
||||
)
|
||||
from .rangeset import RangeSet
|
||||
|
||||
PACKET_TYPE_NAMES = {
|
||||
QuicPacketType.INITIAL: "initial",
|
||||
QuicPacketType.HANDSHAKE: "handshake",
|
||||
QuicPacketType.ZERO_RTT: "0RTT",
|
||||
QuicPacketType.ONE_RTT: "1RTT",
|
||||
QuicPacketType.RETRY: "retry",
|
||||
QuicPacketType.VERSION_NEGOTIATION: "version_negotiation",
|
||||
}
|
||||
QLOG_VERSION = "0.3"
|
||||
|
||||
|
||||
def hexdump(data: bytes) -> str:
|
||||
return binascii.hexlify(data).decode("ascii")
|
||||
|
||||
|
||||
class QuicLoggerTrace:
|
||||
"""
|
||||
A QUIC event trace.
|
||||
|
||||
Events are logged in the format defined by qlog.
|
||||
|
||||
See:
|
||||
- https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-02
|
||||
- https://datatracker.ietf.org/doc/html/draft-marx-quic-qlog-quic-events
|
||||
- https://datatracker.ietf.org/doc/html/draft-marx-quic-qlog-h3-events
|
||||
"""
|
||||
|
||||
def __init__(self, *, is_client: bool, odcid: bytes) -> None:
|
||||
self._odcid = odcid
|
||||
self._events: Deque[Dict[str, Any]] = deque()
|
||||
self._vantage_point = {
|
||||
"name": "aioquic",
|
||||
"type": "client" if is_client else "server",
|
||||
}
|
||||
|
||||
# QUIC
|
||||
|
||||
def encode_ack_frame(self, ranges: RangeSet, delay: float) -> Dict:
|
||||
return {
|
||||
"ack_delay": self.encode_time(delay),
|
||||
"acked_ranges": [[x.start, x.stop - 1] for x in ranges],
|
||||
"frame_type": "ack",
|
||||
}
|
||||
|
||||
def encode_connection_close_frame(
|
||||
self, error_code: int, frame_type: Optional[int], reason_phrase: str
|
||||
) -> Dict:
|
||||
attrs = {
|
||||
"error_code": error_code,
|
||||
"error_space": "application" if frame_type is None else "transport",
|
||||
"frame_type": "connection_close",
|
||||
"raw_error_code": error_code,
|
||||
"reason": reason_phrase,
|
||||
}
|
||||
if frame_type is not None:
|
||||
attrs["trigger_frame_type"] = frame_type
|
||||
|
||||
return attrs
|
||||
|
||||
def encode_connection_limit_frame(self, frame_type: int, maximum: int) -> Dict:
|
||||
if frame_type == QuicFrameType.MAX_DATA:
|
||||
return {"frame_type": "max_data", "maximum": maximum}
|
||||
else:
|
||||
return {
|
||||
"frame_type": "max_streams",
|
||||
"maximum": maximum,
|
||||
"stream_type": "unidirectional"
|
||||
if frame_type == QuicFrameType.MAX_STREAMS_UNI
|
||||
else "bidirectional",
|
||||
}
|
||||
|
||||
def encode_crypto_frame(self, frame: QuicStreamFrame) -> Dict:
|
||||
return {
|
||||
"frame_type": "crypto",
|
||||
"length": len(frame.data),
|
||||
"offset": frame.offset,
|
||||
}
|
||||
|
||||
def encode_data_blocked_frame(self, limit: int) -> Dict:
|
||||
return {"frame_type": "data_blocked", "limit": limit}
|
||||
|
||||
def encode_datagram_frame(self, length: int) -> Dict:
|
||||
return {"frame_type": "datagram", "length": length}
|
||||
|
||||
def encode_handshake_done_frame(self) -> Dict:
|
||||
return {"frame_type": "handshake_done"}
|
||||
|
||||
def encode_max_stream_data_frame(self, maximum: int, stream_id: int) -> Dict:
|
||||
return {
|
||||
"frame_type": "max_stream_data",
|
||||
"maximum": maximum,
|
||||
"stream_id": stream_id,
|
||||
}
|
||||
|
||||
def encode_new_connection_id_frame(
|
||||
self,
|
||||
connection_id: bytes,
|
||||
retire_prior_to: int,
|
||||
sequence_number: int,
|
||||
stateless_reset_token: bytes,
|
||||
) -> Dict:
|
||||
return {
|
||||
"connection_id": hexdump(connection_id),
|
||||
"frame_type": "new_connection_id",
|
||||
"length": len(connection_id),
|
||||
"reset_token": hexdump(stateless_reset_token),
|
||||
"retire_prior_to": retire_prior_to,
|
||||
"sequence_number": sequence_number,
|
||||
}
|
||||
|
||||
def encode_new_token_frame(self, token: bytes) -> Dict:
|
||||
return {
|
||||
"frame_type": "new_token",
|
||||
"length": len(token),
|
||||
"token": hexdump(token),
|
||||
}
|
||||
|
||||
def encode_padding_frame(self) -> Dict:
|
||||
return {"frame_type": "padding"}
|
||||
|
||||
def encode_path_challenge_frame(self, data: bytes) -> Dict:
|
||||
return {"data": hexdump(data), "frame_type": "path_challenge"}
|
||||
|
||||
def encode_path_response_frame(self, data: bytes) -> Dict:
|
||||
return {"data": hexdump(data), "frame_type": "path_response"}
|
||||
|
||||
def encode_ping_frame(self) -> Dict:
|
||||
return {"frame_type": "ping"}
|
||||
|
||||
def encode_reset_stream_frame(
|
||||
self, error_code: int, final_size: int, stream_id: int
|
||||
) -> Dict:
|
||||
return {
|
||||
"error_code": error_code,
|
||||
"final_size": final_size,
|
||||
"frame_type": "reset_stream",
|
||||
"stream_id": stream_id,
|
||||
}
|
||||
|
||||
def encode_retire_connection_id_frame(self, sequence_number: int) -> Dict:
|
||||
return {
|
||||
"frame_type": "retire_connection_id",
|
||||
"sequence_number": sequence_number,
|
||||
}
|
||||
|
||||
def encode_stream_data_blocked_frame(self, limit: int, stream_id: int) -> Dict:
|
||||
return {
|
||||
"frame_type": "stream_data_blocked",
|
||||
"limit": limit,
|
||||
"stream_id": stream_id,
|
||||
}
|
||||
|
||||
def encode_stop_sending_frame(self, error_code: int, stream_id: int) -> Dict:
|
||||
return {
|
||||
"frame_type": "stop_sending",
|
||||
"error_code": error_code,
|
||||
"stream_id": stream_id,
|
||||
}
|
||||
|
||||
def encode_stream_frame(self, frame: QuicStreamFrame, stream_id: int) -> Dict:
|
||||
return {
|
||||
"fin": frame.fin,
|
||||
"frame_type": "stream",
|
||||
"length": len(frame.data),
|
||||
"offset": frame.offset,
|
||||
"stream_id": stream_id,
|
||||
}
|
||||
|
||||
def encode_streams_blocked_frame(self, is_unidirectional: bool, limit: int) -> Dict:
|
||||
return {
|
||||
"frame_type": "streams_blocked",
|
||||
"limit": limit,
|
||||
"stream_type": "unidirectional" if is_unidirectional else "bidirectional",
|
||||
}
|
||||
|
||||
def encode_time(self, seconds: float) -> float:
|
||||
"""
|
||||
Convert a time to milliseconds.
|
||||
"""
|
||||
return seconds * 1000
|
||||
|
||||
def encode_transport_parameters(
|
||||
self, owner: str, parameters: QuicTransportParameters
|
||||
) -> Dict[str, Any]:
|
||||
data: Dict[str, Any] = {"owner": owner}
|
||||
for param_name, param_value in parameters.__dict__.items():
|
||||
if isinstance(param_value, bool):
|
||||
data[param_name] = param_value
|
||||
elif isinstance(param_value, bytes):
|
||||
data[param_name] = hexdump(param_value)
|
||||
elif isinstance(param_value, int):
|
||||
data[param_name] = param_value
|
||||
return data
|
||||
|
||||
def packet_type(self, packet_type: QuicPacketType) -> str:
|
||||
return PACKET_TYPE_NAMES[packet_type]
|
||||
|
||||
# HTTP/3
|
||||
|
||||
def encode_http3_data_frame(self, length: int, stream_id: int) -> Dict:
|
||||
return {
|
||||
"frame": {"frame_type": "data"},
|
||||
"length": length,
|
||||
"stream_id": stream_id,
|
||||
}
|
||||
|
||||
def encode_http3_headers_frame(
|
||||
self, length: int, headers: Headers, stream_id: int
|
||||
) -> Dict:
|
||||
return {
|
||||
"frame": {
|
||||
"frame_type": "headers",
|
||||
"headers": self._encode_http3_headers(headers),
|
||||
},
|
||||
"length": length,
|
||||
"stream_id": stream_id,
|
||||
}
|
||||
|
||||
def encode_http3_push_promise_frame(
|
||||
self, length: int, headers: Headers, push_id: int, stream_id: int
|
||||
) -> Dict:
|
||||
return {
|
||||
"frame": {
|
||||
"frame_type": "push_promise",
|
||||
"headers": self._encode_http3_headers(headers),
|
||||
"push_id": push_id,
|
||||
},
|
||||
"length": length,
|
||||
"stream_id": stream_id,
|
||||
}
|
||||
|
||||
def _encode_http3_headers(self, headers: Headers) -> List[Dict]:
|
||||
return [
|
||||
{"name": h[0].decode("utf8"), "value": h[1].decode("utf8")} for h in headers
|
||||
]
|
||||
|
||||
# CORE
|
||||
|
||||
def log_event(self, *, category: str, event: str, data: Dict) -> None:
|
||||
self._events.append(
|
||||
{
|
||||
"data": data,
|
||||
"name": category + ":" + event,
|
||||
"time": self.encode_time(time.time()),
|
||||
}
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Return the trace as a dictionary which can be written as JSON.
|
||||
"""
|
||||
return {
|
||||
"common_fields": {
|
||||
"ODCID": hexdump(self._odcid),
|
||||
},
|
||||
"events": list(self._events),
|
||||
"vantage_point": self._vantage_point,
|
||||
}
|
||||
|
||||
|
||||
class QuicLogger:
|
||||
"""
|
||||
A QUIC event logger which stores traces in memory.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._traces: List[QuicLoggerTrace] = []
|
||||
|
||||
def start_trace(self, is_client: bool, odcid: bytes) -> QuicLoggerTrace:
|
||||
trace = QuicLoggerTrace(is_client=is_client, odcid=odcid)
|
||||
self._traces.append(trace)
|
||||
return trace
|
||||
|
||||
def end_trace(self, trace: QuicLoggerTrace) -> None:
|
||||
assert trace in self._traces, "QuicLoggerTrace does not belong to QuicLogger"
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Return the traces as a dictionary which can be written as JSON.
|
||||
"""
|
||||
return {
|
||||
"qlog_format": "JSON",
|
||||
"qlog_version": QLOG_VERSION,
|
||||
"traces": [trace.to_dict() for trace in self._traces],
|
||||
}
|
||||
|
||||
|
||||
class QuicFileLogger(QuicLogger):
|
||||
"""
|
||||
A QUIC event logger which writes one trace per file.
|
||||
"""
|
||||
|
||||
def __init__(self, path: str) -> None:
|
||||
if not os.path.isdir(path):
|
||||
raise ValueError("QUIC log output directory '%s' does not exist" % path)
|
||||
self.path = path
|
||||
super().__init__()
|
||||
|
||||
def end_trace(self, trace: QuicLoggerTrace) -> None:
|
||||
trace_dict = trace.to_dict()
|
||||
trace_path = os.path.join(
|
||||
self.path, trace_dict["common_fields"]["ODCID"] + ".qlog"
|
||||
)
|
||||
with open(trace_path, "w") as logger_fp:
|
||||
json.dump(
|
||||
{
|
||||
"qlog_format": "JSON",
|
||||
"qlog_version": QLOG_VERSION,
|
||||
"traces": [trace_dict],
|
||||
},
|
||||
logger_fp,
|
||||
)
|
||||
self._traces.remove(trace)
|
||||
640
venv/Lib/site-packages/aioquic/quic/packet.py
Normal file
640
venv/Lib/site-packages/aioquic/quic/packet.py
Normal file
@@ -0,0 +1,640 @@
|
||||
import binascii
|
||||
import ipaddress
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, IntEnum
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
|
||||
from ..buffer import Buffer
|
||||
from .rangeset import RangeSet
|
||||
|
||||
PACKET_LONG_HEADER = 0x80
|
||||
PACKET_FIXED_BIT = 0x40
|
||||
PACKET_SPIN_BIT = 0x20
|
||||
|
||||
CONNECTION_ID_MAX_SIZE = 20
|
||||
PACKET_NUMBER_MAX_SIZE = 4
|
||||
RETRY_AEAD_KEY_VERSION_1 = binascii.unhexlify("be0c690b9f66575a1d766b54e368c84e")
|
||||
RETRY_AEAD_KEY_VERSION_2 = binascii.unhexlify("8fb4b01b56ac48e260fbcbcead7ccc92")
|
||||
RETRY_AEAD_NONCE_VERSION_1 = binascii.unhexlify("461599d35d632bf2239825bb")
|
||||
RETRY_AEAD_NONCE_VERSION_2 = binascii.unhexlify("d86969bc2d7c6d9990efb04a")
|
||||
RETRY_INTEGRITY_TAG_SIZE = 16
|
||||
STATELESS_RESET_TOKEN_SIZE = 16
|
||||
|
||||
|
||||
class QuicErrorCode(IntEnum):
|
||||
NO_ERROR = 0x0
|
||||
INTERNAL_ERROR = 0x1
|
||||
CONNECTION_REFUSED = 0x2
|
||||
FLOW_CONTROL_ERROR = 0x3
|
||||
STREAM_LIMIT_ERROR = 0x4
|
||||
STREAM_STATE_ERROR = 0x5
|
||||
FINAL_SIZE_ERROR = 0x6
|
||||
FRAME_ENCODING_ERROR = 0x7
|
||||
TRANSPORT_PARAMETER_ERROR = 0x8
|
||||
CONNECTION_ID_LIMIT_ERROR = 0x9
|
||||
PROTOCOL_VIOLATION = 0xA
|
||||
INVALID_TOKEN = 0xB
|
||||
APPLICATION_ERROR = 0xC
|
||||
CRYPTO_BUFFER_EXCEEDED = 0xD
|
||||
KEY_UPDATE_ERROR = 0xE
|
||||
AEAD_LIMIT_REACHED = 0xF
|
||||
VERSION_NEGOTIATION_ERROR = 0x11
|
||||
CRYPTO_ERROR = 0x100
|
||||
|
||||
|
||||
class QuicPacketType(Enum):
|
||||
INITIAL = 0
|
||||
ZERO_RTT = 1
|
||||
HANDSHAKE = 2
|
||||
RETRY = 3
|
||||
VERSION_NEGOTIATION = 4
|
||||
ONE_RTT = 5
|
||||
|
||||
|
||||
# For backwards compatibility only, use `QuicPacketType` in new code.
|
||||
PACKET_TYPE_INITIAL = QuicPacketType.INITIAL
|
||||
|
||||
# QUIC version 1
|
||||
# https://datatracker.ietf.org/doc/html/rfc9000#section-17.2
|
||||
PACKET_LONG_TYPE_ENCODE_VERSION_1 = {
|
||||
QuicPacketType.INITIAL: 0,
|
||||
QuicPacketType.ZERO_RTT: 1,
|
||||
QuicPacketType.HANDSHAKE: 2,
|
||||
QuicPacketType.RETRY: 3,
|
||||
}
|
||||
PACKET_LONG_TYPE_DECODE_VERSION_1 = dict(
|
||||
(v, i) for (i, v) in PACKET_LONG_TYPE_ENCODE_VERSION_1.items()
|
||||
)
|
||||
|
||||
# QUIC version 2
|
||||
# https://datatracker.ietf.org/doc/html/rfc9369#section-3.2
|
||||
PACKET_LONG_TYPE_ENCODE_VERSION_2 = {
|
||||
QuicPacketType.INITIAL: 1,
|
||||
QuicPacketType.ZERO_RTT: 2,
|
||||
QuicPacketType.HANDSHAKE: 3,
|
||||
QuicPacketType.RETRY: 0,
|
||||
}
|
||||
PACKET_LONG_TYPE_DECODE_VERSION_2 = dict(
|
||||
(v, i) for (i, v) in PACKET_LONG_TYPE_ENCODE_VERSION_2.items()
|
||||
)
|
||||
|
||||
|
||||
class QuicProtocolVersion(IntEnum):
|
||||
NEGOTIATION = 0
|
||||
VERSION_1 = 0x00000001
|
||||
VERSION_2 = 0x6B3343CF
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuicHeader:
|
||||
version: Optional[int]
|
||||
"The protocol version. Only present in long header packets."
|
||||
|
||||
packet_type: QuicPacketType
|
||||
"The type of the packet."
|
||||
|
||||
packet_length: int
|
||||
"The total length of the packet, in bytes."
|
||||
|
||||
destination_cid: bytes
|
||||
"The destination connection ID."
|
||||
|
||||
source_cid: bytes
|
||||
"The destination connection ID."
|
||||
|
||||
token: bytes
|
||||
"The address verification token. Only present in `INITIAL` and `RETRY` packets."
|
||||
|
||||
integrity_tag: bytes
|
||||
"The retry integrity tag. Only present in `RETRY` packets."
|
||||
|
||||
supported_versions: List[int]
|
||||
"Supported protocol versions. Only present in `VERSION_NEGOTIATION` packets."
|
||||
|
||||
|
||||
def decode_packet_number(truncated: int, num_bits: int, expected: int) -> int:
|
||||
"""
|
||||
Recover a packet number from a truncated packet number.
|
||||
|
||||
See: Appendix A - Sample Packet Number Decoding Algorithm
|
||||
"""
|
||||
window = 1 << num_bits
|
||||
half_window = window // 2
|
||||
candidate = (expected & ~(window - 1)) | truncated
|
||||
if candidate <= expected - half_window and candidate < (1 << 62) - window:
|
||||
return candidate + window
|
||||
elif candidate > expected + half_window and candidate >= window:
|
||||
return candidate - window
|
||||
else:
|
||||
return candidate
|
||||
|
||||
|
||||
def get_retry_integrity_tag(
|
||||
packet_without_tag: bytes, original_destination_cid: bytes, version: int
|
||||
) -> bytes:
|
||||
"""
|
||||
Calculate the integrity tag for a RETRY packet.
|
||||
"""
|
||||
# build Retry pseudo packet
|
||||
buf = Buffer(capacity=1 + len(original_destination_cid) + len(packet_without_tag))
|
||||
buf.push_uint8(len(original_destination_cid))
|
||||
buf.push_bytes(original_destination_cid)
|
||||
buf.push_bytes(packet_without_tag)
|
||||
assert buf.eof()
|
||||
|
||||
if version == QuicProtocolVersion.VERSION_2:
|
||||
aead_key = RETRY_AEAD_KEY_VERSION_2
|
||||
aead_nonce = RETRY_AEAD_NONCE_VERSION_2
|
||||
else:
|
||||
aead_key = RETRY_AEAD_KEY_VERSION_1
|
||||
aead_nonce = RETRY_AEAD_NONCE_VERSION_1
|
||||
|
||||
# run AES-128-GCM
|
||||
aead = AESGCM(aead_key)
|
||||
integrity_tag = aead.encrypt(aead_nonce, b"", buf.data)
|
||||
assert len(integrity_tag) == RETRY_INTEGRITY_TAG_SIZE
|
||||
return integrity_tag
|
||||
|
||||
|
||||
def get_spin_bit(first_byte: int) -> bool:
|
||||
return bool(first_byte & PACKET_SPIN_BIT)
|
||||
|
||||
|
||||
def is_long_header(first_byte: int) -> bool:
|
||||
return bool(first_byte & PACKET_LONG_HEADER)
|
||||
|
||||
|
||||
def pretty_protocol_version(version: int) -> str:
|
||||
"""
|
||||
Return a user-friendly representation of a protocol version.
|
||||
"""
|
||||
try:
|
||||
version_name = QuicProtocolVersion(version).name
|
||||
except ValueError:
|
||||
version_name = "UNKNOWN"
|
||||
return f"0x{version:08x} ({version_name})"
|
||||
|
||||
|
||||
def pull_quic_header(buf: Buffer, host_cid_length: Optional[int] = None) -> QuicHeader:
|
||||
packet_start = buf.tell()
|
||||
|
||||
version = None
|
||||
integrity_tag = b""
|
||||
supported_versions = []
|
||||
token = b""
|
||||
|
||||
first_byte = buf.pull_uint8()
|
||||
if is_long_header(first_byte):
|
||||
# Long Header Packets.
|
||||
# https://datatracker.ietf.org/doc/html/rfc9000#section-17.2
|
||||
version = buf.pull_uint32()
|
||||
|
||||
destination_cid_length = buf.pull_uint8()
|
||||
if destination_cid_length > CONNECTION_ID_MAX_SIZE:
|
||||
raise ValueError(
|
||||
"Destination CID is too long (%d bytes)" % destination_cid_length
|
||||
)
|
||||
destination_cid = buf.pull_bytes(destination_cid_length)
|
||||
|
||||
source_cid_length = buf.pull_uint8()
|
||||
if source_cid_length > CONNECTION_ID_MAX_SIZE:
|
||||
raise ValueError("Source CID is too long (%d bytes)" % source_cid_length)
|
||||
source_cid = buf.pull_bytes(source_cid_length)
|
||||
|
||||
if version == QuicProtocolVersion.NEGOTIATION:
|
||||
# Version Negotiation Packet.
|
||||
# https://datatracker.ietf.org/doc/html/rfc9000#section-17.2.1
|
||||
packet_type = QuicPacketType.VERSION_NEGOTIATION
|
||||
while not buf.eof():
|
||||
supported_versions.append(buf.pull_uint32())
|
||||
packet_end = buf.tell()
|
||||
else:
|
||||
if not (first_byte & PACKET_FIXED_BIT):
|
||||
raise ValueError("Packet fixed bit is zero")
|
||||
|
||||
if version == QuicProtocolVersion.VERSION_2:
|
||||
packet_type = PACKET_LONG_TYPE_DECODE_VERSION_2[
|
||||
(first_byte & 0x30) >> 4
|
||||
]
|
||||
else:
|
||||
packet_type = PACKET_LONG_TYPE_DECODE_VERSION_1[
|
||||
(first_byte & 0x30) >> 4
|
||||
]
|
||||
|
||||
if packet_type == QuicPacketType.INITIAL:
|
||||
token_length = buf.pull_uint_var()
|
||||
token = buf.pull_bytes(token_length)
|
||||
rest_length = buf.pull_uint_var()
|
||||
elif packet_type == QuicPacketType.ZERO_RTT:
|
||||
rest_length = buf.pull_uint_var()
|
||||
elif packet_type == QuicPacketType.HANDSHAKE:
|
||||
rest_length = buf.pull_uint_var()
|
||||
else:
|
||||
token_length = buf.capacity - buf.tell() - RETRY_INTEGRITY_TAG_SIZE
|
||||
token = buf.pull_bytes(token_length)
|
||||
integrity_tag = buf.pull_bytes(RETRY_INTEGRITY_TAG_SIZE)
|
||||
rest_length = 0
|
||||
|
||||
# Check remainder length.
|
||||
packet_end = buf.tell() + rest_length
|
||||
if packet_end > buf.capacity:
|
||||
raise ValueError("Packet payload is truncated")
|
||||
|
||||
else:
|
||||
# Short Header Packets.
|
||||
# https://datatracker.ietf.org/doc/html/rfc9000#section-17.3
|
||||
if not (first_byte & PACKET_FIXED_BIT):
|
||||
raise ValueError("Packet fixed bit is zero")
|
||||
|
||||
version = None
|
||||
packet_type = QuicPacketType.ONE_RTT
|
||||
destination_cid = buf.pull_bytes(host_cid_length)
|
||||
source_cid = b""
|
||||
packet_end = buf.capacity
|
||||
|
||||
return QuicHeader(
|
||||
version=version,
|
||||
packet_type=packet_type,
|
||||
packet_length=packet_end - packet_start,
|
||||
destination_cid=destination_cid,
|
||||
source_cid=source_cid,
|
||||
token=token,
|
||||
integrity_tag=integrity_tag,
|
||||
supported_versions=supported_versions,
|
||||
)
|
||||
|
||||
|
||||
def encode_long_header_first_byte(
|
||||
version: int, packet_type: QuicPacketType, bits: int
|
||||
) -> int:
|
||||
"""
|
||||
Encode the first byte of a long header packet.
|
||||
"""
|
||||
if version == QuicProtocolVersion.VERSION_2:
|
||||
long_type_encode = PACKET_LONG_TYPE_ENCODE_VERSION_2
|
||||
else:
|
||||
long_type_encode = PACKET_LONG_TYPE_ENCODE_VERSION_1
|
||||
return (
|
||||
PACKET_LONG_HEADER
|
||||
| PACKET_FIXED_BIT
|
||||
| long_type_encode[packet_type] << 4
|
||||
| bits
|
||||
)
|
||||
|
||||
|
||||
def encode_quic_retry(
|
||||
version: int,
|
||||
source_cid: bytes,
|
||||
destination_cid: bytes,
|
||||
original_destination_cid: bytes,
|
||||
retry_token: bytes,
|
||||
unused: int = 0,
|
||||
) -> bytes:
|
||||
buf = Buffer(
|
||||
capacity=7
|
||||
+ len(destination_cid)
|
||||
+ len(source_cid)
|
||||
+ len(retry_token)
|
||||
+ RETRY_INTEGRITY_TAG_SIZE
|
||||
)
|
||||
buf.push_uint8(encode_long_header_first_byte(version, QuicPacketType.RETRY, unused))
|
||||
buf.push_uint32(version)
|
||||
buf.push_uint8(len(destination_cid))
|
||||
buf.push_bytes(destination_cid)
|
||||
buf.push_uint8(len(source_cid))
|
||||
buf.push_bytes(source_cid)
|
||||
buf.push_bytes(retry_token)
|
||||
buf.push_bytes(
|
||||
get_retry_integrity_tag(buf.data, original_destination_cid, version=version)
|
||||
)
|
||||
assert buf.eof()
|
||||
return buf.data
|
||||
|
||||
|
||||
def encode_quic_version_negotiation(
|
||||
source_cid: bytes, destination_cid: bytes, supported_versions: List[int]
|
||||
) -> bytes:
|
||||
buf = Buffer(
|
||||
capacity=7
|
||||
+ len(destination_cid)
|
||||
+ len(source_cid)
|
||||
+ 4 * len(supported_versions)
|
||||
)
|
||||
buf.push_uint8(os.urandom(1)[0] | PACKET_LONG_HEADER)
|
||||
buf.push_uint32(QuicProtocolVersion.NEGOTIATION)
|
||||
buf.push_uint8(len(destination_cid))
|
||||
buf.push_bytes(destination_cid)
|
||||
buf.push_uint8(len(source_cid))
|
||||
buf.push_bytes(source_cid)
|
||||
for version in supported_versions:
|
||||
buf.push_uint32(version)
|
||||
return buf.data
|
||||
|
||||
|
||||
# TLS EXTENSION
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuicPreferredAddress:
|
||||
ipv4_address: Optional[Tuple[str, int]]
|
||||
ipv6_address: Optional[Tuple[str, int]]
|
||||
connection_id: bytes
|
||||
stateless_reset_token: bytes
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuicVersionInformation:
|
||||
chosen_version: int
|
||||
available_versions: List[int]
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuicTransportParameters:
|
||||
original_destination_connection_id: Optional[bytes] = None
|
||||
max_idle_timeout: Optional[int] = None
|
||||
stateless_reset_token: Optional[bytes] = None
|
||||
max_udp_payload_size: Optional[int] = None
|
||||
initial_max_data: Optional[int] = None
|
||||
initial_max_stream_data_bidi_local: Optional[int] = None
|
||||
initial_max_stream_data_bidi_remote: Optional[int] = None
|
||||
initial_max_stream_data_uni: Optional[int] = None
|
||||
initial_max_streams_bidi: Optional[int] = None
|
||||
initial_max_streams_uni: Optional[int] = None
|
||||
ack_delay_exponent: Optional[int] = None
|
||||
max_ack_delay: Optional[int] = None
|
||||
disable_active_migration: Optional[bool] = False
|
||||
preferred_address: Optional[QuicPreferredAddress] = None
|
||||
active_connection_id_limit: Optional[int] = None
|
||||
initial_source_connection_id: Optional[bytes] = None
|
||||
retry_source_connection_id: Optional[bytes] = None
|
||||
version_information: Optional[QuicVersionInformation] = None
|
||||
max_datagram_frame_size: Optional[int] = None
|
||||
quantum_readiness: Optional[bytes] = None
|
||||
|
||||
|
||||
PARAMS = {
|
||||
0x00: ("original_destination_connection_id", bytes),
|
||||
0x01: ("max_idle_timeout", int),
|
||||
0x02: ("stateless_reset_token", bytes),
|
||||
0x03: ("max_udp_payload_size", int),
|
||||
0x04: ("initial_max_data", int),
|
||||
0x05: ("initial_max_stream_data_bidi_local", int),
|
||||
0x06: ("initial_max_stream_data_bidi_remote", int),
|
||||
0x07: ("initial_max_stream_data_uni", int),
|
||||
0x08: ("initial_max_streams_bidi", int),
|
||||
0x09: ("initial_max_streams_uni", int),
|
||||
0x0A: ("ack_delay_exponent", int),
|
||||
0x0B: ("max_ack_delay", int),
|
||||
0x0C: ("disable_active_migration", bool),
|
||||
0x0D: ("preferred_address", QuicPreferredAddress),
|
||||
0x0E: ("active_connection_id_limit", int),
|
||||
0x0F: ("initial_source_connection_id", bytes),
|
||||
0x10: ("retry_source_connection_id", bytes),
|
||||
# https://datatracker.ietf.org/doc/html/rfc9368#section-3
|
||||
0x11: ("version_information", QuicVersionInformation),
|
||||
# extensions
|
||||
0x0020: ("max_datagram_frame_size", int),
|
||||
0x0C37: ("quantum_readiness", bytes),
|
||||
}
|
||||
|
||||
|
||||
def pull_quic_preferred_address(buf: Buffer) -> QuicPreferredAddress:
|
||||
ipv4_address = None
|
||||
ipv4_host = buf.pull_bytes(4)
|
||||
ipv4_port = buf.pull_uint16()
|
||||
if ipv4_host != bytes(4):
|
||||
ipv4_address = (str(ipaddress.IPv4Address(ipv4_host)), ipv4_port)
|
||||
|
||||
ipv6_address = None
|
||||
ipv6_host = buf.pull_bytes(16)
|
||||
ipv6_port = buf.pull_uint16()
|
||||
if ipv6_host != bytes(16):
|
||||
ipv6_address = (str(ipaddress.IPv6Address(ipv6_host)), ipv6_port)
|
||||
|
||||
connection_id_length = buf.pull_uint8()
|
||||
connection_id = buf.pull_bytes(connection_id_length)
|
||||
stateless_reset_token = buf.pull_bytes(16)
|
||||
|
||||
return QuicPreferredAddress(
|
||||
ipv4_address=ipv4_address,
|
||||
ipv6_address=ipv6_address,
|
||||
connection_id=connection_id,
|
||||
stateless_reset_token=stateless_reset_token,
|
||||
)
|
||||
|
||||
|
||||
def push_quic_preferred_address(
|
||||
buf: Buffer, preferred_address: QuicPreferredAddress
|
||||
) -> None:
|
||||
if preferred_address.ipv4_address is not None:
|
||||
buf.push_bytes(ipaddress.IPv4Address(preferred_address.ipv4_address[0]).packed)
|
||||
buf.push_uint16(preferred_address.ipv4_address[1])
|
||||
else:
|
||||
buf.push_bytes(bytes(6))
|
||||
|
||||
if preferred_address.ipv6_address is not None:
|
||||
buf.push_bytes(ipaddress.IPv6Address(preferred_address.ipv6_address[0]).packed)
|
||||
buf.push_uint16(preferred_address.ipv6_address[1])
|
||||
else:
|
||||
buf.push_bytes(bytes(18))
|
||||
|
||||
buf.push_uint8(len(preferred_address.connection_id))
|
||||
buf.push_bytes(preferred_address.connection_id)
|
||||
buf.push_bytes(preferred_address.stateless_reset_token)
|
||||
|
||||
|
||||
def pull_quic_version_information(buf: Buffer, length: int) -> QuicVersionInformation:
|
||||
chosen_version = buf.pull_uint32()
|
||||
available_versions = []
|
||||
for i in range(length // 4 - 1):
|
||||
available_versions.append(buf.pull_uint32())
|
||||
|
||||
# If an endpoint receives a Chosen Version equal to zero, or any Available Version
|
||||
# equal to zero, it MUST treat it as a parsing failure.
|
||||
#
|
||||
# https://datatracker.ietf.org/doc/html/rfc9368#section-4
|
||||
if chosen_version == 0 or 0 in available_versions:
|
||||
raise ValueError("Version Information must not contain version 0")
|
||||
|
||||
return QuicVersionInformation(
|
||||
chosen_version=chosen_version,
|
||||
available_versions=available_versions,
|
||||
)
|
||||
|
||||
|
||||
def push_quic_version_information(
|
||||
buf: Buffer, version_information: QuicVersionInformation
|
||||
) -> None:
|
||||
buf.push_uint32(version_information.chosen_version)
|
||||
for version in version_information.available_versions:
|
||||
buf.push_uint32(version)
|
||||
|
||||
|
||||
def pull_quic_transport_parameters(buf: Buffer) -> QuicTransportParameters:
|
||||
params = QuicTransportParameters()
|
||||
while not buf.eof():
|
||||
param_id = buf.pull_uint_var()
|
||||
param_len = buf.pull_uint_var()
|
||||
param_start = buf.tell()
|
||||
if param_id in PARAMS:
|
||||
# Parse known parameter.
|
||||
param_name, param_type = PARAMS[param_id]
|
||||
if param_type is int:
|
||||
setattr(params, param_name, buf.pull_uint_var())
|
||||
elif param_type is bytes:
|
||||
setattr(params, param_name, buf.pull_bytes(param_len))
|
||||
elif param_type is QuicPreferredAddress:
|
||||
setattr(params, param_name, pull_quic_preferred_address(buf))
|
||||
elif param_type is QuicVersionInformation:
|
||||
setattr(
|
||||
params,
|
||||
param_name,
|
||||
pull_quic_version_information(buf, param_len),
|
||||
)
|
||||
else:
|
||||
setattr(params, param_name, True)
|
||||
else:
|
||||
# Skip unknown parameter.
|
||||
buf.pull_bytes(param_len)
|
||||
|
||||
if buf.tell() != param_start + param_len:
|
||||
raise ValueError("Transport parameter length does not match")
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def push_quic_transport_parameters(
|
||||
buf: Buffer, params: QuicTransportParameters
|
||||
) -> None:
|
||||
for param_id, (param_name, param_type) in PARAMS.items():
|
||||
param_value = getattr(params, param_name)
|
||||
if param_value is not None and param_value is not False:
|
||||
param_buf = Buffer(capacity=65536)
|
||||
if param_type is int:
|
||||
param_buf.push_uint_var(param_value)
|
||||
elif param_type is bytes:
|
||||
param_buf.push_bytes(param_value)
|
||||
elif param_type is QuicPreferredAddress:
|
||||
push_quic_preferred_address(param_buf, param_value)
|
||||
elif param_type is QuicVersionInformation:
|
||||
push_quic_version_information(param_buf, param_value)
|
||||
buf.push_uint_var(param_id)
|
||||
buf.push_uint_var(param_buf.tell())
|
||||
buf.push_bytes(param_buf.data)
|
||||
|
||||
|
||||
# FRAMES
|
||||
|
||||
|
||||
class QuicFrameType(IntEnum):
|
||||
PADDING = 0x00
|
||||
PING = 0x01
|
||||
ACK = 0x02
|
||||
ACK_ECN = 0x03
|
||||
RESET_STREAM = 0x04
|
||||
STOP_SENDING = 0x05
|
||||
CRYPTO = 0x06
|
||||
NEW_TOKEN = 0x07
|
||||
STREAM_BASE = 0x08
|
||||
MAX_DATA = 0x10
|
||||
MAX_STREAM_DATA = 0x11
|
||||
MAX_STREAMS_BIDI = 0x12
|
||||
MAX_STREAMS_UNI = 0x13
|
||||
DATA_BLOCKED = 0x14
|
||||
STREAM_DATA_BLOCKED = 0x15
|
||||
STREAMS_BLOCKED_BIDI = 0x16
|
||||
STREAMS_BLOCKED_UNI = 0x17
|
||||
NEW_CONNECTION_ID = 0x18
|
||||
RETIRE_CONNECTION_ID = 0x19
|
||||
PATH_CHALLENGE = 0x1A
|
||||
PATH_RESPONSE = 0x1B
|
||||
TRANSPORT_CLOSE = 0x1C
|
||||
APPLICATION_CLOSE = 0x1D
|
||||
HANDSHAKE_DONE = 0x1E
|
||||
DATAGRAM = 0x30
|
||||
DATAGRAM_WITH_LENGTH = 0x31
|
||||
|
||||
|
||||
NON_ACK_ELICITING_FRAME_TYPES = frozenset(
|
||||
[
|
||||
QuicFrameType.ACK,
|
||||
QuicFrameType.ACK_ECN,
|
||||
QuicFrameType.PADDING,
|
||||
QuicFrameType.TRANSPORT_CLOSE,
|
||||
QuicFrameType.APPLICATION_CLOSE,
|
||||
]
|
||||
)
|
||||
NON_IN_FLIGHT_FRAME_TYPES = frozenset(
|
||||
[
|
||||
QuicFrameType.ACK,
|
||||
QuicFrameType.ACK_ECN,
|
||||
QuicFrameType.TRANSPORT_CLOSE,
|
||||
QuicFrameType.APPLICATION_CLOSE,
|
||||
]
|
||||
)
|
||||
|
||||
PROBING_FRAME_TYPES = frozenset(
|
||||
[
|
||||
QuicFrameType.PATH_CHALLENGE,
|
||||
QuicFrameType.PATH_RESPONSE,
|
||||
QuicFrameType.PADDING,
|
||||
QuicFrameType.NEW_CONNECTION_ID,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuicResetStreamFrame:
|
||||
error_code: int
|
||||
final_size: int
|
||||
stream_id: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuicStopSendingFrame:
|
||||
error_code: int
|
||||
stream_id: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuicStreamFrame:
|
||||
data: bytes = b""
|
||||
fin: bool = False
|
||||
offset: int = 0
|
||||
|
||||
|
||||
def pull_ack_frame(buf: Buffer) -> Tuple[RangeSet, int]:
|
||||
rangeset = RangeSet()
|
||||
end = buf.pull_uint_var() # largest acknowledged
|
||||
delay = buf.pull_uint_var()
|
||||
ack_range_count = buf.pull_uint_var()
|
||||
ack_count = buf.pull_uint_var() # first ack range
|
||||
rangeset.add(end - ack_count, end + 1)
|
||||
end -= ack_count
|
||||
for _ in range(ack_range_count):
|
||||
end -= buf.pull_uint_var() + 2
|
||||
ack_count = buf.pull_uint_var()
|
||||
rangeset.add(end - ack_count, end + 1)
|
||||
end -= ack_count
|
||||
return rangeset, delay
|
||||
|
||||
|
||||
def push_ack_frame(buf: Buffer, rangeset: RangeSet, delay: int) -> int:
|
||||
ranges = len(rangeset)
|
||||
index = ranges - 1
|
||||
r = rangeset[index]
|
||||
buf.push_uint_var(r.stop - 1)
|
||||
buf.push_uint_var(delay)
|
||||
buf.push_uint_var(index)
|
||||
buf.push_uint_var(r.stop - 1 - r.start)
|
||||
start = r.start
|
||||
while index > 0:
|
||||
index -= 1
|
||||
r = rangeset[index]
|
||||
buf.push_uint_var(start - r.stop - 1)
|
||||
buf.push_uint_var(r.stop - r.start - 1)
|
||||
start = r.start
|
||||
return ranges
|
||||
384
venv/Lib/site-packages/aioquic/quic/packet_builder.py
Normal file
384
venv/Lib/site-packages/aioquic/quic/packet_builder.py
Normal file
@@ -0,0 +1,384 @@
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ..buffer import Buffer, size_uint_var
|
||||
from ..tls import Epoch
|
||||
from .crypto import CryptoPair
|
||||
from .logger import QuicLoggerTrace
|
||||
from .packet import (
|
||||
NON_ACK_ELICITING_FRAME_TYPES,
|
||||
NON_IN_FLIGHT_FRAME_TYPES,
|
||||
PACKET_FIXED_BIT,
|
||||
PACKET_NUMBER_MAX_SIZE,
|
||||
QuicFrameType,
|
||||
QuicPacketType,
|
||||
encode_long_header_first_byte,
|
||||
)
|
||||
|
||||
PACKET_LENGTH_SEND_SIZE = 2
|
||||
PACKET_NUMBER_SEND_SIZE = 2
|
||||
|
||||
|
||||
QuicDeliveryHandler = Callable[..., None]
|
||||
|
||||
|
||||
class QuicDeliveryState(Enum):
|
||||
ACKED = 0
|
||||
LOST = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuicSentPacket:
|
||||
epoch: Epoch
|
||||
in_flight: bool
|
||||
is_ack_eliciting: bool
|
||||
is_crypto_packet: bool
|
||||
packet_number: int
|
||||
packet_type: QuicPacketType
|
||||
sent_time: Optional[float] = None
|
||||
sent_bytes: int = 0
|
||||
|
||||
delivery_handlers: List[Tuple[QuicDeliveryHandler, Any]] = field(
|
||||
default_factory=list
|
||||
)
|
||||
quic_logger_frames: List[Dict] = field(default_factory=list)
|
||||
|
||||
|
||||
class QuicPacketBuilderStop(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class QuicPacketBuilder:
|
||||
"""
|
||||
Helper for building QUIC packets.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
host_cid: bytes,
|
||||
peer_cid: bytes,
|
||||
version: int,
|
||||
is_client: bool,
|
||||
max_datagram_size: int,
|
||||
packet_number: int = 0,
|
||||
peer_token: bytes = b"",
|
||||
quic_logger: Optional[QuicLoggerTrace] = None,
|
||||
spin_bit: bool = False,
|
||||
):
|
||||
self.max_flight_bytes: Optional[int] = None
|
||||
self.max_total_bytes: Optional[int] = None
|
||||
self.quic_logger_frames: Optional[List[Dict]] = None
|
||||
|
||||
self._host_cid = host_cid
|
||||
self._is_client = is_client
|
||||
self._peer_cid = peer_cid
|
||||
self._peer_token = peer_token
|
||||
self._quic_logger = quic_logger
|
||||
self._spin_bit = spin_bit
|
||||
self._version = version
|
||||
|
||||
# assembled datagrams and packets
|
||||
self._datagrams: List[bytes] = []
|
||||
self._datagram_flight_bytes = 0
|
||||
self._datagram_init = True
|
||||
self._datagram_needs_padding = False
|
||||
self._packets: List[QuicSentPacket] = []
|
||||
self._flight_bytes = 0
|
||||
self._total_bytes = 0
|
||||
|
||||
# current packet
|
||||
self._header_size = 0
|
||||
self._packet: Optional[QuicSentPacket] = None
|
||||
self._packet_crypto: Optional[CryptoPair] = None
|
||||
self._packet_number = packet_number
|
||||
self._packet_start = 0
|
||||
self._packet_type: Optional[QuicPacketType] = None
|
||||
|
||||
self._buffer = Buffer(max_datagram_size)
|
||||
self._buffer_capacity = max_datagram_size
|
||||
self._flight_capacity = max_datagram_size
|
||||
|
||||
@property
|
||||
def packet_is_empty(self) -> bool:
|
||||
"""
|
||||
Returns `True` if the current packet is empty.
|
||||
"""
|
||||
assert self._packet is not None
|
||||
packet_size = self._buffer.tell() - self._packet_start
|
||||
return packet_size <= self._header_size
|
||||
|
||||
@property
|
||||
def packet_number(self) -> int:
|
||||
"""
|
||||
Returns the packet number for the next packet.
|
||||
"""
|
||||
return self._packet_number
|
||||
|
||||
@property
|
||||
def remaining_buffer_space(self) -> int:
|
||||
"""
|
||||
Returns the remaining number of bytes which can be used in
|
||||
the current packet.
|
||||
"""
|
||||
return (
|
||||
self._buffer_capacity
|
||||
- self._buffer.tell()
|
||||
- self._packet_crypto.aead_tag_size
|
||||
)
|
||||
|
||||
@property
|
||||
def remaining_flight_space(self) -> int:
|
||||
"""
|
||||
Returns the remaining number of bytes which can be used in
|
||||
the current packet.
|
||||
"""
|
||||
return (
|
||||
self._flight_capacity
|
||||
- self._buffer.tell()
|
||||
- self._packet_crypto.aead_tag_size
|
||||
)
|
||||
|
||||
def flush(self) -> Tuple[List[bytes], List[QuicSentPacket]]:
|
||||
"""
|
||||
Returns the assembled datagrams.
|
||||
"""
|
||||
if self._packet is not None:
|
||||
self._end_packet()
|
||||
self._flush_current_datagram()
|
||||
|
||||
datagrams = self._datagrams
|
||||
packets = self._packets
|
||||
self._datagrams = []
|
||||
self._packets = []
|
||||
return datagrams, packets
|
||||
|
||||
def start_frame(
|
||||
self,
|
||||
frame_type: int,
|
||||
capacity: int = 1,
|
||||
handler: Optional[QuicDeliveryHandler] = None,
|
||||
handler_args: Sequence[Any] = [],
|
||||
) -> Buffer:
|
||||
"""
|
||||
Starts a new frame.
|
||||
"""
|
||||
if self.remaining_buffer_space < capacity or (
|
||||
frame_type not in NON_IN_FLIGHT_FRAME_TYPES
|
||||
and self.remaining_flight_space < capacity
|
||||
):
|
||||
raise QuicPacketBuilderStop
|
||||
|
||||
self._buffer.push_uint_var(frame_type)
|
||||
if frame_type not in NON_ACK_ELICITING_FRAME_TYPES:
|
||||
self._packet.is_ack_eliciting = True
|
||||
if frame_type not in NON_IN_FLIGHT_FRAME_TYPES:
|
||||
self._packet.in_flight = True
|
||||
if frame_type == QuicFrameType.CRYPTO:
|
||||
self._packet.is_crypto_packet = True
|
||||
if handler is not None:
|
||||
self._packet.delivery_handlers.append((handler, handler_args))
|
||||
return self._buffer
|
||||
|
||||
def start_packet(self, packet_type: QuicPacketType, crypto: CryptoPair) -> None:
|
||||
"""
|
||||
Starts a new packet.
|
||||
"""
|
||||
assert packet_type in (
|
||||
QuicPacketType.INITIAL,
|
||||
QuicPacketType.HANDSHAKE,
|
||||
QuicPacketType.ZERO_RTT,
|
||||
QuicPacketType.ONE_RTT,
|
||||
), "Invalid packet type"
|
||||
buf = self._buffer
|
||||
|
||||
# finish previous datagram
|
||||
if self._packet is not None:
|
||||
self._end_packet()
|
||||
|
||||
# if there is too little space remaining, start a new datagram
|
||||
# FIXME: the limit is arbitrary!
|
||||
packet_start = buf.tell()
|
||||
if self._buffer_capacity - packet_start < 128:
|
||||
self._flush_current_datagram()
|
||||
packet_start = 0
|
||||
|
||||
# initialize datagram if needed
|
||||
if self._datagram_init:
|
||||
if self.max_total_bytes is not None:
|
||||
remaining_total_bytes = self.max_total_bytes - self._total_bytes
|
||||
if remaining_total_bytes < self._buffer_capacity:
|
||||
self._buffer_capacity = remaining_total_bytes
|
||||
|
||||
self._flight_capacity = self._buffer_capacity
|
||||
if self.max_flight_bytes is not None:
|
||||
remaining_flight_bytes = self.max_flight_bytes - self._flight_bytes
|
||||
if remaining_flight_bytes < self._flight_capacity:
|
||||
self._flight_capacity = remaining_flight_bytes
|
||||
self._datagram_flight_bytes = 0
|
||||
self._datagram_init = False
|
||||
self._datagram_needs_padding = False
|
||||
|
||||
# calculate header size
|
||||
if packet_type != QuicPacketType.ONE_RTT:
|
||||
header_size = 11 + len(self._peer_cid) + len(self._host_cid)
|
||||
if packet_type == QuicPacketType.INITIAL:
|
||||
token_length = len(self._peer_token)
|
||||
header_size += size_uint_var(token_length) + token_length
|
||||
else:
|
||||
header_size = 3 + len(self._peer_cid)
|
||||
|
||||
# check we have enough space
|
||||
if packet_start + header_size >= self._buffer_capacity:
|
||||
raise QuicPacketBuilderStop
|
||||
|
||||
# determine ack epoch
|
||||
if packet_type == QuicPacketType.INITIAL:
|
||||
epoch = Epoch.INITIAL
|
||||
elif packet_type == QuicPacketType.HANDSHAKE:
|
||||
epoch = Epoch.HANDSHAKE
|
||||
else:
|
||||
epoch = Epoch.ONE_RTT
|
||||
|
||||
self._header_size = header_size
|
||||
self._packet = QuicSentPacket(
|
||||
epoch=epoch,
|
||||
in_flight=False,
|
||||
is_ack_eliciting=False,
|
||||
is_crypto_packet=False,
|
||||
packet_number=self._packet_number,
|
||||
packet_type=packet_type,
|
||||
)
|
||||
self._packet_crypto = crypto
|
||||
self._packet_start = packet_start
|
||||
self._packet_type = packet_type
|
||||
self.quic_logger_frames = self._packet.quic_logger_frames
|
||||
|
||||
buf.seek(self._packet_start + self._header_size)
|
||||
|
||||
def _end_packet(self) -> None:
|
||||
"""
|
||||
Ends the current packet.
|
||||
"""
|
||||
buf = self._buffer
|
||||
packet_size = buf.tell() - self._packet_start
|
||||
if packet_size > self._header_size:
|
||||
# padding to ensure sufficient sample size
|
||||
padding_size = (
|
||||
PACKET_NUMBER_MAX_SIZE
|
||||
- PACKET_NUMBER_SEND_SIZE
|
||||
+ self._header_size
|
||||
- packet_size
|
||||
)
|
||||
|
||||
# Padding for datagrams containing initial packets; see RFC 9000
|
||||
# section 14.1.
|
||||
if (
|
||||
self._is_client or self._packet.is_ack_eliciting
|
||||
) and self._packet_type == QuicPacketType.INITIAL:
|
||||
self._datagram_needs_padding = True
|
||||
|
||||
# For datagrams containing 1-RTT data, we *must* apply the padding
|
||||
# inside the packet, we cannot tack bytes onto the end of the
|
||||
# datagram.
|
||||
if (
|
||||
self._datagram_needs_padding
|
||||
and self._packet_type == QuicPacketType.ONE_RTT
|
||||
):
|
||||
if self.remaining_flight_space > padding_size:
|
||||
padding_size = self.remaining_flight_space
|
||||
self._datagram_needs_padding = False
|
||||
|
||||
# write padding
|
||||
if padding_size > 0:
|
||||
buf.push_bytes(bytes(padding_size))
|
||||
packet_size += padding_size
|
||||
self._packet.in_flight = True
|
||||
|
||||
# log frame
|
||||
if self._quic_logger is not None:
|
||||
self._packet.quic_logger_frames.append(
|
||||
self._quic_logger.encode_padding_frame()
|
||||
)
|
||||
|
||||
# write header
|
||||
if self._packet_type != QuicPacketType.ONE_RTT:
|
||||
length = (
|
||||
packet_size
|
||||
- self._header_size
|
||||
+ PACKET_NUMBER_SEND_SIZE
|
||||
+ self._packet_crypto.aead_tag_size
|
||||
)
|
||||
|
||||
buf.seek(self._packet_start)
|
||||
buf.push_uint8(
|
||||
encode_long_header_first_byte(
|
||||
self._version, self._packet_type, PACKET_NUMBER_SEND_SIZE - 1
|
||||
)
|
||||
)
|
||||
buf.push_uint32(self._version)
|
||||
buf.push_uint8(len(self._peer_cid))
|
||||
buf.push_bytes(self._peer_cid)
|
||||
buf.push_uint8(len(self._host_cid))
|
||||
buf.push_bytes(self._host_cid)
|
||||
if self._packet_type == QuicPacketType.INITIAL:
|
||||
buf.push_uint_var(len(self._peer_token))
|
||||
buf.push_bytes(self._peer_token)
|
||||
buf.push_uint16(length | 0x4000)
|
||||
buf.push_uint16(self._packet_number & 0xFFFF)
|
||||
else:
|
||||
buf.seek(self._packet_start)
|
||||
buf.push_uint8(
|
||||
PACKET_FIXED_BIT
|
||||
| (self._spin_bit << 5)
|
||||
| (self._packet_crypto.key_phase << 2)
|
||||
| (PACKET_NUMBER_SEND_SIZE - 1)
|
||||
)
|
||||
buf.push_bytes(self._peer_cid)
|
||||
buf.push_uint16(self._packet_number & 0xFFFF)
|
||||
|
||||
# encrypt in place
|
||||
plain = buf.data_slice(self._packet_start, self._packet_start + packet_size)
|
||||
buf.seek(self._packet_start)
|
||||
buf.push_bytes(
|
||||
self._packet_crypto.encrypt_packet(
|
||||
plain[0 : self._header_size],
|
||||
plain[self._header_size : packet_size],
|
||||
self._packet_number,
|
||||
)
|
||||
)
|
||||
self._packet.sent_bytes = buf.tell() - self._packet_start
|
||||
self._packets.append(self._packet)
|
||||
if self._packet.in_flight:
|
||||
self._datagram_flight_bytes += self._packet.sent_bytes
|
||||
|
||||
# Short header packets cannot be coalesced, we need a new datagram.
|
||||
if self._packet_type == QuicPacketType.ONE_RTT:
|
||||
self._flush_current_datagram()
|
||||
|
||||
self._packet_number += 1
|
||||
else:
|
||||
# "cancel" the packet
|
||||
buf.seek(self._packet_start)
|
||||
|
||||
self._packet = None
|
||||
self.quic_logger_frames = None
|
||||
|
||||
def _flush_current_datagram(self) -> None:
|
||||
datagram_bytes = self._buffer.tell()
|
||||
if datagram_bytes:
|
||||
# Padding for datagrams containing initial packets; see RFC 9000
|
||||
# section 14.1.
|
||||
if self._datagram_needs_padding:
|
||||
extra_bytes = self._flight_capacity - self._buffer.tell()
|
||||
if extra_bytes > 0:
|
||||
self._buffer.push_bytes(bytes(extra_bytes))
|
||||
self._datagram_flight_bytes += extra_bytes
|
||||
datagram_bytes += extra_bytes
|
||||
|
||||
self._datagrams.append(self._buffer.data)
|
||||
self._flight_bytes += self._datagram_flight_bytes
|
||||
self._total_bytes += datagram_bytes
|
||||
self._datagram_init = True
|
||||
self._buffer.seek(0)
|
||||
98
venv/Lib/site-packages/aioquic/quic/rangeset.py
Normal file
98
venv/Lib/site-packages/aioquic/quic/rangeset.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Iterable, List, Optional
|
||||
|
||||
|
||||
class RangeSet(Sequence):
|
||||
def __init__(self, ranges: Iterable[range] = []):
|
||||
self.__ranges: List[range] = []
|
||||
for r in ranges:
|
||||
assert r.step == 1
|
||||
self.add(r.start, r.stop)
|
||||
|
||||
def add(self, start: int, stop: Optional[int] = None) -> None:
|
||||
if stop is None:
|
||||
stop = start + 1
|
||||
assert stop > start
|
||||
|
||||
for i, r in enumerate(self.__ranges):
|
||||
# the added range is entirely before current item, insert here
|
||||
if stop < r.start:
|
||||
self.__ranges.insert(i, range(start, stop))
|
||||
return
|
||||
|
||||
# the added range is entirely after current item, keep looking
|
||||
if start > r.stop:
|
||||
continue
|
||||
|
||||
# the added range touches the current item, merge it
|
||||
start = min(start, r.start)
|
||||
stop = max(stop, r.stop)
|
||||
while i < len(self.__ranges) - 1 and self.__ranges[i + 1].start <= stop:
|
||||
stop = max(self.__ranges[i + 1].stop, stop)
|
||||
self.__ranges.pop(i + 1)
|
||||
self.__ranges[i] = range(start, stop)
|
||||
return
|
||||
|
||||
# the added range is entirely after all existing items, append it
|
||||
self.__ranges.append(range(start, stop))
|
||||
|
||||
def bounds(self) -> range:
|
||||
return range(self.__ranges[0].start, self.__ranges[-1].stop)
|
||||
|
||||
def shift(self) -> range:
|
||||
return self.__ranges.pop(0)
|
||||
|
||||
def subtract(self, start: int, stop: int) -> None:
|
||||
assert stop > start
|
||||
|
||||
i = 0
|
||||
while i < len(self.__ranges):
|
||||
r = self.__ranges[i]
|
||||
|
||||
# the removed range is entirely before current item, stop here
|
||||
if stop <= r.start:
|
||||
return
|
||||
|
||||
# the removed range is entirely after current item, keep looking
|
||||
if start >= r.stop:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# the removed range completely covers the current item, remove it
|
||||
if start <= r.start and stop >= r.stop:
|
||||
self.__ranges.pop(i)
|
||||
continue
|
||||
|
||||
# the removed range touches the current item
|
||||
if start > r.start:
|
||||
self.__ranges[i] = range(r.start, start)
|
||||
if stop < r.stop:
|
||||
self.__ranges.insert(i + 1, range(stop, r.stop))
|
||||
else:
|
||||
self.__ranges[i] = range(stop, r.stop)
|
||||
|
||||
i += 1
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def __contains__(self, val: Any) -> bool:
|
||||
for r in self.__ranges:
|
||||
if val in r:
|
||||
return True
|
||||
return False
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, RangeSet):
|
||||
return NotImplemented
|
||||
|
||||
return self.__ranges == other.__ranges
|
||||
|
||||
def __getitem__(self, key: Any) -> range:
|
||||
return self.__ranges[key]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.__ranges)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "RangeSet({})".format(repr(self.__ranges))
|
||||
389
venv/Lib/site-packages/aioquic/quic/recovery.py
Normal file
389
venv/Lib/site-packages/aioquic/quic/recovery.py
Normal file
@@ -0,0 +1,389 @@
|
||||
import logging
|
||||
import math
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional
|
||||
|
||||
from .congestion import cubic, reno # noqa
|
||||
from .congestion.base import K_GRANULARITY, create_congestion_control
|
||||
from .logger import QuicLoggerTrace
|
||||
from .packet_builder import QuicDeliveryState, QuicSentPacket
|
||||
from .rangeset import RangeSet
|
||||
|
||||
# loss detection
|
||||
K_PACKET_THRESHOLD = 3
|
||||
K_TIME_THRESHOLD = 9 / 8
|
||||
K_MICRO_SECOND = 0.000001
|
||||
K_SECOND = 1.0
|
||||
|
||||
|
||||
class QuicPacketSpace:
|
||||
def __init__(self) -> None:
|
||||
self.ack_at: Optional[float] = None
|
||||
self.ack_queue = RangeSet()
|
||||
self.discarded = False
|
||||
self.expected_packet_number = 0
|
||||
self.largest_received_packet = -1
|
||||
self.largest_received_time: Optional[float] = None
|
||||
|
||||
# sent packets and loss
|
||||
self.ack_eliciting_in_flight = 0
|
||||
self.largest_acked_packet = 0
|
||||
self.loss_time: Optional[float] = None
|
||||
self.sent_packets: Dict[int, QuicSentPacket] = {}
|
||||
|
||||
|
||||
class QuicPacketPacer:
|
||||
def __init__(self, *, max_datagram_size: int) -> None:
|
||||
self._max_datagram_size = max_datagram_size
|
||||
self.bucket_max: float = 0.0
|
||||
self.bucket_time: float = 0.0
|
||||
self.evaluation_time: float = 0.0
|
||||
self.packet_time: Optional[float] = None
|
||||
|
||||
def next_send_time(self, now: float) -> float:
|
||||
if self.packet_time is not None:
|
||||
self.update_bucket(now=now)
|
||||
if self.bucket_time <= 0:
|
||||
return now + self.packet_time
|
||||
return None
|
||||
|
||||
def update_after_send(self, now: float) -> None:
|
||||
if self.packet_time is not None:
|
||||
self.update_bucket(now=now)
|
||||
if self.bucket_time < self.packet_time:
|
||||
self.bucket_time = 0.0
|
||||
else:
|
||||
self.bucket_time -= self.packet_time
|
||||
|
||||
def update_bucket(self, now: float) -> None:
|
||||
if now > self.evaluation_time:
|
||||
self.bucket_time = min(
|
||||
self.bucket_time + (now - self.evaluation_time), self.bucket_max
|
||||
)
|
||||
self.evaluation_time = now
|
||||
|
||||
def update_rate(self, congestion_window: int, smoothed_rtt: float) -> None:
|
||||
pacing_rate = congestion_window / max(smoothed_rtt, K_MICRO_SECOND)
|
||||
self.packet_time = max(
|
||||
K_MICRO_SECOND, min(self._max_datagram_size / pacing_rate, K_SECOND)
|
||||
)
|
||||
|
||||
self.bucket_max = (
|
||||
max(
|
||||
2 * self._max_datagram_size,
|
||||
min(congestion_window // 4, 16 * self._max_datagram_size),
|
||||
)
|
||||
/ pacing_rate
|
||||
)
|
||||
if self.bucket_time > self.bucket_max:
|
||||
self.bucket_time = self.bucket_max
|
||||
|
||||
|
||||
class QuicPacketRecovery:
|
||||
"""
|
||||
Packet loss and congestion controller.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
congestion_control_algorithm: str,
|
||||
initial_rtt: float,
|
||||
max_datagram_size: int,
|
||||
peer_completed_address_validation: bool,
|
||||
send_probe: Callable[[], None],
|
||||
logger: Optional[logging.LoggerAdapter] = None,
|
||||
quic_logger: Optional[QuicLoggerTrace] = None,
|
||||
) -> None:
|
||||
self.max_ack_delay = 0.025
|
||||
self.peer_completed_address_validation = peer_completed_address_validation
|
||||
self.spaces: List[QuicPacketSpace] = []
|
||||
|
||||
# callbacks
|
||||
self._logger = logger
|
||||
self._quic_logger = quic_logger
|
||||
self._send_probe = send_probe
|
||||
|
||||
# loss detection
|
||||
self._pto_count = 0
|
||||
self._rtt_initial = initial_rtt
|
||||
self._rtt_initialized = False
|
||||
self._rtt_latest = 0.0
|
||||
self._rtt_min = math.inf
|
||||
self._rtt_smoothed = 0.0
|
||||
self._rtt_variance = 0.0
|
||||
self._time_of_last_sent_ack_eliciting_packet = 0.0
|
||||
|
||||
# congestion control
|
||||
self._cc = create_congestion_control(
|
||||
congestion_control_algorithm, max_datagram_size=max_datagram_size
|
||||
)
|
||||
self._pacer = QuicPacketPacer(max_datagram_size=max_datagram_size)
|
||||
|
||||
@property
|
||||
def bytes_in_flight(self) -> int:
|
||||
return self._cc.bytes_in_flight
|
||||
|
||||
@property
|
||||
def congestion_window(self) -> int:
|
||||
return self._cc.congestion_window
|
||||
|
||||
def discard_space(self, space: QuicPacketSpace) -> None:
|
||||
assert space in self.spaces
|
||||
|
||||
self._cc.on_packets_expired(
|
||||
packets=filter(lambda x: x.in_flight, space.sent_packets.values())
|
||||
)
|
||||
space.sent_packets.clear()
|
||||
|
||||
space.ack_at = None
|
||||
space.ack_eliciting_in_flight = 0
|
||||
space.loss_time = None
|
||||
|
||||
# reset PTO count
|
||||
self._pto_count = 0
|
||||
|
||||
if self._quic_logger is not None:
|
||||
self._log_metrics_updated()
|
||||
|
||||
def get_loss_detection_time(self) -> float:
|
||||
# loss timer
|
||||
loss_space = self._get_loss_space()
|
||||
if loss_space is not None:
|
||||
return loss_space.loss_time
|
||||
|
||||
# packet timer
|
||||
if (
|
||||
not self.peer_completed_address_validation
|
||||
or sum(space.ack_eliciting_in_flight for space in self.spaces) > 0
|
||||
):
|
||||
timeout = self.get_probe_timeout() * (2**self._pto_count)
|
||||
return self._time_of_last_sent_ack_eliciting_packet + timeout
|
||||
|
||||
return None
|
||||
|
||||
def get_probe_timeout(self) -> float:
|
||||
if not self._rtt_initialized:
|
||||
return 2 * self._rtt_initial
|
||||
return (
|
||||
self._rtt_smoothed
|
||||
+ max(4 * self._rtt_variance, K_GRANULARITY)
|
||||
+ self.max_ack_delay
|
||||
)
|
||||
|
||||
def on_ack_received(
|
||||
self,
|
||||
*,
|
||||
ack_rangeset: RangeSet,
|
||||
ack_delay: float,
|
||||
now: float,
|
||||
space: QuicPacketSpace,
|
||||
) -> None:
|
||||
"""
|
||||
Update metrics as the result of an ACK being received.
|
||||
"""
|
||||
is_ack_eliciting = False
|
||||
largest_acked = ack_rangeset.bounds().stop - 1
|
||||
largest_newly_acked = None
|
||||
largest_sent_time = None
|
||||
|
||||
if largest_acked > space.largest_acked_packet:
|
||||
space.largest_acked_packet = largest_acked
|
||||
|
||||
for packet_number in sorted(space.sent_packets.keys()):
|
||||
if packet_number > largest_acked:
|
||||
break
|
||||
if packet_number in ack_rangeset:
|
||||
# remove packet and update counters
|
||||
packet = space.sent_packets.pop(packet_number)
|
||||
if packet.is_ack_eliciting:
|
||||
is_ack_eliciting = True
|
||||
space.ack_eliciting_in_flight -= 1
|
||||
if packet.in_flight:
|
||||
self._cc.on_packet_acked(packet=packet, now=now)
|
||||
largest_newly_acked = packet_number
|
||||
largest_sent_time = packet.sent_time
|
||||
|
||||
# trigger callbacks
|
||||
for handler, args in packet.delivery_handlers:
|
||||
handler(QuicDeliveryState.ACKED, *args)
|
||||
|
||||
# nothing to do if there are no newly acked packets
|
||||
if largest_newly_acked is None:
|
||||
return
|
||||
|
||||
if largest_acked == largest_newly_acked and is_ack_eliciting:
|
||||
latest_rtt = now - largest_sent_time
|
||||
log_rtt = True
|
||||
|
||||
# limit ACK delay to max_ack_delay
|
||||
ack_delay = min(ack_delay, self.max_ack_delay)
|
||||
|
||||
# update RTT estimate, which cannot be < 1 ms
|
||||
self._rtt_latest = max(latest_rtt, 0.001)
|
||||
if self._rtt_latest < self._rtt_min:
|
||||
self._rtt_min = self._rtt_latest
|
||||
if self._rtt_latest > self._rtt_min + ack_delay:
|
||||
self._rtt_latest -= ack_delay
|
||||
|
||||
if not self._rtt_initialized:
|
||||
self._rtt_initialized = True
|
||||
self._rtt_variance = latest_rtt / 2
|
||||
self._rtt_smoothed = latest_rtt
|
||||
else:
|
||||
self._rtt_variance = 3 / 4 * self._rtt_variance + 1 / 4 * abs(
|
||||
self._rtt_min - self._rtt_latest
|
||||
)
|
||||
self._rtt_smoothed = (
|
||||
7 / 8 * self._rtt_smoothed + 1 / 8 * self._rtt_latest
|
||||
)
|
||||
|
||||
# inform congestion controller
|
||||
self._cc.on_rtt_measurement(now=now, rtt=latest_rtt)
|
||||
self._pacer.update_rate(
|
||||
congestion_window=self._cc.congestion_window,
|
||||
smoothed_rtt=self._rtt_smoothed,
|
||||
)
|
||||
|
||||
else:
|
||||
log_rtt = False
|
||||
|
||||
self._detect_loss(now=now, space=space)
|
||||
|
||||
# reset PTO count
|
||||
self._pto_count = 0
|
||||
|
||||
if self._quic_logger is not None:
|
||||
self._log_metrics_updated(log_rtt=log_rtt)
|
||||
|
||||
def on_loss_detection_timeout(self, *, now: float) -> None:
|
||||
loss_space = self._get_loss_space()
|
||||
if loss_space is not None:
|
||||
self._detect_loss(now=now, space=loss_space)
|
||||
else:
|
||||
self._pto_count += 1
|
||||
self.reschedule_data(now=now)
|
||||
|
||||
def on_packet_sent(self, *, packet: QuicSentPacket, space: QuicPacketSpace) -> None:
|
||||
space.sent_packets[packet.packet_number] = packet
|
||||
|
||||
if packet.is_ack_eliciting:
|
||||
space.ack_eliciting_in_flight += 1
|
||||
if packet.in_flight:
|
||||
if packet.is_ack_eliciting:
|
||||
self._time_of_last_sent_ack_eliciting_packet = packet.sent_time
|
||||
|
||||
# add packet to bytes in flight
|
||||
self._cc.on_packet_sent(packet=packet)
|
||||
|
||||
if self._quic_logger is not None:
|
||||
self._log_metrics_updated()
|
||||
|
||||
def reschedule_data(self, *, now: float) -> None:
|
||||
"""
|
||||
Schedule some data for retransmission.
|
||||
"""
|
||||
# if there is any outstanding CRYPTO, retransmit it
|
||||
crypto_scheduled = False
|
||||
for space in self.spaces:
|
||||
packets = tuple(
|
||||
filter(lambda i: i.is_crypto_packet, space.sent_packets.values())
|
||||
)
|
||||
if packets:
|
||||
self._on_packets_lost(now=now, packets=packets, space=space)
|
||||
crypto_scheduled = True
|
||||
if crypto_scheduled and self._logger is not None:
|
||||
self._logger.debug("Scheduled CRYPTO data for retransmission")
|
||||
|
||||
# ensure an ACK-elliciting packet is sent
|
||||
self._send_probe()
|
||||
|
||||
def _detect_loss(self, *, now: float, space: QuicPacketSpace) -> None:
|
||||
"""
|
||||
Check whether any packets should be declared lost.
|
||||
"""
|
||||
loss_delay = K_TIME_THRESHOLD * (
|
||||
max(self._rtt_latest, self._rtt_smoothed)
|
||||
if self._rtt_initialized
|
||||
else self._rtt_initial
|
||||
)
|
||||
packet_threshold = space.largest_acked_packet - K_PACKET_THRESHOLD
|
||||
time_threshold = now - loss_delay
|
||||
|
||||
lost_packets = []
|
||||
space.loss_time = None
|
||||
for packet_number, packet in space.sent_packets.items():
|
||||
if packet_number > space.largest_acked_packet:
|
||||
break
|
||||
|
||||
if packet_number <= packet_threshold or packet.sent_time <= time_threshold:
|
||||
lost_packets.append(packet)
|
||||
else:
|
||||
packet_loss_time = packet.sent_time + loss_delay
|
||||
if space.loss_time is None or space.loss_time > packet_loss_time:
|
||||
space.loss_time = packet_loss_time
|
||||
|
||||
self._on_packets_lost(now=now, packets=lost_packets, space=space)
|
||||
|
||||
def _get_loss_space(self) -> Optional[QuicPacketSpace]:
|
||||
loss_space = None
|
||||
for space in self.spaces:
|
||||
if space.loss_time is not None and (
|
||||
loss_space is None or space.loss_time < loss_space.loss_time
|
||||
):
|
||||
loss_space = space
|
||||
return loss_space
|
||||
|
||||
def _log_metrics_updated(self, log_rtt=False) -> None:
|
||||
data: Dict[str, Any] = self._cc.get_log_data()
|
||||
|
||||
if log_rtt:
|
||||
data.update(
|
||||
{
|
||||
"latest_rtt": self._quic_logger.encode_time(self._rtt_latest),
|
||||
"min_rtt": self._quic_logger.encode_time(self._rtt_min),
|
||||
"smoothed_rtt": self._quic_logger.encode_time(self._rtt_smoothed),
|
||||
"rtt_variance": self._quic_logger.encode_time(self._rtt_variance),
|
||||
}
|
||||
)
|
||||
|
||||
self._quic_logger.log_event(
|
||||
category="recovery", event="metrics_updated", data=data
|
||||
)
|
||||
|
||||
def _on_packets_lost(
|
||||
self, *, now: float, packets: Iterable[QuicSentPacket], space: QuicPacketSpace
|
||||
) -> None:
|
||||
lost_packets_cc = []
|
||||
for packet in packets:
|
||||
del space.sent_packets[packet.packet_number]
|
||||
|
||||
if packet.in_flight:
|
||||
lost_packets_cc.append(packet)
|
||||
|
||||
if packet.is_ack_eliciting:
|
||||
space.ack_eliciting_in_flight -= 1
|
||||
|
||||
if self._quic_logger is not None:
|
||||
self._quic_logger.log_event(
|
||||
category="recovery",
|
||||
event="packet_lost",
|
||||
data={
|
||||
"type": self._quic_logger.packet_type(packet.packet_type),
|
||||
"packet_number": packet.packet_number,
|
||||
},
|
||||
)
|
||||
self._log_metrics_updated()
|
||||
|
||||
# trigger callbacks
|
||||
for handler, args in packet.delivery_handlers:
|
||||
handler(QuicDeliveryState.LOST, *args)
|
||||
|
||||
# inform congestion controller
|
||||
if lost_packets_cc:
|
||||
self._cc.on_packets_lost(now=now, packets=lost_packets_cc)
|
||||
self._pacer.update_rate(
|
||||
congestion_window=self._cc.congestion_window,
|
||||
smoothed_rtt=self._rtt_smoothed,
|
||||
)
|
||||
if self._quic_logger is not None:
|
||||
self._log_metrics_updated()
|
||||
53
venv/Lib/site-packages/aioquic/quic/retry.py
Normal file
53
venv/Lib/site-packages/aioquic/quic/retry.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import ipaddress
|
||||
from typing import Tuple
|
||||
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.asymmetric import padding, rsa
|
||||
|
||||
from ..buffer import Buffer
|
||||
from ..tls import pull_opaque, push_opaque
|
||||
from .connection import NetworkAddress
|
||||
|
||||
|
||||
def encode_address(addr: NetworkAddress) -> bytes:
|
||||
return ipaddress.ip_address(addr[0]).packed + bytes([addr[1] >> 8, addr[1] & 0xFF])
|
||||
|
||||
|
||||
class QuicRetryTokenHandler:
|
||||
def __init__(self) -> None:
|
||||
self._key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
||||
|
||||
def create_token(
|
||||
self,
|
||||
addr: NetworkAddress,
|
||||
original_destination_connection_id: bytes,
|
||||
retry_source_connection_id: bytes,
|
||||
) -> bytes:
|
||||
buf = Buffer(capacity=512)
|
||||
push_opaque(buf, 1, encode_address(addr))
|
||||
push_opaque(buf, 1, original_destination_connection_id)
|
||||
push_opaque(buf, 1, retry_source_connection_id)
|
||||
return self._key.public_key().encrypt(
|
||||
buf.data,
|
||||
padding.OAEP(
|
||||
mgf=padding.MGF1(hashes.SHA256()), algorithm=hashes.SHA256(), label=None
|
||||
),
|
||||
)
|
||||
|
||||
def validate_token(self, addr: NetworkAddress, token: bytes) -> Tuple[bytes, bytes]:
|
||||
buf = Buffer(
|
||||
data=self._key.decrypt(
|
||||
token,
|
||||
padding.OAEP(
|
||||
mgf=padding.MGF1(hashes.SHA256()),
|
||||
algorithm=hashes.SHA256(),
|
||||
label=None,
|
||||
),
|
||||
)
|
||||
)
|
||||
encoded_addr = pull_opaque(buf, 1)
|
||||
original_destination_connection_id = pull_opaque(buf, 1)
|
||||
retry_source_connection_id = pull_opaque(buf, 1)
|
||||
if encoded_addr != encode_address(addr):
|
||||
raise ValueError("Remote address does not match.")
|
||||
return original_destination_connection_id, retry_source_connection_id
|
||||
364
venv/Lib/site-packages/aioquic/quic/stream.py
Normal file
364
venv/Lib/site-packages/aioquic/quic/stream.py
Normal file
@@ -0,0 +1,364 @@
|
||||
from typing import Optional
|
||||
|
||||
from . import events
|
||||
from .packet import (
|
||||
QuicErrorCode,
|
||||
QuicResetStreamFrame,
|
||||
QuicStopSendingFrame,
|
||||
QuicStreamFrame,
|
||||
)
|
||||
from .packet_builder import QuicDeliveryState
|
||||
from .rangeset import RangeSet
|
||||
|
||||
|
||||
class FinalSizeError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class StreamFinishedError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class QuicStreamReceiver:
|
||||
"""
|
||||
The receive part of a QUIC stream.
|
||||
|
||||
It finishes:
|
||||
- immediately for a send-only stream
|
||||
- upon reception of a STREAM_RESET frame
|
||||
- upon reception of a data frame with the FIN bit set
|
||||
"""
|
||||
|
||||
def __init__(self, stream_id: Optional[int], readable: bool) -> None:
|
||||
self.highest_offset = 0 # the highest offset ever seen
|
||||
self.is_finished = False
|
||||
self.stop_pending = False
|
||||
|
||||
self._buffer = bytearray()
|
||||
self._buffer_start = 0 # the offset for the start of the buffer
|
||||
self._final_size: Optional[int] = None
|
||||
self._ranges = RangeSet()
|
||||
self._stream_id = stream_id
|
||||
self._stop_error_code: Optional[int] = None
|
||||
|
||||
def get_stop_frame(self) -> QuicStopSendingFrame:
|
||||
self.stop_pending = False
|
||||
return QuicStopSendingFrame(
|
||||
error_code=self._stop_error_code,
|
||||
stream_id=self._stream_id,
|
||||
)
|
||||
|
||||
def starting_offset(self) -> int:
|
||||
return self._buffer_start
|
||||
|
||||
def handle_frame(
|
||||
self, frame: QuicStreamFrame
|
||||
) -> Optional[events.StreamDataReceived]:
|
||||
"""
|
||||
Handle a frame of received data.
|
||||
"""
|
||||
pos = frame.offset - self._buffer_start
|
||||
count = len(frame.data)
|
||||
frame_end = frame.offset + count
|
||||
|
||||
# we should receive no more data beyond FIN!
|
||||
if self._final_size is not None:
|
||||
if frame_end > self._final_size:
|
||||
raise FinalSizeError("Data received beyond final size")
|
||||
elif frame.fin and frame_end != self._final_size:
|
||||
raise FinalSizeError("Cannot change final size")
|
||||
if frame.fin:
|
||||
self._final_size = frame_end
|
||||
if frame_end > self.highest_offset:
|
||||
self.highest_offset = frame_end
|
||||
|
||||
# fast path: new in-order chunk
|
||||
if pos == 0 and count and not self._buffer:
|
||||
self._buffer_start += count
|
||||
if frame.fin:
|
||||
# all data up to the FIN has been received, we're done receiving
|
||||
self.is_finished = True
|
||||
return events.StreamDataReceived(
|
||||
data=frame.data, end_stream=frame.fin, stream_id=self._stream_id
|
||||
)
|
||||
|
||||
# discard duplicate data
|
||||
if pos < 0:
|
||||
frame.data = frame.data[-pos:]
|
||||
frame.offset -= pos
|
||||
pos = 0
|
||||
count = len(frame.data)
|
||||
|
||||
# marked received range
|
||||
if frame_end > frame.offset:
|
||||
self._ranges.add(frame.offset, frame_end)
|
||||
|
||||
# add new data
|
||||
gap = pos - len(self._buffer)
|
||||
if gap > 0:
|
||||
self._buffer += bytearray(gap)
|
||||
self._buffer[pos : pos + count] = frame.data
|
||||
|
||||
# return data from the front of the buffer
|
||||
data = self._pull_data()
|
||||
end_stream = self._buffer_start == self._final_size
|
||||
if end_stream:
|
||||
# all data up to the FIN has been received, we're done receiving
|
||||
self.is_finished = True
|
||||
if data or end_stream:
|
||||
return events.StreamDataReceived(
|
||||
data=data, end_stream=end_stream, stream_id=self._stream_id
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
def handle_reset(
|
||||
self, *, final_size: int, error_code: int = QuicErrorCode.NO_ERROR
|
||||
) -> Optional[events.StreamReset]:
|
||||
"""
|
||||
Handle an abrupt termination of the receiving part of the QUIC stream.
|
||||
"""
|
||||
if self._final_size is not None and final_size != self._final_size:
|
||||
raise FinalSizeError("Cannot change final size")
|
||||
|
||||
# we are done receiving
|
||||
self._final_size = final_size
|
||||
self.is_finished = True
|
||||
return events.StreamReset(error_code=error_code, stream_id=self._stream_id)
|
||||
|
||||
def on_stop_sending_delivery(self, delivery: QuicDeliveryState) -> None:
|
||||
"""
|
||||
Callback when a STOP_SENDING is ACK'd.
|
||||
"""
|
||||
if delivery != QuicDeliveryState.ACKED:
|
||||
self.stop_pending = True
|
||||
|
||||
def stop(self, error_code: int = QuicErrorCode.NO_ERROR) -> None:
|
||||
"""
|
||||
Request the peer stop sending data on the QUIC stream.
|
||||
"""
|
||||
self._stop_error_code = error_code
|
||||
self.stop_pending = True
|
||||
|
||||
def _pull_data(self) -> bytes:
|
||||
"""
|
||||
Remove data from the front of the buffer.
|
||||
"""
|
||||
try:
|
||||
has_data_to_read = self._ranges[0].start == self._buffer_start
|
||||
except IndexError:
|
||||
has_data_to_read = False
|
||||
if not has_data_to_read:
|
||||
return b""
|
||||
|
||||
r = self._ranges.shift()
|
||||
pos = r.stop - r.start
|
||||
data = bytes(self._buffer[:pos])
|
||||
del self._buffer[:pos]
|
||||
self._buffer_start = r.stop
|
||||
return data
|
||||
|
||||
|
||||
class QuicStreamSender:
|
||||
"""
|
||||
The send part of a QUIC stream.
|
||||
|
||||
It finishes:
|
||||
- immediately for a receive-only stream
|
||||
- upon acknowledgement of a STREAM_RESET frame
|
||||
- upon acknowledgement of a data frame with the FIN bit set
|
||||
"""
|
||||
|
||||
def __init__(self, stream_id: Optional[int], writable: bool) -> None:
|
||||
self.buffer_is_empty = True
|
||||
self.highest_offset = 0
|
||||
self.is_finished = not writable
|
||||
self.reset_pending = False
|
||||
|
||||
self._acked = RangeSet()
|
||||
self._acked_fin = False
|
||||
self._buffer = bytearray()
|
||||
self._buffer_fin: Optional[int] = None
|
||||
self._buffer_start = 0 # the offset for the start of the buffer
|
||||
self._buffer_stop = 0 # the offset for the stop of the buffer
|
||||
self._pending = RangeSet()
|
||||
self._pending_eof = False
|
||||
self._reset_error_code: Optional[int] = None
|
||||
self._stream_id = stream_id
|
||||
|
||||
@property
|
||||
def next_offset(self) -> int:
|
||||
"""
|
||||
The offset for the next frame to send.
|
||||
|
||||
This is used to determine the space needed for the frame's `offset` field.
|
||||
"""
|
||||
try:
|
||||
return self._pending[0].start
|
||||
except IndexError:
|
||||
return self._buffer_stop
|
||||
|
||||
def get_frame(
|
||||
self, max_size: int, max_offset: Optional[int] = None
|
||||
) -> Optional[QuicStreamFrame]:
|
||||
"""
|
||||
Get a frame of data to send.
|
||||
"""
|
||||
assert self._reset_error_code is None, "cannot call get_frame() after reset()"
|
||||
|
||||
# get the first pending data range
|
||||
try:
|
||||
r = self._pending[0]
|
||||
except IndexError:
|
||||
if self._pending_eof:
|
||||
# FIN only
|
||||
self._pending_eof = False
|
||||
return QuicStreamFrame(fin=True, offset=self._buffer_fin)
|
||||
|
||||
self.buffer_is_empty = True
|
||||
return None
|
||||
|
||||
# apply flow control
|
||||
start = r.start
|
||||
stop = min(r.stop, start + max_size)
|
||||
if max_offset is not None and stop > max_offset:
|
||||
stop = max_offset
|
||||
if stop <= start:
|
||||
return None
|
||||
|
||||
# create frame
|
||||
frame = QuicStreamFrame(
|
||||
data=bytes(
|
||||
self._buffer[start - self._buffer_start : stop - self._buffer_start]
|
||||
),
|
||||
offset=start,
|
||||
)
|
||||
self._pending.subtract(start, stop)
|
||||
|
||||
# track the highest offset ever sent
|
||||
if stop > self.highest_offset:
|
||||
self.highest_offset = stop
|
||||
|
||||
# if the buffer is empty and EOF was written, set the FIN bit
|
||||
if self._buffer_fin == stop:
|
||||
frame.fin = True
|
||||
self._pending_eof = False
|
||||
|
||||
return frame
|
||||
|
||||
def get_reset_frame(self) -> QuicResetStreamFrame:
|
||||
self.reset_pending = False
|
||||
return QuicResetStreamFrame(
|
||||
error_code=self._reset_error_code,
|
||||
final_size=self.highest_offset,
|
||||
stream_id=self._stream_id,
|
||||
)
|
||||
|
||||
def on_data_delivery(
|
||||
self, delivery: QuicDeliveryState, start: int, stop: int, fin: bool
|
||||
) -> None:
|
||||
"""
|
||||
Callback when sent data is ACK'd.
|
||||
"""
|
||||
# If the frame had the FIN bit set, its end MUST match otherwise
|
||||
# we have a programming error.
|
||||
assert (
|
||||
not fin or stop == self._buffer_fin
|
||||
), "on_data_delivered() was called with inconsistent fin / stop"
|
||||
|
||||
# If a reset has been requested, stop processing data delivery.
|
||||
# The transition to the finished state only depends on the reset
|
||||
# being acknowledged.
|
||||
if self._reset_error_code is not None:
|
||||
return
|
||||
|
||||
if delivery == QuicDeliveryState.ACKED:
|
||||
if stop > start:
|
||||
# Some data has been ACK'd, discard it.
|
||||
self._acked.add(start, stop)
|
||||
first_range = self._acked[0]
|
||||
if first_range.start == self._buffer_start:
|
||||
size = first_range.stop - first_range.start
|
||||
self._acked.shift()
|
||||
self._buffer_start += size
|
||||
del self._buffer[:size]
|
||||
|
||||
if fin:
|
||||
# The FIN has been ACK'd.
|
||||
self._acked_fin = True
|
||||
|
||||
if self._buffer_start == self._buffer_fin and self._acked_fin:
|
||||
# All data and the FIN have been ACK'd, we're done sending.
|
||||
self.is_finished = True
|
||||
else:
|
||||
if stop > start:
|
||||
# Some data has been lost, reschedule it.
|
||||
self.buffer_is_empty = False
|
||||
self._pending.add(start, stop)
|
||||
|
||||
if fin:
|
||||
# The FIN has been lost, reschedule it.
|
||||
self.buffer_is_empty = False
|
||||
self._pending_eof = True
|
||||
|
||||
def on_reset_delivery(self, delivery: QuicDeliveryState) -> None:
|
||||
"""
|
||||
Callback when a reset is ACK'd.
|
||||
"""
|
||||
if delivery == QuicDeliveryState.ACKED:
|
||||
# The reset has been ACK'd, we're done sending.
|
||||
self.is_finished = True
|
||||
else:
|
||||
# The reset has been lost, reschedule it.
|
||||
self.reset_pending = True
|
||||
|
||||
def reset(self, error_code: int) -> None:
|
||||
"""
|
||||
Abruptly terminate the sending part of the QUIC stream.
|
||||
"""
|
||||
assert self._reset_error_code is None, "cannot call reset() more than once"
|
||||
self._reset_error_code = error_code
|
||||
self.reset_pending = True
|
||||
|
||||
# Prevent any more data from being sent or re-sent.
|
||||
self.buffer_is_empty = True
|
||||
|
||||
def write(self, data: bytes, end_stream: bool = False) -> None:
|
||||
"""
|
||||
Write some data bytes to the QUIC stream.
|
||||
"""
|
||||
assert self._buffer_fin is None, "cannot call write() after FIN"
|
||||
assert self._reset_error_code is None, "cannot call write() after reset()"
|
||||
size = len(data)
|
||||
|
||||
if size:
|
||||
self.buffer_is_empty = False
|
||||
self._pending.add(self._buffer_stop, self._buffer_stop + size)
|
||||
self._buffer += data
|
||||
self._buffer_stop += size
|
||||
if end_stream:
|
||||
self.buffer_is_empty = False
|
||||
self._buffer_fin = self._buffer_stop
|
||||
self._pending_eof = True
|
||||
|
||||
|
||||
class QuicStream:
|
||||
def __init__(
|
||||
self,
|
||||
stream_id: Optional[int] = None,
|
||||
max_stream_data_local: int = 0,
|
||||
max_stream_data_remote: int = 0,
|
||||
readable: bool = True,
|
||||
writable: bool = True,
|
||||
) -> None:
|
||||
self.is_blocked = False
|
||||
self.max_stream_data_local = max_stream_data_local
|
||||
self.max_stream_data_local_sent = max_stream_data_local
|
||||
self.max_stream_data_remote = max_stream_data_remote
|
||||
self.receiver = QuicStreamReceiver(stream_id=stream_id, readable=readable)
|
||||
self.sender = QuicStreamSender(stream_id=stream_id, writable=writable)
|
||||
self.stream_id = stream_id
|
||||
|
||||
@property
|
||||
def is_finished(self) -> bool:
|
||||
return self.receiver.is_finished and self.sender.is_finished
|
||||
2185
venv/Lib/site-packages/aioquic/tls.py
Normal file
2185
venv/Lib/site-packages/aioquic/tls.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user