refactor: LLMMetadataExtractor - adopt ChatGenerator protocol: deprecate generator_api, generator_api_params and LLMProvider (#9099)

* draft

* improvements + tests

* release note

* mypy fixes

* improve relnote

* serialize chat_generator only

* small simplification

* clarify that also LLMProvider is deprecated

* revert from_dict

* test_from_dict_openai_using_chat_generator
This commit is contained in:
Stefano Fiorucci 2025-03-24 18:38:09 +01:00 committed by GitHub
parent dae8c7baba
commit 6db8f0a40d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 197 additions and 60 deletions

View File

@ -4,6 +4,7 @@
import copy
import json
import warnings
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from typing import Any, Dict, List, Optional, Union
@ -14,7 +15,9 @@ from jinja2.sandbox import SandboxedEnvironment
from haystack import Document, component, default_from_dict, default_to_dict, logging
from haystack.components.builders import PromptBuilder
from haystack.components.generators.chat import AzureOpenAIChatGenerator, OpenAIChatGenerator
from haystack.components.generators.chat.types import ChatGenerator
from haystack.components.preprocessors import DocumentSplitter
from haystack.core.serialization import import_class_by_name
from haystack.dataclasses import ChatMessage
from haystack.lazy_imports import LazyImport
from haystack.utils import deserialize_callable, deserialize_secrets_inplace, expand_page_range
@ -76,7 +79,8 @@ class LLMMetadataExtractor:
```python
from haystack import Document
from haystack_experimental.components.extractors.llm_metadata_extractor import LLMMetadataExtractor
from haystack.components.extractors.llm_metadata_extractor import LLMMetadataExtractor
from haystack.components.generators.chat import OpenAIChatGenerator
NER_PROMPT = '''
-Goal-
@ -122,22 +126,24 @@ class LLMMetadataExtractor:
Document(content="Hugging Face is a company that was founded in New York, USA and is known for its Transformers library")
]
chat_generator = OpenAIChatGenerator(
generation_kwargs={
"max_tokens": 500,
"temperature": 0.0,
"seed": 0,
"response_format": {"type": "json_object"},
},
max_retries=1,
timeout=60.0,
)
extractor = LLMMetadataExtractor(
prompt=NER_PROMPT,
generator_api="openai",
generator_api_params={
"generation_kwargs": {
"max_tokens": 500,
"temperature": 0.0,
"seed": 0,
"response_format": {"type": "json_object"},
},
"max_retries": 1,
"timeout": 60.0,
},
chat_generator=generator,
expected_keys=["entities"],
raise_on_failure=False,
)
extractor.warm_up()
extractor.run(documents=docs)
>> {'documents': [
@ -159,8 +165,9 @@ class LLMMetadataExtractor:
def __init__( # pylint: disable=R0917
self,
prompt: str,
generator_api: Union[str, LLMProvider],
generator_api: Optional[Union[str, LLMProvider]] = None,
generator_api_params: Optional[Dict[str, Any]] = None,
chat_generator: Optional[ChatGenerator] = None,
expected_keys: Optional[List[str]] = None,
page_range: Optional[List[Union[str, int]]] = None,
raise_on_failure: bool = False,
@ -170,18 +177,20 @@ class LLMMetadataExtractor:
Initializes the LLMMetadataExtractor.
:param prompt: The prompt to be used for the LLM.
:param generator_api: The API provider for the LLM. Currently supported providers are:
"openai", "openai_azure", "aws_bedrock", "google_vertex"
:param generator_api_params: The parameters for the LLM generator.
:param generator_api: The API provider for the LLM. Deprecated. Use chat_generator to configure the LLM.
Currently supported providers are: "openai", "openai_azure", "aws_bedrock", "google_vertex".
:param generator_api_params: The parameters for the LLM generator. Deprecated. Use chat_generator to configure
the LLM.
:param chat_generator: a ChatGenerator instance which represents the LLM. If provided, this will override
settings in generator_api and generator_api_params.
:param expected_keys: The keys expected in the JSON output from the LLM.
:param page_range: A range of pages to extract metadata from. For example, page_range=['1', '3'] will extract
metadata from the first and third pages of each document. It also accepts printable range
strings, e.g.: ['1-3', '5', '8', '10-12'] will extract metadata from pages 1, 2, 3, 5, 8, 10,
11, 12. If None, metadata will be extracted from the entire document for each document in the
documents list.
This parameter is optional and can be overridden in the `run` method.
metadata from the first and third pages of each document. It also accepts printable range strings, e.g.:
['1-3', '5', '8', '10-12'] will extract metadata from pages 1, 2, 3, 5, 8, 10,11, 12.
If None, metadata will be extracted from the entire document for each document in the documents list.
This parameter is optional and can be overridden in the `run` method.
:param raise_on_failure: Whether to raise an error on failure during the execution of the Generator or
validation of the JSON output.
validation of the JSON output.
:param max_workers: The maximum number of workers to use in the thread pool executor.
"""
self.prompt = prompt
@ -195,11 +204,32 @@ class LLMMetadataExtractor:
self.builder = PromptBuilder(prompt, required_variables=variables)
self.raise_on_failure = raise_on_failure
self.expected_keys = expected_keys or []
self.generator_api = (
generator_api if isinstance(generator_api, LLMProvider) else LLMProvider.from_str(generator_api)
)
self.generator_api_params = generator_api_params or {}
self.llm_provider = self._init_generator(self.generator_api, self.generator_api_params)
generator_api_params = generator_api_params or {}
if generator_api is None and chat_generator is None:
raise ValueError("Either generator_api or chat_generator must be provided.")
if chat_generator is not None:
self._chat_generator = chat_generator
if generator_api is not None:
logger.warning(
"Both chat_generator and generator_api are provided. "
"chat_generator will be used. generator_api/generator_api_params/LLMProvider are deprecated and "
"will be removed in Haystack 2.13.0."
)
else:
warnings.warn(
"generator_api, generator_api_params, and LLMProvider are deprecated and will be removed in Haystack "
"2.13.0. Use chat_generator instead. For example, change `generator_api=LLMProvider.OPENAI` to "
"`chat_generator=OpenAIChatGenerator()`.",
DeprecationWarning,
)
assert generator_api is not None # verified by the checks above
generator_api = (
generator_api if isinstance(generator_api, LLMProvider) else LLMProvider.from_str(generator_api)
)
self._chat_generator = self._init_generator(generator_api, generator_api_params)
self.splitter = DocumentSplitter(split_by="page", split_length=1)
self.expanded_range = expand_page_range(page_range) if page_range else None
self.max_workers = max_workers
@ -233,8 +263,8 @@ class LLMMetadataExtractor:
"""
Warm up the LLM provider component.
"""
if hasattr(self.llm_provider, "warm_up"):
self.llm_provider.warm_up()
if hasattr(self._chat_generator, "warm_up"):
self._chat_generator.warm_up()
def to_dict(self) -> Dict[str, Any]:
"""
@ -244,13 +274,10 @@ class LLMMetadataExtractor:
Dictionary with serialized data.
"""
llm_provider = self.llm_provider.to_dict()
return default_to_dict(
self,
prompt=self.prompt,
generator_api=self.generator_api.value,
generator_api_params=llm_provider["init_parameters"],
chat_generator=self._chat_generator.to_dict(),
expected_keys=self.expected_keys,
page_range=self.expanded_range,
raise_on_failure=self.raise_on_failure,
@ -270,6 +297,15 @@ class LLMMetadataExtractor:
init_parameters = data.get("init_parameters", {})
# new deserialization with chat_generator
if init_parameters.get("chat_generator") is not None:
chat_generator_class = import_class_by_name(init_parameters["chat_generator"]["type"])
assert hasattr(chat_generator_class, "from_dict") # we know but mypy doesn't
chat_generator_instance = chat_generator_class.from_dict(init_parameters["chat_generator"])
data["init_parameters"]["chat_generator"] = chat_generator_instance
return default_from_dict(cls, data)
# legacy deserialization
if "generator_api" in init_parameters:
data["init_parameters"]["generator_api"] = LLMProvider.from_str(data["init_parameters"]["generator_api"])
@ -364,15 +400,15 @@ class LLMMetadataExtractor:
return {"replies": ["{}"]}
try:
result = self.llm_provider.run(messages=[prompt])
result = self._chat_generator.run(messages=[prompt])
except Exception as e:
logger.error(
"LLM {class_name} execution failed. Skipping metadata extraction. Failed with exception '{error}'.",
class_name=self.llm_provider.__class__.__name__,
error=e,
)
if self.raise_on_failure:
raise e
logger.error(
"LLM {class_name} execution failed. Skipping metadata extraction. Failed with exception '{error}'.",
class_name=self._chat_generator.__class__.__name__,
error=e,
)
result = {"error": "LLM failed with exception: " + str(e)}
return result

View File

@ -0,0 +1,14 @@
---
enhancements:
- |
`LLMMetadataExtractor` now accepts a `chat_generator` initialization parameter, consisting of a pre-configured
`ChatGenerator` instance.
Regardless of whether `LLMMetadataExtractor` is initialized using `generator_api` and `generator_api_params` or
the new `chat_generator` parameter, the serialization format will only include `chat_generator` in preparation
for the future removal of `generator_api` and `generator_api_params`.
deprecations:
- |
The `generator_api` and `generator_api_params` initialization parameters of `LLMMetadataExtractor` and the
`LLMProvider` enum are deprecated and will be removed in Haystack 2.13.0. Use `chat_generator` instead to configure
the underlying LLM. For example, change `generator_api=LLMProvider.OPENAI` to
`chat_generator=OpenAIChatGenerator()`.

View File

@ -1,13 +1,16 @@
import os
import pytest
from unittest.mock import Mock
from haystack import Document, Pipeline
from haystack.components.builders import PromptBuilder
from haystack.components.writers import DocumentWriter
from haystack.dataclasses import ChatMessage
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.components.extractors import LLMMetadataExtractor, LLMProvider
from haystack.components.generators.chat import OpenAIChatGenerator
class TestLLMMetadataExtractor:
@ -17,7 +20,7 @@ class TestLLMMetadataExtractor:
prompt="prompt {{document.content}}", expected_keys=["key1", "key2"], generator_api=LLMProvider.OPENAI
)
assert isinstance(extractor.builder, PromptBuilder)
assert extractor.generator_api == LLMProvider.OPENAI
assert isinstance(extractor._chat_generator, OpenAIChatGenerator)
assert extractor.expected_keys == ["key1", "key2"]
assert extractor.raise_on_failure is False
@ -34,10 +37,23 @@ class TestLLMMetadataExtractor:
assert isinstance(extractor.builder, PromptBuilder)
assert extractor.expected_keys == ["key1", "key2"]
assert extractor.raise_on_failure is True
assert extractor.generator_api == LLMProvider.OPENAI
assert extractor.generator_api_params == {"model": "gpt-3.5-turbo", "generation_kwargs": {"temperature": 0.5}}
assert isinstance(extractor._chat_generator, OpenAIChatGenerator)
assert extractor._chat_generator.model == "gpt-3.5-turbo"
assert extractor._chat_generator.generation_kwargs == {"temperature": 0.5}
assert extractor.expanded_range == [1, 2, 3, 4, 5]
def test_init_with_chat_generator(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
chat_generator = OpenAIChatGenerator(model="gpt-4o-mini", generation_kwargs={"temperature": 0.5})
extractor = LLMMetadataExtractor(
prompt="prompt {{document.content}}", expected_keys=["key1", "key2"], chat_generator=chat_generator
)
assert isinstance(extractor._chat_generator, OpenAIChatGenerator)
assert extractor._chat_generator.model == "gpt-4o-mini"
assert extractor._chat_generator.generation_kwargs == {"temperature": 0.5}
assert extractor.expected_keys == ["key1", "key2"]
def test_init_missing_prompt_variable(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
with pytest.raises(ValueError):
@ -45,6 +61,25 @@ class TestLLMMetadataExtractor:
prompt="prompt {{ wrong_variable }}", expected_keys=["key1", "key2"], generator_api=LLMProvider.OPENAI
)
def test_init_fails_without_generator_api_or_chat_generator(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
with pytest.raises(ValueError):
_ = LLMMetadataExtractor(prompt="prompt {{document.content}}", expected_keys=["key1", "key2"])
def test_init_chat_generator_overrides_generator_api(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
chat_generator = OpenAIChatGenerator(model="gpt-4o-mini", generation_kwargs={"temperature": 0.5})
extractor = LLMMetadataExtractor(
prompt="prompt {{document.content}}",
expected_keys=["key1", "key2"],
generator_api=LLMProvider.AWS_BEDROCK,
chat_generator=chat_generator,
)
assert isinstance(extractor._chat_generator, OpenAIChatGenerator)
assert extractor._chat_generator.model == "gpt-4o-mini"
assert extractor._chat_generator.generation_kwargs == {"temperature": 0.5}
def test_to_dict_openai(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
extractor = LLMMetadataExtractor(
@ -59,29 +94,55 @@ class TestLLMMetadataExtractor:
assert extractor_dict == {
"type": "haystack.components.extractors.llm_metadata_extractor.LLMMetadataExtractor",
"init_parameters": {
"chat_generator": {
"init_parameters": {
"api_base_url": None,
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
"generation_kwargs": {"temperature": 0.5},
"max_retries": None,
"model": "gpt-4o-mini",
"organization": None,
"streaming_callback": None,
"timeout": None,
"tools": None,
"tools_strict": False,
},
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
},
"prompt": "some prompt that was used with the LLM {{document.content}}",
"expected_keys": ["key1", "key2"],
"raise_on_failure": True,
"generator_api": "openai",
"page_range": None,
"generator_api_params": {
"api_base_url": None,
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
"generation_kwargs": {"temperature": 0.5},
"model": "gpt-4o-mini",
"organization": None,
"streaming_callback": None,
"max_retries": None,
"timeout": None,
"tools": None,
"tools_strict": False,
},
"max_workers": 3,
},
}
def test_to_dict_openai_using_chat_generator(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
chat_generator = OpenAIChatGenerator(model="gpt-4o-mini", generation_kwargs={"temperature": 0.5})
extractor = LLMMetadataExtractor(
prompt="some prompt that was used with the LLM {{document.content}}",
expected_keys=["key1", "key2"],
chat_generator=chat_generator,
raise_on_failure=True,
)
extractor_dict = extractor.to_dict()
assert extractor_dict == {
"type": "haystack.components.extractors.llm_metadata_extractor.LLMMetadataExtractor",
"init_parameters": {
"prompt": "some prompt that was used with the LLM {{document.content}}",
"expected_keys": ["key1", "key2"],
"raise_on_failure": True,
"chat_generator": chat_generator.to_dict(),
"page_range": None,
"max_workers": 3,
},
}
def test_from_dict_openai(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
extractor_dict = {
"type": "haystack.components.extractors.llm_metadata_extractor.LLMMetadataExtractor",
"init_parameters": {
@ -103,12 +164,38 @@ class TestLLMMetadataExtractor:
assert extractor.raise_on_failure is True
assert extractor.expected_keys == ["key1", "key2"]
assert extractor.prompt == "some prompt that was used with the LLM {{document.content}}"
assert extractor.generator_api == LLMProvider.OPENAI
assert isinstance(extractor._chat_generator, OpenAIChatGenerator)
assert extractor._chat_generator.model == "gpt-4o-mini"
def test_warm_up(self, monkeypatch):
def test_from_dict_openai_using_chat_generator(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
extractor = LLMMetadataExtractor(prompt="prompt {{document.content}}", generator_api=LLMProvider.OPENAI)
assert extractor.warm_up() is None
chat_generator = OpenAIChatGenerator(model="gpt-4o-mini", generation_kwargs={"temperature": 0.5})
extractor_dict = {
"type": "haystack.components.extractors.llm_metadata_extractor.LLMMetadataExtractor",
"init_parameters": {
"prompt": "some prompt that was used with the LLM {{document.content}}",
"expected_keys": ["key1", "key2"],
"chat_generator": chat_generator.to_dict(),
"raise_on_failure": True,
},
}
extractor = LLMMetadataExtractor.from_dict(extractor_dict)
assert extractor.raise_on_failure is True
assert extractor.expected_keys == ["key1", "key2"]
assert extractor.prompt == "some prompt that was used with the LLM {{document.content}}"
assert extractor._chat_generator.to_dict() == chat_generator.to_dict()
def test_warm_up_with_chat_generator(self, monkeypatch):
mock_chat_generator = Mock()
extractor = LLMMetadataExtractor(prompt="prompt {{document.content}}", chat_generator=mock_chat_generator)
mock_chat_generator.warm_up.assert_not_called()
extractor.warm_up()
mock_chat_generator.warm_up.assert_called_once()
def test_extract_metadata(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
@ -287,7 +374,7 @@ output:
doc_store = InMemoryDocumentStore()
extractor = LLMMetadataExtractor(
prompt=ner_prompt, expected_keys=["entities"], generator_api=LLMProvider.OPENAI
prompt=ner_prompt, expected_keys=["entities"], chat_generator=OpenAIChatGenerator()
)
writer = DocumentWriter(document_store=doc_store)
pipeline = Pipeline()