From 49cad21a2e934eb1f19f0f39a55c52ce89c56726 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 22 Feb 2024 14:33:07 +0100 Subject: [PATCH] chore: Adjust json_schema.py slightly (#7055) * Slighly adjust json_schema.py * Adjust test structures --- haystack/components/validators/json_schema.py | 6 ++-- .../components/validators/test_json_schema.py | 31 +++++++------------ 2 files changed, 15 insertions(+), 22 deletions(-) diff --git a/haystack/components/validators/json_schema.py b/haystack/components/validators/json_schema.py index e3eed9542..d0be03a74 100644 --- a/haystack/components/validators/json_schema.py +++ b/haystack/components/validators/json_schema.py @@ -68,7 +68,7 @@ class JsonSchemaValidator: # Default error description template default_error_template = ( - "The JSON content in the previous message does not conform to the provided schema.\n" + "The JSON content in the next message does not conform to the provided schema.\n" "Error details:\n- Message: {error_message}\n" "- Error Path in JSON: {error_path}\n" "- Schema Path: {error_schema_path}\n" @@ -126,7 +126,7 @@ class JsonSchemaValidator: last_message_json = [last_message_json] if not isinstance(last_message_json, list) else last_message_json for content in last_message_json: if using_openai_schema: - validate(instance=content["function"]["arguments"]["parameters"], schema=validation_schema) + validate(instance=content["function"]["arguments"], schema=validation_schema) else: validate(instance=content, schema=validation_schema) @@ -140,7 +140,7 @@ class JsonSchemaValidator: recovery_prompt = self.construct_error_recovery_message( error_template, str(e), error_path, error_schema_path, validation_schema ) - complete_message_list = messages + [ChatMessage.from_user(recovery_prompt)] + complete_message_list = [ChatMessage.from_user(recovery_prompt)] + messages return {"validation_error": complete_message_list} def construct_error_recovery_message( diff --git a/test/components/validators/test_json_schema.py b/test/components/validators/test_json_schema.py index e21bf6ebc..9bf0f936e 100644 --- a/test/components/validators/test_json_schema.py +++ b/test/components/validators/test_json_schema.py @@ -11,7 +11,7 @@ from haystack.dataclasses import ChatMessage @pytest.fixture def genuine_fc_message(): - return """[{"id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", "function": {"arguments": "{\\n \\"parameters\\": {\\n \\"basehead\\": \\"main...amzn_chat\\",\\n \\"owner\\": \\"deepset-ai\\",\\n \\"repo\\": \\"haystack-core-integrations\\"\\n }\\n}", "name": "compare_branches"}, "type": "function"}]""" + return """[{"id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", "function": {"arguments": "{\\n \\"basehead\\": \\"main...amzn_chat\\",\\n \\"owner\\": \\"deepset-ai\\",\\n \\"repo\\": \\"haystack-core-integrations\\"\\n }", "name": "compare_branches"}, "type": "function"}]""" @pytest.fixture @@ -26,23 +26,16 @@ def json_schema_github_compare(): "arguments": { "type": "object", "properties": { - "parameters": { - "type": "object", - "properties": { - "basehead": { - "type": "string", - "pattern": "^[^\\.]+(\\.{3}).+$", - "description": "Branch names must be in the format 'base_branch...head_branch'", - }, - "owner": {"type": "string", "description": "Owner of the repository"}, - "repo": {"type": "string", "description": "Name of the repository"}, - }, - "required": ["basehead", "owner", "repo"], - "description": "Parameters for the function call", - } + "basehead": { + "type": "string", + "pattern": "^[^\\.]+(\\.{3}).+$", + "description": "Branch names must be in the format 'base_branch...head_branch'", + }, + "owner": {"type": "string", "description": "Owner of the repository"}, + "repo": {"type": "string", "description": "Name of the repository"}, }, - "required": ["parameters"], - "description": "Arguments for the function", + "required": ["basehead", "owner", "repo"], + "description": "Parameters for the function call", }, "name": {"type": "string", "description": "Name of the function to be called"}, }, @@ -102,7 +95,7 @@ class TestJsonSchemaValidator: result = validator.recursive_json_to_object({"key": genuine_fc_message}) # we need this recursive json conversion to validate the message - assert result["key"][0]["function"]["arguments"]["parameters"]["basehead"] == "main...amzn_chat" + assert result["key"][0]["function"]["arguments"]["basehead"] == "main...amzn_chat" # Validates multiple messages against a provided JSON schema successfully. def test_validates_multiple_messages_against_json_schema(self, json_schema_github_compare, genuine_fc_message): @@ -206,4 +199,4 @@ class TestJsonSchemaValidator: result = pipe.run(data={"schema_validator": {"json_schema": json_schema_github_compare}}) assert "validation_error" in result["schema_validator"] assert len(result["schema_validator"]["validation_error"]) > 1 - assert "Error details" in result["schema_validator"]["validation_error"][1].content + assert "Error details" in result["schema_validator"]["validation_error"][0].content