Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions agentrun/knowledgebase/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .knowledgebase import KnowledgeBase
from .model import (
ADBProviderSettings,
ADBRerankModel,
ADBRetrieveSettings,
BailianProviderSettings,
BailianRetrieveSettings,
Expand Down Expand Up @@ -64,6 +65,7 @@
"RetrieveSettings",
"RagFlowRetrieveSettings",
"BailianRetrieveSettings",
"ADBRerankModel",
"ADBRetrieveSettings",
"OTSRetrieveSettings",
"OTSDenseVectorSearchConfig",
Expand Down
9 changes: 9 additions & 0 deletions agentrun/knowledgebase/__knowledgebase_async_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .api.data import get_data_api
from .model import (
ADBProviderSettings,
ADBRerankModel,
ADBRetrieveSettings,
BailianProviderSettings,
BailianRetrieveSettings,
Expand Down Expand Up @@ -344,6 +345,14 @@ def _get_data_api(self, config: Optional[Config] = None):
rerank_factor=self.retrieve_settings.get(
"RerankFactor"
),
rerank_model=(
ADBRerankModel(
name=self.retrieve_settings.get("RerankModel", {}).get("Name", ""),
instruct=self.retrieve_settings.get("RerankModel", {}).get("Instruct"),
)
if self.retrieve_settings.get("RerankModel")
else None
),
recall_window=self.retrieve_settings.get(
"RecallWindow"
),
Expand Down
11 changes: 11 additions & 0 deletions agentrun/knowledgebase/api/__data_async_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,8 @@ def _build_query_content_request(
"namespace_password": self.provider_settings.namespace_password,
"collection": self.knowledge_base_name,
"region_id": cfg.get_region_id(),
# 固定设置 URL 过期时间为 356 天 / Fixed URL expiration to 356 days
"url_expiration": "356d",
}

# 添加可选的提供商设置 / Add optional provider settings
Expand All @@ -423,6 +425,15 @@ def _build_query_content_request(
request_params["rerank_factor"] = (
self.retrieve_settings.rerank_factor
)
if self.retrieve_settings.rerank_model is not None:
rerank_model_params: Dict[str, Any] = {
"Name": self.retrieve_settings.rerank_model.name,
}
if self.retrieve_settings.rerank_model.instruct is not None:
rerank_model_params["Instruct"] = (
self.retrieve_settings.rerank_model.instruct
)
request_params["rerank_model"] = rerank_model_params
if self.retrieve_settings.recall_window is not None:
request_params["recall_window"] = (
self.retrieve_settings.recall_window
Expand Down
11 changes: 11 additions & 0 deletions agentrun/knowledgebase/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,8 @@ def _build_query_content_request(
"namespace_password": self.provider_settings.namespace_password,
"collection": self.knowledge_base_name,
"region_id": cfg.get_region_id(),
# 固定设置 URL 过期时间为 356 天 / Fixed URL expiration to 356 days
"url_expiration": "356d",
}

# 添加可选的提供商设置 / Add optional provider settings
Expand All @@ -633,6 +635,15 @@ def _build_query_content_request(
request_params["rerank_factor"] = (
self.retrieve_settings.rerank_factor
)
if self.retrieve_settings.rerank_model is not None:
rerank_model_params: Dict[str, Any] = {
"Name": self.retrieve_settings.rerank_model.name,
}
if self.retrieve_settings.rerank_model.instruct is not None:
rerank_model_params["Instruct"] = (
self.retrieve_settings.rerank_model.instruct
)
request_params["rerank_model"] = rerank_model_params
if self.retrieve_settings.recall_window is not None:
request_params["recall_window"] = (
self.retrieve_settings.recall_window
Expand Down
9 changes: 9 additions & 0 deletions agentrun/knowledgebase/knowledgebase.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .api.data import get_data_api
from .model import (
ADBProviderSettings,
ADBRerankModel,
ADBRetrieveSettings,
BailianProviderSettings,
BailianRetrieveSettings,
Expand Down Expand Up @@ -526,6 +527,14 @@ def _get_data_api(self, config: Optional[Config] = None):
rerank_factor=self.retrieve_settings.get(
"RerankFactor"
),
rerank_model=(
ADBRerankModel(
name=self.retrieve_settings.get("RerankModel", {}).get("Name", ""),
instruct=self.retrieve_settings.get("RerankModel", {}).get("Instruct"),
)
if self.retrieve_settings.get("RerankModel")
else None
),
recall_window=self.retrieve_settings.get(
"RecallWindow"
),
Expand Down
19 changes: 19 additions & 0 deletions agentrun/knowledgebase/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,22 @@ class ADBProviderSettings(BaseModel):
"""元数据配置,JSON 字符串格式 / Metadata configuration in JSON string format"""


class ADBRerankModel(BaseModel):
"""ADB 重排模型配置 / ADB Rerank Model Configuration

配置重排模型的名称和排序任务类型说明。
Configure the rerank model name and instruct for sorting task type.
"""

name: str
"""重排模型名称,可选值:qwen3-rerank、gte-rerank-v2
Rerank model name, options: qwen3-rerank, gte-rerank-v2"""
instruct: Optional[str] = None
"""排序任务类型说明,仅当 name 为 qwen3-rerank 时可设置,指导模型采用不同的排序策略
Instruct for sorting task type, only available when name is qwen3-rerank,
guides the model to adopt different sorting strategies"""


class ADBRetrieveSettings(BaseModel):
"""ADB 检索设置 / ADB Retrieve Settings

Expand All @@ -122,6 +138,9 @@ class ADBRetrieveSettings(BaseModel):
rerank_factor: Optional[float] = None
"""重排序因子,取值范围 1 < RerankFactor <= 5
Re-ranking factor, value range: 1 < RerankFactor <= 5"""
rerank_model: Optional[ADBRerankModel] = None
"""重排模型配置,当启用重排因子时可设置
Rerank model configuration, available when rerank factor is enabled"""
recall_window: Optional[List[int]] = None
"""召回窗口,格式为 [A, B],其中 -10 <= A <= 0,0 <= B <= 10
Recall window, format [A, B] where -10 <= A <= 0, 0 <= B <= 10"""
Expand Down
5 changes: 5 additions & 0 deletions examples/knowledgebase.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@

from agentrun.knowledgebase import (
ADBProviderSettings,
ADBRerankModel,
ADBRetrieveSettings,
BailianProviderSettings,
BailianRetrieveSettings,
Expand Down Expand Up @@ -474,6 +475,10 @@ def create_or_get_adb_kb() -> KnowledgeBase:
top_k=10,
use_full_text_retrieval=False, # 仅使用向量检索 / Vector only
rerank_factor=2.0, # 重排序因子 / Rerank factor
rerank_model=ADBRerankModel(
name="qwen3-rerank", # 重排模型名称 / Rerank model name
instruct="按相关性排序", # 排序任务类型说明(仅 qwen3-rerank 支持)/ Instruct (only for qwen3-rerank)
),
),
)
)
Expand Down
8 changes: 8 additions & 0 deletions tests/unittests/knowledgebase/api/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from agentrun.knowledgebase.model import (
ADBProviderSettings,
ADBRerankModel,
ADBRetrieveSettings,
BailianProviderSettings,
BailianRetrieveSettings,
Expand Down Expand Up @@ -995,6 +996,7 @@ def test_build_query_content_request(self):
assert request.dbinstance_id == "gp-123456"
assert request.namespace == "public"
assert request.collection == "test-kb"
assert request.url_expiration == "356d"

@patch.dict(
os.environ,
Expand All @@ -1017,17 +1019,23 @@ def test_build_query_content_request_with_settings(self):
top_k=10,
use_full_text_retrieval=True,
rerank_factor=1.5,
rerank_model=ADBRerankModel(
name="qwen3-rerank",
instruct="按相关性排序",
),
recall_window=[-5, 5],
hybrid_search="RRF",
hybrid_search_args={"RRF": {"k": 60}},
),
)

request = api._build_query_content_request("test query")
assert request.url_expiration == "356d"
assert request.metrics == "cosine"
assert request.top_k == 10
assert request.use_full_text_retrieval is True
assert request.rerank_factor == 1.5
assert request.rerank_model is not None

@patch.dict(
os.environ,
Expand Down
1 change: 1 addition & 0 deletions tests/unittests/knowledgebase/test_knowledgebase.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,7 @@ def test_get_data_api_adb_with_raw_dict_settings(self):
"TopK": 10,
"UseFullTextRetrieval": True,
"RerankFactor": 1.5,
"RerankModel": {"Name": "qwen3-rerank", "Instruct": "按相关性排序"},
"RecallWindow": [-5, 5],
"HybridSearch": "RRF",
"HybridSearchArgs": {"RRF": {"k": 60}},
Expand Down
9 changes: 9 additions & 0 deletions tests/unittests/knowledgebase/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

from agentrun.knowledgebase.model import (
ADBRerankModel,
ADBProviderSettings,
ADBRetrieveSettings,
BailianProviderSettings,
Expand Down Expand Up @@ -207,13 +208,20 @@ def test_create_adb_retrieve_settings(self):
top_k=10,
use_full_text_retrieval=True,
rerank_factor=1.5,
rerank_model=ADBRerankModel(
name="qwen3-rerank",
instruct="按相关性排序",
),
recall_window=[-5, 5],
hybrid_search="RRF",
hybrid_search_args={"RRF": {"k": 60}},
)
assert settings.top_k == 10
assert settings.use_full_text_retrieval is True
assert settings.rerank_factor == 1.5
assert settings.rerank_model is not None
assert settings.rerank_model.name == "qwen3-rerank"
assert settings.rerank_model.instruct == "按相关性排序"
assert settings.recall_window == [-5, 5]
assert settings.hybrid_search == "RRF"
assert settings.hybrid_search_args == {"RRF": {"k": 60}}
Expand All @@ -224,6 +232,7 @@ def test_adb_retrieve_settings_optional(self):
assert settings.top_k is None
assert settings.use_full_text_retrieval is None
assert settings.rerank_factor is None
assert settings.rerank_model is None
assert settings.recall_window is None
assert settings.hybrid_search is None
assert settings.hybrid_search_args is None
Expand Down
Loading