haystack/test/components/extractors/test_llm_metadata_extractor.py
Sebastian Husch Lee 4edefe3e56
Feat: Support Azure Workload Identity Credential (#9012)
* Start adding support for passing callable to Azure components

* Add to chat version

* Fix test

* Add reno

* Add support to azure doc and text embedder

* Rename

* update llm metadata extractor

* Add tests for text embedder

* Update tests

* Remove unused fixture and import

* Update reno
2025-03-12 13:45:40 +01:00

303 lines
14 KiB
Python

import os
import pytest
from haystack import Document, Pipeline
from haystack.components.builders import PromptBuilder
from haystack.components.writers import DocumentWriter
from haystack.dataclasses import ChatMessage
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.components.extractors import LLMMetadataExtractor, LLMProvider
class TestLLMMetadataExtractor:
def test_init_default(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
extractor = LLMMetadataExtractor(
prompt="prompt {{document.content}}", expected_keys=["key1", "key2"], generator_api=LLMProvider.OPENAI
)
assert isinstance(extractor.builder, PromptBuilder)
assert extractor.generator_api == LLMProvider.OPENAI
assert extractor.expected_keys == ["key1", "key2"]
assert extractor.raise_on_failure is False
def test_init_with_parameters(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
extractor = LLMMetadataExtractor(
prompt="prompt {{document.content}}",
expected_keys=["key1", "key2"],
raise_on_failure=True,
generator_api=LLMProvider.OPENAI,
generator_api_params={"model": "gpt-3.5-turbo", "generation_kwargs": {"temperature": 0.5}},
page_range=["1-5"],
)
assert isinstance(extractor.builder, PromptBuilder)
assert extractor.expected_keys == ["key1", "key2"]
assert extractor.raise_on_failure is True
assert extractor.generator_api == LLMProvider.OPENAI
assert extractor.generator_api_params == {"model": "gpt-3.5-turbo", "generation_kwargs": {"temperature": 0.5}}
assert extractor.expanded_range == [1, 2, 3, 4, 5]
def test_init_missing_prompt_variable(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
with pytest.raises(ValueError):
_ = LLMMetadataExtractor(
prompt="prompt {{ wrong_variable }}", expected_keys=["key1", "key2"], generator_api=LLMProvider.OPENAI
)
def test_to_dict_openai(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
extractor = LLMMetadataExtractor(
prompt="some prompt that was used with the LLM {{document.content}}",
expected_keys=["key1", "key2"],
generator_api=LLMProvider.OPENAI,
generator_api_params={"model": "gpt-4o-mini", "generation_kwargs": {"temperature": 0.5}},
raise_on_failure=True,
)
extractor_dict = extractor.to_dict()
assert extractor_dict == {
"type": "haystack.components.extractors.llm_metadata_extractor.LLMMetadataExtractor",
"init_parameters": {
"prompt": "some prompt that was used with the LLM {{document.content}}",
"expected_keys": ["key1", "key2"],
"raise_on_failure": True,
"generator_api": "openai",
"page_range": None,
"generator_api_params": {
"api_base_url": None,
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
"generation_kwargs": {"temperature": 0.5},
"model": "gpt-4o-mini",
"organization": None,
"streaming_callback": None,
"max_retries": None,
"timeout": None,
"tools": None,
"tools_strict": False,
},
"max_workers": 3,
},
}
def test_from_dict_openai(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
extractor_dict = {
"type": "haystack.components.extractors.llm_metadata_extractor.LLMMetadataExtractor",
"init_parameters": {
"prompt": "some prompt that was used with the LLM {{document.content}}",
"expected_keys": ["key1", "key2"],
"raise_on_failure": True,
"generator_api": "openai",
"generator_api_params": {
"api_base_url": None,
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
"generation_kwargs": {},
"model": "gpt-4o-mini",
"organization": None,
"streaming_callback": None,
},
},
}
extractor = LLMMetadataExtractor.from_dict(extractor_dict)
assert extractor.raise_on_failure is True
assert extractor.expected_keys == ["key1", "key2"]
assert extractor.prompt == "some prompt that was used with the LLM {{document.content}}"
assert extractor.generator_api == LLMProvider.OPENAI
def test_warm_up(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
extractor = LLMMetadataExtractor(prompt="prompt {{document.content}}", generator_api=LLMProvider.OPENAI)
assert extractor.warm_up() is None
def test_extract_metadata(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
extractor = LLMMetadataExtractor(prompt="prompt {{document.content}}", generator_api=LLMProvider.OPENAI)
result = extractor._extract_metadata(llm_answer='{"output": "valid json"}')
assert result == {"output": "valid json"}
def test_extract_metadata_invalid_json(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
extractor = LLMMetadataExtractor(
prompt="prompt {{document.content}}", generator_api=LLMProvider.OPENAI, raise_on_failure=True
)
with pytest.raises(ValueError):
extractor._extract_metadata(llm_answer='{"output: "valid json"}')
def test_extract_metadata_missing_key(self, monkeypatch, caplog):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
extractor = LLMMetadataExtractor(
prompt="prompt {{document.content}}", generator_api=LLMProvider.OPENAI, expected_keys=["key1"]
)
extractor._extract_metadata(llm_answer='{"output": "valid json"}')
assert "Expected response from LLM to be a JSON with keys" in caplog.text
def test_prepare_prompts(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
extractor = LLMMetadataExtractor(
prompt="some_user_definer_prompt {{document.content}}", generator_api=LLMProvider.OPENAI
)
docs = [
Document(content="deepset was founded in 2018 in Berlin, and is known for its Haystack framework"),
Document(
content="Hugging Face is a company founded in Paris, France and is known for its Transformers library"
),
]
prompts = extractor._prepare_prompts(docs)
assert prompts == [
ChatMessage.from_dict(
{
"_role": "user",
"_meta": {},
"_name": None,
"_content": [
{
"text": "some_user_definer_prompt deepset was founded in 2018 in Berlin, and is known for its Haystack framework"
}
],
}
),
ChatMessage.from_dict(
{
"_role": "user",
"_meta": {},
"_name": None,
"_content": [
{
"text": "some_user_definer_prompt Hugging Face is a company founded in Paris, France and is known for its Transformers library"
}
],
}
),
]
def test_prepare_prompts_empty_document(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
extractor = LLMMetadataExtractor(
prompt="some_user_definer_prompt {{document.content}}", generator_api=LLMProvider.OPENAI
)
docs = [
Document(content=""),
Document(
content="Hugging Face is a company founded in Paris, France and is known for its Transformers library"
),
]
prompts = extractor._prepare_prompts(docs)
assert prompts == [
None,
ChatMessage.from_dict(
{
"_role": "user",
"_meta": {},
"_name": None,
"_content": [
{
"text": "some_user_definer_prompt Hugging Face is a company founded in Paris, France and is known for its Transformers library"
}
],
}
),
]
def test_prepare_prompts_expanded_range(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
extractor = LLMMetadataExtractor(
prompt="some_user_definer_prompt {{document.content}}", generator_api=LLMProvider.OPENAI, page_range=["1-2"]
)
docs = [
Document(
content="Hugging Face is a company founded in Paris, France and is known for its Transformers library\fPage 2\fPage 3"
)
]
prompts = extractor._prepare_prompts(docs, expanded_range=[1, 2])
assert prompts == [
ChatMessage.from_dict(
{
"_role": "user",
"_meta": {},
"_name": None,
"_content": [
{
"text": "some_user_definer_prompt Hugging Face is a company founded in Paris, France and is known for its Transformers library\x0cPage 2\x0c"
}
],
}
)
]
def test_run_no_documents(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
extractor = LLMMetadataExtractor(prompt="prompt {{document.content}}", generator_api=LLMProvider.OPENAI)
result = extractor.run(documents=[])
assert result["documents"] == []
assert result["failed_documents"] == []
@pytest.mark.integration
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
def test_live_run(self):
docs = [
Document(content="deepset was founded in 2018 in Berlin, and is known for its Haystack framework"),
Document(
content="Hugging Face is a company founded in Paris, France and is known for its Transformers library"
),
]
ner_prompt = """-Goal-
Given text and a list of entity types, identify all entities of those types from the text.
-Steps-
1. Identify all entities. For each identified entity, extract the following information:
- entity_name: Name of the entity, capitalized
- entity_type: One of the following types: [organization, product, service, industry]
Format each entity as {"entity": <entity_name>, "entity_type": <entity_type>}
2. Return output in a single list with all the entities identified in steps 1.
-Examples-
######################
Example 1:
entity_types: [organization, person, partnership, financial metric, product, service, industry, investment strategy, market trend]
text: Another area of strength is our co-brand issuance. Visa is the primary network partner for eight of the top
10 co-brand partnerships in the US today and we are pleased that Visa has finalized a multi-year extension of
our successful credit co-branded partnership with Alaska Airlines, a portfolio that benefits from a loyal customer
base and high cross-border usage.
We have also had significant co-brand momentum in CEMEA. First, we launched a new co-brand card in partnership
with Qatar Airways, British Airways and the National Bank of Kuwait. Second, we expanded our strong global
Marriott relationship to launch Qatar's first hospitality co-branded card with Qatar Islamic Bank. Across the
United Arab Emirates, we now have exclusive agreements with all the leading airlines marked by a recent
agreement with Emirates Skywards.
And we also signed an inaugural Airline co-brand agreement in Morocco with Royal Air Maroc. Now newer digital
issuers are equally
------------------------
output:
{"entities": [{"entity": "Visa", "entity_type": "company"}, {"entity": "Alaska Airlines", "entity_type": "company"}, {"entity": "Qatar Airways", "entity_type": "company"}, {"entity": "British Airways", "entity_type": "company"}, {"entity": "National Bank of Kuwait", "entity_type": "company"}, {"entity": "Marriott", "entity_type": "company"}, {"entity": "Qatar Islamic Bank", "entity_type": "company"}, {"entity": "Emirates Skywards", "entity_type": "company"}, {"entity": "Royal Air Maroc", "entity_type": "company"}]}
#############################
-Real Data-
######################
entity_types: [company, organization, person, country, product, service]
text: {{ document.content }}
######################
output:
"""
doc_store = InMemoryDocumentStore()
extractor = LLMMetadataExtractor(
prompt=ner_prompt, expected_keys=["entities"], generator_api=LLMProvider.OPENAI
)
writer = DocumentWriter(document_store=doc_store)
pipeline = Pipeline()
pipeline.add_component("extractor", extractor)
pipeline.add_component("doc_writer", writer)
pipeline.connect("extractor.documents", "doc_writer.documents")
pipeline.run(data={"documents": docs})
doc_store_docs = doc_store.filter_documents()
assert len(doc_store_docs) == 2
assert "entities" in doc_store_docs[0].meta
assert "entities" in doc_store_docs[1].meta