mirror of
				https://github.com/langgenius/dify.git
				synced 2025-11-04 12:53:38 +00:00 
			
		
		
		
	fix(workflow): refine variable type checks in LLMNode (#10051)
This commit is contained in:
		
							parent
							
								
									4d38798dd5
								
							
						
					
					
						commit
						3b53e06e0d
					
				@ -349,13 +349,11 @@ class LLMNode(BaseNode[LLMNodeData]):
 | 
			
		||||
        variable = self.graph_runtime_state.variable_pool.get(selector)
 | 
			
		||||
        if variable is None:
 | 
			
		||||
            return []
 | 
			
		||||
        if isinstance(variable, FileSegment):
 | 
			
		||||
        elif isinstance(variable, FileSegment):
 | 
			
		||||
            return [variable.value]
 | 
			
		||||
        if isinstance(variable, ArrayFileSegment):
 | 
			
		||||
        elif isinstance(variable, ArrayFileSegment):
 | 
			
		||||
            return variable.value
 | 
			
		||||
        # FIXME: Temporary fix for empty array,
 | 
			
		||||
        # all variables added to variable pool should be a Segment instance.
 | 
			
		||||
        if isinstance(variable, ArrayAnySegment) and len(variable.value) == 0:
 | 
			
		||||
        elif isinstance(variable, NoneSegment | ArrayAnySegment):
 | 
			
		||||
            return []
 | 
			
		||||
        raise ValueError(f"Invalid variable type: {type(variable)}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										125
									
								
								api/tests/unit_tests/core/workflow/nodes/llm/test_node.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										125
									
								
								api/tests/unit_tests/core/workflow/nodes/llm/test_node.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,125 @@
 | 
			
		||||
import pytest
 | 
			
		||||
 | 
			
		||||
from core.app.entities.app_invoke_entities import InvokeFrom
 | 
			
		||||
from core.file import File, FileTransferMethod, FileType
 | 
			
		||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
 | 
			
		||||
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
 | 
			
		||||
from core.workflow.entities.variable_pool import VariablePool
 | 
			
		||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
 | 
			
		||||
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
 | 
			
		||||
from core.workflow.nodes.end import EndStreamParam
 | 
			
		||||
from core.workflow.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig, VisionConfig, VisionConfigOptions
 | 
			
		||||
from core.workflow.nodes.llm.node import LLMNode
 | 
			
		||||
from models.enums import UserFrom
 | 
			
		||||
from models.workflow import WorkflowType
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestLLMNode:
 | 
			
		||||
    @pytest.fixture
 | 
			
		||||
    def llm_node(self):
 | 
			
		||||
        data = LLMNodeData(
 | 
			
		||||
            title="Test LLM",
 | 
			
		||||
            model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
 | 
			
		||||
            prompt_template=[],
 | 
			
		||||
            memory=None,
 | 
			
		||||
            context=ContextConfig(enabled=False),
 | 
			
		||||
            vision=VisionConfig(
 | 
			
		||||
                enabled=True,
 | 
			
		||||
                configs=VisionConfigOptions(
 | 
			
		||||
                    variable_selector=["sys", "files"],
 | 
			
		||||
                    detail=ImagePromptMessageContent.DETAIL.HIGH,
 | 
			
		||||
                ),
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
        variable_pool = VariablePool(
 | 
			
		||||
            system_variables={},
 | 
			
		||||
            user_inputs={},
 | 
			
		||||
        )
 | 
			
		||||
        node = LLMNode(
 | 
			
		||||
            id="1",
 | 
			
		||||
            config={
 | 
			
		||||
                "id": "1",
 | 
			
		||||
                "data": data.model_dump(),
 | 
			
		||||
            },
 | 
			
		||||
            graph_init_params=GraphInitParams(
 | 
			
		||||
                tenant_id="1",
 | 
			
		||||
                app_id="1",
 | 
			
		||||
                workflow_type=WorkflowType.WORKFLOW,
 | 
			
		||||
                workflow_id="1",
 | 
			
		||||
                graph_config={},
 | 
			
		||||
                user_id="1",
 | 
			
		||||
                user_from=UserFrom.ACCOUNT,
 | 
			
		||||
                invoke_from=InvokeFrom.SERVICE_API,
 | 
			
		||||
                call_depth=0,
 | 
			
		||||
            ),
 | 
			
		||||
            graph=Graph(
 | 
			
		||||
                root_node_id="1",
 | 
			
		||||
                answer_stream_generate_routes=AnswerStreamGenerateRoute(
 | 
			
		||||
                    answer_dependencies={},
 | 
			
		||||
                    answer_generate_route={},
 | 
			
		||||
                ),
 | 
			
		||||
                end_stream_param=EndStreamParam(
 | 
			
		||||
                    end_dependencies={},
 | 
			
		||||
                    end_stream_variable_selector_mapping={},
 | 
			
		||||
                ),
 | 
			
		||||
            ),
 | 
			
		||||
            graph_runtime_state=GraphRuntimeState(
 | 
			
		||||
                variable_pool=variable_pool,
 | 
			
		||||
                start_at=0,
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
        return node
 | 
			
		||||
 | 
			
		||||
    def test_fetch_files_with_file_segment(self, llm_node):
 | 
			
		||||
        file = File(
 | 
			
		||||
            id="1",
 | 
			
		||||
            tenant_id="test",
 | 
			
		||||
            type=FileType.IMAGE,
 | 
			
		||||
            filename="test.jpg",
 | 
			
		||||
            transfer_method=FileTransferMethod.LOCAL_FILE,
 | 
			
		||||
            related_id="1",
 | 
			
		||||
        )
 | 
			
		||||
        llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file)
 | 
			
		||||
 | 
			
		||||
        result = llm_node._fetch_files(selector=["sys", "files"])
 | 
			
		||||
        assert result == [file]
 | 
			
		||||
 | 
			
		||||
    def test_fetch_files_with_array_file_segment(self, llm_node):
 | 
			
		||||
        files = [
 | 
			
		||||
            File(
 | 
			
		||||
                id="1",
 | 
			
		||||
                tenant_id="test",
 | 
			
		||||
                type=FileType.IMAGE,
 | 
			
		||||
                filename="test1.jpg",
 | 
			
		||||
                transfer_method=FileTransferMethod.LOCAL_FILE,
 | 
			
		||||
                related_id="1",
 | 
			
		||||
            ),
 | 
			
		||||
            File(
 | 
			
		||||
                id="2",
 | 
			
		||||
                tenant_id="test",
 | 
			
		||||
                type=FileType.IMAGE,
 | 
			
		||||
                filename="test2.jpg",
 | 
			
		||||
                transfer_method=FileTransferMethod.LOCAL_FILE,
 | 
			
		||||
                related_id="2",
 | 
			
		||||
            ),
 | 
			
		||||
        ]
 | 
			
		||||
        llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
 | 
			
		||||
 | 
			
		||||
        result = llm_node._fetch_files(selector=["sys", "files"])
 | 
			
		||||
        assert result == files
 | 
			
		||||
 | 
			
		||||
    def test_fetch_files_with_none_segment(self, llm_node):
 | 
			
		||||
        llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment())
 | 
			
		||||
 | 
			
		||||
        result = llm_node._fetch_files(selector=["sys", "files"])
 | 
			
		||||
        assert result == []
 | 
			
		||||
 | 
			
		||||
    def test_fetch_files_with_array_any_segment(self, llm_node):
 | 
			
		||||
        llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
 | 
			
		||||
 | 
			
		||||
        result = llm_node._fetch_files(selector=["sys", "files"])
 | 
			
		||||
        assert result == []
 | 
			
		||||
 | 
			
		||||
    def test_fetch_files_with_non_existent_variable(self, llm_node):
 | 
			
		||||
        result = llm_node._fetch_files(selector=["sys", "files"])
 | 
			
		||||
        assert result == []
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user