mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-06 03:57:19 +00:00
feat: Add JsonSchemaValidator (#6937)
* Add JsonSchemaValidator --------- Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>
This commit is contained in:
parent
cf221a9701
commit
5a8d02064b
26
docs/pydoc/config/validators_api.yml
Normal file
26
docs/pydoc/config/validators_api.yml
Normal file
@ -0,0 +1,26 @@
|
||||
loaders:
|
||||
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
|
||||
search_path: [../../../haystack/components/validators]
|
||||
modules: ["json_schema"]
|
||||
ignore_when_discovered: ["__init__"]
|
||||
processors:
|
||||
- type: filter
|
||||
expression:
|
||||
documented_only: true
|
||||
do_not_filter_modules: false
|
||||
skip_empty_modules: true
|
||||
- type: smart
|
||||
- type: crossref
|
||||
renderer:
|
||||
type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer
|
||||
excerpt: Validators validate LLM outputs
|
||||
category_slug: haystack-api
|
||||
title: Validators
|
||||
slug: validators-api
|
||||
order: 155
|
||||
markdown:
|
||||
descriptive_class_title: false
|
||||
descriptive_module_title: true
|
||||
add_method_class_prefix: true
|
||||
add_member_class_prefix: false
|
||||
filename: validators_api.md
|
||||
3
haystack/components/validators/__init__.py
Normal file
3
haystack/components/validators/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from haystack.components.validators.json_schema import JsonSchemaValidator
|
||||
|
||||
__all__ = ["JsonSchemaValidator"]
|
||||
212
haystack/components/validators/json_schema.py
Normal file
212
haystack/components/validators/json_schema.py
Normal file
@ -0,0 +1,212 @@
|
||||
import json
|
||||
from typing import List, Any, Dict, Optional
|
||||
|
||||
from haystack import component
|
||||
from haystack.dataclasses import ChatMessage
|
||||
from haystack.lazy_imports import LazyImport
|
||||
|
||||
with LazyImport(message="Run 'pip install jsonschema'") as jsonschema_import:
|
||||
from jsonschema import validate, ValidationError
|
||||
|
||||
|
||||
@component
|
||||
class JsonSchemaValidator:
|
||||
"""
|
||||
JsonSchemaValidator validates JSON content of ChatMessage against a specified JSON schema.
|
||||
|
||||
If JSON content of a message conforms to the provided schema, the message is passed along the "validated" output.
|
||||
If the JSON content does not conform to the schema, the message is passed along the "validation_error" output.
|
||||
In the latter case, the error message is constructed using the provided error_template or a default template.
|
||||
These error ChatMessages can be used by LLMs in Haystack 2.x recovery loops.
|
||||
|
||||
Here is a small example of how to use this component in a pipeline implementing schema validation recovery loop:
|
||||
|
||||
```python
|
||||
from typing import List
|
||||
|
||||
from haystack import Pipeline
|
||||
from haystack.components.generators.chat import OpenAIChatGenerator
|
||||
from haystack.components.others import Multiplexer
|
||||
from haystack.components.validators import JsonSchemaValidator
|
||||
from haystack import component
|
||||
from haystack.dataclasses import ChatMessage
|
||||
|
||||
|
||||
@component
|
||||
class MessageProducer:
|
||||
|
||||
@component.output_types(messages=List[ChatMessage])
|
||||
def run(self, messages: List[ChatMessage]) -> dict:
|
||||
return {"messages": messages}
|
||||
|
||||
|
||||
p = Pipeline()
|
||||
p.add_component("llm", OpenAIChatGenerator(model="gpt-4-1106-preview",
|
||||
generation_kwargs={"response_format": {"type": "json_object"}}))
|
||||
p.add_component("schema_validator", JsonSchemaValidator())
|
||||
p.add_component("mx_for_llm", Multiplexer(List[ChatMessage]))
|
||||
p.add_component("message_producer", MessageProducer())
|
||||
|
||||
p.connect("message_producer.messages", "mx_for_llm")
|
||||
p.connect("mx_for_llm", "llm")
|
||||
p.connect("llm.replies", "schema_validator.messages")
|
||||
p.connect("schema_validator.validation_error", "mx_for_llm")
|
||||
|
||||
|
||||
|
||||
result = p.run(
|
||||
data={"message_producer": {"messages":[ChatMessage.from_user("Generate JSON for person with name 'John' and age 30")]},
|
||||
"schema_validator": {"json_schema": {"type": "object",
|
||||
"properties": {"name": {"type": "string"},
|
||||
"age": {"type": "integer"}}}}})
|
||||
print(result)
|
||||
>> {'schema_validator': {'validated': [ChatMessage(content='\n{\n "name": "John",\n "age": 30\n}',
|
||||
role=<ChatRole.ASSISTANT: 'assistant'>, name=None, meta={'model': 'gpt-4-1106-preview', 'index': 0,
|
||||
'finish_reason': 'stop', 'usage': {'completion_tokens': 17, 'prompt_tokens': 20, 'total_tokens': 37}})]}}
|
||||
```
|
||||
"""
|
||||
|
||||
# Default error description template
|
||||
default_error_template = (
|
||||
"The JSON content in the previous message does not conform to the provided schema.\n"
|
||||
"Error details:\n- Message: {error_message}\n"
|
||||
"- Error Path in JSON: {error_path}\n"
|
||||
"- Schema Path: {error_schema_path}\n"
|
||||
"Please match the following schema:\n"
|
||||
"{json_schema}\n"
|
||||
"and provide the corrected JSON content ONLY."
|
||||
)
|
||||
|
||||
def __init__(self, json_schema: Optional[Dict[str, Any]] = None, error_template: Optional[str] = None):
|
||||
"""
|
||||
Initializes a new JsonSchemaValidator instance.
|
||||
|
||||
:param json_schema: A dictionary representing the JSON schema against which the messages' content is validated.
|
||||
:param error_template: A custom template string for formatting the error message in case of validation failure.
|
||||
"""
|
||||
jsonschema_import.check()
|
||||
self.json_schema = json_schema
|
||||
self.error_template = error_template
|
||||
|
||||
@component.output_types(validated=List[ChatMessage], validation_error=List[ChatMessage])
|
||||
def run(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
json_schema: Optional[Dict[str, Any]] = None,
|
||||
error_template: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Checks if the last message and its content field conforms to json_schema. If it does, the message is passed
|
||||
along the "validated" output. If it does not, the message is passed along the "validation_error" output.
|
||||
|
||||
:param messages: A list of ChatMessage instances to be validated. The last message in this list is the one
|
||||
that is validated.
|
||||
:param json_schema: A dictionary representing the JSON schema against which the messages' content is validated.
|
||||
:param error_template: A custom template string for formatting the error message in case of validation
|
||||
failure, by default None.
|
||||
"""
|
||||
last_message = messages[-1]
|
||||
last_message_content = json.loads(last_message.content)
|
||||
|
||||
json_schema = json_schema or self.json_schema
|
||||
error_template = error_template or self.error_template or self.default_error_template
|
||||
|
||||
if not json_schema:
|
||||
raise ValueError("Provide a JSON schema for validation either in the run method or in the component init.")
|
||||
|
||||
# fc payload is json object but subtree `parameters` is string - we need to convert to json object
|
||||
# we need complete json to validate it against schema
|
||||
last_message_json = self.recursive_json_to_object(last_message_content)
|
||||
using_openai_schema: bool = self.is_openai_function_calling_schema(json_schema)
|
||||
if using_openai_schema:
|
||||
validation_schema = json_schema["parameters"]
|
||||
else:
|
||||
validation_schema = json_schema
|
||||
try:
|
||||
last_message_json = [last_message_json] if not isinstance(last_message_json, list) else last_message_json
|
||||
for content in last_message_json:
|
||||
if using_openai_schema:
|
||||
validate(instance=content["function"]["arguments"]["parameters"], schema=validation_schema)
|
||||
else:
|
||||
validate(instance=content, schema=validation_schema)
|
||||
|
||||
return {"validated": messages}
|
||||
except ValidationError as e:
|
||||
error_path = " -> ".join(map(str, e.absolute_path)) if e.absolute_path else "N/A"
|
||||
error_schema_path = " -> ".join(map(str, e.absolute_schema_path)) if e.absolute_schema_path else "N/A"
|
||||
|
||||
error_template = error_template or self.default_error_template
|
||||
|
||||
recovery_prompt = self.construct_error_recovery_message(
|
||||
error_template, str(e), error_path, error_schema_path, validation_schema
|
||||
)
|
||||
complete_message_list = messages + [ChatMessage.from_user(recovery_prompt)]
|
||||
return {"validation_error": complete_message_list}
|
||||
|
||||
def construct_error_recovery_message(
|
||||
self,
|
||||
error_template: str,
|
||||
error_message: str,
|
||||
error_path: str,
|
||||
error_schema_path: str,
|
||||
json_schema: Dict[str, Any],
|
||||
) -> str:
|
||||
"""
|
||||
Constructs an error recovery message using a specified template or the default one if none is provided.
|
||||
|
||||
:param error_template: A custom template string for formatting the error message in case of validation failure.
|
||||
:param error_message: The error message returned by the JSON schema validator.
|
||||
:param error_path: The path in the JSON content where the error occurred.
|
||||
:param error_schema_path: The path in the JSON schema where the error occurred.
|
||||
:param json_schema: The JSON schema against which the content is validated.
|
||||
"""
|
||||
error_template = error_template or self.default_error_template
|
||||
|
||||
return error_template.format(
|
||||
error_message=error_message,
|
||||
error_path=error_path,
|
||||
error_schema_path=error_schema_path,
|
||||
json_schema=json_schema,
|
||||
)
|
||||
|
||||
def is_openai_function_calling_schema(self, json_schema: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Checks if the provided schema is a valid OpenAI function calling schema.
|
||||
|
||||
:param json_schema: The JSON schema to check
|
||||
:return: True if the schema is a valid OpenAI function calling schema; otherwise, False.
|
||||
"""
|
||||
return all(key in json_schema for key in ["name", "description", "parameters"])
|
||||
|
||||
def recursive_json_to_object(self, data: Any) -> Any:
|
||||
"""
|
||||
Recursively traverses a data structure (dictionary or list), converting any string values
|
||||
that are valid JSON objects into dictionary objects, and returns a new data structure.
|
||||
|
||||
:param data: The data structure to be traversed.
|
||||
:return: A new data structure with JSON strings converted to dictionary objects.
|
||||
"""
|
||||
if isinstance(data, list):
|
||||
return [self.recursive_json_to_object(item) for item in data]
|
||||
|
||||
if isinstance(data, dict):
|
||||
new_dict = {}
|
||||
for key, value in data.items():
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
json_value = json.loads(value)
|
||||
new_dict[key] = (
|
||||
self.recursive_json_to_object(json_value)
|
||||
if isinstance(json_value, (dict, list))
|
||||
else json_value
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
new_dict[key] = value
|
||||
elif isinstance(value, dict):
|
||||
new_dict[key] = self.recursive_json_to_object(value)
|
||||
else:
|
||||
new_dict[key] = value
|
||||
return new_dict
|
||||
|
||||
# If it's neither a list nor a dictionary, return the value directly
|
||||
raise ValueError("Input must be a dictionary or a list of dictionaries.")
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Introduced JsonSchemaValidator to validate the JSON content of ChatMessage against a provided JSON schema. Valid messages are emitted through the 'validated' output, while messages failing validation are sent via the 'validation_error' output, along with useful error details for troubleshooting.
|
||||
209
test/components/validators/test_json_schema.py
Normal file
209
test/components/validators/test_json_schema.py
Normal file
@ -0,0 +1,209 @@
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
from haystack import component, Pipeline
|
||||
from haystack.components.validators import JsonSchemaValidator
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack.dataclasses import ChatMessage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def genuine_fc_message():
|
||||
return """[{"id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", "function": {"arguments": "{\\n \\"parameters\\": {\\n \\"basehead\\": \\"main...amzn_chat\\",\\n \\"owner\\": \\"deepset-ai\\",\\n \\"repo\\": \\"haystack-core-integrations\\"\\n }\\n}", "name": "compare_branches"}, "type": "function"}]"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def json_schema_github_compare():
|
||||
json_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {"type": "string", "description": "A unique identifier for the call"},
|
||||
"function": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"arguments": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"basehead": {
|
||||
"type": "string",
|
||||
"pattern": "^[^\\.]+(\\.{3}).+$",
|
||||
"description": "Branch names must be in the format 'base_branch...head_branch'",
|
||||
},
|
||||
"owner": {"type": "string", "description": "Owner of the repository"},
|
||||
"repo": {"type": "string", "description": "Name of the repository"},
|
||||
},
|
||||
"required": ["basehead", "owner", "repo"],
|
||||
"description": "Parameters for the function call",
|
||||
}
|
||||
},
|
||||
"required": ["parameters"],
|
||||
"description": "Arguments for the function",
|
||||
},
|
||||
"name": {"type": "string", "description": "Name of the function to be called"},
|
||||
},
|
||||
"required": ["arguments", "name"],
|
||||
"description": "Details of the function being called",
|
||||
},
|
||||
"type": {"type": "string", "description": "Type of the call (e.g., 'function')"},
|
||||
},
|
||||
"required": ["function", "type"],
|
||||
"description": "Structure representing a function call",
|
||||
}
|
||||
return json_schema
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def json_schema_github_compare_openai():
|
||||
json_schema = {
|
||||
"name": "compare_branches",
|
||||
"description": "Compares two branches in a GitHub repository",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"basehead": {
|
||||
"type": "string",
|
||||
"pattern": "^[^\\.]+(\\.{3}).+$",
|
||||
"description": "Branch names must be in the format 'base_branch...head_branch'",
|
||||
},
|
||||
"owner": {"type": "string", "description": "Owner of the repository"},
|
||||
"repo": {"type": "string", "description": "Name of the repository"},
|
||||
},
|
||||
"required": ["basehead", "owner", "repo"],
|
||||
"description": "Parameters for the function call",
|
||||
},
|
||||
}
|
||||
return json_schema
|
||||
|
||||
|
||||
class TestJsonSchemaValidator:
|
||||
# Validates a message against a provided JSON schema successfully.
|
||||
def test_validates_message_against_json_schema(self, json_schema_github_compare, genuine_fc_message):
|
||||
validator = JsonSchemaValidator()
|
||||
message = ChatMessage.from_assistant(genuine_fc_message)
|
||||
|
||||
result = validator.run([message], json_schema_github_compare)
|
||||
|
||||
assert "validated" in result
|
||||
assert len(result["validated"]) == 1
|
||||
assert result["validated"][0] == message
|
||||
|
||||
# Validates recursive_json_to_object method
|
||||
def test_recursive_json_to_object(self, genuine_fc_message):
|
||||
arguments_is_string = json.loads(genuine_fc_message)
|
||||
assert isinstance(arguments_is_string[0]["function"]["arguments"], str)
|
||||
|
||||
# but ensure_json_objects converts the string to a json object
|
||||
validator = JsonSchemaValidator()
|
||||
result = validator.recursive_json_to_object({"key": genuine_fc_message})
|
||||
|
||||
# we need this recursive json conversion to validate the message
|
||||
assert result["key"][0]["function"]["arguments"]["parameters"]["basehead"] == "main...amzn_chat"
|
||||
|
||||
# Validates multiple messages against a provided JSON schema successfully.
|
||||
def test_validates_multiple_messages_against_json_schema(self, json_schema_github_compare, genuine_fc_message):
|
||||
validator = JsonSchemaValidator()
|
||||
|
||||
messages = [
|
||||
ChatMessage.from_user("I'm not being validated, but the message after me is!"),
|
||||
ChatMessage.from_assistant(genuine_fc_message),
|
||||
]
|
||||
|
||||
result = validator.run(messages, json_schema_github_compare)
|
||||
|
||||
assert "validated" in result
|
||||
assert len(result["validated"]) == 2
|
||||
assert result["validated"] == messages
|
||||
|
||||
# Validates a message against an OpenAI function calling schema successfully.
|
||||
def test_validates_message_against_openai_function_calling_schema(
|
||||
self, json_schema_github_compare_openai, genuine_fc_message
|
||||
):
|
||||
validator = JsonSchemaValidator()
|
||||
|
||||
message = ChatMessage.from_assistant(genuine_fc_message)
|
||||
result = validator.run([message], json_schema_github_compare_openai)
|
||||
|
||||
assert "validated" in result
|
||||
assert len(result["validated"]) == 1
|
||||
assert result["validated"][0] == message
|
||||
|
||||
# Validates multiple messages against an OpenAI function calling schema successfully.
|
||||
def test_validates_multiple_messages_against_openai_function_calling_schema(
|
||||
self, json_schema_github_compare_openai, genuine_fc_message
|
||||
):
|
||||
validator = JsonSchemaValidator()
|
||||
|
||||
messages = [
|
||||
ChatMessage.from_system("Common use case is that this is for example system message"),
|
||||
ChatMessage.from_assistant(genuine_fc_message),
|
||||
]
|
||||
|
||||
result = validator.run(messages, json_schema_github_compare_openai)
|
||||
|
||||
assert "validated" in result
|
||||
assert len(result["validated"]) == 2
|
||||
assert result["validated"] == messages
|
||||
|
||||
# Constructs a custom error recovery message when validation fails.
|
||||
def test_construct_custom_error_recovery_message(self):
|
||||
validator = JsonSchemaValidator()
|
||||
|
||||
new_error_template = (
|
||||
"Error details:\n- Message: {error_message}\n"
|
||||
"- Error Path in JSON: {error_path}\n"
|
||||
"- Schema Path: {error_schema_path}\n"
|
||||
"Please match the following schema:\n"
|
||||
"{json_schema}\n"
|
||||
)
|
||||
|
||||
recovery_message = validator.construct_error_recovery_message(
|
||||
new_error_template, "Error message", "Error path", "Error schema path", {"type": "object"}
|
||||
)
|
||||
|
||||
expected_recovery_message = (
|
||||
"Error details:\n- Message: Error message\n"
|
||||
"- Error Path in JSON: Error path\n"
|
||||
"- Schema Path: Error schema path\n"
|
||||
"Please match the following schema:\n"
|
||||
"{'type': 'object'}\n"
|
||||
)
|
||||
assert recovery_message == expected_recovery_message
|
||||
|
||||
def test_schema_validator_in_pipeline_validated(self, json_schema_github_compare, genuine_fc_message):
|
||||
@component
|
||||
class ChatMessageProducer:
|
||||
@component.output_types(messages=List[ChatMessage])
|
||||
def run(self):
|
||||
return {"messages": [ChatMessage.from_assistant(genuine_fc_message)]}
|
||||
|
||||
pipe = Pipeline()
|
||||
pipe.add_component(name="schema_validator", instance=JsonSchemaValidator())
|
||||
pipe.add_component(name="message_producer", instance=ChatMessageProducer())
|
||||
pipe.connect("message_producer", "schema_validator")
|
||||
result = pipe.run(data={"schema_validator": {"json_schema": json_schema_github_compare}})
|
||||
assert "validated" in result["schema_validator"]
|
||||
assert len(result["schema_validator"]["validated"]) == 1
|
||||
assert result["schema_validator"]["validated"][0].content == genuine_fc_message
|
||||
|
||||
def test_schema_validator_in_pipeline_validation_error(self, json_schema_github_compare):
|
||||
@component
|
||||
class ChatMessageProducer:
|
||||
@component.output_types(messages=List[ChatMessage])
|
||||
def run(self):
|
||||
# example json string that is not valid
|
||||
simple_invalid_json = '{"key": "value"}'
|
||||
return {"messages": [ChatMessage.from_assistant(simple_invalid_json)]} # invalid message
|
||||
|
||||
pipe = Pipeline()
|
||||
pipe.add_component(name="schema_validator", instance=JsonSchemaValidator())
|
||||
pipe.add_component(name="message_producer", instance=ChatMessageProducer())
|
||||
pipe.connect("message_producer", "schema_validator")
|
||||
result = pipe.run(data={"schema_validator": {"json_schema": json_schema_github_compare}})
|
||||
assert "validation_error" in result["schema_validator"]
|
||||
assert len(result["schema_validator"]["validation_error"]) > 1
|
||||
assert "Error details" in result["schema_validator"]["validation_error"][1].content
|
||||
@ -21,3 +21,6 @@ openai-whisper>=20231106 # LocalWhisperTranscriber
|
||||
# OpenAPI
|
||||
jsonref # OpenAPIServiceConnector, OpenAPIServiceToFunctions
|
||||
openapi3
|
||||
|
||||
# Validation
|
||||
jsonschema
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user