mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
fix: propagate contextvars through thread_pool_exec (#16247)
## Problem
`thread_pool_exec()` dispatches work via `loop.run_in_executor()`, which
submits the callable with a plain `executor.submit(func, *args)` and
does **not** copy the caller's `contextvars.Context`. So a `ContextVar`
set in the async caller is not visible inside the function running in
the worker thread.
This differs from `asyncio.to_thread()`, which runs the callable inside
a copied context. `run_in_executor()` has never propagated context
(verified on Python 3.12 and 3.13) — so this is a pre-existing gap in
the helper, **not** a regression or a Python-version compatibility
issue.
Concretely, any code that sets a `ContextVar` in async code and reads it
inside a function dispatched via `thread_pool_exec` (request tracing,
per-task state, Langfuse trace propagation, etc.) silently loses that
context.
## Fix
Copy the current context before submitting and run the callable inside
it with `ctx.run()`, matching what `asyncio.to_thread()` does:
```python
async def thread_pool_exec(func, *args, **kwargs):
loop = asyncio.get_running_loop()
ctx = contextvars.copy_context()
if kwargs:
inner = functools.partial(func, *args, **kwargs)
return await loop.run_in_executor(_thread_pool_executor(), ctx.run, inner)
return await loop.run_in_executor(_thread_pool_executor(), ctx.run, func, *args)
```
This explicitly **adds** ContextVar propagation to the helper (it does
not restore any prior behavior). Backward-compatible.
## Tests
`TestThreadPoolExec` covers propagation, the kwargs path, per-call
isolation and the unset-default case.
> Note: the branch name still contains `python313` for historical
reasons; the change is unrelated to any Python version.
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import contextvars
|
||||
import hashlib
|
||||
import sys
|
||||
import types
|
||||
@@ -21,8 +22,10 @@ import uuid
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from common import ssrf_guard
|
||||
from common.misc_utils import convert_bytes, download_img, get_uuid, hash_str2int
|
||||
from common.misc_utils import convert_bytes, download_img, get_uuid, hash_str2int, thread_pool_exec
|
||||
|
||||
|
||||
class _Hdr:
|
||||
@@ -105,6 +108,67 @@ def _fake_httpx_sys_modules(client):
|
||||
sys.modules.pop("httpx", None)
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
class TestThreadPoolExec:
|
||||
"""Test cases for thread_pool_exec — verifies ContextVar propagation into the worker thread."""
|
||||
|
||||
def test_contextvar_propagated_to_thread(self):
|
||||
"""ContextVar set in async caller must be visible inside the thread."""
|
||||
_var: contextvars.ContextVar[str] = contextvars.ContextVar("_var")
|
||||
|
||||
def read_var():
|
||||
return _var.get(None)
|
||||
|
||||
async def run():
|
||||
_var.set("hello")
|
||||
return await thread_pool_exec(read_var)
|
||||
|
||||
result = asyncio.run(run())
|
||||
assert result == "hello"
|
||||
|
||||
def test_contextvar_propagated_with_kwargs(self):
|
||||
"""ContextVar propagation also works when kwargs are passed (functools.partial path)."""
|
||||
_var: contextvars.ContextVar[int] = contextvars.ContextVar("_var_kw")
|
||||
|
||||
def read_var_and_add(increment):
|
||||
return (_var.get(0)) + increment
|
||||
|
||||
async def run():
|
||||
_var.set(10)
|
||||
return await thread_pool_exec(read_var_and_add, increment=5)
|
||||
|
||||
result = asyncio.run(run())
|
||||
assert result == 15
|
||||
|
||||
def test_contextvar_isolation_between_calls(self):
|
||||
"""Each thread_pool_exec call captures the context at submission time."""
|
||||
_var: contextvars.ContextVar[str] = contextvars.ContextVar("_var_iso")
|
||||
|
||||
def read_var():
|
||||
return _var.get(None)
|
||||
|
||||
async def run():
|
||||
_var.set("first")
|
||||
r1 = await thread_pool_exec(read_var)
|
||||
_var.set("second")
|
||||
r2 = await thread_pool_exec(read_var)
|
||||
return r1, r2
|
||||
|
||||
r1, r2 = asyncio.run(run())
|
||||
assert r1 == "first"
|
||||
assert r2 == "second"
|
||||
|
||||
def test_unset_contextvar_returns_default(self):
|
||||
"""A ContextVar that was never set in caller returns its default inside the thread."""
|
||||
_var: contextvars.ContextVar[str] = contextvars.ContextVar("_var_unset", default="default")
|
||||
|
||||
def read_var():
|
||||
return _var.get()
|
||||
|
||||
result = asyncio.run(thread_pool_exec(read_var))
|
||||
assert result == "default"
|
||||
|
||||
|
||||
class TestGetUuid:
|
||||
"""Test cases for get_uuid function"""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user