Compare commits

...

17 Commits

Author SHA1 Message Date
Stefano Fiorucci
2693f39e44
docs: discourage usage of HuggingFaceAPIGenerator with the HF Inference API (#9590)
* docs: discourage usage of HuggingFaceAPIGenerator with the HF Inference API

* small fixes

* Apply suggestions from code review

Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>

* fix fmt

---------

Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>
2025-07-04 14:27:30 +00:00
Stefano Fiorucci
646eedf26a
chore: reenable HF API Embedders tests + improve HFAPIChatGenerator docstrings (#9589)
* chore: reenable some HF API tests + improve docstrings

* revert deletion
2025-07-04 09:39:43 +02:00
Amna Mubashar
050c987946
chore: remove backward compatibility for State deserialization (#9585)
* remove backward compatability

* Fix linting
2025-07-03 13:20:34 +02:00
Sebastian Husch Lee
85258f0654
fix: Fix types and formatting pipeline test_run.py (#9575)
* Fix types in test_run.py

* Get test_run.py to pass fmt-check

* Add test_run to mypy checks

* Update test folder to pass ruff linting

* Fix merge

* Fix HF tests

* Fix hf test

* Try to fix tests

* Another attempt

* minor fix

* fix SentenceTransformersDiversityRanker

* skip integrations tests due to model unavailable on HF inference

---------

Co-authored-by: anakin87 <stefanofiorucci@gmail.com>
2025-07-03 09:49:09 +02:00
Sebastian Husch Lee
16fc41cd95
feat: Relax requirement for creating a ToolCallDelta dataclass (#9582)
* Relax our requirement for ToolCallDelta to better match ChoiceDeltaToolCall and ChoiceDeltaToolCallFunction from OpenAI

* Add reno

* Update tests
2025-07-03 08:50:29 +02:00
Amna Mubashar
9fd552f906
chore: remove deprecated async_executor param from ToolInvoker (#9571)
* Remove async executor

* Add release notes

* Linting

* update release notes
2025-07-02 14:02:51 +02:00
Amna Mubashar
adb2759d00
chore: remove deprecated State from haystack.dataclasses (#9578)
* Remove deprecated class

* Remove state from pydocs
2025-07-02 12:19:06 +02:00
Stefano Fiorucci
848115c65e
fix: fix print_streaming_chunk + add tests (#9579)
* fix: fix print_streaming_chunk + add tests

* rel note
2025-07-01 16:49:18 +02:00
Sebastian Husch Lee
3aaa201ed6
feat: Add tool_invoker_kwargs to Agent (#9574)
* Add new param to agent to pass any kwargs to tool invoker

* Add reno
2025-07-01 11:09:58 +02:00
Amna Mubashar
f11870b212
fix: adjust the async executor in ToolInvoker (#9562)
* Fix bug

* Add a new test

* PR comments

* Add another test

* Small fix

* Fix linting

* Update tests
2025-06-30 15:13:01 +02:00
Sebastian Husch Lee
97e72b9693
feat: Add to_dict and from_dict to ByteStream (#9568)
* Add to_dict and from_dict to ByteStream

* Add reno

* Add unit tests

* Fix and expand tests

* Fix typing

* PR comments
2025-06-30 11:57:22 +00:00
Sebastian Husch Lee
fc64884819
fix: Fix _convert_streaming_chunks_to_chat_message (#9566)
* Fix conversion

* Add reno

* Add unit test
2025-06-30 11:51:25 +02:00
mathislucka
c54a68ab63
fix: files should not be passed as single string (#9559)
* fix: files should not be passed as single string

* chore: we want word splitting in this case

* fix: place directive before command

* fix: find correct directive placement
2025-06-27 11:17:42 +02:00
Stefano Fiorucci
c18f81283c
chore: fix deepset_sync.py for pylint + general linting improvements (#9558)
* chore: fix deepset_sync.py for pylint

* check .github with ruff

* fix

* Update .github/utils/pyproject_to_requirements.py

Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com>

---------

Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com>
2025-06-27 07:54:22 +00:00
mathislucka
101e9cdc34
docs: sync code to deepset workspace (#9555)
* docs: sync code to deepset workspace

* fix: naming

* fix: actionlint
2025-06-27 07:51:59 +02:00
Stefano Fiorucci
bcaef53cbc
test: export HF_TOKEN env var in e2e environment (#9551)
* try to fix e2e tests for private NER models

* explanatory comment

* extend skipif condition
2025-06-25 15:00:28 +02:00
Haystack Bot
85e8493f4f
Update unstable version to 2.16.0-rc0 (#9554)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-06-25 14:57:16 +02:00
170 changed files with 1954 additions and 1520 deletions

View File

@ -1,14 +1,17 @@
import importlib
import os import os
import sys import sys
import importlib
import traceback import traceback
from pathlib import Path from pathlib import Path
from typing import List, Optional from typing import List, Optional
from haystack import logging # pylint: disable=unused-import # this is needed to avoid circular imports from haystack import logging # pylint: disable=unused-import # this is needed to avoid circular imports
def validate_module_imports(root_dir: str, exclude_subdirs: Optional[List[str]] = None) -> tuple[list, list]: def validate_module_imports(root_dir: str, exclude_subdirs: Optional[List[str]] = None) -> tuple[list, list]:
""" """
Recursively search for all Python modules and attempt to import them. Recursively search for all Python modules and attempt to import them.
This includes both packages (directories with __init__.py) and individual Python files. This includes both packages (directories with __init__.py) and individual Python files.
""" """
imported = [] imported = []
@ -25,7 +28,7 @@ def validate_module_imports(root_dir: str, exclude_subdirs: Optional[List[str]]
# Convert path to module format # Convert path to module format
module_path = ".".join(Path(root).relative_to(base_path.parent).parts) module_path = ".".join(Path(root).relative_to(base_path.parent).parts)
python_files = [f for f in files if f.endswith('.py')] python_files = [f for f in files if f.endswith(".py")]
# Try importing package and individual files # Try importing package and individual files
for file in python_files: for file in python_files:
@ -39,16 +42,15 @@ def validate_module_imports(root_dir: str, exclude_subdirs: Optional[List[str]]
importlib.import_module(module_to_import) importlib.import_module(module_to_import)
imported.append(module_to_import) imported.append(module_to_import)
except: except:
failed.append({ failed.append({"module": module_to_import, "traceback": traceback.format_exc()})
'module': module_to_import,
'traceback': traceback.format_exc()
})
return imported, failed return imported, failed
def main(): def main():
""" """
This script checks that all Haystack modules can be imported successfully. This script checks that all Haystack modules can be imported successfully.
This includes both packages and individual Python files. This includes both packages and individual Python files.
This can detect several issues, such as: This can detect several issues, such as:
- Syntax errors in Python files - Syntax errors in Python files
@ -80,5 +82,5 @@ def main():
sys.exit(1) sys.exit(1)
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@ -1,14 +1,16 @@
import argparse
import re import re
import sys import sys
import argparse
from readme_api import get_versions, create_new_unstable
from readme_api import create_new_unstable, get_versions
VERSION_VALIDATOR = re.compile(r"^[0-9]+\.[0-9]+$") VERSION_VALIDATOR = re.compile(r"^[0-9]+\.[0-9]+$")
def calculate_new_unstable(version: str): def calculate_new_unstable(version: str):
"""
Calculate the new unstable version based on the given version.
"""
# version must be formatted like so <major>.<minor> # version must be formatted like so <major>.<minor>
major, minor = version.split(".") major, minor = version.split(".")
return f"{major}.{int(minor) + 1}-unstable" return f"{major}.{int(minor) + 1}-unstable"

181
.github/utils/deepset_sync.py vendored Normal file
View File

@ -0,0 +1,181 @@
# /// script
# dependencies = [
# "requests",
# ]
# ///
import argparse
import json
import os
import sys
from pathlib import Path
from typing import Optional
import requests
def transform_filename(filepath: Path) -> str:
"""
Transform a file path to the required format:
- Replace path separators with underscores
"""
# Convert to string and replace path separators with underscores
transformed = str(filepath).replace("/", "_").replace("\\", "_")
return transformed
def upload_file_to_deepset(filepath: Path, api_key: str, workspace: str) -> bool:
"""
Upload a single file to Deepset API.
"""
# Read file content
try:
content = filepath.read_text(encoding="utf-8")
except Exception as e:
print(f"Error reading file {filepath}: {e}")
return False
# Transform filename
transformed_name = transform_filename(filepath)
# Prepare metadata
metadata: dict[str, str] = {"original_file_path": str(filepath)}
# Prepare API request
url = f"https://api.cloud.deepset.ai/api/v1/workspaces/{workspace}/files"
params: dict[str, str] = {"file_name": transformed_name, "write_mode": "OVERWRITE"}
headers: dict[str, str] = {"accept": "application/json", "authorization": f"Bearer {api_key}"}
# Prepare multipart form data
files: dict[str, tuple[None, str, str]] = {
"meta": (None, json.dumps(metadata), "application/json"),
"text": (None, content, "text/plain"),
}
try:
response = requests.post(url, params=params, headers=headers, files=files, timeout=300)
response.raise_for_status()
print(f"Successfully uploaded: {filepath} as {transformed_name}")
return True
except requests.exceptions.HTTPError:
print(f"Failed to upload {filepath}: HTTP {response.status_code}")
print(f" Response: {response.text}")
return False
except Exception as e:
print(f"Failed to upload {filepath}: {e}")
return False
def delete_files_from_deepset(filepaths: list[Path], api_key: str, workspace: str) -> bool:
"""
Delete multiple files from Deepset API.
"""
if not filepaths:
return True
# Transform filenames
transformed_names: list[str] = [transform_filename(fp) for fp in filepaths]
# Prepare API request
url = f"https://api.cloud.deepset.ai/api/v1/workspaces/{workspace}/files"
headers: dict[str, str] = {
"accept": "application/json",
"authorization": f"Bearer {api_key}",
"content-type": "application/json",
}
data: dict[str, list[str]] = {"names": transformed_names}
try:
response = requests.delete(url, headers=headers, json=data, timeout=300)
response.raise_for_status()
print(f"Successfully deleted {len(transformed_names)} file(s):")
for original, transformed in zip(filepaths, transformed_names):
print(f" - {original} (as {transformed})")
return True
except requests.exceptions.HTTPError:
print(f"Failed to delete files: HTTP {response.status_code}")
print(f" Response: {response.text}")
return False
except Exception as e:
print(f"Failed to delete files: {e}")
return False
def main() -> None:
"""
Main function to process and upload/delete files.
"""
# Parse command line arguments
parser = argparse.ArgumentParser(description="Upload/delete Python files to/from Deepset")
parser.add_argument("--changed", nargs="*", default=[], help="Changed or added files")
parser.add_argument("--deleted", nargs="*", default=[], help="Deleted files")
args = parser.parse_args()
# Get environment variables
api_key: Optional[str] = os.environ.get("DEEPSET_API_KEY")
workspace: str = os.environ.get("DEEPSET_WORKSPACE")
if not api_key:
print("Error: DEEPSET_API_KEY environment variable not set")
sys.exit(1)
# Process arguments and convert to Path objects
changed_files: list[Path] = [Path(f.strip()) for f in args.changed if f.strip()]
deleted_files: list[Path] = [Path(f.strip()) for f in args.deleted if f.strip()]
if not changed_files and not deleted_files:
print("No files to process")
sys.exit(0)
print(f"Processing files in Deepset workspace: {workspace}")
print("-" * 50)
# Track results
upload_success: int = 0
upload_failed: list[Path] = []
delete_success: bool = False
# Handle deletions first
if deleted_files:
print(f"\nDeleting {len(deleted_files)} file(s)...")
delete_success = delete_files_from_deepset(deleted_files, api_key, workspace)
# Upload changed/new files
if changed_files:
print(f"\nUploading {len(changed_files)} file(s)...")
for filepath in changed_files:
if filepath.exists():
if upload_file_to_deepset(filepath, api_key, workspace):
upload_success += 1
else:
upload_failed.append(filepath)
else:
print(f"Skipping non-existent file: {filepath}")
# Summary
print("-" * 50)
print("Processing Summary:")
if changed_files:
print(f" Uploads - Successful: {upload_success}, Failed: {len(upload_failed)}")
if deleted_files:
print(f" Deletions - {'Successful' if delete_success else 'Failed'}: {len(deleted_files)} file(s)")
if upload_failed:
print("\nFailed uploads:")
for f in upload_failed:
print(f" - {f}")
# Exit with error if any operation failed
if upload_failed or (deleted_files and not delete_success):
sys.exit(1)
print("\nAll operations completed successfully!")
if __name__ == "__main__":
main()

View File

@ -12,6 +12,9 @@ VERSION_VALIDATOR = re.compile(r"^[0-9]+\.[0-9]+$")
def readme_token(): def readme_token():
"""
Get the Readme API token from the environment variable and encode it in base64.
"""
api_key = os.getenv("README_API_KEY", None) api_key = os.getenv("README_API_KEY", None)
if not api_key: if not api_key:
raise Exception("README_API_KEY env var is not set") raise Exception("README_API_KEY env var is not set")
@ -21,6 +24,9 @@ def readme_token():
def create_headers(version: str): def create_headers(version: str):
"""
Create headers for the Readme API.
"""
return {"authorization": f"Basic {readme_token()}", "x-readme-version": version} return {"authorization": f"Basic {readme_token()}", "x-readme-version": version}
@ -35,6 +41,9 @@ def get_docs_in_category(category_slug: str, version: str) -> List[str]:
def delete_doc(slug: str, version: str): def delete_doc(slug: str, version: str):
"""
Delete a document from Readme, based on the slug and version.
"""
url = f"https://dash.readme.com/api/v1/docs/{slug}" url = f"https://dash.readme.com/api/v1/docs/{slug}"
headers = create_headers(version) headers = create_headers(version)
res = requests.delete(url, headers=headers, timeout=10) res = requests.delete(url, headers=headers, timeout=10)

View File

@ -1,11 +1,13 @@
import ast
import hashlib
from pathlib import Path from pathlib import Path
from typing import Iterator from typing import Iterator
import ast
import hashlib
def docstrings_checksum(python_files: Iterator[Path]): def docstrings_checksum(python_files: Iterator[Path]):
"""
Calculate the checksum of the docstrings in the given Python files.
"""
files_content = (f.read_text() for f in python_files) files_content = (f.read_text() for f in python_files)
trees = (ast.parse(c) for c in files_content) trees = (ast.parse(c) for c in files_content)

View File

@ -1,6 +1,6 @@
import argparse
import re import re
import sys import sys
import argparse
from readme_api import get_versions, promote_unstable_to_stable from readme_api import get_versions, promote_unstable_to_stable
@ -8,9 +8,7 @@ VERSION_VALIDATOR = re.compile(r"^[0-9]+\.[0-9]+$")
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument("-v", "--version", help="The version to promote to stable (e.g. 2.1).", required=True)
"-v", "--version", help="The version to promote to stable (e.g. 2.1).", required=True
)
args = parser.parse_args() args = parser.parse_args()
if VERSION_VALIDATOR.match(args.version) is None: if VERSION_VALIDATOR.match(args.version) is None:

View File

@ -3,7 +3,8 @@ import re
import sys import sys
from pathlib import Path from pathlib import Path
import toml # toml is available in the default environment but not in the test environment, so pylint complains
import toml # pylint: disable=import-error
matcher = re.compile(r"farm-haystack\[(.+)\]") matcher = re.compile(r"farm-haystack\[(.+)\]")
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -14,6 +15,9 @@ parser.add_argument("--extra", default="")
def resolve(target: str, extras: dict, results: set): def resolve(target: str, extras: dict, results: set):
"""
Resolve the dependencies for a given target.
"""
if target not in extras: if target not in extras:
results.add(target) results.add(target)
return return
@ -28,6 +32,9 @@ def resolve(target: str, extras: dict, results: set):
def main(pyproject_path: Path, extra: str = ""): def main(pyproject_path: Path, extra: str = ""):
"""
Convert a pyproject.toml file to a requirements.txt file.
"""
content = toml.load(pyproject_path) content = toml.load(pyproject_path)
# basic set of dependencies # basic set of dependencies
deps = set(content["project"]["dependencies"]) deps = set(content["project"]["dependencies"])

View File

@ -1,16 +1,23 @@
import os
import base64 import base64
import os
import requests import requests
class ReadmeAuth(requests.auth.AuthBase): class ReadmeAuth(requests.auth.AuthBase):
def __call__(self, r): """
Custom authentication class for Readme API.
"""
def __call__(self, r): # noqa: D102
r.headers["authorization"] = f"Basic {readme_token()}" r.headers["authorization"] = f"Basic {readme_token()}"
return r return r
def readme_token(): def readme_token():
"""
Get the Readme API token from the environment variable and encode it in base64.
"""
api_key = os.getenv("RDME_API_KEY", None) api_key = os.getenv("RDME_API_KEY", None)
if not api_key: if not api_key:
raise Exception("RDME_API_KEY env var is not set") raise Exception("RDME_API_KEY env var is not set")

View File

@ -18,7 +18,9 @@ env:
PYTHON_VERSION: "3.9" PYTHON_VERSION: "3.9"
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
HATCH_VERSION: "1.14.1" HATCH_VERSION: "1.14.1"
HF_API_TOKEN: ${{ secrets.HUGGINGFACE_API_KEY }} # we use HF_TOKEN instead of HF_API_TOKEN to work around a Hugging Face bug
# see https://github.com/deepset-ai/haystack/issues/9552
HF_TOKEN: ${{ secrets.HUGGINGFACE_API_KEY }}
jobs: jobs:
run: run:

View File

@ -0,0 +1,56 @@
name: Upload Code to Deepset
on:
push:
branches:
- main
jobs:
upload-files:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0 # Fetch all history for proper diff
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install uv
run: |
pip install uv
- name: Get changed files
id: changed-files
uses: tj-actions/changed-files@v46
with:
files: |
haystack/**/*.py
separator: ' '
- name: Upload files to Deepset
if: steps.changed-files.outputs.any_changed == 'true' || steps.changed-files.outputs.any_deleted == 'true'
env:
DEEPSET_API_KEY: ${{ secrets.DEEPSET_API_KEY }}
DEEPSET_WORKSPACE: haystack-code
run: |
# Combine added and modified files for upload
CHANGED_FILES=""
if [ -n "${{ steps.changed-files.outputs.added_files }}" ]; then
CHANGED_FILES="${{ steps.changed-files.outputs.added_files }}"
fi
if [ -n "${{ steps.changed-files.outputs.modified_files }}" ]; then
if [ -n "$CHANGED_FILES" ]; then
CHANGED_FILES="$CHANGED_FILES ${{ steps.changed-files.outputs.modified_files }}"
else
CHANGED_FILES="${{ steps.changed-files.outputs.modified_files }}"
fi
fi
# Run the script with changed and deleted files
# shellcheck disable=SC2086
uv run --no-project --no-config --no-cache .github/utils/deepset_sync.py \
--changed $CHANGED_FILES \
--deleted ${{ steps.changed-files.outputs.deleted_files }}

View File

@ -19,7 +19,7 @@ on:
- "haystack/core/pipeline/predefined/*" - "haystack/core/pipeline/predefined/*"
- "test/**/*.py" - "test/**/*.py"
- "pyproject.toml" - "pyproject.toml"
- ".github/utils/check_imports.py" - ".github/utils/*.py"
env: env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}

View File

@ -16,7 +16,7 @@ on:
- "haystack/core/pipeline/predefined/*" - "haystack/core/pipeline/predefined/*"
- "test/**/*.py" - "test/**/*.py"
- "pyproject.toml" - "pyproject.toml"
- ".github/utils/check_imports.py" - ".github/utils/*.py"
jobs: jobs:
check_if_changed: check_if_changed:
@ -39,7 +39,7 @@ jobs:
- "haystack/core/pipeline/predefined/*" - "haystack/core/pipeline/predefined/*"
- "test/**/*.py" - "test/**/*.py"
- "pyproject.toml" - "pyproject.toml"
- ".github/utils/check_imports.py" - ".github/utils/*.py"
trigger-catch-all: trigger-catch-all:
name: Tests completed name: Tests completed

View File

@ -1 +1 @@
2.15.0-rc0 2.16.0-rc0

View File

@ -2,7 +2,7 @@ loaders:
- type: haystack_pydoc_tools.loaders.CustomPythonLoader - type: haystack_pydoc_tools.loaders.CustomPythonLoader
search_path: [../../../haystack/dataclasses] search_path: [../../../haystack/dataclasses]
modules: modules:
["answer", "byte_stream", "chat_message", "document", "sparse_embedding", "state", "streaming_chunk"] ["answer", "byte_stream", "chat_message", "document", "sparse_embedding", "streaming_chunk"]
ignore_when_discovered: ["__init__"] ignore_when_discovered: ["__init__"]
processors: processors:
- type: filter - type: filter

View File

@ -68,8 +68,8 @@ def test_ner_extractor_hf_backend(raw_texts, hf_annotations, batch_size):
@pytest.mark.parametrize("batch_size", [1, 3]) @pytest.mark.parametrize("batch_size", [1, 3])
@pytest.mark.skipif( @pytest.mark.skipif(
not os.environ.get("HF_API_TOKEN", None), not os.environ.get("HF_API_TOKEN", None) and not os.environ.get("HF_TOKEN", None),
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", reason="Export an env var called HF_API_TOKEN or HF_TOKEN containing the Hugging Face token to run this test.",
) )
def test_ner_extractor_hf_backend_private_models(raw_texts, hf_annotations, batch_size): def test_ner_extractor_hf_backend_private_models(raw_texts, hf_annotations, batch_size):
extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="deepset/bert-base-NER") extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="deepset/bert-base-NER")

View File

@ -67,8 +67,9 @@ class Agent:
exit_conditions: Optional[List[str]] = None, exit_conditions: Optional[List[str]] = None,
state_schema: Optional[Dict[str, Any]] = None, state_schema: Optional[Dict[str, Any]] = None,
max_agent_steps: int = 100, max_agent_steps: int = 100,
raise_on_tool_invocation_failure: bool = False,
streaming_callback: Optional[StreamingCallbackT] = None, streaming_callback: Optional[StreamingCallbackT] = None,
raise_on_tool_invocation_failure: bool = False,
tool_invoker_kwargs: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
""" """
Initialize the agent component. Initialize the agent component.
@ -82,10 +83,11 @@ class Agent:
:param state_schema: The schema for the runtime state used by the tools. :param state_schema: The schema for the runtime state used by the tools.
:param max_agent_steps: Maximum number of steps the agent will run before stopping. Defaults to 100. :param max_agent_steps: Maximum number of steps the agent will run before stopping. Defaults to 100.
If the agent exceeds this number of steps, it will stop and return the current state. If the agent exceeds this number of steps, it will stop and return the current state.
:param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails?
If set to False, the exception will be turned into a chat message and passed to the LLM.
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM. :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
The same callback can be configured to emit tool results when a tool is called. The same callback can be configured to emit tool results when a tool is called.
:param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails?
If set to False, the exception will be turned into a chat message and passed to the LLM.
:param tool_invoker_kwargs: Additional keyword arguments to pass to the ToolInvoker.
:raises TypeError: If the chat_generator does not support tools parameter in its run method. :raises TypeError: If the chat_generator does not support tools parameter in its run method.
:raises ValueError: If the exit_conditions are not valid. :raises ValueError: If the exit_conditions are not valid.
""" """
@ -135,9 +137,15 @@ class Agent:
component.set_input_type(self, name=param, type=config["type"], default=None) component.set_input_type(self, name=param, type=config["type"], default=None)
component.set_output_types(self, **output_types) component.set_output_types(self, **output_types)
self.tool_invoker_kwargs = tool_invoker_kwargs
self._tool_invoker = None self._tool_invoker = None
if self.tools: if self.tools:
self._tool_invoker = ToolInvoker(tools=self.tools, raise_on_failure=self.raise_on_tool_invocation_failure) resolved_tool_invoker_kwargs = {
"tools": self.tools,
"raise_on_failure": self.raise_on_tool_invocation_failure,
**(tool_invoker_kwargs or {}),
}
self._tool_invoker = ToolInvoker(**resolved_tool_invoker_kwargs)
else: else:
logger.warning( logger.warning(
"No tools provided to the Agent. The Agent will behave like a ChatGenerator and only return text " "No tools provided to the Agent. The Agent will behave like a ChatGenerator and only return text "
@ -175,8 +183,9 @@ class Agent:
# We serialize the original state schema, not the resolved one to reflect the original user input # We serialize the original state schema, not the resolved one to reflect the original user input
state_schema=_schema_to_dict(self._state_schema), state_schema=_schema_to_dict(self._state_schema),
max_agent_steps=self.max_agent_steps, max_agent_steps=self.max_agent_steps,
raise_on_tool_invocation_failure=self.raise_on_tool_invocation_failure,
streaming_callback=streaming_callback, streaming_callback=streaming_callback,
raise_on_tool_invocation_failure=self.raise_on_tool_invocation_failure,
tool_invoker_kwargs=self.tool_invoker_kwargs,
) )
@classmethod @classmethod

View File

@ -193,13 +193,13 @@ class HuggingFaceAPIChatGenerator:
HuggingFaceAPIChatGenerator uses the [ChatMessage](https://docs.haystack.deepset.ai/docs/chatmessage) HuggingFaceAPIChatGenerator uses the [ChatMessage](https://docs.haystack.deepset.ai/docs/chatmessage)
format for input and output. Use it to generate text with Hugging Face APIs: format for input and output. Use it to generate text with Hugging Face APIs:
- [Free Serverless Inference API](https://huggingface.co/inference-api) - [Serverless Inference API (Inference Providers)](https://huggingface.co/docs/inference-providers)
- [Paid Inference Endpoints](https://huggingface.co/inference-endpoints) - [Paid Inference Endpoints](https://huggingface.co/inference-endpoints)
- [Self-hosted Text Generation Inference](https://github.com/huggingface/text-generation-inference) - [Self-hosted Text Generation Inference](https://github.com/huggingface/text-generation-inference)
### Usage examples ### Usage examples
#### With the free serverless inference API #### With the serverless inference API (Inference Providers) - free tier available
```python ```python
from haystack.components.generators.chat import HuggingFaceAPIChatGenerator from haystack.components.generators.chat import HuggingFaceAPIChatGenerator
@ -215,7 +215,8 @@ class HuggingFaceAPIChatGenerator:
api_type = "serverless_inference_api" # this is equivalent to the above api_type = "serverless_inference_api" # this is equivalent to the above
generator = HuggingFaceAPIChatGenerator(api_type=api_type, generator = HuggingFaceAPIChatGenerator(api_type=api_type,
api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, api_params={"model": "microsoft/Phi-3.5-mini-instruct",
"provider": "featherless-ai"},
token=Secret.from_token("<your-api-key>")) token=Secret.from_token("<your-api-key>"))
result = generator.run(messages) result = generator.run(messages)
@ -273,13 +274,15 @@ class HuggingFaceAPIChatGenerator:
The type of Hugging Face API to use. Available types: The type of Hugging Face API to use. Available types:
- `text_generation_inference`: See [TGI](https://github.com/huggingface/text-generation-inference). - `text_generation_inference`: See [TGI](https://github.com/huggingface/text-generation-inference).
- `inference_endpoints`: See [Inference Endpoints](https://huggingface.co/inference-endpoints). - `inference_endpoints`: See [Inference Endpoints](https://huggingface.co/inference-endpoints).
- `serverless_inference_api`: See [Serverless Inference API](https://huggingface.co/inference-api). - `serverless_inference_api`: See
[Serverless Inference API - Inference Providers](https://huggingface.co/docs/inference-providers).
:param api_params: :param api_params:
A dictionary with the following keys: A dictionary with the following keys:
- `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`. - `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`.
- `provider`: Provider name. Recommended when `api_type` is `SERVERLESS_INFERENCE_API`.
- `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or - `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or
`TEXT_GENERATION_INFERENCE`. `TEXT_GENERATION_INFERENCE`.
- Other parameters specific to the chosen API type, such as `timeout`, `headers`, `provider` etc. - Other parameters specific to the chosen API type, such as `timeout`, `headers`, etc.
:param token: :param token:
The Hugging Face token to use as HTTP bearer authorization. The Hugging Face token to use as HTTP bearer authorization.
Check your HF token in your [account settings](https://huggingface.co/settings/tokens). Check your HF token in your [account settings](https://huggingface.co/settings/tokens).

View File

@ -6,7 +6,7 @@ from dataclasses import asdict
from datetime import datetime from datetime import datetime
from typing import Any, Dict, Iterable, List, Optional, Union, cast from typing import Any, Dict, Iterable, List, Optional, Union, cast
from haystack import component, default_from_dict, default_to_dict from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses import ( from haystack.dataclasses import (
ComponentInfo, ComponentInfo,
FinishReason, FinishReason,
@ -29,33 +29,26 @@ with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as hugg
) )
logger = logging.getLogger(__name__)
@component @component
class HuggingFaceAPIGenerator: class HuggingFaceAPIGenerator:
""" """
Generates text using Hugging Face APIs. Generates text using Hugging Face APIs.
Use it with the following Hugging Face APIs: Use it with the following Hugging Face APIs:
- [Free Serverless Inference API]((https://huggingface.co/inference-api)
- [Paid Inference Endpoints](https://huggingface.co/inference-endpoints) - [Paid Inference Endpoints](https://huggingface.co/inference-endpoints)
- [Self-hosted Text Generation Inference](https://github.com/huggingface/text-generation-inference) - [Self-hosted Text Generation Inference](https://github.com/huggingface/text-generation-inference)
**Note:** As of July 2025, the Hugging Face Inference API no longer offers generative models through the
`text_generation` endpoint. Generative models are now only available through providers supporting the
`chat_completion` endpoint. As a result, this component might no longer work with the Hugging Face Inference API.
Use the `HuggingFaceAPIChatGenerator` component, which supports the `chat_completion` endpoint.
### Usage examples ### Usage examples
#### With the free serverless inference API #### With Hugging Face Inference Endpoints
```python
from haystack.components.generators import HuggingFaceAPIGenerator
from haystack.utils import Secret
generator = HuggingFaceAPIGenerator(api_type="serverless_inference_api",
api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
token=Secret.from_token("<your-api-key>"))
result = generator.run(prompt="What's Natural Language Processing?")
print(result)
```
#### With paid inference endpoints
```python ```python
from haystack.components.generators import HuggingFaceAPIGenerator from haystack.components.generators import HuggingFaceAPIGenerator
@ -75,6 +68,24 @@ class HuggingFaceAPIGenerator:
generator = HuggingFaceAPIGenerator(api_type="text_generation_inference", generator = HuggingFaceAPIGenerator(api_type="text_generation_inference",
api_params={"url": "http://localhost:8080"}) api_params={"url": "http://localhost:8080"})
result = generator.run(prompt="What's Natural Language Processing?")
print(result)
```
#### With the free serverless inference API
Be aware that this example might not work as the Hugging Face Inference API no longer offer models that support the
`text_generation` endpoint. Use the `HuggingFaceAPIChatGenerator` for generative models through the
`chat_completion` endpoint.
```python
from haystack.components.generators import HuggingFaceAPIGenerator
from haystack.utils import Secret
generator = HuggingFaceAPIGenerator(api_type="serverless_inference_api",
api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
token=Secret.from_token("<your-api-key>"))
result = generator.run(prompt="What's Natural Language Processing?") result = generator.run(prompt="What's Natural Language Processing?")
print(result) print(result)
``` ```
@ -97,6 +108,8 @@ class HuggingFaceAPIGenerator:
- `text_generation_inference`: See [TGI](https://github.com/huggingface/text-generation-inference). - `text_generation_inference`: See [TGI](https://github.com/huggingface/text-generation-inference).
- `inference_endpoints`: See [Inference Endpoints](https://huggingface.co/inference-endpoints). - `inference_endpoints`: See [Inference Endpoints](https://huggingface.co/inference-endpoints).
- `serverless_inference_api`: See [Serverless Inference API](https://huggingface.co/inference-api). - `serverless_inference_api`: See [Serverless Inference API](https://huggingface.co/inference-api).
This might no longer work due to changes in the models offered in the Hugging Face Inference API.
Please use the `HuggingFaceAPIChatGenerator` component instead.
:param api_params: :param api_params:
A dictionary with the following keys: A dictionary with the following keys:
- `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`. - `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`.
@ -120,6 +133,11 @@ class HuggingFaceAPIGenerator:
api_type = HFGenerationAPIType.from_str(api_type) api_type = HFGenerationAPIType.from_str(api_type)
if api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API: if api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API:
logger.warning(
"Due to changes in the models offered in Hugging Face Inference API, using this component with the "
"Serverless Inference API might no longer work. "
"Please use the `HuggingFaceAPIChatGenerator` component instead."
)
model = api_params.get("model") model = api_params.get("model")
if model is None: if model is None:
raise ValueError( raise ValueError(

View File

@ -44,7 +44,7 @@ def print_streaming_chunk(chunk: StreamingChunk) -> None:
if chunk.index and tool_call.index > chunk.index: if chunk.index and tool_call.index > chunk.index:
print("\n\n", flush=True, end="") print("\n\n", flush=True, end="")
print("[TOOL CALL]\nTool: {tool_call.tool_name} \nArguments: ", flush=True, end="") print(f"[TOOL CALL]\nTool: {tool_call.tool_name} \nArguments: ", flush=True, end="")
# print the tool arguments # print the tool arguments
if tool_call.arguments: if tool_call.arguments:
@ -84,24 +84,20 @@ def _convert_streaming_chunks_to_chat_message(chunks: List[StreamingChunk]) -> C
tool_call_data: Dict[int, Dict[str, str]] = {} # Track tool calls by index tool_call_data: Dict[int, Dict[str, str]] = {} # Track tool calls by index
for chunk in chunks: for chunk in chunks:
if chunk.tool_calls: if chunk.tool_calls:
# We do this to make sure mypy is happy, but we enforce index is not None in the StreamingChunk dataclass if
# tool_call is present
assert chunk.index is not None
for tool_call in chunk.tool_calls: for tool_call in chunk.tool_calls:
# We use the index of the tool_call to track the tool call across chunks since the ID is not always # We use the index of the tool_call to track the tool call across chunks since the ID is not always
# provided # provided
if tool_call.index not in tool_call_data: if tool_call.index not in tool_call_data:
tool_call_data[chunk.index] = {"id": "", "name": "", "arguments": ""} tool_call_data[tool_call.index] = {"id": "", "name": "", "arguments": ""}
# Save the ID if present # Save the ID if present
if tool_call.id is not None: if tool_call.id is not None:
tool_call_data[chunk.index]["id"] = tool_call.id tool_call_data[tool_call.index]["id"] = tool_call.id
if tool_call.tool_name is not None: if tool_call.tool_name is not None:
tool_call_data[chunk.index]["name"] += tool_call.tool_name tool_call_data[tool_call.index]["name"] += tool_call.tool_name
if tool_call.arguments is not None: if tool_call.arguments is not None:
tool_call_data[chunk.index]["arguments"] += tool_call.arguments tool_call_data[tool_call.index]["arguments"] += tool_call.arguments
# Convert accumulated tool call data into ToolCall objects # Convert accumulated tool call data into ToolCall objects
sorted_keys = sorted(tool_call_data.keys()) sorted_keys = sorted(tool_call_data.keys())

View File

@ -328,8 +328,8 @@ class SentenceTransformersDiversityRanker:
# Normalize embeddings to unit length for computing cosine similarity # Normalize embeddings to unit length for computing cosine similarity
if self.similarity == DiversityRankingSimilarity.COSINE: if self.similarity == DiversityRankingSimilarity.COSINE:
doc_embeddings /= torch.norm(doc_embeddings, p=2, dim=-1).unsqueeze(-1) doc_embeddings = doc_embeddings / torch.norm(doc_embeddings, p=2, dim=-1).unsqueeze(-1)
query_embedding /= torch.norm(query_embedding, p=2, dim=-1).unsqueeze(-1) query_embedding = query_embedding / torch.norm(query_embedding, p=2, dim=-1).unsqueeze(-1)
return doc_embeddings, query_embedding return doc_embeddings, query_embedding
def _maximum_margin_relevance( def _maximum_margin_relevance(

View File

@ -5,7 +5,6 @@
import asyncio import asyncio
import inspect import inspect
import json import json
import warnings
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from functools import partial from functools import partial
from typing import Any, Dict, List, Optional, Set, Union from typing import Any, Dict, List, Optional, Set, Union
@ -171,7 +170,6 @@ class ToolInvoker:
*, *,
enable_streaming_callback_passthrough: bool = False, enable_streaming_callback_passthrough: bool = False,
max_workers: int = 4, max_workers: int = 4,
async_executor: Optional[ThreadPoolExecutor] = None,
): ):
""" """
Initialize the ToolInvoker component. Initialize the ToolInvoker component.
@ -197,13 +195,7 @@ class ToolInvoker:
If False, the `streaming_callback` will not be passed to the tool invocation. If False, the `streaming_callback` will not be passed to the tool invocation.
:param max_workers: :param max_workers:
The maximum number of workers to use in the thread pool executor. The maximum number of workers to use in the thread pool executor.
:param async_executor: This also decides the maximum number of concurrent tool invocations.
Optional `ThreadPoolExecutor` to use for asynchronous calls.
Note: As of Haystack 2.15.0, you no longer need to explicitly pass
`async_executor`. Instead, you can provide the `max_workers` parameter,
and a `ThreadPoolExecutor` will be created automatically for parallel tool invocations.
Support for `async_executor` will be removed in Haystack 2.16.0.
Please migrate to using `max_workers` instead.
:raises ValueError: :raises ValueError:
If no tools are provided or if duplicate tool names are found. If no tools are provided or if duplicate tool names are found.
""" """
@ -231,37 +223,6 @@ class ToolInvoker:
self._tools_with_names = dict(zip(tool_names, converted_tools)) self._tools_with_names = dict(zip(tool_names, converted_tools))
self.raise_on_failure = raise_on_failure self.raise_on_failure = raise_on_failure
self.convert_result_to_json_string = convert_result_to_json_string self.convert_result_to_json_string = convert_result_to_json_string
self._owns_executor = async_executor is None
if self._owns_executor:
warnings.warn(
"'async_executor' is deprecated in favor of the 'max_workers' parameter. "
"ToolInvoker now creates its own thread pool executor by default using 'max_workers'. "
"Support for 'async_executor' will be removed in Haystack 2.16.0. "
"Please update your usage to pass 'max_workers' instead.",
DeprecationWarning,
)
self.executor = (
ThreadPoolExecutor(
thread_name_prefix=f"async-ToolInvoker-executor-{id(self)}", max_workers=self.max_workers
)
if async_executor is None
else async_executor
)
def __del__(self):
"""
Cleanup when the instance is being destroyed.
"""
if hasattr(self, "_owns_executor") and self._owns_executor and hasattr(self, "executor"):
self.executor.shutdown(wait=True)
def shutdown(self):
"""
Explicitly shutdown the executor if we own it.
"""
if self._owns_executor:
self.executor.shutdown(wait=True)
def _handle_error(self, error: Exception) -> str: def _handle_error(self, error: Exception) -> str:
""" """
@ -655,7 +616,7 @@ class ToolInvoker:
return e return e
@component.output_types(tool_messages=List[ChatMessage], state=State) @component.output_types(tool_messages=List[ChatMessage], state=State)
async def run_async( async def run_async( # noqa: PLR0915
self, self,
messages: List[ChatMessage], messages: List[ChatMessage],
state: Optional[State] = None, state: Optional[State] = None,
@ -718,10 +679,10 @@ class ToolInvoker:
# 2) Execute valid tool calls in parallel # 2) Execute valid tool calls in parallel
if tool_call_params: if tool_call_params:
with self.executor as executor:
tool_call_tasks = [] tool_call_tasks = []
valid_tool_calls = [] valid_tool_calls = []
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
# 3) Create async tasks for valid tool calls # 3) Create async tasks for valid tool calls
for params in tool_call_params: for params in tool_call_params:
task = ToolInvoker.invoke_tool_safely(executor, params["tool_to_invoke"], params["final_args"]) task = ToolInvoker.invoke_tool_safely(executor, params["tool_to_invoke"], params["final_args"])
@ -749,7 +710,7 @@ class ToolInvoker:
try: try:
error_message = self._handle_error( error_message = self._handle_error(
ToolOutputMergeError( ToolOutputMergeError(
f"Failed to merge tool outputs fromtool {tool_call.tool_name} into State: {e}" f"Failed to merge tool outputs from tool {tool_call.tool_name} into State: {e}"
) )
) )
tool_messages.append( tool_messages.append(

View File

@ -38,7 +38,6 @@ if TYPE_CHECKING:
from .chat_message import ToolCallResult as ToolCallResult from .chat_message import ToolCallResult as ToolCallResult
from .document import Document as Document from .document import Document as Document
from .sparse_embedding import SparseEmbedding as SparseEmbedding from .sparse_embedding import SparseEmbedding as SparseEmbedding
from .state import State as State
from .streaming_chunk import AsyncStreamingCallbackT as AsyncStreamingCallbackT from .streaming_chunk import AsyncStreamingCallbackT as AsyncStreamingCallbackT
from .streaming_chunk import ComponentInfo as ComponentInfo from .streaming_chunk import ComponentInfo as ComponentInfo
from .streaming_chunk import FinishReason as FinishReason from .streaming_chunk import FinishReason as FinishReason

View File

@ -79,3 +79,24 @@ class ByteStream:
fields.append(f"mime_type={self.mime_type!r}") fields.append(f"mime_type={self.mime_type!r}")
fields_str = ", ".join(fields) fields_str = ", ".join(fields)
return f"{self.__class__.__name__}({fields_str})" return f"{self.__class__.__name__}({fields_str})"
def to_dict(self) -> Dict[str, Any]:
"""
Convert the ByteStream to a dictionary representation.
:returns: A dictionary with keys 'data', 'meta', and 'mime_type'.
"""
# Note: The data is converted to a list of integers for serialization since JSON does not support bytes
# directly.
return {"data": list(self.data), "meta": self.meta, "mime_type": self.mime_type}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ByteStream":
"""
Create a ByteStream from a dictionary representation.
:param data: A dictionary with keys 'data', 'meta', and 'mime_type'.
:returns: A ByteStream instance.
"""
return ByteStream(data=bytes(data["data"]), meta=data.get("meta", {}), mime_type=data.get("mime_type"))

View File

@ -127,8 +127,12 @@ class Document(metaclass=_BackwardCompatible): # noqa: PLW1641
Whether to flatten `meta` field or not. Defaults to `True` to be backward-compatible with Haystack 1.x. Whether to flatten `meta` field or not. Defaults to `True` to be backward-compatible with Haystack 1.x.
""" """
data = asdict(self) data = asdict(self)
if (blob := data.get("blob")) is not None:
data["blob"] = {"data": list(blob["data"]), "mime_type": blob["mime_type"]} # Use `ByteStream` and `SparseEmbedding`'s to_dict methods to convert them to JSON-serializable types.
if self.blob is not None:
data["blob"] = self.blob.to_dict()
if self.sparse_embedding is not None:
data["sparse_embedding"] = self.sparse_embedding.to_dict()
if flatten: if flatten:
meta = data.pop("meta") meta = data.pop("meta")
@ -144,7 +148,7 @@ class Document(metaclass=_BackwardCompatible): # noqa: PLW1641
The `blob` field is converted to its original type. The `blob` field is converted to its original type.
""" """
if blob := data.get("blob"): if blob := data.get("blob"):
data["blob"] = ByteStream(data=bytes(blob["data"]), mime_type=blob["mime_type"]) data["blob"] = ByteStream.from_dict(blob)
if sparse_embedding := data.get("sparse_embedding"): if sparse_embedding := data.get("sparse_embedding"):
data["sparse_embedding"] = SparseEmbedding.from_dict(sparse_embedding) data["sparse_embedding"] = SparseEmbedding.from_dict(sparse_embedding)

View File

@ -1,43 +0,0 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import warnings
from typing import Any, Dict, Optional
from haystack.components.agents import State as Utils_State
class State(Utils_State):
"""
A class that wraps a StateSchema and maintains an internal _data dictionary.
Deprecated in favor of `haystack.components.agents.State`. It will be removed in Haystack 2.16.0.
Each schema entry has:
"parameter_name": {
"type": SomeType,
"handler": Optional[Callable[[Any, Any], Any]]
}
"""
def __init__(self, schema: Dict[str, Any], data: Optional[Dict[str, Any]] = None):
"""
Initialize a State object with a schema and optional data.
:param schema: Dictionary mapping parameter names to their type and handler configs.
Type must be a valid Python type, and handler must be a callable function or None.
If handler is None, the default handler for the type will be used. The default handlers are:
- For list types: `haystack.components.agents.state.state_utils.merge_lists`
- For all other types: `haystack.components.agents.state.state_utils.replace_values`
:param data: Optional dictionary of initial data to populate the state
"""
warnings.warn(
"`haystack.dataclasses.State` is deprecated and will be removed in Haystack 2.16.0. "
"Use `haystack.components.agents.State` instead.",
DeprecationWarning,
)
super().__init__(schema, data)

View File

@ -1,52 +0,0 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import warnings
from typing import Any, List, TypeVar, Union
T = TypeVar("T")
def merge_lists(current: Union[List[T], T, None], new: Union[List[T], T]) -> List[T]:
"""
Merges two values into a single list.
Deprecated in favor of `haystack.components.agents.state.merge_lists`. It will be removed in Haystack 2.16.0.
If either `current` or `new` is not already a list, it is converted into one.
The function ensures that both inputs are treated as lists and concatenates them.
If `current` is None, it is treated as an empty list.
:param current: The existing value(s), either a single item or a list.
:param new: The new value(s) to merge, either a single item or a list.
:return: A list containing elements from both `current` and `new`.
"""
warnings.warn(
"`haystack.dataclasses.state_utils.merge_lists` is deprecated and will be removed in Haystack 2.16.0. "
"Use `haystack.components.agents.state.merge_lists` instead.",
DeprecationWarning,
)
current_list = [] if current is None else current if isinstance(current, list) else [current]
new_list = new if isinstance(new, list) else [new]
return current_list + new_list
def replace_values(current: Any, new: Any) -> Any:
"""
Replace the `current` value with the `new` value.
:param current: The existing value
:param new: The new value to replace
:return: The new value
"""
warnings.warn(
"`haystack.dataclasses.state_utils.replace_values` is deprecated and will be removed in Haystack 2.16.0. "
"Use `haystack.components.agents.state.replace_values` instead.",
DeprecationWarning,
)
return new

View File

@ -30,12 +30,6 @@ class ToolCallDelta:
arguments: Optional[str] = field(default=None) arguments: Optional[str] = field(default=None)
id: Optional[str] = field(default=None) # noqa: A003 id: Optional[str] = field(default=None) # noqa: A003
def __post_init__(self):
# NOTE: We allow for name and arguments to both be present because some providers like Mistral provide the
# name and full arguments in one chunk
if self.tool_name is None and self.arguments is None:
raise ValueError("At least one of tool_name or arguments must be provided.")
@dataclass @dataclass
class ComponentInfo: class ComponentInfo:

View File

@ -2,7 +2,6 @@
# #
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import warnings
from typing import Any, Dict from typing import Any, Dict
from haystack.core.errors import DeserializationError, SerializationError from haystack.core.errors import DeserializationError, SerializationError
@ -210,14 +209,12 @@ def _deserialize_value_with_schema(serialized: Dict[str, Any]) -> Any: # pylint
if not schema_type: if not schema_type:
# for backward comaptability till Haystack 2.16 we use legacy implementation # for backward comaptability till Haystack 2.16 we use legacy implementation
warnings.warn( raise DeserializationError(
"Missing 'type' key in 'serialization_schema'. This likely indicates that you're using a serialized " "Missing 'type' key in 'serialization_schema'. This likely indicates that you're using a serialized "
"State object created with a version of Haystack older than 2.15.0. " "State object created with a version of Haystack older than 2.15.0. "
"Support for the old serialization format will be removed in Haystack 2.16.0. " "Support for the old serialization format is removed in Haystack 2.16.0. "
"Please upgrade to the new serialization format to ensure forward compatibility.", "Please upgrade to the new serialization format to ensure forward compatibility."
DeprecationWarning,
) )
return _deserialize_value_with_schema_legacy(serialized)
# Handle object case (dictionary with properties) # Handle object case (dictionary with properties)
if schema_type == "object": if schema_type == "object":
@ -331,61 +328,3 @@ def _deserialize_value(value: Any) -> Any: # pylint: disable=too-many-return-st
# 4) Fallback (shouldn't usually happen with our schema) # 4) Fallback (shouldn't usually happen with our schema)
return value return value
def _deserialize_value_with_schema_legacy(serialized: Dict[str, Any]) -> Dict[str, Any]:
"""
Legacy function for deserializing a dictionary with schema information and data to original values.
Kept for backward compatibility till Haystack 2.16.0.
Takes a dict of the form:
{
"schema": {
"numbers": {"type": "integer"},
"messages": {"type": "array", "items": {"type": "haystack.dataclasses.chat_message.ChatMessage"}},
},
"data": {
"numbers": 1,
"messages": [{"role": "user", "meta": {}, "name": None, "content": [{"text": "Hello, world!"}]}],
}
:param serialized: The serialized dict with schema and data.
:returns: The deserialized dict with original values.
"""
schema = serialized.get("serialization_schema", {})
data = serialized.get("serialized_data", {})
result: Dict[str, Any] = {}
for field, raw in data.items():
info = schema.get(field)
# no schema entry → just deep-deserialize whatever we have
if not info:
result[field] = _deserialize_value(raw)
continue
t = info["type"]
# ARRAY case
if t == "array":
item_type = info["items"]["type"]
reconstructed = []
for item in raw:
envelope = {"type": item_type, "data": item}
reconstructed.append(_deserialize_value(envelope))
result[field] = reconstructed
# PRIMITIVE case
elif t in ("null", "boolean", "integer", "number", "string"):
result[field] = raw
# GENERIC OBJECT
elif t == "object":
envelope = {"type": "object", "data": raw}
result[field] = _deserialize_value(envelope)
# CUSTOM CLASS
else:
envelope = {"type": t, "data": raw}
result[field] = _deserialize_value(envelope)
return result

View File

@ -153,7 +153,8 @@ integration-only-fast = 'pytest --maxfail=5 -m "integration and not slow" {args:
integration-only-slow = 'pytest --maxfail=5 -m "integration and slow" {args:test}' integration-only-slow = 'pytest --maxfail=5 -m "integration and slow" {args:test}'
all = 'pytest {args:test}' all = 'pytest {args:test}'
types = "mypy --install-types --non-interactive --cache-dir=.mypy_cache/ {args:haystack}" # TODO We want to eventually type the whole test folder
types = "mypy --install-types --non-interactive --cache-dir=.mypy_cache/ {args:haystack test/core/pipeline/features/test_run.py}"
lint = "pylint -ry -j 0 {args:haystack}" lint = "pylint -ry -j 0 {args:haystack}"
[tool.hatch.envs.e2e] [tool.hatch.envs.e2e]
@ -283,14 +284,24 @@ disallow_incomplete_defs = false
[tool.ruff] [tool.ruff]
line-length = 120 line-length = 120
exclude = [".github", "proposals"] exclude = ["proposals"]
[tool.ruff.format] [tool.ruff.format]
skip-magic-trailing-comma = true skip-magic-trailing-comma = true
[tool.ruff.lint.per-file-ignores]
"test/**" = [
"D102", # Missing docstring in public method
"D103", # Missing docstring in public function
"D205", # 1 blank line required between summary line and description
"PLC0206", # Extracting value from dictionary without calling `.items()`
"SIM105", # try-except-pass instead of contextlib-suppress
"SIM117", # Use a single `with` statement with multiple contexts instead of nested `with` statements
]
[tool.ruff.lint] [tool.ruff.lint]
isort.split-on-trailing-comma = false isort.split-on-trailing-comma = false
exclude = ["test/**", "e2e/**"] exclude = ["e2e/**"]
select = [ select = [
"ASYNC", # flake8-async "ASYNC", # flake8-async
"C4", # flake8-comprehensions "C4", # flake8-comprehensions

View File

@ -0,0 +1,4 @@
---
features:
- |
Add `to_dict` and `from_dict` to ByteStream so it is consistent with our other dataclasses in having serialization and deserialization methods.

View File

@ -0,0 +1,4 @@
---
features:
- |
Added the `tool_invoker_kwargs` param to Agent so additional kwargs can be passed to the ToolInvoker like `max_workers` and `enable_streaming_callback_passthrough`.

View File

@ -0,0 +1,13 @@
---
upgrade:
- |
`HuggingFaceAPIGenerator` might no longer work with the Hugging Face Inference API.
As of July 2025, the Hugging Face Inference API no longer offers generative models that support the
`text_generation` endpoint. Generative models are now only available through providers that support the
`chat_completion` endpoint.
As a result, the `HuggingFaceAPIGenerator` component might not work with the Hugging Face Inference API.
It still works with Hugging Face Inference Endpoints and self-hosted TGI instances.
To use generative models via Hugging Face Inference API, please use the `HuggingFaceAPIChatGenerator` component,
which supports the `chat_completion` endpoint.

View File

@ -0,0 +1,4 @@
---
fixes:
- |
Fix `_convert_streaming_chunks_to_chat_message` which is used to convert Haystack StreamingChunks into a Haystack ChatMessage. This fixes the scenario where one StreamingChunk contains two ToolCallDetlas in StreamingChunk.tool_calls. With this fix this correctly saves both ToolCallDeltas whereas before they were overwriting each other. This only occurs with some LLM providers like Mistral (and not OpenAI) due to how the provider returns tool calls.

View File

@ -0,0 +1,3 @@
---
fixes:
- Fixed a bug in the `print_streaming_chunk` utility function that prevented tool call name from being printed.

View File

@ -0,0 +1,4 @@
---
enhancements:
- |
We relaxed the requirement that in ToolCallDelta (introduced in Haystack 2.15) which required the parameters arguments or name to be populated to be able to create a ToolCallDelta dataclass. We remove this requirement to be more in line with OpenAI's SDK and since this was causing errors for some hosted versions of open source models following OpenAI's SDK specification.

View File

@ -0,0 +1,6 @@
---
upgrade:
- |
The deprecated `async_executor` parameter has been removed from the `ToolInvoker` class.
Please use the `max_workers` parameter instead and a `ThreadPoolExecutor` with
these workers will be created automatically for parallel tool invocations.

View File

@ -0,0 +1,5 @@
---
upgrade:
- |
The deprecated `State` class has been removed from the `haystack.dataclasses` module.
The `State` class is now part of the `haystack.components.agents` module.

View File

@ -0,0 +1,6 @@
---
upgrade:
- |
Remove the deserialize_value_with_schema_legacy function from the base_serialization module.
This function was used to deserialize State objects created with Haystack 2.14.0 or older.
Support for the old serialization format is removed in Haystack 2.16.0.

View File

@ -5,17 +5,16 @@
import logging import logging
import os import os
from datetime import datetime from datetime import datetime
from typing import Iterator, Dict, Any, List, Optional, Union from typing import Any, Dict, Iterator, List, Optional, Union
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import MagicMock, patch, AsyncMock
import pytest import pytest
from openai import Stream from openai import Stream
from openai.types.chat import ChatCompletionChunk, chat_completion_chunk from openai.types.chat import ChatCompletionChunk, chat_completion_chunk
from haystack.tracing.logging_tracer import LoggingTracer
from haystack import Pipeline, tracing from haystack import Pipeline, tracing
from haystack.components.agents import Agent from haystack.components.agents import Agent
from haystack.components.agents.state import merge_lists
from haystack.components.builders.chat_prompt_builder import ChatPromptBuilder from haystack.components.builders.chat_prompt_builder import ChatPromptBuilder
from haystack.components.builders.prompt_builder import PromptBuilder from haystack.components.builders.prompt_builder import PromptBuilder
from haystack.components.generators.chat.openai import OpenAIChatGenerator from haystack.components.generators.chat.openai import OpenAIChatGenerator
@ -24,11 +23,10 @@ from haystack.core.component.types import OutputSocket
from haystack.dataclasses import ChatMessage, ToolCall from haystack.dataclasses import ChatMessage, ToolCall
from haystack.dataclasses.chat_message import ChatRole, TextContent from haystack.dataclasses.chat_message import ChatRole, TextContent
from haystack.dataclasses.streaming_chunk import StreamingChunk from haystack.dataclasses.streaming_chunk import StreamingChunk
from haystack.tools import ComponentTool, Tool
from haystack.tools import Tool, ComponentTool
from haystack.tools.toolset import Toolset from haystack.tools.toolset import Toolset
from haystack.utils import serialize_callable, Secret from haystack.tracing.logging_tracer import LoggingTracer
from haystack.components.agents.state import merge_lists from haystack.utils import Secret, serialize_callable
def streaming_callback_for_serde(chunk: StreamingChunk): def streaming_callback_for_serde(chunk: StreamingChunk):
@ -174,6 +172,7 @@ class TestAgent:
tools=[weather_tool, component_tool], tools=[weather_tool, component_tool],
exit_conditions=["text", "weather_tool"], exit_conditions=["text", "weather_tool"],
state_schema={"foo": {"type": str}}, state_schema={"foo": {"type": str}},
tool_invoker_kwargs={"max_workers": 5, "enable_streaming_callback_passthrough": True},
) )
serialized_agent = agent.to_dict() serialized_agent = agent.to_dict()
assert serialized_agent == { assert serialized_agent == {
@ -236,8 +235,9 @@ class TestAgent:
"exit_conditions": ["text", "weather_tool"], "exit_conditions": ["text", "weather_tool"],
"state_schema": {"foo": {"type": "str"}}, "state_schema": {"foo": {"type": "str"}},
"max_agent_steps": 100, "max_agent_steps": 100,
"raise_on_tool_invocation_failure": False,
"streaming_callback": None, "streaming_callback": None,
"raise_on_tool_invocation_failure": False,
"tool_invoker_kwargs": {"max_workers": 5, "enable_streaming_callback_passthrough": True},
}, },
} }
@ -294,6 +294,7 @@ class TestAgent:
"max_agent_steps": 100, "max_agent_steps": 100,
"raise_on_tool_invocation_failure": False, "raise_on_tool_invocation_failure": False,
"streaming_callback": None, "streaming_callback": None,
"tool_invoker_kwargs": None,
}, },
} }
@ -361,6 +362,7 @@ class TestAgent:
"max_agent_steps": 100, "max_agent_steps": 100,
"raise_on_tool_invocation_failure": False, "raise_on_tool_invocation_failure": False,
"streaming_callback": None, "streaming_callback": None,
"tool_invoker_kwargs": {"max_workers": 5, "enable_streaming_callback_passthrough": True},
}, },
} }
agent = Agent.from_dict(data) agent = Agent.from_dict(data)
@ -375,6 +377,9 @@ class TestAgent:
"foo": {"type": str}, "foo": {"type": str},
"messages": {"handler": merge_lists, "type": List[ChatMessage]}, "messages": {"handler": merge_lists, "type": List[ChatMessage]},
} }
assert agent.tool_invoker_kwargs == {"max_workers": 5, "enable_streaming_callback_passthrough": True}
assert agent._tool_invoker.max_workers == 5
assert agent._tool_invoker.enable_streaming_callback_passthrough is True
def test_from_dict_with_toolset(self, monkeypatch): def test_from_dict_with_toolset(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "fake-key") monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
@ -426,6 +431,7 @@ class TestAgent:
"max_agent_steps": 100, "max_agent_steps": 100,
"raise_on_tool_invocation_failure": False, "raise_on_tool_invocation_failure": False,
"streaming_callback": None, "streaming_callback": None,
"tool_invoker_kwargs": None,
}, },
} }
agent = Agent.from_dict(data) agent = Agent.from_dict(data)
@ -882,15 +888,15 @@ class TestAgentTracing:
'{"messages": "list", "tools": "list"}', '{"messages": "list", "tools": "list"}',
"{}", "{}",
"{}", "{}",
'{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}], "tools": [{"type": "haystack.tools.tool.Tool", "data": {"name": "weather_tool", "description": "Provides weather information for a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}, "function": "test_agent.weather_function", "outputs_to_string": null, "inputs_from_state": null, "outputs_to_state": null}}]}', '{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}], "tools": [{"type": "haystack.tools.tool.Tool", "data": {"name": "weather_tool", "description": "Provides weather information for a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}, "function": "test_agent.weather_function", "outputs_to_string": null, "inputs_from_state": null, "outputs_to_state": null}}]}', # noqa: E501
1, 1,
'{"replies": [{"role": "assistant", "meta": {}, "name": null, "content": [{"text": "Hello"}]}]}', '{"replies": [{"role": "assistant", "meta": {}, "name": null, "content": [{"text": "Hello"}]}]}',
100, 100,
'[{"type": "haystack.tools.tool.Tool", "data": {"name": "weather_tool", "description": "Provides weather information for a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}, "function": "test_agent.weather_function", "outputs_to_string": null, "inputs_from_state": null, "outputs_to_state": null}}]', '[{"type": "haystack.tools.tool.Tool", "data": {"name": "weather_tool", "description": "Provides weather information for a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}, "function": "test_agent.weather_function", "outputs_to_string": null, "inputs_from_state": null, "outputs_to_state": null}}]', # noqa: E501
'["text"]', '["text"]',
'{"messages": {"type": "typing.List[haystack.dataclasses.chat_message.ChatMessage]", "handler": "haystack.components.agents.state.state_utils.merge_lists"}}', '{"messages": {"type": "typing.List[haystack.dataclasses.chat_message.ChatMessage]", "handler": "haystack.components.agents.state.state_utils.merge_lists"}}', # noqa: E501
'{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}], "streaming_callback": null}', '{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}], "streaming_callback": null}', # noqa: E501
'{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}, {"role": "assistant", "meta": {}, "name": null, "content": [{"text": "Hello"}]}]}', '{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}, {"role": "assistant", "meta": {}, "name": null, "content": [{"text": "Hello"}]}]}', # noqa: E501
1, 1,
] ]
for idx, record in enumerate(tags_records): for idx, record in enumerate(tags_records):
@ -942,15 +948,15 @@ class TestAgentTracing:
'{"messages": "list", "tools": "list"}', '{"messages": "list", "tools": "list"}',
"{}", "{}",
"{}", "{}",
'{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}], "tools": [{"type": "haystack.tools.tool.Tool", "data": {"name": "weather_tool", "description": "Provides weather information for a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}, "function": "test_agent.weather_function", "outputs_to_string": null, "inputs_from_state": null, "outputs_to_state": null}}]}', '{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}], "tools": [{"type": "haystack.tools.tool.Tool", "data": {"name": "weather_tool", "description": "Provides weather information for a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}, "function": "test_agent.weather_function", "outputs_to_string": null, "inputs_from_state": null, "outputs_to_state": null}}]}', # noqa: E501
1, 1,
'{"replies": [{"role": "assistant", "meta": {}, "name": null, "content": [{"text": "Hello from run_async"}]}]}', '{"replies": [{"role": "assistant", "meta": {}, "name": null, "content": [{"text": "Hello from run_async"}]}]}', # noqa: E501
100, 100,
'[{"type": "haystack.tools.tool.Tool", "data": {"name": "weather_tool", "description": "Provides weather information for a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}, "function": "test_agent.weather_function", "outputs_to_string": null, "inputs_from_state": null, "outputs_to_state": null}}]', '[{"type": "haystack.tools.tool.Tool", "data": {"name": "weather_tool", "description": "Provides weather information for a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}, "function": "test_agent.weather_function", "outputs_to_string": null, "inputs_from_state": null, "outputs_to_state": null}}]', # noqa: E501
'["text"]', '["text"]',
'{"messages": {"type": "typing.List[haystack.dataclasses.chat_message.ChatMessage]", "handler": "haystack.components.agents.state.state_utils.merge_lists"}}', '{"messages": {"type": "typing.List[haystack.dataclasses.chat_message.ChatMessage]", "handler": "haystack.components.agents.state.state_utils.merge_lists"}}', # noqa: E501
'{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}], "streaming_callback": null}', '{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}], "streaming_callback": null}', # noqa: E501
'{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}, {"role": "assistant", "meta": {}, "name": null, "content": [{"text": "Hello from run_async"}]}]}', '{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}, {"role": "assistant", "meta": {}, "name": null, "content": [{"text": "Hello from run_async"}]}]}', # noqa: E501
1, 1,
] ]
for idx, record in enumerate(tags_records): for idx, record in enumerate(tags_records):

View File

@ -2,21 +2,22 @@
# #
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import pytest import inspect
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Generic, List, Optional, TypeVar, Union
import pytest
from haystack.dataclasses import ChatMessage
from haystack.components.agents.state.state import ( from haystack.components.agents.state.state import (
State, State,
_validate_schema,
_schema_to_dict,
_schema_from_dict,
_is_list_type, _is_list_type,
merge_lists,
_is_valid_type, _is_valid_type,
_schema_from_dict,
_schema_to_dict,
_validate_schema,
merge_lists,
) )
from typing import List, Dict, Optional, Union, TypeVar, Generic from haystack.dataclasses import ChatMessage
import inspect
@pytest.fixture @pytest.fixture
@ -432,45 +433,3 @@ class TestState:
assert state.data["numbers"] == 1 assert state.data["numbers"] == 1
assert state.data["messages"] == [ChatMessage.from_user(text="Hello, world!")] assert state.data["messages"] == [ChatMessage.from_user(text="Hello, world!")]
assert state.data["dict_of_lists"] == {"numbers": [1, 2, 3]} assert state.data["dict_of_lists"] == {"numbers": [1, 2, 3]}
def test_state_from_dict_legacy(self):
# this is the old format of the state dictionary
# it is kept for backward compatibility
# it will be removed in Haystack 2.16.0
state_dict = {
"schema": {
"numbers": {"type": "int", "handler": "haystack.components.agents.state.state_utils.replace_values"},
"messages": {
"type": "typing.List[haystack.dataclasses.chat_message.ChatMessage]",
"handler": "haystack.components.agents.state.state_utils.merge_lists",
},
"dict_of_lists": {
"type": "dict",
"handler": "haystack.components.agents.state.state_utils.replace_values",
},
},
"data": {
"serialization_schema": {
"numbers": {"type": "integer"},
"messages": {"type": "array", "items": {"type": "haystack.dataclasses.chat_message.ChatMessage"}},
"dict_of_lists": {"type": "object"},
},
"serialized_data": {
"numbers": 1,
"messages": [{"role": "user", "meta": {}, "name": None, "content": [{"text": "Hello, world!"}]}],
"dict_of_lists": {"numbers": [1, 2, 3]},
},
},
}
state = State.from_dict(state_dict)
# Check types are correctly converted
assert state.schema["numbers"]["type"] == int
assert state.schema["dict_of_lists"]["type"] == dict
# Check handlers are functions, not comparing exact functions as they might be different references
assert callable(state.schema["numbers"]["handler"])
assert callable(state.schema["messages"]["handler"])
assert callable(state.schema["dict_of_lists"]["handler"])
# Check data is correct
assert state.data["numbers"] == 1
assert state.data["messages"] == [ChatMessage.from_user(text="Hello, world!")]
assert state.data["dict_of_lists"] == {"numbers": [1, 2, 3]}

View File

@ -4,18 +4,17 @@
import sys import sys
from pathlib import Path from pathlib import Path
from unittest.mock import patch, MagicMock from unittest.mock import MagicMock, patch
import pytest import pytest
import torch import torch
from haystack import Pipeline from haystack import Pipeline
from haystack.components.fetchers import LinkContentFetcher
from haystack.dataclasses import Document, ByteStream
from haystack.components.audio import LocalWhisperTranscriber from haystack.components.audio import LocalWhisperTranscriber
from haystack.components.fetchers import LinkContentFetcher
from haystack.dataclasses import ByteStream, Document
from haystack.utils.device import ComponentDevice, Device from haystack.utils.device import ComponentDevice, Device
SAMPLES_PATH = Path(__file__).parent.parent.parent / "test_files" SAMPLES_PATH = Path(__file__).parent.parent.parent / "test_files"
@ -192,12 +191,12 @@ class TestLocalWhisperTranscriber:
docs = output["documents"] docs = output["documents"]
assert len(docs) == 3 assert len(docs) == 3
assert all(word in docs[0].content.strip().lower() for word in {"content", "the", "document"}), ( assert all(word in docs[0].content.strip().lower() for word in ("content", "the", "document")), (
f"Expected words not found in: {docs[0].content.strip().lower()}" f"Expected words not found in: {docs[0].content.strip().lower()}"
) )
assert test_files_path / "audio" / "this is the content of the document.wav" == docs[0].meta["audio_file"] assert test_files_path / "audio" / "this is the content of the document.wav" == docs[0].meta["audio_file"]
assert all(word in docs[1].content.strip().lower() for word in {"context", "answer"}), ( assert all(word in docs[1].content.strip().lower() for word in ("context", "answer")), (
f"Expected words not found in: {docs[1].content.strip().lower()}" f"Expected words not found in: {docs[1].content.strip().lower()}"
) )
path = test_files_path / "audio" / "the context for this answer is here.wav" path = test_files_path / "audio" / "the context for this answer is here.wav"

View File

@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os import os
import pytest import pytest
from haystack import Pipeline from haystack import Pipeline

View File

@ -2,14 +2,15 @@
# #
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
from jinja2 import TemplateSyntaxError
import arrow
import logging import logging
import pytest from typing import Any, Dict, List, Optional
import arrow
import pytest
from jinja2 import TemplateSyntaxError
from haystack.components.builders.chat_prompt_builder import ChatPromptBuilder
from haystack import component from haystack import component
from haystack.components.builders.chat_prompt_builder import ChatPromptBuilder
from haystack.core.pipeline.pipeline import Pipeline from haystack.core.pipeline.pipeline import Pipeline
from haystack.dataclasses.chat_message import ChatMessage from haystack.dataclasses.chat_message import ChatMessage
from haystack.dataclasses.document import Document from haystack.dataclasses.document import Document

View File

@ -2,11 +2,11 @@
# #
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import logging
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from unittest.mock import patch from unittest.mock import patch
import arrow import arrow
import logging
import pytest import pytest
from jinja2 import TemplateSyntaxError from jinja2 import TemplateSyntaxError

View File

@ -2,13 +2,14 @@
# #
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from unittest.mock import MagicMock
import pytest import pytest
from haystack import Document, DeserializationError from haystack import DeserializationError, Document
from haystack.testing.factory import document_store_class
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.components.caching.cache_checker import CacheChecker from haystack.components.caching.cache_checker import CacheChecker
from unittest.mock import MagicMock from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.testing.factory import document_store_class
class TestCacheChecker: class TestCacheChecker:

View File

@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import logging import logging
import pytest import pytest
from haystack import Document from haystack import Document

View File

@ -2,10 +2,10 @@
# #
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import pytest
from unittest.mock import patch from unittest.mock import patch
import pytest
from haystack import Document, Pipeline from haystack import Document, Pipeline
from haystack.components.classifiers import TransformersZeroShotDocumentClassifier from haystack.components.classifiers import TransformersZeroShotDocumentClassifier
from haystack.components.retrievers import InMemoryBM25Retriever from haystack.components.retrievers import InMemoryBM25Retriever
@ -30,7 +30,7 @@ class TestTransformersZeroShotDocumentClassifier:
) )
component_dict = component.to_dict() component_dict = component.to_dict()
assert component_dict == { assert component_dict == {
"type": "haystack.components.classifiers.zero_shot_document_classifier.TransformersZeroShotDocumentClassifier", "type": "haystack.components.classifiers.zero_shot_document_classifier.TransformersZeroShotDocumentClassifier", # noqa: E501
"init_parameters": { "init_parameters": {
"model": "cross-encoder/nli-deberta-v3-xsmall", "model": "cross-encoder/nli-deberta-v3-xsmall",
"labels": ["positive", "negative"], "labels": ["positive", "negative"],
@ -47,7 +47,7 @@ class TestTransformersZeroShotDocumentClassifier:
monkeypatch.delenv("HF_API_TOKEN", raising=False) monkeypatch.delenv("HF_API_TOKEN", raising=False)
monkeypatch.delenv("HF_TOKEN", raising=False) monkeypatch.delenv("HF_TOKEN", raising=False)
data = { data = {
"type": "haystack.components.classifiers.zero_shot_document_classifier.TransformersZeroShotDocumentClassifier", "type": "haystack.components.classifiers.zero_shot_document_classifier.TransformersZeroShotDocumentClassifier", # noqa: E501
"init_parameters": { "init_parameters": {
"model": "cross-encoder/nli-deberta-v3-xsmall", "model": "cross-encoder/nli-deberta-v3-xsmall",
"labels": ["positive", "negative"], "labels": ["positive", "negative"],
@ -76,7 +76,7 @@ class TestTransformersZeroShotDocumentClassifier:
monkeypatch.delenv("HF_API_TOKEN", raising=False) monkeypatch.delenv("HF_API_TOKEN", raising=False)
monkeypatch.delenv("HF_TOKEN", raising=False) monkeypatch.delenv("HF_TOKEN", raising=False)
data = { data = {
"type": "haystack.components.classifiers.zero_shot_document_classifier.TransformersZeroShotDocumentClassifier", "type": "haystack.components.classifiers.zero_shot_document_classifier.TransformersZeroShotDocumentClassifier", # noqa: E501
"init_parameters": {"model": "cross-encoder/nli-deberta-v3-xsmall", "labels": ["positive", "negative"]}, "init_parameters": {"model": "cross-encoder/nli-deberta-v3-xsmall", "labels": ["positive", "negative"]},
} }
component = TransformersZeroShotDocumentClassifier.from_dict(data) component = TransformersZeroShotDocumentClassifier.from_dict(data)

View File

@ -6,9 +6,10 @@ import os
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import pytest import pytest
from haystack import Pipeline from haystack import Pipeline
from haystack.utils import Secret
from haystack.components.connectors.openapi import OpenAPIConnector from haystack.components.connectors.openapi import OpenAPIConnector
from haystack.utils import Secret
# Mock OpenAPI spec for testing # Mock OpenAPI spec for testing
MOCK_OPENAPI_SPEC = """ MOCK_OPENAPI_SPEC = """

View File

@ -7,19 +7,18 @@ import os
from typing import Any, Dict, List from typing import Any, Dict, List
from unittest.mock import MagicMock, Mock, patch from unittest.mock import MagicMock, Mock, patch
import pytest
import requests import requests
from openapi3 import OpenAPI
from haystack import Pipeline from haystack import Pipeline
import pytest from haystack.components.connectors import OpenAPIServiceConnector
from haystack.components.converters.openapi_functions import OpenAPIServiceToFunctions from haystack.components.converters.openapi_functions import OpenAPIServiceToFunctions
from haystack.components.converters.output_adapter import OutputAdapter from haystack.components.converters.output_adapter import OutputAdapter
from haystack.components.generators.chat.openai import OpenAIChatGenerator from haystack.components.generators.chat.openai import OpenAIChatGenerator
from haystack.components.generators.utils import print_streaming_chunk from haystack.components.generators.utils import print_streaming_chunk
from haystack.dataclasses.byte_stream import ByteStream
from openapi3 import OpenAPI
from haystack.components.connectors import OpenAPIServiceConnector
from haystack.dataclasses import ChatMessage, ToolCall from haystack.dataclasses import ChatMessage, ToolCall
from haystack.dataclasses.byte_stream import ByteStream
@pytest.fixture @pytest.fixture

View File

@ -7,8 +7,8 @@ import os
import pytest import pytest
from haystack.dataclasses import ByteStream
from haystack.components.converters.csv import CSVToDocument from haystack.components.converters.csv import CSVToDocument
from haystack.dataclasses import ByteStream
@pytest.fixture @pytest.fixture
@ -32,7 +32,7 @@ class TestCSVToDocument:
output = converter.run(sources=files) output = converter.run(sources=files)
docs = output["documents"] docs = output["documents"]
assert len(docs) == 3 assert len(docs) == 3
assert "Name,Age\r\nJohn Doe,27\r\nJane Smith,37\r\nMike Johnson,47\r\n" == docs[0].content assert docs[0].content == "Name,Age\r\nJohn Doe,27\r\nJane Smith,37\r\nMike Johnson,47\r\n"
assert isinstance(docs[0].content, str) assert isinstance(docs[0].content, str)
assert docs[0].meta == {"file_path": os.path.basename(bytestream.meta["file_path"]), "key": "value"} assert docs[0].meta == {"file_path": os.path.basename(bytestream.meta["file_path"]), "key": "value"}
assert docs[1].meta["file_path"] == os.path.basename(files[1]) assert docs[1].meta["file_path"] == os.path.basename(files[1])
@ -50,7 +50,7 @@ class TestCSVToDocument:
output = converter.run(sources=files) output = converter.run(sources=files)
docs = output["documents"] docs = output["documents"]
assert len(docs) == 3 assert len(docs) == 3
assert "Name,Age\r\nJohn Doe,27\r\nJane Smith,37\r\nMike Johnson,47\r\n" == docs[0].content assert docs[0].content == "Name,Age\r\nJohn Doe,27\r\nJane Smith,37\r\nMike Johnson,47\r\n"
assert isinstance(docs[0].content, str) assert isinstance(docs[0].content, str)
assert docs[0].meta["file_path"] == "sample_1.csv" assert docs[0].meta["file_path"] == "sample_1.csv"
assert docs[0].meta["key"] == "value" assert docs[0].meta["key"] == "value"

View File

@ -2,15 +2,16 @@
# #
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import json
import os
import logging
import pytest
import csv import csv
import json
import logging
import os
from io import StringIO from io import StringIO
import pytest
from haystack import Document, Pipeline from haystack import Document, Pipeline
from haystack.components.converters.docx import DOCXMetadata, DOCXToDocument, DOCXTableFormat, DOCXLinkFormat from haystack.components.converters.docx import DOCXLinkFormat, DOCXMetadata, DOCXTableFormat, DOCXToDocument
from haystack.dataclasses import ByteStream from haystack.dataclasses import ByteStream

View File

@ -4,9 +4,9 @@
import logging import logging
from pathlib import Path from pathlib import Path
from unittest.mock import patch
import pytest import pytest
from unittest.mock import patch
from haystack.components.converters import HTMLToDocument from haystack.components.converters import HTMLToDocument
from haystack.dataclasses import ByteStream from haystack.dataclasses import ByteStream

View File

@ -3,17 +3,16 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import json import json
import os
from unittest.mock import patch
from pathlib import Path
import logging import logging
import os
from pathlib import Path
from unittest.mock import patch
import pytest import pytest
from haystack.components.converters import JSONConverter from haystack.components.converters import JSONConverter
from haystack.dataclasses import ByteStream from haystack.dataclasses import ByteStream
test_data = [ test_data = [
{ {
"year": "1997", "year": "1997",
@ -23,7 +22,8 @@ test_data = [
"id": "674", "id": "674",
"firstname": "Dario", "firstname": "Dario",
"surname": "Fokin", "surname": "Fokin",
"motivation": "who emulates the jesters of the Middle Ages in scourging authority and upholding the dignity of the downtrodden", "motivation": "who emulates the jesters of the Middle Ages in scourging authority and upholding the "
"dignity of the downtrodden",
"share": "1", "share": "1",
} }
], ],
@ -56,7 +56,9 @@ test_data = [
"id": "46", "id": "46",
"firstname": "Enrico", "firstname": "Enrico",
"surname": "Fermi", "surname": "Fermi",
"motivation": "for his demonstrations of the existence of new radioactive elements produced by neutron irradiation, and for his related discovery of nuclear reactions brought about by slow neutrons", "motivation": "for his demonstrations of the existence of new radioactive elements produced by neutron "
"irradiation, and for his related discovery of nuclear reactions brought about by slow "
"neutrons",
"share": "1", "share": "1",
} }
], ],
@ -211,7 +213,8 @@ def test_run_with_non_json_file(tmpdir, caplog):
assert len(records) == 1 assert len(records) == 1
assert ( assert (
records[0].msg records[0].msg
== f"Failed to extract text from {test_file}. Skipping it. Error: parse error: Invalid numeric literal at line 1, column 5" == f"Failed to extract text from {test_file}. Skipping it. Error: parse error: Invalid numeric literal at "
f"line 1, column 5"
) )
assert result == {"documents": []} assert result == {"documents": []}
@ -481,7 +484,8 @@ def test_run_with_jq_schema_content_key_and_extra_meta_fields_literal(tmpdir):
assert len(result["documents"]) == 4 assert len(result["documents"]) == 4
assert ( assert (
result["documents"][0].content result["documents"][0].content
== "who emulates the jesters of the Middle Ages in scourging authority and upholding the dignity of the downtrodden" == "who emulates the jesters of the Middle Ages in scourging authority and upholding the dignity of the "
"downtrodden"
) )
assert result["documents"][0].meta == { assert result["documents"][0].meta == {
"file_path": os.path.basename(first_test_file), "file_path": os.path.basename(first_test_file),

View File

@ -3,8 +3,8 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import logging import logging
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from haystack.components.converters.markdown import MarkdownToDocument from haystack.components.converters.markdown import MarkdownToDocument

View File

@ -5,10 +5,10 @@
import pytest import pytest
from haystack import Document, Pipeline from haystack import Document, Pipeline
from haystack.core.pipeline.base import component_to_dict, component_from_dict
from haystack.core.component.component import Component
from haystack.dataclasses import ByteStream
from haystack.components.converters.multi_file_converter import MultiFileConverter from haystack.components.converters.multi_file_converter import MultiFileConverter
from haystack.core.component.component import Component
from haystack.core.pipeline.base import component_from_dict, component_to_dict
from haystack.dataclasses import ByteStream
@pytest.fixture @pytest.fixture

View File

@ -2,15 +2,15 @@
# #
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import List
import json import json
from typing import List
import pytest import pytest
from haystack import Pipeline, component from haystack import Pipeline, component
from haystack.dataclasses import Document
from haystack.components.converters import OutputAdapter from haystack.components.converters import OutputAdapter
from haystack.components.converters.output_adapter import OutputAdaptationException from haystack.components.converters.output_adapter import OutputAdaptationException
from haystack.dataclasses import Document
def custom_filter_to_sede(value): def custom_filter_to_sede(value):

View File

@ -8,9 +8,9 @@ from unittest.mock import patch
import pytest import pytest
from haystack import Document from haystack import Document
from haystack.components.converters.pdfminer import PDFMinerToDocument
from haystack.components.preprocessors import DocumentSplitter from haystack.components.preprocessors import DocumentSplitter
from haystack.dataclasses import ByteStream from haystack.dataclasses import ByteStream
from haystack.components.converters.pdfminer import PDFMinerToDocument
class TestPDFMinerToDocument: class TestPDFMinerToDocument:

View File

@ -5,8 +5,8 @@
import logging import logging
import os import os
from haystack.dataclasses import ByteStream
from haystack.components.converters.pptx import PPTXToDocument from haystack.components.converters.pptx import PPTXToDocument
from haystack.dataclasses import ByteStream
class TestPPTXToDocument: class TestPPTXToDocument:

View File

@ -3,12 +3,12 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import logging import logging
from unittest.mock import patch, Mock from unittest.mock import Mock, patch
import pytest import pytest
from haystack import Document from haystack import Document
from haystack.components.converters.pypdf import PyPDFToDocument, PyPDFExtractionMode from haystack.components.converters.pypdf import PyPDFExtractionMode, PyPDFToDocument
from haystack.components.preprocessors import DocumentSplitter from haystack.components.preprocessors import DocumentSplitter
from haystack.dataclasses import ByteStream from haystack.dataclasses import ByteStream

View File

@ -7,8 +7,8 @@ import os
import pytest import pytest
from haystack.dataclasses import ByteStream
from haystack.components.converters.txt import TextFileToDocument from haystack.components.converters.txt import TextFileToDocument
from haystack.dataclasses import ByteStream
class TestTextfileToDocument: class TestTextfileToDocument:

View File

@ -2,12 +2,13 @@
# #
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import sys
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
import sys
from haystack.dataclasses import ByteStream
from haystack.components.converters.tika import TikaDocumentConverter from haystack.components.converters.tika import TikaDocumentConverter
from haystack.dataclasses import ByteStream
class TestTikaDocumentConverter: class TestTikaDocumentConverter:

View File

@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import pytest import pytest
from haystack.components.converters.utils import normalize_metadata from haystack.components.converters.utils import normalize_metadata

View File

@ -3,16 +3,15 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os import os
from unittest.mock import Mock, patch
from openai import APIError
from haystack.utils.auth import Secret
import pytest import pytest
from openai import APIError
from haystack import Document from haystack import Document
from haystack.components.embedders import AzureOpenAIDocumentEmbedder from haystack.components.embedders import AzureOpenAIDocumentEmbedder
from haystack.utils.auth import Secret
from haystack.utils.azure import default_azure_ad_token_provider from haystack.utils.azure import default_azure_ad_token_provider
from unittest.mock import Mock, patch
class TestAzureOpenAIDocumentEmbedder: class TestAzureOpenAIDocumentEmbedder:

View File

@ -3,12 +3,11 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os import os
import random
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import random
import pytest import pytest
from huggingface_hub.utils import RepositoryNotFoundError from huggingface_hub.utils import RepositoryNotFoundError
from numpy import array from numpy import array
from haystack.components.embedders import HuggingFaceAPIDocumentEmbedder from haystack.components.embedders import HuggingFaceAPIDocumentEmbedder
@ -370,13 +369,13 @@ class TestHuggingFaceAPIDocumentEmbedder:
assert truncate is True assert truncate is True
assert normalize is False assert normalize is False
@pytest.mark.flaky(reruns=5, reruns_delay=5)
@pytest.mark.integration @pytest.mark.integration
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.skipif( @pytest.mark.skipif(
not os.environ.get("HF_API_TOKEN", None), not os.environ.get("HF_API_TOKEN", None),
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
) )
@pytest.mark.flaky(reruns=2, reruns_delay=10)
def test_live_run_serverless(self): def test_live_run_serverless(self):
docs = [ docs = [
Document(content="I love cheese", meta={"topic": "Cuisine"}), Document(content="I love cheese", meta={"topic": "Cuisine"}),

View File

@ -3,12 +3,13 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os import os
import random
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import random
import pytest import pytest
from huggingface_hub.utils import RepositoryNotFoundError from huggingface_hub.utils import RepositoryNotFoundError
from numpy import array from numpy import array
from haystack.components.embedders import HuggingFaceAPITextEmbedder from haystack.components.embedders import HuggingFaceAPITextEmbedder
from haystack.utils.auth import Secret from haystack.utils.auth import Secret
from haystack.utils.hf import HFEmbeddingAPIType from haystack.utils.hf import HFEmbeddingAPIType
@ -213,9 +214,9 @@ class TestHuggingFaceAPITextEmbedder:
with pytest.raises(ValueError): with pytest.raises(ValueError):
embedder.run(text="The food was delicious") embedder.run(text="The food was delicious")
@pytest.mark.flaky(reruns=5, reruns_delay=5)
@pytest.mark.integration @pytest.mark.integration
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.flaky(reruns=2, reruns_delay=10)
@pytest.mark.skipif( @pytest.mark.skipif(
not os.environ.get("HF_API_TOKEN", None), not os.environ.get("HF_API_TOKEN", None),
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
@ -233,6 +234,7 @@ class TestHuggingFaceAPITextEmbedder:
@pytest.mark.integration @pytest.mark.integration
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.flaky(reruns=2, reruns_delay=10)
@pytest.mark.skipif(os.environ.get("HF_API_TOKEN", "") == "", reason="HF_API_TOKEN is not set") @pytest.mark.skipif(os.environ.get("HF_API_TOKEN", "") == "", reason="HF_API_TOKEN is not set")
async def test_live_run_async_serverless(self): async def test_live_run_async_serverless(self):
model_name = "sentence-transformers/all-MiniLM-L6-v2" model_name = "sentence-transformers/all-MiniLM-L6-v2"

View File

@ -15,20 +15,6 @@ from haystack.components.embedders.openai_document_embedder import OpenAIDocumen
from haystack.utils.auth import Secret from haystack.utils.auth import Secret
def mock_openai_response(input: List[str], model: str = "text-embedding-ada-002", **kwargs) -> dict:
dict_response = {
"object": "list",
"data": [
{"object": "embedding", "index": i, "embedding": [random.random() for _ in range(1536)]}
for i in range(len(input))
],
"model": model,
"usage": {"prompt_tokens": 4, "total_tokens": 4},
}
return dict_response
class TestOpenAIDocumentEmbedder: class TestOpenAIDocumentEmbedder:
def test_init_default(self, monkeypatch): def test_init_default(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key") monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key")

View File

@ -5,10 +5,10 @@
import os import os
import pytest import pytest
from openai.types import CreateEmbeddingResponse, Embedding
from haystack.components.embedders.openai_text_embedder import OpenAITextEmbedder from haystack.components.embedders.openai_text_embedder import OpenAITextEmbedder
from haystack.utils.auth import Secret from haystack.utils.auth import Secret
from openai.types import CreateEmbeddingResponse, Embedding
class TestOpenAITextEmbedder: class TestOpenAITextEmbedder:
@ -149,8 +149,8 @@ class TestOpenAITextEmbedder:
monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key") monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key")
embedder = OpenAITextEmbedder(dimensions=1536) embedder = OpenAITextEmbedder(dimensions=1536)
input = "The food was delicious" inp = "The food was delicious"
prepared_input = embedder._prepare_input(input) prepared_input = embedder._prepare_input(inp)
assert prepared_input == { assert prepared_input == {
"model": "text-embedding-ada-002", "model": "text-embedding-ada-002",
"input": "The food was delicious", "input": "The food was delicious",

View File

@ -67,7 +67,7 @@ class TestSentenceTransformersDocumentEmbedder:
component = SentenceTransformersDocumentEmbedder(model="model", device=ComponentDevice.from_str("cpu")) component = SentenceTransformersDocumentEmbedder(model="model", device=ComponentDevice.from_str("cpu"))
data = component.to_dict() data = component.to_dict()
assert data == { assert data == {
"type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder", "type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder", # noqa: E501
"init_parameters": { "init_parameters": {
"model": "model", "model": "model",
"device": ComponentDevice.from_str("cpu").to_dict(), "device": ComponentDevice.from_str("cpu").to_dict(),
@ -115,7 +115,7 @@ class TestSentenceTransformersDocumentEmbedder:
data = component.to_dict() data = component.to_dict()
assert data == { assert data == {
"type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder", "type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder", # noqa: E501
"init_parameters": { "init_parameters": {
"model": "model", "model": "model",
"device": ComponentDevice.from_str("cuda:0").to_dict(), "device": ComponentDevice.from_str("cuda:0").to_dict(),
@ -161,7 +161,7 @@ class TestSentenceTransformersDocumentEmbedder:
} }
component = SentenceTransformersDocumentEmbedder.from_dict( component = SentenceTransformersDocumentEmbedder.from_dict(
{ {
"type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder", "type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder", # noqa: E501
"init_parameters": init_parameters, "init_parameters": init_parameters,
} }
) )
@ -186,7 +186,7 @@ class TestSentenceTransformersDocumentEmbedder:
def test_from_dict_no_default_parameters(self): def test_from_dict_no_default_parameters(self):
component = SentenceTransformersDocumentEmbedder.from_dict( component = SentenceTransformersDocumentEmbedder.from_dict(
{ {
"type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder", "type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder", # noqa: E501
"init_parameters": {}, "init_parameters": {},
} }
) )
@ -224,7 +224,7 @@ class TestSentenceTransformersDocumentEmbedder:
} }
component = SentenceTransformersDocumentEmbedder.from_dict( component = SentenceTransformersDocumentEmbedder.from_dict(
{ {
"type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder", "type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder", # noqa: E501
"init_parameters": init_parameters, "init_parameters": init_parameters,
} }
) )
@ -382,9 +382,9 @@ class TestSentenceTransformersDocumentEmbedder:
model="sentence-transformers/all-MiniLM-L6-v2", model="sentence-transformers/all-MiniLM-L6-v2",
token=None, token=None,
device=ComponentDevice.from_str("cpu"), device=ComponentDevice.from_str("cpu"),
model_kwargs={ # setting the path isn't necessary if the repo contains a "onnx/model.onnx" file but this is to prevent
"file_name": "onnx/model.onnx" # a HF warning
}, # setting the path isn't necessary if the repo contains a "onnx/model.onnx" file but this is to prevent a HF warning model_kwargs={"file_name": "onnx/model.onnx"},
backend="onnx", backend="onnx",
) )
onnx_embedder.warm_up() onnx_embedder.warm_up()
@ -410,9 +410,9 @@ class TestSentenceTransformersDocumentEmbedder:
model="sentence-transformers/all-MiniLM-L6-v2", model="sentence-transformers/all-MiniLM-L6-v2",
token=None, token=None,
device=ComponentDevice.from_str("cpu"), device=ComponentDevice.from_str("cpu"),
model_kwargs={ # setting the path isn't necessary if the repo contains a "openvino/openvino_model.xml" file but this is
"file_name": "openvino/openvino_model.xml" # to prevent a HF warning
}, # setting the path isn't necessary if the repo contains a "openvino/openvino_model.xml" file but this is to prevent a HF warning model_kwargs={"file_name": "openvino/openvino_model.xml"},
backend="openvino", backend="openvino",
) )
openvino_embedder.warm_up() openvino_embedder.warm_up()

View File

@ -60,7 +60,7 @@ class TestSentenceTransformersTextEmbedder:
component = SentenceTransformersTextEmbedder(model="model", device=ComponentDevice.from_str("cpu")) component = SentenceTransformersTextEmbedder(model="model", device=ComponentDevice.from_str("cpu"))
data = component.to_dict() data = component.to_dict()
assert data == { assert data == {
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder", "type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder", # noqa: E501
"init_parameters": { "init_parameters": {
"token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"}, "token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"},
"model": "model", "model": "model",
@ -103,7 +103,7 @@ class TestSentenceTransformersTextEmbedder:
) )
data = component.to_dict() data = component.to_dict()
assert data == { assert data == {
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder", "type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder", # noqa: E501
"init_parameters": { "init_parameters": {
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, "token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"model": "model", "model": "model",
@ -132,7 +132,7 @@ class TestSentenceTransformersTextEmbedder:
def test_from_dict(self): def test_from_dict(self):
data = { data = {
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder", "type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder", # noqa: E501
"init_parameters": { "init_parameters": {
"token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"}, "token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"},
"model": "model", "model": "model",
@ -170,7 +170,7 @@ class TestSentenceTransformersTextEmbedder:
def test_from_dict_no_default_parameters(self): def test_from_dict_no_default_parameters(self):
data = { data = {
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder", "type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder", # noqa: E501
"init_parameters": {}, "init_parameters": {},
} }
component = SentenceTransformersTextEmbedder.from_dict(data) component = SentenceTransformersTextEmbedder.from_dict(data)
@ -189,7 +189,7 @@ class TestSentenceTransformersTextEmbedder:
def test_from_dict_none_device(self): def test_from_dict_none_device(self):
data = { data = {
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder", "type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder", # noqa: E501
"init_parameters": { "init_parameters": {
"token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"}, "token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"},
"model": "model", "model": "model",
@ -283,7 +283,8 @@ class TestSentenceTransformersTextEmbedder:
@pytest.mark.slow @pytest.mark.slow
def test_run_trunc(self, monkeypatch): def test_run_trunc(self, monkeypatch):
""" """
sentence-transformers/paraphrase-albert-small-v2 maps sentences & paragraphs to a 768 dimensional dense vector space sentence-transformers/paraphrase-albert-small-v2 maps sentences & paragraphs to a 768 dimensional dense vector
space
""" """
monkeypatch.delenv("HF_API_TOKEN", raising=False) # https://github.com/deepset-ai/haystack/issues/8811 monkeypatch.delenv("HF_API_TOKEN", raising=False) # https://github.com/deepset-ai/haystack/issues/8811
checkpoint = "sentence-transformers/paraphrase-albert-small-v2" checkpoint = "sentence-transformers/paraphrase-albert-small-v2"
@ -306,7 +307,8 @@ class TestSentenceTransformersTextEmbedder:
@pytest.mark.slow @pytest.mark.slow
def test_run_quantization(self): def test_run_quantization(self):
""" """
sentence-transformers/paraphrase-albert-small-v2 maps sentences & paragraphs to a 768 dimensional dense vector space sentence-transformers/paraphrase-albert-small-v2 maps sentences & paragraphs to a 768 dimensional dense vector
space
""" """
checkpoint = "sentence-transformers/paraphrase-albert-small-v2" checkpoint = "sentence-transformers/paraphrase-albert-small-v2"
text = "a nice text to embed" text = "a nice text to embed"
@ -341,9 +343,9 @@ class TestSentenceTransformersTextEmbedder:
model="sentence-transformers/all-MiniLM-L6-v2", model="sentence-transformers/all-MiniLM-L6-v2",
token=None, token=None,
device=ComponentDevice.from_str("cpu"), device=ComponentDevice.from_str("cpu"),
model_kwargs={ # setting the path isn't necessary if the repo contains a "onnx/model.onnx" file but this is to prevent
"file_name": "onnx/model.onnx" # a HF warning
}, # setting the path isn't necessary if the repo contains a "onnx/model.onnx" file but this is to prevent a HF warning model_kwargs={"file_name": "onnx/model.onnx"},
backend="onnx", backend="onnx",
) )
onnx_embedder.warm_up() onnx_embedder.warm_up()
@ -369,9 +371,9 @@ class TestSentenceTransformersTextEmbedder:
model="sentence-transformers/all-MiniLM-L6-v2", model="sentence-transformers/all-MiniLM-L6-v2",
token=None, token=None,
device=ComponentDevice.from_str("cpu"), device=ComponentDevice.from_str("cpu"),
model_kwargs={ # setting the path isn't necessary if the repo contains a "openvino/openvino_model.xml" file but
"file_name": "openvino/openvino_model.xml" # this is to prevent a HF warning
}, # setting the path isn't necessary if the repo contains a "openvino/openvino_model.xml" file but this is to prevent a HF warning model_kwargs={"file_name": "openvino/openvino_model.xml"},
backend="openvino", backend="openvino",
) )
openvino_embedder.warm_up() openvino_embedder.warm_up()

View File

@ -2,18 +2,17 @@
# #
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import math
import os import os
from typing import List from typing import List
import math
import pytest import pytest
from haystack import Pipeline from haystack import Pipeline
from haystack.components.evaluators import ContextRelevanceEvaluator from haystack.components.evaluators import ContextRelevanceEvaluator
from haystack.utils.auth import Secret
from haystack.dataclasses.chat_message import ChatMessage
from haystack.components.generators.chat.openai import OpenAIChatGenerator from haystack.components.generators.chat.openai import OpenAIChatGenerator
from haystack.dataclasses.chat_message import ChatMessage
from haystack.utils.auth import Secret
class TestContextRelevanceEvaluator: class TestContextRelevanceEvaluator:

View File

@ -122,7 +122,7 @@ def test_run_empty_retrieved_and_empty_ground_truth():
def test_run_no_retrieved(): def test_run_no_retrieved():
evaluator = DocumentNDCGEvaluator() evaluator = DocumentNDCGEvaluator()
with pytest.raises(ValueError): with pytest.raises(ValueError):
result = evaluator.run(ground_truth_documents=[[Document(content="France")]], retrieved_documents=[]) _ = evaluator.run(ground_truth_documents=[[Document(content="France")]], retrieved_documents=[])
def test_run_no_ground_truth(): def test_run_no_ground_truth():

View File

@ -4,9 +4,9 @@
import pytest import pytest
from haystack import default_from_dict
from haystack.components.evaluators.document_recall import DocumentRecallEvaluator, RecallMode from haystack.components.evaluators.document_recall import DocumentRecallEvaluator, RecallMode
from haystack.dataclasses import Document from haystack.dataclasses import Document
from haystack import default_from_dict
def test_init_with_unknown_mode_string(): def test_init_with_unknown_mode_string():

View File

@ -2,17 +2,17 @@
# #
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os
import math import math
import os
from typing import List from typing import List
import pytest import pytest
from haystack import Pipeline from haystack import Pipeline
from haystack.components.evaluators import FaithfulnessEvaluator from haystack.components.evaluators import FaithfulnessEvaluator
from haystack.utils.auth import Secret
from haystack.dataclasses.chat_message import ChatMessage
from haystack.components.generators.chat.openai import OpenAIChatGenerator from haystack.components.generators.chat.openai import OpenAIChatGenerator
from haystack.dataclasses.chat_message import ChatMessage
from haystack.utils.auth import Secret
class TestFaithfulnessEvaluator: class TestFaithfulnessEvaluator:

View File

@ -8,8 +8,8 @@ import pytest
from haystack import Pipeline from haystack import Pipeline
from haystack.components.evaluators import LLMEvaluator from haystack.components.evaluators import LLMEvaluator
from haystack.dataclasses.chat_message import ChatMessage
from haystack.components.generators.chat.openai import OpenAIChatGenerator from haystack.components.generators.chat.openai import OpenAIChatGenerator
from haystack.dataclasses.chat_message import ChatMessage
class TestLLMEvaluator: class TestLLMEvaluator:
@ -349,7 +349,11 @@ class TestLLMEvaluator:
template = component.prepare_template() template = component.prepare_template()
assert ( assert (
template template
== 'Instructions:\ntest-instruction\n\nGenerate the response in JSON format with the following keys:\n["score"]\nConsider the instructions and the examples below to determine those values.\n\nExamples:\nInputs:\n{"predicted_answers": "Damn, this is straight outta hell!!!"}\nOutputs:\n{"score": 1}\nInputs:\n{"predicted_answers": "Football is the most popular sport."}\nOutputs:\n{"score": 0}\n\nInputs:\n{"predicted_answers": {{ predicted_answers }}}\nOutputs:\n' == "Instructions:\ntest-instruction\n\nGenerate the response in JSON format with the following keys:"
'\n["score"]\nConsider the instructions and the examples below to determine those values.\n\n'
'Examples:\nInputs:\n{"predicted_answers": "Damn, this is straight outta hell!!!"}\nOutputs:'
'\n{"score": 1}\nInputs:\n{"predicted_answers": "Football is the most popular sport."}\nOutputs:'
'\n{"score": 0}\n\nInputs:\n{"predicted_answers": {{ predicted_answers }}}\nOutputs:\n'
) )
def test_invalid_input_parameters(self, monkeypatch): def test_invalid_input_parameters(self, monkeypatch):

View File

@ -3,19 +3,18 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os import os
from unittest.mock import Mock
import pytest import pytest
from unittest.mock import Mock
from haystack import Document, Pipeline from haystack import Document, Pipeline
from haystack.components.extractors import LLMMetadataExtractor
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.components.writers import DocumentWriter from haystack.components.writers import DocumentWriter
from haystack.dataclasses import ChatMessage from haystack.dataclasses import ChatMessage
from haystack.document_stores.in_memory import InMemoryDocumentStore from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.components.extractors import LLMMetadataExtractor
from haystack.components.generators.chat import OpenAIChatGenerator
class TestLLMMetadataExtractor: class TestLLMMetadataExtractor:
def test_init(self, monkeypatch): def test_init(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
@ -139,7 +138,8 @@ class TestLLMMetadataExtractor:
"_name": None, "_name": None,
"_content": [ "_content": [
{ {
"text": "some_user_definer_prompt deepset was founded in 2018 in Berlin, and is known for its Haystack framework" "text": "some_user_definer_prompt deepset was founded in 2018 in Berlin, and is known for "
"its Haystack framework"
} }
], ],
} }
@ -151,7 +151,8 @@ class TestLLMMetadataExtractor:
"_name": None, "_name": None,
"_content": [ "_content": [
{ {
"text": "some_user_definer_prompt Hugging Face is a company founded in Paris, France and is known for its Transformers library" "text": "some_user_definer_prompt Hugging Face is a company founded in Paris, France and "
"is known for its Transformers library"
} }
], ],
} }
@ -179,7 +180,8 @@ class TestLLMMetadataExtractor:
"_name": None, "_name": None,
"_content": [ "_content": [
{ {
"text": "some_user_definer_prompt Hugging Face is a company founded in Paris, France and is known for its Transformers library" "text": "some_user_definer_prompt Hugging Face is a company founded in Paris, "
"France and is known for its Transformers library"
} }
], ],
} }
@ -195,7 +197,8 @@ class TestLLMMetadataExtractor:
) )
docs = [ docs = [
Document( Document(
content="Hugging Face is a company founded in Paris, France and is known for its Transformers library\fPage 2\fPage 3" content="Hugging Face is a company founded in Paris, France and is known for its Transformers "
"library\fPage 2\fPage 3"
) )
] ]
prompts = extractor._prepare_prompts(docs, expanded_range=[1, 2]) prompts = extractor._prepare_prompts(docs, expanded_range=[1, 2])
@ -208,7 +211,8 @@ class TestLLMMetadataExtractor:
"_name": None, "_name": None,
"_content": [ "_content": [
{ {
"text": "some_user_definer_prompt Hugging Face is a company founded in Paris, France and is known for its Transformers library\x0cPage 2\x0c" "text": "some_user_definer_prompt Hugging Face is a company founded in Paris, France and "
"is known for its Transformers library\x0cPage 2\x0c"
} }
], ],
} }
@ -305,7 +309,7 @@ entity_types: [company, organization, person, country, product, service]
text: {{ document.content }} text: {{ document.content }}
###################### ######################
output: output:
""" """ # noqa: E501
doc_store = InMemoryDocumentStore() doc_store = InMemoryDocumentStore()
extractor = LLMMetadataExtractor( extractor = LLMMetadataExtractor(

View File

@ -6,11 +6,11 @@
# Spacy is not installed in the test environment to keep the CI fast. # Spacy is not installed in the test environment to keep the CI fast.
# We test the Spacy backend in e2e/pipelines/test_named_entity_extractor.py. # We test the Spacy backend in e2e/pipelines/test_named_entity_extractor.py.
from haystack.utils.auth import Secret
import pytest import pytest
from haystack import ComponentError, DeserializationError, Pipeline from haystack import ComponentError, DeserializationError, Pipeline
from haystack.components.extractors import NamedEntityExtractor, NamedEntityExtractorBackend from haystack.components.extractors import NamedEntityExtractor, NamedEntityExtractorBackend
from haystack.utils.auth import Secret
from haystack.utils.device import ComponentDevice from haystack.utils.device import ComponentDevice

View File

@ -2,16 +2,16 @@
# #
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from unittest.mock import patch, Mock from unittest.mock import Mock, patch
import pytest
import httpx import httpx
import pytest
from haystack.components.fetchers.link_content import ( from haystack.components.fetchers.link_content import (
LinkContentFetcher,
_text_content_handler,
_binary_content_handler,
DEFAULT_USER_AGENT, DEFAULT_USER_AGENT,
LinkContentFetcher,
_binary_content_handler,
_text_content_handler,
) )
HTML_URL = "https://docs.haystack.deepset.ai/docs" HTML_URL = "https://docs.haystack.deepset.ai/docs"
@ -30,11 +30,9 @@ def mock_get_link_text_content():
@pytest.fixture @pytest.fixture
def mock_get_link_content(test_files_path): def mock_get_link_content(test_files_path):
with patch("haystack.components.fetchers.link_content.httpx.Client.get") as mock_get: with patch("haystack.components.fetchers.link_content.httpx.Client.get") as mock_get:
mock_response = Mock( with open(test_files_path / "pdf" / "sample_pdf_1.pdf", "rb") as f1:
status_code=200, file_bytes = f1.read()
content=open(test_files_path / "pdf" / "sample_pdf_1.pdf", "rb").read(), mock_response = Mock(status_code=200, content=file_bytes, headers={"Content-Type": "application/pdf"})
headers={"Content-Type": "application/pdf"},
)
mock_get.return_value = mock_response mock_get.return_value = mock_response
yield mock_get yield mock_get
@ -110,7 +108,8 @@ class TestLinkContentFetcher:
def test_run_binary(self, test_files_path): def test_run_binary(self, test_files_path):
"""Test fetching binary content""" """Test fetching binary content"""
file_bytes = open(test_files_path / "pdf" / "sample_pdf_1.pdf", "rb").read() with open(test_files_path / "pdf" / "sample_pdf_1.pdf", "rb") as f1:
file_bytes = f1.read()
with patch("haystack.components.fetchers.link_content.httpx.Client.get") as mock_get: with patch("haystack.components.fetchers.link_content.httpx.Client.get") as mock_get:
mock_response = Mock(status_code=200, content=file_bytes, headers={"Content-Type": "application/pdf"}) mock_response = Mock(status_code=200, content=file_bytes, headers={"Content-Type": "application/pdf"})
mock_get.return_value = mock_response mock_get.return_value = mock_response
@ -143,8 +142,8 @@ class TestLinkContentFetcher:
def test_bad_request_exception_raised(self): def test_bad_request_exception_raised(self):
""" """
This test is to ensure that the fetcher raises an exception when a single bad request is made and it is configured to This test is to ensure that the fetcher raises an exception when a single bad request is made and it is
do so. configured to do so.
""" """
fetcher = LinkContentFetcher(raise_on_failure=True, retry_attempts=0) fetcher = LinkContentFetcher(raise_on_failure=True, retry_attempts=0)
@ -247,7 +246,8 @@ class TestLinkContentFetcherAsync:
streams = (await fetcher.run_async(urls=["https://www.example.com"]))["streams"] streams = (await fetcher.run_async(urls=["https://www.example.com"]))["streams"]
first_stream = streams[0] first_stream = streams[0]
assert first_stream.data == b"Example test response" expected_content = b"Example test response"
assert first_stream.data == expected_content
assert first_stream.meta["content_type"] == "text/plain" assert first_stream.meta["content_type"] == "text/plain"
assert first_stream.mime_type == "text/plain" assert first_stream.mime_type == "text/plain"
@ -265,7 +265,8 @@ class TestLinkContentFetcherAsync:
assert len(streams) == 2 assert len(streams) == 2
for stream in streams: for stream in streams:
assert stream.data == b"Example test response" expected_data = b"Example test response"
assert stream.data == expected_data
assert stream.meta["content_type"] == "text/plain" assert stream.meta["content_type"] == "text/plain"
assert stream.mime_type == "text/plain" assert stream.mime_type == "text/plain"
@ -324,7 +325,8 @@ class TestLinkContentFetcherAsync:
# Should succeed on the second attempt with the second user agent # Should succeed on the second attempt with the second user agent
streams = (await fetcher.run_async(urls=["https://www.example.com"]))["streams"] streams = (await fetcher.run_async(urls=["https://www.example.com"]))["streams"]
assert len(streams) == 1 assert len(streams) == 1
assert streams[0].data == b"Success" expected_result = b"Success"
assert streams[0].data == expected_result
mock_sleep.assert_called_once() mock_sleep.assert_called_once()

View File

@ -8,7 +8,7 @@ from typing import Any, Dict, List, Optional
import pytest import pytest
from openai import OpenAIError from openai import OpenAIError
from haystack import component, Pipeline from haystack import Pipeline, component
from haystack.components.generators.chat import AzureOpenAIChatGenerator from haystack.components.generators.chat import AzureOpenAIChatGenerator
from haystack.components.generators.utils import print_streaming_chunk from haystack.components.generators.utils import print_streaming_chunk
from haystack.dataclasses import ChatMessage, ToolCall from haystack.dataclasses import ChatMessage, ToolCall

View File

@ -2,18 +2,14 @@
# #
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from datetime import datetime
import os import os
from datetime import datetime
from typing import Any, Dict from typing import Any, Dict
from unittest.mock import MagicMock, Mock, AsyncMock, patch from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest import pytest
from haystack import Pipeline
from haystack.dataclasses import StreamingChunk
from haystack.utils.auth import Secret
from haystack.utils.hf import HFGenerationAPIType
from huggingface_hub import ( from huggingface_hub import (
ChatCompletionInputStreamOptions,
ChatCompletionOutput, ChatCompletionOutput,
ChatCompletionOutputComplete, ChatCompletionOutputComplete,
ChatCompletionOutputFunctionDefinition, ChatCompletionOutputFunctionDefinition,
@ -23,21 +19,22 @@ from huggingface_hub import (
ChatCompletionStreamOutput, ChatCompletionStreamOutput,
ChatCompletionStreamOutputChoice, ChatCompletionStreamOutputChoice,
ChatCompletionStreamOutputDelta, ChatCompletionStreamOutputDelta,
ChatCompletionInputStreamOptions,
ChatCompletionStreamOutputUsage, ChatCompletionStreamOutputUsage,
) )
from huggingface_hub.errors import RepositoryNotFoundError from huggingface_hub.errors import RepositoryNotFoundError
from haystack import Pipeline
from haystack.components.generators.chat.hugging_face_api import ( from haystack.components.generators.chat.hugging_face_api import (
HuggingFaceAPIChatGenerator, HuggingFaceAPIChatGenerator,
_convert_chat_completion_stream_output_to_streaming_chunk,
_convert_hfapi_tool_calls, _convert_hfapi_tool_calls,
_convert_tools_to_hfapi_tools, _convert_tools_to_hfapi_tools,
_convert_chat_completion_stream_output_to_streaming_chunk,
) )
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall
from haystack.tools import Tool from haystack.tools import Tool
from haystack.dataclasses import ChatMessage, ToolCall
from haystack.tools.toolset import Toolset from haystack.tools.toolset import Toolset
from haystack.utils.auth import Secret
from haystack.utils.hf import HFGenerationAPIType
@pytest.fixture @pytest.fixture
@ -753,11 +750,11 @@ class TestHuggingFaceAPIChatGenerator:
not os.environ.get("HF_API_TOKEN", None), not os.environ.get("HF_API_TOKEN", None),
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
) )
@pytest.mark.flaky(reruns=3, reruns_delay=10) @pytest.mark.flaky(reruns=2, reruns_delay=10)
def test_live_run_serverless(self): def test_live_run_serverless(self):
generator = HuggingFaceAPIChatGenerator( generator = HuggingFaceAPIChatGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
api_params={"model": "microsoft/Phi-3.5-mini-instruct", "provider": "hf-inference"}, api_params={"model": "microsoft/Phi-3.5-mini-instruct", "provider": "featherless-ai"},
generation_kwargs={"max_tokens": 20}, generation_kwargs={"max_tokens": 20},
) )
@ -788,11 +785,11 @@ class TestHuggingFaceAPIChatGenerator:
not os.environ.get("HF_API_TOKEN", None), not os.environ.get("HF_API_TOKEN", None),
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
) )
@pytest.mark.flaky(reruns=3, reruns_delay=10) @pytest.mark.flaky(reruns=2, reruns_delay=10)
def test_live_run_serverless_streaming(self): def test_live_run_serverless_streaming(self):
generator = HuggingFaceAPIChatGenerator( generator = HuggingFaceAPIChatGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
api_params={"model": "microsoft/Phi-3.5-mini-instruct", "provider": "hf-inference"}, api_params={"model": "microsoft/Phi-3.5-mini-instruct", "provider": "featherless-ai"},
generation_kwargs={"max_tokens": 20}, generation_kwargs={"max_tokens": 20},
streaming_callback=streaming_callback_handler, streaming_callback=streaming_callback_handler,
) )
@ -837,7 +834,7 @@ class TestHuggingFaceAPIChatGenerator:
chat_messages = [ChatMessage.from_user("What's the weather like in Paris?")] chat_messages = [ChatMessage.from_user("What's the weather like in Paris?")]
generator = HuggingFaceAPIChatGenerator( generator = HuggingFaceAPIChatGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
api_params={"model": "Qwen/Qwen2.5-72B-Instruct", "provider": "hf-inference"}, api_params={"model": "Qwen/Qwen2.5-72B-Instruct", "provider": "together"},
generation_kwargs={"temperature": 0.5}, generation_kwargs={"temperature": 0.5},
) )
@ -851,7 +848,7 @@ class TestHuggingFaceAPIChatGenerator:
assert tool_call.tool_name == "weather" assert tool_call.tool_name == "weather"
assert "city" in tool_call.arguments assert "city" in tool_call.arguments
assert "Paris" in tool_call.arguments["city"] assert "Paris" in tool_call.arguments["city"]
assert message.meta["finish_reason"] == "stop" assert message.meta["finish_reason"] == "tool_calls"
new_messages = chat_messages + [message, ChatMessage.from_tool(tool_result="22° C", origin=tool_call)] new_messages = chat_messages + [message, ChatMessage.from_tool(tool_result="22° C", origin=tool_call)]
@ -1024,12 +1021,12 @@ class TestHuggingFaceAPIChatGenerator:
not os.environ.get("HF_API_TOKEN", None), not os.environ.get("HF_API_TOKEN", None),
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
) )
@pytest.mark.flaky(reruns=3, reruns_delay=10) @pytest.mark.flaky(reruns=2, reruns_delay=10)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_live_run_async_serverless(self): async def test_live_run_async_serverless(self):
generator = HuggingFaceAPIChatGenerator( generator = HuggingFaceAPIChatGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
api_params={"model": "microsoft/Phi-3.5-mini-instruct", "provider": "hf-inference"}, api_params={"model": "microsoft/Phi-3.5-mini-instruct", "provider": "featherless-ai"},
generation_kwargs={"max_tokens": 20}, generation_kwargs={"max_tokens": 20},
) )

View File

@ -4,20 +4,20 @@
import asyncio import asyncio
import gc import gc
from typing import Optional, List from typing import List, Optional
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from haystack.utils.hf import AsyncHFTokenStreamingHandler
import pytest import pytest
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from haystack.components.generators.chat import HuggingFaceLocalChatGenerator from haystack.components.generators.chat import HuggingFaceLocalChatGenerator
from haystack.dataclasses import ChatMessage, ChatRole, ToolCall from haystack.dataclasses import ChatMessage, ChatRole, ToolCall
from haystack.dataclasses.streaming_chunk import StreamingChunk, AsyncStreamingCallbackT from haystack.dataclasses.streaming_chunk import AsyncStreamingCallbackT, StreamingChunk
from haystack.tools import Tool from haystack.tools import Tool
from haystack.tools.toolset import Toolset
from haystack.utils import ComponentDevice from haystack.utils import ComponentDevice
from haystack.utils.auth import Secret from haystack.utils.auth import Secret
from haystack.tools.toolset import Toolset from haystack.utils.hf import AsyncHFTokenStreamingHandler
# used to test serialization of streaming_callback # used to test serialization of streaming_callback
@ -558,7 +558,7 @@ class TestHuggingFaceLocalChatGeneratorAsync:
) )
def test_executor_shutdown(self, model_info_mock, mock_pipeline_with_tokenizer): def test_executor_shutdown(self, model_info_mock, mock_pipeline_with_tokenizer):
with patch("haystack.components.generators.chat.hugging_face_local.pipeline") as mock_pipeline: with patch("haystack.components.generators.chat.hugging_face_local.pipeline"):
generator = HuggingFaceLocalChatGenerator(model="mocked-model") generator = HuggingFaceLocalChatGenerator(model="mocked-model")
executor = generator.executor executor = generator.executor
with patch.object(executor, "shutdown", wraps=executor.shutdown) as mock_shutdown: with patch.object(executor, "shutdown", wraps=executor.shutdown) as mock_shutdown:

View File

@ -2,35 +2,37 @@
# #
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from unittest.mock import patch, ANY, MagicMock
import pytest
import logging import logging
import os import os
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from unittest.mock import ANY, MagicMock, patch
import pytest
from openai import OpenAIError from openai import OpenAIError
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage, ChatCompletionMessageToolCall from openai.types.chat import (
ChatCompletion,
ChatCompletionChunk,
ChatCompletionMessage,
ChatCompletionMessageToolCall,
chat_completion_chunk,
)
from openai.types.chat.chat_completion import Choice from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_chunk import ChoiceDelta, ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction from openai.types.chat.chat_completion_chunk import ChoiceDelta, ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction
from openai.types.completion_usage import CompletionTokensDetails, CompletionUsage, PromptTokensDetails
from openai.types.chat.chat_completion_message_tool_call import Function from openai.types.chat.chat_completion_message_tool_call import Function
from openai.types.chat import chat_completion_chunk from openai.types.completion_usage import CompletionTokensDetails, CompletionUsage, PromptTokensDetails
from haystack import component from haystack import component
from haystack.components.generators.utils import print_streaming_chunk
from haystack.dataclasses import StreamingChunk, ToolCallDelta
from haystack.utils.auth import Secret
from haystack.dataclasses import ChatMessage, ToolCall
from haystack.tools import ComponentTool, Tool
from haystack.components.generators.chat.openai import ( from haystack.components.generators.chat.openai import (
OpenAIChatGenerator, OpenAIChatGenerator,
_check_finish_reason, _check_finish_reason,
_convert_chat_completion_chunk_to_streaming_chunk, _convert_chat_completion_chunk_to_streaming_chunk,
) )
from haystack.components.generators.utils import print_streaming_chunk
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, ToolCallDelta
from haystack.tools import ComponentTool, Tool
from haystack.tools.toolset import Toolset from haystack.tools.toolset import Toolset
from haystack.utils.auth import Secret
@pytest.fixture @pytest.fixture
@ -1177,6 +1179,32 @@ class TestChatCompletionChunkConversion:
assert stream_chunk == haystack_chunk assert stream_chunk == haystack_chunk
previous_chunks.append(stream_chunk) previous_chunks.append(stream_chunk)
def test_convert_chat_completion_chunk_with_empty_tool_calls(self):
# This can happen with some LLM providers where tool calls are not present but the pydantic models are still
# initialized.
chunk = ChatCompletionChunk(
id="chatcmpl-BC1y4wqIhe17R8sv3lgLcWlB4tXCw",
choices=[
chat_completion_chunk.Choice(
delta=chat_completion_chunk.ChoiceDelta(
tool_calls=[ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction())]
),
index=0,
)
],
created=1742207200,
model="gpt-4o-mini-2024-07-18",
object="chat.completion.chunk",
)
result = _convert_chat_completion_chunk_to_streaming_chunk(chunk=chunk, previous_chunks=[])
assert result.content == ""
assert result.start is False
assert result.tool_calls == [ToolCallDelta(index=0)]
assert result.tool_call_result is None
assert result.index == 0
assert result.meta["model"] == "gpt-4o-mini-2024-07-18"
assert result.meta["received_at"] is not None
def test_handle_stream_response(self, chat_completion_chunks): def test_handle_stream_response(self, chat_completion_chunks):
openai_chunks = chat_completion_chunks openai_chunks = chat_completion_chunks
comp = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key")) comp = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))

View File

@ -2,24 +2,27 @@
# #
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from unittest.mock import AsyncMock, patch, MagicMock
from openai import AsyncOpenAI, OpenAIError
import pytest
from datetime import datetime
import os import os
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock, patch
from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall, ChatCompletionChunk import pytest
from openai import AsyncOpenAI, OpenAIError
from openai.types.chat import (
ChatCompletion,
ChatCompletionChunk,
ChatCompletionMessage,
ChatCompletionMessageToolCall,
chat_completion_chunk,
)
from openai.types.chat.chat_completion import Choice from openai.types.chat.chat_completion import Choice
from openai.types.completion_usage import CompletionTokensDetails, CompletionUsage, PromptTokensDetails
from openai.types.chat.chat_completion_message_tool_call import Function from openai.types.chat.chat_completion_message_tool_call import Function
from openai.types.chat import chat_completion_chunk from openai.types.completion_usage import CompletionTokensDetails, CompletionUsage, PromptTokensDetails
from haystack.dataclasses import StreamingChunk
from haystack.utils.auth import Secret
from haystack.dataclasses import ChatMessage, ToolCall
from haystack.tools import Tool
from haystack.components.generators.chat.openai import OpenAIChatGenerator from haystack.components.generators.chat.openai import OpenAIChatGenerator
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall
from haystack.tools import Tool
from haystack.utils.auth import Secret
@pytest.fixture @pytest.fixture

View File

@ -4,12 +4,11 @@
from datetime import datetime from datetime import datetime
from typing import Iterator from typing import Iterator
from unittest.mock import MagicMock, patch, AsyncMock from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from openai import AsyncStream, Stream from openai import AsyncStream, Stream
from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai.types.chat import ChatCompletion, ChatCompletionChunk, chat_completion_chunk
from openai.types.chat import chat_completion_chunk
@pytest.fixture @pytest.fixture

View File

@ -4,14 +4,13 @@
import os import os
from haystack import Pipeline
from haystack.utils.auth import Secret
import pytest import pytest
from openai import OpenAIError from openai import OpenAIError
from haystack import Pipeline
from haystack.components.generators import AzureOpenAIGenerator from haystack.components.generators import AzureOpenAIGenerator
from haystack.components.generators.utils import print_streaming_chunk from haystack.components.generators.utils import print_streaming_chunk
from haystack.utils.auth import Secret
from haystack.utils.azure import default_azure_ad_token_provider from haystack.utils.azure import default_azure_ad_token_provider

View File

@ -3,8 +3,8 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os import os
from unittest.mock import MagicMock, Mock, patch
from datetime import datetime from datetime import datetime
from unittest.mock import MagicMock, Mock, patch
import pytest import pytest
from huggingface_hub import ( from huggingface_hub import (
@ -297,6 +297,10 @@ class TestHuggingFaceAPIGenerator:
not os.environ.get("HF_API_TOKEN", None), not os.environ.get("HF_API_TOKEN", None),
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
) )
@pytest.mark.skip(
reason="HF Inference API is not currently serving these models. "
"See https://github.com/deepset-ai/haystack/issues/9586."
)
def test_run_serverless(self): def test_run_serverless(self):
generator = HuggingFaceAPIGenerator( generator = HuggingFaceAPIGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
@ -307,7 +311,8 @@ class TestHuggingFaceAPIGenerator:
# You must include the instruction tokens in the prompt. HF does not add them automatically. # You must include the instruction tokens in the prompt. HF does not add them automatically.
# Without them the model will behave erratically. # Without them the model will behave erratically.
response = generator.run( response = generator.run(
"<|user|>\nWhat is the capital of France? Be concise only provide the capital, nothing else.<|end|>\n<|assistant|>\n" "<|user|>\nWhat is the capital of France? Be concise only provide the capital, nothing else.<|end|>"
"\n<|assistant|>\n"
) )
# Assert that the response contains the generated replies # Assert that the response contains the generated replies
@ -329,6 +334,10 @@ class TestHuggingFaceAPIGenerator:
not os.environ.get("HF_API_TOKEN", None), not os.environ.get("HF_API_TOKEN", None),
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
) )
@pytest.mark.skip(
reason="HF Inference API is not currently serving these models. "
"See https://github.com/deepset-ai/haystack/issues/9586."
)
def test_live_run_streaming_check_completion_start_time(self): def test_live_run_streaming_check_completion_start_time(self):
generator = HuggingFaceAPIGenerator( generator = HuggingFaceAPIGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
@ -338,7 +347,8 @@ class TestHuggingFaceAPIGenerator:
) )
results = generator.run( results = generator.run(
"<|user|>\nWhat is the capital of France? Be concise only provide the capital, nothing else.<|end|>\n<|assistant|>\n" "<|user|>\nWhat is the capital of France? Be concise only provide the capital, nothing else."
"<|end|>\n<|assistant|>\n"
) )
# Assert that the response contains the generated replies # Assert that the response contains the generated replies

View File

@ -4,10 +4,11 @@
import os import os
from datetime import datetime from datetime import datetime
from unittest.mock import MagicMock, patch
import pytest import pytest
from openai import OpenAIError from openai import OpenAIError
from openai.types.chat import ChatCompletionChunk, chat_completion_chunk from openai.types.chat import ChatCompletionChunk, chat_completion_chunk
from unittest.mock import MagicMock, patch
from haystack.components.generators import OpenAIGenerator from haystack.components.generators import OpenAIGenerator
from haystack.components.generators.utils import print_streaming_chunk from haystack.components.generators.utils import print_streaming_chunk

View File

@ -2,13 +2,14 @@
# #
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import pytest
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from haystack.utils import Secret
from openai.types.image import Image import pytest
from openai.types import ImagesResponse from openai.types import ImagesResponse
from openai.types.image import Image
from haystack.components.generators.openai_dalle import DALLEImageGenerator from haystack.components.generators.openai_dalle import DALLEImageGenerator
from haystack.utils import Secret
@pytest.fixture @pytest.fixture
@ -30,7 +31,7 @@ class TestDALLEImageGenerator:
assert component.api_base_url is None assert component.api_base_url is None
assert component.organization is None assert component.organization is None
assert pytest.approx(component.timeout) == 30.0 assert pytest.approx(component.timeout) == 30.0
assert component.max_retries is 5 assert component.max_retries == 5
assert component.http_client_kwargs is None assert component.http_client_kwargs is None
def test_init_with_params(self, monkeypatch): def test_init_with_params(self, monkeypatch):

View File

@ -2,10 +2,12 @@
# #
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from unittest.mock import call, patch
from openai.types.chat import chat_completion_chunk from openai.types.chat import chat_completion_chunk
from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message, print_streaming_chunk
from haystack.dataclasses import ComponentInfo, StreamingChunk, ToolCallDelta from haystack.dataclasses import ComponentInfo, StreamingChunk, ToolCall, ToolCallDelta, ToolCallResult
def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(): def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk():
@ -325,3 +327,256 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk():
}, },
"prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0}, "prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0},
} }
def test_convert_streaming_chunk_to_chat_message_two_tool_calls_in_same_chunk():
chunks = [
StreamingChunk(
content="",
meta={
"model": "mistral-small-latest",
"index": 0,
"tool_calls": None,
"finish_reason": None,
"usage": None,
},
component_info=ComponentInfo(
type="haystack_integrations.components.generators.mistral.chat.chat_generator.MistralChatGenerator",
name=None,
),
),
StreamingChunk(
content="",
meta={
"model": "mistral-small-latest",
"index": 0,
"finish_reason": "tool_calls",
"usage": {
"completion_tokens": 35,
"prompt_tokens": 77,
"total_tokens": 112,
"completion_tokens_details": None,
"prompt_tokens_details": None,
},
},
component_info=ComponentInfo(
type="haystack_integrations.components.generators.mistral.chat.chat_generator.MistralChatGenerator",
name=None,
),
index=0,
tool_calls=[
ToolCallDelta(index=0, tool_name="weather", arguments='{"city": "Paris"}', id="FL1FFlqUG"),
ToolCallDelta(index=1, tool_name="weather", arguments='{"city": "Berlin"}', id="xSuhp66iB"),
],
start=True,
finish_reason="tool_calls",
),
]
# Convert chunks to a chat message
result = _convert_streaming_chunks_to_chat_message(chunks=chunks)
assert not result.texts
assert not result.text
# Verify both tool calls were found and processed
assert len(result.tool_calls) == 2
assert result.tool_calls[0].id == "FL1FFlqUG"
assert result.tool_calls[0].tool_name == "weather"
assert result.tool_calls[0].arguments == {"city": "Paris"}
assert result.tool_calls[1].id == "xSuhp66iB"
assert result.tool_calls[1].tool_name == "weather"
assert result.tool_calls[1].arguments == {"city": "Berlin"}
def test_convert_streaming_chunk_to_chat_message_empty_tool_call_delta():
chunks = [
StreamingChunk(
content="",
meta={
"model": "gpt-4o-mini-2024-07-18",
"index": 0,
"tool_calls": None,
"finish_reason": None,
"received_at": "2025-02-19T16:02:55.910076",
},
component_info=ComponentInfo(name="test", type="test"),
),
StreamingChunk(
content="",
meta={
"model": "gpt-4o-mini-2024-07-18",
"index": 0,
"tool_calls": [
chat_completion_chunk.ChoiceDeltaToolCall(
index=0,
id="call_ZOj5l67zhZOx6jqjg7ATQwb6",
function=chat_completion_chunk.ChoiceDeltaToolCallFunction(
arguments='{"query":', name="rag_pipeline_tool"
),
type="function",
)
],
"finish_reason": None,
"received_at": "2025-02-19T16:02:55.913919",
},
component_info=ComponentInfo(name="test", type="test"),
index=0,
start=True,
tool_calls=[
ToolCallDelta(
id="call_ZOj5l67zhZOx6jqjg7ATQwb6", tool_name="rag_pipeline_tool", arguments='{"query":', index=0
)
],
),
StreamingChunk(
content="",
meta={
"model": "gpt-4o-mini-2024-07-18",
"index": 0,
"tool_calls": [
chat_completion_chunk.ChoiceDeltaToolCall(
index=0,
function=chat_completion_chunk.ChoiceDeltaToolCallFunction(
arguments=' "Where does Mark live?"}'
),
)
],
"finish_reason": None,
"received_at": "2025-02-19T16:02:55.924420",
},
component_info=ComponentInfo(name="test", type="test"),
index=0,
tool_calls=[ToolCallDelta(arguments=' "Where does Mark live?"}', index=0)],
),
StreamingChunk(
content="",
meta={
"model": "gpt-4o-mini-2024-07-18",
"index": 0,
"tool_calls": [
chat_completion_chunk.ChoiceDeltaToolCall(
index=0, function=chat_completion_chunk.ChoiceDeltaToolCallFunction()
)
],
"finish_reason": "tool_calls",
"received_at": "2025-02-19T16:02:55.948772",
},
tool_calls=[ToolCallDelta(index=0)],
component_info=ComponentInfo(name="test", type="test"),
finish_reason="tool_calls",
index=0,
),
StreamingChunk(
content="",
meta={
"model": "gpt-4o-mini-2024-07-18",
"index": 0,
"tool_calls": None,
"finish_reason": None,
"received_at": "2025-02-19T16:02:55.948772",
"usage": {
"completion_tokens": 42,
"prompt_tokens": 282,
"total_tokens": 324,
"completion_tokens_details": {
"accepted_prediction_tokens": 0,
"audio_tokens": 0,
"reasoning_tokens": 0,
"rejected_prediction_tokens": 0,
},
"prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0},
},
},
component_info=ComponentInfo(name="test", type="test"),
),
]
# Convert chunks to a chat message
result = _convert_streaming_chunks_to_chat_message(chunks=chunks)
assert not result.texts
assert not result.text
# Verify both tool calls were found and processed
assert len(result.tool_calls) == 1
assert result.tool_calls[0].id == "call_ZOj5l67zhZOx6jqjg7ATQwb6"
assert result.tool_calls[0].tool_name == "rag_pipeline_tool"
assert result.tool_calls[0].arguments == {"query": "Where does Mark live?"}
assert result.meta["finish_reason"] == "tool_calls"
def test_print_streaming_chunk_content_only():
chunk = StreamingChunk(
content="Hello, world!",
meta={"model": "test-model"},
component_info=ComponentInfo(name="test", type="test"),
start=True,
)
with patch("builtins.print") as mock_print:
print_streaming_chunk(chunk)
expected_calls = [call("[ASSISTANT]\n", flush=True, end=""), call("Hello, world!", flush=True, end="")]
mock_print.assert_has_calls(expected_calls)
def test_print_streaming_chunk_tool_call():
chunk = StreamingChunk(
content="",
meta={"model": "test-model"},
component_info=ComponentInfo(name="test", type="test"),
start=True,
index=0,
tool_calls=[ToolCallDelta(id="call_123", tool_name="test_tool", arguments='{"param": "value"}', index=0)],
)
with patch("builtins.print") as mock_print:
print_streaming_chunk(chunk)
expected_calls = [
call("[TOOL CALL]\nTool: test_tool \nArguments: ", flush=True, end=""),
call('{"param": "value"}', flush=True, end=""),
]
mock_print.assert_has_calls(expected_calls)
def test_print_streaming_chunk_tool_call_result():
chunk = StreamingChunk(
content="",
meta={"model": "test-model"},
component_info=ComponentInfo(name="test", type="test"),
index=0,
tool_call_result=ToolCallResult(
result="Tool execution completed successfully",
origin=ToolCall(id="call_123", tool_name="test_tool", arguments={}),
error=False,
),
)
with patch("builtins.print") as mock_print:
print_streaming_chunk(chunk)
expected_calls = [call("[TOOL RESULT]\nTool execution completed successfully", flush=True, end="")]
mock_print.assert_has_calls(expected_calls)
def test_print_streaming_chunk_with_finish_reason():
chunk = StreamingChunk(
content="Final content.",
meta={"model": "test-model"},
component_info=ComponentInfo(name="test", type="test"),
start=True,
finish_reason="stop",
)
with patch("builtins.print") as mock_print:
print_streaming_chunk(chunk)
expected_calls = [
call("[ASSISTANT]\n", flush=True, end=""),
call("Final content.", flush=True, end=""),
call("\n\n", flush=True, end=""),
]
mock_print.assert_has_calls(expected_calls)
def test_print_streaming_chunk_empty_chunk():
chunk = StreamingChunk(
content="", meta={"model": "test-model"}, component_info=ComponentInfo(name="test", type="test")
)
with patch("builtins.print") as mock_print:
print_streaming_chunk(chunk)
mock_print.assert_not_called()

View File

@ -5,8 +5,8 @@
import pytest import pytest
from haystack import Document from haystack import Document
from haystack.dataclasses.answer import ExtractedAnswer, GeneratedAnswer
from haystack.components.joiners.answer_joiner import AnswerJoiner, JoinMode from haystack.components.joiners.answer_joiner import AnswerJoiner, JoinMode
from haystack.dataclasses.answer import ExtractedAnswer, GeneratedAnswer
class TestAnswerJoiner: class TestAnswerJoiner:

View File

@ -3,16 +3,17 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import List from typing import List
import pytest import pytest
from haystack import Document, Pipeline from haystack import Document, Pipeline
from haystack.dataclasses import ChatMessage
from haystack.dataclasses.answer import GeneratedAnswer
from haystack.components.builders import AnswerBuilder, ChatPromptBuilder from haystack.components.builders import AnswerBuilder, ChatPromptBuilder
from haystack.components.embedders import SentenceTransformersTextEmbedder
from haystack.components.generators.chat.openai import OpenAIChatGenerator from haystack.components.generators.chat.openai import OpenAIChatGenerator
from haystack.components.joiners.list_joiner import ListJoiner from haystack.components.joiners.list_joiner import ListJoiner
from haystack.components.embedders import SentenceTransformersTextEmbedder
from haystack.core.errors import PipelineConnectError from haystack.core.errors import PipelineConnectError
from haystack.dataclasses import ChatMessage
from haystack.dataclasses.answer import GeneratedAnswer
from haystack.utils.auth import Secret from haystack.utils.auth import Secret

View File

@ -2,8 +2,8 @@
# #
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from haystack.core.serialization import component_from_dict, component_to_dict
from haystack.components.joiners.string_joiner import StringJoiner from haystack.components.joiners.string_joiner import StringJoiner
from haystack.core.serialization import component_from_dict, component_to_dict
class TestStringJoiner: class TestStringJoiner:

View File

@ -3,7 +3,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from haystack import Document from haystack import Document
from haystack.components.preprocessors.csv_document_cleaner import CSVDocumentCleaner from haystack.components.preprocessors.csv_document_cleaner import CSVDocumentCleaner

View File

@ -2,13 +2,15 @@
# #
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import pytest
import logging import logging
from pandas import read_csv
from io import StringIO from io import StringIO
import pytest
from pandas import read_csv
from haystack import Document from haystack import Document
from haystack.core.serialization import component_from_dict, component_to_dict
from haystack.components.preprocessors.csv_document_splitter import CSVDocumentSplitter from haystack.components.preprocessors.csv_document_splitter import CSVDocumentSplitter
from haystack.core.serialization import component_from_dict, component_to_dict
@pytest.fixture @pytest.fixture

View File

@ -7,8 +7,8 @@ import logging
import pytest import pytest
from haystack import Document from haystack import Document
from haystack.dataclasses import ByteStream, SparseEmbedding
from haystack.components.preprocessors import DocumentCleaner from haystack.components.preprocessors import DocumentCleaner
from haystack.dataclasses import ByteStream, SparseEmbedding
class TestDocumentCleaner: class TestDocumentCleaner:

View File

@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from haystack import Document, Pipeline from haystack import Document, Pipeline

View File

@ -2,9 +2,8 @@
# #
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import List
import re import re
from typing import List
import pytest import pytest
@ -27,8 +26,7 @@ def merge_documents(documents):
start = doc.meta["split_idx_start"] # start of the current content start = doc.meta["split_idx_start"] # start of the current content
# if the start of the current content is before the end of the last appended content, adjust it # if the start of the current content is before the end of the last appended content, adjust it
if start < last_idx_end: start = max(start, last_idx_end)
start = last_idx_end
# append the non-overlapping part to the merged text # append the non-overlapping part to the merged text
merged_text += doc.content[start - doc.meta["split_idx_start"] :] merged_text += doc.content[start - doc.meta["split_idx_start"] :]
@ -40,7 +38,7 @@ def merge_documents(documents):
class TestSplittingByFunctionOrCharacterRegex: class TestSplittingByFunctionOrCharacterRegex:
def test_non_text_document(self): def test_non_text_document(self, caplog):
with pytest.raises( with pytest.raises(
ValueError, match="DocumentSplitter only works with text documents but content for document ID" ValueError, match="DocumentSplitter only works with text documents but content for document ID"
): ):
@ -110,7 +108,10 @@ class TestSplittingByFunctionOrCharacterRegex:
def test_split_by_word_multiple_input_docs(self): def test_split_by_word_multiple_input_docs(self):
splitter = DocumentSplitter(split_by="word", split_length=10) splitter = DocumentSplitter(split_by="word", split_length=10)
text1 = "This is a text with some words. There is a second sentence. And there is a third sentence." text1 = "This is a text with some words. There is a second sentence. And there is a third sentence."
text2 = "This is a different text with some words. There is a second sentence. And there is a third sentence. And there is a fourth sentence." text2 = (
"This is a different text with some words. There is a second sentence. And there is a third sentence. "
"And there is a fourth sentence."
)
splitter.warm_up() splitter.warm_up()
result = splitter.run(documents=[Document(content=text1), Document(content=text2)]) result = splitter.run(documents=[Document(content=text1), Document(content=text2)])
docs = result["documents"] docs = result["documents"]
@ -155,7 +156,10 @@ class TestSplittingByFunctionOrCharacterRegex:
def test_split_by_passage(self): def test_split_by_passage(self):
splitter = DocumentSplitter(split_by="passage", split_length=1) splitter = DocumentSplitter(split_by="passage", split_length=1)
text = "This is a text with some words. There is a second sentence.\n\nAnd there is a third sentence.\n\n And another passage." text = (
"This is a text with some words. There is a second sentence.\n\nAnd there is a third sentence.\n\n "
"And another passage."
)
splitter.warm_up() splitter.warm_up()
result = splitter.run(documents=[Document(content=text)]) result = splitter.run(documents=[Document(content=text)])
docs = result["documents"] docs = result["documents"]
@ -172,7 +176,10 @@ class TestSplittingByFunctionOrCharacterRegex:
def test_split_by_page(self): def test_split_by_page(self):
splitter = DocumentSplitter(split_by="page", split_length=1) splitter = DocumentSplitter(split_by="page", split_length=1)
text = "This is a text with some words. There is a second sentence.\f And there is a third sentence.\f And another passage." text = (
"This is a text with some words. There is a second sentence.\f And there is a third sentence.\f And "
"another passage."
)
splitter.warm_up() splitter.warm_up()
result = splitter.run(documents=[Document(content=text)]) result = splitter.run(documents=[Document(content=text)])
docs = result["documents"] docs = result["documents"]
@ -310,7 +317,8 @@ class TestSplittingByFunctionOrCharacterRegex:
def test_add_page_number_to_metadata_with_no_overlap_passage_split(self): def test_add_page_number_to_metadata_with_no_overlap_passage_split(self):
splitter = DocumentSplitter(split_by="passage", split_length=1) splitter = DocumentSplitter(split_by="passage", split_length=1)
doc1 = Document( doc1 = Document(
content="This is a text with some words.\f There is a second sentence.\n\nAnd there is a third sentence.\n\nAnd more passages.\n\n\f And another passage." content="This is a text with some words.\f There is a second sentence.\n\nAnd there is a third sentence."
"\n\nAnd more passages.\n\n\f And another passage."
) )
splitter.warm_up() splitter.warm_up()
result = splitter.run(documents=[doc1]) result = splitter.run(documents=[doc1])
@ -322,7 +330,8 @@ class TestSplittingByFunctionOrCharacterRegex:
def test_add_page_number_to_metadata_with_no_overlap_page_split(self): def test_add_page_number_to_metadata_with_no_overlap_page_split(self):
splitter = DocumentSplitter(split_by="page", split_length=1) splitter = DocumentSplitter(split_by="page", split_length=1)
doc1 = Document( doc1 = Document(
content="This is a text with some words. There is a second sentence.\f And there is a third sentence.\f And another passage." content="This is a text with some words. There is a second sentence.\f And there is a third sentence.\f "
"And another passage."
) )
splitter.warm_up() splitter.warm_up()
result = splitter.run(documents=[doc1]) result = splitter.run(documents=[doc1])
@ -332,7 +341,8 @@ class TestSplittingByFunctionOrCharacterRegex:
splitter = DocumentSplitter(split_by="page", split_length=2) splitter = DocumentSplitter(split_by="page", split_length=2)
doc1 = Document( doc1 = Document(
content="This is a text with some words. There is a second sentence.\f And there is a third sentence.\f And another passage." content="This is a text with some words. There is a second sentence.\f And there is a third sentence.\f "
"And another passage."
) )
splitter.warm_up() splitter.warm_up()
result = splitter.run(documents=[doc1]) result = splitter.run(documents=[doc1])
@ -366,7 +376,8 @@ class TestSplittingByFunctionOrCharacterRegex:
def test_add_page_number_to_metadata_with_overlap_passage_split(self): def test_add_page_number_to_metadata_with_overlap_passage_split(self):
splitter = DocumentSplitter(split_by="passage", split_length=2, split_overlap=1) splitter = DocumentSplitter(split_by="passage", split_length=2, split_overlap=1)
doc1 = Document( doc1 = Document(
content="This is a text with some words.\f There is a second sentence.\n\nAnd there is a third sentence.\n\nAnd more passages.\n\n\f And another passage." content="This is a text with some words.\f There is a second sentence.\n\nAnd there is a third sentence."
"\n\nAnd more passages.\n\n\f And another passage."
) )
splitter.warm_up() splitter.warm_up()
result = splitter.run(documents=[doc1]) result = splitter.run(documents=[doc1])
@ -378,7 +389,8 @@ class TestSplittingByFunctionOrCharacterRegex:
def test_add_page_number_to_metadata_with_overlap_page_split(self): def test_add_page_number_to_metadata_with_overlap_page_split(self):
splitter = DocumentSplitter(split_by="page", split_length=2, split_overlap=1) splitter = DocumentSplitter(split_by="page", split_length=2, split_overlap=1)
doc1 = Document( doc1 = Document(
content="This is a text with some words. There is a second sentence.\f And there is a third sentence.\f And another passage." content="This is a text with some words. There is a second sentence.\f And there is a third sentence.\f "
"And another passage."
) )
splitter.warm_up() splitter.warm_up()
result = splitter.run(documents=[doc1]) result = splitter.run(documents=[doc1])

View File

@ -120,7 +120,7 @@ class TestHierarchicalDocumentSplitter:
"max_runs_per_component": 100, "max_runs_per_component": 100,
"components": { "components": {
"hierarchical_document_splitter": { "hierarchical_document_splitter": {
"type": "haystack.components.preprocessors.hierarchical_document_splitter.HierarchicalDocumentSplitter", "type": "haystack.components.preprocessors.hierarchical_document_splitter.HierarchicalDocumentSplitter", # noqa: E501
"init_parameters": {"block_sizes": [10, 5, 2], "split_overlap": 0, "split_by": "word"}, "init_parameters": {"block_sizes": [10, 5, 2], "split_overlap": 0, "split_by": "word"},
}, },
"doc_writer": { "doc_writer": {
@ -203,7 +203,8 @@ class TestHierarchicalDocumentSplitter:
def test_hierarchical_splitter_multiple_block_sizes(self): def test_hierarchical_splitter_multiple_block_sizes(self):
# Test with three different block sizes # Test with three different block sizes
doc = Document( doc = Document(
content="This is a simple test document with multiple sentences. It should be split into various sizes. This helps test the hierarchy." content="This is a simple test document with multiple sentences. It should be split into various sizes. "
"This helps test the hierarchy."
) )
# Using three block sizes: 10, 5, 2 words # Using three block sizes: 10, 5, 2 words

View File

@ -58,8 +58,8 @@ def test_apply_overlap_with_overlap_capturing_completely_previous_chunk(caplog):
chunks = ["chunk1", "chunk2", "chunk3", "chunk4"] chunks = ["chunk1", "chunk2", "chunk3", "chunk4"]
_ = splitter._apply_overlap(chunks) _ = splitter._apply_overlap(chunks)
assert ( assert (
"Overlap is the same as the previous chunk. Consider increasing the `split_length` parameter or decreasing the `split_overlap` parameter." "Overlap is the same as the previous chunk. Consider increasing the `split_length` parameter or decreasing "
in caplog.text "the `split_overlap` parameter." in caplog.text
) )
@ -134,12 +134,15 @@ AI technology is widely used throughout industry, government, and science. Some
assert ( assert (
chunks[1].content chunks[1].content
== "AI, in its broadest sense, is intelligence exhibited by machines, particularly computer systems.\n" == "AI, in its broadest sense, is intelligence exhibited by machines, particularly computer systems.\n"
) # noqa: E501 )
assert chunks[2].content == "AI technology is widely used throughout industry, government, and science." # noqa: E501 assert chunks[2].content == "AI technology is widely used throughout industry, government, and science."
assert ( assert (
chunks[3].content chunks[3].content
== "Some high-profile applications include advanced web search engines (e.g., Google Search); recommendation systems (used by YouTube, Amazon, and Netflix); interacting via human speech (e.g., Google Assistant, Siri, and Alexa); autonomous vehicles (e.g., Waymo); generative and creative tools (e.g., ChatGPT and AI art); and superhuman play and analysis in strategy games (e.g., chess and Go)." == "Some high-profile applications include advanced web search engines (e.g., Google Search); recommendation "
) # noqa: E501 "systems (used by YouTube, Amazon, and Netflix); interacting via human speech (e.g., Google Assistant, "
"Siri, and Alexa); autonomous vehicles (e.g., Waymo); generative and creative tools (e.g., ChatGPT and "
"AI art); and superhuman play and analysis in strategy games (e.g., chess and Go)."
)
def test_run_split_by_dot_count_page_breaks_split_unit_char() -> None: def test_run_split_by_dot_count_page_breaks_split_unit_char() -> None:

View File

@ -3,15 +3,14 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import time import time
import pytest
from unittest.mock import patch
from pathlib import Path from pathlib import Path
from unittest.mock import patch
from haystack.components.preprocessors.sentence_tokenizer import SentenceSplitter import pytest
from haystack.components.preprocessors.sentence_tokenizer import QUOTE_SPANS_RE
from pytest import LogCaptureFixture from pytest import LogCaptureFixture
from haystack.components.preprocessors.sentence_tokenizer import QUOTE_SPANS_RE, SentenceSplitter
def test_apply_split_rules_no_join() -> None: def test_apply_split_rules_no_join() -> None:
text = "This is a test. This is another test. And a third test." text = "This is a test. This is another test. And a third test."
@ -56,7 +55,7 @@ def mock_file_content():
def test_read_abbreviations_existing_file(tmp_path, mock_file_content): def test_read_abbreviations_existing_file(tmp_path, mock_file_content):
abbrev_dir = tmp_path / "data" / "abbreviations" abbrev_dir = tmp_path / "data" / "abbreviations"
abbrev_dir.mkdir(parents=True) abbrev_dir.mkdir(parents=True)
abbrev_file = abbrev_dir / f"en.txt" abbrev_file = abbrev_dir / "en.txt"
abbrev_file.write_text(mock_file_content) abbrev_file.write_text(mock_file_content)
with patch("haystack.components.preprocessors.sentence_tokenizer.Path") as mock_path: with patch("haystack.components.preprocessors.sentence_tokenizer.Path") as mock_path:

View File

@ -2,11 +2,11 @@
# #
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import pytest from unittest.mock import MagicMock, patch
from unittest.mock import patch, MagicMock
import requests
import httpx import httpx
import pytest
import requests
from haystack import Document from haystack import Document
from haystack.components.rankers.hugging_face_tei import HuggingFaceTEIRanker, TruncationDirection from haystack.components.rankers.hugging_face_tei import HuggingFaceTEIRanker, TruncationDirection

Some files were not shown because too many files have changed in this diff Show More