2025-05-26 17:22:51 +01:00
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
2025-02-10 17:54:25 +01:00
import os
import pytest
2025-03-24 18:38:09 +01:00
from unittest . mock import Mock
2025-02-10 17:54:25 +01:00
from haystack import Document , Pipeline
from haystack . components . writers import DocumentWriter
from haystack . dataclasses import ChatMessage
from haystack . document_stores . in_memory import InMemoryDocumentStore
2025-03-24 18:38:09 +01:00
2025-04-11 15:50:52 +02:00
from haystack . components . extractors import LLMMetadataExtractor
2025-03-24 18:38:09 +01:00
from haystack . components . generators . chat import OpenAIChatGenerator
2025-02-10 17:54:25 +01:00
class TestLLMMetadataExtractor :
2025-04-11 15:50:52 +02:00
def test_init ( self , monkeypatch ) :
2025-03-24 18:38:09 +01:00
monkeypatch . setenv ( " OPENAI_API_KEY " , " test-api-key " )
chat_generator = OpenAIChatGenerator ( model = " gpt-4o-mini " , generation_kwargs = { " temperature " : 0.5 } )
extractor = LLMMetadataExtractor (
prompt = " prompt {{ document.content}} " , expected_keys = [ " key1 " , " key2 " ] , chat_generator = chat_generator
)
assert isinstance ( extractor . _chat_generator , OpenAIChatGenerator )
assert extractor . _chat_generator . model == " gpt-4o-mini "
assert extractor . _chat_generator . generation_kwargs == { " temperature " : 0.5 }
assert extractor . expected_keys == [ " key1 " , " key2 " ]
2025-02-10 17:54:25 +01:00
def test_init_missing_prompt_variable ( self , monkeypatch ) :
monkeypatch . setenv ( " OPENAI_API_KEY " , " test-api-key " )
2025-04-11 15:50:52 +02:00
chat_generator = OpenAIChatGenerator ( model = " gpt-4o-mini " )
2025-02-10 17:54:25 +01:00
with pytest . raises ( ValueError ) :
_ = LLMMetadataExtractor (
2025-04-11 15:50:52 +02:00
prompt = " prompt {{ wrong_variable }} " , expected_keys = [ " key1 " , " key2 " ] , chat_generator = chat_generator
2025-02-10 17:54:25 +01:00
)
2025-04-11 15:50:52 +02:00
def test_init_fails_without_chat_generator ( self , monkeypatch ) :
2025-03-24 18:38:09 +01:00
monkeypatch . setenv ( " OPENAI_API_KEY " , " test-api-key " )
2025-04-11 15:50:52 +02:00
with pytest . raises ( TypeError ) :
2025-03-24 18:38:09 +01:00
_ = LLMMetadataExtractor ( prompt = " prompt {{ document.content}} " , expected_keys = [ " key1 " , " key2 " ] )
2025-02-10 17:54:25 +01:00
def test_to_dict_openai ( self , monkeypatch ) :
2025-03-24 18:38:09 +01:00
monkeypatch . setenv ( " OPENAI_API_KEY " , " test-api-key " )
chat_generator = OpenAIChatGenerator ( model = " gpt-4o-mini " , generation_kwargs = { " temperature " : 0.5 } )
extractor = LLMMetadataExtractor (
prompt = " some prompt that was used with the LLM {{ document.content}} " ,
expected_keys = [ " key1 " , " key2 " ] ,
chat_generator = chat_generator ,
raise_on_failure = True ,
)
extractor_dict = extractor . to_dict ( )
assert extractor_dict == {
" type " : " haystack.components.extractors.llm_metadata_extractor.LLMMetadataExtractor " ,
" init_parameters " : {
" prompt " : " some prompt that was used with the LLM {{ document.content}} " ,
" expected_keys " : [ " key1 " , " key2 " ] ,
" raise_on_failure " : True ,
" chat_generator " : chat_generator . to_dict ( ) ,
" page_range " : None ,
2025-02-10 17:54:25 +01:00
" max_workers " : 3 ,
} ,
}
def test_from_dict_openai ( self , monkeypatch ) :
monkeypatch . setenv ( " OPENAI_API_KEY " , " test-api-key " )
2025-03-24 18:38:09 +01:00
chat_generator = OpenAIChatGenerator ( model = " gpt-4o-mini " , generation_kwargs = { " temperature " : 0.5 } )
extractor_dict = {
" type " : " haystack.components.extractors.llm_metadata_extractor.LLMMetadataExtractor " ,
" init_parameters " : {
" prompt " : " some prompt that was used with the LLM {{ document.content}} " ,
" expected_keys " : [ " key1 " , " key2 " ] ,
" chat_generator " : chat_generator . to_dict ( ) ,
" raise_on_failure " : True ,
} ,
}
extractor = LLMMetadataExtractor . from_dict ( extractor_dict )
assert extractor . raise_on_failure is True
assert extractor . expected_keys == [ " key1 " , " key2 " ]
assert extractor . prompt == " some prompt that was used with the LLM {{ document.content}} "
assert extractor . _chat_generator . to_dict ( ) == chat_generator . to_dict ( )
def test_warm_up_with_chat_generator ( self , monkeypatch ) :
mock_chat_generator = Mock ( )
extractor = LLMMetadataExtractor ( prompt = " prompt {{ document.content}} " , chat_generator = mock_chat_generator )
mock_chat_generator . warm_up . assert_not_called ( )
extractor . warm_up ( )
mock_chat_generator . warm_up . assert_called_once ( )
2025-02-10 17:54:25 +01:00
def test_extract_metadata ( self , monkeypatch ) :
monkeypatch . setenv ( " OPENAI_API_KEY " , " test-api-key " )
2025-04-11 15:50:52 +02:00
extractor = LLMMetadataExtractor ( prompt = " prompt {{ document.content}} " , chat_generator = OpenAIChatGenerator ( ) )
2025-02-10 17:54:25 +01:00
result = extractor . _extract_metadata ( llm_answer = ' { " output " : " valid json " } ' )
assert result == { " output " : " valid json " }
def test_extract_metadata_invalid_json ( self , monkeypatch ) :
monkeypatch . setenv ( " OPENAI_API_KEY " , " test-api-key " )
extractor = LLMMetadataExtractor (
2025-04-11 15:50:52 +02:00
prompt = " prompt {{ document.content}} " , chat_generator = OpenAIChatGenerator ( ) , raise_on_failure = True
2025-02-10 17:54:25 +01:00
)
with pytest . raises ( ValueError ) :
extractor . _extract_metadata ( llm_answer = ' { " output: " valid json " } ' )
def test_extract_metadata_missing_key ( self , monkeypatch , caplog ) :
monkeypatch . setenv ( " OPENAI_API_KEY " , " test-api-key " )
extractor = LLMMetadataExtractor (
2025-04-11 15:50:52 +02:00
prompt = " prompt {{ document.content}} " , chat_generator = OpenAIChatGenerator ( ) , expected_keys = [ " key1 " ]
2025-02-10 17:54:25 +01:00
)
extractor . _extract_metadata ( llm_answer = ' { " output " : " valid json " } ' )
assert " Expected response from LLM to be a JSON with keys " in caplog . text
def test_prepare_prompts ( self , monkeypatch ) :
monkeypatch . setenv ( " OPENAI_API_KEY " , " test-api-key " )
extractor = LLMMetadataExtractor (
2025-04-11 15:50:52 +02:00
prompt = " some_user_definer_prompt {{ document.content}} " , chat_generator = OpenAIChatGenerator ( )
2025-02-10 17:54:25 +01:00
)
docs = [
Document ( content = " deepset was founded in 2018 in Berlin, and is known for its Haystack framework " ) ,
Document (
content = " Hugging Face is a company founded in Paris, France and is known for its Transformers library "
) ,
]
prompts = extractor . _prepare_prompts ( docs )
assert prompts == [
ChatMessage . from_dict (
{
" _role " : " user " ,
" _meta " : { } ,
" _name " : None ,
" _content " : [
{
" text " : " some_user_definer_prompt deepset was founded in 2018 in Berlin, and is known for its Haystack framework "
}
] ,
}
) ,
ChatMessage . from_dict (
{
" _role " : " user " ,
" _meta " : { } ,
" _name " : None ,
" _content " : [
{
" text " : " some_user_definer_prompt Hugging Face is a company founded in Paris, France and is known for its Transformers library "
}
] ,
}
) ,
]
def test_prepare_prompts_empty_document ( self , monkeypatch ) :
monkeypatch . setenv ( " OPENAI_API_KEY " , " test-api-key " )
extractor = LLMMetadataExtractor (
2025-04-11 15:50:52 +02:00
prompt = " some_user_definer_prompt {{ document.content}} " , chat_generator = OpenAIChatGenerator ( )
2025-02-10 17:54:25 +01:00
)
docs = [
Document ( content = " " ) ,
Document (
content = " Hugging Face is a company founded in Paris, France and is known for its Transformers library "
) ,
]
prompts = extractor . _prepare_prompts ( docs )
assert prompts == [
None ,
ChatMessage . from_dict (
{
" _role " : " user " ,
" _meta " : { } ,
" _name " : None ,
" _content " : [
{
" text " : " some_user_definer_prompt Hugging Face is a company founded in Paris, France and is known for its Transformers library "
}
] ,
}
) ,
]
def test_prepare_prompts_expanded_range ( self , monkeypatch ) :
monkeypatch . setenv ( " OPENAI_API_KEY " , " test-api-key " )
extractor = LLMMetadataExtractor (
2025-04-11 15:50:52 +02:00
prompt = " some_user_definer_prompt {{ document.content}} " ,
chat_generator = OpenAIChatGenerator ( ) ,
page_range = [ " 1-2 " ] ,
2025-02-10 17:54:25 +01:00
)
docs = [
Document (
content = " Hugging Face is a company founded in Paris, France and is known for its Transformers library \f Page 2 \f Page 3 "
)
]
prompts = extractor . _prepare_prompts ( docs , expanded_range = [ 1 , 2 ] )
assert prompts == [
ChatMessage . from_dict (
{
" _role " : " user " ,
" _meta " : { } ,
" _name " : None ,
" _content " : [
{
" text " : " some_user_definer_prompt Hugging Face is a company founded in Paris, France and is known for its Transformers library \x0c Page 2 \x0c "
}
] ,
}
)
]
def test_run_no_documents ( self , monkeypatch ) :
monkeypatch . setenv ( " OPENAI_API_KEY " , " test-api-key " )
2025-04-11 15:50:52 +02:00
extractor = LLMMetadataExtractor ( prompt = " prompt {{ document.content}} " , chat_generator = OpenAIChatGenerator ( ) )
2025-02-10 17:54:25 +01:00
result = extractor . run ( documents = [ ] )
assert result [ " documents " ] == [ ]
assert result [ " failed_documents " ] == [ ]
2025-05-23 19:57:39 +03:00
def test_run_with_document_content_none ( self , monkeypatch ) :
monkeypatch . setenv ( " OPENAI_API_KEY " , " test-api-key " )
# Mock the chat generator to prevent actual LLM calls
mock_chat_generator = Mock ( spec = OpenAIChatGenerator )
extractor = LLMMetadataExtractor (
prompt = " prompt {{ document.content}} " , chat_generator = mock_chat_generator , expected_keys = [ " some_key " ]
)
# Document with None content
doc_with_none_content = Document ( content = None )
# also test with empty string content
doc_with_empty_content = Document ( content = " " )
docs = [ doc_with_none_content , doc_with_empty_content ]
result = extractor . run ( documents = docs )
# Assert that the documents are in failed_documents
assert len ( result [ " documents " ] ) == 0
assert len ( result [ " failed_documents " ] ) == 2
failed_doc_none = result [ " failed_documents " ] [ 0 ]
assert failed_doc_none . id == doc_with_none_content . id
assert " metadata_extraction_error " in failed_doc_none . meta
assert failed_doc_none . meta [ " metadata_extraction_error " ] == " Document has no content, skipping LLM call. "
failed_doc_empty = result [ " failed_documents " ] [ 1 ]
assert failed_doc_empty . id == doc_with_empty_content . id
assert " metadata_extraction_error " in failed_doc_empty . meta
assert failed_doc_empty . meta [ " metadata_extraction_error " ] == " Document has no content, skipping LLM call. "
# Ensure no attempt was made to call the LLM
mock_chat_generator . run . assert_not_called ( )
2025-02-10 17:54:25 +01:00
@pytest.mark.integration
@pytest.mark.skipif (
not os . environ . get ( " OPENAI_API_KEY " , None ) ,
reason = " Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test. " ,
)
def test_live_run ( self ) :
docs = [
Document ( content = " deepset was founded in 2018 in Berlin, and is known for its Haystack framework " ) ,
Document (
content = " Hugging Face is a company founded in Paris, France and is known for its Transformers library "
) ,
]
ner_prompt = """ -Goal-
Given text and a list of entity types , identify all entities of those types from the text .
- Steps -
1. Identify all entities . For each identified entity , extract the following information :
- entity_name : Name of the entity , capitalized
- entity_type : One of the following types : [ organization , product , service , industry ]
Format each entity as { " entity " : < entity_name > , " entity_type " : < entity_type > }
2. Return output in a single list with all the entities identified in steps 1.
- Examples -
######################
Example 1 :
entity_types : [ organization , person , partnership , financial metric , product , service , industry , investment strategy , market trend ]
text : Another area of strength is our co - brand issuance . Visa is the primary network partner for eight of the top
10 co - brand partnerships in the US today and we are pleased that Visa has finalized a multi - year extension of
our successful credit co - branded partnership with Alaska Airlines , a portfolio that benefits from a loyal customer
base and high cross - border usage .
We have also had significant co - brand momentum in CEMEA . First , we launched a new co - brand card in partnership
with Qatar Airways , British Airways and the National Bank of Kuwait . Second , we expanded our strong global
Marriott relationship to launch Qatar ' s first hospitality co-branded card with Qatar Islamic Bank. Across the
United Arab Emirates , we now have exclusive agreements with all the leading airlines marked by a recent
agreement with Emirates Skywards .
And we also signed an inaugural Airline co - brand agreement in Morocco with Royal Air Maroc . Now newer digital
issuers are equally
- - - - - - - - - - - - - - - - - - - - - - - -
output :
{ " entities " : [ { " entity " : " Visa " , " entity_type " : " company " } , { " entity " : " Alaska Airlines " , " entity_type " : " company " } , { " entity " : " Qatar Airways " , " entity_type " : " company " } , { " entity " : " British Airways " , " entity_type " : " company " } , { " entity " : " National Bank of Kuwait " , " entity_type " : " company " } , { " entity " : " Marriott " , " entity_type " : " company " } , { " entity " : " Qatar Islamic Bank " , " entity_type " : " company " } , { " entity " : " Emirates Skywards " , " entity_type " : " company " } , { " entity " : " Royal Air Maroc " , " entity_type " : " company " } ] }
#############################
- Real Data -
######################
entity_types : [ company , organization , person , country , product , service ]
text : { { document . content } }
######################
output :
"""
doc_store = InMemoryDocumentStore ( )
extractor = LLMMetadataExtractor (
2025-03-24 18:38:09 +01:00
prompt = ner_prompt , expected_keys = [ " entities " ] , chat_generator = OpenAIChatGenerator ( )
2025-02-10 17:54:25 +01:00
)
writer = DocumentWriter ( document_store = doc_store )
pipeline = Pipeline ( )
pipeline . add_component ( " extractor " , extractor )
pipeline . add_component ( " doc_writer " , writer )
pipeline . connect ( " extractor.documents " , " doc_writer.documents " )
pipeline . run ( data = { " documents " : docs } )
doc_store_docs = doc_store . filter_documents ( )
assert len ( doc_store_docs ) == 2
assert " entities " in doc_store_docs [ 0 ] . meta
assert " entities " in doc_store_docs [ 1 ] . meta