mirror of
				https://github.com/langgenius/dify.git
				synced 2025-11-03 20:33:00 +00:00 
			
		
		
		
	robust for json parser (#17687)
This commit is contained in:
		
							parent
							
								
									0e0220bdbf
								
							
						
					
					
						commit
						5541a1f80e
					
				@ -1,4 +1,5 @@
 | 
			
		||||
import json
 | 
			
		||||
import logging
 | 
			
		||||
import uuid
 | 
			
		||||
from collections.abc import Mapping, Sequence
 | 
			
		||||
from typing import Any, Optional, cast
 | 
			
		||||
@ -58,6 +59,30 @@ from .prompts import (
 | 
			
		||||
    FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def extract_json(text):
 | 
			
		||||
    """
 | 
			
		||||
    From a given JSON started from '{' or '[' extract the complete JSON object.
 | 
			
		||||
    """
 | 
			
		||||
    stack = []
 | 
			
		||||
    for i, c in enumerate(text):
 | 
			
		||||
        if c in {"{", "["}:
 | 
			
		||||
            stack.append(c)
 | 
			
		||||
        elif c in {"}", "]"}:
 | 
			
		||||
            # check if stack is empty
 | 
			
		||||
            if not stack:
 | 
			
		||||
                return text[:i]
 | 
			
		||||
            # check if the last element in stack is matching
 | 
			
		||||
            if (c == "}" and stack[-1] == "{") or (c == "]" and stack[-1] == "["):
 | 
			
		||||
                stack.pop()
 | 
			
		||||
                if not stack:
 | 
			
		||||
                    return text[: i + 1]
 | 
			
		||||
            else:
 | 
			
		||||
                return text[:i]
 | 
			
		||||
    return None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ParameterExtractorNode(LLMNode):
 | 
			
		||||
    """
 | 
			
		||||
@ -594,27 +619,6 @@ class ParameterExtractorNode(LLMNode):
 | 
			
		||||
        Extract complete json response.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        def extract_json(text):
 | 
			
		||||
            """
 | 
			
		||||
            From a given JSON started from '{' or '[' extract the complete JSON object.
 | 
			
		||||
            """
 | 
			
		||||
            stack = []
 | 
			
		||||
            for i, c in enumerate(text):
 | 
			
		||||
                if c in {"{", "["}:
 | 
			
		||||
                    stack.append(c)
 | 
			
		||||
                elif c in {"}", "]"}:
 | 
			
		||||
                    # check if stack is empty
 | 
			
		||||
                    if not stack:
 | 
			
		||||
                        return text[:i]
 | 
			
		||||
                    # check if the last element in stack is matching
 | 
			
		||||
                    if (c == "}" and stack[-1] == "{") or (c == "]" and stack[-1] == "["):
 | 
			
		||||
                        stack.pop()
 | 
			
		||||
                        if not stack:
 | 
			
		||||
                            return text[: i + 1]
 | 
			
		||||
                    else:
 | 
			
		||||
                        return text[:i]
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
        # extract json from the text
 | 
			
		||||
        for idx in range(len(result)):
 | 
			
		||||
            if result[idx] == "{" or result[idx] == "[":
 | 
			
		||||
@ -624,6 +628,7 @@ class ParameterExtractorNode(LLMNode):
 | 
			
		||||
                        return cast(dict, json.loads(json_str))
 | 
			
		||||
                    except Exception:
 | 
			
		||||
                        pass
 | 
			
		||||
        logger.info(f"extra error: {result}")
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> Optional[dict]:
 | 
			
		||||
@ -633,7 +638,18 @@ class ParameterExtractorNode(LLMNode):
 | 
			
		||||
        if not tool_call or not tool_call.function.arguments:
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
        return cast(dict, json.loads(tool_call.function.arguments))
 | 
			
		||||
        result = tool_call.function.arguments
 | 
			
		||||
        # extract json from the arguments
 | 
			
		||||
        for idx in range(len(result)):
 | 
			
		||||
            if result[idx] == "{" or result[idx] == "[":
 | 
			
		||||
                json_str = extract_json(result[idx:])
 | 
			
		||||
                if json_str:
 | 
			
		||||
                    try:
 | 
			
		||||
                        return cast(dict, json.loads(json_str))
 | 
			
		||||
                    except Exception:
 | 
			
		||||
                        pass
 | 
			
		||||
        logger.info(f"extra error: {result}")
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict:
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
@ -5,6 +5,7 @@ from typing import Optional
 | 
			
		||||
from unittest.mock import MagicMock
 | 
			
		||||
 | 
			
		||||
from core.app.entities.app_invoke_entities import InvokeFrom
 | 
			
		||||
from core.model_runtime.entities import AssistantPromptMessage
 | 
			
		||||
from core.workflow.entities.variable_pool import VariablePool
 | 
			
		||||
from core.workflow.enums import SystemVariableKey
 | 
			
		||||
from core.workflow.graph_engine.entities.graph import Graph
 | 
			
		||||
@ -311,6 +312,46 @@ def test_extract_json_response():
 | 
			
		||||
    assert result["location"] == "kawaii"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_extract_json_from_tool_call():
 | 
			
		||||
    """
 | 
			
		||||
    Test extract json response.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    node = init_parameter_extractor_node(
 | 
			
		||||
        config={
 | 
			
		||||
            "id": "llm",
 | 
			
		||||
            "data": {
 | 
			
		||||
                "title": "123",
 | 
			
		||||
                "type": "parameter-extractor",
 | 
			
		||||
                "model": {
 | 
			
		||||
                    "provider": "langgenius/openai/openai",
 | 
			
		||||
                    "name": "gpt-3.5-turbo-instruct",
 | 
			
		||||
                    "mode": "completion",
 | 
			
		||||
                    "completion_params": {},
 | 
			
		||||
                },
 | 
			
		||||
                "query": ["sys", "query"],
 | 
			
		||||
                "parameters": [{"name": "location", "type": "string", "description": "location", "required": True}],
 | 
			
		||||
                "reasoning_mode": "prompt",
 | 
			
		||||
                "instruction": "{{#sys.query#}}",
 | 
			
		||||
                "memory": None,
 | 
			
		||||
            },
 | 
			
		||||
        },
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    result = node._extract_json_from_tool_call(
 | 
			
		||||
        AssistantPromptMessage.ToolCall(
 | 
			
		||||
            id="llm",
 | 
			
		||||
            type="parameter-extractor",
 | 
			
		||||
            function=AssistantPromptMessage.ToolCall.ToolCallFunction(
 | 
			
		||||
                name="foo", arguments="""{"location":"kawaii"}{"location": 1}"""
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    assert result is not None
 | 
			
		||||
    assert result["location"] == "kawaii"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_chat_parameter_extractor_with_memory(setup_model_mock):
 | 
			
		||||
    """
 | 
			
		||||
    Test chat parameter extractor with memory.
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user