Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions agentrun/knowledgebase/__knowledgebase_async_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,18 @@ def _get_data_api(self, config: Optional[Config] = None):
converted_provider_settings = None
converted_retrieve_settings = None

# 当 retrieve_settings 被 pydantic Union 匹配到错误的类型时(由于 extra="allow"),
# 从 __pydantic_extra__ 提取原始数据作为 dict 使用
# When retrieve_settings is matched to wrong Union type by pydantic (due to extra="allow"),
# extract raw data from __pydantic_extra__ as dict
if (
self.retrieve_settings is not None
and not isinstance(self.retrieve_settings, dict)
and hasattr(self.retrieve_settings, "__pydantic_extra__")
and self.retrieve_settings.__pydantic_extra__
):
self.retrieve_settings = self.retrieve_settings.__pydantic_extra__

if provider == KnowledgeBaseProvider.BAILIAN:
# 百炼设置 / Bailian settings
if self.provider_settings:
Expand Down Expand Up @@ -347,8 +359,12 @@ def _get_data_api(self, config: Optional[Config] = None):
),
rerank_model=(
ADBRerankModel(
name=self.retrieve_settings.get("RerankModel", {}).get("Name", ""),
instruct=self.retrieve_settings.get("RerankModel", {}).get("Instruct"),
name=self.retrieve_settings.get(
"RerankModel", {}
).get("Name", ""),
instruct=self.retrieve_settings.get(
"RerankModel", {}
).get("Instruct"),
)
if self.retrieve_settings.get("RerankModel")
else None
Expand All @@ -362,6 +378,7 @@ def _get_data_api(self, config: Optional[Config] = None):
hybrid_search_args=self.retrieve_settings.get(
"HybridSearchArgs"
),
filter=self.retrieve_settings.get("Filter"),
)

elif provider == KnowledgeBaseProvider.OTS:
Expand Down
14 changes: 7 additions & 7 deletions agentrun/knowledgebase/api/__data_async_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,14 +426,12 @@ def _build_query_content_request(
self.retrieve_settings.rerank_factor
)
if self.retrieve_settings.rerank_model is not None:
rerank_model_params: Dict[str, Any] = {
"Name": self.retrieve_settings.rerank_model.name,
}
if self.retrieve_settings.rerank_model.instruct is not None:
rerank_model_params["Instruct"] = (
self.retrieve_settings.rerank_model.instruct
request_params["rerank_model"] = (
gpdb_models.QueryContentRequestRerankModel(
name=self.retrieve_settings.rerank_model.name,
instruct=self.retrieve_settings.rerank_model.instruct,
)
request_params["rerank_model"] = rerank_model_params
)
if self.retrieve_settings.recall_window is not None:
request_params["recall_window"] = (
self.retrieve_settings.recall_window
Expand All @@ -446,6 +444,8 @@ def _build_query_content_request(
request_params["hybrid_search_args"] = (
self.retrieve_settings.hybrid_search_args
)
if self.retrieve_settings.filter is not None:
request_params["filter"] = self.retrieve_settings.filter

return gpdb_models.QueryContentRequest(**request_params)

Expand Down
27 changes: 17 additions & 10 deletions agentrun/knowledgebase/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ async def retrieve_async(
"""
raise NotImplementedError("Subclasses must implement retrieve_async")


@abstractmethod
def retrieve(
self,
Expand Down Expand Up @@ -175,7 +176,9 @@ def _get_api_key(self, config: Optional[Config] = None) -> str:

from agentrun.credential import Credential

credential = Credential.get_by_name(self.credential_name, config=config)
credential = Credential.get_by_name(
self.credential_name, config=config
)
if not credential.credential_secret:
raise ValueError(
f"Credential '{self.credential_name}' has no secret configured"
Expand Down Expand Up @@ -282,6 +285,7 @@ async def retrieve_async(
"error": True,
}


def retrieve(
self,
query: str,
Expand Down Expand Up @@ -315,7 +319,9 @@ def retrieve(
body = self._build_request_body(query)

# 发送请求 / Send request
with httpx.Client(timeout=self.config.get_timeout()) as client:
with httpx.Client(
timeout=self.config.get_timeout()
) as client:
response = client.post(url, json=body, headers=headers)
response.raise_for_status()
result = response.json()
Expand Down Expand Up @@ -467,6 +473,7 @@ async def retrieve_async(
"error": True,
}


def retrieve(
self,
query: str,
Expand Down Expand Up @@ -636,14 +643,10 @@ def _build_query_content_request(
self.retrieve_settings.rerank_factor
)
if self.retrieve_settings.rerank_model is not None:
rerank_model_params: Dict[str, Any] = {
"Name": self.retrieve_settings.rerank_model.name,
}
if self.retrieve_settings.rerank_model.instruct is not None:
rerank_model_params["Instruct"] = (
self.retrieve_settings.rerank_model.instruct
)
request_params["rerank_model"] = rerank_model_params
request_params["rerank_model"] = gpdb_models.QueryContentRequestRerankModel(
name=self.retrieve_settings.rerank_model.name,
instruct=self.retrieve_settings.rerank_model.instruct,
)
if self.retrieve_settings.recall_window is not None:
request_params["recall_window"] = (
self.retrieve_settings.recall_window
Expand All @@ -656,6 +659,8 @@ def _build_query_content_request(
request_params["hybrid_search_args"] = (
self.retrieve_settings.hybrid_search_args
)
if self.retrieve_settings.filter is not None:
request_params["filter"] = self.retrieve_settings.filter

return gpdb_models.QueryContentRequest(**request_params)

Expand Down Expand Up @@ -768,6 +773,7 @@ async def retrieve_async(
"error": True,
}


def retrieve(
self,
query: str,
Expand Down Expand Up @@ -1039,6 +1045,7 @@ async def retrieve_async(
"error": True,
}


def retrieve(
self,
query: str,
Expand Down
39 changes: 31 additions & 8 deletions agentrun/knowledgebase/knowledgebase.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def create(
Returns:
KnowledgeBase: 创建的知识库对象 / Created knowledge base object
"""
return cls.__get_client(config=config).create(input, config=config)
return cls.__get_client(config=config).create(
input, config=config
)

@classmethod
async def delete_by_name_async(
Expand Down Expand Up @@ -357,7 +359,9 @@ def delete(self, config: Optional[Config] = None):
"knowledge_base_name is required to delete a KnowledgeBase"
)

return self.delete_by_name(self.knowledge_base_name, config=config)
return self.delete_by_name(
self.knowledge_base_name, config=config
)

async def get_async(self, config: Optional[Config] = None):
"""刷新知识库信息(异步)/ Refresh knowledge base info asynchronously
Expand Down Expand Up @@ -394,7 +398,9 @@ def get(self, config: Optional[Config] = None):
"knowledge_base_name is required to refresh a KnowledgeBase"
)

result = self.get_by_name(self.knowledge_base_name, config=config)
result = self.get_by_name(
self.knowledge_base_name, config=config
)
self.update_self(result)

return self
Expand Down Expand Up @@ -458,6 +464,18 @@ def _get_data_api(self, config: Optional[Config] = None):
converted_provider_settings = None
converted_retrieve_settings = None

# 当 retrieve_settings 被 pydantic Union 匹配到错误的类型时(由于 extra="allow"),
# 从 __pydantic_extra__ 提取原始数据作为 dict 使用
# When retrieve_settings is matched to wrong Union type by pydantic (due to extra="allow"),
# extract raw data from __pydantic_extra__ as dict
if (
self.retrieve_settings is not None
and not isinstance(self.retrieve_settings, dict)
and hasattr(self.retrieve_settings, "__pydantic_extra__")
and self.retrieve_settings.__pydantic_extra__
):
self.retrieve_settings = self.retrieve_settings.__pydantic_extra__

if provider == KnowledgeBaseProvider.BAILIAN:
# 百炼设置 / Bailian settings
if self.provider_settings:
Expand Down Expand Up @@ -544,6 +562,9 @@ def _get_data_api(self, config: Optional[Config] = None):
hybrid_search_args=self.retrieve_settings.get(
"HybridSearchArgs"
),
filter=self.retrieve_settings.get(
"Filter"
),
)

elif provider == KnowledgeBaseProvider.OTS:
Expand Down Expand Up @@ -905,19 +926,21 @@ def multi_retrieve(
"""
# 1. 根据 knowledge_base_names 并发获取各知识库配置(安全方式)
# Fetch all knowledge bases concurrently by name (safely)
knowledge_base_results = [
knowledge_base_results = ([
cls._safe_get_kb(name, config=config)
for name in knowledge_base_names
]
])

# 2. 并发执行各知识库的检索(安全方式)
# Execute retrieval for each knowledge base concurrently (safely)
retrieve_results = [
cls._safe_retrieve_kb(kb_name, kb_or_error, query, config=config)
retrieve_results = ([
cls._safe_retrieve_kb(
kb_name, kb_or_error, query, config=config
)
for kb_name, kb_or_error in zip(
knowledge_base_names, knowledge_base_results
)
]
])

# 3. 合并返回结果,按知识库名称分组
# Merge results, grouped by knowledge base name
Expand Down
3 changes: 3 additions & 0 deletions agentrun/knowledgebase/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ class ADBRetrieveSettings(BaseModel):
hybrid_search_args: Optional[Dict[str, Any]] = None
"""混合检索算法参数,如 {"RRF": {"k": 60}} 或 {"Weight": {"alpha": 0.5}}
Hybrid search algorithm parameters"""
filter: Optional[str] = None
"""过滤条件,SQL WHERE 格式,如 "category = 'tech' AND score > 0.5"
Filter condition in SQL WHERE format"""


# =============================================================================
Expand Down
2 changes: 2 additions & 0 deletions tests/unittests/knowledgebase/api/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,7 @@ def test_build_query_content_request_with_settings(self):
recall_window=[-5, 5],
hybrid_search="RRF",
hybrid_search_args={"RRF": {"k": 60}},
filter="category = 'tech' AND score > 0.5",
),
)

Expand All @@ -1036,6 +1037,7 @@ def test_build_query_content_request_with_settings(self):
assert request.use_full_text_retrieval is True
assert request.rerank_factor == 1.5
assert request.rerank_model is not None
assert request.filter == "category = 'tech' AND score > 0.5"

@patch.dict(
os.environ,
Expand Down
2 changes: 2 additions & 0 deletions tests/unittests/knowledgebase/test_knowledgebase.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,13 +706,15 @@ def test_get_data_api_adb_with_raw_dict_settings(self):
"RecallWindow": [-5, 5],
"HybridSearch": "RRF",
"HybridSearchArgs": {"RRF": {"k": 60}},
"Filter": "category = 'tech'",
},
)

from agentrun.knowledgebase.api.data import ADBDataAPI

data_api = kb._get_data_api()
assert isinstance(data_api, ADBDataAPI)
assert data_api.retrieve_settings.filter == "category = 'tech'"

def test_get_data_api_without_provider(self):
"""测试获取数据链路 API(无提供商)"""
Expand Down
3 changes: 3 additions & 0 deletions tests/unittests/knowledgebase/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def test_create_adb_retrieve_settings(self):
recall_window=[-5, 5],
hybrid_search="RRF",
hybrid_search_args={"RRF": {"k": 60}},
filter="category = 'tech'",
)
assert settings.top_k == 10
assert settings.use_full_text_retrieval is True
Expand All @@ -225,6 +226,7 @@ def test_create_adb_retrieve_settings(self):
assert settings.recall_window == [-5, 5]
assert settings.hybrid_search == "RRF"
assert settings.hybrid_search_args == {"RRF": {"k": 60}}
assert settings.filter == "category = 'tech'"

def test_adb_retrieve_settings_optional(self):
"""测试 ADB 检索设置可选字段"""
Expand All @@ -236,6 +238,7 @@ def test_adb_retrieve_settings_optional(self):
assert settings.recall_window is None
assert settings.hybrid_search is None
assert settings.hybrid_search_args is None
assert settings.filter is None

def test_adb_retrieve_settings_weight_hybrid(self):
"""测试 ADB 检索设置加权混合检索"""
Expand Down
Loading