mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-13 07:47:26 +00:00
chore: Adjust json_schema.py slightly (#7055)
* Slighly adjust json_schema.py * Adjust test structures
This commit is contained in:
parent
6d0d373def
commit
49cad21a2e
@ -68,7 +68,7 @@ class JsonSchemaValidator:
|
|||||||
|
|
||||||
# Default error description template
|
# Default error description template
|
||||||
default_error_template = (
|
default_error_template = (
|
||||||
"The JSON content in the previous message does not conform to the provided schema.\n"
|
"The JSON content in the next message does not conform to the provided schema.\n"
|
||||||
"Error details:\n- Message: {error_message}\n"
|
"Error details:\n- Message: {error_message}\n"
|
||||||
"- Error Path in JSON: {error_path}\n"
|
"- Error Path in JSON: {error_path}\n"
|
||||||
"- Schema Path: {error_schema_path}\n"
|
"- Schema Path: {error_schema_path}\n"
|
||||||
@ -126,7 +126,7 @@ class JsonSchemaValidator:
|
|||||||
last_message_json = [last_message_json] if not isinstance(last_message_json, list) else last_message_json
|
last_message_json = [last_message_json] if not isinstance(last_message_json, list) else last_message_json
|
||||||
for content in last_message_json:
|
for content in last_message_json:
|
||||||
if using_openai_schema:
|
if using_openai_schema:
|
||||||
validate(instance=content["function"]["arguments"]["parameters"], schema=validation_schema)
|
validate(instance=content["function"]["arguments"], schema=validation_schema)
|
||||||
else:
|
else:
|
||||||
validate(instance=content, schema=validation_schema)
|
validate(instance=content, schema=validation_schema)
|
||||||
|
|
||||||
@ -140,7 +140,7 @@ class JsonSchemaValidator:
|
|||||||
recovery_prompt = self.construct_error_recovery_message(
|
recovery_prompt = self.construct_error_recovery_message(
|
||||||
error_template, str(e), error_path, error_schema_path, validation_schema
|
error_template, str(e), error_path, error_schema_path, validation_schema
|
||||||
)
|
)
|
||||||
complete_message_list = messages + [ChatMessage.from_user(recovery_prompt)]
|
complete_message_list = [ChatMessage.from_user(recovery_prompt)] + messages
|
||||||
return {"validation_error": complete_message_list}
|
return {"validation_error": complete_message_list}
|
||||||
|
|
||||||
def construct_error_recovery_message(
|
def construct_error_recovery_message(
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from haystack.dataclasses import ChatMessage
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def genuine_fc_message():
|
def genuine_fc_message():
|
||||||
return """[{"id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", "function": {"arguments": "{\\n \\"parameters\\": {\\n \\"basehead\\": \\"main...amzn_chat\\",\\n \\"owner\\": \\"deepset-ai\\",\\n \\"repo\\": \\"haystack-core-integrations\\"\\n }\\n}", "name": "compare_branches"}, "type": "function"}]"""
|
return """[{"id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", "function": {"arguments": "{\\n \\"basehead\\": \\"main...amzn_chat\\",\\n \\"owner\\": \\"deepset-ai\\",\\n \\"repo\\": \\"haystack-core-integrations\\"\\n }", "name": "compare_branches"}, "type": "function"}]"""
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -26,23 +26,16 @@ def json_schema_github_compare():
|
|||||||
"arguments": {
|
"arguments": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"parameters": {
|
"basehead": {
|
||||||
"type": "object",
|
"type": "string",
|
||||||
"properties": {
|
"pattern": "^[^\\.]+(\\.{3}).+$",
|
||||||
"basehead": {
|
"description": "Branch names must be in the format 'base_branch...head_branch'",
|
||||||
"type": "string",
|
},
|
||||||
"pattern": "^[^\\.]+(\\.{3}).+$",
|
"owner": {"type": "string", "description": "Owner of the repository"},
|
||||||
"description": "Branch names must be in the format 'base_branch...head_branch'",
|
"repo": {"type": "string", "description": "Name of the repository"},
|
||||||
},
|
|
||||||
"owner": {"type": "string", "description": "Owner of the repository"},
|
|
||||||
"repo": {"type": "string", "description": "Name of the repository"},
|
|
||||||
},
|
|
||||||
"required": ["basehead", "owner", "repo"],
|
|
||||||
"description": "Parameters for the function call",
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"required": ["parameters"],
|
"required": ["basehead", "owner", "repo"],
|
||||||
"description": "Arguments for the function",
|
"description": "Parameters for the function call",
|
||||||
},
|
},
|
||||||
"name": {"type": "string", "description": "Name of the function to be called"},
|
"name": {"type": "string", "description": "Name of the function to be called"},
|
||||||
},
|
},
|
||||||
@ -102,7 +95,7 @@ class TestJsonSchemaValidator:
|
|||||||
result = validator.recursive_json_to_object({"key": genuine_fc_message})
|
result = validator.recursive_json_to_object({"key": genuine_fc_message})
|
||||||
|
|
||||||
# we need this recursive json conversion to validate the message
|
# we need this recursive json conversion to validate the message
|
||||||
assert result["key"][0]["function"]["arguments"]["parameters"]["basehead"] == "main...amzn_chat"
|
assert result["key"][0]["function"]["arguments"]["basehead"] == "main...amzn_chat"
|
||||||
|
|
||||||
# Validates multiple messages against a provided JSON schema successfully.
|
# Validates multiple messages against a provided JSON schema successfully.
|
||||||
def test_validates_multiple_messages_against_json_schema(self, json_schema_github_compare, genuine_fc_message):
|
def test_validates_multiple_messages_against_json_schema(self, json_schema_github_compare, genuine_fc_message):
|
||||||
@ -206,4 +199,4 @@ class TestJsonSchemaValidator:
|
|||||||
result = pipe.run(data={"schema_validator": {"json_schema": json_schema_github_compare}})
|
result = pipe.run(data={"schema_validator": {"json_schema": json_schema_github_compare}})
|
||||||
assert "validation_error" in result["schema_validator"]
|
assert "validation_error" in result["schema_validator"]
|
||||||
assert len(result["schema_validator"]["validation_error"]) > 1
|
assert len(result["schema_validator"]["validation_error"]) > 1
|
||||||
assert "Error details" in result["schema_validator"]["validation_error"][1].content
|
assert "Error details" in result["schema_validator"]["validation_error"][0].content
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user