fix: Fix retrieval configuration handling in dataset components (#26361)

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Wu Tianwei 2025-09-29 14:58:28 +08:00 committed by GitHub
parent af662b100b
commit 1a7898dff1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 216 additions and 134 deletions

View File

@ -65,13 +65,40 @@ const DatasetConfig: FC = () => {
const onRemove = (id: string) => { const onRemove = (id: string) => {
const filteredDataSets = dataSet.filter(item => item.id !== id) const filteredDataSets = dataSet.filter(item => item.id !== id)
setDataSet(filteredDataSets) setDataSet(filteredDataSets)
const retrievalConfig = getMultipleRetrievalConfig(datasetConfigs as any, filteredDataSets, dataSet, { const { datasets, retrieval_model, score_threshold_enabled, ...restConfigs } = datasetConfigs
const {
top_k,
score_threshold,
reranking_model,
reranking_mode,
weights,
reranking_enable,
} = restConfigs
const oldRetrievalConfig = {
top_k,
score_threshold,
reranking_model: (reranking_model.reranking_provider_name && reranking_model.reranking_model_name) ? {
provider: reranking_model.reranking_provider_name,
model: reranking_model.reranking_model_name,
} : undefined,
reranking_mode,
weights,
reranking_enable,
}
const retrievalConfig = getMultipleRetrievalConfig(oldRetrievalConfig, filteredDataSets, dataSet, {
provider: currentRerankProvider?.provider, provider: currentRerankProvider?.provider,
model: currentRerankModel?.model, model: currentRerankModel?.model,
}) })
setDatasetConfigs({ setDatasetConfigs({
...(datasetConfigs as any), ...datasetConfigsRef.current,
...retrievalConfig, ...retrievalConfig,
reranking_model: {
reranking_provider_name: retrievalConfig?.reranking_model?.provider || '',
reranking_model_name: retrievalConfig?.reranking_model?.model || '',
},
retrieval_model,
score_threshold_enabled,
datasets,
}) })
const { const {
allExternal, allExternal,

View File

@ -30,11 +30,11 @@ import { noop } from 'lodash-es'
type Props = { type Props = {
datasetConfigs: DatasetConfigs datasetConfigs: DatasetConfigs
onChange: (configs: DatasetConfigs, isRetrievalModeChange?: boolean) => void onChange: (configs: DatasetConfigs, isRetrievalModeChange?: boolean) => void
selectedDatasets?: DataSet[]
isInWorkflow?: boolean isInWorkflow?: boolean
singleRetrievalModelConfig?: ModelConfig singleRetrievalModelConfig?: ModelConfig
onSingleRetrievalModelChange?: (config: ModelConfig) => void onSingleRetrievalModelChange?: (config: ModelConfig) => void
onSingleRetrievalModelParamsChange?: (config: ModelConfig) => void onSingleRetrievalModelParamsChange?: (config: ModelConfig) => void
selectedDatasets?: DataSet[]
} }
const ConfigContent: FC<Props> = ({ const ConfigContent: FC<Props> = ({
@ -61,22 +61,28 @@ const ConfigContent: FC<Props> = ({
const { const {
modelList: rerankModelList, modelList: rerankModelList,
currentModel: validDefaultRerankModel,
currentProvider: validDefaultRerankProvider,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
/**
* If reranking model is set and is valid, use the reranking model
* Otherwise, check if the default reranking model is valid
*/
const { const {
currentModel: currentRerankModel, currentModel: currentRerankModel,
} = useCurrentProviderAndModel( } = useCurrentProviderAndModel(
rerankModelList, rerankModelList,
{ {
provider: datasetConfigs.reranking_model?.reranking_provider_name, provider: datasetConfigs.reranking_model?.reranking_provider_name || validDefaultRerankProvider?.provider || '',
model: datasetConfigs.reranking_model?.reranking_model_name, model: datasetConfigs.reranking_model?.reranking_model_name || validDefaultRerankModel?.model || '',
}, },
) )
const rerankModel = useMemo(() => { const rerankModel = useMemo(() => {
return { return {
provider_name: datasetConfigs?.reranking_model?.reranking_provider_name ?? '', provider_name: datasetConfigs.reranking_model?.reranking_provider_name ?? '',
model_name: datasetConfigs?.reranking_model?.reranking_model_name ?? '', model_name: datasetConfigs.reranking_model?.reranking_model_name ?? '',
} }
}, [datasetConfigs.reranking_model]) }, [datasetConfigs.reranking_model])
@ -135,7 +141,7 @@ const ConfigContent: FC<Props> = ({
}) })
} }
const model = singleRetrievalConfig const model = singleRetrievalConfig // Legacy code, for compatibility, have to keep it
const rerankingModeOptions = [ const rerankingModeOptions = [
{ {
@ -168,7 +174,7 @@ const ConfigContent: FC<Props> = ({
return datasetConfigs.reranking_enable return datasetConfigs.reranking_enable
}, [datasetConfigs.reranking_enable, canManuallyToggleRerank]) }, [datasetConfigs.reranking_enable, canManuallyToggleRerank])
const handleDisabledSwitchClick = useCallback((enable: boolean) => { const handleManuallyToggleRerank = useCallback((enable: boolean) => {
if (!currentRerankModel && enable) if (!currentRerankModel && enable)
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') }) Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
onChange({ onChange({
@ -255,12 +261,11 @@ const ConfigContent: FC<Props> = ({
<div className='mt-2'> <div className='mt-2'>
<div className='flex items-center'> <div className='flex items-center'>
{ {
selectedDatasetsMode.allEconomic && !selectedDatasetsMode.mixtureInternalAndExternal && ( canManuallyToggleRerank && (
<Switch <Switch
size='md' size='md'
defaultValue={showRerankModel} defaultValue={showRerankModel}
disabled={!canManuallyToggleRerank} onChange={handleManuallyToggleRerank}
onChange={handleDisabledSwitchClick}
/> />
) )
} }

View File

@ -284,18 +284,28 @@ const Configuration: FC = () => {
setRerankSettingModalOpen(true) setRerankSettingModalOpen(true)
const { datasets, retrieval_model, score_threshold_enabled, ...restConfigs } = datasetConfigs const { datasets, retrieval_model, score_threshold_enabled, ...restConfigs } = datasetConfigs
const {
top_k,
score_threshold,
reranking_model,
reranking_mode,
weights,
reranking_enable,
} = restConfigs
const retrievalConfig = getMultipleRetrievalConfig({ const oldRetrievalConfig = {
top_k: restConfigs.top_k, top_k,
score_threshold: restConfigs.score_threshold, score_threshold,
reranking_model: restConfigs.reranking_model && { reranking_model: (reranking_model.reranking_provider_name && reranking_model.reranking_model_name) ? {
provider: restConfigs.reranking_model.reranking_provider_name, provider: reranking_model.reranking_provider_name,
model: restConfigs.reranking_model.reranking_model_name, model: reranking_model.reranking_model_name,
}, } : undefined,
reranking_mode: restConfigs.reranking_mode, reranking_mode,
weights: restConfigs.weights, weights,
reranking_enable: restConfigs.reranking_enable, reranking_enable,
}, newDatasets, dataSets, { }
const retrievalConfig = getMultipleRetrievalConfig(oldRetrievalConfig, newDatasets, dataSets, {
provider: currentRerankProvider?.provider, provider: currentRerankProvider?.provider,
model: currentRerankModel?.model, model: currentRerankModel?.model,
}) })

View File

@ -40,7 +40,7 @@ const RetrievalMethodConfig: FC<Props> = ({
onChange({ onChange({
...value, ...value,
search_method: retrieveMethod, search_method: retrieveMethod,
...(!value.reranking_model.reranking_model_name ...((!value.reranking_model.reranking_model_name || !value.reranking_model.reranking_provider_name)
? { ? {
reranking_model: { reranking_model: {
reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider?.provider ?? '' : '', reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider?.provider ?? '' : '',
@ -57,7 +57,7 @@ const RetrievalMethodConfig: FC<Props> = ({
onChange({ onChange({
...value, ...value,
search_method: retrieveMethod, search_method: retrieveMethod,
...(!value.reranking_model.reranking_model_name ...((!value.reranking_model.reranking_model_name || !value.reranking_model.reranking_provider_name)
? { ? {
reranking_model: { reranking_model: {
reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider?.provider ?? '' : '', reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider?.provider ?? '' : '',

View File

@ -54,7 +54,7 @@ const RetrievalParamConfig: FC<Props> = ({
}, },
) )
const handleDisabledSwitchClick = useCallback((enable: boolean) => { const handleToggleRerankEnable = useCallback((enable: boolean) => {
if (enable && !currentModel) if (enable && !currentModel)
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') }) Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
onChange({ onChange({
@ -119,7 +119,7 @@ const RetrievalParamConfig: FC<Props> = ({
<Switch <Switch
size='md' size='md'
defaultValue={value.reranking_enable} defaultValue={value.reranking_enable}
onChange={handleDisabledSwitchClick} onChange={handleToggleRerankEnable}
/> />
)} )}
<div className='flex items-center'> <div className='flex items-center'>

View File

@ -1,6 +1,6 @@
'use client' 'use client'
import type { FC } from 'react' import type { FC } from 'react'
import React, { useCallback, useState } from 'react' import React, { useCallback, useMemo } from 'react'
import { RiEqualizer2Line } from '@remixicon/react' import { RiEqualizer2Line } from '@remixicon/react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import type { MultipleRetrievalConfig, SingleRetrievalConfig } from '../types' import type { MultipleRetrievalConfig, SingleRetrievalConfig } from '../types'
@ -14,8 +14,6 @@ import {
import ConfigRetrievalContent from '@/app/components/app/configuration/dataset-config/params-config/config-content' import ConfigRetrievalContent from '@/app/components/app/configuration/dataset-config/params-config/config-content'
import { RETRIEVE_TYPE } from '@/types/app' import { RETRIEVE_TYPE } from '@/types/app'
import { DATASET_DEFAULT } from '@/config' import { DATASET_DEFAULT } from '@/config'
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import Button from '@/app/components/base/button' import Button from '@/app/components/base/button'
import type { DatasetConfigs } from '@/models/debug' import type { DatasetConfigs } from '@/models/debug'
import type { DataSet } from '@/models/datasets' import type { DataSet } from '@/models/datasets'
@ -32,8 +30,8 @@ type Props = {
onSingleRetrievalModelChange?: (config: ModelConfig) => void onSingleRetrievalModelChange?: (config: ModelConfig) => void
onSingleRetrievalModelParamsChange?: (config: ModelConfig) => void onSingleRetrievalModelParamsChange?: (config: ModelConfig) => void
readonly?: boolean readonly?: boolean
openFromProps?: boolean rerankModalOpen: boolean
onOpenFromPropsChange?: (openFromProps: boolean) => void onRerankModelOpenChange: (open: boolean) => void
selectedDatasets: DataSet[] selectedDatasets: DataSet[]
} }
@ -45,26 +43,52 @@ const RetrievalConfig: FC<Props> = ({
onSingleRetrievalModelChange, onSingleRetrievalModelChange,
onSingleRetrievalModelParamsChange, onSingleRetrievalModelParamsChange,
readonly, readonly,
openFromProps, rerankModalOpen,
onOpenFromPropsChange, onRerankModelOpenChange,
selectedDatasets, selectedDatasets,
}) => { }) => {
const { t } = useTranslation() const { t } = useTranslation()
const [open, setOpen] = useState(false) const { retrieval_mode, multiple_retrieval_config } = payload
const mergedOpen = openFromProps !== undefined ? openFromProps : open
const handleOpen = useCallback((newOpen: boolean) => { const handleOpen = useCallback((newOpen: boolean) => {
setOpen(newOpen) onRerankModelOpenChange(newOpen)
onOpenFromPropsChange?.(newOpen) }, [onRerankModelOpenChange])
}, [onOpenFromPropsChange])
const datasetConfigs = useMemo(() => {
const { const {
currentProvider: validRerankDefaultProvider, reranking_model,
currentModel: validRerankDefaultModel, top_k,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) score_threshold,
reranking_mode,
weights,
reranking_enable,
} = multiple_retrieval_config || {}
return {
retrieval_model: retrieval_mode,
reranking_model: (reranking_model?.provider && reranking_model?.model)
? {
reranking_provider_name: reranking_model?.provider,
reranking_model_name: reranking_model?.model,
}
: {
reranking_provider_name: '',
reranking_model_name: '',
},
top_k: top_k || DATASET_DEFAULT.top_k,
score_threshold_enabled: !(score_threshold === undefined || score_threshold === null),
score_threshold,
datasets: {
datasets: [],
},
reranking_mode,
weights,
reranking_enable,
}
}, [retrieval_mode, multiple_retrieval_config])
const { multiple_retrieval_config } = payload
const handleChange = useCallback((configs: DatasetConfigs, isRetrievalModeChange?: boolean) => { const handleChange = useCallback((configs: DatasetConfigs, isRetrievalModeChange?: boolean) => {
// Legacy code, for compatibility, have to keep it
if (isRetrievalModeChange) { if (isRetrievalModeChange) {
onRetrievalModeChange(configs.retrieval_model) onRetrievalModeChange(configs.retrieval_model)
return return
@ -72,13 +96,11 @@ const RetrievalConfig: FC<Props> = ({
onMultipleRetrievalConfigChange({ onMultipleRetrievalConfigChange({
top_k: configs.top_k, top_k: configs.top_k,
score_threshold: configs.score_threshold_enabled ? (configs.score_threshold ?? DATASET_DEFAULT.score_threshold) : null, score_threshold: configs.score_threshold_enabled ? (configs.score_threshold ?? DATASET_DEFAULT.score_threshold) : null,
reranking_model: payload.retrieval_mode === RETRIEVE_TYPE.oneWay reranking_model: retrieval_mode === RETRIEVE_TYPE.oneWay
? undefined ? undefined
// eslint-disable-next-line sonarjs/no-nested-conditional
: (!configs.reranking_model?.reranking_provider_name : (!configs.reranking_model?.reranking_provider_name
? { ? undefined
provider: validRerankDefaultProvider?.provider || '',
model: validRerankDefaultModel?.model || '',
}
: { : {
provider: configs.reranking_model?.reranking_provider_name, provider: configs.reranking_model?.reranking_provider_name,
model: configs.reranking_model?.reranking_model_name, model: configs.reranking_model?.reranking_model_name,
@ -87,11 +109,11 @@ const RetrievalConfig: FC<Props> = ({
weights: configs.weights, weights: configs.weights,
reranking_enable: configs.reranking_enable, reranking_enable: configs.reranking_enable,
}) })
}, [onMultipleRetrievalConfigChange, payload.retrieval_mode, validRerankDefaultProvider, validRerankDefaultModel, onRetrievalModeChange]) }, [onMultipleRetrievalConfigChange, retrieval_mode, onRetrievalModeChange])
return ( return (
<PortalToFollowElem <PortalToFollowElem
open={mergedOpen} open={rerankModalOpen}
onOpenChange={handleOpen} onOpenChange={handleOpen}
placement='bottom-end' placement='bottom-end'
offset={{ offset={{
@ -102,14 +124,14 @@ const RetrievalConfig: FC<Props> = ({
onClick={() => { onClick={() => {
if (readonly) if (readonly)
return return
handleOpen(!mergedOpen) handleOpen(!rerankModalOpen)
}} }}
> >
<Button <Button
variant='ghost' variant='ghost'
size='small' size='small'
disabled={readonly} disabled={readonly}
className={cn(open && 'bg-components-button-ghost-bg-hover')} className={cn(rerankModalOpen && 'bg-components-button-ghost-bg-hover')}
> >
<RiEqualizer2Line className='mr-1 h-3.5 w-3.5' /> <RiEqualizer2Line className='mr-1 h-3.5 w-3.5' />
{t('dataset.retrievalSettings')} {t('dataset.retrievalSettings')}
@ -118,35 +140,13 @@ const RetrievalConfig: FC<Props> = ({
<PortalToFollowElemContent style={{ zIndex: 1001 }}> <PortalToFollowElemContent style={{ zIndex: 1001 }}>
<div className='w-[404px] rounded-2xl border border-components-panel-border bg-components-panel-bg px-4 pb-4 pt-3 shadow-xl'> <div className='w-[404px] rounded-2xl border border-components-panel-border bg-components-panel-bg px-4 pb-4 pt-3 shadow-xl'>
<ConfigRetrievalContent <ConfigRetrievalContent
datasetConfigs={ datasetConfigs={datasetConfigs}
{
retrieval_model: payload.retrieval_mode,
reranking_model: multiple_retrieval_config?.reranking_model?.provider
? {
reranking_provider_name: multiple_retrieval_config.reranking_model?.provider,
reranking_model_name: multiple_retrieval_config.reranking_model?.model,
}
: {
reranking_provider_name: '',
reranking_model_name: '',
},
top_k: multiple_retrieval_config?.top_k || DATASET_DEFAULT.top_k,
score_threshold_enabled: !(multiple_retrieval_config?.score_threshold === undefined || multiple_retrieval_config.score_threshold === null),
score_threshold: multiple_retrieval_config?.score_threshold,
datasets: {
datasets: [],
},
reranking_mode: multiple_retrieval_config?.reranking_mode,
weights: multiple_retrieval_config?.weights,
reranking_enable: multiple_retrieval_config?.reranking_enable,
}
}
onChange={handleChange} onChange={handleChange}
selectedDatasets={selectedDatasets}
isInWorkflow isInWorkflow
singleRetrievalModelConfig={singleRetrievalModelConfig} singleRetrievalModelConfig={singleRetrievalModelConfig}
onSingleRetrievalModelChange={onSingleRetrievalModelChange} onSingleRetrievalModelChange={onSingleRetrievalModelChange}
onSingleRetrievalModelParamsChange={onSingleRetrievalModelParamsChange} onSingleRetrievalModelParamsChange={onSingleRetrievalModelParamsChange}
selectedDatasets={selectedDatasets}
/> />
</div> </div>
</PortalToFollowElemContent> </PortalToFollowElemContent>

View File

@ -1,6 +1,6 @@
import type { NodeDefault } from '../../types' import type { NodeDefault } from '../../types'
import type { KnowledgeRetrievalNodeType } from './types' import type { KnowledgeRetrievalNodeType } from './types'
import { checkoutRerankModelConfigedInRetrievalSettings } from './utils' import { checkoutRerankModelConfiguredInRetrievalSettings } from './utils'
import { DATASET_DEFAULT } from '@/config' import { DATASET_DEFAULT } from '@/config'
import { RETRIEVE_TYPE } from '@/types/app' import { RETRIEVE_TYPE } from '@/types/app'
import { genNodeMetaData } from '@/app/components/workflow/utils' import { genNodeMetaData } from '@/app/components/workflow/utils'
@ -36,7 +36,7 @@ const nodeDefault: NodeDefault<KnowledgeRetrievalNodeType> = {
const { _datasets, multiple_retrieval_config, retrieval_mode } = payload const { _datasets, multiple_retrieval_config, retrieval_mode } = payload
if (retrieval_mode === RETRIEVE_TYPE.multiWay) { if (retrieval_mode === RETRIEVE_TYPE.multiWay) {
const checked = checkoutRerankModelConfigedInRetrievalSettings(_datasets || [], multiple_retrieval_config) const checked = checkoutRerankModelConfiguredInRetrievalSettings(_datasets || [], multiple_retrieval_config)
if (!errorMessages && !checked) if (!errorMessages && !checked)
errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.errorMsg.fields.rerankModel`) }) errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.errorMsg.fields.rerankModel`) })

View File

@ -1,7 +1,6 @@
import type { FC } from 'react' import type { FC } from 'react'
import { import {
memo, memo,
useCallback,
useMemo, useMemo,
} from 'react' } from 'react'
import { intersectionBy } from 'lodash-es' import { intersectionBy } from 'lodash-es'
@ -53,10 +52,6 @@ const Panel: FC<NodePanelProps<KnowledgeRetrievalNodeType>> = ({
availableNumberNodesWithParent, availableNumberNodesWithParent,
} = useConfig(id, data) } = useConfig(id, data)
const handleOpenFromPropsChange = useCallback((openFromProps: boolean) => {
setRerankModelOpen(openFromProps)
}, [setRerankModelOpen])
const metadataList = useMemo(() => { const metadataList = useMemo(() => {
return intersectionBy(...selectedDatasets.filter((dataset) => { return intersectionBy(...selectedDatasets.filter((dataset) => {
return !!dataset.doc_metadata return !!dataset.doc_metadata
@ -68,7 +63,6 @@ const Panel: FC<NodePanelProps<KnowledgeRetrievalNodeType>> = ({
return ( return (
<div className='pt-2'> <div className='pt-2'>
<div className='space-y-4 px-4 pb-2'> <div className='space-y-4 px-4 pb-2'>
{/* {JSON.stringify(inputs, null, 2)} */}
<Field <Field
title={t(`${i18nPrefix}.queryVariable`)} title={t(`${i18nPrefix}.queryVariable`)}
required required
@ -100,8 +94,8 @@ const Panel: FC<NodePanelProps<KnowledgeRetrievalNodeType>> = ({
onSingleRetrievalModelChange={handleModelChanged as any} onSingleRetrievalModelChange={handleModelChanged as any}
onSingleRetrievalModelParamsChange={handleCompletionParamsChange} onSingleRetrievalModelParamsChange={handleCompletionParamsChange}
readonly={readOnly || !selectedDatasets.length} readonly={readOnly || !selectedDatasets.length}
openFromProps={rerankModelOpen} rerankModalOpen={rerankModelOpen}
onOpenFromPropsChange={handleOpenFromPropsChange} onRerankModelOpenChange={setRerankModelOpen}
selectedDatasets={selectedDatasets} selectedDatasets={selectedDatasets}
/> />
{!readOnly && (<div className='h-3 w-px bg-divider-regular'></div>)} {!readOnly && (<div className='h-3 w-px bg-divider-regular'></div>)}

View File

@ -204,10 +204,11 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => { const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => {
const newInputs = produce(inputs, (draft) => { const newInputs = produce(inputs, (draft) => {
draft.multiple_retrieval_config = getMultipleRetrievalConfig(newConfig!, selectedDatasets, selectedDatasets, { const newMultipleRetrievalConfig = getMultipleRetrievalConfig(newConfig!, selectedDatasets, selectedDatasets, {
provider: currentRerankProvider?.provider, provider: currentRerankProvider?.provider,
model: currentRerankModel?.model, model: currentRerankModel?.model,
}) })
draft.multiple_retrieval_config = newMultipleRetrievalConfig
}) })
setInputs(newInputs) setInputs(newInputs)
}, [inputs, setInputs, selectedDatasets, currentRerankModel, currentRerankProvider]) }, [inputs, setInputs, selectedDatasets, currentRerankModel, currentRerankProvider])
@ -254,10 +255,11 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) { if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) {
const multipleRetrievalConfig = draft.multiple_retrieval_config const multipleRetrievalConfig = draft.multiple_retrieval_config
draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets, selectedDatasets, { const newMultipleRetrievalConfig = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets, selectedDatasets, {
provider: currentRerankProvider?.provider, provider: currentRerankProvider?.provider,
model: currentRerankModel?.model, model: currentRerankModel?.model,
}) })
draft.multiple_retrieval_config = newMultipleRetrievalConfig
} }
}) })
updateDatasetsDetail(newDatasets) updateDatasetsDetail(newDatasets)

View File

@ -10,6 +10,7 @@ import type {
import { import {
DEFAULT_WEIGHTED_SCORE, DEFAULT_WEIGHTED_SCORE,
RerankingModeEnum, RerankingModeEnum,
WeightedScoreEnum,
} from '@/models/datasets' } from '@/models/datasets'
import { RETRIEVE_METHOD } from '@/types/app' import { RETRIEVE_METHOD } from '@/types/app'
import { DATASET_DEFAULT } from '@/config' import { DATASET_DEFAULT } from '@/config'
@ -93,10 +94,12 @@ export const getMultipleRetrievalConfig = (
multipleRetrievalConfig: MultipleRetrievalConfig, multipleRetrievalConfig: MultipleRetrievalConfig,
selectedDatasets: DataSet[], selectedDatasets: DataSet[],
originalDatasets: DataSet[], originalDatasets: DataSet[],
validRerankModel?: { provider?: string; model?: string }, fallbackRerankModel?: { provider?: string; model?: string }, // fallback rerank model
) => { ) => {
const shouldSetWeightDefaultValue = xorBy(selectedDatasets, originalDatasets, 'id').length > 0 // Check if the selected datasets are different from the original datasets
const rerankModelIsValid = validRerankModel?.provider && validRerankModel?.model const isDatasetsChanged = xorBy(selectedDatasets, originalDatasets, 'id').length > 0
// Check if the rerank model is valid
const isFallbackRerankModelValid = !!(fallbackRerankModel?.provider && fallbackRerankModel?.model)
const { const {
allHighQuality, allHighQuality,
@ -125,14 +128,16 @@ export const getMultipleRetrievalConfig = (
reranking_mode, reranking_mode,
reranking_model, reranking_model,
weights, weights,
reranking_enable: ((allInternal && allEconomic) || allExternal) ? reranking_enable : shouldSetWeightDefaultValue, reranking_enable,
} }
const setDefaultWeights = () => { const setDefaultWeights = () => {
result.weights = { result.weights = {
weight_type: WeightedScoreEnum.Customized,
vector_setting: { vector_setting: {
vector_weight: allHighQualityVectorSearch vector_weight: allHighQualityVectorSearch
? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.semantic ? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.semantic
// eslint-disable-next-line sonarjs/no-nested-conditional
: allHighQualityFullTextSearch : allHighQualityFullTextSearch
? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.semantic ? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.semantic
: DEFAULT_WEIGHTED_SCORE.other.semantic, : DEFAULT_WEIGHTED_SCORE.other.semantic,
@ -142,6 +147,7 @@ export const getMultipleRetrievalConfig = (
keyword_setting: { keyword_setting: {
keyword_weight: allHighQualityVectorSearch keyword_weight: allHighQualityVectorSearch
? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.keyword ? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.keyword
// eslint-disable-next-line sonarjs/no-nested-conditional
: allHighQualityFullTextSearch : allHighQualityFullTextSearch
? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.keyword ? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.keyword
: DEFAULT_WEIGHTED_SCORE.other.keyword, : DEFAULT_WEIGHTED_SCORE.other.keyword,
@ -149,65 +155,106 @@ export const getMultipleRetrievalConfig = (
} }
} }
if (allEconomic || mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || allExternal || mixtureInternalAndExternal) { /**
* In this case, user can manually toggle reranking
* So should keep the reranking_enable value
* But the default reranking_model should be set
*/
if ((allEconomic && allInternal) || allExternal) {
result.reranking_mode = RerankingModeEnum.RerankingModel result.reranking_mode = RerankingModeEnum.RerankingModel
if (!result.reranking_model?.provider || !result.reranking_model?.model) { // Need to check if the reranking model should be set to default when first time initialized
if (rerankModelIsValid) { if ((!result.reranking_model?.provider || !result.reranking_model?.model) && isFallbackRerankModelValid) {
result.reranking_enable = reranking_enable !== false
result.reranking_model = { result.reranking_model = {
provider: validRerankModel?.provider || '', provider: fallbackRerankModel.provider || '',
model: validRerankModel?.model || '', model: fallbackRerankModel.model || '',
} }
} }
else { result.reranking_enable = reranking_enable
result.reranking_model = {
provider: '',
model: '',
}
}
}
else {
result.reranking_enable = reranking_enable !== false
}
} }
/**
* In this case, reranking_enable must be true
* And if rerank model is not set, should set the default rerank model
*/
if (mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || mixtureInternalAndExternal) {
result.reranking_mode = RerankingModeEnum.RerankingModel
// Need to check if the reranking model should be set to default when first time initialized
if ((!result.reranking_model?.provider || !result.reranking_model?.model) && isFallbackRerankModelValid) {
result.reranking_model = {
provider: fallbackRerankModel.provider || '',
model: fallbackRerankModel.model || '',
}
}
result.reranking_enable = true
}
/**
* In this case, user can choose to use weighted score or rerank model
* But if the reranking_mode is not initialized, should set the default rerank model and reranking_enable to true
* and set reranking_mode to reranking_model
*/
if (allHighQuality && !inconsistentEmbeddingModel && allInternal) { if (allHighQuality && !inconsistentEmbeddingModel && allInternal) {
// If not initialized, check if the default rerank model is valid
if (!reranking_mode) { if (!reranking_mode) {
if (validRerankModel?.provider && validRerankModel?.model) { if (isFallbackRerankModelValid) {
result.reranking_mode = RerankingModeEnum.RerankingModel result.reranking_mode = RerankingModeEnum.RerankingModel
result.reranking_enable = reranking_enable !== false result.reranking_enable = true
result.reranking_model = { result.reranking_model = {
provider: validRerankModel.provider, provider: fallbackRerankModel.provider || '',
model: validRerankModel.model, model: fallbackRerankModel.model || '',
} }
} }
else { else {
result.reranking_mode = RerankingModeEnum.WeightedScore result.reranking_mode = RerankingModeEnum.WeightedScore
result.reranking_enable = false
setDefaultWeights() setDefaultWeights()
} }
} }
if (reranking_mode === RerankingModeEnum.WeightedScore && !weights) // After initialization, if datasets has no change, make sure the config has correct value
if (reranking_mode === RerankingModeEnum.WeightedScore) {
result.reranking_enable = false
if (!weights)
setDefaultWeights() setDefaultWeights()
}
if (reranking_mode === RerankingModeEnum.WeightedScore && weights && shouldSetWeightDefaultValue) { if (reranking_mode === RerankingModeEnum.RerankingModel) {
if (rerankModelIsValid) { if ((!result.reranking_model?.provider || !result.reranking_model?.model) && isFallbackRerankModelValid) {
result.reranking_mode = RerankingModeEnum.RerankingModel
result.reranking_enable = reranking_enable !== false
result.reranking_model = { result.reranking_model = {
provider: validRerankModel.provider || '', provider: fallbackRerankModel.provider || '',
model: validRerankModel.model || '', model: fallbackRerankModel.model || '',
}
}
result.reranking_enable = true
}
// Need to check if reranking_mode should be set to reranking_model when datasets changed
if (reranking_mode === RerankingModeEnum.WeightedScore && weights && isDatasetsChanged) {
if ((result.reranking_model?.provider && result.reranking_model?.model) || isFallbackRerankModelValid) {
result.reranking_mode = RerankingModeEnum.RerankingModel
result.reranking_enable = true
// eslint-disable-next-line sonarjs/nested-control-flow
if ((!result.reranking_model?.provider || !result.reranking_model?.model) && isFallbackRerankModelValid) {
result.reranking_model = {
provider: fallbackRerankModel.provider || '',
model: fallbackRerankModel.model || '',
}
} }
} }
else { else {
setDefaultWeights() setDefaultWeights()
} }
} }
if (reranking_mode === RerankingModeEnum.RerankingModel && !rerankModelIsValid && shouldSetWeightDefaultValue) { // Need to switch to weighted score when reranking model is not valid and datasets changed
if (
reranking_mode === RerankingModeEnum.RerankingModel
&& (!result.reranking_model?.provider || !result.reranking_model?.model)
&& !isFallbackRerankModelValid
&& isDatasetsChanged
) {
result.reranking_mode = RerankingModeEnum.WeightedScore result.reranking_mode = RerankingModeEnum.WeightedScore
result.reranking_enable = false
setDefaultWeights() setDefaultWeights()
} }
} }
@ -215,7 +262,7 @@ export const getMultipleRetrievalConfig = (
return result return result
} }
export const checkoutRerankModelConfigedInRetrievalSettings = ( export const checkoutRerankModelConfiguredInRetrievalSettings = (
datasets: DataSet[], datasets: DataSet[],
multipleRetrievalConfig?: MultipleRetrievalConfig, multipleRetrievalConfig?: MultipleRetrievalConfig,
) => { ) => {
@ -225,6 +272,7 @@ export const checkoutRerankModelConfigedInRetrievalSettings = (
const { const {
allEconomic, allEconomic,
allExternal, allExternal,
allInternal,
} = getSelectedDatasetsMode(datasets) } = getSelectedDatasetsMode(datasets)
const { const {
@ -233,12 +281,8 @@ export const checkoutRerankModelConfigedInRetrievalSettings = (
reranking_model, reranking_model,
} = multipleRetrievalConfig } = multipleRetrievalConfig
if (reranking_mode === RerankingModeEnum.RerankingModel && (!reranking_model?.provider || !reranking_model?.model)) { if (reranking_mode === RerankingModeEnum.RerankingModel && (!reranking_model?.provider || !reranking_model?.model))
if ((allEconomic || allExternal) && !reranking_enable) return ((allEconomic && allInternal) || allExternal) && !reranking_enable
return true
return false
}
return true return true
} }