mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-29 08:26:19 +00:00
refactor: LLMMetadataExtractor - adopt ChatGenerator protocol: deprecate generator_api, generator_api_params and LLMProvider (#9099)
* draft * improvements + tests * release note * mypy fixes * improve relnote * serialize chat_generator only * small simplification * clarify that also LLMProvider is deprecated * revert from_dict * test_from_dict_openai_using_chat_generator
This commit is contained in:
parent
dae8c7baba
commit
6db8f0a40d
@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
|
import warnings
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
@ -14,7 +15,9 @@ from jinja2.sandbox import SandboxedEnvironment
|
|||||||
from haystack import Document, component, default_from_dict, default_to_dict, logging
|
from haystack import Document, component, default_from_dict, default_to_dict, logging
|
||||||
from haystack.components.builders import PromptBuilder
|
from haystack.components.builders import PromptBuilder
|
||||||
from haystack.components.generators.chat import AzureOpenAIChatGenerator, OpenAIChatGenerator
|
from haystack.components.generators.chat import AzureOpenAIChatGenerator, OpenAIChatGenerator
|
||||||
|
from haystack.components.generators.chat.types import ChatGenerator
|
||||||
from haystack.components.preprocessors import DocumentSplitter
|
from haystack.components.preprocessors import DocumentSplitter
|
||||||
|
from haystack.core.serialization import import_class_by_name
|
||||||
from haystack.dataclasses import ChatMessage
|
from haystack.dataclasses import ChatMessage
|
||||||
from haystack.lazy_imports import LazyImport
|
from haystack.lazy_imports import LazyImport
|
||||||
from haystack.utils import deserialize_callable, deserialize_secrets_inplace, expand_page_range
|
from haystack.utils import deserialize_callable, deserialize_secrets_inplace, expand_page_range
|
||||||
@ -76,7 +79,8 @@ class LLMMetadataExtractor:
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
from haystack import Document
|
from haystack import Document
|
||||||
from haystack_experimental.components.extractors.llm_metadata_extractor import LLMMetadataExtractor
|
from haystack.components.extractors.llm_metadata_extractor import LLMMetadataExtractor
|
||||||
|
from haystack.components.generators.chat import OpenAIChatGenerator
|
||||||
|
|
||||||
NER_PROMPT = '''
|
NER_PROMPT = '''
|
||||||
-Goal-
|
-Goal-
|
||||||
@ -122,22 +126,24 @@ class LLMMetadataExtractor:
|
|||||||
Document(content="Hugging Face is a company that was founded in New York, USA and is known for its Transformers library")
|
Document(content="Hugging Face is a company that was founded in New York, USA and is known for its Transformers library")
|
||||||
]
|
]
|
||||||
|
|
||||||
extractor = LLMMetadataExtractor(
|
chat_generator = OpenAIChatGenerator(
|
||||||
prompt=NER_PROMPT,
|
generation_kwargs={
|
||||||
generator_api="openai",
|
|
||||||
generator_api_params={
|
|
||||||
"generation_kwargs": {
|
|
||||||
"max_tokens": 500,
|
"max_tokens": 500,
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
"seed": 0,
|
"seed": 0,
|
||||||
"response_format": {"type": "json_object"},
|
"response_format": {"type": "json_object"},
|
||||||
},
|
},
|
||||||
"max_retries": 1,
|
max_retries=1,
|
||||||
"timeout": 60.0,
|
timeout=60.0,
|
||||||
},
|
)
|
||||||
|
|
||||||
|
extractor = LLMMetadataExtractor(
|
||||||
|
prompt=NER_PROMPT,
|
||||||
|
chat_generator=generator,
|
||||||
expected_keys=["entities"],
|
expected_keys=["entities"],
|
||||||
raise_on_failure=False,
|
raise_on_failure=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
extractor.warm_up()
|
extractor.warm_up()
|
||||||
extractor.run(documents=docs)
|
extractor.run(documents=docs)
|
||||||
>> {'documents': [
|
>> {'documents': [
|
||||||
@ -159,8 +165,9 @@ class LLMMetadataExtractor:
|
|||||||
def __init__( # pylint: disable=R0917
|
def __init__( # pylint: disable=R0917
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
generator_api: Union[str, LLMProvider],
|
generator_api: Optional[Union[str, LLMProvider]] = None,
|
||||||
generator_api_params: Optional[Dict[str, Any]] = None,
|
generator_api_params: Optional[Dict[str, Any]] = None,
|
||||||
|
chat_generator: Optional[ChatGenerator] = None,
|
||||||
expected_keys: Optional[List[str]] = None,
|
expected_keys: Optional[List[str]] = None,
|
||||||
page_range: Optional[List[Union[str, int]]] = None,
|
page_range: Optional[List[Union[str, int]]] = None,
|
||||||
raise_on_failure: bool = False,
|
raise_on_failure: bool = False,
|
||||||
@ -170,15 +177,17 @@ class LLMMetadataExtractor:
|
|||||||
Initializes the LLMMetadataExtractor.
|
Initializes the LLMMetadataExtractor.
|
||||||
|
|
||||||
:param prompt: The prompt to be used for the LLM.
|
:param prompt: The prompt to be used for the LLM.
|
||||||
:param generator_api: The API provider for the LLM. Currently supported providers are:
|
:param generator_api: The API provider for the LLM. Deprecated. Use chat_generator to configure the LLM.
|
||||||
"openai", "openai_azure", "aws_bedrock", "google_vertex"
|
Currently supported providers are: "openai", "openai_azure", "aws_bedrock", "google_vertex".
|
||||||
:param generator_api_params: The parameters for the LLM generator.
|
:param generator_api_params: The parameters for the LLM generator. Deprecated. Use chat_generator to configure
|
||||||
|
the LLM.
|
||||||
|
:param chat_generator: a ChatGenerator instance which represents the LLM. If provided, this will override
|
||||||
|
settings in generator_api and generator_api_params.
|
||||||
:param expected_keys: The keys expected in the JSON output from the LLM.
|
:param expected_keys: The keys expected in the JSON output from the LLM.
|
||||||
:param page_range: A range of pages to extract metadata from. For example, page_range=['1', '3'] will extract
|
:param page_range: A range of pages to extract metadata from. For example, page_range=['1', '3'] will extract
|
||||||
metadata from the first and third pages of each document. It also accepts printable range
|
metadata from the first and third pages of each document. It also accepts printable range strings, e.g.:
|
||||||
strings, e.g.: ['1-3', '5', '8', '10-12'] will extract metadata from pages 1, 2, 3, 5, 8, 10,
|
['1-3', '5', '8', '10-12'] will extract metadata from pages 1, 2, 3, 5, 8, 10,11, 12.
|
||||||
11, 12. If None, metadata will be extracted from the entire document for each document in the
|
If None, metadata will be extracted from the entire document for each document in the documents list.
|
||||||
documents list.
|
|
||||||
This parameter is optional and can be overridden in the `run` method.
|
This parameter is optional and can be overridden in the `run` method.
|
||||||
:param raise_on_failure: Whether to raise an error on failure during the execution of the Generator or
|
:param raise_on_failure: Whether to raise an error on failure during the execution of the Generator or
|
||||||
validation of the JSON output.
|
validation of the JSON output.
|
||||||
@ -195,11 +204,32 @@ class LLMMetadataExtractor:
|
|||||||
self.builder = PromptBuilder(prompt, required_variables=variables)
|
self.builder = PromptBuilder(prompt, required_variables=variables)
|
||||||
self.raise_on_failure = raise_on_failure
|
self.raise_on_failure = raise_on_failure
|
||||||
self.expected_keys = expected_keys or []
|
self.expected_keys = expected_keys or []
|
||||||
self.generator_api = (
|
generator_api_params = generator_api_params or {}
|
||||||
|
|
||||||
|
if generator_api is None and chat_generator is None:
|
||||||
|
raise ValueError("Either generator_api or chat_generator must be provided.")
|
||||||
|
|
||||||
|
if chat_generator is not None:
|
||||||
|
self._chat_generator = chat_generator
|
||||||
|
if generator_api is not None:
|
||||||
|
logger.warning(
|
||||||
|
"Both chat_generator and generator_api are provided. "
|
||||||
|
"chat_generator will be used. generator_api/generator_api_params/LLMProvider are deprecated and "
|
||||||
|
"will be removed in Haystack 2.13.0."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
warnings.warn(
|
||||||
|
"generator_api, generator_api_params, and LLMProvider are deprecated and will be removed in Haystack "
|
||||||
|
"2.13.0. Use chat_generator instead. For example, change `generator_api=LLMProvider.OPENAI` to "
|
||||||
|
"`chat_generator=OpenAIChatGenerator()`.",
|
||||||
|
DeprecationWarning,
|
||||||
|
)
|
||||||
|
assert generator_api is not None # verified by the checks above
|
||||||
|
generator_api = (
|
||||||
generator_api if isinstance(generator_api, LLMProvider) else LLMProvider.from_str(generator_api)
|
generator_api if isinstance(generator_api, LLMProvider) else LLMProvider.from_str(generator_api)
|
||||||
)
|
)
|
||||||
self.generator_api_params = generator_api_params or {}
|
self._chat_generator = self._init_generator(generator_api, generator_api_params)
|
||||||
self.llm_provider = self._init_generator(self.generator_api, self.generator_api_params)
|
|
||||||
self.splitter = DocumentSplitter(split_by="page", split_length=1)
|
self.splitter = DocumentSplitter(split_by="page", split_length=1)
|
||||||
self.expanded_range = expand_page_range(page_range) if page_range else None
|
self.expanded_range = expand_page_range(page_range) if page_range else None
|
||||||
self.max_workers = max_workers
|
self.max_workers = max_workers
|
||||||
@ -233,8 +263,8 @@ class LLMMetadataExtractor:
|
|||||||
"""
|
"""
|
||||||
Warm up the LLM provider component.
|
Warm up the LLM provider component.
|
||||||
"""
|
"""
|
||||||
if hasattr(self.llm_provider, "warm_up"):
|
if hasattr(self._chat_generator, "warm_up"):
|
||||||
self.llm_provider.warm_up()
|
self._chat_generator.warm_up()
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
@ -244,13 +274,10 @@ class LLMMetadataExtractor:
|
|||||||
Dictionary with serialized data.
|
Dictionary with serialized data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
llm_provider = self.llm_provider.to_dict()
|
|
||||||
|
|
||||||
return default_to_dict(
|
return default_to_dict(
|
||||||
self,
|
self,
|
||||||
prompt=self.prompt,
|
prompt=self.prompt,
|
||||||
generator_api=self.generator_api.value,
|
chat_generator=self._chat_generator.to_dict(),
|
||||||
generator_api_params=llm_provider["init_parameters"],
|
|
||||||
expected_keys=self.expected_keys,
|
expected_keys=self.expected_keys,
|
||||||
page_range=self.expanded_range,
|
page_range=self.expanded_range,
|
||||||
raise_on_failure=self.raise_on_failure,
|
raise_on_failure=self.raise_on_failure,
|
||||||
@ -270,6 +297,15 @@ class LLMMetadataExtractor:
|
|||||||
|
|
||||||
init_parameters = data.get("init_parameters", {})
|
init_parameters = data.get("init_parameters", {})
|
||||||
|
|
||||||
|
# new deserialization with chat_generator
|
||||||
|
if init_parameters.get("chat_generator") is not None:
|
||||||
|
chat_generator_class = import_class_by_name(init_parameters["chat_generator"]["type"])
|
||||||
|
assert hasattr(chat_generator_class, "from_dict") # we know but mypy doesn't
|
||||||
|
chat_generator_instance = chat_generator_class.from_dict(init_parameters["chat_generator"])
|
||||||
|
data["init_parameters"]["chat_generator"] = chat_generator_instance
|
||||||
|
return default_from_dict(cls, data)
|
||||||
|
|
||||||
|
# legacy deserialization
|
||||||
if "generator_api" in init_parameters:
|
if "generator_api" in init_parameters:
|
||||||
data["init_parameters"]["generator_api"] = LLMProvider.from_str(data["init_parameters"]["generator_api"])
|
data["init_parameters"]["generator_api"] = LLMProvider.from_str(data["init_parameters"]["generator_api"])
|
||||||
|
|
||||||
@ -364,15 +400,15 @@ class LLMMetadataExtractor:
|
|||||||
return {"replies": ["{}"]}
|
return {"replies": ["{}"]}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = self.llm_provider.run(messages=[prompt])
|
result = self._chat_generator.run(messages=[prompt])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
|
||||||
"LLM {class_name} execution failed. Skipping metadata extraction. Failed with exception '{error}'.",
|
|
||||||
class_name=self.llm_provider.__class__.__name__,
|
|
||||||
error=e,
|
|
||||||
)
|
|
||||||
if self.raise_on_failure:
|
if self.raise_on_failure:
|
||||||
raise e
|
raise e
|
||||||
|
logger.error(
|
||||||
|
"LLM {class_name} execution failed. Skipping metadata extraction. Failed with exception '{error}'.",
|
||||||
|
class_name=self._chat_generator.__class__.__name__,
|
||||||
|
error=e,
|
||||||
|
)
|
||||||
result = {"error": "LLM failed with exception: " + str(e)}
|
result = {"error": "LLM failed with exception: " + str(e)}
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|||||||
@ -0,0 +1,14 @@
|
|||||||
|
---
|
||||||
|
enhancements:
|
||||||
|
- |
|
||||||
|
`LLMMetadataExtractor` now accepts a `chat_generator` initialization parameter, consisting of a pre-configured
|
||||||
|
`ChatGenerator` instance.
|
||||||
|
Regardless of whether `LLMMetadataExtractor` is initialized using `generator_api` and `generator_api_params` or
|
||||||
|
the new `chat_generator` parameter, the serialization format will only include `chat_generator` in preparation
|
||||||
|
for the future removal of `generator_api` and `generator_api_params`.
|
||||||
|
deprecations:
|
||||||
|
- |
|
||||||
|
The `generator_api` and `generator_api_params` initialization parameters of `LLMMetadataExtractor` and the
|
||||||
|
`LLMProvider` enum are deprecated and will be removed in Haystack 2.13.0. Use `chat_generator` instead to configure
|
||||||
|
the underlying LLM. For example, change `generator_api=LLMProvider.OPENAI` to
|
||||||
|
`chat_generator=OpenAIChatGenerator()`.
|
||||||
@ -1,13 +1,16 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from unittest.mock import Mock
|
||||||
from haystack import Document, Pipeline
|
from haystack import Document, Pipeline
|
||||||
from haystack.components.builders import PromptBuilder
|
from haystack.components.builders import PromptBuilder
|
||||||
from haystack.components.writers import DocumentWriter
|
from haystack.components.writers import DocumentWriter
|
||||||
from haystack.dataclasses import ChatMessage
|
from haystack.dataclasses import ChatMessage
|
||||||
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
||||||
|
|
||||||
|
|
||||||
from haystack.components.extractors import LLMMetadataExtractor, LLMProvider
|
from haystack.components.extractors import LLMMetadataExtractor, LLMProvider
|
||||||
|
from haystack.components.generators.chat import OpenAIChatGenerator
|
||||||
|
|
||||||
|
|
||||||
class TestLLMMetadataExtractor:
|
class TestLLMMetadataExtractor:
|
||||||
@ -17,7 +20,7 @@ class TestLLMMetadataExtractor:
|
|||||||
prompt="prompt {{document.content}}", expected_keys=["key1", "key2"], generator_api=LLMProvider.OPENAI
|
prompt="prompt {{document.content}}", expected_keys=["key1", "key2"], generator_api=LLMProvider.OPENAI
|
||||||
)
|
)
|
||||||
assert isinstance(extractor.builder, PromptBuilder)
|
assert isinstance(extractor.builder, PromptBuilder)
|
||||||
assert extractor.generator_api == LLMProvider.OPENAI
|
assert isinstance(extractor._chat_generator, OpenAIChatGenerator)
|
||||||
assert extractor.expected_keys == ["key1", "key2"]
|
assert extractor.expected_keys == ["key1", "key2"]
|
||||||
assert extractor.raise_on_failure is False
|
assert extractor.raise_on_failure is False
|
||||||
|
|
||||||
@ -34,10 +37,23 @@ class TestLLMMetadataExtractor:
|
|||||||
assert isinstance(extractor.builder, PromptBuilder)
|
assert isinstance(extractor.builder, PromptBuilder)
|
||||||
assert extractor.expected_keys == ["key1", "key2"]
|
assert extractor.expected_keys == ["key1", "key2"]
|
||||||
assert extractor.raise_on_failure is True
|
assert extractor.raise_on_failure is True
|
||||||
assert extractor.generator_api == LLMProvider.OPENAI
|
assert isinstance(extractor._chat_generator, OpenAIChatGenerator)
|
||||||
assert extractor.generator_api_params == {"model": "gpt-3.5-turbo", "generation_kwargs": {"temperature": 0.5}}
|
assert extractor._chat_generator.model == "gpt-3.5-turbo"
|
||||||
|
assert extractor._chat_generator.generation_kwargs == {"temperature": 0.5}
|
||||||
assert extractor.expanded_range == [1, 2, 3, 4, 5]
|
assert extractor.expanded_range == [1, 2, 3, 4, 5]
|
||||||
|
|
||||||
|
def test_init_with_chat_generator(self, monkeypatch):
|
||||||
|
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
||||||
|
chat_generator = OpenAIChatGenerator(model="gpt-4o-mini", generation_kwargs={"temperature": 0.5})
|
||||||
|
|
||||||
|
extractor = LLMMetadataExtractor(
|
||||||
|
prompt="prompt {{document.content}}", expected_keys=["key1", "key2"], chat_generator=chat_generator
|
||||||
|
)
|
||||||
|
assert isinstance(extractor._chat_generator, OpenAIChatGenerator)
|
||||||
|
assert extractor._chat_generator.model == "gpt-4o-mini"
|
||||||
|
assert extractor._chat_generator.generation_kwargs == {"temperature": 0.5}
|
||||||
|
assert extractor.expected_keys == ["key1", "key2"]
|
||||||
|
|
||||||
def test_init_missing_prompt_variable(self, monkeypatch):
|
def test_init_missing_prompt_variable(self, monkeypatch):
|
||||||
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
@ -45,6 +61,25 @@ class TestLLMMetadataExtractor:
|
|||||||
prompt="prompt {{ wrong_variable }}", expected_keys=["key1", "key2"], generator_api=LLMProvider.OPENAI
|
prompt="prompt {{ wrong_variable }}", expected_keys=["key1", "key2"], generator_api=LLMProvider.OPENAI
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_init_fails_without_generator_api_or_chat_generator(self, monkeypatch):
|
||||||
|
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_ = LLMMetadataExtractor(prompt="prompt {{document.content}}", expected_keys=["key1", "key2"])
|
||||||
|
|
||||||
|
def test_init_chat_generator_overrides_generator_api(self, monkeypatch):
|
||||||
|
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
||||||
|
chat_generator = OpenAIChatGenerator(model="gpt-4o-mini", generation_kwargs={"temperature": 0.5})
|
||||||
|
extractor = LLMMetadataExtractor(
|
||||||
|
prompt="prompt {{document.content}}",
|
||||||
|
expected_keys=["key1", "key2"],
|
||||||
|
generator_api=LLMProvider.AWS_BEDROCK,
|
||||||
|
chat_generator=chat_generator,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(extractor._chat_generator, OpenAIChatGenerator)
|
||||||
|
assert extractor._chat_generator.model == "gpt-4o-mini"
|
||||||
|
assert extractor._chat_generator.generation_kwargs == {"temperature": 0.5}
|
||||||
|
|
||||||
def test_to_dict_openai(self, monkeypatch):
|
def test_to_dict_openai(self, monkeypatch):
|
||||||
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
||||||
extractor = LLMMetadataExtractor(
|
extractor = LLMMetadataExtractor(
|
||||||
@ -59,29 +94,55 @@ class TestLLMMetadataExtractor:
|
|||||||
assert extractor_dict == {
|
assert extractor_dict == {
|
||||||
"type": "haystack.components.extractors.llm_metadata_extractor.LLMMetadataExtractor",
|
"type": "haystack.components.extractors.llm_metadata_extractor.LLMMetadataExtractor",
|
||||||
"init_parameters": {
|
"init_parameters": {
|
||||||
"prompt": "some prompt that was used with the LLM {{document.content}}",
|
"chat_generator": {
|
||||||
"expected_keys": ["key1", "key2"],
|
"init_parameters": {
|
||||||
"raise_on_failure": True,
|
|
||||||
"generator_api": "openai",
|
|
||||||
"page_range": None,
|
|
||||||
"generator_api_params": {
|
|
||||||
"api_base_url": None,
|
"api_base_url": None,
|
||||||
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
|
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
|
||||||
"generation_kwargs": {"temperature": 0.5},
|
"generation_kwargs": {"temperature": 0.5},
|
||||||
|
"max_retries": None,
|
||||||
"model": "gpt-4o-mini",
|
"model": "gpt-4o-mini",
|
||||||
"organization": None,
|
"organization": None,
|
||||||
"streaming_callback": None,
|
"streaming_callback": None,
|
||||||
"max_retries": None,
|
|
||||||
"timeout": None,
|
"timeout": None,
|
||||||
"tools": None,
|
"tools": None,
|
||||||
"tools_strict": False,
|
"tools_strict": False,
|
||||||
},
|
},
|
||||||
|
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
|
||||||
|
},
|
||||||
|
"prompt": "some prompt that was used with the LLM {{document.content}}",
|
||||||
|
"expected_keys": ["key1", "key2"],
|
||||||
|
"raise_on_failure": True,
|
||||||
|
"page_range": None,
|
||||||
|
"max_workers": 3,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_to_dict_openai_using_chat_generator(self, monkeypatch):
|
||||||
|
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
||||||
|
chat_generator = OpenAIChatGenerator(model="gpt-4o-mini", generation_kwargs={"temperature": 0.5})
|
||||||
|
extractor = LLMMetadataExtractor(
|
||||||
|
prompt="some prompt that was used with the LLM {{document.content}}",
|
||||||
|
expected_keys=["key1", "key2"],
|
||||||
|
chat_generator=chat_generator,
|
||||||
|
raise_on_failure=True,
|
||||||
|
)
|
||||||
|
extractor_dict = extractor.to_dict()
|
||||||
|
|
||||||
|
assert extractor_dict == {
|
||||||
|
"type": "haystack.components.extractors.llm_metadata_extractor.LLMMetadataExtractor",
|
||||||
|
"init_parameters": {
|
||||||
|
"prompt": "some prompt that was used with the LLM {{document.content}}",
|
||||||
|
"expected_keys": ["key1", "key2"],
|
||||||
|
"raise_on_failure": True,
|
||||||
|
"chat_generator": chat_generator.to_dict(),
|
||||||
|
"page_range": None,
|
||||||
"max_workers": 3,
|
"max_workers": 3,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def test_from_dict_openai(self, monkeypatch):
|
def test_from_dict_openai(self, monkeypatch):
|
||||||
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
||||||
|
|
||||||
extractor_dict = {
|
extractor_dict = {
|
||||||
"type": "haystack.components.extractors.llm_metadata_extractor.LLMMetadataExtractor",
|
"type": "haystack.components.extractors.llm_metadata_extractor.LLMMetadataExtractor",
|
||||||
"init_parameters": {
|
"init_parameters": {
|
||||||
@ -103,12 +164,38 @@ class TestLLMMetadataExtractor:
|
|||||||
assert extractor.raise_on_failure is True
|
assert extractor.raise_on_failure is True
|
||||||
assert extractor.expected_keys == ["key1", "key2"]
|
assert extractor.expected_keys == ["key1", "key2"]
|
||||||
assert extractor.prompt == "some prompt that was used with the LLM {{document.content}}"
|
assert extractor.prompt == "some prompt that was used with the LLM {{document.content}}"
|
||||||
assert extractor.generator_api == LLMProvider.OPENAI
|
assert isinstance(extractor._chat_generator, OpenAIChatGenerator)
|
||||||
|
assert extractor._chat_generator.model == "gpt-4o-mini"
|
||||||
|
|
||||||
def test_warm_up(self, monkeypatch):
|
def test_from_dict_openai_using_chat_generator(self, monkeypatch):
|
||||||
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
||||||
extractor = LLMMetadataExtractor(prompt="prompt {{document.content}}", generator_api=LLMProvider.OPENAI)
|
chat_generator = OpenAIChatGenerator(model="gpt-4o-mini", generation_kwargs={"temperature": 0.5})
|
||||||
assert extractor.warm_up() is None
|
|
||||||
|
extractor_dict = {
|
||||||
|
"type": "haystack.components.extractors.llm_metadata_extractor.LLMMetadataExtractor",
|
||||||
|
"init_parameters": {
|
||||||
|
"prompt": "some prompt that was used with the LLM {{document.content}}",
|
||||||
|
"expected_keys": ["key1", "key2"],
|
||||||
|
"chat_generator": chat_generator.to_dict(),
|
||||||
|
"raise_on_failure": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
extractor = LLMMetadataExtractor.from_dict(extractor_dict)
|
||||||
|
assert extractor.raise_on_failure is True
|
||||||
|
assert extractor.expected_keys == ["key1", "key2"]
|
||||||
|
assert extractor.prompt == "some prompt that was used with the LLM {{document.content}}"
|
||||||
|
assert extractor._chat_generator.to_dict() == chat_generator.to_dict()
|
||||||
|
|
||||||
|
def test_warm_up_with_chat_generator(self, monkeypatch):
|
||||||
|
mock_chat_generator = Mock()
|
||||||
|
|
||||||
|
extractor = LLMMetadataExtractor(prompt="prompt {{document.content}}", chat_generator=mock_chat_generator)
|
||||||
|
|
||||||
|
mock_chat_generator.warm_up.assert_not_called()
|
||||||
|
|
||||||
|
extractor.warm_up()
|
||||||
|
|
||||||
|
mock_chat_generator.warm_up.assert_called_once()
|
||||||
|
|
||||||
def test_extract_metadata(self, monkeypatch):
|
def test_extract_metadata(self, monkeypatch):
|
||||||
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
||||||
@ -287,7 +374,7 @@ output:
|
|||||||
|
|
||||||
doc_store = InMemoryDocumentStore()
|
doc_store = InMemoryDocumentStore()
|
||||||
extractor = LLMMetadataExtractor(
|
extractor = LLMMetadataExtractor(
|
||||||
prompt=ner_prompt, expected_keys=["entities"], generator_api=LLMProvider.OPENAI
|
prompt=ner_prompt, expected_keys=["entities"], chat_generator=OpenAIChatGenerator()
|
||||||
)
|
)
|
||||||
writer = DocumentWriter(document_store=doc_store)
|
writer = DocumentWriter(document_store=doc_store)
|
||||||
pipeline = Pipeline()
|
pipeline = Pipeline()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user