From f4f17e6deb46b8c44072830d45c714f4c2796644 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=81=B5=E8=BD=AE?= Date: Sat, 6 Jun 2026 19:08:51 +0800 Subject: [PATCH] feat:adb retrieve settings supports filter Change-Id: I665fa00b9362e58b43295c7167b70ff015ea3ea1 Co-developed-by: Qoder --- .../__knowledgebase_async_template.py | 21 +++++++++- .../api/__data_async_template.py | 14 +++---- agentrun/knowledgebase/api/data.py | 27 ++++++++----- agentrun/knowledgebase/knowledgebase.py | 39 +++++++++++++++---- agentrun/knowledgebase/model.py | 3 ++ .../unittests/knowledgebase/api/test_data.py | 2 + .../knowledgebase/test_knowledgebase.py | 2 + tests/unittests/knowledgebase/test_model.py | 3 ++ 8 files changed, 84 insertions(+), 27 deletions(-) diff --git a/agentrun/knowledgebase/__knowledgebase_async_template.py b/agentrun/knowledgebase/__knowledgebase_async_template.py index 6b90eab..11d64be 100644 --- a/agentrun/knowledgebase/__knowledgebase_async_template.py +++ b/agentrun/knowledgebase/__knowledgebase_async_template.py @@ -276,6 +276,18 @@ def _get_data_api(self, config: Optional[Config] = None): converted_provider_settings = None converted_retrieve_settings = None + # 当 retrieve_settings 被 pydantic Union 匹配到错误的类型时(由于 extra="allow"), + # 从 __pydantic_extra__ 提取原始数据作为 dict 使用 + # When retrieve_settings is matched to wrong Union type by pydantic (due to extra="allow"), + # extract raw data from __pydantic_extra__ as dict + if ( + self.retrieve_settings is not None + and not isinstance(self.retrieve_settings, dict) + and hasattr(self.retrieve_settings, "__pydantic_extra__") + and self.retrieve_settings.__pydantic_extra__ + ): + self.retrieve_settings = self.retrieve_settings.__pydantic_extra__ + if provider == KnowledgeBaseProvider.BAILIAN: # 百炼设置 / Bailian settings if self.provider_settings: @@ -347,8 +359,12 @@ def _get_data_api(self, config: Optional[Config] = None): ), rerank_model=( ADBRerankModel( - name=self.retrieve_settings.get("RerankModel", {}).get("Name", ""), - instruct=self.retrieve_settings.get("RerankModel", {}).get("Instruct"), + name=self.retrieve_settings.get( + "RerankModel", {} + ).get("Name", ""), + instruct=self.retrieve_settings.get( + "RerankModel", {} + ).get("Instruct"), ) if self.retrieve_settings.get("RerankModel") else None @@ -362,6 +378,7 @@ def _get_data_api(self, config: Optional[Config] = None): hybrid_search_args=self.retrieve_settings.get( "HybridSearchArgs" ), + filter=self.retrieve_settings.get("Filter"), ) elif provider == KnowledgeBaseProvider.OTS: diff --git a/agentrun/knowledgebase/api/__data_async_template.py b/agentrun/knowledgebase/api/__data_async_template.py index e1cb8e8..5277c3f 100644 --- a/agentrun/knowledgebase/api/__data_async_template.py +++ b/agentrun/knowledgebase/api/__data_async_template.py @@ -426,14 +426,12 @@ def _build_query_content_request( self.retrieve_settings.rerank_factor ) if self.retrieve_settings.rerank_model is not None: - rerank_model_params: Dict[str, Any] = { - "Name": self.retrieve_settings.rerank_model.name, - } - if self.retrieve_settings.rerank_model.instruct is not None: - rerank_model_params["Instruct"] = ( - self.retrieve_settings.rerank_model.instruct + request_params["rerank_model"] = ( + gpdb_models.QueryContentRequestRerankModel( + name=self.retrieve_settings.rerank_model.name, + instruct=self.retrieve_settings.rerank_model.instruct, ) - request_params["rerank_model"] = rerank_model_params + ) if self.retrieve_settings.recall_window is not None: request_params["recall_window"] = ( self.retrieve_settings.recall_window @@ -446,6 +444,8 @@ def _build_query_content_request( request_params["hybrid_search_args"] = ( self.retrieve_settings.hybrid_search_args ) + if self.retrieve_settings.filter is not None: + request_params["filter"] = self.retrieve_settings.filter return gpdb_models.QueryContentRequest(**request_params) diff --git a/agentrun/knowledgebase/api/data.py b/agentrun/knowledgebase/api/data.py index 7980ff6..74517b5 100644 --- a/agentrun/knowledgebase/api/data.py +++ b/agentrun/knowledgebase/api/data.py @@ -81,6 +81,7 @@ async def retrieve_async( """ raise NotImplementedError("Subclasses must implement retrieve_async") + @abstractmethod def retrieve( self, @@ -175,7 +176,9 @@ def _get_api_key(self, config: Optional[Config] = None) -> str: from agentrun.credential import Credential - credential = Credential.get_by_name(self.credential_name, config=config) + credential = Credential.get_by_name( + self.credential_name, config=config + ) if not credential.credential_secret: raise ValueError( f"Credential '{self.credential_name}' has no secret configured" @@ -282,6 +285,7 @@ async def retrieve_async( "error": True, } + def retrieve( self, query: str, @@ -315,7 +319,9 @@ def retrieve( body = self._build_request_body(query) # 发送请求 / Send request - with httpx.Client(timeout=self.config.get_timeout()) as client: + with httpx.Client( + timeout=self.config.get_timeout() + ) as client: response = client.post(url, json=body, headers=headers) response.raise_for_status() result = response.json() @@ -467,6 +473,7 @@ async def retrieve_async( "error": True, } + def retrieve( self, query: str, @@ -636,14 +643,10 @@ def _build_query_content_request( self.retrieve_settings.rerank_factor ) if self.retrieve_settings.rerank_model is not None: - rerank_model_params: Dict[str, Any] = { - "Name": self.retrieve_settings.rerank_model.name, - } - if self.retrieve_settings.rerank_model.instruct is not None: - rerank_model_params["Instruct"] = ( - self.retrieve_settings.rerank_model.instruct - ) - request_params["rerank_model"] = rerank_model_params + request_params["rerank_model"] = gpdb_models.QueryContentRequestRerankModel( + name=self.retrieve_settings.rerank_model.name, + instruct=self.retrieve_settings.rerank_model.instruct, + ) if self.retrieve_settings.recall_window is not None: request_params["recall_window"] = ( self.retrieve_settings.recall_window @@ -656,6 +659,8 @@ def _build_query_content_request( request_params["hybrid_search_args"] = ( self.retrieve_settings.hybrid_search_args ) + if self.retrieve_settings.filter is not None: + request_params["filter"] = self.retrieve_settings.filter return gpdb_models.QueryContentRequest(**request_params) @@ -768,6 +773,7 @@ async def retrieve_async( "error": True, } + def retrieve( self, query: str, @@ -1039,6 +1045,7 @@ async def retrieve_async( "error": True, } + def retrieve( self, query: str, diff --git a/agentrun/knowledgebase/knowledgebase.py b/agentrun/knowledgebase/knowledgebase.py index f08c90c..ad3f6ab 100644 --- a/agentrun/knowledgebase/knowledgebase.py +++ b/agentrun/knowledgebase/knowledgebase.py @@ -109,7 +109,9 @@ def create( Returns: KnowledgeBase: 创建的知识库对象 / Created knowledge base object """ - return cls.__get_client(config=config).create(input, config=config) + return cls.__get_client(config=config).create( + input, config=config + ) @classmethod async def delete_by_name_async( @@ -357,7 +359,9 @@ def delete(self, config: Optional[Config] = None): "knowledge_base_name is required to delete a KnowledgeBase" ) - return self.delete_by_name(self.knowledge_base_name, config=config) + return self.delete_by_name( + self.knowledge_base_name, config=config + ) async def get_async(self, config: Optional[Config] = None): """刷新知识库信息(异步)/ Refresh knowledge base info asynchronously @@ -394,7 +398,9 @@ def get(self, config: Optional[Config] = None): "knowledge_base_name is required to refresh a KnowledgeBase" ) - result = self.get_by_name(self.knowledge_base_name, config=config) + result = self.get_by_name( + self.knowledge_base_name, config=config + ) self.update_self(result) return self @@ -458,6 +464,18 @@ def _get_data_api(self, config: Optional[Config] = None): converted_provider_settings = None converted_retrieve_settings = None + # 当 retrieve_settings 被 pydantic Union 匹配到错误的类型时(由于 extra="allow"), + # 从 __pydantic_extra__ 提取原始数据作为 dict 使用 + # When retrieve_settings is matched to wrong Union type by pydantic (due to extra="allow"), + # extract raw data from __pydantic_extra__ as dict + if ( + self.retrieve_settings is not None + and not isinstance(self.retrieve_settings, dict) + and hasattr(self.retrieve_settings, "__pydantic_extra__") + and self.retrieve_settings.__pydantic_extra__ + ): + self.retrieve_settings = self.retrieve_settings.__pydantic_extra__ + if provider == KnowledgeBaseProvider.BAILIAN: # 百炼设置 / Bailian settings if self.provider_settings: @@ -544,6 +562,9 @@ def _get_data_api(self, config: Optional[Config] = None): hybrid_search_args=self.retrieve_settings.get( "HybridSearchArgs" ), + filter=self.retrieve_settings.get( + "Filter" + ), ) elif provider == KnowledgeBaseProvider.OTS: @@ -905,19 +926,21 @@ def multi_retrieve( """ # 1. 根据 knowledge_base_names 并发获取各知识库配置(安全方式) # Fetch all knowledge bases concurrently by name (safely) - knowledge_base_results = [ + knowledge_base_results = ([ cls._safe_get_kb(name, config=config) for name in knowledge_base_names - ] + ]) # 2. 并发执行各知识库的检索(安全方式) # Execute retrieval for each knowledge base concurrently (safely) - retrieve_results = [ - cls._safe_retrieve_kb(kb_name, kb_or_error, query, config=config) + retrieve_results = ([ + cls._safe_retrieve_kb( + kb_name, kb_or_error, query, config=config + ) for kb_name, kb_or_error in zip( knowledge_base_names, knowledge_base_results ) - ] + ]) # 3. 合并返回结果,按知识库名称分组 # Merge results, grouped by knowledge base name diff --git a/agentrun/knowledgebase/model.py b/agentrun/knowledgebase/model.py index aaaf1e0..ff0f7a6 100644 --- a/agentrun/knowledgebase/model.py +++ b/agentrun/knowledgebase/model.py @@ -150,6 +150,9 @@ class ADBRetrieveSettings(BaseModel): hybrid_search_args: Optional[Dict[str, Any]] = None """混合检索算法参数,如 {"RRF": {"k": 60}} 或 {"Weight": {"alpha": 0.5}} Hybrid search algorithm parameters""" + filter: Optional[str] = None + """过滤条件,SQL WHERE 格式,如 "category = 'tech' AND score > 0.5" + Filter condition in SQL WHERE format""" # ============================================================================= diff --git a/tests/unittests/knowledgebase/api/test_data.py b/tests/unittests/knowledgebase/api/test_data.py index 08f0a7f..4104a0a 100644 --- a/tests/unittests/knowledgebase/api/test_data.py +++ b/tests/unittests/knowledgebase/api/test_data.py @@ -1026,6 +1026,7 @@ def test_build_query_content_request_with_settings(self): recall_window=[-5, 5], hybrid_search="RRF", hybrid_search_args={"RRF": {"k": 60}}, + filter="category = 'tech' AND score > 0.5", ), ) @@ -1036,6 +1037,7 @@ def test_build_query_content_request_with_settings(self): assert request.use_full_text_retrieval is True assert request.rerank_factor == 1.5 assert request.rerank_model is not None + assert request.filter == "category = 'tech' AND score > 0.5" @patch.dict( os.environ, diff --git a/tests/unittests/knowledgebase/test_knowledgebase.py b/tests/unittests/knowledgebase/test_knowledgebase.py index 6617fc8..ffbf883 100644 --- a/tests/unittests/knowledgebase/test_knowledgebase.py +++ b/tests/unittests/knowledgebase/test_knowledgebase.py @@ -706,6 +706,7 @@ def test_get_data_api_adb_with_raw_dict_settings(self): "RecallWindow": [-5, 5], "HybridSearch": "RRF", "HybridSearchArgs": {"RRF": {"k": 60}}, + "Filter": "category = 'tech'", }, ) @@ -713,6 +714,7 @@ def test_get_data_api_adb_with_raw_dict_settings(self): data_api = kb._get_data_api() assert isinstance(data_api, ADBDataAPI) + assert data_api.retrieve_settings.filter == "category = 'tech'" def test_get_data_api_without_provider(self): """测试获取数据链路 API(无提供商)""" diff --git a/tests/unittests/knowledgebase/test_model.py b/tests/unittests/knowledgebase/test_model.py index ba3cf17..f0b8ac9 100644 --- a/tests/unittests/knowledgebase/test_model.py +++ b/tests/unittests/knowledgebase/test_model.py @@ -215,6 +215,7 @@ def test_create_adb_retrieve_settings(self): recall_window=[-5, 5], hybrid_search="RRF", hybrid_search_args={"RRF": {"k": 60}}, + filter="category = 'tech'", ) assert settings.top_k == 10 assert settings.use_full_text_retrieval is True @@ -225,6 +226,7 @@ def test_create_adb_retrieve_settings(self): assert settings.recall_window == [-5, 5] assert settings.hybrid_search == "RRF" assert settings.hybrid_search_args == {"RRF": {"k": 60}} + assert settings.filter == "category = 'tech'" def test_adb_retrieve_settings_optional(self): """测试 ADB 检索设置可选字段""" @@ -236,6 +238,7 @@ def test_adb_retrieve_settings_optional(self): assert settings.recall_window is None assert settings.hybrid_search is None assert settings.hybrid_search_args is None + assert settings.filter is None def test_adb_retrieve_settings_weight_hybrid(self): """测试 ADB 检索设置加权混合检索"""