feat: Add JsonSchemaValidator (#6937)

* Add JsonSchemaValidator
---------

Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>
This commit is contained in:
Vladimir Blagojevic 2024-02-15 14:07:01 +01:00 committed by GitHub
parent cf221a9701
commit 5a8d02064b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 457 additions and 0 deletions

View File

@ -0,0 +1,26 @@
loaders:
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
search_path: [../../../haystack/components/validators]
modules: ["json_schema"]
ignore_when_discovered: ["__init__"]
processors:
- type: filter
expression:
documented_only: true
do_not_filter_modules: false
skip_empty_modules: true
- type: smart
- type: crossref
renderer:
type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer
excerpt: Validators validate LLM outputs
category_slug: haystack-api
title: Validators
slug: validators-api
order: 155
markdown:
descriptive_class_title: false
descriptive_module_title: true
add_method_class_prefix: true
add_member_class_prefix: false
filename: validators_api.md

View File

@ -0,0 +1,3 @@
from haystack.components.validators.json_schema import JsonSchemaValidator
__all__ = ["JsonSchemaValidator"]

View File

@ -0,0 +1,212 @@
import json
from typing import List, Any, Dict, Optional
from haystack import component
from haystack.dataclasses import ChatMessage
from haystack.lazy_imports import LazyImport
with LazyImport(message="Run 'pip install jsonschema'") as jsonschema_import:
from jsonschema import validate, ValidationError
@component
class JsonSchemaValidator:
"""
JsonSchemaValidator validates JSON content of ChatMessage against a specified JSON schema.
If JSON content of a message conforms to the provided schema, the message is passed along the "validated" output.
If the JSON content does not conform to the schema, the message is passed along the "validation_error" output.
In the latter case, the error message is constructed using the provided error_template or a default template.
These error ChatMessages can be used by LLMs in Haystack 2.x recovery loops.
Here is a small example of how to use this component in a pipeline implementing schema validation recovery loop:
```python
from typing import List
from haystack import Pipeline
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.components.others import Multiplexer
from haystack.components.validators import JsonSchemaValidator
from haystack import component
from haystack.dataclasses import ChatMessage
@component
class MessageProducer:
@component.output_types(messages=List[ChatMessage])
def run(self, messages: List[ChatMessage]) -> dict:
return {"messages": messages}
p = Pipeline()
p.add_component("llm", OpenAIChatGenerator(model="gpt-4-1106-preview",
generation_kwargs={"response_format": {"type": "json_object"}}))
p.add_component("schema_validator", JsonSchemaValidator())
p.add_component("mx_for_llm", Multiplexer(List[ChatMessage]))
p.add_component("message_producer", MessageProducer())
p.connect("message_producer.messages", "mx_for_llm")
p.connect("mx_for_llm", "llm")
p.connect("llm.replies", "schema_validator.messages")
p.connect("schema_validator.validation_error", "mx_for_llm")
result = p.run(
data={"message_producer": {"messages":[ChatMessage.from_user("Generate JSON for person with name 'John' and age 30")]},
"schema_validator": {"json_schema": {"type": "object",
"properties": {"name": {"type": "string"},
"age": {"type": "integer"}}}}})
print(result)
>> {'schema_validator': {'validated': [ChatMessage(content='\n{\n "name": "John",\n "age": 30\n}',
role=<ChatRole.ASSISTANT: 'assistant'>, name=None, meta={'model': 'gpt-4-1106-preview', 'index': 0,
'finish_reason': 'stop', 'usage': {'completion_tokens': 17, 'prompt_tokens': 20, 'total_tokens': 37}})]}}
```
"""
# Default error description template
default_error_template = (
"The JSON content in the previous message does not conform to the provided schema.\n"
"Error details:\n- Message: {error_message}\n"
"- Error Path in JSON: {error_path}\n"
"- Schema Path: {error_schema_path}\n"
"Please match the following schema:\n"
"{json_schema}\n"
"and provide the corrected JSON content ONLY."
)
def __init__(self, json_schema: Optional[Dict[str, Any]] = None, error_template: Optional[str] = None):
"""
Initializes a new JsonSchemaValidator instance.
:param json_schema: A dictionary representing the JSON schema against which the messages' content is validated.
:param error_template: A custom template string for formatting the error message in case of validation failure.
"""
jsonschema_import.check()
self.json_schema = json_schema
self.error_template = error_template
@component.output_types(validated=List[ChatMessage], validation_error=List[ChatMessage])
def run(
self,
messages: List[ChatMessage],
json_schema: Optional[Dict[str, Any]] = None,
error_template: Optional[str] = None,
):
"""
Checks if the last message and its content field conforms to json_schema. If it does, the message is passed
along the "validated" output. If it does not, the message is passed along the "validation_error" output.
:param messages: A list of ChatMessage instances to be validated. The last message in this list is the one
that is validated.
:param json_schema: A dictionary representing the JSON schema against which the messages' content is validated.
:param error_template: A custom template string for formatting the error message in case of validation
failure, by default None.
"""
last_message = messages[-1]
last_message_content = json.loads(last_message.content)
json_schema = json_schema or self.json_schema
error_template = error_template or self.error_template or self.default_error_template
if not json_schema:
raise ValueError("Provide a JSON schema for validation either in the run method or in the component init.")
# fc payload is json object but subtree `parameters` is string - we need to convert to json object
# we need complete json to validate it against schema
last_message_json = self.recursive_json_to_object(last_message_content)
using_openai_schema: bool = self.is_openai_function_calling_schema(json_schema)
if using_openai_schema:
validation_schema = json_schema["parameters"]
else:
validation_schema = json_schema
try:
last_message_json = [last_message_json] if not isinstance(last_message_json, list) else last_message_json
for content in last_message_json:
if using_openai_schema:
validate(instance=content["function"]["arguments"]["parameters"], schema=validation_schema)
else:
validate(instance=content, schema=validation_schema)
return {"validated": messages}
except ValidationError as e:
error_path = " -> ".join(map(str, e.absolute_path)) if e.absolute_path else "N/A"
error_schema_path = " -> ".join(map(str, e.absolute_schema_path)) if e.absolute_schema_path else "N/A"
error_template = error_template or self.default_error_template
recovery_prompt = self.construct_error_recovery_message(
error_template, str(e), error_path, error_schema_path, validation_schema
)
complete_message_list = messages + [ChatMessage.from_user(recovery_prompt)]
return {"validation_error": complete_message_list}
def construct_error_recovery_message(
self,
error_template: str,
error_message: str,
error_path: str,
error_schema_path: str,
json_schema: Dict[str, Any],
) -> str:
"""
Constructs an error recovery message using a specified template or the default one if none is provided.
:param error_template: A custom template string for formatting the error message in case of validation failure.
:param error_message: The error message returned by the JSON schema validator.
:param error_path: The path in the JSON content where the error occurred.
:param error_schema_path: The path in the JSON schema where the error occurred.
:param json_schema: The JSON schema against which the content is validated.
"""
error_template = error_template or self.default_error_template
return error_template.format(
error_message=error_message,
error_path=error_path,
error_schema_path=error_schema_path,
json_schema=json_schema,
)
def is_openai_function_calling_schema(self, json_schema: Dict[str, Any]) -> bool:
"""
Checks if the provided schema is a valid OpenAI function calling schema.
:param json_schema: The JSON schema to check
:return: True if the schema is a valid OpenAI function calling schema; otherwise, False.
"""
return all(key in json_schema for key in ["name", "description", "parameters"])
def recursive_json_to_object(self, data: Any) -> Any:
"""
Recursively traverses a data structure (dictionary or list), converting any string values
that are valid JSON objects into dictionary objects, and returns a new data structure.
:param data: The data structure to be traversed.
:return: A new data structure with JSON strings converted to dictionary objects.
"""
if isinstance(data, list):
return [self.recursive_json_to_object(item) for item in data]
if isinstance(data, dict):
new_dict = {}
for key, value in data.items():
if isinstance(value, str):
try:
json_value = json.loads(value)
new_dict[key] = (
self.recursive_json_to_object(json_value)
if isinstance(json_value, (dict, list))
else json_value
)
except json.JSONDecodeError:
new_dict[key] = value
elif isinstance(value, dict):
new_dict[key] = self.recursive_json_to_object(value)
else:
new_dict[key] = value
return new_dict
# If it's neither a list nor a dictionary, return the value directly
raise ValueError("Input must be a dictionary or a list of dictionaries.")

View File

@ -0,0 +1,4 @@
---
features:
- |
Introduced JsonSchemaValidator to validate the JSON content of ChatMessage against a provided JSON schema. Valid messages are emitted through the 'validated' output, while messages failing validation are sent via the 'validation_error' output, along with useful error details for troubleshooting.

View File

@ -0,0 +1,209 @@
import json
from typing import List
from haystack import component, Pipeline
from haystack.components.validators import JsonSchemaValidator
import pytest
from haystack.dataclasses import ChatMessage
@pytest.fixture
def genuine_fc_message():
return """[{"id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", "function": {"arguments": "{\\n \\"parameters\\": {\\n \\"basehead\\": \\"main...amzn_chat\\",\\n \\"owner\\": \\"deepset-ai\\",\\n \\"repo\\": \\"haystack-core-integrations\\"\\n }\\n}", "name": "compare_branches"}, "type": "function"}]"""
@pytest.fixture
def json_schema_github_compare():
json_schema = {
"type": "object",
"properties": {
"id": {"type": "string", "description": "A unique identifier for the call"},
"function": {
"type": "object",
"properties": {
"arguments": {
"type": "object",
"properties": {
"parameters": {
"type": "object",
"properties": {
"basehead": {
"type": "string",
"pattern": "^[^\\.]+(\\.{3}).+$",
"description": "Branch names must be in the format 'base_branch...head_branch'",
},
"owner": {"type": "string", "description": "Owner of the repository"},
"repo": {"type": "string", "description": "Name of the repository"},
},
"required": ["basehead", "owner", "repo"],
"description": "Parameters for the function call",
}
},
"required": ["parameters"],
"description": "Arguments for the function",
},
"name": {"type": "string", "description": "Name of the function to be called"},
},
"required": ["arguments", "name"],
"description": "Details of the function being called",
},
"type": {"type": "string", "description": "Type of the call (e.g., 'function')"},
},
"required": ["function", "type"],
"description": "Structure representing a function call",
}
return json_schema
@pytest.fixture
def json_schema_github_compare_openai():
json_schema = {
"name": "compare_branches",
"description": "Compares two branches in a GitHub repository",
"parameters": {
"type": "object",
"properties": {
"basehead": {
"type": "string",
"pattern": "^[^\\.]+(\\.{3}).+$",
"description": "Branch names must be in the format 'base_branch...head_branch'",
},
"owner": {"type": "string", "description": "Owner of the repository"},
"repo": {"type": "string", "description": "Name of the repository"},
},
"required": ["basehead", "owner", "repo"],
"description": "Parameters for the function call",
},
}
return json_schema
class TestJsonSchemaValidator:
# Validates a message against a provided JSON schema successfully.
def test_validates_message_against_json_schema(self, json_schema_github_compare, genuine_fc_message):
validator = JsonSchemaValidator()
message = ChatMessage.from_assistant(genuine_fc_message)
result = validator.run([message], json_schema_github_compare)
assert "validated" in result
assert len(result["validated"]) == 1
assert result["validated"][0] == message
# Validates recursive_json_to_object method
def test_recursive_json_to_object(self, genuine_fc_message):
arguments_is_string = json.loads(genuine_fc_message)
assert isinstance(arguments_is_string[0]["function"]["arguments"], str)
# but ensure_json_objects converts the string to a json object
validator = JsonSchemaValidator()
result = validator.recursive_json_to_object({"key": genuine_fc_message})
# we need this recursive json conversion to validate the message
assert result["key"][0]["function"]["arguments"]["parameters"]["basehead"] == "main...amzn_chat"
# Validates multiple messages against a provided JSON schema successfully.
def test_validates_multiple_messages_against_json_schema(self, json_schema_github_compare, genuine_fc_message):
validator = JsonSchemaValidator()
messages = [
ChatMessage.from_user("I'm not being validated, but the message after me is!"),
ChatMessage.from_assistant(genuine_fc_message),
]
result = validator.run(messages, json_schema_github_compare)
assert "validated" in result
assert len(result["validated"]) == 2
assert result["validated"] == messages
# Validates a message against an OpenAI function calling schema successfully.
def test_validates_message_against_openai_function_calling_schema(
self, json_schema_github_compare_openai, genuine_fc_message
):
validator = JsonSchemaValidator()
message = ChatMessage.from_assistant(genuine_fc_message)
result = validator.run([message], json_schema_github_compare_openai)
assert "validated" in result
assert len(result["validated"]) == 1
assert result["validated"][0] == message
# Validates multiple messages against an OpenAI function calling schema successfully.
def test_validates_multiple_messages_against_openai_function_calling_schema(
self, json_schema_github_compare_openai, genuine_fc_message
):
validator = JsonSchemaValidator()
messages = [
ChatMessage.from_system("Common use case is that this is for example system message"),
ChatMessage.from_assistant(genuine_fc_message),
]
result = validator.run(messages, json_schema_github_compare_openai)
assert "validated" in result
assert len(result["validated"]) == 2
assert result["validated"] == messages
# Constructs a custom error recovery message when validation fails.
def test_construct_custom_error_recovery_message(self):
validator = JsonSchemaValidator()
new_error_template = (
"Error details:\n- Message: {error_message}\n"
"- Error Path in JSON: {error_path}\n"
"- Schema Path: {error_schema_path}\n"
"Please match the following schema:\n"
"{json_schema}\n"
)
recovery_message = validator.construct_error_recovery_message(
new_error_template, "Error message", "Error path", "Error schema path", {"type": "object"}
)
expected_recovery_message = (
"Error details:\n- Message: Error message\n"
"- Error Path in JSON: Error path\n"
"- Schema Path: Error schema path\n"
"Please match the following schema:\n"
"{'type': 'object'}\n"
)
assert recovery_message == expected_recovery_message
def test_schema_validator_in_pipeline_validated(self, json_schema_github_compare, genuine_fc_message):
@component
class ChatMessageProducer:
@component.output_types(messages=List[ChatMessage])
def run(self):
return {"messages": [ChatMessage.from_assistant(genuine_fc_message)]}
pipe = Pipeline()
pipe.add_component(name="schema_validator", instance=JsonSchemaValidator())
pipe.add_component(name="message_producer", instance=ChatMessageProducer())
pipe.connect("message_producer", "schema_validator")
result = pipe.run(data={"schema_validator": {"json_schema": json_schema_github_compare}})
assert "validated" in result["schema_validator"]
assert len(result["schema_validator"]["validated"]) == 1
assert result["schema_validator"]["validated"][0].content == genuine_fc_message
def test_schema_validator_in_pipeline_validation_error(self, json_schema_github_compare):
@component
class ChatMessageProducer:
@component.output_types(messages=List[ChatMessage])
def run(self):
# example json string that is not valid
simple_invalid_json = '{"key": "value"}'
return {"messages": [ChatMessage.from_assistant(simple_invalid_json)]} # invalid message
pipe = Pipeline()
pipe.add_component(name="schema_validator", instance=JsonSchemaValidator())
pipe.add_component(name="message_producer", instance=ChatMessageProducer())
pipe.connect("message_producer", "schema_validator")
result = pipe.run(data={"schema_validator": {"json_schema": json_schema_github_compare}})
assert "validation_error" in result["schema_validator"]
assert len(result["schema_validator"]["validation_error"]) > 1
assert "Error details" in result["schema_validator"]["validation_error"][1].content

View File

@ -21,3 +21,6 @@ openai-whisper>=20231106 # LocalWhisperTranscriber
# OpenAPI
jsonref # OpenAPIServiceConnector, OpenAPIServiceToFunctions
openapi3
# Validation
jsonschema