fix: propagate contextvars through thread_pool_exec (#16247)

## Problem

`thread_pool_exec()` dispatches work via `loop.run_in_executor()`, which
submits the callable with a plain `executor.submit(func, *args)` and
does **not** copy the caller's `contextvars.Context`. So a `ContextVar`
set in the async caller is not visible inside the function running in
the worker thread.

This differs from `asyncio.to_thread()`, which runs the callable inside
a copied context. `run_in_executor()` has never propagated context
(verified on Python 3.12 and 3.13) — so this is a pre-existing gap in
the helper, **not** a regression or a Python-version compatibility
issue.

Concretely, any code that sets a `ContextVar` in async code and reads it
inside a function dispatched via `thread_pool_exec` (request tracing,
per-task state, Langfuse trace propagation, etc.) silently loses that
context.

## Fix

Copy the current context before submitting and run the callable inside
it with `ctx.run()`, matching what `asyncio.to_thread()` does:

```python
async def thread_pool_exec(func, *args, **kwargs):
    loop = asyncio.get_running_loop()
    ctx = contextvars.copy_context()
    if kwargs:
        inner = functools.partial(func, *args, **kwargs)
        return await loop.run_in_executor(_thread_pool_executor(), ctx.run, inner)
    return await loop.run_in_executor(_thread_pool_executor(), ctx.run, func, *args)
```

This explicitly **adds** ContextVar propagation to the helper (it does
not restore any prior behavior). Backward-compatible.

## Tests

`TestThreadPoolExec` covers propagation, the kwargs path, per-call
isolation and the unset-default case.

> Note: the branch name still contains `python313` for historical
reasons; the change is unrelated to any Python version.
This commit is contained in:
VincentLambert
2026-06-23 09:17:42 +02:00
committed by GitHub
parent d8ee1ffaad
commit 11e14a8353
2 changed files with 74 additions and 4 deletions

View File

@@ -14,6 +14,7 @@
# limitations under the License.
#
import asyncio
import contextvars
import hashlib
import sys
import types
@@ -21,8 +22,10 @@ import uuid
from contextlib import contextmanager
from unittest.mock import patch
import pytest
from common import ssrf_guard
from common.misc_utils import convert_bytes, download_img, get_uuid, hash_str2int
from common.misc_utils import convert_bytes, download_img, get_uuid, hash_str2int, thread_pool_exec
class _Hdr:
@@ -105,6 +108,67 @@ def _fake_httpx_sys_modules(client):
sys.modules.pop("httpx", None)
@pytest.mark.p1
class TestThreadPoolExec:
"""Test cases for thread_pool_exec — verifies ContextVar propagation into the worker thread."""
def test_contextvar_propagated_to_thread(self):
"""ContextVar set in async caller must be visible inside the thread."""
_var: contextvars.ContextVar[str] = contextvars.ContextVar("_var")
def read_var():
return _var.get(None)
async def run():
_var.set("hello")
return await thread_pool_exec(read_var)
result = asyncio.run(run())
assert result == "hello"
def test_contextvar_propagated_with_kwargs(self):
"""ContextVar propagation also works when kwargs are passed (functools.partial path)."""
_var: contextvars.ContextVar[int] = contextvars.ContextVar("_var_kw")
def read_var_and_add(increment):
return (_var.get(0)) + increment
async def run():
_var.set(10)
return await thread_pool_exec(read_var_and_add, increment=5)
result = asyncio.run(run())
assert result == 15
def test_contextvar_isolation_between_calls(self):
"""Each thread_pool_exec call captures the context at submission time."""
_var: contextvars.ContextVar[str] = contextvars.ContextVar("_var_iso")
def read_var():
return _var.get(None)
async def run():
_var.set("first")
r1 = await thread_pool_exec(read_var)
_var.set("second")
r2 = await thread_pool_exec(read_var)
return r1, r2
r1, r2 = asyncio.run(run())
assert r1 == "first"
assert r2 == "second"
def test_unset_contextvar_returns_default(self):
"""A ContextVar that was never set in caller returns its default inside the thread."""
_var: contextvars.ContextVar[str] = contextvars.ContextVar("_var_unset", default="default")
def read_var():
return _var.get()
result = asyncio.run(thread_pool_exec(read_var))
assert result == "default"
class TestGetUuid:
"""Test cases for get_uuid function"""