chore: make Haystack warnings consistent (#9083)

* chore: make Haystack warnings consistent

* more structured logging

* small fixes
This commit is contained in:
Stefano Fiorucci 2025-03-21 18:18:55 +01:00 committed by GitHub
parent 3e435439d9
commit 1c1030efc6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 65 additions and 50 deletions

View File

@ -84,7 +84,7 @@ class OpenAPIServiceToFunctions:
"IO error reading OpenAPI specification file: {source}. Error: {e}", source=source, e=e "IO error reading OpenAPI specification file: {source}. Error: {e}", source=source, e=e
) )
else: else:
logger.warning(f"OpenAPI specification file not found: {source}") logger.warning("OpenAPI specification file not found: {source}", source=source)
elif isinstance(source, ByteStream): elif isinstance(source, ByteStream):
openapi_spec_content = source.data.decode("utf-8") openapi_spec_content = source.data.decode("utf-8")
if not openapi_spec_content: if not openapi_spec_content:

View File

@ -5,7 +5,6 @@
import ast import ast
import contextlib import contextlib
from typing import Any, Callable, Dict, Optional, Set from typing import Any, Callable, Dict, Optional, Set
from warnings import warn
import jinja2.runtime import jinja2.runtime
from jinja2 import Environment, TemplateSyntaxError, meta from jinja2 import Environment, TemplateSyntaxError, meta
@ -13,9 +12,11 @@ from jinja2.nativetypes import NativeEnvironment
from jinja2.sandbox import SandboxedEnvironment from jinja2.sandbox import SandboxedEnvironment
from typing_extensions import TypeAlias from typing_extensions import TypeAlias
from haystack import component, default_from_dict, default_to_dict from haystack import component, default_from_dict, default_to_dict, logging
from haystack.utils import deserialize_callable, deserialize_type, serialize_callable, serialize_type from haystack.utils import deserialize_callable, deserialize_type, serialize_callable, serialize_type
logger = logging.getLogger(__name__)
class OutputAdaptationException(Exception): class OutputAdaptationException(Exception):
"""Exception raised when there is an error during output adaptation.""" """Exception raised when there is an error during output adaptation."""
@ -76,7 +77,7 @@ class OutputAdapter:
"Unsafe mode is enabled. This allows execution of arbitrary code in the Jinja template. " "Unsafe mode is enabled. This allows execution of arbitrary code in the Jinja template. "
"Use this only if you trust the source of the template." "Use this only if you trust the source of the template."
) )
warn(msg) logger.warning(msg)
self._env = ( self._env = (
NativeEnvironment() if self._unsafe else SandboxedEnvironment(undefined=jinja2.runtime.StrictUndefined) NativeEnvironment() if self._unsafe else SandboxedEnvironment(undefined=jinja2.runtime.StrictUndefined)
) )

View File

@ -2,12 +2,11 @@
# #
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import warnings
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from tqdm import tqdm from tqdm import tqdm
from haystack import component, default_from_dict, default_to_dict from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses import Document from haystack.dataclasses import Document
from haystack.lazy_imports import LazyImport from haystack.lazy_imports import LazyImport
from haystack.utils import Secret, deserialize_secrets_inplace from haystack.utils import Secret, deserialize_secrets_inplace
@ -17,6 +16,8 @@ from haystack.utils.url_validation import is_valid_http_url
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import: with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import:
from huggingface_hub import InferenceClient from huggingface_hub import InferenceClient
logger = logging.getLogger(__name__)
@component @component
class HuggingFaceAPIDocumentEmbedder: class HuggingFaceAPIDocumentEmbedder:
@ -241,11 +242,11 @@ class HuggingFaceAPIDocumentEmbedder:
if self.api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API: if self.api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API:
if truncate is not None: if truncate is not None:
msg = "`truncate` parameter is not supported for Serverless Inference API. It will be ignored." msg = "`truncate` parameter is not supported for Serverless Inference API. It will be ignored."
warnings.warn(msg) logger.warning(msg)
truncate = None truncate = None
if normalize is not None: if normalize is not None:
msg = "`normalize` parameter is not supported for Serverless Inference API. It will be ignored." msg = "`normalize` parameter is not supported for Serverless Inference API. It will be ignored."
warnings.warn(msg) logger.warning(msg)
normalize = None normalize = None
all_embeddings = [] all_embeddings = []

View File

@ -2,10 +2,9 @@
# #
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import warnings
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from haystack import component, default_from_dict, default_to_dict from haystack import component, default_from_dict, default_to_dict, logging
from haystack.lazy_imports import LazyImport from haystack.lazy_imports import LazyImport
from haystack.utils import Secret, deserialize_secrets_inplace from haystack.utils import Secret, deserialize_secrets_inplace
from haystack.utils.hf import HFEmbeddingAPIType, HFModelType, check_valid_model from haystack.utils.hf import HFEmbeddingAPIType, HFModelType, check_valid_model
@ -14,6 +13,8 @@ from haystack.utils.url_validation import is_valid_http_url
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import: with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import:
from huggingface_hub import InferenceClient from huggingface_hub import InferenceClient
logger = logging.getLogger(__name__)
@component @component
class HuggingFaceAPITextEmbedder: class HuggingFaceAPITextEmbedder:
@ -200,11 +201,11 @@ class HuggingFaceAPITextEmbedder:
if self.api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API: if self.api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API:
if truncate is not None: if truncate is not None:
msg = "`truncate` parameter is not supported for Serverless Inference API. It will be ignored." msg = "`truncate` parameter is not supported for Serverless Inference API. It will be ignored."
warnings.warn(msg) logger.warning(msg)
truncate = None truncate = None
if normalize is not None: if normalize is not None:
msg = "`normalize` parameter is not supported for Serverless Inference API. It will be ignored." msg = "`normalize` parameter is not supported for Serverless Inference API. It will be ignored."
warnings.warn(msg) logger.warning(msg)
normalize = None normalize = None
text_to_embed = self.prefix + text + self.suffix text_to_embed = self.prefix + text + self.suffix

View File

@ -4,15 +4,16 @@
import json import json
from typing import Any, Dict, List, Optional, Tuple, Type from typing import Any, Dict, List, Optional, Tuple, Type
from warnings import warn
from tqdm import tqdm from tqdm import tqdm
from haystack import component, default_from_dict, default_to_dict from haystack import component, default_from_dict, default_to_dict, logging
from haystack.components.builders import PromptBuilder from haystack.components.builders import PromptBuilder
from haystack.components.generators import OpenAIGenerator from haystack.components.generators import OpenAIGenerator
from haystack.utils import Secret, deserialize_secrets_inplace, deserialize_type, serialize_type from haystack.utils import Secret, deserialize_secrets_inplace, deserialize_type, serialize_type
logger = logging.getLogger(__name__)
@component @component
class LLMEvaluator: class LLMEvaluator:
@ -206,10 +207,9 @@ class LLMEvaluator:
try: try:
result = self.generator.run(prompt=prompt["prompt"]) result = self.generator.run(prompt=prompt["prompt"])
except Exception as e: except Exception as e:
msg = f"Error while generating response for prompt: {prompt}. Error: {e}"
if self.raise_on_failure: if self.raise_on_failure:
raise ValueError(msg) raise ValueError(f"Error while generating response for prompt: {prompt}. Error: {e}")
warn(msg) logger.warning("Error while generating response for prompt: {prompt}. Error: {e}", prompt=prompt, e=e)
results.append(None) results.append(None)
errors += 1 errors += 1
continue continue
@ -225,8 +225,11 @@ class LLMEvaluator:
metadata = result["meta"] metadata = result["meta"]
if errors > 0: if errors > 0:
msg = f"LLM evaluator failed for {errors} out of {len(list_of_input_names_to_values)} inputs." logger.warning(
warn(msg) "LLM evaluator failed for {errors} out of {len(list_of_input_names_to_values)} inputs.",
errors=errors,
len=len(list_of_input_names_to_values),
)
return {"results": results, "meta": metadata} return {"results": results, "meta": metadata}
@ -374,14 +377,19 @@ class LLMEvaluator:
msg = "Response from LLM evaluator is not a valid JSON." msg = "Response from LLM evaluator is not a valid JSON."
if self.raise_on_failure: if self.raise_on_failure:
raise ValueError(msg) raise ValueError(msg)
warn(msg) logger.warning(msg)
return False return False
if not all(output in parsed_output for output in expected): if not all(output in parsed_output for output in expected):
msg = f"Expected response from LLM evaluator to be JSON with keys {expected}, got {received}."
if self.raise_on_failure: if self.raise_on_failure:
raise ValueError(msg) raise ValueError(
warn(msg) f"Expected response from LLM evaluator to be JSON with keys {expected}, got {{received}}."
)
logger.warning(
"Expected response from LLM evaluator to be JSON with keys {expected}, got {received}.",
expected=expected,
received=received,
)
return False return False
return True return True

View File

@ -3,7 +3,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import math import math
import warnings
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
@ -215,9 +214,10 @@ class ExtractiveReader:
document_contents = [] document_contents = []
for i, doc in enumerate(documents): for i, doc in enumerate(documents):
if doc.content is None: if doc.content is None:
warnings.warn( logger.warning(
f"Document with id {doc.id} was passed to ExtractiveReader. The Document doesn't " "Document with id {doc_id} was passed to ExtractiveReader. The Document doesn't "
f"contain any text and it will be ignored." "contain any text and it will be ignored.",
doc_id=doc.id,
) )
continue continue
texts.append(doc.content) texts.append(doc.content)

View File

@ -5,7 +5,6 @@
import ast import ast
import contextlib import contextlib
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Set, Union, get_args, get_origin from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Set, Union, get_args, get_origin
from warnings import warn
from jinja2 import Environment, TemplateSyntaxError, meta from jinja2 import Environment, TemplateSyntaxError, meta
from jinja2.nativetypes import NativeEnvironment from jinja2.nativetypes import NativeEnvironment
@ -192,7 +191,7 @@ class ConditionalRouter:
"Unsafe mode is enabled. This allows execution of arbitrary code in the Jinja template. " "Unsafe mode is enabled. This allows execution of arbitrary code in the Jinja template. "
"Use this only if you trust the source of the template." "Use this only if you trust the source of the template."
) )
warn(msg) logger.warning(msg)
self._env = NativeEnvironment() if self._unsafe else SandboxedEnvironment() self._env = NativeEnvironment() if self._unsafe else SandboxedEnvironment()
self._env.filters.update(self.custom_filters) self._env.filters.update(self.custom_filters)
@ -216,13 +215,11 @@ class ConditionalRouter:
# warn about unused optional variables # warn about unused optional variables
unused_optional_vars = set(self.optional_variables) - input_types if self.optional_variables else None unused_optional_vars = set(self.optional_variables) - input_types if self.optional_variables else None
if unused_optional_vars: if unused_optional_vars:
msg = ( logger.warning(
f"The following optional variables are specified but not used in any route: {unused_optional_vars}. " "The following optional variables are specified but not used in any route: {unused_optional_vars}. "
"Check if there's a typo in variable names." "Check if there's a typo in variable names.",
unused_optional_vars=unused_optional_vars,
) )
# intentionally using both warn and logger
warn(msg, UserWarning)
logger.warning(msg)
# add mandatory input types # add mandatory input types
component.set_input_types(self, **dict.fromkeys(mandatory_input_types, Any)) component.set_input_types(self, **dict.fromkeys(mandatory_input_types, Any))

View File

@ -5,13 +5,15 @@
import csv import csv
from copy import deepcopy from copy import deepcopy
from typing import Any, Dict, List, Literal, Optional, Union from typing import Any, Dict, List, Literal, Optional, Union
from warnings import warn
from haystack import logging
from haystack.lazy_imports import LazyImport from haystack.lazy_imports import LazyImport
with LazyImport("Run 'pip install pandas'") as pandas_import: with LazyImport("Run 'pip install pandas'") as pandas_import:
from pandas import DataFrame from pandas import DataFrame
logger = logging.getLogger(__name__)
class EvaluationRunResult: class EvaluationRunResult:
""" """
@ -188,10 +190,15 @@ class EvaluationRunResult:
raise ValueError("The 'other' parameter must have 'run_name', 'inputs', and 'results' attributes.") raise ValueError("The 'other' parameter must have 'run_name', 'inputs', and 'results' attributes.")
if self.run_name == other.run_name: if self.run_name == other.run_name:
warn(f"The run names of the two evaluation results are the same ('{self.run_name}')") logger.warning(
"The run names of the two evaluation results are the same ('{run_name}')", run_name=self.run_name
)
if self.inputs.keys() != other.inputs.keys(): if self.inputs.keys() != other.inputs.keys():
warn(f"The input columns differ between the results; using the input columns of '{self.run_name}'.") logger.warning(
"The input columns differ between the results; using the input columns of '{run_name}'",
run_name=self.run_name,
)
# got both detailed reports # got both detailed reports
detailed_a = self.detailed_report(output_format="json") detailed_a = self.detailed_report(output_format="json")

View File

@ -203,7 +203,7 @@ class TestHuggingFaceAPIDocumentEmbedder:
"my_prefix document number 4 my_suffix", "my_prefix document number 4 my_suffix",
] ]
def test_embed_batch(self, mock_check_valid_model, recwarn): def test_embed_batch(self, mock_check_valid_model, caplog):
texts = ["text 1", "text 2", "text 3", "text 4", "text 5"] texts = ["text 1", "text 2", "text 3", "text 4", "text 5"]
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch: with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
@ -225,10 +225,10 @@ class TestHuggingFaceAPIDocumentEmbedder:
assert len(embedding) == 384 assert len(embedding) == 384
assert all(isinstance(x, float) for x in embedding) assert all(isinstance(x, float) for x in embedding)
# Check that warnings about ignoring truncate and normalize are raised # Check that logger warnings about ignoring truncate and normalize are raised
assert len(recwarn) == 2 assert len(caplog.records) == 2
assert "truncate" in str(recwarn[0].message) assert "truncate" in caplog.records[0].message
assert "normalize" in str(recwarn[1].message) assert "normalize" in caplog.records[1].message
def test_embed_batch_wrong_embedding_shape(self, mock_check_valid_model): def test_embed_batch_wrong_embedding_shape(self, mock_check_valid_model):
texts = ["text 1", "text 2", "text 3", "text 4", "text 5"] texts = ["text 1", "text 2", "text 3", "text 4", "text 5"]

View File

@ -136,7 +136,7 @@ class TestHuggingFaceAPITextEmbedder:
with pytest.raises(TypeError): with pytest.raises(TypeError):
embedder.run(text=list_integers_input) embedder.run(text=list_integers_input)
def test_run(self, mock_check_valid_model, recwarn): def test_run(self, mock_check_valid_model, caplog):
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch: with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
mock_embedding_patch.return_value = array([[random.random() for _ in range(384)]]) mock_embedding_patch.return_value = array([[random.random() for _ in range(384)]])
@ -158,9 +158,9 @@ class TestHuggingFaceAPITextEmbedder:
assert all(isinstance(x, float) for x in result["embedding"]) assert all(isinstance(x, float) for x in result["embedding"])
# Check that warnings about ignoring truncate and normalize are raised # Check that warnings about ignoring truncate and normalize are raised
assert len(recwarn) == 2 assert len(caplog.records) == 2
assert "truncate" in str(recwarn[0].message) assert "truncate" in caplog.records[0].message
assert "normalize" in str(recwarn[1].message) assert "normalize" in caplog.records[1].message
def test_run_wrong_embedding_shape(self, mock_check_valid_model): def test_run_wrong_embedding_shape(self, mock_check_valid_model):
# embedding ndim > 2 # embedding ndim > 2

View File

@ -525,7 +525,7 @@ class TestRouter:
result = pipe.run(data={"router": {"question": "What?", "mode": "chat", "language": "en", "source": "doc"}}) result = pipe.run(data={"router": {"question": "What?", "mode": "chat", "language": "en", "source": "doc"}})
assert result["router"] == {"en_doc_chat": "What?"}, "Pipeline should handle all parameters" assert result["router"] == {"en_doc_chat": "What?"}, "Pipeline should handle all parameters"
def test_warns_on_unused_optional_variables(self): def test_warns_on_unused_optional_variables(self, caplog):
""" """
Test that a warning is raised when optional_variables contains variables Test that a warning is raised when optional_variables contains variables
that are not used in any route conditions or outputs. that are not used in any route conditions or outputs.
@ -536,8 +536,8 @@ class TestRouter:
] ]
# Initialize with unused optional variables and capture warning # Initialize with unused optional variables and capture warning
with pytest.warns(UserWarning, match="optional variables"):
router = ConditionalRouter(routes=routes, optional_variables=["unused_var1", "unused_var2"]) router = ConditionalRouter(routes=routes, optional_variables=["unused_var1", "unused_var2"])
assert "optional variables" in caplog.records[0].message
# Verify router still works normally # Verify router still works normally
result = router.run(question="What?", mode="chat") result = router.run(question="What?", mode="chat")