diff --git a/web/app/components/app/configuration/dataset-config/index.tsx b/web/app/components/app/configuration/dataset-config/index.tsx index 6165cfdeec..65ef74bc27 100644 --- a/web/app/components/app/configuration/dataset-config/index.tsx +++ b/web/app/components/app/configuration/dataset-config/index.tsx @@ -65,13 +65,40 @@ const DatasetConfig: FC = () => { const onRemove = (id: string) => { const filteredDataSets = dataSet.filter(item => item.id !== id) setDataSet(filteredDataSets) - const retrievalConfig = getMultipleRetrievalConfig(datasetConfigs as any, filteredDataSets, dataSet, { + const { datasets, retrieval_model, score_threshold_enabled, ...restConfigs } = datasetConfigs + const { + top_k, + score_threshold, + reranking_model, + reranking_mode, + weights, + reranking_enable, + } = restConfigs + const oldRetrievalConfig = { + top_k, + score_threshold, + reranking_model: (reranking_model.reranking_provider_name && reranking_model.reranking_model_name) ? { + provider: reranking_model.reranking_provider_name, + model: reranking_model.reranking_model_name, + } : undefined, + reranking_mode, + weights, + reranking_enable, + } + const retrievalConfig = getMultipleRetrievalConfig(oldRetrievalConfig, filteredDataSets, dataSet, { provider: currentRerankProvider?.provider, model: currentRerankModel?.model, }) setDatasetConfigs({ - ...(datasetConfigs as any), + ...datasetConfigsRef.current, ...retrievalConfig, + reranking_model: { + reranking_provider_name: retrievalConfig?.reranking_model?.provider || '', + reranking_model_name: retrievalConfig?.reranking_model?.model || '', + }, + retrieval_model, + score_threshold_enabled, + datasets, }) const { allExternal, diff --git a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx index cb61b927bc..1558d32fc6 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx @@ -30,11 +30,11 @@ import { noop } from 'lodash-es' type Props = { datasetConfigs: DatasetConfigs onChange: (configs: DatasetConfigs, isRetrievalModeChange?: boolean) => void + selectedDatasets?: DataSet[] isInWorkflow?: boolean singleRetrievalModelConfig?: ModelConfig onSingleRetrievalModelChange?: (config: ModelConfig) => void onSingleRetrievalModelParamsChange?: (config: ModelConfig) => void - selectedDatasets?: DataSet[] } const ConfigContent: FC = ({ @@ -61,22 +61,28 @@ const ConfigContent: FC = ({ const { modelList: rerankModelList, + currentModel: validDefaultRerankModel, + currentProvider: validDefaultRerankProvider, } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) + /** + * If reranking model is set and is valid, use the reranking model + * Otherwise, check if the default reranking model is valid + */ const { currentModel: currentRerankModel, } = useCurrentProviderAndModel( rerankModelList, { - provider: datasetConfigs.reranking_model?.reranking_provider_name, - model: datasetConfigs.reranking_model?.reranking_model_name, + provider: datasetConfigs.reranking_model?.reranking_provider_name || validDefaultRerankProvider?.provider || '', + model: datasetConfigs.reranking_model?.reranking_model_name || validDefaultRerankModel?.model || '', }, ) const rerankModel = useMemo(() => { return { - provider_name: datasetConfigs?.reranking_model?.reranking_provider_name ?? '', - model_name: datasetConfigs?.reranking_model?.reranking_model_name ?? '', + provider_name: datasetConfigs.reranking_model?.reranking_provider_name ?? '', + model_name: datasetConfigs.reranking_model?.reranking_model_name ?? '', } }, [datasetConfigs.reranking_model]) @@ -135,7 +141,7 @@ const ConfigContent: FC = ({ }) } - const model = singleRetrievalConfig + const model = singleRetrievalConfig // Legacy code, for compatibility, have to keep it const rerankingModeOptions = [ { @@ -158,7 +164,7 @@ const ConfigContent: FC = ({ const canManuallyToggleRerank = useMemo(() => { return (selectedDatasetsMode.allInternal && selectedDatasetsMode.allEconomic) - || selectedDatasetsMode.allExternal + || selectedDatasetsMode.allExternal }, [selectedDatasetsMode.allEconomic, selectedDatasetsMode.allExternal, selectedDatasetsMode.allInternal]) const showRerankModel = useMemo(() => { @@ -168,7 +174,7 @@ const ConfigContent: FC = ({ return datasetConfigs.reranking_enable }, [datasetConfigs.reranking_enable, canManuallyToggleRerank]) - const handleDisabledSwitchClick = useCallback((enable: boolean) => { + const handleManuallyToggleRerank = useCallback((enable: boolean) => { if (!currentRerankModel && enable) Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') }) onChange({ @@ -255,12 +261,11 @@ const ConfigContent: FC = ({
{ - selectedDatasetsMode.allEconomic && !selectedDatasetsMode.mixtureInternalAndExternal && ( + canManuallyToggleRerank && ( ) } diff --git a/web/app/components/app/configuration/index.tsx b/web/app/components/app/configuration/index.tsx index 091900642a..f1f81ebf97 100644 --- a/web/app/components/app/configuration/index.tsx +++ b/web/app/components/app/configuration/index.tsx @@ -284,18 +284,28 @@ const Configuration: FC = () => { setRerankSettingModalOpen(true) const { datasets, retrieval_model, score_threshold_enabled, ...restConfigs } = datasetConfigs + const { + top_k, + score_threshold, + reranking_model, + reranking_mode, + weights, + reranking_enable, + } = restConfigs - const retrievalConfig = getMultipleRetrievalConfig({ - top_k: restConfigs.top_k, - score_threshold: restConfigs.score_threshold, - reranking_model: restConfigs.reranking_model && { - provider: restConfigs.reranking_model.reranking_provider_name, - model: restConfigs.reranking_model.reranking_model_name, - }, - reranking_mode: restConfigs.reranking_mode, - weights: restConfigs.weights, - reranking_enable: restConfigs.reranking_enable, - }, newDatasets, dataSets, { + const oldRetrievalConfig = { + top_k, + score_threshold, + reranking_model: (reranking_model.reranking_provider_name && reranking_model.reranking_model_name) ? { + provider: reranking_model.reranking_provider_name, + model: reranking_model.reranking_model_name, + } : undefined, + reranking_mode, + weights, + reranking_enable, + } + + const retrievalConfig = getMultipleRetrievalConfig(oldRetrievalConfig, newDatasets, dataSets, { provider: currentRerankProvider?.provider, model: currentRerankModel?.model, }) diff --git a/web/app/components/datasets/common/retrieval-method-config/index.tsx b/web/app/components/datasets/common/retrieval-method-config/index.tsx index 57d357442f..ed230c52ce 100644 --- a/web/app/components/datasets/common/retrieval-method-config/index.tsx +++ b/web/app/components/datasets/common/retrieval-method-config/index.tsx @@ -40,7 +40,7 @@ const RetrievalMethodConfig: FC = ({ onChange({ ...value, search_method: retrieveMethod, - ...(!value.reranking_model.reranking_model_name + ...((!value.reranking_model.reranking_model_name || !value.reranking_model.reranking_provider_name) ? { reranking_model: { reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider?.provider ?? '' : '', @@ -57,7 +57,7 @@ const RetrievalMethodConfig: FC = ({ onChange({ ...value, search_method: retrieveMethod, - ...(!value.reranking_model.reranking_model_name + ...((!value.reranking_model.reranking_model_name || !value.reranking_model.reranking_provider_name) ? { reranking_model: { reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider?.provider ?? '' : '', diff --git a/web/app/components/datasets/common/retrieval-param-config/index.tsx b/web/app/components/datasets/common/retrieval-param-config/index.tsx index 216a56ab16..0c28149d56 100644 --- a/web/app/components/datasets/common/retrieval-param-config/index.tsx +++ b/web/app/components/datasets/common/retrieval-param-config/index.tsx @@ -54,7 +54,7 @@ const RetrievalParamConfig: FC = ({ }, ) - const handleDisabledSwitchClick = useCallback((enable: boolean) => { + const handleToggleRerankEnable = useCallback((enable: boolean) => { if (enable && !currentModel) Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') }) onChange({ @@ -119,7 +119,7 @@ const RetrievalParamConfig: FC = ({ )}
diff --git a/web/app/components/workflow/nodes/knowledge-retrieval/components/retrieval-config.tsx b/web/app/components/workflow/nodes/knowledge-retrieval/components/retrieval-config.tsx index 8a3dc1efba..619216d672 100644 --- a/web/app/components/workflow/nodes/knowledge-retrieval/components/retrieval-config.tsx +++ b/web/app/components/workflow/nodes/knowledge-retrieval/components/retrieval-config.tsx @@ -1,6 +1,6 @@ 'use client' import type { FC } from 'react' -import React, { useCallback, useState } from 'react' +import React, { useCallback, useMemo } from 'react' import { RiEqualizer2Line } from '@remixicon/react' import { useTranslation } from 'react-i18next' import type { MultipleRetrievalConfig, SingleRetrievalConfig } from '../types' @@ -14,8 +14,6 @@ import { import ConfigRetrievalContent from '@/app/components/app/configuration/dataset-config/params-config/config-content' import { RETRIEVE_TYPE } from '@/types/app' import { DATASET_DEFAULT } from '@/config' -import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' -import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import Button from '@/app/components/base/button' import type { DatasetConfigs } from '@/models/debug' import type { DataSet } from '@/models/datasets' @@ -32,8 +30,8 @@ type Props = { onSingleRetrievalModelChange?: (config: ModelConfig) => void onSingleRetrievalModelParamsChange?: (config: ModelConfig) => void readonly?: boolean - openFromProps?: boolean - onOpenFromPropsChange?: (openFromProps: boolean) => void + rerankModalOpen: boolean + onRerankModelOpenChange: (open: boolean) => void selectedDatasets: DataSet[] } @@ -45,26 +43,52 @@ const RetrievalConfig: FC = ({ onSingleRetrievalModelChange, onSingleRetrievalModelParamsChange, readonly, - openFromProps, - onOpenFromPropsChange, + rerankModalOpen, + onRerankModelOpenChange, selectedDatasets, }) => { const { t } = useTranslation() - const [open, setOpen] = useState(false) - const mergedOpen = openFromProps !== undefined ? openFromProps : open + const { retrieval_mode, multiple_retrieval_config } = payload const handleOpen = useCallback((newOpen: boolean) => { - setOpen(newOpen) - onOpenFromPropsChange?.(newOpen) - }, [onOpenFromPropsChange]) + onRerankModelOpenChange(newOpen) + }, [onRerankModelOpenChange]) - const { - currentProvider: validRerankDefaultProvider, - currentModel: validRerankDefaultModel, - } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) + const datasetConfigs = useMemo(() => { + const { + reranking_model, + top_k, + score_threshold, + reranking_mode, + weights, + reranking_enable, + } = multiple_retrieval_config || {} + + return { + retrieval_model: retrieval_mode, + reranking_model: (reranking_model?.provider && reranking_model?.model) + ? { + reranking_provider_name: reranking_model?.provider, + reranking_model_name: reranking_model?.model, + } + : { + reranking_provider_name: '', + reranking_model_name: '', + }, + top_k: top_k || DATASET_DEFAULT.top_k, + score_threshold_enabled: !(score_threshold === undefined || score_threshold === null), + score_threshold, + datasets: { + datasets: [], + }, + reranking_mode, + weights, + reranking_enable, + } + }, [retrieval_mode, multiple_retrieval_config]) - const { multiple_retrieval_config } = payload const handleChange = useCallback((configs: DatasetConfigs, isRetrievalModeChange?: boolean) => { + // Legacy code, for compatibility, have to keep it if (isRetrievalModeChange) { onRetrievalModeChange(configs.retrieval_model) return @@ -72,13 +96,11 @@ const RetrievalConfig: FC = ({ onMultipleRetrievalConfigChange({ top_k: configs.top_k, score_threshold: configs.score_threshold_enabled ? (configs.score_threshold ?? DATASET_DEFAULT.score_threshold) : null, - reranking_model: payload.retrieval_mode === RETRIEVE_TYPE.oneWay + reranking_model: retrieval_mode === RETRIEVE_TYPE.oneWay ? undefined + // eslint-disable-next-line sonarjs/no-nested-conditional : (!configs.reranking_model?.reranking_provider_name - ? { - provider: validRerankDefaultProvider?.provider || '', - model: validRerankDefaultModel?.model || '', - } + ? undefined : { provider: configs.reranking_model?.reranking_provider_name, model: configs.reranking_model?.reranking_model_name, @@ -87,11 +109,11 @@ const RetrievalConfig: FC = ({ weights: configs.weights, reranking_enable: configs.reranking_enable, }) - }, [onMultipleRetrievalConfigChange, payload.retrieval_mode, validRerankDefaultProvider, validRerankDefaultModel, onRetrievalModeChange]) + }, [onMultipleRetrievalConfigChange, retrieval_mode, onRetrievalModeChange]) return ( = ({ onClick={() => { if (readonly) return - handleOpen(!mergedOpen) + handleOpen(!rerankModalOpen) }} >