mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
Feat: tenant llm provider (#14595)
### What problem does this PR solve? Python implementation of the Go-based model_provider API suite. ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: bill <yibie_jingnian@163.com>
This commit is contained in:
@@ -92,6 +92,7 @@ def get_email():
|
||||
|
||||
|
||||
def get_my_llms(auth, name):
|
||||
# todo deprecated
|
||||
url = HOST_ADDRESS + "/v1/llm/my_llms"
|
||||
authorization = {"Authorization": auth}
|
||||
response = requests.get(url=url, headers=authorization)
|
||||
@@ -103,7 +104,20 @@ def get_my_llms(auth, name):
|
||||
return False
|
||||
|
||||
|
||||
def get_added_models(auth, factory_name):
|
||||
url = HOST_ADDRESS + "/api/v1/models"
|
||||
authorization = {"Authorization": auth}
|
||||
response = requests.get(url=url, headers=authorization)
|
||||
res = response.json()
|
||||
if res.get("code") != 0:
|
||||
raise Exception(res.get("message"))
|
||||
added_factory = {model["provider_name"] for model in res.get("data", [])}
|
||||
if factory_name in added_factory:
|
||||
return True
|
||||
return False
|
||||
|
||||
def add_models(auth):
|
||||
# todo deprecated
|
||||
url = HOST_ADDRESS + "/v1/llm/set_api_key"
|
||||
authorization = {"Authorization": auth}
|
||||
models_info = {
|
||||
@@ -118,7 +132,32 @@ def add_models(auth):
|
||||
pytest.exit(f"Critical error in add_models: {res.get('message')}")
|
||||
|
||||
|
||||
def add_model_instance(auth):
|
||||
add_provider_api = HOST_ADDRESS + "/api/v1/providers"
|
||||
authorization = {"Authorization": auth}
|
||||
add_provider_response = requests.put(url=add_provider_api, headers=authorization, json={"provider_name": "ZHIPU-AI"})
|
||||
add_provider_res = add_provider_response.json()
|
||||
if add_provider_res.get("code") != 0:
|
||||
pytest.exit(f"Critical error in add model provider: {add_provider_res.get('message')}")
|
||||
|
||||
add_instance_api = HOST_ADDRESS + "/api/v1/providers/ZHIPU-AI/instances"
|
||||
add_instance_response = requests.post(url=add_instance_api, headers=authorization, json={
|
||||
"instance_name": "CI",
|
||||
"api_key": ZHIPU_AI_API_KEY,
|
||||
"region": "default",
|
||||
"base_url": ""
|
||||
})
|
||||
add_instance_res = add_instance_response.json()
|
||||
if add_instance_res.get("code") != 0:
|
||||
pytest.exit(f"Critical error in add model instance: {add_instance_res.get('message')}")
|
||||
|
||||
add_success = get_added_models(auth, "ZHIPU-AI")
|
||||
if not add_success:
|
||||
pytest.exit("Critical error in check added model: add model failed")
|
||||
|
||||
|
||||
def get_tenant_info(auth):
|
||||
# todo deprecated
|
||||
url = HOST_ADDRESS + "/api/v1/users/me/models"
|
||||
authorization = {"Authorization": auth}
|
||||
response = requests.get(url=url, headers=authorization)
|
||||
@@ -131,22 +170,49 @@ def get_tenant_info(auth):
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def set_tenant_info(get_auth):
|
||||
auth = get_auth
|
||||
try:
|
||||
add_models(auth)
|
||||
tenant_id = get_tenant_info(auth)
|
||||
except Exception as e:
|
||||
pytest.exit(f"Error in set_tenant_info: {str(e)}")
|
||||
url = HOST_ADDRESS + "/api/v1/users/me/models"
|
||||
if not get_added_models(auth, "ZHIPU-AI"):
|
||||
try:
|
||||
add_model_instance(auth)
|
||||
except Exception as e:
|
||||
pytest.exit(f"Error in set_tenant_info: {str(e)}")
|
||||
url = HOST_ADDRESS + "/api/v1/models/default"
|
||||
authorization = {"Authorization": get_auth}
|
||||
tenant_info = {
|
||||
"tenant_id": tenant_id,
|
||||
"llm_id": "glm-4-flash@ZHIPU-AI",
|
||||
"embd_id": "BAAI/bge-small-en-v1.5@Builtin",
|
||||
"img2txt_id": "glm-4v@ZHIPU-AI",
|
||||
"asr_id": "",
|
||||
"tts_id": None,
|
||||
}
|
||||
response = requests.patch(url=url, headers=authorization, json=tenant_info)
|
||||
res = response.json()
|
||||
if res.get("code") != 0:
|
||||
raise Exception(res.get("message"))
|
||||
# set chat model
|
||||
set_default_llm_response = requests.patch(
|
||||
url=url,
|
||||
headers=authorization,
|
||||
json={
|
||||
"model_provider": "ZHIPU-AI",
|
||||
"model_instance": "CI",
|
||||
"model_type": "chat",
|
||||
"model_name": "glm-4-flash"
|
||||
})
|
||||
llm_res = set_default_llm_response.json()
|
||||
if llm_res.get("code") != 0:
|
||||
raise Exception(llm_res.get("message"))
|
||||
# set embedding model
|
||||
set_default_embedding_response = requests.patch(
|
||||
url=url,
|
||||
headers=authorization,
|
||||
json={
|
||||
"model_provider": "Builtin",
|
||||
"model_instance": "Local",
|
||||
"model_type": "embedding",
|
||||
"model_name": "BAAI/bge-small-en-v1.5"
|
||||
})
|
||||
embd_res = set_default_embedding_response.json()
|
||||
if embd_res.get("code") != 0:
|
||||
raise Exception(embd_res.get("message"))
|
||||
# set image to text model
|
||||
set_default_img2txt_response = requests.patch(
|
||||
url=url,
|
||||
headers=authorization,
|
||||
json={
|
||||
"model_provider": "ZHIPU-AI",
|
||||
"model_instance": "CI",
|
||||
"model_type": "vision",
|
||||
"model_name": "glm-4v"
|
||||
})
|
||||
img2txt_res = set_default_img2txt_response.json()
|
||||
if img2txt_res.get("code") != 0:
|
||||
raise Exception(img2txt_res.get("message"))
|
||||
|
||||
Reference in New Issue
Block a user