diff --git a/.github/workflows/snippets_tests.yml b/.github/workflows/snippets_tests.yml index 8be6a8b1f..a7c28713a 100644 --- a/.github/workflows/snippets_tests.yml +++ b/.github/workflows/snippets_tests.yml @@ -34,6 +34,7 @@ jobs: run: | pip install --upgrade pip pip install "." torch + pip install pydantic - name: Get changed files id: files diff --git a/examples/pipeline_loop_to_autocorrect_json.py b/examples/pipeline_loop_to_autocorrect_json.py new file mode 100644 index 000000000..885c5a3ea --- /dev/null +++ b/examples/pipeline_loop_to_autocorrect_json.py @@ -0,0 +1,100 @@ +import json +import os + +from haystack import Pipeline +from haystack.components.generators.openai import GPTGenerator +from haystack.components.builders.prompt_builder import PromptBuilder +import random +from haystack import component +from typing import Optional, List + +import pydantic +from pydantic import BaseModel, ValidationError + +import logging + +logging.basicConfig() +logging.getLogger("canals.pipeline.pipeline").setLevel(logging.DEBUG) + + +# Let's define a simple schema for the data we want to extract from a passsage via the LLM +# We want the output from our LLM to be always compliant with this +class City(BaseModel): + name: str + country: str + population: int + +class CitiesData(BaseModel): + cities: List[City] + +schema = CitiesData.schema_json(indent=2) + + +# We then create a simple, custom Haystack component that takes the LLM output +# and validates if this is compliant with our schema. +# If not, it returns also the error message so that we have a better chance of correcting it in the next loop +@component +class OutputParser(): + def __init__(self, pydantic_model:pydantic.BaseModel): + self.pydantic_model = pydantic_model + self.iteration_counter = 0 + + @component.output_types(valid=List[str], + invalid=Optional[List[str]], + error_message=Optional[str]) + def run( + self, + replies: List[str]): + + self.iteration_counter += 1 + + # let's simulate a corrupt JSON with 30% probability by adding extra brackets (for demo purposes) + if random.randint(0, 100) < 30: + replies[0] = "{{" + replies[0] + + try: + output_dict = json.loads(replies[0]) + self.pydantic_model.parse_obj(output_dict) + print(f"OutputParser at Iteration {self.iteration_counter}: Valid JSON from LLM - No need for looping: {replies[0]}") + return {"valid": replies} + + except (ValueError, ValidationError) as e: + print(f"OutputParser at Iteration {self.iteration_counter}: Invalid JSON from LLM - Let's try again.\n" + f"Output from LLM:\n {replies[0]} \n" + f"Error from OutputParser: {e}") + return {"invalid": replies, "error_message": str(e)} + + +# Let's create a prompt that always includes the basic instructions for creating our JSON, and optionally, information from any previously failed attempt (corrupt JSON + error message from parsing it). +# The Jinja2 templating language gives us full flexibility here to adjust the prompt dynamically depending on which inputs are available +prompt_template = """ + Create a JSON object from the information present in this passage: {{passage}}. + Only use information that is present in the passage. Follow this JSON schema, but only return the actual instances without any additional schema definition:" + {{schema}} + Make sure your response is a dict and not a list. + {% if replies and error_message %} + You already created the following output in a previous attempt: {{replies}} + However, this doesn't comply with the format requirements from above and triggered this Python exception: {{ error_message}} + Correct the output and try again. Just return the corrected output without any extra explanations. + {% endif %} +""" + +# Let's build the pipeline (Make sure to set OPENAI_API_KEY as an environment variable) +pipeline = Pipeline(max_loops_allowed=5) +pipeline.add_component(instance=PromptBuilder(template=prompt_template), name="prompt_builder") +pipeline.add_component(instance=GPTGenerator(), name="llm") +pipeline.add_component(instance=OutputParser(pydantic_model=CitiesData), name="output_parser") + +pipeline.connect("prompt_builder", "llm") +pipeline.connect("llm", "output_parser") +pipeline.connect("output_parser.invalid", "prompt_builder.replies") +pipeline.connect("output_parser.error_message", "prompt_builder.error_message") + +# Now, let's run our pipeline with an example passage that we want to convert into our JSON format +passage = "Berlin is the capital of Germany. It has a population of 3,850,809" +result = pipeline.run({ + "prompt_builder": {"passage": passage, + "schema": schema} +}) + +print(result)