diff --git a/docs/pydoc/config/validators_api.yml b/docs/pydoc/config/validators_api.yml new file mode 100644 index 000000000..06c37b24b --- /dev/null +++ b/docs/pydoc/config/validators_api.yml @@ -0,0 +1,26 @@ +loaders: + - type: haystack_pydoc_tools.loaders.CustomPythonLoader + search_path: [../../../haystack/components/validators] + modules: ["json_schema"] + ignore_when_discovered: ["__init__"] +processors: + - type: filter + expression: + documented_only: true + do_not_filter_modules: false + skip_empty_modules: true + - type: smart + - type: crossref +renderer: + type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + excerpt: Validators validate LLM outputs + category_slug: haystack-api + title: Validators + slug: validators-api + order: 155 + markdown: + descriptive_class_title: false + descriptive_module_title: true + add_method_class_prefix: true + add_member_class_prefix: false + filename: validators_api.md diff --git a/haystack/components/validators/__init__.py b/haystack/components/validators/__init__.py new file mode 100644 index 000000000..46467862a --- /dev/null +++ b/haystack/components/validators/__init__.py @@ -0,0 +1,3 @@ +from haystack.components.validators.json_schema import JsonSchemaValidator + +__all__ = ["JsonSchemaValidator"] diff --git a/haystack/components/validators/json_schema.py b/haystack/components/validators/json_schema.py new file mode 100644 index 000000000..e3eed9542 --- /dev/null +++ b/haystack/components/validators/json_schema.py @@ -0,0 +1,212 @@ +import json +from typing import List, Any, Dict, Optional + +from haystack import component +from haystack.dataclasses import ChatMessage +from haystack.lazy_imports import LazyImport + +with LazyImport(message="Run 'pip install jsonschema'") as jsonschema_import: + from jsonschema import validate, ValidationError + + +@component +class JsonSchemaValidator: + """ + JsonSchemaValidator validates JSON content of ChatMessage against a specified JSON schema. + + If JSON content of a message conforms to the provided schema, the message is passed along the "validated" output. + If the JSON content does not conform to the schema, the message is passed along the "validation_error" output. + In the latter case, the error message is constructed using the provided error_template or a default template. + These error ChatMessages can be used by LLMs in Haystack 2.x recovery loops. + + Here is a small example of how to use this component in a pipeline implementing schema validation recovery loop: + + ```python + from typing import List + + from haystack import Pipeline + from haystack.components.generators.chat import OpenAIChatGenerator + from haystack.components.others import Multiplexer + from haystack.components.validators import JsonSchemaValidator + from haystack import component + from haystack.dataclasses import ChatMessage + + + @component + class MessageProducer: + + @component.output_types(messages=List[ChatMessage]) + def run(self, messages: List[ChatMessage]) -> dict: + return {"messages": messages} + + + p = Pipeline() + p.add_component("llm", OpenAIChatGenerator(model="gpt-4-1106-preview", + generation_kwargs={"response_format": {"type": "json_object"}})) + p.add_component("schema_validator", JsonSchemaValidator()) + p.add_component("mx_for_llm", Multiplexer(List[ChatMessage])) + p.add_component("message_producer", MessageProducer()) + + p.connect("message_producer.messages", "mx_for_llm") + p.connect("mx_for_llm", "llm") + p.connect("llm.replies", "schema_validator.messages") + p.connect("schema_validator.validation_error", "mx_for_llm") + + + + result = p.run( + data={"message_producer": {"messages":[ChatMessage.from_user("Generate JSON for person with name 'John' and age 30")]}, + "schema_validator": {"json_schema": {"type": "object", + "properties": {"name": {"type": "string"}, + "age": {"type": "integer"}}}}}) + print(result) + >> {'schema_validator': {'validated': [ChatMessage(content='\n{\n "name": "John",\n "age": 30\n}', + role=, name=None, meta={'model': 'gpt-4-1106-preview', 'index': 0, + 'finish_reason': 'stop', 'usage': {'completion_tokens': 17, 'prompt_tokens': 20, 'total_tokens': 37}})]}} + ``` + """ + + # Default error description template + default_error_template = ( + "The JSON content in the previous message does not conform to the provided schema.\n" + "Error details:\n- Message: {error_message}\n" + "- Error Path in JSON: {error_path}\n" + "- Schema Path: {error_schema_path}\n" + "Please match the following schema:\n" + "{json_schema}\n" + "and provide the corrected JSON content ONLY." + ) + + def __init__(self, json_schema: Optional[Dict[str, Any]] = None, error_template: Optional[str] = None): + """ + Initializes a new JsonSchemaValidator instance. + + :param json_schema: A dictionary representing the JSON schema against which the messages' content is validated. + :param error_template: A custom template string for formatting the error message in case of validation failure. + """ + jsonschema_import.check() + self.json_schema = json_schema + self.error_template = error_template + + @component.output_types(validated=List[ChatMessage], validation_error=List[ChatMessage]) + def run( + self, + messages: List[ChatMessage], + json_schema: Optional[Dict[str, Any]] = None, + error_template: Optional[str] = None, + ): + """ + Checks if the last message and its content field conforms to json_schema. If it does, the message is passed + along the "validated" output. If it does not, the message is passed along the "validation_error" output. + + :param messages: A list of ChatMessage instances to be validated. The last message in this list is the one + that is validated. + :param json_schema: A dictionary representing the JSON schema against which the messages' content is validated. + :param error_template: A custom template string for formatting the error message in case of validation + failure, by default None. + """ + last_message = messages[-1] + last_message_content = json.loads(last_message.content) + + json_schema = json_schema or self.json_schema + error_template = error_template or self.error_template or self.default_error_template + + if not json_schema: + raise ValueError("Provide a JSON schema for validation either in the run method or in the component init.") + + # fc payload is json object but subtree `parameters` is string - we need to convert to json object + # we need complete json to validate it against schema + last_message_json = self.recursive_json_to_object(last_message_content) + using_openai_schema: bool = self.is_openai_function_calling_schema(json_schema) + if using_openai_schema: + validation_schema = json_schema["parameters"] + else: + validation_schema = json_schema + try: + last_message_json = [last_message_json] if not isinstance(last_message_json, list) else last_message_json + for content in last_message_json: + if using_openai_schema: + validate(instance=content["function"]["arguments"]["parameters"], schema=validation_schema) + else: + validate(instance=content, schema=validation_schema) + + return {"validated": messages} + except ValidationError as e: + error_path = " -> ".join(map(str, e.absolute_path)) if e.absolute_path else "N/A" + error_schema_path = " -> ".join(map(str, e.absolute_schema_path)) if e.absolute_schema_path else "N/A" + + error_template = error_template or self.default_error_template + + recovery_prompt = self.construct_error_recovery_message( + error_template, str(e), error_path, error_schema_path, validation_schema + ) + complete_message_list = messages + [ChatMessage.from_user(recovery_prompt)] + return {"validation_error": complete_message_list} + + def construct_error_recovery_message( + self, + error_template: str, + error_message: str, + error_path: str, + error_schema_path: str, + json_schema: Dict[str, Any], + ) -> str: + """ + Constructs an error recovery message using a specified template or the default one if none is provided. + + :param error_template: A custom template string for formatting the error message in case of validation failure. + :param error_message: The error message returned by the JSON schema validator. + :param error_path: The path in the JSON content where the error occurred. + :param error_schema_path: The path in the JSON schema where the error occurred. + :param json_schema: The JSON schema against which the content is validated. + """ + error_template = error_template or self.default_error_template + + return error_template.format( + error_message=error_message, + error_path=error_path, + error_schema_path=error_schema_path, + json_schema=json_schema, + ) + + def is_openai_function_calling_schema(self, json_schema: Dict[str, Any]) -> bool: + """ + Checks if the provided schema is a valid OpenAI function calling schema. + + :param json_schema: The JSON schema to check + :return: True if the schema is a valid OpenAI function calling schema; otherwise, False. + """ + return all(key in json_schema for key in ["name", "description", "parameters"]) + + def recursive_json_to_object(self, data: Any) -> Any: + """ + Recursively traverses a data structure (dictionary or list), converting any string values + that are valid JSON objects into dictionary objects, and returns a new data structure. + + :param data: The data structure to be traversed. + :return: A new data structure with JSON strings converted to dictionary objects. + """ + if isinstance(data, list): + return [self.recursive_json_to_object(item) for item in data] + + if isinstance(data, dict): + new_dict = {} + for key, value in data.items(): + if isinstance(value, str): + try: + json_value = json.loads(value) + new_dict[key] = ( + self.recursive_json_to_object(json_value) + if isinstance(json_value, (dict, list)) + else json_value + ) + except json.JSONDecodeError: + new_dict[key] = value + elif isinstance(value, dict): + new_dict[key] = self.recursive_json_to_object(value) + else: + new_dict[key] = value + return new_dict + + # If it's neither a list nor a dictionary, return the value directly + raise ValueError("Input must be a dictionary or a list of dictionaries.") diff --git a/releasenotes/notes/introduce-jsonschema-validator-65debc51a3b64975.yaml b/releasenotes/notes/introduce-jsonschema-validator-65debc51a3b64975.yaml new file mode 100644 index 000000000..915b24bcd --- /dev/null +++ b/releasenotes/notes/introduce-jsonschema-validator-65debc51a3b64975.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + Introduced JsonSchemaValidator to validate the JSON content of ChatMessage against a provided JSON schema. Valid messages are emitted through the 'validated' output, while messages failing validation are sent via the 'validation_error' output, along with useful error details for troubleshooting. diff --git a/test/components/validators/test_json_schema.py b/test/components/validators/test_json_schema.py new file mode 100644 index 000000000..e21bf6ebc --- /dev/null +++ b/test/components/validators/test_json_schema.py @@ -0,0 +1,209 @@ +import json +from typing import List + +from haystack import component, Pipeline +from haystack.components.validators import JsonSchemaValidator + +import pytest + +from haystack.dataclasses import ChatMessage + + +@pytest.fixture +def genuine_fc_message(): + return """[{"id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", "function": {"arguments": "{\\n \\"parameters\\": {\\n \\"basehead\\": \\"main...amzn_chat\\",\\n \\"owner\\": \\"deepset-ai\\",\\n \\"repo\\": \\"haystack-core-integrations\\"\\n }\\n}", "name": "compare_branches"}, "type": "function"}]""" + + +@pytest.fixture +def json_schema_github_compare(): + json_schema = { + "type": "object", + "properties": { + "id": {"type": "string", "description": "A unique identifier for the call"}, + "function": { + "type": "object", + "properties": { + "arguments": { + "type": "object", + "properties": { + "parameters": { + "type": "object", + "properties": { + "basehead": { + "type": "string", + "pattern": "^[^\\.]+(\\.{3}).+$", + "description": "Branch names must be in the format 'base_branch...head_branch'", + }, + "owner": {"type": "string", "description": "Owner of the repository"}, + "repo": {"type": "string", "description": "Name of the repository"}, + }, + "required": ["basehead", "owner", "repo"], + "description": "Parameters for the function call", + } + }, + "required": ["parameters"], + "description": "Arguments for the function", + }, + "name": {"type": "string", "description": "Name of the function to be called"}, + }, + "required": ["arguments", "name"], + "description": "Details of the function being called", + }, + "type": {"type": "string", "description": "Type of the call (e.g., 'function')"}, + }, + "required": ["function", "type"], + "description": "Structure representing a function call", + } + return json_schema + + +@pytest.fixture +def json_schema_github_compare_openai(): + json_schema = { + "name": "compare_branches", + "description": "Compares two branches in a GitHub repository", + "parameters": { + "type": "object", + "properties": { + "basehead": { + "type": "string", + "pattern": "^[^\\.]+(\\.{3}).+$", + "description": "Branch names must be in the format 'base_branch...head_branch'", + }, + "owner": {"type": "string", "description": "Owner of the repository"}, + "repo": {"type": "string", "description": "Name of the repository"}, + }, + "required": ["basehead", "owner", "repo"], + "description": "Parameters for the function call", + }, + } + return json_schema + + +class TestJsonSchemaValidator: + # Validates a message against a provided JSON schema successfully. + def test_validates_message_against_json_schema(self, json_schema_github_compare, genuine_fc_message): + validator = JsonSchemaValidator() + message = ChatMessage.from_assistant(genuine_fc_message) + + result = validator.run([message], json_schema_github_compare) + + assert "validated" in result + assert len(result["validated"]) == 1 + assert result["validated"][0] == message + + # Validates recursive_json_to_object method + def test_recursive_json_to_object(self, genuine_fc_message): + arguments_is_string = json.loads(genuine_fc_message) + assert isinstance(arguments_is_string[0]["function"]["arguments"], str) + + # but ensure_json_objects converts the string to a json object + validator = JsonSchemaValidator() + result = validator.recursive_json_to_object({"key": genuine_fc_message}) + + # we need this recursive json conversion to validate the message + assert result["key"][0]["function"]["arguments"]["parameters"]["basehead"] == "main...amzn_chat" + + # Validates multiple messages against a provided JSON schema successfully. + def test_validates_multiple_messages_against_json_schema(self, json_schema_github_compare, genuine_fc_message): + validator = JsonSchemaValidator() + + messages = [ + ChatMessage.from_user("I'm not being validated, but the message after me is!"), + ChatMessage.from_assistant(genuine_fc_message), + ] + + result = validator.run(messages, json_schema_github_compare) + + assert "validated" in result + assert len(result["validated"]) == 2 + assert result["validated"] == messages + + # Validates a message against an OpenAI function calling schema successfully. + def test_validates_message_against_openai_function_calling_schema( + self, json_schema_github_compare_openai, genuine_fc_message + ): + validator = JsonSchemaValidator() + + message = ChatMessage.from_assistant(genuine_fc_message) + result = validator.run([message], json_schema_github_compare_openai) + + assert "validated" in result + assert len(result["validated"]) == 1 + assert result["validated"][0] == message + + # Validates multiple messages against an OpenAI function calling schema successfully. + def test_validates_multiple_messages_against_openai_function_calling_schema( + self, json_schema_github_compare_openai, genuine_fc_message + ): + validator = JsonSchemaValidator() + + messages = [ + ChatMessage.from_system("Common use case is that this is for example system message"), + ChatMessage.from_assistant(genuine_fc_message), + ] + + result = validator.run(messages, json_schema_github_compare_openai) + + assert "validated" in result + assert len(result["validated"]) == 2 + assert result["validated"] == messages + + # Constructs a custom error recovery message when validation fails. + def test_construct_custom_error_recovery_message(self): + validator = JsonSchemaValidator() + + new_error_template = ( + "Error details:\n- Message: {error_message}\n" + "- Error Path in JSON: {error_path}\n" + "- Schema Path: {error_schema_path}\n" + "Please match the following schema:\n" + "{json_schema}\n" + ) + + recovery_message = validator.construct_error_recovery_message( + new_error_template, "Error message", "Error path", "Error schema path", {"type": "object"} + ) + + expected_recovery_message = ( + "Error details:\n- Message: Error message\n" + "- Error Path in JSON: Error path\n" + "- Schema Path: Error schema path\n" + "Please match the following schema:\n" + "{'type': 'object'}\n" + ) + assert recovery_message == expected_recovery_message + + def test_schema_validator_in_pipeline_validated(self, json_schema_github_compare, genuine_fc_message): + @component + class ChatMessageProducer: + @component.output_types(messages=List[ChatMessage]) + def run(self): + return {"messages": [ChatMessage.from_assistant(genuine_fc_message)]} + + pipe = Pipeline() + pipe.add_component(name="schema_validator", instance=JsonSchemaValidator()) + pipe.add_component(name="message_producer", instance=ChatMessageProducer()) + pipe.connect("message_producer", "schema_validator") + result = pipe.run(data={"schema_validator": {"json_schema": json_schema_github_compare}}) + assert "validated" in result["schema_validator"] + assert len(result["schema_validator"]["validated"]) == 1 + assert result["schema_validator"]["validated"][0].content == genuine_fc_message + + def test_schema_validator_in_pipeline_validation_error(self, json_schema_github_compare): + @component + class ChatMessageProducer: + @component.output_types(messages=List[ChatMessage]) + def run(self): + # example json string that is not valid + simple_invalid_json = '{"key": "value"}' + return {"messages": [ChatMessage.from_assistant(simple_invalid_json)]} # invalid message + + pipe = Pipeline() + pipe.add_component(name="schema_validator", instance=JsonSchemaValidator()) + pipe.add_component(name="message_producer", instance=ChatMessageProducer()) + pipe.connect("message_producer", "schema_validator") + result = pipe.run(data={"schema_validator": {"json_schema": json_schema_github_compare}}) + assert "validation_error" in result["schema_validator"] + assert len(result["schema_validator"]["validation_error"]) > 1 + assert "Error details" in result["schema_validator"]["validation_error"][1].content diff --git a/test/test_requirements.txt b/test/test_requirements.txt index db6e3826d..336895054 100644 --- a/test/test_requirements.txt +++ b/test/test_requirements.txt @@ -21,3 +21,6 @@ openai-whisper>=20231106 # LocalWhisperTranscriber # OpenAPI jsonref # OpenAPIServiceConnector, OpenAPIServiceToFunctions openapi3 + +# Validation +jsonschema