mirror of
https://github.com/langgenius/dify.git
synced 2025-12-05 15:26:11 +00:00
fix: Fix retrieval configuration handling in dataset components (#26361)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
parent
af662b100b
commit
1a7898dff1
@ -65,13 +65,40 @@ const DatasetConfig: FC = () => {
|
|||||||
const onRemove = (id: string) => {
|
const onRemove = (id: string) => {
|
||||||
const filteredDataSets = dataSet.filter(item => item.id !== id)
|
const filteredDataSets = dataSet.filter(item => item.id !== id)
|
||||||
setDataSet(filteredDataSets)
|
setDataSet(filteredDataSets)
|
||||||
const retrievalConfig = getMultipleRetrievalConfig(datasetConfigs as any, filteredDataSets, dataSet, {
|
const { datasets, retrieval_model, score_threshold_enabled, ...restConfigs } = datasetConfigs
|
||||||
|
const {
|
||||||
|
top_k,
|
||||||
|
score_threshold,
|
||||||
|
reranking_model,
|
||||||
|
reranking_mode,
|
||||||
|
weights,
|
||||||
|
reranking_enable,
|
||||||
|
} = restConfigs
|
||||||
|
const oldRetrievalConfig = {
|
||||||
|
top_k,
|
||||||
|
score_threshold,
|
||||||
|
reranking_model: (reranking_model.reranking_provider_name && reranking_model.reranking_model_name) ? {
|
||||||
|
provider: reranking_model.reranking_provider_name,
|
||||||
|
model: reranking_model.reranking_model_name,
|
||||||
|
} : undefined,
|
||||||
|
reranking_mode,
|
||||||
|
weights,
|
||||||
|
reranking_enable,
|
||||||
|
}
|
||||||
|
const retrievalConfig = getMultipleRetrievalConfig(oldRetrievalConfig, filteredDataSets, dataSet, {
|
||||||
provider: currentRerankProvider?.provider,
|
provider: currentRerankProvider?.provider,
|
||||||
model: currentRerankModel?.model,
|
model: currentRerankModel?.model,
|
||||||
})
|
})
|
||||||
setDatasetConfigs({
|
setDatasetConfigs({
|
||||||
...(datasetConfigs as any),
|
...datasetConfigsRef.current,
|
||||||
...retrievalConfig,
|
...retrievalConfig,
|
||||||
|
reranking_model: {
|
||||||
|
reranking_provider_name: retrievalConfig?.reranking_model?.provider || '',
|
||||||
|
reranking_model_name: retrievalConfig?.reranking_model?.model || '',
|
||||||
|
},
|
||||||
|
retrieval_model,
|
||||||
|
score_threshold_enabled,
|
||||||
|
datasets,
|
||||||
})
|
})
|
||||||
const {
|
const {
|
||||||
allExternal,
|
allExternal,
|
||||||
|
|||||||
@ -30,11 +30,11 @@ import { noop } from 'lodash-es'
|
|||||||
type Props = {
|
type Props = {
|
||||||
datasetConfigs: DatasetConfigs
|
datasetConfigs: DatasetConfigs
|
||||||
onChange: (configs: DatasetConfigs, isRetrievalModeChange?: boolean) => void
|
onChange: (configs: DatasetConfigs, isRetrievalModeChange?: boolean) => void
|
||||||
|
selectedDatasets?: DataSet[]
|
||||||
isInWorkflow?: boolean
|
isInWorkflow?: boolean
|
||||||
singleRetrievalModelConfig?: ModelConfig
|
singleRetrievalModelConfig?: ModelConfig
|
||||||
onSingleRetrievalModelChange?: (config: ModelConfig) => void
|
onSingleRetrievalModelChange?: (config: ModelConfig) => void
|
||||||
onSingleRetrievalModelParamsChange?: (config: ModelConfig) => void
|
onSingleRetrievalModelParamsChange?: (config: ModelConfig) => void
|
||||||
selectedDatasets?: DataSet[]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const ConfigContent: FC<Props> = ({
|
const ConfigContent: FC<Props> = ({
|
||||||
@ -61,22 +61,28 @@ const ConfigContent: FC<Props> = ({
|
|||||||
|
|
||||||
const {
|
const {
|
||||||
modelList: rerankModelList,
|
modelList: rerankModelList,
|
||||||
|
currentModel: validDefaultRerankModel,
|
||||||
|
currentProvider: validDefaultRerankProvider,
|
||||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* If reranking model is set and is valid, use the reranking model
|
||||||
|
* Otherwise, check if the default reranking model is valid
|
||||||
|
*/
|
||||||
const {
|
const {
|
||||||
currentModel: currentRerankModel,
|
currentModel: currentRerankModel,
|
||||||
} = useCurrentProviderAndModel(
|
} = useCurrentProviderAndModel(
|
||||||
rerankModelList,
|
rerankModelList,
|
||||||
{
|
{
|
||||||
provider: datasetConfigs.reranking_model?.reranking_provider_name,
|
provider: datasetConfigs.reranking_model?.reranking_provider_name || validDefaultRerankProvider?.provider || '',
|
||||||
model: datasetConfigs.reranking_model?.reranking_model_name,
|
model: datasetConfigs.reranking_model?.reranking_model_name || validDefaultRerankModel?.model || '',
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
const rerankModel = useMemo(() => {
|
const rerankModel = useMemo(() => {
|
||||||
return {
|
return {
|
||||||
provider_name: datasetConfigs?.reranking_model?.reranking_provider_name ?? '',
|
provider_name: datasetConfigs.reranking_model?.reranking_provider_name ?? '',
|
||||||
model_name: datasetConfigs?.reranking_model?.reranking_model_name ?? '',
|
model_name: datasetConfigs.reranking_model?.reranking_model_name ?? '',
|
||||||
}
|
}
|
||||||
}, [datasetConfigs.reranking_model])
|
}, [datasetConfigs.reranking_model])
|
||||||
|
|
||||||
@ -135,7 +141,7 @@ const ConfigContent: FC<Props> = ({
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
const model = singleRetrievalConfig
|
const model = singleRetrievalConfig // Legacy code, for compatibility, have to keep it
|
||||||
|
|
||||||
const rerankingModeOptions = [
|
const rerankingModeOptions = [
|
||||||
{
|
{
|
||||||
@ -168,7 +174,7 @@ const ConfigContent: FC<Props> = ({
|
|||||||
return datasetConfigs.reranking_enable
|
return datasetConfigs.reranking_enable
|
||||||
}, [datasetConfigs.reranking_enable, canManuallyToggleRerank])
|
}, [datasetConfigs.reranking_enable, canManuallyToggleRerank])
|
||||||
|
|
||||||
const handleDisabledSwitchClick = useCallback((enable: boolean) => {
|
const handleManuallyToggleRerank = useCallback((enable: boolean) => {
|
||||||
if (!currentRerankModel && enable)
|
if (!currentRerankModel && enable)
|
||||||
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
|
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
|
||||||
onChange({
|
onChange({
|
||||||
@ -255,12 +261,11 @@ const ConfigContent: FC<Props> = ({
|
|||||||
<div className='mt-2'>
|
<div className='mt-2'>
|
||||||
<div className='flex items-center'>
|
<div className='flex items-center'>
|
||||||
{
|
{
|
||||||
selectedDatasetsMode.allEconomic && !selectedDatasetsMode.mixtureInternalAndExternal && (
|
canManuallyToggleRerank && (
|
||||||
<Switch
|
<Switch
|
||||||
size='md'
|
size='md'
|
||||||
defaultValue={showRerankModel}
|
defaultValue={showRerankModel}
|
||||||
disabled={!canManuallyToggleRerank}
|
onChange={handleManuallyToggleRerank}
|
||||||
onChange={handleDisabledSwitchClick}
|
|
||||||
/>
|
/>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -284,18 +284,28 @@ const Configuration: FC = () => {
|
|||||||
setRerankSettingModalOpen(true)
|
setRerankSettingModalOpen(true)
|
||||||
|
|
||||||
const { datasets, retrieval_model, score_threshold_enabled, ...restConfigs } = datasetConfigs
|
const { datasets, retrieval_model, score_threshold_enabled, ...restConfigs } = datasetConfigs
|
||||||
|
const {
|
||||||
|
top_k,
|
||||||
|
score_threshold,
|
||||||
|
reranking_model,
|
||||||
|
reranking_mode,
|
||||||
|
weights,
|
||||||
|
reranking_enable,
|
||||||
|
} = restConfigs
|
||||||
|
|
||||||
const retrievalConfig = getMultipleRetrievalConfig({
|
const oldRetrievalConfig = {
|
||||||
top_k: restConfigs.top_k,
|
top_k,
|
||||||
score_threshold: restConfigs.score_threshold,
|
score_threshold,
|
||||||
reranking_model: restConfigs.reranking_model && {
|
reranking_model: (reranking_model.reranking_provider_name && reranking_model.reranking_model_name) ? {
|
||||||
provider: restConfigs.reranking_model.reranking_provider_name,
|
provider: reranking_model.reranking_provider_name,
|
||||||
model: restConfigs.reranking_model.reranking_model_name,
|
model: reranking_model.reranking_model_name,
|
||||||
},
|
} : undefined,
|
||||||
reranking_mode: restConfigs.reranking_mode,
|
reranking_mode,
|
||||||
weights: restConfigs.weights,
|
weights,
|
||||||
reranking_enable: restConfigs.reranking_enable,
|
reranking_enable,
|
||||||
}, newDatasets, dataSets, {
|
}
|
||||||
|
|
||||||
|
const retrievalConfig = getMultipleRetrievalConfig(oldRetrievalConfig, newDatasets, dataSets, {
|
||||||
provider: currentRerankProvider?.provider,
|
provider: currentRerankProvider?.provider,
|
||||||
model: currentRerankModel?.model,
|
model: currentRerankModel?.model,
|
||||||
})
|
})
|
||||||
|
|||||||
@ -40,7 +40,7 @@ const RetrievalMethodConfig: FC<Props> = ({
|
|||||||
onChange({
|
onChange({
|
||||||
...value,
|
...value,
|
||||||
search_method: retrieveMethod,
|
search_method: retrieveMethod,
|
||||||
...(!value.reranking_model.reranking_model_name
|
...((!value.reranking_model.reranking_model_name || !value.reranking_model.reranking_provider_name)
|
||||||
? {
|
? {
|
||||||
reranking_model: {
|
reranking_model: {
|
||||||
reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider?.provider ?? '' : '',
|
reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider?.provider ?? '' : '',
|
||||||
@ -57,7 +57,7 @@ const RetrievalMethodConfig: FC<Props> = ({
|
|||||||
onChange({
|
onChange({
|
||||||
...value,
|
...value,
|
||||||
search_method: retrieveMethod,
|
search_method: retrieveMethod,
|
||||||
...(!value.reranking_model.reranking_model_name
|
...((!value.reranking_model.reranking_model_name || !value.reranking_model.reranking_provider_name)
|
||||||
? {
|
? {
|
||||||
reranking_model: {
|
reranking_model: {
|
||||||
reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider?.provider ?? '' : '',
|
reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider?.provider ?? '' : '',
|
||||||
|
|||||||
@ -54,7 +54,7 @@ const RetrievalParamConfig: FC<Props> = ({
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
const handleDisabledSwitchClick = useCallback((enable: boolean) => {
|
const handleToggleRerankEnable = useCallback((enable: boolean) => {
|
||||||
if (enable && !currentModel)
|
if (enable && !currentModel)
|
||||||
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
|
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
|
||||||
onChange({
|
onChange({
|
||||||
@ -119,7 +119,7 @@ const RetrievalParamConfig: FC<Props> = ({
|
|||||||
<Switch
|
<Switch
|
||||||
size='md'
|
size='md'
|
||||||
defaultValue={value.reranking_enable}
|
defaultValue={value.reranking_enable}
|
||||||
onChange={handleDisabledSwitchClick}
|
onChange={handleToggleRerankEnable}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
<div className='flex items-center'>
|
<div className='flex items-center'>
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
'use client'
|
'use client'
|
||||||
import type { FC } from 'react'
|
import type { FC } from 'react'
|
||||||
import React, { useCallback, useState } from 'react'
|
import React, { useCallback, useMemo } from 'react'
|
||||||
import { RiEqualizer2Line } from '@remixicon/react'
|
import { RiEqualizer2Line } from '@remixicon/react'
|
||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
import type { MultipleRetrievalConfig, SingleRetrievalConfig } from '../types'
|
import type { MultipleRetrievalConfig, SingleRetrievalConfig } from '../types'
|
||||||
@ -14,8 +14,6 @@ import {
|
|||||||
import ConfigRetrievalContent from '@/app/components/app/configuration/dataset-config/params-config/config-content'
|
import ConfigRetrievalContent from '@/app/components/app/configuration/dataset-config/params-config/config-content'
|
||||||
import { RETRIEVE_TYPE } from '@/types/app'
|
import { RETRIEVE_TYPE } from '@/types/app'
|
||||||
import { DATASET_DEFAULT } from '@/config'
|
import { DATASET_DEFAULT } from '@/config'
|
||||||
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
|
||||||
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
|
||||||
import Button from '@/app/components/base/button'
|
import Button from '@/app/components/base/button'
|
||||||
import type { DatasetConfigs } from '@/models/debug'
|
import type { DatasetConfigs } from '@/models/debug'
|
||||||
import type { DataSet } from '@/models/datasets'
|
import type { DataSet } from '@/models/datasets'
|
||||||
@ -32,8 +30,8 @@ type Props = {
|
|||||||
onSingleRetrievalModelChange?: (config: ModelConfig) => void
|
onSingleRetrievalModelChange?: (config: ModelConfig) => void
|
||||||
onSingleRetrievalModelParamsChange?: (config: ModelConfig) => void
|
onSingleRetrievalModelParamsChange?: (config: ModelConfig) => void
|
||||||
readonly?: boolean
|
readonly?: boolean
|
||||||
openFromProps?: boolean
|
rerankModalOpen: boolean
|
||||||
onOpenFromPropsChange?: (openFromProps: boolean) => void
|
onRerankModelOpenChange: (open: boolean) => void
|
||||||
selectedDatasets: DataSet[]
|
selectedDatasets: DataSet[]
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -45,26 +43,52 @@ const RetrievalConfig: FC<Props> = ({
|
|||||||
onSingleRetrievalModelChange,
|
onSingleRetrievalModelChange,
|
||||||
onSingleRetrievalModelParamsChange,
|
onSingleRetrievalModelParamsChange,
|
||||||
readonly,
|
readonly,
|
||||||
openFromProps,
|
rerankModalOpen,
|
||||||
onOpenFromPropsChange,
|
onRerankModelOpenChange,
|
||||||
selectedDatasets,
|
selectedDatasets,
|
||||||
}) => {
|
}) => {
|
||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
const [open, setOpen] = useState(false)
|
const { retrieval_mode, multiple_retrieval_config } = payload
|
||||||
const mergedOpen = openFromProps !== undefined ? openFromProps : open
|
|
||||||
|
|
||||||
const handleOpen = useCallback((newOpen: boolean) => {
|
const handleOpen = useCallback((newOpen: boolean) => {
|
||||||
setOpen(newOpen)
|
onRerankModelOpenChange(newOpen)
|
||||||
onOpenFromPropsChange?.(newOpen)
|
}, [onRerankModelOpenChange])
|
||||||
}, [onOpenFromPropsChange])
|
|
||||||
|
|
||||||
|
const datasetConfigs = useMemo(() => {
|
||||||
const {
|
const {
|
||||||
currentProvider: validRerankDefaultProvider,
|
reranking_model,
|
||||||
currentModel: validRerankDefaultModel,
|
top_k,
|
||||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
score_threshold,
|
||||||
|
reranking_mode,
|
||||||
|
weights,
|
||||||
|
reranking_enable,
|
||||||
|
} = multiple_retrieval_config || {}
|
||||||
|
|
||||||
|
return {
|
||||||
|
retrieval_model: retrieval_mode,
|
||||||
|
reranking_model: (reranking_model?.provider && reranking_model?.model)
|
||||||
|
? {
|
||||||
|
reranking_provider_name: reranking_model?.provider,
|
||||||
|
reranking_model_name: reranking_model?.model,
|
||||||
|
}
|
||||||
|
: {
|
||||||
|
reranking_provider_name: '',
|
||||||
|
reranking_model_name: '',
|
||||||
|
},
|
||||||
|
top_k: top_k || DATASET_DEFAULT.top_k,
|
||||||
|
score_threshold_enabled: !(score_threshold === undefined || score_threshold === null),
|
||||||
|
score_threshold,
|
||||||
|
datasets: {
|
||||||
|
datasets: [],
|
||||||
|
},
|
||||||
|
reranking_mode,
|
||||||
|
weights,
|
||||||
|
reranking_enable,
|
||||||
|
}
|
||||||
|
}, [retrieval_mode, multiple_retrieval_config])
|
||||||
|
|
||||||
const { multiple_retrieval_config } = payload
|
|
||||||
const handleChange = useCallback((configs: DatasetConfigs, isRetrievalModeChange?: boolean) => {
|
const handleChange = useCallback((configs: DatasetConfigs, isRetrievalModeChange?: boolean) => {
|
||||||
|
// Legacy code, for compatibility, have to keep it
|
||||||
if (isRetrievalModeChange) {
|
if (isRetrievalModeChange) {
|
||||||
onRetrievalModeChange(configs.retrieval_model)
|
onRetrievalModeChange(configs.retrieval_model)
|
||||||
return
|
return
|
||||||
@ -72,13 +96,11 @@ const RetrievalConfig: FC<Props> = ({
|
|||||||
onMultipleRetrievalConfigChange({
|
onMultipleRetrievalConfigChange({
|
||||||
top_k: configs.top_k,
|
top_k: configs.top_k,
|
||||||
score_threshold: configs.score_threshold_enabled ? (configs.score_threshold ?? DATASET_DEFAULT.score_threshold) : null,
|
score_threshold: configs.score_threshold_enabled ? (configs.score_threshold ?? DATASET_DEFAULT.score_threshold) : null,
|
||||||
reranking_model: payload.retrieval_mode === RETRIEVE_TYPE.oneWay
|
reranking_model: retrieval_mode === RETRIEVE_TYPE.oneWay
|
||||||
? undefined
|
? undefined
|
||||||
|
// eslint-disable-next-line sonarjs/no-nested-conditional
|
||||||
: (!configs.reranking_model?.reranking_provider_name
|
: (!configs.reranking_model?.reranking_provider_name
|
||||||
? {
|
? undefined
|
||||||
provider: validRerankDefaultProvider?.provider || '',
|
|
||||||
model: validRerankDefaultModel?.model || '',
|
|
||||||
}
|
|
||||||
: {
|
: {
|
||||||
provider: configs.reranking_model?.reranking_provider_name,
|
provider: configs.reranking_model?.reranking_provider_name,
|
||||||
model: configs.reranking_model?.reranking_model_name,
|
model: configs.reranking_model?.reranking_model_name,
|
||||||
@ -87,11 +109,11 @@ const RetrievalConfig: FC<Props> = ({
|
|||||||
weights: configs.weights,
|
weights: configs.weights,
|
||||||
reranking_enable: configs.reranking_enable,
|
reranking_enable: configs.reranking_enable,
|
||||||
})
|
})
|
||||||
}, [onMultipleRetrievalConfigChange, payload.retrieval_mode, validRerankDefaultProvider, validRerankDefaultModel, onRetrievalModeChange])
|
}, [onMultipleRetrievalConfigChange, retrieval_mode, onRetrievalModeChange])
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<PortalToFollowElem
|
<PortalToFollowElem
|
||||||
open={mergedOpen}
|
open={rerankModalOpen}
|
||||||
onOpenChange={handleOpen}
|
onOpenChange={handleOpen}
|
||||||
placement='bottom-end'
|
placement='bottom-end'
|
||||||
offset={{
|
offset={{
|
||||||
@ -102,14 +124,14 @@ const RetrievalConfig: FC<Props> = ({
|
|||||||
onClick={() => {
|
onClick={() => {
|
||||||
if (readonly)
|
if (readonly)
|
||||||
return
|
return
|
||||||
handleOpen(!mergedOpen)
|
handleOpen(!rerankModalOpen)
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<Button
|
<Button
|
||||||
variant='ghost'
|
variant='ghost'
|
||||||
size='small'
|
size='small'
|
||||||
disabled={readonly}
|
disabled={readonly}
|
||||||
className={cn(open && 'bg-components-button-ghost-bg-hover')}
|
className={cn(rerankModalOpen && 'bg-components-button-ghost-bg-hover')}
|
||||||
>
|
>
|
||||||
<RiEqualizer2Line className='mr-1 h-3.5 w-3.5' />
|
<RiEqualizer2Line className='mr-1 h-3.5 w-3.5' />
|
||||||
{t('dataset.retrievalSettings')}
|
{t('dataset.retrievalSettings')}
|
||||||
@ -118,35 +140,13 @@ const RetrievalConfig: FC<Props> = ({
|
|||||||
<PortalToFollowElemContent style={{ zIndex: 1001 }}>
|
<PortalToFollowElemContent style={{ zIndex: 1001 }}>
|
||||||
<div className='w-[404px] rounded-2xl border border-components-panel-border bg-components-panel-bg px-4 pb-4 pt-3 shadow-xl'>
|
<div className='w-[404px] rounded-2xl border border-components-panel-border bg-components-panel-bg px-4 pb-4 pt-3 shadow-xl'>
|
||||||
<ConfigRetrievalContent
|
<ConfigRetrievalContent
|
||||||
datasetConfigs={
|
datasetConfigs={datasetConfigs}
|
||||||
{
|
|
||||||
retrieval_model: payload.retrieval_mode,
|
|
||||||
reranking_model: multiple_retrieval_config?.reranking_model?.provider
|
|
||||||
? {
|
|
||||||
reranking_provider_name: multiple_retrieval_config.reranking_model?.provider,
|
|
||||||
reranking_model_name: multiple_retrieval_config.reranking_model?.model,
|
|
||||||
}
|
|
||||||
: {
|
|
||||||
reranking_provider_name: '',
|
|
||||||
reranking_model_name: '',
|
|
||||||
},
|
|
||||||
top_k: multiple_retrieval_config?.top_k || DATASET_DEFAULT.top_k,
|
|
||||||
score_threshold_enabled: !(multiple_retrieval_config?.score_threshold === undefined || multiple_retrieval_config.score_threshold === null),
|
|
||||||
score_threshold: multiple_retrieval_config?.score_threshold,
|
|
||||||
datasets: {
|
|
||||||
datasets: [],
|
|
||||||
},
|
|
||||||
reranking_mode: multiple_retrieval_config?.reranking_mode,
|
|
||||||
weights: multiple_retrieval_config?.weights,
|
|
||||||
reranking_enable: multiple_retrieval_config?.reranking_enable,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
onChange={handleChange}
|
onChange={handleChange}
|
||||||
|
selectedDatasets={selectedDatasets}
|
||||||
isInWorkflow
|
isInWorkflow
|
||||||
singleRetrievalModelConfig={singleRetrievalModelConfig}
|
singleRetrievalModelConfig={singleRetrievalModelConfig}
|
||||||
onSingleRetrievalModelChange={onSingleRetrievalModelChange}
|
onSingleRetrievalModelChange={onSingleRetrievalModelChange}
|
||||||
onSingleRetrievalModelParamsChange={onSingleRetrievalModelParamsChange}
|
onSingleRetrievalModelParamsChange={onSingleRetrievalModelParamsChange}
|
||||||
selectedDatasets={selectedDatasets}
|
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</PortalToFollowElemContent>
|
</PortalToFollowElemContent>
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import type { NodeDefault } from '../../types'
|
import type { NodeDefault } from '../../types'
|
||||||
import type { KnowledgeRetrievalNodeType } from './types'
|
import type { KnowledgeRetrievalNodeType } from './types'
|
||||||
import { checkoutRerankModelConfigedInRetrievalSettings } from './utils'
|
import { checkoutRerankModelConfiguredInRetrievalSettings } from './utils'
|
||||||
import { DATASET_DEFAULT } from '@/config'
|
import { DATASET_DEFAULT } from '@/config'
|
||||||
import { RETRIEVE_TYPE } from '@/types/app'
|
import { RETRIEVE_TYPE } from '@/types/app'
|
||||||
import { genNodeMetaData } from '@/app/components/workflow/utils'
|
import { genNodeMetaData } from '@/app/components/workflow/utils'
|
||||||
@ -36,7 +36,7 @@ const nodeDefault: NodeDefault<KnowledgeRetrievalNodeType> = {
|
|||||||
|
|
||||||
const { _datasets, multiple_retrieval_config, retrieval_mode } = payload
|
const { _datasets, multiple_retrieval_config, retrieval_mode } = payload
|
||||||
if (retrieval_mode === RETRIEVE_TYPE.multiWay) {
|
if (retrieval_mode === RETRIEVE_TYPE.multiWay) {
|
||||||
const checked = checkoutRerankModelConfigedInRetrievalSettings(_datasets || [], multiple_retrieval_config)
|
const checked = checkoutRerankModelConfiguredInRetrievalSettings(_datasets || [], multiple_retrieval_config)
|
||||||
|
|
||||||
if (!errorMessages && !checked)
|
if (!errorMessages && !checked)
|
||||||
errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.errorMsg.fields.rerankModel`) })
|
errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.errorMsg.fields.rerankModel`) })
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
import type { FC } from 'react'
|
import type { FC } from 'react'
|
||||||
import {
|
import {
|
||||||
memo,
|
memo,
|
||||||
useCallback,
|
|
||||||
useMemo,
|
useMemo,
|
||||||
} from 'react'
|
} from 'react'
|
||||||
import { intersectionBy } from 'lodash-es'
|
import { intersectionBy } from 'lodash-es'
|
||||||
@ -53,10 +52,6 @@ const Panel: FC<NodePanelProps<KnowledgeRetrievalNodeType>> = ({
|
|||||||
availableNumberNodesWithParent,
|
availableNumberNodesWithParent,
|
||||||
} = useConfig(id, data)
|
} = useConfig(id, data)
|
||||||
|
|
||||||
const handleOpenFromPropsChange = useCallback((openFromProps: boolean) => {
|
|
||||||
setRerankModelOpen(openFromProps)
|
|
||||||
}, [setRerankModelOpen])
|
|
||||||
|
|
||||||
const metadataList = useMemo(() => {
|
const metadataList = useMemo(() => {
|
||||||
return intersectionBy(...selectedDatasets.filter((dataset) => {
|
return intersectionBy(...selectedDatasets.filter((dataset) => {
|
||||||
return !!dataset.doc_metadata
|
return !!dataset.doc_metadata
|
||||||
@ -68,7 +63,6 @@ const Panel: FC<NodePanelProps<KnowledgeRetrievalNodeType>> = ({
|
|||||||
return (
|
return (
|
||||||
<div className='pt-2'>
|
<div className='pt-2'>
|
||||||
<div className='space-y-4 px-4 pb-2'>
|
<div className='space-y-4 px-4 pb-2'>
|
||||||
{/* {JSON.stringify(inputs, null, 2)} */}
|
|
||||||
<Field
|
<Field
|
||||||
title={t(`${i18nPrefix}.queryVariable`)}
|
title={t(`${i18nPrefix}.queryVariable`)}
|
||||||
required
|
required
|
||||||
@ -100,8 +94,8 @@ const Panel: FC<NodePanelProps<KnowledgeRetrievalNodeType>> = ({
|
|||||||
onSingleRetrievalModelChange={handleModelChanged as any}
|
onSingleRetrievalModelChange={handleModelChanged as any}
|
||||||
onSingleRetrievalModelParamsChange={handleCompletionParamsChange}
|
onSingleRetrievalModelParamsChange={handleCompletionParamsChange}
|
||||||
readonly={readOnly || !selectedDatasets.length}
|
readonly={readOnly || !selectedDatasets.length}
|
||||||
openFromProps={rerankModelOpen}
|
rerankModalOpen={rerankModelOpen}
|
||||||
onOpenFromPropsChange={handleOpenFromPropsChange}
|
onRerankModelOpenChange={setRerankModelOpen}
|
||||||
selectedDatasets={selectedDatasets}
|
selectedDatasets={selectedDatasets}
|
||||||
/>
|
/>
|
||||||
{!readOnly && (<div className='h-3 w-px bg-divider-regular'></div>)}
|
{!readOnly && (<div className='h-3 w-px bg-divider-regular'></div>)}
|
||||||
|
|||||||
@ -204,10 +204,11 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|||||||
|
|
||||||
const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => {
|
const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => {
|
||||||
const newInputs = produce(inputs, (draft) => {
|
const newInputs = produce(inputs, (draft) => {
|
||||||
draft.multiple_retrieval_config = getMultipleRetrievalConfig(newConfig!, selectedDatasets, selectedDatasets, {
|
const newMultipleRetrievalConfig = getMultipleRetrievalConfig(newConfig!, selectedDatasets, selectedDatasets, {
|
||||||
provider: currentRerankProvider?.provider,
|
provider: currentRerankProvider?.provider,
|
||||||
model: currentRerankModel?.model,
|
model: currentRerankModel?.model,
|
||||||
})
|
})
|
||||||
|
draft.multiple_retrieval_config = newMultipleRetrievalConfig
|
||||||
})
|
})
|
||||||
setInputs(newInputs)
|
setInputs(newInputs)
|
||||||
}, [inputs, setInputs, selectedDatasets, currentRerankModel, currentRerankProvider])
|
}, [inputs, setInputs, selectedDatasets, currentRerankModel, currentRerankProvider])
|
||||||
@ -254,10 +255,11 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|||||||
|
|
||||||
if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) {
|
if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) {
|
||||||
const multipleRetrievalConfig = draft.multiple_retrieval_config
|
const multipleRetrievalConfig = draft.multiple_retrieval_config
|
||||||
draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets, selectedDatasets, {
|
const newMultipleRetrievalConfig = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets, selectedDatasets, {
|
||||||
provider: currentRerankProvider?.provider,
|
provider: currentRerankProvider?.provider,
|
||||||
model: currentRerankModel?.model,
|
model: currentRerankModel?.model,
|
||||||
})
|
})
|
||||||
|
draft.multiple_retrieval_config = newMultipleRetrievalConfig
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
updateDatasetsDetail(newDatasets)
|
updateDatasetsDetail(newDatasets)
|
||||||
|
|||||||
@ -10,6 +10,7 @@ import type {
|
|||||||
import {
|
import {
|
||||||
DEFAULT_WEIGHTED_SCORE,
|
DEFAULT_WEIGHTED_SCORE,
|
||||||
RerankingModeEnum,
|
RerankingModeEnum,
|
||||||
|
WeightedScoreEnum,
|
||||||
} from '@/models/datasets'
|
} from '@/models/datasets'
|
||||||
import { RETRIEVE_METHOD } from '@/types/app'
|
import { RETRIEVE_METHOD } from '@/types/app'
|
||||||
import { DATASET_DEFAULT } from '@/config'
|
import { DATASET_DEFAULT } from '@/config'
|
||||||
@ -93,10 +94,12 @@ export const getMultipleRetrievalConfig = (
|
|||||||
multipleRetrievalConfig: MultipleRetrievalConfig,
|
multipleRetrievalConfig: MultipleRetrievalConfig,
|
||||||
selectedDatasets: DataSet[],
|
selectedDatasets: DataSet[],
|
||||||
originalDatasets: DataSet[],
|
originalDatasets: DataSet[],
|
||||||
validRerankModel?: { provider?: string; model?: string },
|
fallbackRerankModel?: { provider?: string; model?: string }, // fallback rerank model
|
||||||
) => {
|
) => {
|
||||||
const shouldSetWeightDefaultValue = xorBy(selectedDatasets, originalDatasets, 'id').length > 0
|
// Check if the selected datasets are different from the original datasets
|
||||||
const rerankModelIsValid = validRerankModel?.provider && validRerankModel?.model
|
const isDatasetsChanged = xorBy(selectedDatasets, originalDatasets, 'id').length > 0
|
||||||
|
// Check if the rerank model is valid
|
||||||
|
const isFallbackRerankModelValid = !!(fallbackRerankModel?.provider && fallbackRerankModel?.model)
|
||||||
|
|
||||||
const {
|
const {
|
||||||
allHighQuality,
|
allHighQuality,
|
||||||
@ -125,14 +128,16 @@ export const getMultipleRetrievalConfig = (
|
|||||||
reranking_mode,
|
reranking_mode,
|
||||||
reranking_model,
|
reranking_model,
|
||||||
weights,
|
weights,
|
||||||
reranking_enable: ((allInternal && allEconomic) || allExternal) ? reranking_enable : shouldSetWeightDefaultValue,
|
reranking_enable,
|
||||||
}
|
}
|
||||||
|
|
||||||
const setDefaultWeights = () => {
|
const setDefaultWeights = () => {
|
||||||
result.weights = {
|
result.weights = {
|
||||||
|
weight_type: WeightedScoreEnum.Customized,
|
||||||
vector_setting: {
|
vector_setting: {
|
||||||
vector_weight: allHighQualityVectorSearch
|
vector_weight: allHighQualityVectorSearch
|
||||||
? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.semantic
|
? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.semantic
|
||||||
|
// eslint-disable-next-line sonarjs/no-nested-conditional
|
||||||
: allHighQualityFullTextSearch
|
: allHighQualityFullTextSearch
|
||||||
? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.semantic
|
? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.semantic
|
||||||
: DEFAULT_WEIGHTED_SCORE.other.semantic,
|
: DEFAULT_WEIGHTED_SCORE.other.semantic,
|
||||||
@ -142,6 +147,7 @@ export const getMultipleRetrievalConfig = (
|
|||||||
keyword_setting: {
|
keyword_setting: {
|
||||||
keyword_weight: allHighQualityVectorSearch
|
keyword_weight: allHighQualityVectorSearch
|
||||||
? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.keyword
|
? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.keyword
|
||||||
|
// eslint-disable-next-line sonarjs/no-nested-conditional
|
||||||
: allHighQualityFullTextSearch
|
: allHighQualityFullTextSearch
|
||||||
? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.keyword
|
? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.keyword
|
||||||
: DEFAULT_WEIGHTED_SCORE.other.keyword,
|
: DEFAULT_WEIGHTED_SCORE.other.keyword,
|
||||||
@ -149,65 +155,106 @@ export const getMultipleRetrievalConfig = (
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (allEconomic || mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || allExternal || mixtureInternalAndExternal) {
|
/**
|
||||||
|
* In this case, user can manually toggle reranking
|
||||||
|
* So should keep the reranking_enable value
|
||||||
|
* But the default reranking_model should be set
|
||||||
|
*/
|
||||||
|
if ((allEconomic && allInternal) || allExternal) {
|
||||||
result.reranking_mode = RerankingModeEnum.RerankingModel
|
result.reranking_mode = RerankingModeEnum.RerankingModel
|
||||||
if (!result.reranking_model?.provider || !result.reranking_model?.model) {
|
// Need to check if the reranking model should be set to default when first time initialized
|
||||||
if (rerankModelIsValid) {
|
if ((!result.reranking_model?.provider || !result.reranking_model?.model) && isFallbackRerankModelValid) {
|
||||||
result.reranking_enable = reranking_enable !== false
|
|
||||||
|
|
||||||
result.reranking_model = {
|
result.reranking_model = {
|
||||||
provider: validRerankModel?.provider || '',
|
provider: fallbackRerankModel.provider || '',
|
||||||
model: validRerankModel?.model || '',
|
model: fallbackRerankModel.model || '',
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
result.reranking_enable = reranking_enable
|
||||||
result.reranking_model = {
|
|
||||||
provider: '',
|
|
||||||
model: '',
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
result.reranking_enable = reranking_enable !== false
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* In this case, reranking_enable must be true
|
||||||
|
* And if rerank model is not set, should set the default rerank model
|
||||||
|
*/
|
||||||
|
if (mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || mixtureInternalAndExternal) {
|
||||||
|
result.reranking_mode = RerankingModeEnum.RerankingModel
|
||||||
|
// Need to check if the reranking model should be set to default when first time initialized
|
||||||
|
if ((!result.reranking_model?.provider || !result.reranking_model?.model) && isFallbackRerankModelValid) {
|
||||||
|
result.reranking_model = {
|
||||||
|
provider: fallbackRerankModel.provider || '',
|
||||||
|
model: fallbackRerankModel.model || '',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result.reranking_enable = true
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* In this case, user can choose to use weighted score or rerank model
|
||||||
|
* But if the reranking_mode is not initialized, should set the default rerank model and reranking_enable to true
|
||||||
|
* and set reranking_mode to reranking_model
|
||||||
|
*/
|
||||||
if (allHighQuality && !inconsistentEmbeddingModel && allInternal) {
|
if (allHighQuality && !inconsistentEmbeddingModel && allInternal) {
|
||||||
|
// If not initialized, check if the default rerank model is valid
|
||||||
if (!reranking_mode) {
|
if (!reranking_mode) {
|
||||||
if (validRerankModel?.provider && validRerankModel?.model) {
|
if (isFallbackRerankModelValid) {
|
||||||
result.reranking_mode = RerankingModeEnum.RerankingModel
|
result.reranking_mode = RerankingModeEnum.RerankingModel
|
||||||
result.reranking_enable = reranking_enable !== false
|
result.reranking_enable = true
|
||||||
|
|
||||||
result.reranking_model = {
|
result.reranking_model = {
|
||||||
provider: validRerankModel.provider,
|
provider: fallbackRerankModel.provider || '',
|
||||||
model: validRerankModel.model,
|
model: fallbackRerankModel.model || '',
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
result.reranking_mode = RerankingModeEnum.WeightedScore
|
result.reranking_mode = RerankingModeEnum.WeightedScore
|
||||||
|
result.reranking_enable = false
|
||||||
setDefaultWeights()
|
setDefaultWeights()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (reranking_mode === RerankingModeEnum.WeightedScore && !weights)
|
// After initialization, if datasets has no change, make sure the config has correct value
|
||||||
|
if (reranking_mode === RerankingModeEnum.WeightedScore) {
|
||||||
|
result.reranking_enable = false
|
||||||
|
if (!weights)
|
||||||
setDefaultWeights()
|
setDefaultWeights()
|
||||||
|
}
|
||||||
if (reranking_mode === RerankingModeEnum.WeightedScore && weights && shouldSetWeightDefaultValue) {
|
if (reranking_mode === RerankingModeEnum.RerankingModel) {
|
||||||
if (rerankModelIsValid) {
|
if ((!result.reranking_model?.provider || !result.reranking_model?.model) && isFallbackRerankModelValid) {
|
||||||
result.reranking_mode = RerankingModeEnum.RerankingModel
|
|
||||||
result.reranking_enable = reranking_enable !== false
|
|
||||||
|
|
||||||
result.reranking_model = {
|
result.reranking_model = {
|
||||||
provider: validRerankModel.provider || '',
|
provider: fallbackRerankModel.provider || '',
|
||||||
model: validRerankModel.model || '',
|
model: fallbackRerankModel.model || '',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result.reranking_enable = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Need to check if reranking_mode should be set to reranking_model when datasets changed
|
||||||
|
if (reranking_mode === RerankingModeEnum.WeightedScore && weights && isDatasetsChanged) {
|
||||||
|
if ((result.reranking_model?.provider && result.reranking_model?.model) || isFallbackRerankModelValid) {
|
||||||
|
result.reranking_mode = RerankingModeEnum.RerankingModel
|
||||||
|
result.reranking_enable = true
|
||||||
|
|
||||||
|
// eslint-disable-next-line sonarjs/nested-control-flow
|
||||||
|
if ((!result.reranking_model?.provider || !result.reranking_model?.model) && isFallbackRerankModelValid) {
|
||||||
|
result.reranking_model = {
|
||||||
|
provider: fallbackRerankModel.provider || '',
|
||||||
|
model: fallbackRerankModel.model || '',
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
setDefaultWeights()
|
setDefaultWeights()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (reranking_mode === RerankingModeEnum.RerankingModel && !rerankModelIsValid && shouldSetWeightDefaultValue) {
|
// Need to switch to weighted score when reranking model is not valid and datasets changed
|
||||||
|
if (
|
||||||
|
reranking_mode === RerankingModeEnum.RerankingModel
|
||||||
|
&& (!result.reranking_model?.provider || !result.reranking_model?.model)
|
||||||
|
&& !isFallbackRerankModelValid
|
||||||
|
&& isDatasetsChanged
|
||||||
|
) {
|
||||||
result.reranking_mode = RerankingModeEnum.WeightedScore
|
result.reranking_mode = RerankingModeEnum.WeightedScore
|
||||||
|
result.reranking_enable = false
|
||||||
setDefaultWeights()
|
setDefaultWeights()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -215,7 +262,7 @@ export const getMultipleRetrievalConfig = (
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
export const checkoutRerankModelConfigedInRetrievalSettings = (
|
export const checkoutRerankModelConfiguredInRetrievalSettings = (
|
||||||
datasets: DataSet[],
|
datasets: DataSet[],
|
||||||
multipleRetrievalConfig?: MultipleRetrievalConfig,
|
multipleRetrievalConfig?: MultipleRetrievalConfig,
|
||||||
) => {
|
) => {
|
||||||
@ -225,6 +272,7 @@ export const checkoutRerankModelConfigedInRetrievalSettings = (
|
|||||||
const {
|
const {
|
||||||
allEconomic,
|
allEconomic,
|
||||||
allExternal,
|
allExternal,
|
||||||
|
allInternal,
|
||||||
} = getSelectedDatasetsMode(datasets)
|
} = getSelectedDatasetsMode(datasets)
|
||||||
|
|
||||||
const {
|
const {
|
||||||
@ -233,12 +281,8 @@ export const checkoutRerankModelConfigedInRetrievalSettings = (
|
|||||||
reranking_model,
|
reranking_model,
|
||||||
} = multipleRetrievalConfig
|
} = multipleRetrievalConfig
|
||||||
|
|
||||||
if (reranking_mode === RerankingModeEnum.RerankingModel && (!reranking_model?.provider || !reranking_model?.model)) {
|
if (reranking_mode === RerankingModeEnum.RerankingModel && (!reranking_model?.provider || !reranking_model?.model))
|
||||||
if ((allEconomic || allExternal) && !reranking_enable)
|
return ((allEconomic && allInternal) || allExternal) && !reranking_enable
|
||||||
return true
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user