diff --git a/agentrun/knowledgebase/__init__.py b/agentrun/knowledgebase/__init__.py index fd314f1..50eb9c7 100644 --- a/agentrun/knowledgebase/__init__.py +++ b/agentrun/knowledgebase/__init__.py @@ -13,6 +13,7 @@ from .knowledgebase import KnowledgeBase from .model import ( ADBProviderSettings, + ADBRerankModel, ADBRetrieveSettings, BailianProviderSettings, BailianRetrieveSettings, @@ -64,6 +65,7 @@ "RetrieveSettings", "RagFlowRetrieveSettings", "BailianRetrieveSettings", + "ADBRerankModel", "ADBRetrieveSettings", "OTSRetrieveSettings", "OTSDenseVectorSearchConfig", diff --git a/agentrun/knowledgebase/__knowledgebase_async_template.py b/agentrun/knowledgebase/__knowledgebase_async_template.py index 114496f..6b90eab 100644 --- a/agentrun/knowledgebase/__knowledgebase_async_template.py +++ b/agentrun/knowledgebase/__knowledgebase_async_template.py @@ -15,6 +15,7 @@ from .api.data import get_data_api from .model import ( ADBProviderSettings, + ADBRerankModel, ADBRetrieveSettings, BailianProviderSettings, BailianRetrieveSettings, @@ -344,6 +345,14 @@ def _get_data_api(self, config: Optional[Config] = None): rerank_factor=self.retrieve_settings.get( "RerankFactor" ), + rerank_model=( + ADBRerankModel( + name=self.retrieve_settings.get("RerankModel", {}).get("Name", ""), + instruct=self.retrieve_settings.get("RerankModel", {}).get("Instruct"), + ) + if self.retrieve_settings.get("RerankModel") + else None + ), recall_window=self.retrieve_settings.get( "RecallWindow" ), diff --git a/agentrun/knowledgebase/api/__data_async_template.py b/agentrun/knowledgebase/api/__data_async_template.py index a068ff5..e1cb8e8 100644 --- a/agentrun/knowledgebase/api/__data_async_template.py +++ b/agentrun/knowledgebase/api/__data_async_template.py @@ -405,6 +405,8 @@ def _build_query_content_request( "namespace_password": self.provider_settings.namespace_password, "collection": self.knowledge_base_name, "region_id": cfg.get_region_id(), + # 固定设置 URL 过期时间为 356 天 / Fixed URL expiration to 356 days + "url_expiration": "356d", } # 添加可选的提供商设置 / Add optional provider settings @@ -423,6 +425,15 @@ def _build_query_content_request( request_params["rerank_factor"] = ( self.retrieve_settings.rerank_factor ) + if self.retrieve_settings.rerank_model is not None: + rerank_model_params: Dict[str, Any] = { + "Name": self.retrieve_settings.rerank_model.name, + } + if self.retrieve_settings.rerank_model.instruct is not None: + rerank_model_params["Instruct"] = ( + self.retrieve_settings.rerank_model.instruct + ) + request_params["rerank_model"] = rerank_model_params if self.retrieve_settings.recall_window is not None: request_params["recall_window"] = ( self.retrieve_settings.recall_window diff --git a/agentrun/knowledgebase/api/data.py b/agentrun/knowledgebase/api/data.py index 2b4fd55..7980ff6 100644 --- a/agentrun/knowledgebase/api/data.py +++ b/agentrun/knowledgebase/api/data.py @@ -615,6 +615,8 @@ def _build_query_content_request( "namespace_password": self.provider_settings.namespace_password, "collection": self.knowledge_base_name, "region_id": cfg.get_region_id(), + # 固定设置 URL 过期时间为 356 天 / Fixed URL expiration to 356 days + "url_expiration": "356d", } # 添加可选的提供商设置 / Add optional provider settings @@ -633,6 +635,15 @@ def _build_query_content_request( request_params["rerank_factor"] = ( self.retrieve_settings.rerank_factor ) + if self.retrieve_settings.rerank_model is not None: + rerank_model_params: Dict[str, Any] = { + "Name": self.retrieve_settings.rerank_model.name, + } + if self.retrieve_settings.rerank_model.instruct is not None: + rerank_model_params["Instruct"] = ( + self.retrieve_settings.rerank_model.instruct + ) + request_params["rerank_model"] = rerank_model_params if self.retrieve_settings.recall_window is not None: request_params["recall_window"] = ( self.retrieve_settings.recall_window diff --git a/agentrun/knowledgebase/knowledgebase.py b/agentrun/knowledgebase/knowledgebase.py index 5b075a1..f08c90c 100644 --- a/agentrun/knowledgebase/knowledgebase.py +++ b/agentrun/knowledgebase/knowledgebase.py @@ -25,6 +25,7 @@ from .api.data import get_data_api from .model import ( ADBProviderSettings, + ADBRerankModel, ADBRetrieveSettings, BailianProviderSettings, BailianRetrieveSettings, @@ -526,6 +527,14 @@ def _get_data_api(self, config: Optional[Config] = None): rerank_factor=self.retrieve_settings.get( "RerankFactor" ), + rerank_model=( + ADBRerankModel( + name=self.retrieve_settings.get("RerankModel", {}).get("Name", ""), + instruct=self.retrieve_settings.get("RerankModel", {}).get("Instruct"), + ) + if self.retrieve_settings.get("RerankModel") + else None + ), recall_window=self.retrieve_settings.get( "RecallWindow" ), diff --git a/agentrun/knowledgebase/model.py b/agentrun/knowledgebase/model.py index c3e5cfc..aaaf1e0 100644 --- a/agentrun/knowledgebase/model.py +++ b/agentrun/knowledgebase/model.py @@ -106,6 +106,22 @@ class ADBProviderSettings(BaseModel): """元数据配置,JSON 字符串格式 / Metadata configuration in JSON string format""" +class ADBRerankModel(BaseModel): + """ADB 重排模型配置 / ADB Rerank Model Configuration + + 配置重排模型的名称和排序任务类型说明。 + Configure the rerank model name and instruct for sorting task type. + """ + + name: str + """重排模型名称,可选值:qwen3-rerank、gte-rerank-v2 + Rerank model name, options: qwen3-rerank, gte-rerank-v2""" + instruct: Optional[str] = None + """排序任务类型说明,仅当 name 为 qwen3-rerank 时可设置,指导模型采用不同的排序策略 + Instruct for sorting task type, only available when name is qwen3-rerank, + guides the model to adopt different sorting strategies""" + + class ADBRetrieveSettings(BaseModel): """ADB 检索设置 / ADB Retrieve Settings @@ -122,6 +138,9 @@ class ADBRetrieveSettings(BaseModel): rerank_factor: Optional[float] = None """重排序因子,取值范围 1 < RerankFactor <= 5 Re-ranking factor, value range: 1 < RerankFactor <= 5""" + rerank_model: Optional[ADBRerankModel] = None + """重排模型配置,当启用重排因子时可设置 + Rerank model configuration, available when rerank factor is enabled""" recall_window: Optional[List[int]] = None """召回窗口,格式为 [A, B],其中 -10 <= A <= 0,0 <= B <= 10 Recall window, format [A, B] where -10 <= A <= 0, 0 <= B <= 10""" diff --git a/examples/knowledgebase.py b/examples/knowledgebase.py index fff3e9f..50376a9 100644 --- a/examples/knowledgebase.py +++ b/examples/knowledgebase.py @@ -45,6 +45,7 @@ from agentrun.knowledgebase import ( ADBProviderSettings, + ADBRerankModel, ADBRetrieveSettings, BailianProviderSettings, BailianRetrieveSettings, @@ -474,6 +475,10 @@ def create_or_get_adb_kb() -> KnowledgeBase: top_k=10, use_full_text_retrieval=False, # 仅使用向量检索 / Vector only rerank_factor=2.0, # 重排序因子 / Rerank factor + rerank_model=ADBRerankModel( + name="qwen3-rerank", # 重排模型名称 / Rerank model name + instruct="按相关性排序", # 排序任务类型说明(仅 qwen3-rerank 支持)/ Instruct (only for qwen3-rerank) + ), ), ) ) diff --git a/tests/unittests/knowledgebase/api/test_data.py b/tests/unittests/knowledgebase/api/test_data.py index 8e15902..08f0a7f 100644 --- a/tests/unittests/knowledgebase/api/test_data.py +++ b/tests/unittests/knowledgebase/api/test_data.py @@ -14,6 +14,7 @@ ) from agentrun.knowledgebase.model import ( ADBProviderSettings, + ADBRerankModel, ADBRetrieveSettings, BailianProviderSettings, BailianRetrieveSettings, @@ -995,6 +996,7 @@ def test_build_query_content_request(self): assert request.dbinstance_id == "gp-123456" assert request.namespace == "public" assert request.collection == "test-kb" + assert request.url_expiration == "356d" @patch.dict( os.environ, @@ -1017,6 +1019,10 @@ def test_build_query_content_request_with_settings(self): top_k=10, use_full_text_retrieval=True, rerank_factor=1.5, + rerank_model=ADBRerankModel( + name="qwen3-rerank", + instruct="按相关性排序", + ), recall_window=[-5, 5], hybrid_search="RRF", hybrid_search_args={"RRF": {"k": 60}}, @@ -1024,10 +1030,12 @@ def test_build_query_content_request_with_settings(self): ) request = api._build_query_content_request("test query") + assert request.url_expiration == "356d" assert request.metrics == "cosine" assert request.top_k == 10 assert request.use_full_text_retrieval is True assert request.rerank_factor == 1.5 + assert request.rerank_model is not None @patch.dict( os.environ, diff --git a/tests/unittests/knowledgebase/test_knowledgebase.py b/tests/unittests/knowledgebase/test_knowledgebase.py index 8936d9d..6617fc8 100644 --- a/tests/unittests/knowledgebase/test_knowledgebase.py +++ b/tests/unittests/knowledgebase/test_knowledgebase.py @@ -702,6 +702,7 @@ def test_get_data_api_adb_with_raw_dict_settings(self): "TopK": 10, "UseFullTextRetrieval": True, "RerankFactor": 1.5, + "RerankModel": {"Name": "qwen3-rerank", "Instruct": "按相关性排序"}, "RecallWindow": [-5, 5], "HybridSearch": "RRF", "HybridSearchArgs": {"RRF": {"k": 60}}, diff --git a/tests/unittests/knowledgebase/test_model.py b/tests/unittests/knowledgebase/test_model.py index 5f5e335..ba3cf17 100644 --- a/tests/unittests/knowledgebase/test_model.py +++ b/tests/unittests/knowledgebase/test_model.py @@ -5,6 +5,7 @@ import pytest from agentrun.knowledgebase.model import ( + ADBRerankModel, ADBProviderSettings, ADBRetrieveSettings, BailianProviderSettings, @@ -207,6 +208,10 @@ def test_create_adb_retrieve_settings(self): top_k=10, use_full_text_retrieval=True, rerank_factor=1.5, + rerank_model=ADBRerankModel( + name="qwen3-rerank", + instruct="按相关性排序", + ), recall_window=[-5, 5], hybrid_search="RRF", hybrid_search_args={"RRF": {"k": 60}}, @@ -214,6 +219,9 @@ def test_create_adb_retrieve_settings(self): assert settings.top_k == 10 assert settings.use_full_text_retrieval is True assert settings.rerank_factor == 1.5 + assert settings.rerank_model is not None + assert settings.rerank_model.name == "qwen3-rerank" + assert settings.rerank_model.instruct == "按相关性排序" assert settings.recall_window == [-5, 5] assert settings.hybrid_search == "RRF" assert settings.hybrid_search_args == {"RRF": {"k": 60}} @@ -224,6 +232,7 @@ def test_adb_retrieve_settings_optional(self): assert settings.top_k is None assert settings.use_full_text_retrieval is None assert settings.rerank_factor is None + assert settings.rerank_model is None assert settings.recall_window is None assert settings.hybrid_search is None assert settings.hybrid_search_args is None