import os import pytest from unittest.mock import Mock from haystack import Document, Pipeline from haystack.components.builders import PromptBuilder from haystack.components.writers import DocumentWriter from haystack.dataclasses import ChatMessage from haystack.document_stores.in_memory import InMemoryDocumentStore from haystack.components.extractors import LLMMetadataExtractor from haystack.components.generators.chat import OpenAIChatGenerator class TestLLMMetadataExtractor: def test_init(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") chat_generator = OpenAIChatGenerator(model="gpt-4o-mini", generation_kwargs={"temperature": 0.5}) extractor = LLMMetadataExtractor( prompt="prompt {{document.content}}", expected_keys=["key1", "key2"], chat_generator=chat_generator ) assert isinstance(extractor._chat_generator, OpenAIChatGenerator) assert extractor._chat_generator.model == "gpt-4o-mini" assert extractor._chat_generator.generation_kwargs == {"temperature": 0.5} assert extractor.expected_keys == ["key1", "key2"] def test_init_missing_prompt_variable(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") chat_generator = OpenAIChatGenerator(model="gpt-4o-mini") with pytest.raises(ValueError): _ = LLMMetadataExtractor( prompt="prompt {{ wrong_variable }}", expected_keys=["key1", "key2"], chat_generator=chat_generator ) def test_init_fails_without_chat_generator(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") with pytest.raises(TypeError): _ = LLMMetadataExtractor(prompt="prompt {{document.content}}", expected_keys=["key1", "key2"]) def test_to_dict_openai(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") chat_generator = OpenAIChatGenerator(model="gpt-4o-mini", generation_kwargs={"temperature": 0.5}) extractor = LLMMetadataExtractor( prompt="some prompt that was used with the LLM {{document.content}}", expected_keys=["key1", "key2"], chat_generator=chat_generator, raise_on_failure=True, ) extractor_dict = extractor.to_dict() assert extractor_dict == { "type": "haystack.components.extractors.llm_metadata_extractor.LLMMetadataExtractor", "init_parameters": { "prompt": "some prompt that was used with the LLM {{document.content}}", "expected_keys": ["key1", "key2"], "raise_on_failure": True, "chat_generator": chat_generator.to_dict(), "page_range": None, "max_workers": 3, }, } def test_from_dict_openai(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") chat_generator = OpenAIChatGenerator(model="gpt-4o-mini", generation_kwargs={"temperature": 0.5}) extractor_dict = { "type": "haystack.components.extractors.llm_metadata_extractor.LLMMetadataExtractor", "init_parameters": { "prompt": "some prompt that was used with the LLM {{document.content}}", "expected_keys": ["key1", "key2"], "chat_generator": chat_generator.to_dict(), "raise_on_failure": True, }, } extractor = LLMMetadataExtractor.from_dict(extractor_dict) assert extractor.raise_on_failure is True assert extractor.expected_keys == ["key1", "key2"] assert extractor.prompt == "some prompt that was used with the LLM {{document.content}}" assert extractor._chat_generator.to_dict() == chat_generator.to_dict() def test_warm_up_with_chat_generator(self, monkeypatch): mock_chat_generator = Mock() extractor = LLMMetadataExtractor(prompt="prompt {{document.content}}", chat_generator=mock_chat_generator) mock_chat_generator.warm_up.assert_not_called() extractor.warm_up() mock_chat_generator.warm_up.assert_called_once() def test_extract_metadata(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") extractor = LLMMetadataExtractor(prompt="prompt {{document.content}}", chat_generator=OpenAIChatGenerator()) result = extractor._extract_metadata(llm_answer='{"output": "valid json"}') assert result == {"output": "valid json"} def test_extract_metadata_invalid_json(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") extractor = LLMMetadataExtractor( prompt="prompt {{document.content}}", chat_generator=OpenAIChatGenerator(), raise_on_failure=True ) with pytest.raises(ValueError): extractor._extract_metadata(llm_answer='{"output: "valid json"}') def test_extract_metadata_missing_key(self, monkeypatch, caplog): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") extractor = LLMMetadataExtractor( prompt="prompt {{document.content}}", chat_generator=OpenAIChatGenerator(), expected_keys=["key1"] ) extractor._extract_metadata(llm_answer='{"output": "valid json"}') assert "Expected response from LLM to be a JSON with keys" in caplog.text def test_prepare_prompts(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") extractor = LLMMetadataExtractor( prompt="some_user_definer_prompt {{document.content}}", chat_generator=OpenAIChatGenerator() ) docs = [ Document(content="deepset was founded in 2018 in Berlin, and is known for its Haystack framework"), Document( content="Hugging Face is a company founded in Paris, France and is known for its Transformers library" ), ] prompts = extractor._prepare_prompts(docs) assert prompts == [ ChatMessage.from_dict( { "_role": "user", "_meta": {}, "_name": None, "_content": [ { "text": "some_user_definer_prompt deepset was founded in 2018 in Berlin, and is known for its Haystack framework" } ], } ), ChatMessage.from_dict( { "_role": "user", "_meta": {}, "_name": None, "_content": [ { "text": "some_user_definer_prompt Hugging Face is a company founded in Paris, France and is known for its Transformers library" } ], } ), ] def test_prepare_prompts_empty_document(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") extractor = LLMMetadataExtractor( prompt="some_user_definer_prompt {{document.content}}", chat_generator=OpenAIChatGenerator() ) docs = [ Document(content=""), Document( content="Hugging Face is a company founded in Paris, France and is known for its Transformers library" ), ] prompts = extractor._prepare_prompts(docs) assert prompts == [ None, ChatMessage.from_dict( { "_role": "user", "_meta": {}, "_name": None, "_content": [ { "text": "some_user_definer_prompt Hugging Face is a company founded in Paris, France and is known for its Transformers library" } ], } ), ] def test_prepare_prompts_expanded_range(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") extractor = LLMMetadataExtractor( prompt="some_user_definer_prompt {{document.content}}", chat_generator=OpenAIChatGenerator(), page_range=["1-2"], ) docs = [ Document( content="Hugging Face is a company founded in Paris, France and is known for its Transformers library\fPage 2\fPage 3" ) ] prompts = extractor._prepare_prompts(docs, expanded_range=[1, 2]) assert prompts == [ ChatMessage.from_dict( { "_role": "user", "_meta": {}, "_name": None, "_content": [ { "text": "some_user_definer_prompt Hugging Face is a company founded in Paris, France and is known for its Transformers library\x0cPage 2\x0c" } ], } ) ] def test_run_no_documents(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") extractor = LLMMetadataExtractor(prompt="prompt {{document.content}}", chat_generator=OpenAIChatGenerator()) result = extractor.run(documents=[]) assert result["documents"] == [] assert result["failed_documents"] == [] def test_run_with_document_content_none(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") # Mock the chat generator to prevent actual LLM calls mock_chat_generator = Mock(spec=OpenAIChatGenerator) extractor = LLMMetadataExtractor( prompt="prompt {{document.content}}", chat_generator=mock_chat_generator, expected_keys=["some_key"] ) # Document with None content doc_with_none_content = Document(content=None) # also test with empty string content doc_with_empty_content = Document(content="") docs = [doc_with_none_content, doc_with_empty_content] result = extractor.run(documents=docs) # Assert that the documents are in failed_documents assert len(result["documents"]) == 0 assert len(result["failed_documents"]) == 2 failed_doc_none = result["failed_documents"][0] assert failed_doc_none.id == doc_with_none_content.id assert "metadata_extraction_error" in failed_doc_none.meta assert failed_doc_none.meta["metadata_extraction_error"] == "Document has no content, skipping LLM call." failed_doc_empty = result["failed_documents"][1] assert failed_doc_empty.id == doc_with_empty_content.id assert "metadata_extraction_error" in failed_doc_empty.meta assert failed_doc_empty.meta["metadata_extraction_error"] == "Document has no content, skipping LLM call." # Ensure no attempt was made to call the LLM mock_chat_generator.run.assert_not_called() @pytest.mark.integration @pytest.mark.skipif( not os.environ.get("OPENAI_API_KEY", None), reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", ) def test_live_run(self): docs = [ Document(content="deepset was founded in 2018 in Berlin, and is known for its Haystack framework"), Document( content="Hugging Face is a company founded in Paris, France and is known for its Transformers library" ), ] ner_prompt = """-Goal- Given text and a list of entity types, identify all entities of those types from the text. -Steps- 1. Identify all entities. For each identified entity, extract the following information: - entity_name: Name of the entity, capitalized - entity_type: One of the following types: [organization, product, service, industry] Format each entity as {"entity": , "entity_type": } 2. Return output in a single list with all the entities identified in steps 1. -Examples- ###################### Example 1: entity_types: [organization, person, partnership, financial metric, product, service, industry, investment strategy, market trend] text: Another area of strength is our co-brand issuance. Visa is the primary network partner for eight of the top 10 co-brand partnerships in the US today and we are pleased that Visa has finalized a multi-year extension of our successful credit co-branded partnership with Alaska Airlines, a portfolio that benefits from a loyal customer base and high cross-border usage. We have also had significant co-brand momentum in CEMEA. First, we launched a new co-brand card in partnership with Qatar Airways, British Airways and the National Bank of Kuwait. Second, we expanded our strong global Marriott relationship to launch Qatar's first hospitality co-branded card with Qatar Islamic Bank. Across the United Arab Emirates, we now have exclusive agreements with all the leading airlines marked by a recent agreement with Emirates Skywards. And we also signed an inaugural Airline co-brand agreement in Morocco with Royal Air Maroc. Now newer digital issuers are equally ------------------------ output: {"entities": [{"entity": "Visa", "entity_type": "company"}, {"entity": "Alaska Airlines", "entity_type": "company"}, {"entity": "Qatar Airways", "entity_type": "company"}, {"entity": "British Airways", "entity_type": "company"}, {"entity": "National Bank of Kuwait", "entity_type": "company"}, {"entity": "Marriott", "entity_type": "company"}, {"entity": "Qatar Islamic Bank", "entity_type": "company"}, {"entity": "Emirates Skywards", "entity_type": "company"}, {"entity": "Royal Air Maroc", "entity_type": "company"}]} ############################# -Real Data- ###################### entity_types: [company, organization, person, country, product, service] text: {{ document.content }} ###################### output: """ doc_store = InMemoryDocumentStore() extractor = LLMMetadataExtractor( prompt=ner_prompt, expected_keys=["entities"], chat_generator=OpenAIChatGenerator() ) writer = DocumentWriter(document_store=doc_store) pipeline = Pipeline() pipeline.add_component("extractor", extractor) pipeline.add_component("doc_writer", writer) pipeline.connect("extractor.documents", "doc_writer.documents") pipeline.run(data={"documents": docs}) doc_store_docs = doc_store.filter_documents() assert len(doc_store_docs) == 2 assert "entities" in doc_store_docs[0].meta assert "entities" in doc_store_docs[1].meta