mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-25 14:36:05 +00:00
refactor: LLMMetadataExtractor - adopt ChatGenerator protocol: deprecate generator_api, generator_api_params and LLMProvider (#9099)
* draft * improvements + tests * release note * mypy fixes * improve relnote * serialize chat_generator only * small simplification * clarify that also LLMProvider is deprecated * revert from_dict * test_from_dict_openai_using_chat_generator
This commit is contained in:
parent
dae8c7baba
commit
6db8f0a40d
@ -4,6 +4,7 @@
|
||||
|
||||
import copy
|
||||
import json
|
||||
import warnings
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
@ -14,7 +15,9 @@ from jinja2.sandbox import SandboxedEnvironment
|
||||
from haystack import Document, component, default_from_dict, default_to_dict, logging
|
||||
from haystack.components.builders import PromptBuilder
|
||||
from haystack.components.generators.chat import AzureOpenAIChatGenerator, OpenAIChatGenerator
|
||||
from haystack.components.generators.chat.types import ChatGenerator
|
||||
from haystack.components.preprocessors import DocumentSplitter
|
||||
from haystack.core.serialization import import_class_by_name
|
||||
from haystack.dataclasses import ChatMessage
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import deserialize_callable, deserialize_secrets_inplace, expand_page_range
|
||||
@ -76,7 +79,8 @@ class LLMMetadataExtractor:
|
||||
|
||||
```python
|
||||
from haystack import Document
|
||||
from haystack_experimental.components.extractors.llm_metadata_extractor import LLMMetadataExtractor
|
||||
from haystack.components.extractors.llm_metadata_extractor import LLMMetadataExtractor
|
||||
from haystack.components.generators.chat import OpenAIChatGenerator
|
||||
|
||||
NER_PROMPT = '''
|
||||
-Goal-
|
||||
@ -122,22 +126,24 @@ class LLMMetadataExtractor:
|
||||
Document(content="Hugging Face is a company that was founded in New York, USA and is known for its Transformers library")
|
||||
]
|
||||
|
||||
chat_generator = OpenAIChatGenerator(
|
||||
generation_kwargs={
|
||||
"max_tokens": 500,
|
||||
"temperature": 0.0,
|
||||
"seed": 0,
|
||||
"response_format": {"type": "json_object"},
|
||||
},
|
||||
max_retries=1,
|
||||
timeout=60.0,
|
||||
)
|
||||
|
||||
extractor = LLMMetadataExtractor(
|
||||
prompt=NER_PROMPT,
|
||||
generator_api="openai",
|
||||
generator_api_params={
|
||||
"generation_kwargs": {
|
||||
"max_tokens": 500,
|
||||
"temperature": 0.0,
|
||||
"seed": 0,
|
||||
"response_format": {"type": "json_object"},
|
||||
},
|
||||
"max_retries": 1,
|
||||
"timeout": 60.0,
|
||||
},
|
||||
chat_generator=generator,
|
||||
expected_keys=["entities"],
|
||||
raise_on_failure=False,
|
||||
)
|
||||
|
||||
extractor.warm_up()
|
||||
extractor.run(documents=docs)
|
||||
>> {'documents': [
|
||||
@ -159,8 +165,9 @@ class LLMMetadataExtractor:
|
||||
def __init__( # pylint: disable=R0917
|
||||
self,
|
||||
prompt: str,
|
||||
generator_api: Union[str, LLMProvider],
|
||||
generator_api: Optional[Union[str, LLMProvider]] = None,
|
||||
generator_api_params: Optional[Dict[str, Any]] = None,
|
||||
chat_generator: Optional[ChatGenerator] = None,
|
||||
expected_keys: Optional[List[str]] = None,
|
||||
page_range: Optional[List[Union[str, int]]] = None,
|
||||
raise_on_failure: bool = False,
|
||||
@ -170,18 +177,20 @@ class LLMMetadataExtractor:
|
||||
Initializes the LLMMetadataExtractor.
|
||||
|
||||
:param prompt: The prompt to be used for the LLM.
|
||||
:param generator_api: The API provider for the LLM. Currently supported providers are:
|
||||
"openai", "openai_azure", "aws_bedrock", "google_vertex"
|
||||
:param generator_api_params: The parameters for the LLM generator.
|
||||
:param generator_api: The API provider for the LLM. Deprecated. Use chat_generator to configure the LLM.
|
||||
Currently supported providers are: "openai", "openai_azure", "aws_bedrock", "google_vertex".
|
||||
:param generator_api_params: The parameters for the LLM generator. Deprecated. Use chat_generator to configure
|
||||
the LLM.
|
||||
:param chat_generator: a ChatGenerator instance which represents the LLM. If provided, this will override
|
||||
settings in generator_api and generator_api_params.
|
||||
:param expected_keys: The keys expected in the JSON output from the LLM.
|
||||
:param page_range: A range of pages to extract metadata from. For example, page_range=['1', '3'] will extract
|
||||
metadata from the first and third pages of each document. It also accepts printable range
|
||||
strings, e.g.: ['1-3', '5', '8', '10-12'] will extract metadata from pages 1, 2, 3, 5, 8, 10,
|
||||
11, 12. If None, metadata will be extracted from the entire document for each document in the
|
||||
documents list.
|
||||
This parameter is optional and can be overridden in the `run` method.
|
||||
metadata from the first and third pages of each document. It also accepts printable range strings, e.g.:
|
||||
['1-3', '5', '8', '10-12'] will extract metadata from pages 1, 2, 3, 5, 8, 10,11, 12.
|
||||
If None, metadata will be extracted from the entire document for each document in the documents list.
|
||||
This parameter is optional and can be overridden in the `run` method.
|
||||
:param raise_on_failure: Whether to raise an error on failure during the execution of the Generator or
|
||||
validation of the JSON output.
|
||||
validation of the JSON output.
|
||||
:param max_workers: The maximum number of workers to use in the thread pool executor.
|
||||
"""
|
||||
self.prompt = prompt
|
||||
@ -195,11 +204,32 @@ class LLMMetadataExtractor:
|
||||
self.builder = PromptBuilder(prompt, required_variables=variables)
|
||||
self.raise_on_failure = raise_on_failure
|
||||
self.expected_keys = expected_keys or []
|
||||
self.generator_api = (
|
||||
generator_api if isinstance(generator_api, LLMProvider) else LLMProvider.from_str(generator_api)
|
||||
)
|
||||
self.generator_api_params = generator_api_params or {}
|
||||
self.llm_provider = self._init_generator(self.generator_api, self.generator_api_params)
|
||||
generator_api_params = generator_api_params or {}
|
||||
|
||||
if generator_api is None and chat_generator is None:
|
||||
raise ValueError("Either generator_api or chat_generator must be provided.")
|
||||
|
||||
if chat_generator is not None:
|
||||
self._chat_generator = chat_generator
|
||||
if generator_api is not None:
|
||||
logger.warning(
|
||||
"Both chat_generator and generator_api are provided. "
|
||||
"chat_generator will be used. generator_api/generator_api_params/LLMProvider are deprecated and "
|
||||
"will be removed in Haystack 2.13.0."
|
||||
)
|
||||
else:
|
||||
warnings.warn(
|
||||
"generator_api, generator_api_params, and LLMProvider are deprecated and will be removed in Haystack "
|
||||
"2.13.0. Use chat_generator instead. For example, change `generator_api=LLMProvider.OPENAI` to "
|
||||
"`chat_generator=OpenAIChatGenerator()`.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
assert generator_api is not None # verified by the checks above
|
||||
generator_api = (
|
||||
generator_api if isinstance(generator_api, LLMProvider) else LLMProvider.from_str(generator_api)
|
||||
)
|
||||
self._chat_generator = self._init_generator(generator_api, generator_api_params)
|
||||
|
||||
self.splitter = DocumentSplitter(split_by="page", split_length=1)
|
||||
self.expanded_range = expand_page_range(page_range) if page_range else None
|
||||
self.max_workers = max_workers
|
||||
@ -233,8 +263,8 @@ class LLMMetadataExtractor:
|
||||
"""
|
||||
Warm up the LLM provider component.
|
||||
"""
|
||||
if hasattr(self.llm_provider, "warm_up"):
|
||||
self.llm_provider.warm_up()
|
||||
if hasattr(self._chat_generator, "warm_up"):
|
||||
self._chat_generator.warm_up()
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -244,13 +274,10 @@ class LLMMetadataExtractor:
|
||||
Dictionary with serialized data.
|
||||
"""
|
||||
|
||||
llm_provider = self.llm_provider.to_dict()
|
||||
|
||||
return default_to_dict(
|
||||
self,
|
||||
prompt=self.prompt,
|
||||
generator_api=self.generator_api.value,
|
||||
generator_api_params=llm_provider["init_parameters"],
|
||||
chat_generator=self._chat_generator.to_dict(),
|
||||
expected_keys=self.expected_keys,
|
||||
page_range=self.expanded_range,
|
||||
raise_on_failure=self.raise_on_failure,
|
||||
@ -270,6 +297,15 @@ class LLMMetadataExtractor:
|
||||
|
||||
init_parameters = data.get("init_parameters", {})
|
||||
|
||||
# new deserialization with chat_generator
|
||||
if init_parameters.get("chat_generator") is not None:
|
||||
chat_generator_class = import_class_by_name(init_parameters["chat_generator"]["type"])
|
||||
assert hasattr(chat_generator_class, "from_dict") # we know but mypy doesn't
|
||||
chat_generator_instance = chat_generator_class.from_dict(init_parameters["chat_generator"])
|
||||
data["init_parameters"]["chat_generator"] = chat_generator_instance
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
# legacy deserialization
|
||||
if "generator_api" in init_parameters:
|
||||
data["init_parameters"]["generator_api"] = LLMProvider.from_str(data["init_parameters"]["generator_api"])
|
||||
|
||||
@ -364,15 +400,15 @@ class LLMMetadataExtractor:
|
||||
return {"replies": ["{}"]}
|
||||
|
||||
try:
|
||||
result = self.llm_provider.run(messages=[prompt])
|
||||
result = self._chat_generator.run(messages=[prompt])
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"LLM {class_name} execution failed. Skipping metadata extraction. Failed with exception '{error}'.",
|
||||
class_name=self.llm_provider.__class__.__name__,
|
||||
error=e,
|
||||
)
|
||||
if self.raise_on_failure:
|
||||
raise e
|
||||
logger.error(
|
||||
"LLM {class_name} execution failed. Skipping metadata extraction. Failed with exception '{error}'.",
|
||||
class_name=self._chat_generator.__class__.__name__,
|
||||
error=e,
|
||||
)
|
||||
result = {"error": "LLM failed with exception: " + str(e)}
|
||||
return result
|
||||
|
||||
|
||||
@ -0,0 +1,14 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
`LLMMetadataExtractor` now accepts a `chat_generator` initialization parameter, consisting of a pre-configured
|
||||
`ChatGenerator` instance.
|
||||
Regardless of whether `LLMMetadataExtractor` is initialized using `generator_api` and `generator_api_params` or
|
||||
the new `chat_generator` parameter, the serialization format will only include `chat_generator` in preparation
|
||||
for the future removal of `generator_api` and `generator_api_params`.
|
||||
deprecations:
|
||||
- |
|
||||
The `generator_api` and `generator_api_params` initialization parameters of `LLMMetadataExtractor` and the
|
||||
`LLMProvider` enum are deprecated and will be removed in Haystack 2.13.0. Use `chat_generator` instead to configure
|
||||
the underlying LLM. For example, change `generator_api=LLMProvider.OPENAI` to
|
||||
`chat_generator=OpenAIChatGenerator()`.
|
||||
@ -1,13 +1,16 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock
|
||||
from haystack import Document, Pipeline
|
||||
from haystack.components.builders import PromptBuilder
|
||||
from haystack.components.writers import DocumentWriter
|
||||
from haystack.dataclasses import ChatMessage
|
||||
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
||||
|
||||
|
||||
from haystack.components.extractors import LLMMetadataExtractor, LLMProvider
|
||||
from haystack.components.generators.chat import OpenAIChatGenerator
|
||||
|
||||
|
||||
class TestLLMMetadataExtractor:
|
||||
@ -17,7 +20,7 @@ class TestLLMMetadataExtractor:
|
||||
prompt="prompt {{document.content}}", expected_keys=["key1", "key2"], generator_api=LLMProvider.OPENAI
|
||||
)
|
||||
assert isinstance(extractor.builder, PromptBuilder)
|
||||
assert extractor.generator_api == LLMProvider.OPENAI
|
||||
assert isinstance(extractor._chat_generator, OpenAIChatGenerator)
|
||||
assert extractor.expected_keys == ["key1", "key2"]
|
||||
assert extractor.raise_on_failure is False
|
||||
|
||||
@ -34,10 +37,23 @@ class TestLLMMetadataExtractor:
|
||||
assert isinstance(extractor.builder, PromptBuilder)
|
||||
assert extractor.expected_keys == ["key1", "key2"]
|
||||
assert extractor.raise_on_failure is True
|
||||
assert extractor.generator_api == LLMProvider.OPENAI
|
||||
assert extractor.generator_api_params == {"model": "gpt-3.5-turbo", "generation_kwargs": {"temperature": 0.5}}
|
||||
assert isinstance(extractor._chat_generator, OpenAIChatGenerator)
|
||||
assert extractor._chat_generator.model == "gpt-3.5-turbo"
|
||||
assert extractor._chat_generator.generation_kwargs == {"temperature": 0.5}
|
||||
assert extractor.expanded_range == [1, 2, 3, 4, 5]
|
||||
|
||||
def test_init_with_chat_generator(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
||||
chat_generator = OpenAIChatGenerator(model="gpt-4o-mini", generation_kwargs={"temperature": 0.5})
|
||||
|
||||
extractor = LLMMetadataExtractor(
|
||||
prompt="prompt {{document.content}}", expected_keys=["key1", "key2"], chat_generator=chat_generator
|
||||
)
|
||||
assert isinstance(extractor._chat_generator, OpenAIChatGenerator)
|
||||
assert extractor._chat_generator.model == "gpt-4o-mini"
|
||||
assert extractor._chat_generator.generation_kwargs == {"temperature": 0.5}
|
||||
assert extractor.expected_keys == ["key1", "key2"]
|
||||
|
||||
def test_init_missing_prompt_variable(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
||||
with pytest.raises(ValueError):
|
||||
@ -45,6 +61,25 @@ class TestLLMMetadataExtractor:
|
||||
prompt="prompt {{ wrong_variable }}", expected_keys=["key1", "key2"], generator_api=LLMProvider.OPENAI
|
||||
)
|
||||
|
||||
def test_init_fails_without_generator_api_or_chat_generator(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
||||
with pytest.raises(ValueError):
|
||||
_ = LLMMetadataExtractor(prompt="prompt {{document.content}}", expected_keys=["key1", "key2"])
|
||||
|
||||
def test_init_chat_generator_overrides_generator_api(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
||||
chat_generator = OpenAIChatGenerator(model="gpt-4o-mini", generation_kwargs={"temperature": 0.5})
|
||||
extractor = LLMMetadataExtractor(
|
||||
prompt="prompt {{document.content}}",
|
||||
expected_keys=["key1", "key2"],
|
||||
generator_api=LLMProvider.AWS_BEDROCK,
|
||||
chat_generator=chat_generator,
|
||||
)
|
||||
|
||||
assert isinstance(extractor._chat_generator, OpenAIChatGenerator)
|
||||
assert extractor._chat_generator.model == "gpt-4o-mini"
|
||||
assert extractor._chat_generator.generation_kwargs == {"temperature": 0.5}
|
||||
|
||||
def test_to_dict_openai(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
||||
extractor = LLMMetadataExtractor(
|
||||
@ -59,29 +94,55 @@ class TestLLMMetadataExtractor:
|
||||
assert extractor_dict == {
|
||||
"type": "haystack.components.extractors.llm_metadata_extractor.LLMMetadataExtractor",
|
||||
"init_parameters": {
|
||||
"chat_generator": {
|
||||
"init_parameters": {
|
||||
"api_base_url": None,
|
||||
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
|
||||
"generation_kwargs": {"temperature": 0.5},
|
||||
"max_retries": None,
|
||||
"model": "gpt-4o-mini",
|
||||
"organization": None,
|
||||
"streaming_callback": None,
|
||||
"timeout": None,
|
||||
"tools": None,
|
||||
"tools_strict": False,
|
||||
},
|
||||
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
|
||||
},
|
||||
"prompt": "some prompt that was used with the LLM {{document.content}}",
|
||||
"expected_keys": ["key1", "key2"],
|
||||
"raise_on_failure": True,
|
||||
"generator_api": "openai",
|
||||
"page_range": None,
|
||||
"generator_api_params": {
|
||||
"api_base_url": None,
|
||||
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
|
||||
"generation_kwargs": {"temperature": 0.5},
|
||||
"model": "gpt-4o-mini",
|
||||
"organization": None,
|
||||
"streaming_callback": None,
|
||||
"max_retries": None,
|
||||
"timeout": None,
|
||||
"tools": None,
|
||||
"tools_strict": False,
|
||||
},
|
||||
"max_workers": 3,
|
||||
},
|
||||
}
|
||||
|
||||
def test_to_dict_openai_using_chat_generator(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
||||
chat_generator = OpenAIChatGenerator(model="gpt-4o-mini", generation_kwargs={"temperature": 0.5})
|
||||
extractor = LLMMetadataExtractor(
|
||||
prompt="some prompt that was used with the LLM {{document.content}}",
|
||||
expected_keys=["key1", "key2"],
|
||||
chat_generator=chat_generator,
|
||||
raise_on_failure=True,
|
||||
)
|
||||
extractor_dict = extractor.to_dict()
|
||||
|
||||
assert extractor_dict == {
|
||||
"type": "haystack.components.extractors.llm_metadata_extractor.LLMMetadataExtractor",
|
||||
"init_parameters": {
|
||||
"prompt": "some prompt that was used with the LLM {{document.content}}",
|
||||
"expected_keys": ["key1", "key2"],
|
||||
"raise_on_failure": True,
|
||||
"chat_generator": chat_generator.to_dict(),
|
||||
"page_range": None,
|
||||
"max_workers": 3,
|
||||
},
|
||||
}
|
||||
|
||||
def test_from_dict_openai(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
||||
|
||||
extractor_dict = {
|
||||
"type": "haystack.components.extractors.llm_metadata_extractor.LLMMetadataExtractor",
|
||||
"init_parameters": {
|
||||
@ -103,12 +164,38 @@ class TestLLMMetadataExtractor:
|
||||
assert extractor.raise_on_failure is True
|
||||
assert extractor.expected_keys == ["key1", "key2"]
|
||||
assert extractor.prompt == "some prompt that was used with the LLM {{document.content}}"
|
||||
assert extractor.generator_api == LLMProvider.OPENAI
|
||||
assert isinstance(extractor._chat_generator, OpenAIChatGenerator)
|
||||
assert extractor._chat_generator.model == "gpt-4o-mini"
|
||||
|
||||
def test_warm_up(self, monkeypatch):
|
||||
def test_from_dict_openai_using_chat_generator(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
||||
extractor = LLMMetadataExtractor(prompt="prompt {{document.content}}", generator_api=LLMProvider.OPENAI)
|
||||
assert extractor.warm_up() is None
|
||||
chat_generator = OpenAIChatGenerator(model="gpt-4o-mini", generation_kwargs={"temperature": 0.5})
|
||||
|
||||
extractor_dict = {
|
||||
"type": "haystack.components.extractors.llm_metadata_extractor.LLMMetadataExtractor",
|
||||
"init_parameters": {
|
||||
"prompt": "some prompt that was used with the LLM {{document.content}}",
|
||||
"expected_keys": ["key1", "key2"],
|
||||
"chat_generator": chat_generator.to_dict(),
|
||||
"raise_on_failure": True,
|
||||
},
|
||||
}
|
||||
extractor = LLMMetadataExtractor.from_dict(extractor_dict)
|
||||
assert extractor.raise_on_failure is True
|
||||
assert extractor.expected_keys == ["key1", "key2"]
|
||||
assert extractor.prompt == "some prompt that was used with the LLM {{document.content}}"
|
||||
assert extractor._chat_generator.to_dict() == chat_generator.to_dict()
|
||||
|
||||
def test_warm_up_with_chat_generator(self, monkeypatch):
|
||||
mock_chat_generator = Mock()
|
||||
|
||||
extractor = LLMMetadataExtractor(prompt="prompt {{document.content}}", chat_generator=mock_chat_generator)
|
||||
|
||||
mock_chat_generator.warm_up.assert_not_called()
|
||||
|
||||
extractor.warm_up()
|
||||
|
||||
mock_chat_generator.warm_up.assert_called_once()
|
||||
|
||||
def test_extract_metadata(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
||||
@ -287,7 +374,7 @@ output:
|
||||
|
||||
doc_store = InMemoryDocumentStore()
|
||||
extractor = LLMMetadataExtractor(
|
||||
prompt=ner_prompt, expected_keys=["entities"], generator_api=LLMProvider.OPENAI
|
||||
prompt=ner_prompt, expected_keys=["entities"], chat_generator=OpenAIChatGenerator()
|
||||
)
|
||||
writer = DocumentWriter(document_store=doc_store)
|
||||
pipeline = Pipeline()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user