mirror of
				https://github.com/langgenius/dify.git
				synced 2025-11-03 20:33:00 +00:00 
			
		
		
		
	feat: agent node add memory (#15976)
This commit is contained in:
		
							parent
							
								
									3d76f09c3a
								
							
						
					
					
						commit
						dcdec98c8e
					
				@ -70,11 +70,20 @@ class AgentStrategyIdentity(ToolIdentity):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AgentFeature(enum.StrEnum):
 | 
			
		||||
    """
 | 
			
		||||
    Agent Feature, used to describe the features of the agent strategy.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    HISTORY_MESSAGES = "history-messages"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AgentStrategyEntity(BaseModel):
 | 
			
		||||
    identity: AgentStrategyIdentity
 | 
			
		||||
    parameters: list[AgentStrategyParameter] = Field(default_factory=list)
 | 
			
		||||
    description: I18nObject = Field(..., description="The description of the agent strategy")
 | 
			
		||||
    output_schema: Optional[dict] = None
 | 
			
		||||
    features: Optional[list[AgentFeature]] = None
 | 
			
		||||
 | 
			
		||||
    # pydantic configs
 | 
			
		||||
    model_config = ConfigDict(protected_namespaces=())
 | 
			
		||||
 | 
			
		||||
@ -1,15 +1,18 @@
 | 
			
		||||
import json
 | 
			
		||||
from collections.abc import Generator, Mapping, Sequence
 | 
			
		||||
from typing import Any, cast
 | 
			
		||||
from typing import Any, Optional, cast
 | 
			
		||||
 | 
			
		||||
from core.agent.entities import AgentToolEntity
 | 
			
		||||
from core.agent.plugin_entities import AgentStrategyParameter
 | 
			
		||||
from core.model_manager import ModelManager
 | 
			
		||||
from core.model_runtime.entities.model_entities import ModelType
 | 
			
		||||
from core.memory.token_buffer_memory import TokenBufferMemory
 | 
			
		||||
from core.model_manager import ModelInstance, ModelManager
 | 
			
		||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
 | 
			
		||||
from core.plugin.manager.exc import PluginDaemonClientSideError
 | 
			
		||||
from core.plugin.manager.plugin import PluginInstallationManager
 | 
			
		||||
from core.provider_manager import ProviderManager
 | 
			
		||||
from core.tools.entities.tool_entities import ToolParameter, ToolProviderType
 | 
			
		||||
from core.tools.tool_manager import ToolManager
 | 
			
		||||
from core.variables.segments import StringSegment
 | 
			
		||||
from core.workflow.entities.node_entities import NodeRunResult
 | 
			
		||||
from core.workflow.entities.variable_pool import VariablePool
 | 
			
		||||
from core.workflow.enums import SystemVariableKey
 | 
			
		||||
@ -19,7 +22,9 @@ from core.workflow.nodes.enums import NodeType
 | 
			
		||||
from core.workflow.nodes.event.event import RunCompletedEvent
 | 
			
		||||
from core.workflow.nodes.tool.tool_node import ToolNode
 | 
			
		||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
 | 
			
		||||
from extensions.ext_database import db
 | 
			
		||||
from factories.agent_factory import get_plugin_agent_strategy
 | 
			
		||||
from models.model import Conversation
 | 
			
		||||
from models.workflow import WorkflowNodeExecutionStatus
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -233,17 +238,20 @@ class AgentNode(ToolNode):
 | 
			
		||||
                    value = tool_value
 | 
			
		||||
                if parameter.type == "model-selector":
 | 
			
		||||
                    value = cast(dict[str, Any], value)
 | 
			
		||||
                    model_instance = ModelManager().get_model_instance(
 | 
			
		||||
                        tenant_id=self.tenant_id,
 | 
			
		||||
                        provider=value.get("provider", ""),
 | 
			
		||||
                        model_type=ModelType(value.get("model_type", "")),
 | 
			
		||||
                        model=value.get("model", ""),
 | 
			
		||||
                    )
 | 
			
		||||
                    models = model_instance.model_type_instance.plugin_model_provider.declaration.models
 | 
			
		||||
                    finded_model = next((model for model in models if model.model == value.get("model", "")), None)
 | 
			
		||||
 | 
			
		||||
                    value["entity"] = finded_model.model_dump(mode="json") if finded_model else None
 | 
			
		||||
 | 
			
		||||
                    model_instance, model_schema = self._fetch_model(value)
 | 
			
		||||
                    # memory config
 | 
			
		||||
                    history_prompt_messages = []
 | 
			
		||||
                    if node_data.memory:
 | 
			
		||||
                        memory = self._fetch_memory(model_instance)
 | 
			
		||||
                        if memory:
 | 
			
		||||
                            prompt_messages = memory.get_history_prompt_messages(
 | 
			
		||||
                                message_limit=node_data.memory.window.size if node_data.memory.window.size else None
 | 
			
		||||
                            )
 | 
			
		||||
                            history_prompt_messages = [
 | 
			
		||||
                                prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
 | 
			
		||||
                            ]
 | 
			
		||||
                    value["history_prompt_messages"] = history_prompt_messages
 | 
			
		||||
                    value["entity"] = model_schema.model_dump(mode="json") if model_schema else None
 | 
			
		||||
            result[parameter_name] = value
 | 
			
		||||
 | 
			
		||||
        return result
 | 
			
		||||
@ -297,3 +305,46 @@ class AgentNode(ToolNode):
 | 
			
		||||
        except StopIteration:
 | 
			
		||||
            icon = None
 | 
			
		||||
        return icon
 | 
			
		||||
 | 
			
		||||
    def _fetch_memory(self, model_instance: ModelInstance) -> Optional[TokenBufferMemory]:
 | 
			
		||||
        # get conversation id
 | 
			
		||||
        conversation_id_variable = self.graph_runtime_state.variable_pool.get(
 | 
			
		||||
            ["sys", SystemVariableKey.CONVERSATION_ID.value]
 | 
			
		||||
        )
 | 
			
		||||
        if not isinstance(conversation_id_variable, StringSegment):
 | 
			
		||||
            return None
 | 
			
		||||
        conversation_id = conversation_id_variable.value
 | 
			
		||||
 | 
			
		||||
        # get conversation
 | 
			
		||||
        conversation = (
 | 
			
		||||
            db.session.query(Conversation)
 | 
			
		||||
            .filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
 | 
			
		||||
            .first()
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if not conversation:
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
        memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
 | 
			
		||||
 | 
			
		||||
        return memory
 | 
			
		||||
 | 
			
		||||
    def _fetch_model(self, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
 | 
			
		||||
        provider_manager = ProviderManager()
 | 
			
		||||
        provider_model_bundle = provider_manager.get_provider_model_bundle(
 | 
			
		||||
            tenant_id=self.tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM
 | 
			
		||||
        )
 | 
			
		||||
        model_name = value.get("model", "")
 | 
			
		||||
        model_credentials = provider_model_bundle.configuration.get_current_credentials(
 | 
			
		||||
            model_type=ModelType.LLM, model=model_name
 | 
			
		||||
        )
 | 
			
		||||
        provider_name = provider_model_bundle.configuration.provider.provider
 | 
			
		||||
        model_type_instance = provider_model_bundle.model_type_instance
 | 
			
		||||
        model_instance = ModelManager().get_model_instance(
 | 
			
		||||
            tenant_id=self.tenant_id,
 | 
			
		||||
            provider=provider_name,
 | 
			
		||||
            model_type=ModelType(value.get("model_type", "")),
 | 
			
		||||
            model=model_name,
 | 
			
		||||
        )
 | 
			
		||||
        model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
 | 
			
		||||
        return model_instance, model_schema
 | 
			
		||||
 | 
			
		||||
@ -3,6 +3,7 @@ from typing import Any, Literal, Union
 | 
			
		||||
 | 
			
		||||
from pydantic import BaseModel
 | 
			
		||||
 | 
			
		||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
 | 
			
		||||
from core.tools.entities.tool_entities import ToolSelector
 | 
			
		||||
from core.workflow.nodes.base.entities import BaseNodeData
 | 
			
		||||
 | 
			
		||||
@ -11,6 +12,7 @@ class AgentNodeData(BaseNodeData):
 | 
			
		||||
    agent_strategy_provider_name: str  # redundancy
 | 
			
		||||
    agent_strategy_name: str
 | 
			
		||||
    agent_strategy_label: str  # redundancy
 | 
			
		||||
    memory: MemoryConfig | None = None
 | 
			
		||||
 | 
			
		||||
    class AgentInput(BaseModel):
 | 
			
		||||
        value: Union[list[str], list[ToolSelector], Any]
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,7 @@
 | 
			
		||||
import type { CredentialFormSchemaBase } from '../header/account-setting/model-provider-page/declarations'
 | 
			
		||||
import type { ToolCredential } from '@/app/components/tools/types'
 | 
			
		||||
import type { Locale } from '@/i18n'
 | 
			
		||||
 | 
			
		||||
import type { AgentFeature } from '@/app/components/workflow/nodes/agent/types'
 | 
			
		||||
export enum PluginType {
 | 
			
		||||
  tool = 'tool',
 | 
			
		||||
  model = 'model',
 | 
			
		||||
@ -418,6 +418,7 @@ export type StrategyDetail = {
 | 
			
		||||
  parameters: StrategyParamItem[]
 | 
			
		||||
  description: Record<Locale, string>
 | 
			
		||||
  output_schema: Record<string, any>
 | 
			
		||||
  features: AgentFeature[]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export type StrategyDeclaration = {
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,7 @@
 | 
			
		||||
import type { FC } from 'react'
 | 
			
		||||
import { memo, useMemo } from 'react'
 | 
			
		||||
import type { NodePanelProps } from '../../types'
 | 
			
		||||
import type { AgentNodeType } from './types'
 | 
			
		||||
import { AgentFeature, type AgentNodeType } from './types'
 | 
			
		||||
import Field from '../_base/components/field'
 | 
			
		||||
import { AgentStrategy } from '../_base/components/agent-strategy'
 | 
			
		||||
import useConfig from './use-config'
 | 
			
		||||
@ -16,6 +16,8 @@ import { useLogs } from '@/app/components/workflow/run/hooks'
 | 
			
		||||
import type { Props as FormProps } from '@/app/components/workflow/nodes/_base/components/before-run-form/form'
 | 
			
		||||
import { toType } from '@/app/components/tools/utils/to-form-schema'
 | 
			
		||||
import { useStore } from '../../store'
 | 
			
		||||
import Split from '../_base/components/split'
 | 
			
		||||
import MemoryConfig from '../_base/components/memory-config'
 | 
			
		||||
 | 
			
		||||
const i18nPrefix = 'workflow.nodes.agent'
 | 
			
		||||
 | 
			
		||||
@ -35,10 +37,10 @@ const AgentPanel: FC<NodePanelProps<AgentNodeType>> = (props) => {
 | 
			
		||||
    currentStrategy,
 | 
			
		||||
    formData,
 | 
			
		||||
    onFormChange,
 | 
			
		||||
 | 
			
		||||
    isChatMode,
 | 
			
		||||
    availableNodesWithParent,
 | 
			
		||||
    availableVars,
 | 
			
		||||
 | 
			
		||||
    readOnly,
 | 
			
		||||
    isShowSingleRun,
 | 
			
		||||
    hideSingleRun,
 | 
			
		||||
    runningStatus,
 | 
			
		||||
@ -49,6 +51,7 @@ const AgentPanel: FC<NodePanelProps<AgentNodeType>> = (props) => {
 | 
			
		||||
    setRunInputData,
 | 
			
		||||
    varInputs,
 | 
			
		||||
    outputSchema,
 | 
			
		||||
    handleMemoryChange,
 | 
			
		||||
  } = useConfig(props.id, props.data)
 | 
			
		||||
  const { t } = useTranslation()
 | 
			
		||||
  const nodeInfo = useMemo(() => {
 | 
			
		||||
@ -106,6 +109,20 @@ const AgentPanel: FC<NodePanelProps<AgentNodeType>> = (props) => {
 | 
			
		||||
        nodeId={props.id}
 | 
			
		||||
      />
 | 
			
		||||
    </Field>
 | 
			
		||||
    <div className='px-4 py-2'>
 | 
			
		||||
      {isChatMode && currentStrategy?.features.includes(AgentFeature.HISTORY_MESSAGES) && (
 | 
			
		||||
        <>
 | 
			
		||||
          <Split />
 | 
			
		||||
          <MemoryConfig
 | 
			
		||||
            className='mt-4'
 | 
			
		||||
            readonly={readOnly}
 | 
			
		||||
            config={{ data: inputs.memory }}
 | 
			
		||||
            onChange={handleMemoryChange}
 | 
			
		||||
            canSetRoleName={false}
 | 
			
		||||
          />
 | 
			
		||||
        </>
 | 
			
		||||
      )}
 | 
			
		||||
    </div>
 | 
			
		||||
    <div>
 | 
			
		||||
      <OutputVars>
 | 
			
		||||
        <VarItem
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
import type { CommonNodeType } from '@/app/components/workflow/types'
 | 
			
		||||
import type { CommonNodeType, Memory } from '@/app/components/workflow/types'
 | 
			
		||||
import type { ToolVarInputs } from '../tool/types'
 | 
			
		||||
 | 
			
		||||
export type AgentNodeType = CommonNodeType & {
 | 
			
		||||
@ -8,4 +8,9 @@ export type AgentNodeType = CommonNodeType & {
 | 
			
		||||
  agent_parameters?: ToolVarInputs
 | 
			
		||||
  output_schema: Record<string, any>
 | 
			
		||||
  plugin_unique_identifier?: string
 | 
			
		||||
  memory?: Memory
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export enum AgentFeature {
 | 
			
		||||
  HISTORY_MESSAGES = 'history-messages',
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -4,14 +4,16 @@ import useVarList from '../_base/hooks/use-var-list'
 | 
			
		||||
import useOneStepRun from '../_base/hooks/use-one-step-run'
 | 
			
		||||
import type { AgentNodeType } from './types'
 | 
			
		||||
import {
 | 
			
		||||
  useIsChatMode,
 | 
			
		||||
  useNodesReadOnly,
 | 
			
		||||
} from '@/app/components/workflow/hooks'
 | 
			
		||||
import { useCallback, useMemo } from 'react'
 | 
			
		||||
import { type ToolVarInputs, VarType } from '../tool/types'
 | 
			
		||||
import { useCheckInstalled, useFetchPluginsInMarketPlaceByIds } from '@/service/use-plugins'
 | 
			
		||||
import type { Var } from '../../types'
 | 
			
		||||
import type { Memory, Var } from '../../types'
 | 
			
		||||
import { VarType as VarKindType } from '../../types'
 | 
			
		||||
import useAvailableVarList from '../_base/hooks/use-available-var-list'
 | 
			
		||||
import produce from 'immer'
 | 
			
		||||
 | 
			
		||||
export type StrategyStatus = {
 | 
			
		||||
  plugin: {
 | 
			
		||||
@ -175,6 +177,13 @@ const useConfig = (id: string, payload: AgentNodeType) => {
 | 
			
		||||
    return res
 | 
			
		||||
  }, [inputs.output_schema])
 | 
			
		||||
 | 
			
		||||
  const handleMemoryChange = useCallback((newMemory?: Memory) => {
 | 
			
		||||
    const newInputs = produce(inputs, (draft) => {
 | 
			
		||||
      draft.memory = newMemory
 | 
			
		||||
    })
 | 
			
		||||
    setInputs(newInputs)
 | 
			
		||||
  }, [inputs, setInputs])
 | 
			
		||||
  const isChatMode = useIsChatMode()
 | 
			
		||||
  return {
 | 
			
		||||
    readOnly,
 | 
			
		||||
    inputs,
 | 
			
		||||
@ -202,6 +211,8 @@ const useConfig = (id: string, payload: AgentNodeType) => {
 | 
			
		||||
    runResult,
 | 
			
		||||
    varInputs,
 | 
			
		||||
    outputSchema,
 | 
			
		||||
    handleMemoryChange,
 | 
			
		||||
    isChatMode,
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user