docs: review docstrings in haystack.components.validators (#7238)

* chore: make private

* docs: review and normalize docstrings

* docs: fix format and unused import
This commit is contained in:
Tobias Wochinger 2024-02-28 17:46:30 +01:00 committed by GitHub
parent c4b54bcac0
commit e5f0e248b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 34 additions and 28 deletions

View File

@ -12,14 +12,14 @@ with LazyImport(message="Run 'pip install jsonschema'") as jsonschema_import:
@component
class JsonSchemaValidator:
"""
JsonSchemaValidator validates JSON content of ChatMessage against a specified JSON schema.
Validates JSON content of `ChatMessage` against a specified [JSON Schema](https://json-schema.org/).
If JSON content of a message conforms to the provided schema, the message is passed along the "validated" output.
If the JSON content does not conform to the schema, the message is passed along the "validation_error" output.
In the latter case, the error message is constructed using the provided error_template or a default template.
In the latter case, the error message is constructed using the provided `error_template` or a default template.
These error ChatMessages can be used by LLMs in Haystack 2.x recovery loops.
Here is a small example of how to use this component in a pipeline implementing schema validation recovery loop:
Usage example:
```python
from typing import List
@ -52,15 +52,13 @@ class JsonSchemaValidator:
p.connect("llm.replies", "schema_validator.messages")
p.connect("schema_validator.validation_error", "mx_for_llm")
result = p.run(
data={"message_producer": {"messages":[ChatMessage.from_user("Generate JSON for person with name 'John' and age 30")]},
"schema_validator": {"json_schema": {"type": "object",
"properties": {"name": {"type": "string"},
"age": {"type": "integer"}}}}})
print(result)
>> {'schema_validator': {'validated': [ChatMessage(content='\n{\n "name": "John",\n "age": 30\n}',
>> {'schema_validator': {'validated': [ChatMessage(content='\\n{\\n "name": "John",\\n "age": 30\\n}',
role=<ChatRole.ASSISTANT: 'assistant'>, name=None, meta={'model': 'gpt-4-1106-preview', 'index': 0,
'finish_reason': 'stop', 'usage': {'completion_tokens': 17, 'prompt_tokens': 20, 'total_tokens': 37}})]}}
```
@ -79,9 +77,8 @@ class JsonSchemaValidator:
def __init__(self, json_schema: Optional[Dict[str, Any]] = None, error_template: Optional[str] = None):
"""
Initializes a new JsonSchemaValidator instance.
:param json_schema: A dictionary representing the JSON schema against which the messages' content is validated.
:param json_schema: A dictionary representing the [JSON schema](https://json-schema.org/) against which
the messages' content is validated.
:param error_template: A custom template string for formatting the error message in case of validation failure.
"""
jsonschema_import.check()
@ -94,16 +91,25 @@ class JsonSchemaValidator:
messages: List[ChatMessage],
json_schema: Optional[Dict[str, Any]] = None,
error_template: Optional[str] = None,
):
) -> Dict[str, List[ChatMessage]]:
"""
Checks if the last message and its content field conforms to json_schema. If it does, the message is passed
along the "validated" output. If it does not, the message is passed along the "validation_error" output.
Validates the last of the provided messages against the specified json schema.
If it does, the message is passed along the "validated" output. If it does not, the message is passed along
the "validation_error" output.
:param messages: A list of ChatMessage instances to be validated. The last message in this list is the one
that is validated.
:param json_schema: A dictionary representing the JSON schema against which the messages' content is validated.
:param error_template: A custom template string for formatting the error message in case of validation
failure, by default None.
that is validated.
:param json_schema: A dictionary representing the [JSON schema](https://json-schema.org/)
against which the messages' content is validated. If not provided, the schema from the component init
is used.
:param error_template: A custom template string for formatting the error message in case of validation. If not
provided, the `error_template` from the component init is used.
:return: A dictionary with the following keys:
- "validated": A list of messages if the last message is valid.
- "validation_error": A list of messages if the last message is invalid.
:raises ValueError: If no JSON schema is provided or if the message content is not a dictionary or a list of
dictionaries.
"""
last_message = messages[-1]
last_message_content = json.loads(last_message.content)
@ -116,8 +122,8 @@ class JsonSchemaValidator:
# fc payload is json object but subtree `parameters` is string - we need to convert to json object
# we need complete json to validate it against schema
last_message_json = self.recursive_json_to_object(last_message_content)
using_openai_schema: bool = self.is_openai_function_calling_schema(json_schema)
last_message_json = self._recursive_json_to_object(last_message_content)
using_openai_schema: bool = self._is_openai_function_calling_schema(json_schema)
if using_openai_schema:
validation_schema = json_schema["parameters"]
else:
@ -137,13 +143,13 @@ class JsonSchemaValidator:
error_template = error_template or self.default_error_template
recovery_prompt = self.construct_error_recovery_message(
recovery_prompt = self._construct_error_recovery_message(
error_template, str(e), error_path, error_schema_path, validation_schema
)
complete_message_list = [ChatMessage.from_user(recovery_prompt)] + messages
return {"validation_error": complete_message_list}
def construct_error_recovery_message(
def _construct_error_recovery_message(
self,
error_template: str,
error_message: str,
@ -169,16 +175,16 @@ class JsonSchemaValidator:
json_schema=json_schema,
)
def is_openai_function_calling_schema(self, json_schema: Dict[str, Any]) -> bool:
def _is_openai_function_calling_schema(self, json_schema: Dict[str, Any]) -> bool:
"""
Checks if the provided schema is a valid OpenAI function calling schema.
:param json_schema: The JSON schema to check
:return: True if the schema is a valid OpenAI function calling schema; otherwise, False.
:return: `True` if the schema is a valid OpenAI function calling schema; otherwise, `False`.
"""
return all(key in json_schema for key in ["name", "description", "parameters"])
def recursive_json_to_object(self, data: Any) -> Any:
def _recursive_json_to_object(self, data: Any) -> Any:
"""
Recursively traverses a data structure (dictionary or list), converting any string values
that are valid JSON objects into dictionary objects, and returns a new data structure.
@ -187,7 +193,7 @@ class JsonSchemaValidator:
:return: A new data structure with JSON strings converted to dictionary objects.
"""
if isinstance(data, list):
return [self.recursive_json_to_object(item) for item in data]
return [self._recursive_json_to_object(item) for item in data]
if isinstance(data, dict):
new_dict = {}
@ -196,14 +202,14 @@ class JsonSchemaValidator:
try:
json_value = json.loads(value)
new_dict[key] = (
self.recursive_json_to_object(json_value)
self._recursive_json_to_object(json_value)
if isinstance(json_value, (dict, list))
else json_value
)
except json.JSONDecodeError:
new_dict[key] = value
elif isinstance(value, dict):
new_dict[key] = self.recursive_json_to_object(value)
new_dict[key] = self._recursive_json_to_object(value)
else:
new_dict[key] = value
return new_dict

View File

@ -92,7 +92,7 @@ class TestJsonSchemaValidator:
# but ensure_json_objects converts the string to a json object
validator = JsonSchemaValidator()
result = validator.recursive_json_to_object({"key": genuine_fc_message})
result = validator._recursive_json_to_object({"key": genuine_fc_message})
# we need this recursive json conversion to validate the message
assert result["key"][0]["function"]["arguments"]["basehead"] == "main...amzn_chat"
@ -154,7 +154,7 @@ class TestJsonSchemaValidator:
"{json_schema}\n"
)
recovery_message = validator.construct_error_recovery_message(
recovery_message = validator._construct_error_recovery_message(
new_error_template, "Error message", "Error path", "Error schema path", {"type": "object"}
)