chore: Adjust json_schema.py slightly (#7055)

* Slighly adjust json_schema.py

* Adjust test structures
This commit is contained in:
Vladimir Blagojevic 2024-02-22 14:33:07 +01:00 committed by GitHub
parent 6d0d373def
commit 49cad21a2e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 22 deletions

View File

@ -68,7 +68,7 @@ class JsonSchemaValidator:
# Default error description template
default_error_template = (
"The JSON content in the previous message does not conform to the provided schema.\n"
"The JSON content in the next message does not conform to the provided schema.\n"
"Error details:\n- Message: {error_message}\n"
"- Error Path in JSON: {error_path}\n"
"- Schema Path: {error_schema_path}\n"
@ -126,7 +126,7 @@ class JsonSchemaValidator:
last_message_json = [last_message_json] if not isinstance(last_message_json, list) else last_message_json
for content in last_message_json:
if using_openai_schema:
validate(instance=content["function"]["arguments"]["parameters"], schema=validation_schema)
validate(instance=content["function"]["arguments"], schema=validation_schema)
else:
validate(instance=content, schema=validation_schema)
@ -140,7 +140,7 @@ class JsonSchemaValidator:
recovery_prompt = self.construct_error_recovery_message(
error_template, str(e), error_path, error_schema_path, validation_schema
)
complete_message_list = messages + [ChatMessage.from_user(recovery_prompt)]
complete_message_list = [ChatMessage.from_user(recovery_prompt)] + messages
return {"validation_error": complete_message_list}
def construct_error_recovery_message(

View File

@ -11,7 +11,7 @@ from haystack.dataclasses import ChatMessage
@pytest.fixture
def genuine_fc_message():
return """[{"id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", "function": {"arguments": "{\\n \\"parameters\\": {\\n \\"basehead\\": \\"main...amzn_chat\\",\\n \\"owner\\": \\"deepset-ai\\",\\n \\"repo\\": \\"haystack-core-integrations\\"\\n }\\n}", "name": "compare_branches"}, "type": "function"}]"""
return """[{"id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", "function": {"arguments": "{\\n \\"basehead\\": \\"main...amzn_chat\\",\\n \\"owner\\": \\"deepset-ai\\",\\n \\"repo\\": \\"haystack-core-integrations\\"\\n }", "name": "compare_branches"}, "type": "function"}]"""
@pytest.fixture
@ -26,23 +26,16 @@ def json_schema_github_compare():
"arguments": {
"type": "object",
"properties": {
"parameters": {
"type": "object",
"properties": {
"basehead": {
"type": "string",
"pattern": "^[^\\.]+(\\.{3}).+$",
"description": "Branch names must be in the format 'base_branch...head_branch'",
},
"owner": {"type": "string", "description": "Owner of the repository"},
"repo": {"type": "string", "description": "Name of the repository"},
},
"required": ["basehead", "owner", "repo"],
"description": "Parameters for the function call",
}
"basehead": {
"type": "string",
"pattern": "^[^\\.]+(\\.{3}).+$",
"description": "Branch names must be in the format 'base_branch...head_branch'",
},
"owner": {"type": "string", "description": "Owner of the repository"},
"repo": {"type": "string", "description": "Name of the repository"},
},
"required": ["parameters"],
"description": "Arguments for the function",
"required": ["basehead", "owner", "repo"],
"description": "Parameters for the function call",
},
"name": {"type": "string", "description": "Name of the function to be called"},
},
@ -102,7 +95,7 @@ class TestJsonSchemaValidator:
result = validator.recursive_json_to_object({"key": genuine_fc_message})
# we need this recursive json conversion to validate the message
assert result["key"][0]["function"]["arguments"]["parameters"]["basehead"] == "main...amzn_chat"
assert result["key"][0]["function"]["arguments"]["basehead"] == "main...amzn_chat"
# Validates multiple messages against a provided JSON schema successfully.
def test_validates_multiple_messages_against_json_schema(self, json_schema_github_compare, genuine_fc_message):
@ -206,4 +199,4 @@ class TestJsonSchemaValidator:
result = pipe.run(data={"schema_validator": {"json_schema": json_schema_github_compare}})
assert "validation_error" in result["schema_validator"]
assert len(result["schema_validator"]["validation_error"]) > 1
assert "Error details" in result["schema_validator"]["validation_error"][1].content
assert "Error details" in result["schema_validator"]["validation_error"][0].content