diff --git a/api/apps/restful_apis/chunk_api.py b/api/apps/restful_apis/chunk_api.py index 8924d62ed6..59383e6f07 100644 --- a/api/apps/restful_apis/chunk_api.py +++ b/api/apps/restful_apis/chunk_api.py @@ -14,6 +14,7 @@ # limitations under the License. # import base64 +import binascii import datetime import logging import re @@ -66,6 +67,31 @@ DOC_STOP_PARSING_INVALID_STATE_MESSAGE = "Can't stop parsing document that has n DOC_STOP_PARSING_INVALID_STATE_ERROR_CODE = "DOC_STOP_PARSING_INVALID_STATE" +def _decode_chunk_image_base64(image_base64): + if not isinstance(image_base64, str) or not image_base64.strip(): + return None, "`image_base64` must be a non-empty string" + try: + image_binary = base64.b64decode(image_base64, validate=True) + except (binascii.Error, ValueError): + return None, "Invalid `image_base64`" + if not image_binary: + return None, "`image_base64` is empty" + return image_binary, None + + +def _store_chunk_image_or_error(dataset_id, chunk_id, image_binary): + try: + store_chunk_image(dataset_id, chunk_id, image_binary) + except Exception: + logging.exception( + "Failed to store chunk image. dataset_id=%s chunk_id=%s", + dataset_id, + chunk_id, + ) + return "Failed to store chunk image" + return None + + class Chunk(BaseModel): id: str = "" content: str = "" @@ -507,8 +533,13 @@ async def add_chunk(tenant_id, dataset_id, document_id): except ValueError as exc: return get_error_data_result(f"`tag_feas` {exc}") - image_base64 = req.get("image_base64") - if image_base64: + if "image_base64" in req: + image_binary, image_err = _decode_chunk_image_base64(req.get("image_base64")) + if image_err: + return get_error_data_result(message=image_err) + store_err = _store_chunk_image_or_error(dataset_id, chunk_id, image_binary) + if store_err: + return get_error_data_result(message=store_err) d["img_id"] = f"{dataset_id}-{chunk_id}" d["doc_type_kwd"] = "image" @@ -520,9 +551,6 @@ async def add_chunk(tenant_id, dataset_id, document_id): d[f"q_{len(v)}_vec"] = v.tolist() settings.docStoreConn.insert([d], search.index_name(dataset_tenant_id), dataset_id) - if image_base64: - store_chunk_image(dataset_id, chunk_id, base64.b64decode(image_base64)) - DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0) key_mapping = { "id": "id", @@ -647,8 +675,13 @@ async def update_chunk(tenant_id, dataset_id, document_id, chunk_id): d["tag_feas"] = validate_tag_features(req["tag_feas"]) except ValueError as exc: return get_error_data_result(f"`tag_feas` {exc}") - image_base64 = req.get("image_base64") - if image_base64: + if "image_base64" in req: + image_binary, image_err = _decode_chunk_image_base64(req.get("image_base64")) + if image_err: + return get_error_data_result(message=image_err) + store_err = _store_chunk_image_or_error(dataset_id, chunk_id, image_binary) + if store_err: + return get_error_data_result(message=store_err) d["img_id"] = f"{dataset_id}-{chunk_id}" d["doc_type_kwd"] = "image" @@ -671,8 +704,6 @@ async def update_chunk(tenant_id, dataset_id, document_id, chunk_id): v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] d[f"q_{len(v)}_vec"] = v.tolist() settings.docStoreConn.update({"id": chunk_id}, d, search.index_name(dataset_tenant_id), dataset_id) - if image_base64: - store_chunk_image(dataset_id, chunk_id, base64.b64decode(image_base64)) return get_result() diff --git a/test/testcases/test_web_api/test_chunk_app/test_chunk_routes_unit.py b/test/testcases/test_web_api/test_chunk_app/test_chunk_routes_unit.py index 221a56d142..dd013838b0 100644 --- a/test/testcases/test_web_api/test_chunk_app/test_chunk_routes_unit.py +++ b/test/testcases/test_web_api/test_chunk_app/test_chunk_routes_unit.py @@ -653,3 +653,79 @@ def test_restful_chunk_guard_branches_unit(monkeypatch): res = _run(_route_core(module.switch_chunks)("tenant-1", "kb-1", "doc-1")) assert res["message"] == "`available_int` or `available` is required.", res + +@pytest.mark.p2 +def test_restful_add_chunk_invalid_image_base64_does_not_index_chunk(monkeypatch): + module = _load_chunk_api_module(monkeypatch) + module.request = SimpleNamespace(args={}, headers={}) + module.settings.docStoreConn.inserted.clear() + + monkeypatch.setattr( + module, + "get_request_json", + lambda: _AwaitableValue({"content": "chunk with bad image", "image_base64": "not-valid-base64!!!"}), + ) + res = _run(_route_core(module.add_chunk)("tenant-1", "kb-1", "doc-1")) + assert res["code"] == module.RetCode.DATA_ERROR, res + assert res["message"] == "Invalid `image_base64`", res + assert module.settings.docStoreConn.inserted == [], res + assert module.DocumentService.increment_calls == [], res + + +@pytest.mark.p2 +def test_restful_add_chunk_empty_image_base64_does_not_index_chunk(monkeypatch): + module = _load_chunk_api_module(monkeypatch) + module.request = SimpleNamespace(args={}, headers={}) + module.settings.docStoreConn.inserted.clear() + + monkeypatch.setattr( + module, + "get_request_json", + lambda: _AwaitableValue({"content": "chunk with empty image", "image_base64": ""}), + ) + res = _run(_route_core(module.add_chunk)("tenant-1", "kb-1", "doc-1")) + assert res["code"] == module.RetCode.DATA_ERROR, res + assert res["message"] == "`image_base64` must be a non-empty string", res + assert module.settings.docStoreConn.inserted == [], res + assert module.DocumentService.increment_calls == [], res + + +@pytest.mark.p2 +def test_restful_update_chunk_invalid_image_base64_does_not_update_chunk(monkeypatch): + module = _load_chunk_api_module(monkeypatch) + module.request = SimpleNamespace(args={}, headers={}) + module.settings.docStoreConn.updated.clear() + + monkeypatch.setattr( + module, + "get_request_json", + lambda: _AwaitableValue({"content": "updated chunk", "image_base64": "not-valid-base64!!!"}), + ) + res = _run(_route_core(module.update_chunk)("tenant-1", "kb-1", "doc-1", "chunk-1")) + assert res["code"] == module.RetCode.DATA_ERROR, res + assert res["message"] == "Invalid `image_base64`", res + assert module.settings.docStoreConn.updated == [], res + + +@pytest.mark.p2 +def test_restful_add_chunk_valid_image_base64_stores_before_insert(monkeypatch): + module = _load_chunk_api_module(monkeypatch) + module.request = SimpleNamespace(args={}, headers={}) + module.settings.docStoreConn.inserted.clear() + store_calls = [] + monkeypatch.setattr(module, "store_chunk_image", lambda bucket, name, binary: store_calls.append((bucket, name, binary))) + + valid_b64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==" + monkeypatch.setattr( + module, + "get_request_json", + lambda: _AwaitableValue({"content": "chunk with image", "image_base64": valid_b64}), + ) + res = _run(_route_core(module.add_chunk)("tenant-1", "kb-1", "doc-1")) + assert res["code"] == 0, res + assert store_calls, "store_chunk_image should run before doc-store insert" + assert module.settings.docStoreConn.inserted, "chunk should be indexed after image stored" + inserted = module.settings.docStoreConn.inserted[-1] + assert inserted.get("img_id"), inserted + assert inserted.get("doc_type_kwd") == "image", inserted +