mirror of
				https://github.com/langgenius/dify.git
				synced 2025-10-31 10:53:02 +00:00 
			
		
		
		
	 7709d9df20
			
		
	
	
		7709d9df20
		
			
		
	
	
	
	
		
			
			Co-authored-by: NFish <douxc512@gmail.com> Co-authored-by: zxhlyh <jasonapring2015@outlook.com> Co-authored-by: twwu <twwu@dify.ai> Co-authored-by: jZonG <jzongcode@gmail.com>
		
			
				
	
	
		
			252 lines
		
	
	
		
			8.9 KiB
		
	
	
	
		
			TypeScript
		
	
	
	
	
	
			
		
		
	
	
			252 lines
		
	
	
		
			8.9 KiB
		
	
	
	
		
			TypeScript
		
	
	
	
	
	
| import type {
 | |
|   FC,
 | |
|   ReactNode,
 | |
| } from 'react'
 | |
| import { useMemo, useState } from 'react'
 | |
| import { useTranslation } from 'react-i18next'
 | |
| import type {
 | |
|   DefaultModel,
 | |
|   FormValue,
 | |
| } from '@/app/components/header/account-setting/model-provider-page/declarations'
 | |
| import { ModelStatusEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
 | |
| import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
 | |
| import {
 | |
|   useModelList,
 | |
| } from '@/app/components/header/account-setting/model-provider-page/hooks'
 | |
| import AgentModelTrigger from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal/agent-model-trigger'
 | |
| import Trigger from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal/trigger'
 | |
| import type { TriggerProps } from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal/trigger'
 | |
| import {
 | |
|   PortalToFollowElem,
 | |
|   PortalToFollowElemContent,
 | |
|   PortalToFollowElemTrigger,
 | |
| } from '@/app/components/base/portal-to-follow-elem'
 | |
| import LLMParamsPanel from './llm-params-panel'
 | |
| import TTSParamsPanel from './tts-params-panel'
 | |
| import { useProviderContext } from '@/context/provider-context'
 | |
| import cn from '@/utils/classnames'
 | |
| 
 | |
| export type ModelParameterModalProps = {
 | |
|   popupClassName?: string
 | |
|   portalToFollowElemContentClassName?: string
 | |
|   isAdvancedMode: boolean
 | |
|   value: any
 | |
|   setModel: (model: any) => void
 | |
|   renderTrigger?: (v: TriggerProps) => ReactNode
 | |
|   readonly?: boolean
 | |
|   isInWorkflow?: boolean
 | |
|   isAgentStrategy?: boolean
 | |
|   scope?: string
 | |
| }
 | |
| 
 | |
| const ModelParameterModal: FC<ModelParameterModalProps> = ({
 | |
|   popupClassName,
 | |
|   portalToFollowElemContentClassName,
 | |
|   isAdvancedMode,
 | |
|   value,
 | |
|   setModel,
 | |
|   renderTrigger,
 | |
|   readonly,
 | |
|   isInWorkflow,
 | |
|   isAgentStrategy,
 | |
|   scope = ModelTypeEnum.textGeneration,
 | |
| }) => {
 | |
|   const { t } = useTranslation()
 | |
|   const { isAPIKeySet } = useProviderContext()
 | |
|   const [open, setOpen] = useState(false)
 | |
|   const scopeArray = scope.split('&')
 | |
|   const scopeFeatures = useMemo(() => {
 | |
|     if (scopeArray.includes('all'))
 | |
|       return []
 | |
|     return scopeArray.filter(item => ![
 | |
|       ModelTypeEnum.textGeneration,
 | |
|       ModelTypeEnum.textEmbedding,
 | |
|       ModelTypeEnum.rerank,
 | |
|       ModelTypeEnum.moderation,
 | |
|       ModelTypeEnum.speech2text,
 | |
|       ModelTypeEnum.tts,
 | |
|     ].includes(item as ModelTypeEnum))
 | |
|   }, [scopeArray])
 | |
| 
 | |
|   const { data: textGenerationList } = useModelList(ModelTypeEnum.textGeneration)
 | |
|   const { data: textEmbeddingList } = useModelList(ModelTypeEnum.textEmbedding)
 | |
|   const { data: rerankList } = useModelList(ModelTypeEnum.rerank)
 | |
|   const { data: moderationList } = useModelList(ModelTypeEnum.moderation)
 | |
|   const { data: sttList } = useModelList(ModelTypeEnum.speech2text)
 | |
|   const { data: ttsList } = useModelList(ModelTypeEnum.tts)
 | |
| 
 | |
|   const scopedModelList = useMemo(() => {
 | |
|     const resultList: any[] = []
 | |
|     if (scopeArray.includes('all')) {
 | |
|       return [
 | |
|         ...textGenerationList,
 | |
|         ...textEmbeddingList,
 | |
|         ...rerankList,
 | |
|         ...sttList,
 | |
|         ...ttsList,
 | |
|         ...moderationList,
 | |
|       ]
 | |
|     }
 | |
|     if (scopeArray.includes(ModelTypeEnum.textGeneration))
 | |
|       return textGenerationList
 | |
|     if (scopeArray.includes(ModelTypeEnum.textEmbedding))
 | |
|       return textEmbeddingList
 | |
|     if (scopeArray.includes(ModelTypeEnum.rerank))
 | |
|       return rerankList
 | |
|     if (scopeArray.includes(ModelTypeEnum.moderation))
 | |
|       return moderationList
 | |
|     if (scopeArray.includes(ModelTypeEnum.speech2text))
 | |
|       return sttList
 | |
|     if (scopeArray.includes(ModelTypeEnum.tts))
 | |
|       return ttsList
 | |
|     return resultList
 | |
|   }, [scopeArray, textGenerationList, textEmbeddingList, rerankList, sttList, ttsList, moderationList])
 | |
| 
 | |
|   const { currentProvider, currentModel } = useMemo(() => {
 | |
|     const currentProvider = scopedModelList.find(item => item.provider === value?.provider)
 | |
|     const currentModel = currentProvider?.models.find((model: { model: string }) => model.model === value?.model)
 | |
|     return {
 | |
|       currentProvider,
 | |
|       currentModel,
 | |
|     }
 | |
|   }, [scopedModelList, value?.provider, value?.model])
 | |
| 
 | |
|   const hasDeprecated = useMemo(() => {
 | |
|     return !currentProvider || !currentModel
 | |
|   }, [currentModel, currentProvider])
 | |
|   const modelDisabled = useMemo(() => {
 | |
|     return currentModel?.status !== ModelStatusEnum.active
 | |
|   }, [currentModel?.status])
 | |
|   const disabled = useMemo(() => {
 | |
|     return !isAPIKeySet || hasDeprecated || modelDisabled
 | |
|   }, [hasDeprecated, isAPIKeySet, modelDisabled])
 | |
| 
 | |
|   const handleChangeModel = ({ provider, model }: DefaultModel) => {
 | |
|     const targetProvider = scopedModelList.find(modelItem => modelItem.provider === provider)
 | |
|     const targetModelItem = targetProvider?.models.find((modelItem: { model: string }) => modelItem.model === model)
 | |
|     const model_type = targetModelItem?.model_type as string
 | |
|     setModel({
 | |
|       provider,
 | |
|       model,
 | |
|       model_type,
 | |
|       ...(model_type === ModelTypeEnum.textGeneration ? {
 | |
|         mode: targetModelItem?.model_properties.mode as string,
 | |
|         completion_params: {},
 | |
|       } : {}),
 | |
|     })
 | |
|   }
 | |
| 
 | |
|   const handleLLMParamsChange = (newParams: FormValue) => {
 | |
|     const newValue = {
 | |
|       ...(value?.completionParams || {}),
 | |
|       completion_params: newParams,
 | |
|     }
 | |
|     setModel({
 | |
|       ...value,
 | |
|       ...newValue,
 | |
|     })
 | |
|   }
 | |
| 
 | |
|   const handleTTSParamsChange = (language: string, voice: string) => {
 | |
|     setModel({
 | |
|       ...value,
 | |
|       language,
 | |
|       voice,
 | |
|     })
 | |
|   }
 | |
| 
 | |
|   return (
 | |
|     <PortalToFollowElem
 | |
|       open={open}
 | |
|       onOpenChange={setOpen}
 | |
|       placement={isInWorkflow ? 'left' : 'bottom-end'}
 | |
|       offset={4}
 | |
|     >
 | |
|       <div className='relative'>
 | |
|         <PortalToFollowElemTrigger
 | |
|           onClick={() => {
 | |
|             if (readonly)
 | |
|               return
 | |
|             setOpen(v => !v)
 | |
|           }}
 | |
|           className='block'
 | |
|         >
 | |
|           {
 | |
|             renderTrigger
 | |
|               ? renderTrigger({
 | |
|                 open,
 | |
|                 disabled,
 | |
|                 modelDisabled,
 | |
|                 hasDeprecated,
 | |
|                 currentProvider,
 | |
|                 currentModel,
 | |
|                 providerName: value?.provider,
 | |
|                 modelId: value?.model,
 | |
|               })
 | |
|               : (isAgentStrategy
 | |
|                 ? <AgentModelTrigger
 | |
|                   disabled={disabled}
 | |
|                   hasDeprecated={hasDeprecated}
 | |
|                   currentProvider={currentProvider}
 | |
|                   currentModel={currentModel}
 | |
|                   providerName={value?.provider}
 | |
|                   modelId={value?.model}
 | |
|                   scope={scope}
 | |
|                 />
 | |
|                 : <Trigger
 | |
|                   disabled={disabled}
 | |
|                   isInWorkflow={isInWorkflow}
 | |
|                   modelDisabled={modelDisabled}
 | |
|                   hasDeprecated={hasDeprecated}
 | |
|                   currentProvider={currentProvider}
 | |
|                   currentModel={currentModel}
 | |
|                   providerName={value?.provider}
 | |
|                   modelId={value?.model}
 | |
|                 />
 | |
|               )
 | |
|           }
 | |
|         </PortalToFollowElemTrigger>
 | |
|         <PortalToFollowElemContent className={cn('z-50', portalToFollowElemContentClassName)}>
 | |
|           <div className={cn(popupClassName, 'w-[389px] rounded-2xl border-[0.5px] border-components-panel-border bg-components-panel-bg shadow-lg')}>
 | |
|             <div className={cn('max-h-[420px] overflow-y-auto p-4 pt-3')}>
 | |
|               <div className='relative'>
 | |
|                 <div className={cn('system-sm-semibold mb-1 flex h-6 items-center text-text-secondary')}>
 | |
|                   {t('common.modelProvider.model').toLocaleUpperCase()}
 | |
|                 </div>
 | |
|                 <ModelSelector
 | |
|                   defaultModel={(value?.provider || value?.model) ? { provider: value?.provider, model: value?.model } : undefined}
 | |
|                   modelList={scopedModelList}
 | |
|                   scopeFeatures={scopeFeatures}
 | |
|                   onSelect={handleChangeModel}
 | |
|                 />
 | |
|               </div>
 | |
|               {(currentModel?.model_type === ModelTypeEnum.textGeneration || currentModel?.model_type === ModelTypeEnum.tts) && (
 | |
|                 <div className='my-3 h-[1px] bg-divider-subtle' />
 | |
|               )}
 | |
|               {currentModel?.model_type === ModelTypeEnum.textGeneration && (
 | |
|                 <LLMParamsPanel
 | |
|                   provider={value?.provider}
 | |
|                   modelId={value?.model}
 | |
|                   completionParams={value?.completion_params || {}}
 | |
|                   onCompletionParamsChange={handleLLMParamsChange}
 | |
|                   isAdvancedMode={isAdvancedMode}
 | |
|                 />
 | |
|               )}
 | |
|               {currentModel?.model_type === ModelTypeEnum.tts && (
 | |
|                 <TTSParamsPanel
 | |
|                   currentModel={currentModel}
 | |
|                   language={value?.language}
 | |
|                   voice={value?.voice}
 | |
|                   onChange={handleTTSParamsChange}
 | |
|                 />
 | |
|               )}
 | |
|             </div>
 | |
|           </div>
 | |
|         </PortalToFollowElemContent>
 | |
|       </div>
 | |
|     </PortalToFollowElem>
 | |
|   )
 | |
| }
 | |
| 
 | |
| export default ModelParameterModal
 |