refactor: Enable type checking for dataset config manager (#26494)

Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Asuka Minato 2025-10-07 15:50:44 +09:00 committed by GitHub
parent 654d522b31
commit 11f7a89e25
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 46 additions and 35 deletions

View File

@ -1,4 +1,5 @@
import uuid
from typing import Literal, cast
from core.app.app_config.entities import (
DatasetEntity,
@ -74,6 +75,9 @@ class DatasetConfigManager:
return None
query_variable = config.get("dataset_query_variable")
metadata_model_config_dict = dataset_configs.get("metadata_model_config")
metadata_filtering_conditions_dict = dataset_configs.get("metadata_filtering_conditions")
if dataset_configs["retrieval_model"] == "single":
return DatasetEntity(
dataset_ids=dataset_ids,
@ -82,18 +86,23 @@ class DatasetConfigManager:
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs["retrieval_model"]
),
metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"),
metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config"))
if dataset_configs.get("metadata_model_config")
metadata_filtering_mode=cast(
Literal["disabled", "automatic", "manual"],
dataset_configs.get("metadata_filtering_mode", "disabled"),
),
metadata_model_config=ModelConfig(**metadata_model_config_dict)
if isinstance(metadata_model_config_dict, dict)
else None,
metadata_filtering_conditions=MetadataFilteringCondition(
**dataset_configs.get("metadata_filtering_conditions", {})
)
if dataset_configs.get("metadata_filtering_conditions")
metadata_filtering_conditions=MetadataFilteringCondition(**metadata_filtering_conditions_dict)
if isinstance(metadata_filtering_conditions_dict, dict)
else None,
),
)
else:
score_threshold_val = dataset_configs.get("score_threshold")
reranking_model_val = dataset_configs.get("reranking_model")
weights_val = dataset_configs.get("weights")
return DatasetEntity(
dataset_ids=dataset_ids,
retrieve_config=DatasetRetrieveConfigEntity(
@ -101,22 +110,23 @@ class DatasetConfigManager:
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs["retrieval_model"]
),
top_k=dataset_configs.get("top_k", 4),
score_threshold=dataset_configs.get("score_threshold")
if dataset_configs.get("score_threshold_enabled", False)
top_k=int(dataset_configs.get("top_k", 4)),
score_threshold=float(score_threshold_val)
if dataset_configs.get("score_threshold_enabled", False) and score_threshold_val is not None
else None,
reranking_model=dataset_configs.get("reranking_model"),
weights=dataset_configs.get("weights"),
reranking_enabled=dataset_configs.get("reranking_enabled", True),
reranking_model=reranking_model_val if isinstance(reranking_model_val, dict) else None,
weights=weights_val if isinstance(weights_val, dict) else None,
reranking_enabled=bool(dataset_configs.get("reranking_enabled", True)),
rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"),
metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"),
metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config"))
if dataset_configs.get("metadata_model_config")
metadata_filtering_mode=cast(
Literal["disabled", "automatic", "manual"],
dataset_configs.get("metadata_filtering_mode", "disabled"),
),
metadata_model_config=ModelConfig(**metadata_model_config_dict)
if isinstance(metadata_model_config_dict, dict)
else None,
metadata_filtering_conditions=MetadataFilteringCondition(
**dataset_configs.get("metadata_filtering_conditions", {})
)
if dataset_configs.get("metadata_filtering_conditions")
metadata_filtering_conditions=MetadataFilteringCondition(**metadata_filtering_conditions_dict)
if isinstance(metadata_filtering_conditions_dict, dict)
else None,
),
)
@ -134,18 +144,17 @@ class DatasetConfigManager:
config = cls.extract_dataset_config_for_legacy_compatibility(tenant_id, app_mode, config)
# dataset_configs
if not config.get("dataset_configs"):
config["dataset_configs"] = {"retrieval_model": "single"}
if "dataset_configs" not in config or not config.get("dataset_configs"):
config["dataset_configs"] = {}
config["dataset_configs"]["retrieval_model"] = config["dataset_configs"].get("retrieval_model", "single")
if not isinstance(config["dataset_configs"], dict):
raise ValueError("dataset_configs must be of object type")
if not config["dataset_configs"].get("datasets"):
if "datasets" not in config["dataset_configs"] or not config["dataset_configs"].get("datasets"):
config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []}
need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get(
"datasets", {}
).get("datasets")
need_manual_query_datasets = config.get("dataset_configs", {}).get("datasets", {}).get("datasets")
if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
# Only check when mode is completion
@ -166,8 +175,8 @@ class DatasetConfigManager:
:param config: app model config args
"""
# Extract dataset config for legacy compatibility
if not config.get("agent_mode"):
config["agent_mode"] = {"enabled": False, "tools": []}
if "agent_mode" not in config or not config.get("agent_mode"):
config["agent_mode"] = {}
if not isinstance(config["agent_mode"], dict):
raise ValueError("agent_mode must be of object type")
@ -180,19 +189,22 @@ class DatasetConfigManager:
raise ValueError("enabled in agent_mode must be of boolean type")
# tools
if not config["agent_mode"].get("tools"):
if "tools" not in config["agent_mode"] or not config["agent_mode"].get("tools"):
config["agent_mode"]["tools"] = []
if not isinstance(config["agent_mode"]["tools"], list):
raise ValueError("tools in agent_mode must be a list of objects")
# strategy
if not config["agent_mode"].get("strategy"):
if "strategy" not in config["agent_mode"] or not config["agent_mode"].get("strategy"):
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
has_datasets = False
if config["agent_mode"]["strategy"] in {PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value}:
for tool in config["agent_mode"]["tools"]:
if config.get("agent_mode", {}).get("strategy") in {
PlanningStrategy.ROUTER.value,
PlanningStrategy.REACT_ROUTER.value,
}:
for tool in config.get("agent_mode", {}).get("tools", []):
key = list(tool.keys())[0]
if key == "dataset":
# old style, use tool name as key
@ -217,7 +229,7 @@ class DatasetConfigManager:
has_datasets = True
need_manual_query_datasets = has_datasets and config["agent_mode"]["enabled"]
need_manual_query_datasets = has_datasets and config.get("agent_mode", {}).get("enabled")
if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
# Only check when mode is completion

View File

@ -4,8 +4,7 @@
"tests/",
".venv",
"migrations/",
"core/rag",
"core/app/app_config/easy_ui_based_app/dataset"
"core/rag"
],
"typeCheckingMode": "strict",
"allowedUntypedLibraries": [