2024-01-15 08:46:22 +08:00
#
2024-01-19 19:51:57 +08:00
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
2024-01-15 08:46:22 +08:00
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
2025-10-21 09:36:13 +08:00
2025-03-18 14:52:20 +08:00
import base64
import json
2025-11-19 13:17:22 +08:00
import logging
2023-12-28 13:50:13 +08:00
import os
2025-11-19 13:17:22 +08:00
import re
2025-10-21 09:36:13 +08:00
import tempfile
2025-03-18 14:52:20 +08:00
from abc import ABC
2025-07-30 19:41:09 +08:00
from copy import deepcopy
2023-12-28 13:50:13 +08:00
from io import BytesIO
2025-10-21 09:36:13 +08:00
from pathlib import Path
2025-06-03 14:18:40 +08:00
from urllib . parse import urljoin
2025-11-19 13:17:22 +08:00
2024-07-16 15:19:43 +08:00
import requests
2025-12-09 13:08:37 +08:00
from openai import OpenAI , AsyncOpenAI
from openai . lib . azure import AzureOpenAI , AsyncAzureOpenAI
2025-11-19 13:17:22 +08:00
from common . token_utils import num_tokens_from_string , total_token_count_from_response
2025-03-18 14:52:20 +08:00
from rag . nlp import is_english
2025-09-23 10:19:25 +08:00
from rag . prompts . generator import vision_llm_describe_prompt
2025-11-19 13:17:22 +08:00
2024-02-23 18:28:12 +08:00
2026-01-20 13:29:37 +08:00
from common . misc_utils import thread_pool_exec
2023-12-28 13:50:13 +08:00
class Base ( ABC ) :
2025-07-30 19:41:09 +08:00
def __init__ ( self , * * kwargs ) :
# Configure retry parameters
self . max_retries = kwargs . get ( " max_retries " , int ( os . environ . get ( " LLM_MAX_RETRIES " , 5 ) ) )
self . base_delay = kwargs . get ( " retry_interval " , float ( os . environ . get ( " LLM_BASE_DELAY " , 2.0 ) ) )
self . max_rounds = kwargs . get ( " max_rounds " , 5 )
self . is_tools = False
self . tools = [ ]
self . toolcall_sessions = { }
2025-10-16 09:39:59 +08:00
self . extra_body = None
2024-01-15 08:46:22 +08:00
2025-03-18 14:52:20 +08:00
def describe ( self , image ) :
raise NotImplementedError ( " Please implement encode method! " )
def describe_with_prompt ( self , image , prompt = None ) :
2023-12-28 13:50:13 +08:00
raise NotImplementedError ( " Please implement encode method! " )
2025-03-18 14:52:20 +08:00
2025-10-22 12:24:12 +08:00
def _form_history ( self , system , history , images = None ) :
2025-07-30 19:41:09 +08:00
hist = [ ]
2024-07-19 18:36:34 +08:00
if system :
2025-07-30 19:41:09 +08:00
hist . append ( { " role " : " system " , " content " : system } )
for h in history :
if images and h [ " role " ] == " user " :
h [ " content " ] = self . _image_prompt ( h [ " content " ] , images )
images = [ ]
hist . append ( h )
return hist
2026-02-11 09:47:33 +08:00
@staticmethod
def _blob_to_data_url ( blob , mime_type = " image/png " ) :
if isinstance ( blob , str ) :
blob = blob . strip ( )
if blob . startswith ( " data: " ) or blob . startswith ( " http:// " ) or blob . startswith ( " https:// " ) or blob . startswith ( " file:// " ) :
return blob
return f " data: { mime_type } ;base64, { blob } "
if isinstance ( blob , BytesIO ) :
blob = blob . getvalue ( )
if isinstance ( blob , memoryview ) :
blob = blob . tobytes ( )
if isinstance ( blob , bytearray ) :
blob = bytes ( blob )
if isinstance ( blob , bytes ) :
b64 = base64 . b64encode ( blob ) . decode ( " utf-8 " )
return f " data: { mime_type } ;base64, { b64 } "
return None
def _normalize_image ( self , image ) :
if isinstance ( image , dict ) :
inline_data = image . get ( " inline_data " )
if isinstance ( inline_data , dict ) :
mime = inline_data . get ( " mime_type " ) or " image/png "
data_url = self . _blob_to_data_url ( inline_data . get ( " data " ) , mime )
if data_url :
return data_url
image_url = image . get ( " image_url " )
if isinstance ( image_url , dict ) :
data_url = self . _blob_to_data_url ( image_url . get ( " url " ) , image . get ( " mime_type " ) or " image/png " )
if data_url :
return data_url
if isinstance ( image_url , str ) :
data_url = self . _blob_to_data_url ( image_url , image . get ( " mime_type " ) or " image/png " )
if data_url :
return data_url
if " url " in image :
data_url = self . _blob_to_data_url ( image . get ( " url " ) , image . get ( " mime_type " ) or " image/png " )
if data_url :
return data_url
mime = image . get ( " mime_type " ) or image . get ( " media_type " ) or " image/png "
for key in ( " blob " , " data " ) :
if key in image :
data_url = self . _blob_to_data_url ( image . get ( key ) , mime )
if data_url :
return data_url
if isinstance ( image , ( bytes , bytearray , memoryview , BytesIO ) ) :
return self . image2base64 ( image )
if isinstance ( image , str ) :
return self . _blob_to_data_url ( image , " image/png " )
return self . image2base64 ( image )
2025-07-30 19:41:09 +08:00
def _image_prompt ( self , text , images ) :
if not images :
return text
2025-08-05 09:26:42 +08:00
2025-08-07 15:20:01 +08:00
if isinstance ( images , str ) or " bytes " in type ( images ) . __name__ :
2025-08-05 09:26:42 +08:00
images = [ images ]
2025-07-30 19:41:09 +08:00
pmpt = [ { " type " : " text " , " text " : text } ]
for img in images :
2026-02-11 09:47:33 +08:00
try :
pmpt . append ( { " type " : " image_url " , " image_url " : { " url " : self . _normalize_image ( img ) } } )
except Exception :
logging . warning ( " [ %s ] Skip invalid image input in request payload. " , self . __class__ . __name__ )
continue
2025-07-30 19:41:09 +08:00
return pmpt
2024-07-19 18:36:34 +08:00
2025-12-09 13:08:37 +08:00
async def async_chat ( self , system , history , gen_conf , images = None , * * kwargs ) :
2025-07-30 19:41:09 +08:00
try :
2025-12-09 13:08:37 +08:00
response = await self . async_client . chat . completions . create (
2024-07-19 18:36:34 +08:00
model = self . model_name ,
2025-10-16 09:39:59 +08:00
messages = self . _form_history ( system , history , images ) ,
extra_body = self . extra_body ,
2024-07-19 18:36:34 +08:00
)
return response . choices [ 0 ] . message . content . strip ( ) , response . usage . total_tokens
except Exception as e :
return " **ERROR**: " + str ( e ) , 0
2025-12-09 13:08:37 +08:00
async def async_chat_streamly ( self , system , history , gen_conf , images = None , * * kwargs ) :
2024-07-19 18:36:34 +08:00
ans = " "
tk_count = 0
try :
2025-12-09 13:08:37 +08:00
response = await self . async_client . chat . completions . create (
2024-07-19 18:36:34 +08:00
model = self . model_name ,
2025-07-30 19:41:09 +08:00
messages = self . _form_history ( system , history , images ) ,
2025-10-16 09:39:59 +08:00
stream = True ,
extra_body = self . extra_body ,
2024-07-19 18:36:34 +08:00
)
2025-12-09 13:08:37 +08:00
async for resp in response :
2024-12-08 14:21:12 +08:00
if not resp . choices [ 0 ] . delta . content :
continue
2024-07-19 18:36:34 +08:00
delta = resp . choices [ 0 ] . delta . content
2025-07-30 19:41:09 +08:00
ans = delta
2024-07-19 18:36:34 +08:00
if resp . choices [ 0 ] . finish_reason == " length " :
2025-07-03 19:05:31 +08:00
ans + = " ... \n For the content length reason, it stopped, continue? " if is_english ( [ ans ] ) else " ······ \n 由于长度的原因,回答被截断了,要继续吗? "
2024-12-08 14:21:12 +08:00
if resp . choices [ 0 ] . finish_reason == " stop " :
2025-07-30 19:41:09 +08:00
tk_count + = resp . usage . total_tokens
2024-07-19 18:36:34 +08:00
yield ans
except Exception as e :
yield ans + " \n **ERROR**: " + str ( e )
yield tk_count
2025-03-18 14:52:20 +08:00
2025-11-10 09:31:19 +08:00
@staticmethod
def image2base64_rawvalue ( self , image ) :
# Return a base64 string without data URL header
if isinstance ( image , bytes ) :
b64 = base64 . b64encode ( image ) . decode ( " utf-8 " )
return b64
if isinstance ( image , BytesIO ) :
data = image . getvalue ( )
b64 = base64 . b64encode ( data ) . decode ( " utf-8 " )
return b64
with BytesIO ( ) as buffered :
try :
image . save ( buffered , format = " JPEG " )
except Exception :
2025-11-19 13:17:22 +08:00
# reset buffer before saving PNG
2025-11-10 09:31:19 +08:00
buffered . seek ( 0 )
buffered . truncate ( )
image . save ( buffered , format = " PNG " )
data = buffered . getvalue ( )
b64 = base64 . b64encode ( data ) . decode ( " utf-8 " )
return b64
2025-07-30 19:41:09 +08:00
@staticmethod
def image2base64 ( image ) :
2025-08-14 17:00:56 +08:00
# Return a data URL with the correct MIME to avoid provider mismatches
2024-01-22 19:51:38 +08:00
if isinstance ( image , bytes ) :
2025-08-14 17:00:56 +08:00
# Best-effort magic number sniffing
mime = " image/png "
if len ( image ) > = 2 and image [ 0 ] == 0xFF and image [ 1 ] == 0xD8 :
mime = " image/jpeg "
b64 = base64 . b64encode ( image ) . decode ( " utf-8 " )
return f " data: { mime } ;base64, { b64 } "
2023-12-28 13:50:13 +08:00
if isinstance ( image , BytesIO ) :
2025-08-14 17:00:56 +08:00
data = image . getvalue ( )
mime = " image/png "
if len ( data ) > = 2 and data [ 0 ] == 0xFF and data [ 1 ] == 0xD8 :
mime = " image/jpeg "
b64 = base64 . b64encode ( data ) . decode ( " utf-8 " )
return f " data: { mime } ;base64, { b64 } "
2025-09-10 15:55:33 +08:00
with BytesIO ( ) as buffered :
2025-10-09 09:47:36 +08:00
fmt = " jpeg "
2025-09-10 15:55:33 +08:00
try :
image . save ( buffered , format = " JPEG " )
except Exception :
2025-11-19 13:17:22 +08:00
# reset buffer before saving PNG
2025-09-10 15:55:33 +08:00
buffered . seek ( 0 )
buffered . truncate ( )
image . save ( buffered , format = " PNG " )
2025-10-09 09:47:36 +08:00
fmt = " png "
2025-09-10 15:55:33 +08:00
data = buffered . getvalue ( )
b64 = base64 . b64encode ( data ) . decode ( " utf-8 " )
2025-10-09 09:47:36 +08:00
mime = f " image/ { fmt } "
2025-08-14 17:00:56 +08:00
return f " data: { mime } ;base64, { b64 } "
2023-12-28 13:50:13 +08:00
def prompt ( self , b64 ) :
return [
{
" role " : " user " ,
2025-07-30 19:41:09 +08:00
" content " : self . _image_prompt (
" 请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。 "
if self . lang . lower ( ) == " chinese "
else " Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out. " ,
2025-11-19 13:17:22 +08:00
b64 ,
) ,
2023-12-28 13:50:13 +08:00
}
]
2025-03-18 14:52:20 +08:00
def vision_llm_prompt ( self , b64 , prompt = None ) :
2025-11-19 13:17:22 +08:00
return [ { " role " : " user " , " content " : self . _image_prompt ( prompt if prompt else vision_llm_describe_prompt ( ) , b64 ) } ]
2025-03-18 14:52:20 +08:00
2024-07-19 18:36:34 +08:00
2023-12-28 13:50:13 +08:00
class GptV4 ( Base ) :
2025-07-03 19:05:31 +08:00
_FACTORY_NAME = " OpenAI "
2025-07-30 19:41:09 +08:00
def __init__ ( self , key , model_name = " gpt-4-vision-preview " , lang = " Chinese " , base_url = " https://api.openai.com/v1 " , * * kwargs ) :
2024-12-08 14:21:12 +08:00
if not base_url :
2025-03-18 14:52:20 +08:00
base_url = " https://api.openai.com/v1 "
2025-10-21 09:36:13 +08:00
self . api_key = key
2024-03-28 19:15:16 +08:00
self . client = OpenAI ( api_key = key , base_url = base_url )
2025-12-09 13:08:37 +08:00
self . async_client = AsyncOpenAI ( api_key = key , base_url = base_url )
2024-01-15 08:46:22 +08:00
self . model_name = model_name
2024-02-23 18:28:12 +08:00
self . lang = lang
2025-07-30 19:41:09 +08:00
super ( ) . __init__ ( * * kwargs )
2023-12-28 13:50:13 +08:00
2025-03-18 14:52:20 +08:00
def describe ( self , image ) :
2023-12-28 13:50:13 +08:00
b64 = self . image2base64 ( image )
2025-09-23 16:06:12 +08:00
res = self . client . chat . completions . create (
2024-01-15 08:46:22 +08:00
model = self . model_name ,
2025-07-30 19:41:09 +08:00
messages = self . prompt ( b64 ) ,
2025-11-25 11:17:27 +08:00
extra_body = self . extra_body
2023-12-28 13:50:13 +08:00
)
2025-10-09 09:47:36 +08:00
return res . choices [ 0 ] . message . content . strip ( ) , total_token_count_from_response ( res )
2023-12-28 13:50:13 +08:00
2025-03-18 14:52:20 +08:00
def describe_with_prompt ( self , image , prompt = None ) :
b64 = self . image2base64 ( image )
2025-09-23 16:06:12 +08:00
res = self . client . chat . completions . create (
2025-03-18 14:52:20 +08:00
model = self . model_name ,
2025-07-30 19:41:09 +08:00
messages = self . vision_llm_prompt ( b64 , prompt ) ,
2025-10-16 09:39:59 +08:00
extra_body = self . extra_body ,
2025-03-18 14:52:20 +08:00
)
2025-11-19 13:17:22 +08:00
return res . choices [ 0 ] . message . content . strip ( ) , total_token_count_from_response ( res )
2025-03-18 14:52:20 +08:00
2025-02-27 14:06:49 +08:00
2025-07-30 19:41:09 +08:00
class AzureGptV4 ( GptV4 ) :
2025-07-03 19:05:31 +08:00
_FACTORY_NAME = " Azure-OpenAI "
2024-07-04 09:57:16 +08:00
def __init__ ( self , key , model_name , lang = " Chinese " , * * kwargs ) :
2025-07-03 19:05:31 +08:00
api_key = json . loads ( key ) . get ( " api_key " , " " )
api_version = json . loads ( key ) . get ( " api_version " , " 2024-02-01 " )
2024-10-11 11:26:42 +08:00
self . client = AzureOpenAI ( api_key = api_key , azure_endpoint = kwargs [ " base_url " ] , api_version = api_version )
2025-12-09 13:08:37 +08:00
self . async_client = AsyncAzureOpenAI ( api_key = api_key , azure_endpoint = kwargs [ " base_url " ] , api_version = api_version )
2024-07-04 09:57:16 +08:00
self . model_name = model_name
self . lang = lang
2025-07-30 19:41:09 +08:00
Base . __init__ ( self , * * kwargs )
2024-07-04 09:57:16 +08:00
2025-07-30 19:41:09 +08:00
class xAICV ( GptV4 ) :
2025-07-11 10:35:23 +08:00
_FACTORY_NAME = " xAI "
2025-07-30 19:41:09 +08:00
def __init__ ( self , key , model_name = " grok-3 " , lang = " Chinese " , base_url = None , * * kwargs ) :
2025-07-11 10:35:23 +08:00
if not base_url :
base_url = " https://api.x.ai/v1 "
2025-07-30 19:41:09 +08:00
super ( ) . __init__ ( key , model_name , lang = lang , base_url = base_url , * * kwargs )
2025-07-11 10:35:23 +08:00
2025-07-30 19:41:09 +08:00
class QWenCV ( GptV4 ) :
2025-07-03 19:05:31 +08:00
_FACTORY_NAME = " Tongyi-Qianwen "
2025-07-30 19:41:09 +08:00
def __init__ ( self , key , model_name = " qwen-vl-chat-v1 " , lang = " Chinese " , base_url = None , * * kwargs ) :
if not base_url :
base_url = " https://dashscope.aliyuncs.com/compatible-mode/v1 "
super ( ) . __init__ ( key , model_name , lang = lang , base_url = base_url , * * kwargs )
2025-07-03 19:05:31 +08:00
2026-02-11 09:47:33 +08:00
@staticmethod
def _extract_text_from_content ( content ) :
if isinstance ( content , str ) :
return content . strip ( )
if isinstance ( content , list ) :
texts = [ ]
for blk in content :
if not isinstance ( blk , dict ) :
continue
if blk . get ( " type " ) in { " text " , " input_text " } and blk . get ( " text " ) :
texts . append ( str ( blk [ " text " ] ) )
elif " text " in blk and isinstance ( blk . get ( " text " ) , ( str , int , float ) ) :
texts . append ( str ( blk [ " text " ] ) )
return " \n " . join ( texts ) . strip ( )
return " "
def _resolve_video_prompt ( self , system , history , * * kwargs ) :
prompt = kwargs . get ( " video_prompt " ) or kwargs . get ( " prompt " )
if isinstance ( prompt , str ) and prompt . strip ( ) :
return prompt . strip ( )
for h in reversed ( history or [ ] ) :
if h . get ( " role " ) != " user " :
continue
txt = self . _extract_text_from_content ( h . get ( " content " ) )
if txt :
return txt
if isinstance ( system , str ) and system . strip ( ) :
return system . strip ( )
return " Please summarize this video in proper sentences. "
2025-12-09 13:08:37 +08:00
async def async_chat ( self , system , history , gen_conf , images = None , video_bytes = None , filename = " " , * * kwargs ) :
2025-10-21 09:36:13 +08:00
if video_bytes :
try :
2026-02-11 09:47:33 +08:00
summary , summary_num_tokens = self . _process_video ( video_bytes , filename , self . _resolve_video_prompt ( system , history , * * kwargs ) )
2025-10-21 09:36:13 +08:00
return summary , summary_num_tokens
except Exception as e :
return " **ERROR**: " + str ( e ) , 0
2026-02-11 09:47:33 +08:00
return await super ( ) . async_chat ( system , history , gen_conf , images = images , * * kwargs )
2025-10-21 09:36:13 +08:00
2026-02-11 09:47:33 +08:00
def _process_video ( self , video_bytes , filename , prompt ) :
2025-10-21 09:36:13 +08:00
from dashscope import MultiModalConversation
video_suffix = Path ( filename ) . suffix or " .mp4 "
2026-02-11 09:47:33 +08:00
tmp_path = None
2025-10-21 09:36:13 +08:00
with tempfile . NamedTemporaryFile ( delete = False , suffix = video_suffix ) as tmp :
tmp . write ( video_bytes )
tmp_path = tmp . name
2026-02-11 09:47:33 +08:00
video_path = f " file:// { tmp_path } "
messages = [
{
" role " : " user " ,
" content " : [
{
" video " : video_path ,
" fps " : 2 ,
} ,
{
" text " : prompt ,
} ,
] ,
}
]
def call_api ( ) :
response = MultiModalConversation . call (
api_key = self . api_key ,
model = self . model_name ,
messages = messages ,
)
if response . get ( " message " ) :
raise Exception ( response [ " message " ] )
summary = response [ " output " ] [ " choices " ] [ 0 ] [ " message " ] . content [ 0 ] [ " text " ]
return summary , num_tokens_from_string ( summary )
2025-10-21 09:36:13 +08:00
2026-02-11 09:47:33 +08:00
try :
2025-10-21 09:36:13 +08:00
try :
return call_api ( )
2025-11-11 12:22:43 +08:00
except Exception as e1 :
import dashscope
dashscope . base_http_api_url = " https://dashscope-intl.aliyuncs.com/api/v1 "
try :
return call_api ( )
except Exception as e2 :
raise RuntimeError ( f " Both default and intl endpoint failed. \n First error: { e1 } \n Second error: { e2 } " )
2026-02-11 09:47:33 +08:00
finally :
if tmp_path and os . path . exists ( tmp_path ) :
try :
os . remove ( tmp_path )
except Exception :
logging . warning ( " [QWenCV] Failed to cleanup temp video file: %s " , tmp_path )
2025-10-21 09:36:13 +08:00
2024-02-23 18:28:12 +08:00
2025-07-30 19:41:09 +08:00
class HunyuanCV ( GptV4 ) :
_FACTORY_NAME = " Tencent Hunyuan "
2024-01-15 08:46:22 +08:00
2025-07-30 19:41:09 +08:00
def __init__ ( self , key , model_name , lang = " Chinese " , base_url = None , * * kwargs ) :
if not base_url :
base_url = " https://api.hunyuan.cloud.tencent.com/v1 "
super ( ) . __init__ ( key , model_name , lang = lang , base_url = base_url , * * kwargs )
2025-03-18 14:52:20 +08:00
2025-07-30 19:41:09 +08:00
class Zhipu4V ( GptV4 ) :
_FACTORY_NAME = " ZHIPU-AI "
2025-03-18 14:52:20 +08:00
2025-07-30 19:41:09 +08:00
def __init__ ( self , key , model_name = " glm-4v " , lang = " Chinese " , * * kwargs ) :
2025-12-09 13:08:37 +08:00
self . client = OpenAI ( api_key = key , base_url = " https://open.bigmodel.cn/api/paas/v4/ " )
self . async_client = AsyncOpenAI ( api_key = key , base_url = " https://open.bigmodel.cn/api/paas/v4/ " )
2025-07-30 19:41:09 +08:00
self . model_name = model_name
self . lang = lang
Base . __init__ ( self , * * kwargs )
2025-03-18 14:52:20 +08:00
2025-11-18 13:09:39 +08:00
def _clean_conf ( self , gen_conf ) :
if " max_tokens " in gen_conf :
del gen_conf [ " max_tokens " ]
gen_conf = self . _clean_conf_plealty ( gen_conf )
return gen_conf
def _clean_conf_plealty ( self , gen_conf ) :
if " presence_penalty " in gen_conf :
del gen_conf [ " presence_penalty " ]
if " frequency_penalty " in gen_conf :
del gen_conf [ " frequency_penalty " ]
return gen_conf
fix(llm): replace mutable default `gen_conf={}` with None + defensive copy (#14566)
### What
19 methods across `rag/llm/chat_model.py` and `rag/llm/cv_model.py`
declare `gen_conf={}` (or `gen_conf: dict = {}`) as a parameter default
and then mutate `gen_conf` in place — typically `del
gen_conf["max_tokens"]`, `gen_conf["penalty_score"] = ...`, or
`gen_conf.pop(...)` as part of provider-specific normalization.
### The two bugs in this pattern
**1. Mutable default argument (Python footgun).** Python evaluates
default values **once** at function-definition time, so the single `{}`
dict is *shared* across every caller that doesn't pass `gen_conf`. The
first such call's mutations leak into the default seen by every
subsequent call.
```python
# Before
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"] # mutates the SHARED default dict
...
```
After call N with `max_tokens` set, call N+1 that omits `gen_conf` no
longer sees `max_tokens` — even though the caller never touched it.
**2. Caller-dict pollution.** When the caller *does* pass a `gen_conf`
dict, the same in-place mutations modify the caller's dict. A reused
`gen_conf` (very common for chat-loop callers that build the config once
and pass it on every turn) silently loses `max_tokens`,
`presence_penalty`, etc. after the first round.
### The fix
In every affected method:
- Change `gen_conf={}` (or `gen_conf: dict = {}`) → `gen_conf=None`.
- Add `gen_conf = dict(gen_conf or {})` as the first statement of the
body so all subsequent mutations operate on a fresh local copy.
```python
# After
def chat_streamly(self, system, history, gen_conf=None, **kwargs):
gen_conf = dict(gen_conf or {})
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"] # local copy — safe
...
```
This is byte-for-byte identical provider-side behavior for callers that
already pass a fresh `gen_conf` per call. The new `dict(...)` copy is
O(small constant) per call.
### Files changed
- `rag/llm/chat_model.py` — 17 methods
- `rag/llm/cv_model.py` — 2 methods
### Tests
Adds `test/unit_test/rag/llm/test_gen_conf_no_mutable_default.py` — an
`ast`-based regression guard that walks both modules and asserts no
parameter named `gen_conf` ever has a mutable literal (`{}` or `[]`) as
its default. The test caught **five additional `gen_conf: dict = {}`
sites** that an initial `gen_conf={}` text grep had missed (annotated
parameters with whitespace), and would fail again if the pattern is ever
reintroduced.
```
$ pytest test/unit_test/rag/llm/test_gen_conf_no_mutable_default.py -v
============================== 3 passed in 0.04s ===============================
```
`ruff check` passes on all touched files.
### Notes
- This PR is intentionally focused on **just** the `gen_conf` default +
copy fix. There's a related (but separate) `history.insert(0, ...)`
pattern in the same files that mutates the caller's history list in 12
places — left for a follow-up so this PR stays mechanical and easy to
review.
### Latest revision (`700bb54a7`) — addresses CodeRabbit review
- Type annotation: `gen_conf: dict = None` → `gen_conf: dict | None =
None` (5 occurrences in `chat_model.py`). The old annotation was a
static-checker mismatch since `None` isn't a `dict`.
- Regression test: the AST check accessed `default.keys` directly.
`ast.List` has no `.keys` attribute — a future `gen_conf=[]` would crash
with `AttributeError` instead of being caught. Use `getattr` for both
`.keys` (Dict) and `.elts` (List). Manually verified the updated check
correctly catches both `gen_conf={}` and `gen_conf=[]` while ignoring
`gen_conf=None` and non-empty literals.
---------
Co-authored-by: Ricardo <ricardo@example.com>
2026-05-09 13:11:44 +08:00
def _request ( self , msg , stream , gen_conf = None ) :
gen_conf = dict ( gen_conf or { } )
2025-11-18 13:09:39 +08:00
response = requests . post (
self . base_url ,
2025-11-19 13:17:22 +08:00
json = { " model " : self . model_name , " messages " : msg , " stream " : stream , * * gen_conf } ,
headers = {
" Authorization " : f " Bearer { self . api_key } " ,
" Content-Type " : " application/json " ,
2025-11-18 13:09:39 +08:00
} ,
)
return response . json ( )
2025-12-09 13:08:37 +08:00
async def async_chat ( self , system , history , gen_conf , images = None , * * kwargs ) :
2025-11-18 13:09:39 +08:00
if system and history and history [ 0 ] . get ( " role " ) != " system " :
history . insert ( 0 , { " role " : " system " , " content " : system } )
gen_conf = self . _clean_conf ( gen_conf )
logging . info ( json . dumps ( history , ensure_ascii = False , indent = 2 ) )
2025-12-09 13:08:37 +08:00
response = await self . async_client . chat . completions . create ( model = self . model_name , messages = self . _form_history ( system , history , images ) , stream = False , * * gen_conf )
2025-11-18 13:09:39 +08:00
content = response . choices [ 0 ] . message . content . strip ( )
cleaned = re . sub ( r " < \ |(begin_of_box|end_of_box) \ |> " , " " , content ) . strip ( )
return cleaned , total_token_count_from_response ( response )
2025-12-09 13:08:37 +08:00
async def async_chat_streamly ( self , system , history , gen_conf , images = None , * * kwargs ) :
2025-11-19 13:17:22 +08:00
from rag . llm . chat_model import LENGTH_NOTIFICATION_CN , LENGTH_NOTIFICATION_EN
2025-11-18 13:09:39 +08:00
from rag . nlp import is_chinese
if system and history and history [ 0 ] . get ( " role " ) != " system " :
history . insert ( 0 , { " role " : " system " , " content " : system } )
gen_conf = self . _clean_conf ( gen_conf )
ans = " "
tk_count = 0
try :
logging . info ( json . dumps ( history , ensure_ascii = False , indent = 2 ) )
2025-12-09 13:08:37 +08:00
response = await self . async_client . chat . completions . create ( model = self . model_name , messages = self . _form_history ( system , history , images ) , stream = True , * * gen_conf )
async for resp in response :
2025-11-18 13:09:39 +08:00
if not resp . choices [ 0 ] . delta . content :
continue
delta = resp . choices [ 0 ] . delta . content
ans = delta
if resp . choices [ 0 ] . finish_reason == " length " :
if is_chinese ( ans ) :
ans + = LENGTH_NOTIFICATION_CN
else :
ans + = LENGTH_NOTIFICATION_EN
tk_count = total_token_count_from_response ( resp )
if resp . choices [ 0 ] . finish_reason == " stop " :
tk_count = total_token_count_from_response ( resp )
yield ans
except Exception as e :
yield ans + " \n **ERROR**: " + str ( e )
yield tk_count
def describe ( self , image ) :
return self . describe_with_prompt ( image )
def describe_with_prompt ( self , image , prompt = None ) :
b64 = self . image2base64 ( image )
if prompt is None :
prompt = " Describe this image. "
# Chat messages
2025-11-19 13:17:22 +08:00
messages = [ { " role " : " user " , " content " : [ { " type " : " image_url " , " image_url " : { " url " : b64 } } , { " type " : " text " , " text " : prompt } ] } ]
2025-11-18 13:09:39 +08:00
2025-11-19 13:17:22 +08:00
resp = self . client . chat . completions . create ( model = self . model_name , messages = messages , stream = False )
2025-11-18 13:09:39 +08:00
content = resp . choices [ 0 ] . message . content . strip ( )
cleaned = re . sub ( r " < \ |(begin_of_box|end_of_box) \ |> " , " " , content ) . strip ( )
return cleaned , num_tokens_from_string ( cleaned )
2025-11-19 13:17:22 +08:00
2025-11-18 13:09:39 +08:00
2025-07-30 19:41:09 +08:00
class StepFunCV ( GptV4 ) :
_FACTORY_NAME = " StepFun "
2025-03-18 14:52:20 +08:00
2025-07-30 19:41:09 +08:00
def __init__ ( self , key , model_name = " step-1v-8k " , lang = " Chinese " , base_url = " https://api.stepfun.com/v1 " , * * kwargs ) :
if not base_url :
base_url = " https://api.stepfun.com/v1 "
self . client = OpenAI ( api_key = key , base_url = base_url )
2025-12-09 13:08:37 +08:00
self . async_client = AsyncOpenAI ( api_key = key , base_url = base_url )
2025-07-30 19:41:09 +08:00
self . model_name = model_name
self . lang = lang
Base . __init__ ( self , * * kwargs )
2025-03-18 14:52:20 +08:00
2025-11-19 13:17:22 +08:00
2025-10-17 11:43:22 +08:00
class VolcEngineCV ( GptV4 ) :
_FACTORY_NAME = " VolcEngine "
def __init__ ( self , key , model_name , lang = " Chinese " , base_url = " https://ark.cn-beijing.volces.com/api/v3 " , * * kwargs ) :
if not base_url :
base_url = " https://ark.cn-beijing.volces.com/api/v3 "
ark_api_key = json . loads ( key ) . get ( " ark_api_key " , " " )
self . client = OpenAI ( api_key = ark_api_key , base_url = base_url )
2025-12-09 13:08:37 +08:00
self . async_client = AsyncOpenAI ( api_key = ark_api_key , base_url = base_url )
2025-10-17 11:43:22 +08:00
self . model_name = json . loads ( key ) . get ( " ep_id " , " " ) + json . loads ( key ) . get ( " endpoint_id " , " " )
self . lang = lang
Base . __init__ ( self , * * kwargs )
2025-03-18 14:52:20 +08:00
2025-11-19 13:17:22 +08:00
2025-07-30 19:41:09 +08:00
class LmStudioCV ( GptV4 ) :
_FACTORY_NAME = " LM-Studio "
2024-02-08 17:01:01 +08:00
2025-07-30 19:41:09 +08:00
def __init__ ( self , key , model_name , lang = " Chinese " , base_url = " " , * * kwargs ) :
if not base_url :
raise ValueError ( " Local llm url cannot be None " )
base_url = urljoin ( base_url , " v1 " )
self . client = OpenAI ( api_key = " lm-studio " , base_url = base_url )
2025-12-09 13:08:37 +08:00
self . async_client = AsyncOpenAI ( api_key = " lm-studio " , base_url = base_url )
2025-07-30 19:41:09 +08:00
self . model_name = model_name
self . lang = lang
Base . __init__ ( self , * * kwargs )
2025-03-18 14:52:20 +08:00
2025-07-03 19:05:31 +08:00
2025-07-30 19:41:09 +08:00
class OpenAI_APICV ( GptV4 ) :
_FACTORY_NAME = [ " VLLM " , " OpenAI-API-Compatible " ]
2024-07-19 18:36:34 +08:00
2025-07-30 19:41:09 +08:00
def __init__ ( self , key , model_name , lang = " Chinese " , base_url = " " , * * kwargs ) :
if not base_url :
raise ValueError ( " url cannot be None " )
base_url = urljoin ( base_url , " v1 " )
self . client = OpenAI ( api_key = key , base_url = base_url )
2025-12-09 13:08:37 +08:00
self . async_client = AsyncOpenAI ( api_key = key , base_url = base_url )
2025-07-30 19:41:09 +08:00
self . model_name = model_name . split ( " ___ " ) [ 0 ]
self . lang = lang
Base . __init__ ( self , * * kwargs )
2024-07-19 18:36:34 +08:00
2025-07-30 19:41:09 +08:00
class TogetherAICV ( GptV4 ) :
_FACTORY_NAME = " TogetherAI "
2024-07-19 18:36:34 +08:00
2025-07-30 19:41:09 +08:00
def __init__ ( self , key , model_name , lang = " Chinese " , base_url = " https://api.together.xyz/v1 " , * * kwargs ) :
if not base_url :
base_url = " https://api.together.xyz/v1 "
super ( ) . __init__ ( key , model_name , lang , base_url , * * kwargs )
2025-03-18 14:52:20 +08:00
2025-07-03 19:05:31 +08:00
2025-07-30 19:41:09 +08:00
class YiCV ( GptV4 ) :
_FACTORY_NAME = " 01.AI "
2024-07-19 18:36:34 +08:00
2025-11-19 13:17:22 +08:00
def __init__ ( self , key , model_name , lang = " Chinese " , base_url = " https://api.lingyiwanwu.com/v1 " , * * kwargs ) :
2025-07-30 19:41:09 +08:00
if not base_url :
base_url = " https://api.lingyiwanwu.com/v1 "
super ( ) . __init__ ( key , model_name , lang , base_url , * * kwargs )
2024-07-19 18:36:34 +08:00
2025-07-30 19:41:09 +08:00
class SILICONFLOWCV ( GptV4 ) :
_FACTORY_NAME = " SILICONFLOW "
2024-07-19 18:36:34 +08:00
2025-11-19 13:17:22 +08:00
def __init__ ( self , key , model_name , lang = " Chinese " , base_url = " https://api.siliconflow.cn/v1 " , * * kwargs ) :
2025-07-30 19:41:09 +08:00
if not base_url :
base_url = " https://api.siliconflow.cn/v1 "
super ( ) . __init__ ( key , model_name , lang , base_url , * * kwargs )
2024-02-08 17:01:01 +08:00
2025-07-03 19:05:31 +08:00
2025-07-30 19:41:09 +08:00
class OpenRouterCV ( GptV4 ) :
_FACTORY_NAME = " OpenRouter "
2025-11-19 13:17:22 +08:00
def __init__ ( self , key , model_name , lang = " Chinese " , base_url = " https://openrouter.ai/api/v1 " , * * kwargs ) :
2025-07-30 19:41:09 +08:00
if not base_url :
base_url = " https://openrouter.ai/api/v1 "
2025-10-16 09:39:59 +08:00
api_key = json . loads ( key ) . get ( " api_key " , " " )
self . client = OpenAI ( api_key = api_key , base_url = base_url )
2025-12-09 13:08:37 +08:00
self . async_client = AsyncOpenAI ( api_key = api_key , base_url = base_url )
2024-02-08 17:01:01 +08:00
self . model_name = model_name
2024-02-23 18:28:12 +08:00
self . lang = lang
2025-07-30 19:41:09 +08:00
Base . __init__ ( self , * * kwargs )
2025-10-16 09:39:59 +08:00
provider_order = json . loads ( key ) . get ( " provider_order " , " " )
self . extra_body = { }
if provider_order :
2025-11-19 13:17:22 +08:00
2025-10-16 09:39:59 +08:00
def _to_order_list ( x ) :
if x is None :
return [ ]
if isinstance ( x , str ) :
return [ s . strip ( ) for s in x . split ( " , " ) if s . strip ( ) ]
if isinstance ( x , ( list , tuple ) ) :
return [ str ( s ) . strip ( ) for s in x if str ( s ) . strip ( ) ]
return [ ]
2025-11-19 13:17:22 +08:00
2025-10-16 09:39:59 +08:00
provider_cfg = { }
provider_order = _to_order_list ( provider_order )
provider_cfg [ " order " ] = provider_order
provider_cfg [ " allow_fallbacks " ] = False
self . extra_body [ " provider " ] = provider_cfg
2024-02-08 17:01:01 +08:00
2025-07-30 19:41:09 +08:00
class LocalAICV ( GptV4 ) :
_FACTORY_NAME = " LocalAI "
2025-03-18 14:52:20 +08:00
2025-07-30 19:41:09 +08:00
def __init__ ( self , key , model_name , base_url , lang = " Chinese " , * * kwargs ) :
if not base_url :
raise ValueError ( " Local cv model url cannot be None " )
base_url = urljoin ( base_url , " v1 " )
self . client = OpenAI ( api_key = " empty " , base_url = base_url )
2025-12-09 13:08:37 +08:00
self . async_client = AsyncOpenAI ( api_key = " empty " , base_url = base_url )
2025-07-30 19:41:09 +08:00
self . model_name = model_name . split ( " ___ " ) [ 0 ]
self . lang = lang
Base . __init__ ( self , * * kwargs )
2025-03-18 14:52:20 +08:00
2025-07-30 19:41:09 +08:00
class XinferenceCV ( GptV4 ) :
_FACTORY_NAME = " Xinference "
2024-03-12 11:57:08 +08:00
2025-07-30 19:41:09 +08:00
def __init__ ( self , key , model_name = " " , lang = " Chinese " , base_url = " " , * * kwargs ) :
base_url = urljoin ( base_url , " v1 " )
self . client = OpenAI ( api_key = key , base_url = base_url )
2025-12-09 13:08:37 +08:00
self . async_client = AsyncOpenAI ( api_key = key , base_url = base_url )
2025-07-30 19:41:09 +08:00
self . model_name = model_name
self . lang = lang
Base . __init__ ( self , * * kwargs )
2024-07-19 18:36:34 +08:00
2025-07-30 19:41:09 +08:00
class GPUStackCV ( GptV4 ) :
_FACTORY_NAME = " GPUStack "
2024-07-19 18:36:34 +08:00
2025-07-30 19:41:09 +08:00
def __init__ ( self , key , model_name , lang = " Chinese " , base_url = " " , * * kwargs ) :
if not base_url :
raise ValueError ( " Local llm url cannot be None " )
base_url = urljoin ( base_url , " v1 " )
self . client = OpenAI ( api_key = key , base_url = base_url )
2025-12-09 13:08:37 +08:00
self . async_client = AsyncOpenAI ( api_key = key , base_url = base_url )
2025-07-30 19:41:09 +08:00
self . model_name = model_name
self . lang = lang
Base . __init__ ( self , * * kwargs )
2024-07-19 18:36:34 +08:00
2025-07-30 19:41:09 +08:00
class LocalCV ( Base ) :
2025-11-07 19:52:57 +08:00
_FACTORY_NAME = " Local "
2025-07-30 19:41:09 +08:00
def __init__ ( self , key , model_name = " glm-4v " , lang = " Chinese " , * * kwargs ) :
pass
def describe ( self , image ) :
return " " , 0
2024-07-19 18:36:34 +08:00
2024-03-12 11:57:08 +08:00
2024-04-08 19:20:57 +08:00
class OllamaCV ( Base ) :
2025-07-03 19:05:31 +08:00
_FACTORY_NAME = " Ollama "
2024-04-08 19:20:57 +08:00
def __init__ ( self , key , model_name , lang = " Chinese " , * * kwargs ) :
2025-07-30 19:41:09 +08:00
from ollama import Client
2025-11-19 13:17:22 +08:00
2024-04-08 19:20:57 +08:00
self . client = Client ( host = kwargs [ " base_url " ] )
self . model_name = model_name
self . lang = lang
2025-07-25 12:16:33 +08:00
self . keep_alive = kwargs . get ( " ollama_keep_alive " , int ( os . environ . get ( " OLLAMA_KEEP_ALIVE " , - 1 ) ) )
2025-07-30 19:41:09 +08:00
Base . __init__ ( self , * * kwargs )
2025-08-14 13:45:38 +08:00
def _clean_img ( self , img ) :
if not isinstance ( img , str ) :
return img
2025-11-19 13:17:22 +08:00
# remove the header like "data/*;base64,"
2025-08-14 13:45:38 +08:00
if img . startswith ( " data: " ) and " ;base64, " in img :
img = img . split ( " ;base64, " ) [ 1 ]
return img
2025-07-30 19:41:09 +08:00
def _clean_conf ( self , gen_conf ) :
options = { }
if " temperature " in gen_conf :
options [ " temperature " ] = gen_conf [ " temperature " ]
if " top_p " in gen_conf :
options [ " top_k " ] = gen_conf [ " top_p " ]
if " presence_penalty " in gen_conf :
options [ " presence_penalty " ] = gen_conf [ " presence_penalty " ]
if " frequency_penalty " in gen_conf :
options [ " frequency_penalty " ] = gen_conf [ " frequency_penalty " ]
return options
2025-10-22 12:24:12 +08:00
def _form_history ( self , system , history , images = None ) :
2025-07-30 19:41:09 +08:00
hist = deepcopy ( history )
if system and hist [ 0 ] [ " role " ] == " user " :
hist . insert ( 0 , { " role " : " system " , " content " : system } )
if not images :
return hist
2025-08-14 13:45:38 +08:00
temp_images = [ ]
for img in images :
temp_images . append ( self . _clean_img ( img ) )
2025-07-30 19:41:09 +08:00
for his in hist :
if his [ " role " ] == " user " :
2025-08-14 13:45:38 +08:00
his [ " images " ] = temp_images
2025-07-30 19:41:09 +08:00
break
return hist
2024-04-08 19:20:57 +08:00
2025-03-18 14:52:20 +08:00
def describe ( self , image ) :
2024-04-08 19:20:57 +08:00
prompt = self . prompt ( " " )
try :
response = self . client . generate (
model = self . model_name ,
2025-10-29 09:45:28 +08:00
prompt = prompt [ 0 ] [ " content " ] ,
2025-07-03 19:05:31 +08:00
images = [ image ] ,
2024-04-08 19:20:57 +08:00
)
ans = response [ " response " ] . strip ( )
return ans , 128
except Exception as e :
return " **ERROR**: " + str ( e ) , 0
2025-03-18 14:52:20 +08:00
def describe_with_prompt ( self , image , prompt = None ) :
vision_prompt = self . vision_llm_prompt ( " " , prompt ) if prompt else self . vision_llm_prompt ( " " )
try :
response = self . client . generate (
model = self . model_name ,
2025-11-03 19:16:41 +08:00
prompt = vision_prompt [ 0 ] [ " content " ] ,
2025-03-18 14:52:20 +08:00
images = [ image ] ,
)
ans = response [ " response " ] . strip ( )
return ans , 128
except Exception as e :
return " **ERROR**: " + str ( e ) , 0
2025-12-09 13:08:37 +08:00
async def async_chat ( self , system , history , gen_conf , images = None , * * kwargs ) :
2024-07-19 18:36:34 +08:00
try :
2026-01-20 13:29:37 +08:00
response = await thread_pool_exec ( self . client . chat , model = self . model_name , messages = self . _form_history ( system , history , images ) , options = self . _clean_conf ( gen_conf ) , keep_alive = self . keep_alive )
2024-07-19 18:36:34 +08:00
ans = response [ " message " ] [ " content " ] . strip ( )
return ans , response [ " eval_count " ] + response . get ( " prompt_eval_count " , 0 )
except Exception as e :
return " **ERROR**: " + str ( e ) , 0
2025-12-09 13:08:37 +08:00
async def async_chat_streamly ( self , system , history , gen_conf , images = None , * * kwargs ) :
2024-07-19 18:36:34 +08:00
ans = " "
try :
2026-01-20 13:29:37 +08:00
response = await thread_pool_exec ( self . client . chat , model = self . model_name , messages = self . _form_history ( system , history , images ) , stream = True , options = self . _clean_conf ( gen_conf ) , keep_alive = self . keep_alive )
2024-07-19 18:36:34 +08:00
for resp in response :
if resp [ " done " ] :
yield resp . get ( " prompt_eval_count " , 0 ) + resp . get ( " eval_count " , 0 )
2025-07-30 19:41:09 +08:00
ans = resp [ " message " ] [ " content " ]
2024-07-19 18:36:34 +08:00
yield ans
except Exception as e :
yield ans + " \n **ERROR**: " + str ( e )
yield 0
2024-04-11 18:25:37 +08:00
2024-07-11 15:41:00 +08:00
class GeminiCV ( Base ) :
2025-07-03 19:05:31 +08:00
_FACTORY_NAME = " Gemini "
2024-07-11 15:41:00 +08:00
def __init__ ( self , key , model_name = " gemini-1.0-pro-vision-latest " , lang = " Chinese " , * * kwargs ) :
2025-11-19 13:17:22 +08:00
from google import genai
2025-07-03 19:05:31 +08:00
2025-11-19 13:17:22 +08:00
self . api_key = key
2024-07-11 15:41:00 +08:00
self . model_name = model_name
2025-11-19 13:17:22 +08:00
self . client = genai . Client ( api_key = key )
2025-03-18 14:52:20 +08:00
self . lang = lang
2025-07-30 19:41:09 +08:00
Base . __init__ ( self , * * kwargs )
2025-11-19 13:17:22 +08:00
logging . info ( f " [GeminiCV] Initialized with model= { self . model_name } lang= { self . lang } " )
def _image_to_part ( self , image ) :
from google . genai import types
if isinstance ( image , str ) and image . startswith ( " data: " ) and " ;base64, " in image :
header , b64data = image . split ( " , " , 1 )
mime = header . split ( " : " , 1 ) [ 1 ] . split ( " ; " , 1 ) [ 0 ]
data = base64 . b64decode ( b64data )
else :
data_url = self . image2base64 ( image )
header , b64data = data_url . split ( " , " , 1 )
mime = header . split ( " : " , 1 ) [ 1 ] . split ( " ; " , 1 ) [ 0 ]
data = base64 . b64decode ( b64data )
return types . Part (
inline_data = types . Blob (
mime_type = mime ,
data = data ,
)
)
2025-07-30 19:41:09 +08:00
2025-10-22 12:24:12 +08:00
def _form_history ( self , system , history , images = None ) :
2025-11-19 13:17:22 +08:00
from google . genai import types
contents = [ ]
images = images or [ ]
system_len = len ( system ) if isinstance ( system , str ) else 0
history_len = len ( history ) if history else 0
images_len = len ( images )
logging . info ( f " [GeminiCV] _form_history called: system_len= { system_len } history_len= { history_len } images_len= { images_len } " )
image_parts = [ ]
2025-07-30 19:41:09 +08:00
for img in images :
2025-11-19 13:17:22 +08:00
try :
image_parts . append ( self . _image_to_part ( img ) )
except Exception :
continue
remaining_history = history or [ ]
if system or remaining_history :
parts = [ ]
if system :
parts . append ( types . Part ( text = system ) )
if remaining_history :
first = remaining_history [ 0 ]
parts . append ( types . Part ( text = first . get ( " content " , " " ) ) )
remaining_history = remaining_history [ 1 : ]
parts . extend ( image_parts )
contents . append ( types . Content ( role = " user " , parts = parts ) )
elif image_parts :
contents . append ( types . Content ( role = " user " , parts = image_parts ) )
role_map = { " user " : " user " , " assistant " : " model " , " system " : " user " }
for h in remaining_history :
role = role_map . get ( h . get ( " role " ) , " user " )
contents . append (
types . Content (
role = role ,
parts = [ types . Part ( text = h . get ( " content " , " " ) ) ] ,
)
)
return contents
2024-07-16 15:19:43 +08:00
2025-03-18 14:52:20 +08:00
def describe ( self , image ) :
2025-11-19 13:17:22 +08:00
from google . genai import types
2025-07-03 19:05:31 +08:00
prompt = (
" 请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。 "
if self . lang . lower ( ) == " chinese "
else " Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out. "
)
2025-11-10 09:31:19 +08:00
2025-11-19 13:17:22 +08:00
contents = [
types . Content (
role = " user " ,
parts = [
types . Part ( text = prompt ) ,
self . _image_to_part ( image ) ,
] ,
)
]
res = self . client . models . generate_content (
model = self . model_name ,
contents = contents ,
)
return res . text , total_token_count_from_response ( res )
2025-03-18 14:52:20 +08:00
def describe_with_prompt ( self , image , prompt = None ) :
2025-11-19 13:17:22 +08:00
from google . genai import types
2025-07-11 11:34:04 +08:00
vision_prompt = prompt if prompt else vision_llm_describe_prompt ( )
2025-11-10 09:31:19 +08:00
2025-11-19 13:17:22 +08:00
contents = [
types . Content (
role = " user " ,
parts = [
types . Part ( text = vision_prompt ) ,
self . _image_to_part ( image ) ,
] ,
)
]
2024-04-08 19:20:57 +08:00
2025-11-19 13:17:22 +08:00
res = self . client . models . generate_content (
model = self . model_name ,
contents = contents ,
)
return res . text , total_token_count_from_response ( res )
2025-10-20 16:49:47 +08:00
2025-12-09 13:08:37 +08:00
async def async_chat ( self , system , history , gen_conf , images = None , video_bytes = None , filename = " " , * * kwargs ) :
2025-10-20 16:49:47 +08:00
if video_bytes :
try :
2025-11-19 13:17:22 +08:00
size = len ( video_bytes ) if video_bytes else 0
2025-12-09 13:08:37 +08:00
logging . info ( f " [GeminiCV] async_chat called with video: filename= { filename } size= { size } " )
2026-01-20 13:29:37 +08:00
summary , summary_num_tokens = await thread_pool_exec ( self . _process_video , video_bytes , filename )
2025-10-20 16:49:47 +08:00
return summary , summary_num_tokens
except Exception as e :
2025-12-09 13:08:37 +08:00
logging . info ( f " [GeminiCV] async_chat video error: { e } " )
2025-10-20 16:49:47 +08:00
return " **ERROR**: " + str ( e ) , 0
2025-11-19 13:17:22 +08:00
from google . genai import types
history_len = len ( history ) if history else 0
images_len = len ( images ) if images else 0
2025-12-09 13:08:37 +08:00
logging . info ( f " [GeminiCV] async_chat called: history_len= { history_len } images_len= { images_len } gen_conf= { gen_conf } " )
2025-11-19 13:17:22 +08:00
generation_config = types . GenerateContentConfig (
temperature = gen_conf . get ( " temperature " , 0.3 ) ,
top_p = gen_conf . get ( " top_p " , 0.7 ) ,
)
2024-07-19 18:36:34 +08:00
try :
2025-12-09 13:08:37 +08:00
response = await self . client . aio . models . generate_content (
2025-11-19 13:17:22 +08:00
model = self . model_name ,
contents = self . _form_history ( system , history , images ) ,
config = generation_config ,
)
2024-07-19 18:36:34 +08:00
ans = response . text
2025-12-09 13:08:37 +08:00
logging . info ( " [GeminiCV] async_chat completed " )
2025-11-19 13:17:22 +08:00
return ans , total_token_count_from_response ( response )
2024-07-19 18:36:34 +08:00
except Exception as e :
2025-12-09 13:08:37 +08:00
logging . warning ( f " [GeminiCV] async_chat error: { e } " )
2024-07-19 18:36:34 +08:00
return " **ERROR**: " + str ( e ) , 0
2025-12-09 13:08:37 +08:00
async def async_chat_streamly ( self , system , history , gen_conf , images = None , * * kwargs ) :
2024-07-19 18:36:34 +08:00
ans = " "
2025-08-05 11:00:06 +08:00
response = None
2024-07-19 18:36:34 +08:00
try :
2025-11-19 13:17:22 +08:00
from google . genai import types
generation_config = types . GenerateContentConfig (
temperature = gen_conf . get ( " temperature " , 0.3 ) ,
top_p = gen_conf . get ( " top_p " , 0.7 ) ,
2025-07-03 19:05:31 +08:00
)
2025-11-19 13:17:22 +08:00
history_len = len ( history ) if history else 0
images_len = len ( images ) if images else 0
2025-12-09 13:08:37 +08:00
logging . info ( f " [GeminiCV] async_chat_streamly called: history_len= { history_len } images_len= { images_len } gen_conf= { gen_conf } " )
2024-07-19 18:36:34 +08:00
2025-12-09 13:08:37 +08:00
response_stream = await self . client . aio . models . generate_content_stream (
2025-11-19 13:17:22 +08:00
model = self . model_name ,
contents = self . _form_history ( system , history , images ) ,
config = generation_config ,
)
2025-12-09 13:08:37 +08:00
async for chunk in response_stream :
2025-11-19 13:17:22 +08:00
if chunk . text :
ans + = chunk . text
yield chunk . text
logging . info ( " [GeminiCV] chat_streamly completed " )
2024-07-19 18:36:34 +08:00
except Exception as e :
2025-11-19 13:17:22 +08:00
logging . warning ( f " [GeminiCV] chat_streamly error: { e } " )
2024-07-19 18:36:34 +08:00
yield ans + " \n **ERROR**: " + str ( e )
2025-10-09 09:47:36 +08:00
yield total_token_count_from_response ( response )
2025-08-18 14:51:10 +08:00
2025-10-20 16:49:47 +08:00
def _process_video ( self , video_bytes , filename ) :
from google import genai
from google . genai import types
video_size_mb = len ( video_bytes ) / ( 1024 * 1024 )
2025-11-19 13:17:22 +08:00
client = self . client if hasattr ( self , " client " ) else genai . Client ( api_key = self . api_key )
logging . info ( f " [GeminiCV] _process_video called: filename= { filename } size_mb= { video_size_mb : .2f } " )
2025-10-20 16:49:47 +08:00
tmp_path = None
try :
if video_size_mb < = 20 :
response = client . models . generate_content (
model = " models/gemini-2.5-flash " ,
2025-11-19 13:17:22 +08:00
contents = types . Content ( parts = [ types . Part ( inline_data = types . Blob ( data = video_bytes , mime_type = " video/mp4 " ) ) , types . Part ( text = " Please summarize the video in proper sentences. " ) ] ) ,
2025-10-20 16:49:47 +08:00
)
else :
logging . info ( f " Video size { video_size_mb : .2f } MB exceeds 20MB. Using Files API... " )
video_suffix = Path ( filename ) . suffix or " .mp4 "
with tempfile . NamedTemporaryFile ( delete = False , suffix = video_suffix ) as tmp :
tmp . write ( video_bytes )
tmp_path = Path ( tmp . name )
uploaded_file = client . files . upload ( file = tmp_path )
2025-11-19 13:17:22 +08:00
response = client . models . generate_content ( model = " gemini-2.5-flash " , contents = [ uploaded_file , " Please summarize this video in proper sentences. " ] )
2025-10-20 16:49:47 +08:00
summary = response . text or " "
2025-11-19 13:17:22 +08:00
logging . info ( f " [GeminiCV] Video summarized: { summary [ : 32 ] } ... " )
2025-10-20 16:49:47 +08:00
return summary , num_tokens_from_string ( summary )
except Exception as e :
2025-11-19 13:17:22 +08:00
logging . warning ( f " [GeminiCV] Video processing failed: { e } " )
2025-10-20 16:49:47 +08:00
raise
finally :
if tmp_path and tmp_path . exists ( ) :
tmp_path . unlink ( )
2024-07-16 15:19:43 +08:00
2024-07-23 10:43:09 +08:00
class NvidiaCV ( Base ) :
2025-07-03 19:05:31 +08:00
_FACTORY_NAME = " NVIDIA "
2025-11-19 13:17:22 +08:00
def __init__ ( self , key , model_name , lang = " Chinese " , base_url = " https://ai.api.nvidia.com/v1/vlm " , * * kwargs ) :
2024-07-23 10:43:09 +08:00
if not base_url :
base_url = ( " https://ai.api.nvidia.com/v1/vlm " , )
self . lang = lang
factory , llm_name = model_name . split ( " / " )
if factory != " liuhaotian " :
2025-06-03 14:18:40 +08:00
self . base_url = urljoin ( base_url , f " { factory } / { llm_name } " )
2024-07-23 10:43:09 +08:00
else :
2025-06-03 14:18:40 +08:00
self . base_url = urljoin ( f " { base_url } /community " , llm_name . replace ( " -v1.6 " , " 16 " ) )
2024-07-23 10:43:09 +08:00
self . key = key
2025-07-30 19:41:09 +08:00
Base . __init__ ( self , * * kwargs )
def _image_prompt ( self , text , images ) :
if not images :
return text
htmls = " "
for img in images :
htmls + = ' <img src= " {} " /> ' . format ( f " data:image/jpeg;base64, { img } " if img [ : 4 ] != " data " else img )
return text + htmls
2024-07-23 10:43:09 +08:00
2025-03-18 14:52:20 +08:00
def describe ( self , image ) :
2024-07-23 10:43:09 +08:00
b64 = self . image2base64 ( image )
response = requests . post (
url = self . base_url ,
headers = {
" accept " : " application/json " ,
" content-type " : " application/json " ,
" Authorization " : f " Bearer { self . key } " ,
} ,
2025-07-03 19:05:31 +08:00
json = { " messages " : self . prompt ( b64 ) } ,
2024-07-23 10:43:09 +08:00
)
response = response . json ( )
return (
response [ " choices " ] [ 0 ] [ " message " ] [ " content " ] . strip ( ) ,
2025-10-10 11:03:40 +08:00
total_token_count_from_response ( response ) ,
2024-07-23 10:43:09 +08:00
)
fix(llm): replace mutable default `gen_conf={}` with None + defensive copy (#14566)
### What
19 methods across `rag/llm/chat_model.py` and `rag/llm/cv_model.py`
declare `gen_conf={}` (or `gen_conf: dict = {}`) as a parameter default
and then mutate `gen_conf` in place — typically `del
gen_conf["max_tokens"]`, `gen_conf["penalty_score"] = ...`, or
`gen_conf.pop(...)` as part of provider-specific normalization.
### The two bugs in this pattern
**1. Mutable default argument (Python footgun).** Python evaluates
default values **once** at function-definition time, so the single `{}`
dict is *shared* across every caller that doesn't pass `gen_conf`. The
first such call's mutations leak into the default seen by every
subsequent call.
```python
# Before
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"] # mutates the SHARED default dict
...
```
After call N with `max_tokens` set, call N+1 that omits `gen_conf` no
longer sees `max_tokens` — even though the caller never touched it.
**2. Caller-dict pollution.** When the caller *does* pass a `gen_conf`
dict, the same in-place mutations modify the caller's dict. A reused
`gen_conf` (very common for chat-loop callers that build the config once
and pass it on every turn) silently loses `max_tokens`,
`presence_penalty`, etc. after the first round.
### The fix
In every affected method:
- Change `gen_conf={}` (or `gen_conf: dict = {}`) → `gen_conf=None`.
- Add `gen_conf = dict(gen_conf or {})` as the first statement of the
body so all subsequent mutations operate on a fresh local copy.
```python
# After
def chat_streamly(self, system, history, gen_conf=None, **kwargs):
gen_conf = dict(gen_conf or {})
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"] # local copy — safe
...
```
This is byte-for-byte identical provider-side behavior for callers that
already pass a fresh `gen_conf` per call. The new `dict(...)` copy is
O(small constant) per call.
### Files changed
- `rag/llm/chat_model.py` — 17 methods
- `rag/llm/cv_model.py` — 2 methods
### Tests
Adds `test/unit_test/rag/llm/test_gen_conf_no_mutable_default.py` — an
`ast`-based regression guard that walks both modules and asserts no
parameter named `gen_conf` ever has a mutable literal (`{}` or `[]`) as
its default. The test caught **five additional `gen_conf: dict = {}`
sites** that an initial `gen_conf={}` text grep had missed (annotated
parameters with whitespace), and would fail again if the pattern is ever
reintroduced.
```
$ pytest test/unit_test/rag/llm/test_gen_conf_no_mutable_default.py -v
============================== 3 passed in 0.04s ===============================
```
`ruff check` passes on all touched files.
### Notes
- This PR is intentionally focused on **just** the `gen_conf` default +
copy fix. There's a related (but separate) `history.insert(0, ...)`
pattern in the same files that mutates the caller's history list in 12
places — left for a follow-up so this PR stays mechanical and easy to
review.
### Latest revision (`700bb54a7`) — addresses CodeRabbit review
- Type annotation: `gen_conf: dict = None` → `gen_conf: dict | None =
None` (5 occurrences in `chat_model.py`). The old annotation was a
static-checker mismatch since `None` isn't a `dict`.
- Regression test: the AST check accessed `default.keys` directly.
`ast.List` has no `.keys` attribute — a future `gen_conf=[]` would crash
with `AttributeError` instead of being caught. Use `getattr` for both
`.keys` (Dict) and `.elts` (List). Manually verified the updated check
correctly catches both `gen_conf={}` and `gen_conf=[]` while ignoring
`gen_conf=None` and non-empty literals.
---------
Co-authored-by: Ricardo <ricardo@example.com>
2026-05-09 13:11:44 +08:00
def _request ( self , msg , gen_conf = None ) :
gen_conf = dict ( gen_conf or { } )
2025-03-18 14:52:20 +08:00
response = requests . post (
url = self . base_url ,
headers = {
" accept " : " application/json " ,
" content-type " : " application/json " ,
" Authorization " : f " Bearer { self . key } " ,
} ,
2025-11-19 13:17:22 +08:00
json = { " messages " : msg , * * gen_conf } ,
2025-03-18 14:52:20 +08:00
)
2025-07-30 19:41:09 +08:00
return response . json ( )
def describe_with_prompt ( self , image , prompt = None ) :
b64 = self . image2base64 ( image )
vision_prompt = self . vision_llm_prompt ( b64 , prompt ) if prompt else self . vision_llm_prompt ( b64 )
response = self . _request ( vision_prompt )
2025-11-19 13:17:22 +08:00
return ( response [ " choices " ] [ 0 ] [ " message " ] [ " content " ] . strip ( ) , total_token_count_from_response ( response ) )
2025-03-18 14:52:20 +08:00
2025-12-09 13:08:37 +08:00
async def async_chat ( self , system , history , gen_conf , images = None , * * kwargs ) :
2024-08-20 15:27:13 +08:00
try :
2026-01-20 13:29:37 +08:00
response = await thread_pool_exec ( self . _request , self . _form_history ( system , history , images ) , gen_conf )
2025-11-19 13:17:22 +08:00
return ( response [ " choices " ] [ 0 ] [ " message " ] [ " content " ] . strip ( ) , total_token_count_from_response ( response ) )
2025-07-30 19:41:09 +08:00
except Exception as e :
return " **ERROR**: " + str ( e ) , 0
2025-03-18 14:52:20 +08:00
2025-12-09 13:08:37 +08:00
async def async_chat_streamly ( self , system , history , gen_conf , images = None , * * kwargs ) :
2025-08-05 17:47:00 +08:00
total_tokens = 0
2025-03-18 14:52:20 +08:00
try :
2026-01-20 13:29:37 +08:00
response = await thread_pool_exec ( self . _request , self . _form_history ( system , history , images ) , gen_conf )
2025-07-30 19:41:09 +08:00
cnt = response [ " choices " ] [ 0 ] [ " message " ] [ " content " ]
2025-11-19 13:17:22 +08:00
total_tokens + = total_token_count_from_response ( response )
2025-07-30 19:41:09 +08:00
for resp in cnt :
yield resp
except Exception as e :
yield " \n **ERROR**: " + str ( e )
2025-03-18 14:52:20 +08:00
2025-08-05 17:47:00 +08:00
yield total_tokens
2025-03-24 12:34:57 +08:00
class AnthropicCV ( Base ) :
2025-07-03 19:05:31 +08:00
_FACTORY_NAME = " Anthropic "
2025-07-30 19:41:09 +08:00
def __init__ ( self , key , model_name , base_url = None , * * kwargs ) :
2025-03-24 12:34:57 +08:00
import anthropic
self . client = anthropic . Anthropic ( api_key = key )
2025-12-09 13:08:37 +08:00
self . async_client = anthropic . AsyncAnthropic ( api_key = key )
2025-03-24 12:34:57 +08:00
self . model_name = model_name
self . system = " "
self . max_tokens = 8192
if " haiku " in self . model_name or " opus " in self . model_name :
self . max_tokens = 4096
2025-07-30 19:41:09 +08:00
Base . __init__ ( self , * * kwargs )
def _image_prompt ( self , text , images ) :
if not images :
return text
pmpt = [ { " type " : " text " , " text " : text } ]
for img in images :
2025-11-19 13:17:22 +08:00
pmpt . append (
{
" type " : " image " ,
" source " : {
" type " : " base64 " ,
" media_type " : ( img . split ( " : " ) [ 1 ] . split ( " ; " ) [ 0 ] if isinstance ( img , str ) and img [ : 4 ] == " data " else " image/png " ) ,
" data " : ( img . split ( " , " ) [ 1 ] if isinstance ( img , str ) and img [ : 4 ] == " data " else img ) ,
} ,
}
2025-07-30 19:41:09 +08:00
)
return pmpt
2025-03-24 12:34:57 +08:00
def describe ( self , image ) :
b64 = self . image2base64 ( image )
2025-07-30 19:41:09 +08:00
response = self . client . messages . create ( model = self . model_name , max_tokens = self . max_tokens , messages = self . prompt ( b64 ) )
2025-07-03 19:05:31 +08:00
return response [ " content " ] [ 0 ] [ " text " ] . strip ( ) , response [ " usage " ] [ " input_tokens " ] + response [ " usage " ] [ " output_tokens " ]
2025-03-24 12:34:57 +08:00
def describe_with_prompt ( self , image , prompt = None ) :
b64 = self . image2base64 ( image )
prompt = self . prompt ( b64 , prompt if prompt else vision_llm_describe_prompt ( ) )
2025-07-03 19:05:31 +08:00
response = self . client . messages . create ( model = self . model_name , max_tokens = self . max_tokens , messages = prompt )
2025-10-29 09:41:15 +08:00
return response [ " content " ] [ 0 ] [ " text " ] . strip ( ) , total_token_count_from_response ( response )
2025-03-24 12:34:57 +08:00
2025-07-30 19:41:09 +08:00
def _clean_conf ( self , gen_conf ) :
2025-03-24 12:34:57 +08:00
if " presence_penalty " in gen_conf :
del gen_conf [ " presence_penalty " ]
if " frequency_penalty " in gen_conf :
del gen_conf [ " frequency_penalty " ]
2025-07-30 19:41:09 +08:00
if " max_token " in gen_conf :
gen_conf [ " max_tokens " ] = self . max_tokens
return gen_conf
2025-03-24 12:34:57 +08:00
2025-12-09 13:08:37 +08:00
async def async_chat ( self , system , history , gen_conf , images = None , * * kwargs ) :
2025-07-30 19:41:09 +08:00
gen_conf = self . _clean_conf ( gen_conf )
2025-03-24 12:34:57 +08:00
ans = " "
try :
2025-12-09 13:08:37 +08:00
response = await self . async_client . messages . create (
2025-03-24 12:34:57 +08:00
model = self . model_name ,
2025-07-30 19:41:09 +08:00
messages = self . _form_history ( system , history , images ) ,
2025-03-24 12:34:57 +08:00
system = system ,
stream = False ,
* * gen_conf ,
2025-12-09 13:08:37 +08:00
)
response = response . to_dict ( )
2025-03-24 12:34:57 +08:00
ans = response [ " content " ] [ 0 ] [ " text " ]
if response [ " stop_reason " ] == " max_tokens " :
2025-07-03 19:05:31 +08:00
ans + = " ... \n For the content length reason, it stopped, continue? " if is_english ( [ ans ] ) else " ······ \n 由于长度的原因,回答被截断了,要继续吗? "
2025-03-24 12:34:57 +08:00
return (
ans ,
2025-10-29 09:41:15 +08:00
total_token_count_from_response ( response ) ,
2025-03-24 12:34:57 +08:00
)
except Exception as e :
return ans + " \n **ERROR**: " + str ( e ) , 0
2025-12-09 13:08:37 +08:00
async def async_chat_streamly ( self , system , history , gen_conf , images = None , * * kwargs ) :
2025-07-30 19:41:09 +08:00
gen_conf = self . _clean_conf ( gen_conf )
2025-03-24 12:34:57 +08:00
total_tokens = 0
try :
2025-12-09 13:08:37 +08:00
response = self . async_client . messages . create (
2025-03-24 12:34:57 +08:00
model = self . model_name ,
2025-07-30 19:41:09 +08:00
messages = self . _form_history ( system , history , images ) ,
2025-03-24 12:34:57 +08:00
system = system ,
stream = True ,
* * gen_conf ,
)
2025-07-30 19:41:09 +08:00
think = False
2025-12-09 13:08:37 +08:00
async for res in response :
2025-07-03 19:05:31 +08:00
if res . type == " content_block_delta " :
2025-03-24 12:34:57 +08:00
if res . delta . type == " thinking_delta " and res . delta . thinking :
2025-07-30 19:41:09 +08:00
if not think :
yield " <think> "
think = True
yield res . delta . thinking
total_tokens + = num_tokens_from_string ( res . delta . thinking )
elif think :
yield " </think> "
2025-03-24 12:34:57 +08:00
else :
2025-07-30 19:41:09 +08:00
yield res . delta . text
total_tokens + = num_tokens_from_string ( res . delta . text )
2025-03-24 12:34:57 +08:00
except Exception as e :
2025-07-30 19:41:09 +08:00
yield " \n **ERROR**: " + str ( e )
2025-03-24 12:34:57 +08:00
2025-03-31 15:33:52 +08:00
yield total_tokens
2025-07-03 19:05:31 +08:00
2025-07-30 19:41:09 +08:00
class GoogleCV ( AnthropicCV , GeminiCV ) :
2025-07-03 19:05:31 +08:00
_FACTORY_NAME = " Google Cloud "
2025-07-02 09:02:01 +07:00
def __init__ ( self , key , model_name , lang = " Chinese " , base_url = None , * * kwargs ) :
import base64
from google . oauth2 import service_account
2025-07-03 19:05:31 +08:00
2025-07-02 09:02:01 +07:00
key = json . loads ( key )
access_token = json . loads ( base64 . b64decode ( key . get ( " google_service_account_key " , " " ) ) )
project_id = key . get ( " google_project_id " , " " )
region = key . get ( " google_region " , " " )
scopes = [ " https://www.googleapis.com/auth/cloud-platform " ]
self . model_name = model_name
self . lang = lang
if " claude " in self . model_name :
from anthropic import AnthropicVertex
from google . auth . transport . requests import Request
if access_token :
credits = service_account . Credentials . from_service_account_info ( access_token , scopes = scopes )
request = Request ( )
credits . refresh ( request )
token = credits . token
self . client = AnthropicVertex ( region = region , project_id = project_id , access_token = token )
else :
self . client = AnthropicVertex ( region = region , project_id = project_id )
else :
2026-02-24 10:28:33 +08:00
from google import genai
2025-07-02 09:02:01 +07:00
if access_token :
2026-02-24 10:28:33 +08:00
credits = service_account . Credentials . from_service_account_info ( access_token , scopes = scopes )
self . client = genai . Client ( vertexai = True , project = project_id , location = region , credentials = credits )
2025-07-02 09:02:01 +07:00
else :
2026-02-24 10:28:33 +08:00
self . client = genai . Client ( vertexai = True , project = project_id , location = region )
2025-07-30 19:41:09 +08:00
Base . __init__ ( self , * * kwargs )
2025-07-02 09:02:01 +07:00
def describe ( self , image ) :
if " claude " in self . model_name :
2025-07-30 19:41:09 +08:00
return AnthropicCV . describe ( self , image )
2025-07-02 09:02:01 +07:00
else :
2025-07-30 19:41:09 +08:00
return GeminiCV . describe ( self , image )
2025-07-02 09:02:01 +07:00
def describe_with_prompt ( self , image , prompt = None ) :
if " claude " in self . model_name :
2025-07-30 19:41:09 +08:00
return AnthropicCV . describe_with_prompt ( self , image , prompt )
2025-07-02 09:02:01 +07:00
else :
2025-07-30 19:41:09 +08:00
return GeminiCV . describe_with_prompt ( self , image , prompt )
2025-07-03 19:05:31 +08:00
2025-12-09 13:08:37 +08:00
async def async_chat ( self , system , history , gen_conf , images = None , * * kwargs ) :
2025-07-30 19:41:09 +08:00
if " claude " in self . model_name :
2025-12-09 13:08:37 +08:00
return await AnthropicCV . async_chat ( self , system , history , gen_conf , images )
2025-07-30 19:41:09 +08:00
else :
2025-12-09 13:08:37 +08:00
return await GeminiCV . async_chat ( self , system , history , gen_conf , images )
2025-07-02 09:02:01 +07:00
2025-12-09 13:08:37 +08:00
async def async_chat_streamly ( self , system , history , gen_conf , images = None , * * kwargs ) :
2025-07-02 09:02:01 +07:00
if " claude " in self . model_name :
2025-12-09 13:08:37 +08:00
async for ans in AnthropicCV . async_chat_streamly ( self , system , history , gen_conf , images ) :
2025-07-30 19:41:09 +08:00
yield ans
2025-07-02 09:02:01 +07:00
else :
2025-12-09 13:08:37 +08:00
async for ans in GeminiCV . async_chat_streamly ( self , system , history , gen_conf , images ) :
2025-08-05 09:26:42 +08:00
yield ans
2025-11-07 19:52:57 +08:00
class MoonshotCV ( GptV4 ) :
_FACTORY_NAME = " Moonshot "
def __init__ ( self , key , model_name = " moonshot-v1-8k-vision-preview " , lang = " Chinese " , base_url = " https://api.moonshot.cn/v1 " , * * kwargs ) :
if not base_url :
base_url = " https://api.moonshot.cn/v1 "
super ( ) . __init__ ( key , model_name , lang = lang , base_url = base_url , * * kwargs )
2026-03-06 02:37:27 +01:00
feat: add FuturMix as model provider (#14419)
## Summary
Add [FuturMix](https://futurmix.ai) as a new model provider. FuturMix is
an OpenAI-compatible unified AI gateway that provides access to 22+
models (GPT, Claude, Gemini, DeepSeek, and more) through a single API
endpoint and key.
- **API Base**: `https://futurmix.ai/v1` (OpenAI-compatible)
- **Supported capabilities**: Chat, Embedding, Image2Text, TTS,
Speech2Text, Rerank
### Changes
| File | Change |
|------|--------|
| `rag/llm/__init__.py` | Add `FuturMix` to `SupportedLiteLLMProvider`
enum, `FACTORY_DEFAULT_BASE_URL`, and `LITELLM_PROVIDER_PREFIX` |
| `rag/llm/chat_model.py` | Add `FuturMixChat(Base)` — follows
Astraflow/Avian pattern |
| `rag/llm/embedding_model.py` | Add `FuturMixEmbed(OpenAIEmbed)` —
follows Astraflow pattern |
| `rag/llm/cv_model.py` | Add `FuturMixCV(GptV4)` — follows
SILICONFLOW/OpenRouter pattern |
| `rag/llm/tts_model.py` | Add `FuturMixTTS(OpenAITTS)` — follows
CometAPI/DeerAPI pattern |
| `rag/llm/sequence2txt_model.py` | Add `FuturMixSeq2txt(GPTSeq2txt)` —
follows StepFun pattern |
| `rag/llm/rerank_model.py` | Add `FuturMixRerank(OpenAI_APIRerank)` |
| `conf/llm_factories.json` | Add factory config with 8 chat, 2
embedding, 1 image2text, 2 TTS, 1 speech2text models |
| `docs/guides/models/supported_models.mdx` | Add FuturMix to supported
models table |
### Models included
- **Chat**: claude-sonnet-4-20250514, claude-3.5-haiku, gpt-4o,
gpt-4o-mini, gemini-2.5-flash, gemini-2.0-flash, deepseek-chat,
deepseek-reasoner
- **Embedding**: text-embedding-3-small, text-embedding-3-large
- **Image2Text**: gpt-4o
- **TTS**: tts-1, tts-1-hd
- **Speech2Text**: whisper-1
## Test plan
- [ ] Verify FuturMix appears in the model provider list in RAGFlow UI
- [ ] Configure FuturMix with API key and test chat completion
- [ ] Test embedding model with document indexing
- [ ] Test image2text with a sample image
🤖 Generated with [Claude Code](https://claude.com/claude-code)
---------
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-30 10:59:37 +08:00
class FuturMixCV ( GptV4 ) :
_FACTORY_NAME = " FuturMix "
def __init__ ( self , key , model_name , lang = " Chinese " , base_url = " https://futurmix.ai/v1 " , * * kwargs ) :
if not base_url :
base_url = " https://futurmix.ai/v1 "
super ( ) . __init__ ( key , model_name , lang = lang , base_url = base_url , * * kwargs )
logging . info ( " [FuturMix] CV initialized with model %s " , model_name )
2026-03-06 02:37:27 +01:00
class RAGconCV ( GptV4 ) :
"""
RAGcon CV Provider - routes through LiteLLM proxy
Supports vision models through LiteLLM .
Default Base URL : https : / / connect . ragcon . ai / v1
"""
_FACTORY_NAME = " RAGcon "
def __init__ ( self , key , model_name , lang = " Chinese " , base_url = " " , * * kwargs ) :
Feat: add BedrockCV for vision/image2text inference via LiteLLM (#14705)
## Summary
- `CvModel["Bedrock"]` was absent from `rag/llm/cv_model.py`, causing
`model_instance()` to return `None` when a Bedrock model was used as a
PDF parser — even after correct model resolution.
- This PR adds `BedrockCV`, enabling Bedrock vision models (e.g.
`amazon.nova-pro-v1:0`, `anthropic.claude-3-5-sonnet`) to be used as PDF
parsers.
## What problem does this PR solve?
When a Bedrock model is selected as the PDF parser in a knowledge base,
ingestion failed with:
```
'LiteLLMBase' object has no attribute 'describe_with_prompt'
```
The root cause: `LiteLLMBase` (the Bedrock chat implementation) was the
only registered handler for the Bedrock factory. It does not implement
`describe_with_prompt`. `CvModel` had no Bedrock entry, so
`model_instance()` returned `None` for `image2text` requests.
## Type of change
- [x] New Feature (non-breaking change which adds functionality)
## Changes
**`rag/llm/cv_model.py`**
Adds `BedrockCV(Base)` with `_FACTORY_NAME = "Bedrock"`:
- Uses `litellm.completion` with the `bedrock/` prefix (consistent with
`LiteLLMBase`)
- Parses AWS credentials from the JSON key assembled by `add_llm`
(`auth_mode`, `bedrock_ak`, `bedrock_sk`, `bedrock_region`,
`aws_role_arn`)
- Supports three auth modes: `access_key_secret`, `iam_role` (via STS
`assume_role`), and default credential chain (IRSA, instance profile)
- Implements `describe_with_prompt` and `describe`
## Test plan
- [ ] Configure a Bedrock vision model (e.g. `amazon.nova-pro-v1:0`)
with valid AWS credentials
- [ ] Select it as PDF parser in a knowledge base
- [ ] Verify ingestion of a PDF document completes without errors
- [ ] Verify `CvModel["Bedrock"]` resolves to `BedrockCV`
🤖 Generated with [Claude Code](https://claude.ai/claude-code)
---------
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-11 04:29:58 +02:00
2026-03-06 02:37:27 +01:00
if not base_url :
base_url = " https://connect.ragcon.com/v1 "
Feat: add BedrockCV for vision/image2text inference via LiteLLM (#14705)
## Summary
- `CvModel["Bedrock"]` was absent from `rag/llm/cv_model.py`, causing
`model_instance()` to return `None` when a Bedrock model was used as a
PDF parser — even after correct model resolution.
- This PR adds `BedrockCV`, enabling Bedrock vision models (e.g.
`amazon.nova-pro-v1:0`, `anthropic.claude-3-5-sonnet`) to be used as PDF
parsers.
## What problem does this PR solve?
When a Bedrock model is selected as the PDF parser in a knowledge base,
ingestion failed with:
```
'LiteLLMBase' object has no attribute 'describe_with_prompt'
```
The root cause: `LiteLLMBase` (the Bedrock chat implementation) was the
only registered handler for the Bedrock factory. It does not implement
`describe_with_prompt`. `CvModel` had no Bedrock entry, so
`model_instance()` returned `None` for `image2text` requests.
## Type of change
- [x] New Feature (non-breaking change which adds functionality)
## Changes
**`rag/llm/cv_model.py`**
Adds `BedrockCV(Base)` with `_FACTORY_NAME = "Bedrock"`:
- Uses `litellm.completion` with the `bedrock/` prefix (consistent with
`LiteLLMBase`)
- Parses AWS credentials from the JSON key assembled by `add_llm`
(`auth_mode`, `bedrock_ak`, `bedrock_sk`, `bedrock_region`,
`aws_role_arn`)
- Supports three auth modes: `access_key_secret`, `iam_role` (via STS
`assume_role`), and default credential chain (IRSA, instance profile)
- Implements `describe_with_prompt` and `describe`
## Test plan
- [ ] Configure a Bedrock vision model (e.g. `amazon.nova-pro-v1:0`)
with valid AWS credentials
- [ ] Select it as PDF parser in a knowledge base
- [ ] Verify ingestion of a PDF document completes without errors
- [ ] Verify `CvModel["Bedrock"]` resolves to `BedrockCV`
🤖 Generated with [Claude Code](https://claude.ai/claude-code)
---------
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-11 04:29:58 +02:00
2026-03-06 02:37:27 +01:00
# Initialize client
self . client = OpenAI ( api_key = key , base_url = base_url )
self . async_client = AsyncOpenAI ( api_key = key , base_url = base_url )
self . model_name = model_name
self . lang = lang
Feat: add BedrockCV for vision/image2text inference via LiteLLM (#14705)
## Summary
- `CvModel["Bedrock"]` was absent from `rag/llm/cv_model.py`, causing
`model_instance()` to return `None` when a Bedrock model was used as a
PDF parser — even after correct model resolution.
- This PR adds `BedrockCV`, enabling Bedrock vision models (e.g.
`amazon.nova-pro-v1:0`, `anthropic.claude-3-5-sonnet`) to be used as PDF
parsers.
## What problem does this PR solve?
When a Bedrock model is selected as the PDF parser in a knowledge base,
ingestion failed with:
```
'LiteLLMBase' object has no attribute 'describe_with_prompt'
```
The root cause: `LiteLLMBase` (the Bedrock chat implementation) was the
only registered handler for the Bedrock factory. It does not implement
`describe_with_prompt`. `CvModel` had no Bedrock entry, so
`model_instance()` returned `None` for `image2text` requests.
## Type of change
- [x] New Feature (non-breaking change which adds functionality)
## Changes
**`rag/llm/cv_model.py`**
Adds `BedrockCV(Base)` with `_FACTORY_NAME = "Bedrock"`:
- Uses `litellm.completion` with the `bedrock/` prefix (consistent with
`LiteLLMBase`)
- Parses AWS credentials from the JSON key assembled by `add_llm`
(`auth_mode`, `bedrock_ak`, `bedrock_sk`, `bedrock_region`,
`aws_role_arn`)
- Supports three auth modes: `access_key_secret`, `iam_role` (via STS
`assume_role`), and default credential chain (IRSA, instance profile)
- Implements `describe_with_prompt` and `describe`
## Test plan
- [ ] Configure a Bedrock vision model (e.g. `amazon.nova-pro-v1:0`)
with valid AWS credentials
- [ ] Select it as PDF parser in a knowledge base
- [ ] Verify ingestion of a PDF document completes without errors
- [ ] Verify `CvModel["Bedrock"]` resolves to `BedrockCV`
🤖 Generated with [Claude Code](https://claude.ai/claude-code)
---------
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-11 04:29:58 +02:00
Base . __init__ ( self , * * kwargs )
class BedrockCV ( Base ) :
_FACTORY_NAME = " Bedrock "
def __init__ ( self , key , model_name , lang = " Chinese " , * * kwargs ) :
self . model_name = f " bedrock/ { model_name } "
self . lang = lang
self . _parse_credentials ( key )
Base . __init__ ( self , * * kwargs )
def _parse_credentials ( self , key ) :
bedrock_key = json . loads ( key )
self . auth_mode = bedrock_key . get ( " auth_mode " , " " )
self . aws_region = bedrock_key . get ( " bedrock_region " , " us-east-1 " )
self . aws_ak = bedrock_key . get ( " bedrock_ak " , " " )
self . aws_sk = bedrock_key . get ( " bedrock_sk " , " " )
self . aws_role_arn = bedrock_key . get ( " aws_role_arn " , " " )
def _get_aws_creds ( self ) :
if self . auth_mode == " access_key_secret " :
return {
" aws_region_name " : self . aws_region ,
" aws_access_key_id " : self . aws_ak ,
" aws_secret_access_key " : self . aws_sk ,
}
elif self . auth_mode == " iam_role " :
import boto3
sts_client = boto3 . client ( " sts " , region_name = self . aws_region )
resp = sts_client . assume_role ( RoleArn = self . aws_role_arn , RoleSessionName = " BedrockCVSession " )
creds = resp [ " Credentials " ]
return {
" aws_region_name " : self . aws_region ,
" aws_access_key_id " : creds [ " AccessKeyId " ] ,
" aws_secret_access_key " : creds [ " SecretAccessKey " ] ,
" aws_session_token " : creds [ " SessionToken " ] ,
}
else :
return { " aws_region_name " : self . aws_region }
def describe_with_prompt ( self , image , prompt = None ) :
import litellm
b64 = self . image2base64 ( image )
messages = self . vision_llm_prompt ( b64 , prompt )
res = litellm . completion (
model = self . model_name ,
messages = messages ,
* * self . _get_aws_creds ( ) ,
)
return res . choices [ 0 ] . message . content . strip ( ) , total_token_count_from_response ( res )
def describe ( self , image ) :
return self . describe_with_prompt ( image )