diff --git a/common/misc_utils.py b/common/misc_utils.py index 9225fcd25d..8226041c1f 100644 --- a/common/misc_utils.py +++ b/common/misc_utils.py @@ -16,6 +16,7 @@ import asyncio import base64 +import contextvars import functools import hashlib import logging @@ -250,8 +251,13 @@ def _thread_pool_executor(): async def thread_pool_exec(func, *args, **kwargs): + # loop.run_in_executor() submits the callable without propagating the caller's + # contextvars (unlike asyncio.to_thread, which copies the context). Copy the + # current context and run the callable inside it so ContextVars set by the + # caller (e.g. tracing / per-request state) are visible in the worker thread. loop = asyncio.get_running_loop() + ctx = contextvars.copy_context() if kwargs: - func = functools.partial(func, *args, **kwargs) - return await loop.run_in_executor(_thread_pool_executor(), func) - return await loop.run_in_executor(_thread_pool_executor(), func, *args) + inner = functools.partial(func, *args, **kwargs) + return await loop.run_in_executor(_thread_pool_executor(), ctx.run, inner) + return await loop.run_in_executor(_thread_pool_executor(), ctx.run, func, *args) diff --git a/test/unit_test/common/test_misc_utils.py b/test/unit_test/common/test_misc_utils.py index 6ca24f1bbd..d94de1027b 100644 --- a/test/unit_test/common/test_misc_utils.py +++ b/test/unit_test/common/test_misc_utils.py @@ -14,6 +14,7 @@ # limitations under the License. # import asyncio +import contextvars import hashlib import sys import types @@ -21,8 +22,10 @@ import uuid from contextlib import contextmanager from unittest.mock import patch +import pytest + from common import ssrf_guard -from common.misc_utils import convert_bytes, download_img, get_uuid, hash_str2int +from common.misc_utils import convert_bytes, download_img, get_uuid, hash_str2int, thread_pool_exec class _Hdr: @@ -105,6 +108,67 @@ def _fake_httpx_sys_modules(client): sys.modules.pop("httpx", None) +@pytest.mark.p1 +class TestThreadPoolExec: + """Test cases for thread_pool_exec — verifies ContextVar propagation into the worker thread.""" + + def test_contextvar_propagated_to_thread(self): + """ContextVar set in async caller must be visible inside the thread.""" + _var: contextvars.ContextVar[str] = contextvars.ContextVar("_var") + + def read_var(): + return _var.get(None) + + async def run(): + _var.set("hello") + return await thread_pool_exec(read_var) + + result = asyncio.run(run()) + assert result == "hello" + + def test_contextvar_propagated_with_kwargs(self): + """ContextVar propagation also works when kwargs are passed (functools.partial path).""" + _var: contextvars.ContextVar[int] = contextvars.ContextVar("_var_kw") + + def read_var_and_add(increment): + return (_var.get(0)) + increment + + async def run(): + _var.set(10) + return await thread_pool_exec(read_var_and_add, increment=5) + + result = asyncio.run(run()) + assert result == 15 + + def test_contextvar_isolation_between_calls(self): + """Each thread_pool_exec call captures the context at submission time.""" + _var: contextvars.ContextVar[str] = contextvars.ContextVar("_var_iso") + + def read_var(): + return _var.get(None) + + async def run(): + _var.set("first") + r1 = await thread_pool_exec(read_var) + _var.set("second") + r2 = await thread_pool_exec(read_var) + return r1, r2 + + r1, r2 = asyncio.run(run()) + assert r1 == "first" + assert r2 == "second" + + def test_unset_contextvar_returns_default(self): + """A ContextVar that was never set in caller returns its default inside the thread.""" + _var: contextvars.ContextVar[str] = contextvars.ContextVar("_var_unset", default="default") + + def read_var(): + return _var.get() + + result = asyncio.run(thread_pool_exec(read_var)) + assert result == "default" + + class TestGetUuid: """Test cases for get_uuid function"""