haystack/test/components/extractors/test_llm_metadata_extractor.py
David S. Batista f798a9e935
feat: adding LLMMetadataExtractor (#8833)
* fixing linting

* adding release notes

* updating tests

* adding to pydocs

* fixing typing due to Optional

* fixing docstring
2025-02-10 16:54:25 +00:00

385 lines
18 KiB
Python

import os
from unittest.mock import MagicMock
import boto3
import pytest
from haystack import Document, Pipeline
from haystack.components.builders import PromptBuilder
from haystack.components.writers import DocumentWriter
from haystack.dataclasses import ChatMessage
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.components.extractors import LLMMetadataExtractor, LLMProvider
class TestLLMMetadataExtractor:
@pytest.fixture
def boto3_session_mock(self, monkeypatch: pytest.MonkeyPatch) -> MagicMock:
mock = MagicMock()
monkeypatch.setattr(boto3, "Session", mock)
return mock
def test_init_default(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
extractor = LLMMetadataExtractor(
prompt="prompt {{document.content}}", expected_keys=["key1", "key2"], generator_api=LLMProvider.OPENAI
)
assert isinstance(extractor.builder, PromptBuilder)
assert extractor.generator_api == LLMProvider.OPENAI
assert extractor.expected_keys == ["key1", "key2"]
assert extractor.raise_on_failure is False
def test_init_with_parameters(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
extractor = LLMMetadataExtractor(
prompt="prompt {{document.content}}",
expected_keys=["key1", "key2"],
raise_on_failure=True,
generator_api=LLMProvider.OPENAI,
generator_api_params={"model": "gpt-3.5-turbo", "generation_kwargs": {"temperature": 0.5}},
page_range=["1-5"],
)
assert isinstance(extractor.builder, PromptBuilder)
assert extractor.expected_keys == ["key1", "key2"]
assert extractor.raise_on_failure is True
assert extractor.generator_api == LLMProvider.OPENAI
assert extractor.generator_api_params == {"model": "gpt-3.5-turbo", "generation_kwargs": {"temperature": 0.5}}
assert extractor.expanded_range == [1, 2, 3, 4, 5]
def test_init_missing_prompt_variable(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
with pytest.raises(ValueError):
_ = LLMMetadataExtractor(
prompt="prompt {{ wrong_variable }}", expected_keys=["key1", "key2"], generator_api=LLMProvider.OPENAI
)
def test_to_dict_openai(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
extractor = LLMMetadataExtractor(
prompt="some prompt that was used with the LLM {{document.content}}",
expected_keys=["key1", "key2"],
generator_api=LLMProvider.OPENAI,
generator_api_params={"model": "gpt-4o-mini", "generation_kwargs": {"temperature": 0.5}},
raise_on_failure=True,
)
extractor_dict = extractor.to_dict()
assert extractor_dict == {
"type": "haystack.components.extractors.llm_metadata_extractor.LLMMetadataExtractor",
"init_parameters": {
"prompt": "some prompt that was used with the LLM {{document.content}}",
"expected_keys": ["key1", "key2"],
"raise_on_failure": True,
"generator_api": "openai",
"page_range": None,
"generator_api_params": {
"api_base_url": None,
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
"generation_kwargs": {"temperature": 0.5},
"model": "gpt-4o-mini",
"organization": None,
"streaming_callback": None,
"max_retries": None,
"timeout": None,
"tools": None,
"tools_strict": False,
},
"max_workers": 3,
},
}
def test_to_dict_aws_bedrock(self, boto3_session_mock):
extractor = LLMMetadataExtractor(
prompt="some prompt that was used with the LLM {{document.content}}",
expected_keys=["key1", "key2"],
generator_api=LLMProvider.AWS_BEDROCK,
generator_api_params={"model": "meta.llama.test"},
raise_on_failure=True,
)
extractor_dict = extractor.to_dict()
assert extractor_dict == {
"type": "haystack.components.extractors.llm_metadata_extractor.LLMMetadataExtractor",
"init_parameters": {
"prompt": "some prompt that was used with the LLM {{document.content}}",
"generator_api": "aws_bedrock",
"generator_api_params": {
"aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False},
"aws_secret_access_key": {
"type": "env_var",
"env_vars": ["AWS_SECRET_ACCESS_KEY"],
"strict": False,
},
"aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False},
"aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False},
"aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False},
"model": "meta.llama.test",
"stop_words": [],
"generation_kwargs": {},
"streaming_callback": None,
"boto3_config": None,
"tools": None,
},
"expected_keys": ["key1", "key2"],
"page_range": None,
"raise_on_failure": True,
"max_workers": 3,
},
}
def test_from_dict_openai(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
extractor_dict = {
"type": "haystack.components.extractors.llm_metadata_extractor.LLMMetadataExtractor",
"init_parameters": {
"prompt": "some prompt that was used with the LLM {{document.content}}",
"expected_keys": ["key1", "key2"],
"raise_on_failure": True,
"generator_api": "openai",
"generator_api_params": {
"api_base_url": None,
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
"generation_kwargs": {},
"model": "gpt-4o-mini",
"organization": None,
"streaming_callback": None,
},
},
}
extractor = LLMMetadataExtractor.from_dict(extractor_dict)
assert extractor.raise_on_failure is True
assert extractor.expected_keys == ["key1", "key2"]
assert extractor.prompt == "some prompt that was used with the LLM {{document.content}}"
assert extractor.generator_api == LLMProvider.OPENAI
def test_from_dict_aws_bedrock(self, boto3_session_mock):
extractor_dict = {
"type": "haystack.components.extractors.llm_metadata_extractor.LLMMetadataExtractor",
"init_parameters": {
"prompt": "some prompt that was used with the LLM {{document.content}}",
"generator_api": "aws_bedrock",
"generator_api_params": {
"aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False},
"aws_secret_access_key": {
"type": "env_var",
"env_vars": ["AWS_SECRET_ACCESS_KEY"],
"strict": False,
},
"aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False},
"aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False},
"aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False},
"model": "meta.llama.test",
"stop_words": [],
"generation_kwargs": {},
"streaming_callback": None,
"boto3_config": None,
"tools": None,
},
"expected_keys": ["key1", "key2"],
"page_range": None,
"raise_on_failure": True,
"max_workers": 3,
},
}
extractor = LLMMetadataExtractor.from_dict(extractor_dict)
assert extractor.raise_on_failure is True
assert extractor.expected_keys == ["key1", "key2"]
assert extractor.prompt == "some prompt that was used with the LLM {{document.content}}"
assert extractor.generator_api == LLMProvider.AWS_BEDROCK
assert extractor.llm_provider.model == "meta.llama.test"
def test_warm_up(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
extractor = LLMMetadataExtractor(prompt="prompt {{document.content}}", generator_api=LLMProvider.OPENAI)
assert extractor.warm_up() is None
def test_extract_metadata(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
extractor = LLMMetadataExtractor(prompt="prompt {{document.content}}", generator_api=LLMProvider.OPENAI)
result = extractor._extract_metadata(llm_answer='{"output": "valid json"}')
assert result == {"output": "valid json"}
def test_extract_metadata_invalid_json(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
extractor = LLMMetadataExtractor(
prompt="prompt {{document.content}}", generator_api=LLMProvider.OPENAI, raise_on_failure=True
)
with pytest.raises(ValueError):
extractor._extract_metadata(llm_answer='{"output: "valid json"}')
def test_extract_metadata_missing_key(self, monkeypatch, caplog):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
extractor = LLMMetadataExtractor(
prompt="prompt {{document.content}}", generator_api=LLMProvider.OPENAI, expected_keys=["key1"]
)
extractor._extract_metadata(llm_answer='{"output": "valid json"}')
assert "Expected response from LLM to be a JSON with keys" in caplog.text
def test_prepare_prompts(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
extractor = LLMMetadataExtractor(
prompt="some_user_definer_prompt {{document.content}}", generator_api=LLMProvider.OPENAI
)
docs = [
Document(content="deepset was founded in 2018 in Berlin, and is known for its Haystack framework"),
Document(
content="Hugging Face is a company founded in Paris, France and is known for its Transformers library"
),
]
prompts = extractor._prepare_prompts(docs)
assert prompts == [
ChatMessage.from_dict(
{
"_role": "user",
"_meta": {},
"_name": None,
"_content": [
{
"text": "some_user_definer_prompt deepset was founded in 2018 in Berlin, and is known for its Haystack framework"
}
],
}
),
ChatMessage.from_dict(
{
"_role": "user",
"_meta": {},
"_name": None,
"_content": [
{
"text": "some_user_definer_prompt Hugging Face is a company founded in Paris, France and is known for its Transformers library"
}
],
}
),
]
def test_prepare_prompts_empty_document(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
extractor = LLMMetadataExtractor(
prompt="some_user_definer_prompt {{document.content}}", generator_api=LLMProvider.OPENAI
)
docs = [
Document(content=""),
Document(
content="Hugging Face is a company founded in Paris, France and is known for its Transformers library"
),
]
prompts = extractor._prepare_prompts(docs)
assert prompts == [
None,
ChatMessage.from_dict(
{
"_role": "user",
"_meta": {},
"_name": None,
"_content": [
{
"text": "some_user_definer_prompt Hugging Face is a company founded in Paris, France and is known for its Transformers library"
}
],
}
),
]
def test_prepare_prompts_expanded_range(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
extractor = LLMMetadataExtractor(
prompt="some_user_definer_prompt {{document.content}}", generator_api=LLMProvider.OPENAI, page_range=["1-2"]
)
docs = [
Document(
content="Hugging Face is a company founded in Paris, France and is known for its Transformers library\fPage 2\fPage 3"
)
]
prompts = extractor._prepare_prompts(docs, expanded_range=[1, 2])
assert prompts == [
ChatMessage.from_dict(
{
"_role": "user",
"_meta": {},
"_name": None,
"_content": [
{
"text": "some_user_definer_prompt Hugging Face is a company founded in Paris, France and is known for its Transformers library\x0cPage 2\x0c"
}
],
}
)
]
def test_run_no_documents(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
extractor = LLMMetadataExtractor(prompt="prompt {{document.content}}", generator_api=LLMProvider.OPENAI)
result = extractor.run(documents=[])
assert result["documents"] == []
assert result["failed_documents"] == []
@pytest.mark.integration
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
def test_live_run(self):
docs = [
Document(content="deepset was founded in 2018 in Berlin, and is known for its Haystack framework"),
Document(
content="Hugging Face is a company founded in Paris, France and is known for its Transformers library"
),
]
ner_prompt = """-Goal-
Given text and a list of entity types, identify all entities of those types from the text.
-Steps-
1. Identify all entities. For each identified entity, extract the following information:
- entity_name: Name of the entity, capitalized
- entity_type: One of the following types: [organization, product, service, industry]
Format each entity as {"entity": <entity_name>, "entity_type": <entity_type>}
2. Return output in a single list with all the entities identified in steps 1.
-Examples-
######################
Example 1:
entity_types: [organization, person, partnership, financial metric, product, service, industry, investment strategy, market trend]
text: Another area of strength is our co-brand issuance. Visa is the primary network partner for eight of the top
10 co-brand partnerships in the US today and we are pleased that Visa has finalized a multi-year extension of
our successful credit co-branded partnership with Alaska Airlines, a portfolio that benefits from a loyal customer
base and high cross-border usage.
We have also had significant co-brand momentum in CEMEA. First, we launched a new co-brand card in partnership
with Qatar Airways, British Airways and the National Bank of Kuwait. Second, we expanded our strong global
Marriott relationship to launch Qatar's first hospitality co-branded card with Qatar Islamic Bank. Across the
United Arab Emirates, we now have exclusive agreements with all the leading airlines marked by a recent
agreement with Emirates Skywards.
And we also signed an inaugural Airline co-brand agreement in Morocco with Royal Air Maroc. Now newer digital
issuers are equally
------------------------
output:
{"entities": [{"entity": "Visa", "entity_type": "company"}, {"entity": "Alaska Airlines", "entity_type": "company"}, {"entity": "Qatar Airways", "entity_type": "company"}, {"entity": "British Airways", "entity_type": "company"}, {"entity": "National Bank of Kuwait", "entity_type": "company"}, {"entity": "Marriott", "entity_type": "company"}, {"entity": "Qatar Islamic Bank", "entity_type": "company"}, {"entity": "Emirates Skywards", "entity_type": "company"}, {"entity": "Royal Air Maroc", "entity_type": "company"}]}
#############################
-Real Data-
######################
entity_types: [company, organization, person, country, product, service]
text: {{ document.content }}
######################
output:
"""
doc_store = InMemoryDocumentStore()
extractor = LLMMetadataExtractor(
prompt=ner_prompt, expected_keys=["entities"], generator_api=LLMProvider.OPENAI
)
writer = DocumentWriter(document_store=doc_store)
pipeline = Pipeline()
pipeline.add_component("extractor", extractor)
pipeline.add_component("doc_writer", writer)
pipeline.connect("extractor.documents", "doc_writer.documents")
pipeline.run(data={"documents": docs})
doc_store_docs = doc_store.filter_documents()
assert len(doc_store_docs) == 2
assert "entities" in doc_store_docs[0].meta
assert "entities" in doc_store_docs[1].meta