Files
ragflow/rag/llm/tool_decorator.py
Kevin Hu e7544562cc Feat: @tool decorator for chat-model tool registration (#15047)
## Summary

- Adds a lightweight `@tool` decorator and `FunctionToolSession` adapter
in `rag/llm/tool_decorator.py` that let callers register plain Python
functions as LLM tools without hand-writing OpenAI function schemas or
building an MCP-style session.
- Refactors `Base.bind_tools` and `LiteLLMBase.bind_tools` in
`rag/llm/chat_model.py` to accept either the new decorator form
`bind_tools(tools=[fn1, fn2])` or the existing `(toolcall_session,
tools_schemas)` form, so existing agent/dialog call-sites in
`agent/component/agent_with_tools.py`, `api/db/services/llm_service.py`,
and `api/db/services/dialog_service.py` are unaffected.
- Adds 8 unit tests in `test/unit_test/rag/llm/test_tool_decorator.py`
covering schema shape, required/optional inference, sync + async
dispatch, and bad-input rejection.

## Usage

```python
from rag.llm.tool_decorator import tool

@tool
def get_weather(city: str) -> str:
    """Get current weather for a city.

    :param city: City name to look up.
    """
    return f"{city}: 21 C, partly cloudy"

chat_mdl.bind_tools(tools=[get_weather])
ans, tk = await chat_mdl.async_chat_with_tools(system, history)
```

The decorator introspects `inspect.signature` + type hints + the
docstring (`:param name:` style) and attaches an OpenAI-format
`openai_schema` to the callable. `FunctionToolSession` duck-types the
existing `ToolCallSession` protocol, dispatching async callables
directly and sync ones through `thread_pool_exec` so the event loop is
never blocked.

## Design notes

- `tool_decorator.py` deliberately does **not** live inside
`rag/llm/__init__.py` to avoid forcing every consumer through the heavy
provider auto-discovery loop and to sidestep a circular import
(`__init__.py` imports `chat_model`, which would otherwise need symbols
from `__init__.py`).
- `FunctionToolSession` is duck-typed against
`common.mcp_tool_call_conn.ToolCallSession` rather than explicitly
inheriting from it, so importing the decorator doesn't pull the MCP
client SDK into the import graph.
- Docstring parsing is intentionally minimal (`:param name:` only) to
keep this dependency-free; Google/NumPy styles can be added later via
`docstring_parser` if needed.

## Test plan

- [x] `python -m pytest test/unit_test/rag/llm/test_tool_decorator.py
-v` — 8 passed
- [x] `python -m pytest test/unit_test/rag/llm/
--ignore=test/unit_test/rag/llm/test_perplexity_embed.py` — 11 passed
(the ignored test has a pre-existing `numpy` import that's unrelated)
- [ ] Reviewer: smoke-test the new path end-to-end with a live model via
`chat_mdl.bind_tools(tools=[my_fn])` to confirm the OpenAI-format
schemas pass through unchanged

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-21 15:32:17 +08:00

222 lines
7.9 KiB
Python

#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Lightweight ``@tool`` decorator and matching ``ToolCallSession`` adapter.
Lets callers register plain Python functions as LLM tools without having to
hand-write the OpenAI function schema or build an MCP-style session::
from rag.llm.tool_decorator import tool
@tool
def get_weather(city: str) -> str:
\"\"\"Get current weather for a city.
:param city: City name to look up.
\"\"\"
return f"{city}: 21 C, partly cloudy"
chat_mdl.bind_tools(tools=[get_weather])
The decorator introspects the function signature, type hints, and docstring,
attaches an OpenAI-format schema as ``fn.openai_schema``, and marks the
function with ``fn._is_tool = True`` so :meth:`Base.bind_tools` can detect
the new style.
"""
from __future__ import annotations
import asyncio
import inspect
import logging
import re
from collections.abc import Mapping
from typing import Any, Callable, Union, get_args, get_origin, get_type_hints
from common.misc_utils import thread_pool_exec
_PY_TO_JSON: dict[type, str] = {
str: "string",
int: "integer",
float: "number",
bool: "boolean",
list: "array",
dict: "object",
type(None): "null",
}
def _py_type_to_json(py_type: Any) -> dict[str, Any]:
"""Best-effort mapping from a Python annotation to a JSON-schema fragment.
Handles ``Optional[T]`` / ``T | None`` by unwrapping the non-None branch
and lets the ``required`` list (built from defaults) carry optionality.
Unknown types fall back to ``{"type": "string"}`` so the schema stays
valid even when annotations are missing.
"""
if py_type is inspect.Parameter.empty or py_type is Any:
return {"type": "string"}
origin = get_origin(py_type)
if origin is Union:
non_none = [a for a in get_args(py_type) if a is not type(None)]
if len(non_none) == 1:
return _py_type_to_json(non_none[0])
return {"type": "string"}
if origin in (list, tuple, set, frozenset):
item_args = get_args(py_type)
item_schema = _py_type_to_json(item_args[0]) if item_args else {"type": "string"}
return {"type": "array", "items": item_schema}
if origin is dict:
return {"type": "object"}
if isinstance(py_type, type):
return {"type": _PY_TO_JSON.get(py_type, "string")}
return {"type": "string"}
_PARAM_RE = re.compile(r"^\s*:param\s+(?P<name>\w+)\s*:\s*(?P<desc>.+?)\s*$")
def _parse_param_docs(docstring: str | None) -> tuple[str, dict[str, str]]:
"""Pull a short function description and ``:param name:`` lines out of a docstring.
Intentionally minimal — Google/NumPy styles are not parsed. Anything
before the first ``:param`` line becomes the function description.
"""
if not docstring:
return "", {}
lines = inspect.cleandoc(docstring).splitlines()
desc_lines: list[str] = []
param_docs: dict[str, str] = {}
for line in lines:
m = _PARAM_RE.match(line)
if m:
param_docs[m.group("name")] = m.group("desc")
elif not param_docs:
desc_lines.append(line)
return "\n".join(desc_lines).strip(), param_docs
def _build_openai_schema(fn: Callable[..., Any]) -> dict[str, Any]:
sig = inspect.signature(fn)
try:
hints = get_type_hints(fn)
except Exception:
hints = {}
description, param_docs = _parse_param_docs(fn.__doc__)
properties: dict[str, dict[str, Any]] = {}
required: list[str] = []
for name, param in sig.parameters.items():
if name in ("self", "cls") or param.kind in (
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
):
continue
schema = _py_type_to_json(hints.get(name, param.annotation))
if name in param_docs:
schema["description"] = param_docs[name]
properties[name] = schema
if param.default is inspect.Parameter.empty:
required.append(name)
return {
"type": "function",
"function": {
"name": fn.__name__,
"description": description or fn.__name__,
"parameters": {
"type": "object",
"properties": properties,
"required": required,
},
},
}
def tool(fn: Callable[..., Any]) -> Callable[..., Any]:
"""Mark ``fn`` as an LLM tool and attach an OpenAI-format schema to it.
The wrapped callable is the same callable — we only set two attributes:
* ``fn._is_tool = True`` — sentinel so :meth:`Base.bind_tools` can tell a
``@tool`` callable apart from a raw schema dict.
* ``fn.openai_schema`` — the schema dict passed verbatim to the LLM
provider in the ``tools=[...]`` request field.
"""
fn.openai_schema = _build_openai_schema(fn) # type: ignore[attr-defined]
fn._is_tool = True # type: ignore[attr-defined]
return fn
def is_tool(obj: Any) -> bool:
return callable(obj) and getattr(obj, "_is_tool", False)
class FunctionToolSession:
"""Adapter that lets a list of ``@tool``-decorated callables satisfy the
:class:`common.mcp_tool_call_conn.ToolCallSession` protocol used by the
chat model tool loop (duck-typed, no explicit inheritance to avoid
pulling the MCP client SDK into this module's import graph).
The chat model only ever calls ``tool_call`` / ``tool_call_async`` with
``(name, arguments)`` — this class looks the name up in ``tools_map`` and
invokes the callable, awaiting it if it is a coroutine and otherwise
pushing it through ``thread_pool_exec`` so the event loop is not blocked.
"""
def __init__(self, tools: list[Callable[..., Any]]):
self.tools_map: dict[str, Callable[..., Any]] = {}
for fn in tools:
if not is_tool(fn):
raise TypeError(
f"{getattr(fn, '__name__', fn)!r} is not a @tool-decorated callable"
)
self.tools_map[fn.openai_schema["function"]["name"]] = fn
@property
def schemas(self) -> list[dict[str, Any]]:
return [fn.openai_schema for fn in self.tools_map.values()]
def tool_call(self, name: str, arguments: dict[str, Any], timeout: float | int = 10) -> Any:
return asyncio.run(self.tool_call_async(name, arguments, request_timeout=timeout))
async def tool_call_async(self, name: str, arguments: dict[str, Any], request_timeout: float | int = 10) -> Any:
if name not in self.tools_map:
raise KeyError(f"Tool {name!r} is not registered")
if not isinstance(arguments, Mapping):
raise TypeError(
f"Tool arguments for {name} must be an object, got {type(arguments).__name__}"
)
fn = self.tools_map[name]
logging.info(f"[FunctionTool] invoke name={name} args={str(arguments)[:200]}")
if asyncio.iscoroutinefunction(fn):
coro = fn(**arguments)
else:
# Sync callables run in the thread pool. asyncio.wait_for cancels
# the awaiting task on timeout, but Python cannot interrupt the
# underlying worker thread — the function keeps running in the
# background until it returns. Callers should treat sync tools
# that block on I/O accordingly.
coro = thread_pool_exec(fn, **arguments)
return await asyncio.wait_for(coro, timeout=request_timeout)