mirror of
https://github.com/langgenius/dify.git
synced 2025-11-22 07:56:37 +00:00
refactor: Enable type checking for dataset config manager (#26494)
Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
654d522b31
commit
11f7a89e25
@ -1,4 +1,5 @@
|
|||||||
import uuid
|
import uuid
|
||||||
|
from typing import Literal, cast
|
||||||
|
|
||||||
from core.app.app_config.entities import (
|
from core.app.app_config.entities import (
|
||||||
DatasetEntity,
|
DatasetEntity,
|
||||||
@ -74,6 +75,9 @@ class DatasetConfigManager:
|
|||||||
return None
|
return None
|
||||||
query_variable = config.get("dataset_query_variable")
|
query_variable = config.get("dataset_query_variable")
|
||||||
|
|
||||||
|
metadata_model_config_dict = dataset_configs.get("metadata_model_config")
|
||||||
|
metadata_filtering_conditions_dict = dataset_configs.get("metadata_filtering_conditions")
|
||||||
|
|
||||||
if dataset_configs["retrieval_model"] == "single":
|
if dataset_configs["retrieval_model"] == "single":
|
||||||
return DatasetEntity(
|
return DatasetEntity(
|
||||||
dataset_ids=dataset_ids,
|
dataset_ids=dataset_ids,
|
||||||
@ -82,18 +86,23 @@ class DatasetConfigManager:
|
|||||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||||
dataset_configs["retrieval_model"]
|
dataset_configs["retrieval_model"]
|
||||||
),
|
),
|
||||||
metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"),
|
metadata_filtering_mode=cast(
|
||||||
metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config"))
|
Literal["disabled", "automatic", "manual"],
|
||||||
if dataset_configs.get("metadata_model_config")
|
dataset_configs.get("metadata_filtering_mode", "disabled"),
|
||||||
|
),
|
||||||
|
metadata_model_config=ModelConfig(**metadata_model_config_dict)
|
||||||
|
if isinstance(metadata_model_config_dict, dict)
|
||||||
else None,
|
else None,
|
||||||
metadata_filtering_conditions=MetadataFilteringCondition(
|
metadata_filtering_conditions=MetadataFilteringCondition(**metadata_filtering_conditions_dict)
|
||||||
**dataset_configs.get("metadata_filtering_conditions", {})
|
if isinstance(metadata_filtering_conditions_dict, dict)
|
||||||
)
|
|
||||||
if dataset_configs.get("metadata_filtering_conditions")
|
|
||||||
else None,
|
else None,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
score_threshold_val = dataset_configs.get("score_threshold")
|
||||||
|
reranking_model_val = dataset_configs.get("reranking_model")
|
||||||
|
weights_val = dataset_configs.get("weights")
|
||||||
|
|
||||||
return DatasetEntity(
|
return DatasetEntity(
|
||||||
dataset_ids=dataset_ids,
|
dataset_ids=dataset_ids,
|
||||||
retrieve_config=DatasetRetrieveConfigEntity(
|
retrieve_config=DatasetRetrieveConfigEntity(
|
||||||
@ -101,22 +110,23 @@ class DatasetConfigManager:
|
|||||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||||
dataset_configs["retrieval_model"]
|
dataset_configs["retrieval_model"]
|
||||||
),
|
),
|
||||||
top_k=dataset_configs.get("top_k", 4),
|
top_k=int(dataset_configs.get("top_k", 4)),
|
||||||
score_threshold=dataset_configs.get("score_threshold")
|
score_threshold=float(score_threshold_val)
|
||||||
if dataset_configs.get("score_threshold_enabled", False)
|
if dataset_configs.get("score_threshold_enabled", False) and score_threshold_val is not None
|
||||||
else None,
|
else None,
|
||||||
reranking_model=dataset_configs.get("reranking_model"),
|
reranking_model=reranking_model_val if isinstance(reranking_model_val, dict) else None,
|
||||||
weights=dataset_configs.get("weights"),
|
weights=weights_val if isinstance(weights_val, dict) else None,
|
||||||
reranking_enabled=dataset_configs.get("reranking_enabled", True),
|
reranking_enabled=bool(dataset_configs.get("reranking_enabled", True)),
|
||||||
rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"),
|
rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"),
|
||||||
metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"),
|
metadata_filtering_mode=cast(
|
||||||
metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config"))
|
Literal["disabled", "automatic", "manual"],
|
||||||
if dataset_configs.get("metadata_model_config")
|
dataset_configs.get("metadata_filtering_mode", "disabled"),
|
||||||
|
),
|
||||||
|
metadata_model_config=ModelConfig(**metadata_model_config_dict)
|
||||||
|
if isinstance(metadata_model_config_dict, dict)
|
||||||
else None,
|
else None,
|
||||||
metadata_filtering_conditions=MetadataFilteringCondition(
|
metadata_filtering_conditions=MetadataFilteringCondition(**metadata_filtering_conditions_dict)
|
||||||
**dataset_configs.get("metadata_filtering_conditions", {})
|
if isinstance(metadata_filtering_conditions_dict, dict)
|
||||||
)
|
|
||||||
if dataset_configs.get("metadata_filtering_conditions")
|
|
||||||
else None,
|
else None,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -134,18 +144,17 @@ class DatasetConfigManager:
|
|||||||
config = cls.extract_dataset_config_for_legacy_compatibility(tenant_id, app_mode, config)
|
config = cls.extract_dataset_config_for_legacy_compatibility(tenant_id, app_mode, config)
|
||||||
|
|
||||||
# dataset_configs
|
# dataset_configs
|
||||||
if not config.get("dataset_configs"):
|
if "dataset_configs" not in config or not config.get("dataset_configs"):
|
||||||
config["dataset_configs"] = {"retrieval_model": "single"}
|
config["dataset_configs"] = {}
|
||||||
|
config["dataset_configs"]["retrieval_model"] = config["dataset_configs"].get("retrieval_model", "single")
|
||||||
|
|
||||||
if not isinstance(config["dataset_configs"], dict):
|
if not isinstance(config["dataset_configs"], dict):
|
||||||
raise ValueError("dataset_configs must be of object type")
|
raise ValueError("dataset_configs must be of object type")
|
||||||
|
|
||||||
if not config["dataset_configs"].get("datasets"):
|
if "datasets" not in config["dataset_configs"] or not config["dataset_configs"].get("datasets"):
|
||||||
config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []}
|
config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []}
|
||||||
|
|
||||||
need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get(
|
need_manual_query_datasets = config.get("dataset_configs", {}).get("datasets", {}).get("datasets")
|
||||||
"datasets", {}
|
|
||||||
).get("datasets")
|
|
||||||
|
|
||||||
if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
|
if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
|
||||||
# Only check when mode is completion
|
# Only check when mode is completion
|
||||||
@ -166,8 +175,8 @@ class DatasetConfigManager:
|
|||||||
:param config: app model config args
|
:param config: app model config args
|
||||||
"""
|
"""
|
||||||
# Extract dataset config for legacy compatibility
|
# Extract dataset config for legacy compatibility
|
||||||
if not config.get("agent_mode"):
|
if "agent_mode" not in config or not config.get("agent_mode"):
|
||||||
config["agent_mode"] = {"enabled": False, "tools": []}
|
config["agent_mode"] = {}
|
||||||
|
|
||||||
if not isinstance(config["agent_mode"], dict):
|
if not isinstance(config["agent_mode"], dict):
|
||||||
raise ValueError("agent_mode must be of object type")
|
raise ValueError("agent_mode must be of object type")
|
||||||
@ -180,19 +189,22 @@ class DatasetConfigManager:
|
|||||||
raise ValueError("enabled in agent_mode must be of boolean type")
|
raise ValueError("enabled in agent_mode must be of boolean type")
|
||||||
|
|
||||||
# tools
|
# tools
|
||||||
if not config["agent_mode"].get("tools"):
|
if "tools" not in config["agent_mode"] or not config["agent_mode"].get("tools"):
|
||||||
config["agent_mode"]["tools"] = []
|
config["agent_mode"]["tools"] = []
|
||||||
|
|
||||||
if not isinstance(config["agent_mode"]["tools"], list):
|
if not isinstance(config["agent_mode"]["tools"], list):
|
||||||
raise ValueError("tools in agent_mode must be a list of objects")
|
raise ValueError("tools in agent_mode must be a list of objects")
|
||||||
|
|
||||||
# strategy
|
# strategy
|
||||||
if not config["agent_mode"].get("strategy"):
|
if "strategy" not in config["agent_mode"] or not config["agent_mode"].get("strategy"):
|
||||||
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
|
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
|
||||||
|
|
||||||
has_datasets = False
|
has_datasets = False
|
||||||
if config["agent_mode"]["strategy"] in {PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value}:
|
if config.get("agent_mode", {}).get("strategy") in {
|
||||||
for tool in config["agent_mode"]["tools"]:
|
PlanningStrategy.ROUTER.value,
|
||||||
|
PlanningStrategy.REACT_ROUTER.value,
|
||||||
|
}:
|
||||||
|
for tool in config.get("agent_mode", {}).get("tools", []):
|
||||||
key = list(tool.keys())[0]
|
key = list(tool.keys())[0]
|
||||||
if key == "dataset":
|
if key == "dataset":
|
||||||
# old style, use tool name as key
|
# old style, use tool name as key
|
||||||
@ -217,7 +229,7 @@ class DatasetConfigManager:
|
|||||||
|
|
||||||
has_datasets = True
|
has_datasets = True
|
||||||
|
|
||||||
need_manual_query_datasets = has_datasets and config["agent_mode"]["enabled"]
|
need_manual_query_datasets = has_datasets and config.get("agent_mode", {}).get("enabled")
|
||||||
|
|
||||||
if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
|
if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
|
||||||
# Only check when mode is completion
|
# Only check when mode is completion
|
||||||
|
|||||||
@ -4,8 +4,7 @@
|
|||||||
"tests/",
|
"tests/",
|
||||||
".venv",
|
".venv",
|
||||||
"migrations/",
|
"migrations/",
|
||||||
"core/rag",
|
"core/rag"
|
||||||
"core/app/app_config/easy_ui_based_app/dataset"
|
|
||||||
],
|
],
|
||||||
"typeCheckingMode": "strict",
|
"typeCheckingMode": "strict",
|
||||||
"allowedUntypedLibraries": [
|
"allowedUntypedLibraries": [
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user