Compare commits

...

13 Commits

Author SHA1 Message Date
Sebastian Husch Lee
16fc41cd95
feat: Relax requirement for creating a ToolCallDelta dataclass (#9582)
* Relax our requirement for ToolCallDelta to better match ChoiceDeltaToolCall and ChoiceDeltaToolCallFunction from OpenAI

* Add reno

* Update tests
2025-07-03 08:50:29 +02:00
Amna Mubashar
9fd552f906
chore: remove deprecated async_executor param from ToolInvoker (#9571)
* Remove async executor

* Add release notes

* Linting

* update release notes
2025-07-02 14:02:51 +02:00
Amna Mubashar
adb2759d00
chore: remove deprecated State from haystack.dataclasses (#9578)
* Remove deprecated class

* Remove state from pydocs
2025-07-02 12:19:06 +02:00
Stefano Fiorucci
848115c65e
fix: fix print_streaming_chunk + add tests (#9579)
* fix: fix print_streaming_chunk + add tests

* rel note
2025-07-01 16:49:18 +02:00
Sebastian Husch Lee
3aaa201ed6
feat: Add tool_invoker_kwargs to Agent (#9574)
* Add new param to agent to pass any kwargs to tool invoker

* Add reno
2025-07-01 11:09:58 +02:00
Amna Mubashar
f11870b212
fix: adjust the async executor in ToolInvoker (#9562)
* Fix bug

* Add a new test

* PR comments

* Add another test

* Small fix

* Fix linting

* Update tests
2025-06-30 15:13:01 +02:00
Sebastian Husch Lee
97e72b9693
feat: Add to_dict and from_dict to ByteStream (#9568)
* Add to_dict and from_dict to ByteStream

* Add reno

* Add unit tests

* Fix and expand tests

* Fix typing

* PR comments
2025-06-30 11:57:22 +00:00
Sebastian Husch Lee
fc64884819
fix: Fix _convert_streaming_chunks_to_chat_message (#9566)
* Fix conversion

* Add reno

* Add unit test
2025-06-30 11:51:25 +02:00
mathislucka
c54a68ab63
fix: files should not be passed as single string (#9559)
* fix: files should not be passed as single string

* chore: we want word splitting in this case

* fix: place directive before command

* fix: find correct directive placement
2025-06-27 11:17:42 +02:00
Stefano Fiorucci
c18f81283c
chore: fix deepset_sync.py for pylint + general linting improvements (#9558)
* chore: fix deepset_sync.py for pylint

* check .github with ruff

* fix

* Update .github/utils/pyproject_to_requirements.py

Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com>

---------

Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com>
2025-06-27 07:54:22 +00:00
mathislucka
101e9cdc34
docs: sync code to deepset workspace (#9555)
* docs: sync code to deepset workspace

* fix: naming

* fix: actionlint
2025-06-27 07:51:59 +02:00
Stefano Fiorucci
bcaef53cbc
test: export HF_TOKEN env var in e2e environment (#9551)
* try to fix e2e tests for private NER models

* explanatory comment

* extend skipif condition
2025-06-25 15:00:28 +02:00
Haystack Bot
85e8493f4f
Update unstable version to 2.16.0-rc0 (#9554)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-06-25 14:57:16 +02:00
40 changed files with 752 additions and 425 deletions

View File

@ -1,14 +1,17 @@
import importlib
import os import os
import sys import sys
import importlib
import traceback import traceback
from pathlib import Path from pathlib import Path
from typing import List, Optional from typing import List, Optional
from haystack import logging # pylint: disable=unused-import # this is needed to avoid circular imports from haystack import logging # pylint: disable=unused-import # this is needed to avoid circular imports
def validate_module_imports(root_dir: str, exclude_subdirs: Optional[List[str]] = None) -> tuple[list, list]: def validate_module_imports(root_dir: str, exclude_subdirs: Optional[List[str]] = None) -> tuple[list, list]:
""" """
Recursively search for all Python modules and attempt to import them. Recursively search for all Python modules and attempt to import them.
This includes both packages (directories with __init__.py) and individual Python files. This includes both packages (directories with __init__.py) and individual Python files.
""" """
imported = [] imported = []
@ -25,7 +28,7 @@ def validate_module_imports(root_dir: str, exclude_subdirs: Optional[List[str]]
# Convert path to module format # Convert path to module format
module_path = ".".join(Path(root).relative_to(base_path.parent).parts) module_path = ".".join(Path(root).relative_to(base_path.parent).parts)
python_files = [f for f in files if f.endswith('.py')] python_files = [f for f in files if f.endswith(".py")]
# Try importing package and individual files # Try importing package and individual files
for file in python_files: for file in python_files:
@ -39,16 +42,15 @@ def validate_module_imports(root_dir: str, exclude_subdirs: Optional[List[str]]
importlib.import_module(module_to_import) importlib.import_module(module_to_import)
imported.append(module_to_import) imported.append(module_to_import)
except: except:
failed.append({ failed.append({"module": module_to_import, "traceback": traceback.format_exc()})
'module': module_to_import,
'traceback': traceback.format_exc()
})
return imported, failed return imported, failed
def main(): def main():
""" """
This script checks that all Haystack modules can be imported successfully. This script checks that all Haystack modules can be imported successfully.
This includes both packages and individual Python files. This includes both packages and individual Python files.
This can detect several issues, such as: This can detect several issues, such as:
- Syntax errors in Python files - Syntax errors in Python files
@ -80,5 +82,5 @@ def main():
sys.exit(1) sys.exit(1)
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@ -1,14 +1,16 @@
import argparse
import re import re
import sys import sys
import argparse
from readme_api import get_versions, create_new_unstable
from readme_api import create_new_unstable, get_versions
VERSION_VALIDATOR = re.compile(r"^[0-9]+\.[0-9]+$") VERSION_VALIDATOR = re.compile(r"^[0-9]+\.[0-9]+$")
def calculate_new_unstable(version: str): def calculate_new_unstable(version: str):
"""
Calculate the new unstable version based on the given version.
"""
# version must be formatted like so <major>.<minor> # version must be formatted like so <major>.<minor>
major, minor = version.split(".") major, minor = version.split(".")
return f"{major}.{int(minor) + 1}-unstable" return f"{major}.{int(minor) + 1}-unstable"

181
.github/utils/deepset_sync.py vendored Normal file
View File

@ -0,0 +1,181 @@
# /// script
# dependencies = [
# "requests",
# ]
# ///
import argparse
import json
import os
import sys
from pathlib import Path
from typing import Optional
import requests
def transform_filename(filepath: Path) -> str:
"""
Transform a file path to the required format:
- Replace path separators with underscores
"""
# Convert to string and replace path separators with underscores
transformed = str(filepath).replace("/", "_").replace("\\", "_")
return transformed
def upload_file_to_deepset(filepath: Path, api_key: str, workspace: str) -> bool:
"""
Upload a single file to Deepset API.
"""
# Read file content
try:
content = filepath.read_text(encoding="utf-8")
except Exception as e:
print(f"Error reading file {filepath}: {e}")
return False
# Transform filename
transformed_name = transform_filename(filepath)
# Prepare metadata
metadata: dict[str, str] = {"original_file_path": str(filepath)}
# Prepare API request
url = f"https://api.cloud.deepset.ai/api/v1/workspaces/{workspace}/files"
params: dict[str, str] = {"file_name": transformed_name, "write_mode": "OVERWRITE"}
headers: dict[str, str] = {"accept": "application/json", "authorization": f"Bearer {api_key}"}
# Prepare multipart form data
files: dict[str, tuple[None, str, str]] = {
"meta": (None, json.dumps(metadata), "application/json"),
"text": (None, content, "text/plain"),
}
try:
response = requests.post(url, params=params, headers=headers, files=files, timeout=300)
response.raise_for_status()
print(f"Successfully uploaded: {filepath} as {transformed_name}")
return True
except requests.exceptions.HTTPError:
print(f"Failed to upload {filepath}: HTTP {response.status_code}")
print(f" Response: {response.text}")
return False
except Exception as e:
print(f"Failed to upload {filepath}: {e}")
return False
def delete_files_from_deepset(filepaths: list[Path], api_key: str, workspace: str) -> bool:
"""
Delete multiple files from Deepset API.
"""
if not filepaths:
return True
# Transform filenames
transformed_names: list[str] = [transform_filename(fp) for fp in filepaths]
# Prepare API request
url = f"https://api.cloud.deepset.ai/api/v1/workspaces/{workspace}/files"
headers: dict[str, str] = {
"accept": "application/json",
"authorization": f"Bearer {api_key}",
"content-type": "application/json",
}
data: dict[str, list[str]] = {"names": transformed_names}
try:
response = requests.delete(url, headers=headers, json=data, timeout=300)
response.raise_for_status()
print(f"Successfully deleted {len(transformed_names)} file(s):")
for original, transformed in zip(filepaths, transformed_names):
print(f" - {original} (as {transformed})")
return True
except requests.exceptions.HTTPError:
print(f"Failed to delete files: HTTP {response.status_code}")
print(f" Response: {response.text}")
return False
except Exception as e:
print(f"Failed to delete files: {e}")
return False
def main() -> None:
"""
Main function to process and upload/delete files.
"""
# Parse command line arguments
parser = argparse.ArgumentParser(description="Upload/delete Python files to/from Deepset")
parser.add_argument("--changed", nargs="*", default=[], help="Changed or added files")
parser.add_argument("--deleted", nargs="*", default=[], help="Deleted files")
args = parser.parse_args()
# Get environment variables
api_key: Optional[str] = os.environ.get("DEEPSET_API_KEY")
workspace: str = os.environ.get("DEEPSET_WORKSPACE")
if not api_key:
print("Error: DEEPSET_API_KEY environment variable not set")
sys.exit(1)
# Process arguments and convert to Path objects
changed_files: list[Path] = [Path(f.strip()) for f in args.changed if f.strip()]
deleted_files: list[Path] = [Path(f.strip()) for f in args.deleted if f.strip()]
if not changed_files and not deleted_files:
print("No files to process")
sys.exit(0)
print(f"Processing files in Deepset workspace: {workspace}")
print("-" * 50)
# Track results
upload_success: int = 0
upload_failed: list[Path] = []
delete_success: bool = False
# Handle deletions first
if deleted_files:
print(f"\nDeleting {len(deleted_files)} file(s)...")
delete_success = delete_files_from_deepset(deleted_files, api_key, workspace)
# Upload changed/new files
if changed_files:
print(f"\nUploading {len(changed_files)} file(s)...")
for filepath in changed_files:
if filepath.exists():
if upload_file_to_deepset(filepath, api_key, workspace):
upload_success += 1
else:
upload_failed.append(filepath)
else:
print(f"Skipping non-existent file: {filepath}")
# Summary
print("-" * 50)
print("Processing Summary:")
if changed_files:
print(f" Uploads - Successful: {upload_success}, Failed: {len(upload_failed)}")
if deleted_files:
print(f" Deletions - {'Successful' if delete_success else 'Failed'}: {len(deleted_files)} file(s)")
if upload_failed:
print("\nFailed uploads:")
for f in upload_failed:
print(f" - {f}")
# Exit with error if any operation failed
if upload_failed or (deleted_files and not delete_success):
sys.exit(1)
print("\nAll operations completed successfully!")
if __name__ == "__main__":
main()

View File

@ -12,6 +12,9 @@ VERSION_VALIDATOR = re.compile(r"^[0-9]+\.[0-9]+$")
def readme_token(): def readme_token():
"""
Get the Readme API token from the environment variable and encode it in base64.
"""
api_key = os.getenv("README_API_KEY", None) api_key = os.getenv("README_API_KEY", None)
if not api_key: if not api_key:
raise Exception("README_API_KEY env var is not set") raise Exception("README_API_KEY env var is not set")
@ -21,6 +24,9 @@ def readme_token():
def create_headers(version: str): def create_headers(version: str):
"""
Create headers for the Readme API.
"""
return {"authorization": f"Basic {readme_token()}", "x-readme-version": version} return {"authorization": f"Basic {readme_token()}", "x-readme-version": version}
@ -35,6 +41,9 @@ def get_docs_in_category(category_slug: str, version: str) -> List[str]:
def delete_doc(slug: str, version: str): def delete_doc(slug: str, version: str):
"""
Delete a document from Readme, based on the slug and version.
"""
url = f"https://dash.readme.com/api/v1/docs/{slug}" url = f"https://dash.readme.com/api/v1/docs/{slug}"
headers = create_headers(version) headers = create_headers(version)
res = requests.delete(url, headers=headers, timeout=10) res = requests.delete(url, headers=headers, timeout=10)

View File

@ -1,11 +1,13 @@
import ast
import hashlib
from pathlib import Path from pathlib import Path
from typing import Iterator from typing import Iterator
import ast
import hashlib
def docstrings_checksum(python_files: Iterator[Path]): def docstrings_checksum(python_files: Iterator[Path]):
"""
Calculate the checksum of the docstrings in the given Python files.
"""
files_content = (f.read_text() for f in python_files) files_content = (f.read_text() for f in python_files)
trees = (ast.parse(c) for c in files_content) trees = (ast.parse(c) for c in files_content)

View File

@ -1,6 +1,6 @@
import argparse
import re import re
import sys import sys
import argparse
from readme_api import get_versions, promote_unstable_to_stable from readme_api import get_versions, promote_unstable_to_stable
@ -8,9 +8,7 @@ VERSION_VALIDATOR = re.compile(r"^[0-9]+\.[0-9]+$")
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument("-v", "--version", help="The version to promote to stable (e.g. 2.1).", required=True)
"-v", "--version", help="The version to promote to stable (e.g. 2.1).", required=True
)
args = parser.parse_args() args = parser.parse_args()
if VERSION_VALIDATOR.match(args.version) is None: if VERSION_VALIDATOR.match(args.version) is None:

View File

@ -3,7 +3,8 @@ import re
import sys import sys
from pathlib import Path from pathlib import Path
import toml # toml is available in the default environment but not in the test environment, so pylint complains
import toml # pylint: disable=import-error
matcher = re.compile(r"farm-haystack\[(.+)\]") matcher = re.compile(r"farm-haystack\[(.+)\]")
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -14,6 +15,9 @@ parser.add_argument("--extra", default="")
def resolve(target: str, extras: dict, results: set): def resolve(target: str, extras: dict, results: set):
"""
Resolve the dependencies for a given target.
"""
if target not in extras: if target not in extras:
results.add(target) results.add(target)
return return
@ -28,6 +32,9 @@ def resolve(target: str, extras: dict, results: set):
def main(pyproject_path: Path, extra: str = ""): def main(pyproject_path: Path, extra: str = ""):
"""
Convert a pyproject.toml file to a requirements.txt file.
"""
content = toml.load(pyproject_path) content = toml.load(pyproject_path)
# basic set of dependencies # basic set of dependencies
deps = set(content["project"]["dependencies"]) deps = set(content["project"]["dependencies"])

View File

@ -1,16 +1,23 @@
import os
import base64 import base64
import os
import requests import requests
class ReadmeAuth(requests.auth.AuthBase): class ReadmeAuth(requests.auth.AuthBase):
def __call__(self, r): """
Custom authentication class for Readme API.
"""
def __call__(self, r): # noqa: D102
r.headers["authorization"] = f"Basic {readme_token()}" r.headers["authorization"] = f"Basic {readme_token()}"
return r return r
def readme_token(): def readme_token():
"""
Get the Readme API token from the environment variable and encode it in base64.
"""
api_key = os.getenv("RDME_API_KEY", None) api_key = os.getenv("RDME_API_KEY", None)
if not api_key: if not api_key:
raise Exception("RDME_API_KEY env var is not set") raise Exception("RDME_API_KEY env var is not set")

View File

@ -18,7 +18,9 @@ env:
PYTHON_VERSION: "3.9" PYTHON_VERSION: "3.9"
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
HATCH_VERSION: "1.14.1" HATCH_VERSION: "1.14.1"
HF_API_TOKEN: ${{ secrets.HUGGINGFACE_API_KEY }} # we use HF_TOKEN instead of HF_API_TOKEN to work around a Hugging Face bug
# see https://github.com/deepset-ai/haystack/issues/9552
HF_TOKEN: ${{ secrets.HUGGINGFACE_API_KEY }}
jobs: jobs:
run: run:

View File

@ -0,0 +1,56 @@
name: Upload Code to Deepset
on:
push:
branches:
- main
jobs:
upload-files:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0 # Fetch all history for proper diff
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install uv
run: |
pip install uv
- name: Get changed files
id: changed-files
uses: tj-actions/changed-files@v46
with:
files: |
haystack/**/*.py
separator: ' '
- name: Upload files to Deepset
if: steps.changed-files.outputs.any_changed == 'true' || steps.changed-files.outputs.any_deleted == 'true'
env:
DEEPSET_API_KEY: ${{ secrets.DEEPSET_API_KEY }}
DEEPSET_WORKSPACE: haystack-code
run: |
# Combine added and modified files for upload
CHANGED_FILES=""
if [ -n "${{ steps.changed-files.outputs.added_files }}" ]; then
CHANGED_FILES="${{ steps.changed-files.outputs.added_files }}"
fi
if [ -n "${{ steps.changed-files.outputs.modified_files }}" ]; then
if [ -n "$CHANGED_FILES" ]; then
CHANGED_FILES="$CHANGED_FILES ${{ steps.changed-files.outputs.modified_files }}"
else
CHANGED_FILES="${{ steps.changed-files.outputs.modified_files }}"
fi
fi
# Run the script with changed and deleted files
# shellcheck disable=SC2086
uv run --no-project --no-config --no-cache .github/utils/deepset_sync.py \
--changed $CHANGED_FILES \
--deleted ${{ steps.changed-files.outputs.deleted_files }}

View File

@ -19,7 +19,7 @@ on:
- "haystack/core/pipeline/predefined/*" - "haystack/core/pipeline/predefined/*"
- "test/**/*.py" - "test/**/*.py"
- "pyproject.toml" - "pyproject.toml"
- ".github/utils/check_imports.py" - ".github/utils/*.py"
env: env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}

View File

@ -16,7 +16,7 @@ on:
- "haystack/core/pipeline/predefined/*" - "haystack/core/pipeline/predefined/*"
- "test/**/*.py" - "test/**/*.py"
- "pyproject.toml" - "pyproject.toml"
- ".github/utils/check_imports.py" - ".github/utils/*.py"
jobs: jobs:
check_if_changed: check_if_changed:
@ -39,7 +39,7 @@ jobs:
- "haystack/core/pipeline/predefined/*" - "haystack/core/pipeline/predefined/*"
- "test/**/*.py" - "test/**/*.py"
- "pyproject.toml" - "pyproject.toml"
- ".github/utils/check_imports.py" - ".github/utils/*.py"
trigger-catch-all: trigger-catch-all:
name: Tests completed name: Tests completed

View File

@ -1 +1 @@
2.15.0-rc0 2.16.0-rc0

View File

@ -2,7 +2,7 @@ loaders:
- type: haystack_pydoc_tools.loaders.CustomPythonLoader - type: haystack_pydoc_tools.loaders.CustomPythonLoader
search_path: [../../../haystack/dataclasses] search_path: [../../../haystack/dataclasses]
modules: modules:
["answer", "byte_stream", "chat_message", "document", "sparse_embedding", "state", "streaming_chunk"] ["answer", "byte_stream", "chat_message", "document", "sparse_embedding", "streaming_chunk"]
ignore_when_discovered: ["__init__"] ignore_when_discovered: ["__init__"]
processors: processors:
- type: filter - type: filter

View File

@ -68,8 +68,8 @@ def test_ner_extractor_hf_backend(raw_texts, hf_annotations, batch_size):
@pytest.mark.parametrize("batch_size", [1, 3]) @pytest.mark.parametrize("batch_size", [1, 3])
@pytest.mark.skipif( @pytest.mark.skipif(
not os.environ.get("HF_API_TOKEN", None), not os.environ.get("HF_API_TOKEN", None) and not os.environ.get("HF_TOKEN", None),
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", reason="Export an env var called HF_API_TOKEN or HF_TOKEN containing the Hugging Face token to run this test.",
) )
def test_ner_extractor_hf_backend_private_models(raw_texts, hf_annotations, batch_size): def test_ner_extractor_hf_backend_private_models(raw_texts, hf_annotations, batch_size):
extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="deepset/bert-base-NER") extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="deepset/bert-base-NER")

View File

@ -67,8 +67,9 @@ class Agent:
exit_conditions: Optional[List[str]] = None, exit_conditions: Optional[List[str]] = None,
state_schema: Optional[Dict[str, Any]] = None, state_schema: Optional[Dict[str, Any]] = None,
max_agent_steps: int = 100, max_agent_steps: int = 100,
raise_on_tool_invocation_failure: bool = False,
streaming_callback: Optional[StreamingCallbackT] = None, streaming_callback: Optional[StreamingCallbackT] = None,
raise_on_tool_invocation_failure: bool = False,
tool_invoker_kwargs: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
""" """
Initialize the agent component. Initialize the agent component.
@ -82,10 +83,11 @@ class Agent:
:param state_schema: The schema for the runtime state used by the tools. :param state_schema: The schema for the runtime state used by the tools.
:param max_agent_steps: Maximum number of steps the agent will run before stopping. Defaults to 100. :param max_agent_steps: Maximum number of steps the agent will run before stopping. Defaults to 100.
If the agent exceeds this number of steps, it will stop and return the current state. If the agent exceeds this number of steps, it will stop and return the current state.
:param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails?
If set to False, the exception will be turned into a chat message and passed to the LLM.
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM. :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
The same callback can be configured to emit tool results when a tool is called. The same callback can be configured to emit tool results when a tool is called.
:param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails?
If set to False, the exception will be turned into a chat message and passed to the LLM.
:param tool_invoker_kwargs: Additional keyword arguments to pass to the ToolInvoker.
:raises TypeError: If the chat_generator does not support tools parameter in its run method. :raises TypeError: If the chat_generator does not support tools parameter in its run method.
:raises ValueError: If the exit_conditions are not valid. :raises ValueError: If the exit_conditions are not valid.
""" """
@ -135,9 +137,15 @@ class Agent:
component.set_input_type(self, name=param, type=config["type"], default=None) component.set_input_type(self, name=param, type=config["type"], default=None)
component.set_output_types(self, **output_types) component.set_output_types(self, **output_types)
self.tool_invoker_kwargs = tool_invoker_kwargs
self._tool_invoker = None self._tool_invoker = None
if self.tools: if self.tools:
self._tool_invoker = ToolInvoker(tools=self.tools, raise_on_failure=self.raise_on_tool_invocation_failure) resolved_tool_invoker_kwargs = {
"tools": self.tools,
"raise_on_failure": self.raise_on_tool_invocation_failure,
**(tool_invoker_kwargs or {}),
}
self._tool_invoker = ToolInvoker(**resolved_tool_invoker_kwargs)
else: else:
logger.warning( logger.warning(
"No tools provided to the Agent. The Agent will behave like a ChatGenerator and only return text " "No tools provided to the Agent. The Agent will behave like a ChatGenerator and only return text "
@ -175,8 +183,9 @@ class Agent:
# We serialize the original state schema, not the resolved one to reflect the original user input # We serialize the original state schema, not the resolved one to reflect the original user input
state_schema=_schema_to_dict(self._state_schema), state_schema=_schema_to_dict(self._state_schema),
max_agent_steps=self.max_agent_steps, max_agent_steps=self.max_agent_steps,
raise_on_tool_invocation_failure=self.raise_on_tool_invocation_failure,
streaming_callback=streaming_callback, streaming_callback=streaming_callback,
raise_on_tool_invocation_failure=self.raise_on_tool_invocation_failure,
tool_invoker_kwargs=self.tool_invoker_kwargs,
) )
@classmethod @classmethod

View File

@ -44,7 +44,7 @@ def print_streaming_chunk(chunk: StreamingChunk) -> None:
if chunk.index and tool_call.index > chunk.index: if chunk.index and tool_call.index > chunk.index:
print("\n\n", flush=True, end="") print("\n\n", flush=True, end="")
print("[TOOL CALL]\nTool: {tool_call.tool_name} \nArguments: ", flush=True, end="") print(f"[TOOL CALL]\nTool: {tool_call.tool_name} \nArguments: ", flush=True, end="")
# print the tool arguments # print the tool arguments
if tool_call.arguments: if tool_call.arguments:
@ -84,24 +84,20 @@ def _convert_streaming_chunks_to_chat_message(chunks: List[StreamingChunk]) -> C
tool_call_data: Dict[int, Dict[str, str]] = {} # Track tool calls by index tool_call_data: Dict[int, Dict[str, str]] = {} # Track tool calls by index
for chunk in chunks: for chunk in chunks:
if chunk.tool_calls: if chunk.tool_calls:
# We do this to make sure mypy is happy, but we enforce index is not None in the StreamingChunk dataclass if
# tool_call is present
assert chunk.index is not None
for tool_call in chunk.tool_calls: for tool_call in chunk.tool_calls:
# We use the index of the tool_call to track the tool call across chunks since the ID is not always # We use the index of the tool_call to track the tool call across chunks since the ID is not always
# provided # provided
if tool_call.index not in tool_call_data: if tool_call.index not in tool_call_data:
tool_call_data[chunk.index] = {"id": "", "name": "", "arguments": ""} tool_call_data[tool_call.index] = {"id": "", "name": "", "arguments": ""}
# Save the ID if present # Save the ID if present
if tool_call.id is not None: if tool_call.id is not None:
tool_call_data[chunk.index]["id"] = tool_call.id tool_call_data[tool_call.index]["id"] = tool_call.id
if tool_call.tool_name is not None: if tool_call.tool_name is not None:
tool_call_data[chunk.index]["name"] += tool_call.tool_name tool_call_data[tool_call.index]["name"] += tool_call.tool_name
if tool_call.arguments is not None: if tool_call.arguments is not None:
tool_call_data[chunk.index]["arguments"] += tool_call.arguments tool_call_data[tool_call.index]["arguments"] += tool_call.arguments
# Convert accumulated tool call data into ToolCall objects # Convert accumulated tool call data into ToolCall objects
sorted_keys = sorted(tool_call_data.keys()) sorted_keys = sorted(tool_call_data.keys())

View File

@ -5,7 +5,6 @@
import asyncio import asyncio
import inspect import inspect
import json import json
import warnings
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from functools import partial from functools import partial
from typing import Any, Dict, List, Optional, Set, Union from typing import Any, Dict, List, Optional, Set, Union
@ -171,7 +170,6 @@ class ToolInvoker:
*, *,
enable_streaming_callback_passthrough: bool = False, enable_streaming_callback_passthrough: bool = False,
max_workers: int = 4, max_workers: int = 4,
async_executor: Optional[ThreadPoolExecutor] = None,
): ):
""" """
Initialize the ToolInvoker component. Initialize the ToolInvoker component.
@ -197,13 +195,7 @@ class ToolInvoker:
If False, the `streaming_callback` will not be passed to the tool invocation. If False, the `streaming_callback` will not be passed to the tool invocation.
:param max_workers: :param max_workers:
The maximum number of workers to use in the thread pool executor. The maximum number of workers to use in the thread pool executor.
:param async_executor: This also decides the maximum number of concurrent tool invocations.
Optional `ThreadPoolExecutor` to use for asynchronous calls.
Note: As of Haystack 2.15.0, you no longer need to explicitly pass
`async_executor`. Instead, you can provide the `max_workers` parameter,
and a `ThreadPoolExecutor` will be created automatically for parallel tool invocations.
Support for `async_executor` will be removed in Haystack 2.16.0.
Please migrate to using `max_workers` instead.
:raises ValueError: :raises ValueError:
If no tools are provided or if duplicate tool names are found. If no tools are provided or if duplicate tool names are found.
""" """
@ -231,37 +223,6 @@ class ToolInvoker:
self._tools_with_names = dict(zip(tool_names, converted_tools)) self._tools_with_names = dict(zip(tool_names, converted_tools))
self.raise_on_failure = raise_on_failure self.raise_on_failure = raise_on_failure
self.convert_result_to_json_string = convert_result_to_json_string self.convert_result_to_json_string = convert_result_to_json_string
self._owns_executor = async_executor is None
if self._owns_executor:
warnings.warn(
"'async_executor' is deprecated in favor of the 'max_workers' parameter. "
"ToolInvoker now creates its own thread pool executor by default using 'max_workers'. "
"Support for 'async_executor' will be removed in Haystack 2.16.0. "
"Please update your usage to pass 'max_workers' instead.",
DeprecationWarning,
)
self.executor = (
ThreadPoolExecutor(
thread_name_prefix=f"async-ToolInvoker-executor-{id(self)}", max_workers=self.max_workers
)
if async_executor is None
else async_executor
)
def __del__(self):
"""
Cleanup when the instance is being destroyed.
"""
if hasattr(self, "_owns_executor") and self._owns_executor and hasattr(self, "executor"):
self.executor.shutdown(wait=True)
def shutdown(self):
"""
Explicitly shutdown the executor if we own it.
"""
if self._owns_executor:
self.executor.shutdown(wait=True)
def _handle_error(self, error: Exception) -> str: def _handle_error(self, error: Exception) -> str:
""" """
@ -655,7 +616,7 @@ class ToolInvoker:
return e return e
@component.output_types(tool_messages=List[ChatMessage], state=State) @component.output_types(tool_messages=List[ChatMessage], state=State)
async def run_async( async def run_async( # noqa: PLR0915
self, self,
messages: List[ChatMessage], messages: List[ChatMessage],
state: Optional[State] = None, state: Optional[State] = None,
@ -718,10 +679,10 @@ class ToolInvoker:
# 2) Execute valid tool calls in parallel # 2) Execute valid tool calls in parallel
if tool_call_params: if tool_call_params:
with self.executor as executor: tool_call_tasks = []
tool_call_tasks = [] valid_tool_calls = []
valid_tool_calls = []
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
# 3) Create async tasks for valid tool calls # 3) Create async tasks for valid tool calls
for params in tool_call_params: for params in tool_call_params:
task = ToolInvoker.invoke_tool_safely(executor, params["tool_to_invoke"], params["final_args"]) task = ToolInvoker.invoke_tool_safely(executor, params["tool_to_invoke"], params["final_args"])
@ -749,7 +710,7 @@ class ToolInvoker:
try: try:
error_message = self._handle_error( error_message = self._handle_error(
ToolOutputMergeError( ToolOutputMergeError(
f"Failed to merge tool outputs fromtool {tool_call.tool_name} into State: {e}" f"Failed to merge tool outputs from tool {tool_call.tool_name} into State: {e}"
) )
) )
tool_messages.append( tool_messages.append(

View File

@ -38,7 +38,6 @@ if TYPE_CHECKING:
from .chat_message import ToolCallResult as ToolCallResult from .chat_message import ToolCallResult as ToolCallResult
from .document import Document as Document from .document import Document as Document
from .sparse_embedding import SparseEmbedding as SparseEmbedding from .sparse_embedding import SparseEmbedding as SparseEmbedding
from .state import State as State
from .streaming_chunk import AsyncStreamingCallbackT as AsyncStreamingCallbackT from .streaming_chunk import AsyncStreamingCallbackT as AsyncStreamingCallbackT
from .streaming_chunk import ComponentInfo as ComponentInfo from .streaming_chunk import ComponentInfo as ComponentInfo
from .streaming_chunk import FinishReason as FinishReason from .streaming_chunk import FinishReason as FinishReason

View File

@ -79,3 +79,24 @@ class ByteStream:
fields.append(f"mime_type={self.mime_type!r}") fields.append(f"mime_type={self.mime_type!r}")
fields_str = ", ".join(fields) fields_str = ", ".join(fields)
return f"{self.__class__.__name__}({fields_str})" return f"{self.__class__.__name__}({fields_str})"
def to_dict(self) -> Dict[str, Any]:
"""
Convert the ByteStream to a dictionary representation.
:returns: A dictionary with keys 'data', 'meta', and 'mime_type'.
"""
# Note: The data is converted to a list of integers for serialization since JSON does not support bytes
# directly.
return {"data": list(self.data), "meta": self.meta, "mime_type": self.mime_type}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ByteStream":
"""
Create a ByteStream from a dictionary representation.
:param data: A dictionary with keys 'data', 'meta', and 'mime_type'.
:returns: A ByteStream instance.
"""
return ByteStream(data=bytes(data["data"]), meta=data.get("meta", {}), mime_type=data.get("mime_type"))

View File

@ -127,8 +127,12 @@ class Document(metaclass=_BackwardCompatible): # noqa: PLW1641
Whether to flatten `meta` field or not. Defaults to `True` to be backward-compatible with Haystack 1.x. Whether to flatten `meta` field or not. Defaults to `True` to be backward-compatible with Haystack 1.x.
""" """
data = asdict(self) data = asdict(self)
if (blob := data.get("blob")) is not None:
data["blob"] = {"data": list(blob["data"]), "mime_type": blob["mime_type"]} # Use `ByteStream` and `SparseEmbedding`'s to_dict methods to convert them to JSON-serializable types.
if self.blob is not None:
data["blob"] = self.blob.to_dict()
if self.sparse_embedding is not None:
data["sparse_embedding"] = self.sparse_embedding.to_dict()
if flatten: if flatten:
meta = data.pop("meta") meta = data.pop("meta")
@ -144,7 +148,7 @@ class Document(metaclass=_BackwardCompatible): # noqa: PLW1641
The `blob` field is converted to its original type. The `blob` field is converted to its original type.
""" """
if blob := data.get("blob"): if blob := data.get("blob"):
data["blob"] = ByteStream(data=bytes(blob["data"]), mime_type=blob["mime_type"]) data["blob"] = ByteStream.from_dict(blob)
if sparse_embedding := data.get("sparse_embedding"): if sparse_embedding := data.get("sparse_embedding"):
data["sparse_embedding"] = SparseEmbedding.from_dict(sparse_embedding) data["sparse_embedding"] = SparseEmbedding.from_dict(sparse_embedding)

View File

@ -1,43 +0,0 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import warnings
from typing import Any, Dict, Optional
from haystack.components.agents import State as Utils_State
class State(Utils_State):
"""
A class that wraps a StateSchema and maintains an internal _data dictionary.
Deprecated in favor of `haystack.components.agents.State`. It will be removed in Haystack 2.16.0.
Each schema entry has:
"parameter_name": {
"type": SomeType,
"handler": Optional[Callable[[Any, Any], Any]]
}
"""
def __init__(self, schema: Dict[str, Any], data: Optional[Dict[str, Any]] = None):
"""
Initialize a State object with a schema and optional data.
:param schema: Dictionary mapping parameter names to their type and handler configs.
Type must be a valid Python type, and handler must be a callable function or None.
If handler is None, the default handler for the type will be used. The default handlers are:
- For list types: `haystack.components.agents.state.state_utils.merge_lists`
- For all other types: `haystack.components.agents.state.state_utils.replace_values`
:param data: Optional dictionary of initial data to populate the state
"""
warnings.warn(
"`haystack.dataclasses.State` is deprecated and will be removed in Haystack 2.16.0. "
"Use `haystack.components.agents.State` instead.",
DeprecationWarning,
)
super().__init__(schema, data)

View File

@ -1,52 +0,0 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import warnings
from typing import Any, List, TypeVar, Union
T = TypeVar("T")
def merge_lists(current: Union[List[T], T, None], new: Union[List[T], T]) -> List[T]:
"""
Merges two values into a single list.
Deprecated in favor of `haystack.components.agents.state.merge_lists`. It will be removed in Haystack 2.16.0.
If either `current` or `new` is not already a list, it is converted into one.
The function ensures that both inputs are treated as lists and concatenates them.
If `current` is None, it is treated as an empty list.
:param current: The existing value(s), either a single item or a list.
:param new: The new value(s) to merge, either a single item or a list.
:return: A list containing elements from both `current` and `new`.
"""
warnings.warn(
"`haystack.dataclasses.state_utils.merge_lists` is deprecated and will be removed in Haystack 2.16.0. "
"Use `haystack.components.agents.state.merge_lists` instead.",
DeprecationWarning,
)
current_list = [] if current is None else current if isinstance(current, list) else [current]
new_list = new if isinstance(new, list) else [new]
return current_list + new_list
def replace_values(current: Any, new: Any) -> Any:
"""
Replace the `current` value with the `new` value.
:param current: The existing value
:param new: The new value to replace
:return: The new value
"""
warnings.warn(
"`haystack.dataclasses.state_utils.replace_values` is deprecated and will be removed in Haystack 2.16.0. "
"Use `haystack.components.agents.state.replace_values` instead.",
DeprecationWarning,
)
return new

View File

@ -30,12 +30,6 @@ class ToolCallDelta:
arguments: Optional[str] = field(default=None) arguments: Optional[str] = field(default=None)
id: Optional[str] = field(default=None) # noqa: A003 id: Optional[str] = field(default=None) # noqa: A003
def __post_init__(self):
# NOTE: We allow for name and arguments to both be present because some providers like Mistral provide the
# name and full arguments in one chunk
if self.tool_name is None and self.arguments is None:
raise ValueError("At least one of tool_name or arguments must be provided.")
@dataclass @dataclass
class ComponentInfo: class ComponentInfo:

View File

@ -283,7 +283,7 @@ disallow_incomplete_defs = false
[tool.ruff] [tool.ruff]
line-length = 120 line-length = 120
exclude = [".github", "proposals"] exclude = ["proposals"]
[tool.ruff.format] [tool.ruff.format]
skip-magic-trailing-comma = true skip-magic-trailing-comma = true

View File

@ -0,0 +1,4 @@
---
features:
- |
Add `to_dict` and `from_dict` to ByteStream so it is consistent with our other dataclasses in having serialization and deserialization methods.

View File

@ -0,0 +1,4 @@
---
features:
- |
Added the `tool_invoker_kwargs` param to Agent so additional kwargs can be passed to the ToolInvoker like `max_workers` and `enable_streaming_callback_passthrough`.

View File

@ -0,0 +1,4 @@
---
fixes:
- |
Fix `_convert_streaming_chunks_to_chat_message` which is used to convert Haystack StreamingChunks into a Haystack ChatMessage. This fixes the scenario where one StreamingChunk contains two ToolCallDetlas in StreamingChunk.tool_calls. With this fix this correctly saves both ToolCallDeltas whereas before they were overwriting each other. This only occurs with some LLM providers like Mistral (and not OpenAI) due to how the provider returns tool calls.

View File

@ -0,0 +1,3 @@
---
fixes:
- Fixed a bug in the `print_streaming_chunk` utility function that prevented tool call name from being printed.

View File

@ -0,0 +1,4 @@
---
enhancements:
- |
We relaxed the requirement that in ToolCallDelta (introduced in Haystack 2.15) which required the parameters arguments or name to be populated to be able to create a ToolCallDelta dataclass. We remove this requirement to be more in line with OpenAI's SDK and since this was causing errors for some hosted versions of open source models following OpenAI's SDK specification.

View File

@ -0,0 +1,6 @@
---
upgrade:
- |
The deprecated `async_executor` parameter has been removed from the `ToolInvoker` class.
Please use the `max_workers` parameter instead and a `ThreadPoolExecutor` with
these workers will be created automatically for parallel tool invocations.

View File

@ -0,0 +1,5 @@
---
upgrade:
- |
The deprecated `State` class has been removed from the `haystack.dataclasses` module.
The `State` class is now part of the `haystack.components.agents` module.

View File

@ -174,6 +174,7 @@ class TestAgent:
tools=[weather_tool, component_tool], tools=[weather_tool, component_tool],
exit_conditions=["text", "weather_tool"], exit_conditions=["text", "weather_tool"],
state_schema={"foo": {"type": str}}, state_schema={"foo": {"type": str}},
tool_invoker_kwargs={"max_workers": 5, "enable_streaming_callback_passthrough": True},
) )
serialized_agent = agent.to_dict() serialized_agent = agent.to_dict()
assert serialized_agent == { assert serialized_agent == {
@ -236,8 +237,9 @@ class TestAgent:
"exit_conditions": ["text", "weather_tool"], "exit_conditions": ["text", "weather_tool"],
"state_schema": {"foo": {"type": "str"}}, "state_schema": {"foo": {"type": "str"}},
"max_agent_steps": 100, "max_agent_steps": 100,
"raise_on_tool_invocation_failure": False,
"streaming_callback": None, "streaming_callback": None,
"raise_on_tool_invocation_failure": False,
"tool_invoker_kwargs": {"max_workers": 5, "enable_streaming_callback_passthrough": True},
}, },
} }
@ -294,6 +296,7 @@ class TestAgent:
"max_agent_steps": 100, "max_agent_steps": 100,
"raise_on_tool_invocation_failure": False, "raise_on_tool_invocation_failure": False,
"streaming_callback": None, "streaming_callback": None,
"tool_invoker_kwargs": None,
}, },
} }
@ -361,6 +364,7 @@ class TestAgent:
"max_agent_steps": 100, "max_agent_steps": 100,
"raise_on_tool_invocation_failure": False, "raise_on_tool_invocation_failure": False,
"streaming_callback": None, "streaming_callback": None,
"tool_invoker_kwargs": {"max_workers": 5, "enable_streaming_callback_passthrough": True},
}, },
} }
agent = Agent.from_dict(data) agent = Agent.from_dict(data)
@ -375,6 +379,9 @@ class TestAgent:
"foo": {"type": str}, "foo": {"type": str},
"messages": {"handler": merge_lists, "type": List[ChatMessage]}, "messages": {"handler": merge_lists, "type": List[ChatMessage]},
} }
assert agent.tool_invoker_kwargs == {"max_workers": 5, "enable_streaming_callback_passthrough": True}
assert agent._tool_invoker.max_workers == 5
assert agent._tool_invoker.enable_streaming_callback_passthrough is True
def test_from_dict_with_toolset(self, monkeypatch): def test_from_dict_with_toolset(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "fake-key") monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
@ -426,6 +433,7 @@ class TestAgent:
"max_agent_steps": 100, "max_agent_steps": 100,
"raise_on_tool_invocation_failure": False, "raise_on_tool_invocation_failure": False,
"streaming_callback": None, "streaming_callback": None,
"tool_invoker_kwargs": None,
}, },
} }
agent = Agent.from_dict(data) agent = Agent.from_dict(data)

View File

@ -1177,6 +1177,32 @@ class TestChatCompletionChunkConversion:
assert stream_chunk == haystack_chunk assert stream_chunk == haystack_chunk
previous_chunks.append(stream_chunk) previous_chunks.append(stream_chunk)
def test_convert_chat_completion_chunk_with_empty_tool_calls(self):
# This can happen with some LLM providers where tool calls are not present but the pydantic models are still
# initialized.
chunk = ChatCompletionChunk(
id="chatcmpl-BC1y4wqIhe17R8sv3lgLcWlB4tXCw",
choices=[
chat_completion_chunk.Choice(
delta=chat_completion_chunk.ChoiceDelta(
tool_calls=[ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction())]
),
index=0,
)
],
created=1742207200,
model="gpt-4o-mini-2024-07-18",
object="chat.completion.chunk",
)
result = _convert_chat_completion_chunk_to_streaming_chunk(chunk=chunk, previous_chunks=[])
assert result.content == ""
assert result.start is False
assert result.tool_calls == [ToolCallDelta(index=0)]
assert result.tool_call_result is None
assert result.index == 0
assert result.meta["model"] == "gpt-4o-mini-2024-07-18"
assert result.meta["received_at"] is not None
def test_handle_stream_response(self, chat_completion_chunks): def test_handle_stream_response(self, chat_completion_chunks):
openai_chunks = chat_completion_chunks openai_chunks = chat_completion_chunks
comp = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key")) comp = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))

View File

@ -3,9 +3,10 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from openai.types.chat import chat_completion_chunk from openai.types.chat import chat_completion_chunk
from unittest.mock import patch, call
from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message, print_streaming_chunk
from haystack.dataclasses import ComponentInfo, StreamingChunk, ToolCallDelta from haystack.dataclasses import ComponentInfo, StreamingChunk, ToolCall, ToolCallDelta, ToolCallResult
def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(): def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk():
@ -325,3 +326,256 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk():
}, },
"prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0}, "prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0},
} }
def test_convert_streaming_chunk_to_chat_message_two_tool_calls_in_same_chunk():
chunks = [
StreamingChunk(
content="",
meta={
"model": "mistral-small-latest",
"index": 0,
"tool_calls": None,
"finish_reason": None,
"usage": None,
},
component_info=ComponentInfo(
type="haystack_integrations.components.generators.mistral.chat.chat_generator.MistralChatGenerator",
name=None,
),
),
StreamingChunk(
content="",
meta={
"model": "mistral-small-latest",
"index": 0,
"finish_reason": "tool_calls",
"usage": {
"completion_tokens": 35,
"prompt_tokens": 77,
"total_tokens": 112,
"completion_tokens_details": None,
"prompt_tokens_details": None,
},
},
component_info=ComponentInfo(
type="haystack_integrations.components.generators.mistral.chat.chat_generator.MistralChatGenerator",
name=None,
),
index=0,
tool_calls=[
ToolCallDelta(index=0, tool_name="weather", arguments='{"city": "Paris"}', id="FL1FFlqUG"),
ToolCallDelta(index=1, tool_name="weather", arguments='{"city": "Berlin"}', id="xSuhp66iB"),
],
start=True,
finish_reason="tool_calls",
),
]
# Convert chunks to a chat message
result = _convert_streaming_chunks_to_chat_message(chunks=chunks)
assert not result.texts
assert not result.text
# Verify both tool calls were found and processed
assert len(result.tool_calls) == 2
assert result.tool_calls[0].id == "FL1FFlqUG"
assert result.tool_calls[0].tool_name == "weather"
assert result.tool_calls[0].arguments == {"city": "Paris"}
assert result.tool_calls[1].id == "xSuhp66iB"
assert result.tool_calls[1].tool_name == "weather"
assert result.tool_calls[1].arguments == {"city": "Berlin"}
def test_convert_streaming_chunk_to_chat_message_empty_tool_call_delta():
chunks = [
StreamingChunk(
content="",
meta={
"model": "gpt-4o-mini-2024-07-18",
"index": 0,
"tool_calls": None,
"finish_reason": None,
"received_at": "2025-02-19T16:02:55.910076",
},
component_info=ComponentInfo(name="test", type="test"),
),
StreamingChunk(
content="",
meta={
"model": "gpt-4o-mini-2024-07-18",
"index": 0,
"tool_calls": [
chat_completion_chunk.ChoiceDeltaToolCall(
index=0,
id="call_ZOj5l67zhZOx6jqjg7ATQwb6",
function=chat_completion_chunk.ChoiceDeltaToolCallFunction(
arguments='{"query":', name="rag_pipeline_tool"
),
type="function",
)
],
"finish_reason": None,
"received_at": "2025-02-19T16:02:55.913919",
},
component_info=ComponentInfo(name="test", type="test"),
index=0,
start=True,
tool_calls=[
ToolCallDelta(
id="call_ZOj5l67zhZOx6jqjg7ATQwb6", tool_name="rag_pipeline_tool", arguments='{"query":', index=0
)
],
),
StreamingChunk(
content="",
meta={
"model": "gpt-4o-mini-2024-07-18",
"index": 0,
"tool_calls": [
chat_completion_chunk.ChoiceDeltaToolCall(
index=0,
function=chat_completion_chunk.ChoiceDeltaToolCallFunction(
arguments=' "Where does Mark live?"}'
),
)
],
"finish_reason": None,
"received_at": "2025-02-19T16:02:55.924420",
},
component_info=ComponentInfo(name="test", type="test"),
index=0,
tool_calls=[ToolCallDelta(arguments=' "Where does Mark live?"}', index=0)],
),
StreamingChunk(
content="",
meta={
"model": "gpt-4o-mini-2024-07-18",
"index": 0,
"tool_calls": [
chat_completion_chunk.ChoiceDeltaToolCall(
index=0, function=chat_completion_chunk.ChoiceDeltaToolCallFunction()
)
],
"finish_reason": "tool_calls",
"received_at": "2025-02-19T16:02:55.948772",
},
tool_calls=[ToolCallDelta(index=0)],
component_info=ComponentInfo(name="test", type="test"),
finish_reason="tool_calls",
index=0,
),
StreamingChunk(
content="",
meta={
"model": "gpt-4o-mini-2024-07-18",
"index": 0,
"tool_calls": None,
"finish_reason": None,
"received_at": "2025-02-19T16:02:55.948772",
"usage": {
"completion_tokens": 42,
"prompt_tokens": 282,
"total_tokens": 324,
"completion_tokens_details": {
"accepted_prediction_tokens": 0,
"audio_tokens": 0,
"reasoning_tokens": 0,
"rejected_prediction_tokens": 0,
},
"prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0},
},
},
component_info=ComponentInfo(name="test", type="test"),
),
]
# Convert chunks to a chat message
result = _convert_streaming_chunks_to_chat_message(chunks=chunks)
assert not result.texts
assert not result.text
# Verify both tool calls were found and processed
assert len(result.tool_calls) == 1
assert result.tool_calls[0].id == "call_ZOj5l67zhZOx6jqjg7ATQwb6"
assert result.tool_calls[0].tool_name == "rag_pipeline_tool"
assert result.tool_calls[0].arguments == {"query": "Where does Mark live?"}
assert result.meta["finish_reason"] == "tool_calls"
def test_print_streaming_chunk_content_only():
chunk = StreamingChunk(
content="Hello, world!",
meta={"model": "test-model"},
component_info=ComponentInfo(name="test", type="test"),
start=True,
)
with patch("builtins.print") as mock_print:
print_streaming_chunk(chunk)
expected_calls = [call("[ASSISTANT]\n", flush=True, end=""), call("Hello, world!", flush=True, end="")]
mock_print.assert_has_calls(expected_calls)
def test_print_streaming_chunk_tool_call():
chunk = StreamingChunk(
content="",
meta={"model": "test-model"},
component_info=ComponentInfo(name="test", type="test"),
start=True,
index=0,
tool_calls=[ToolCallDelta(id="call_123", tool_name="test_tool", arguments='{"param": "value"}', index=0)],
)
with patch("builtins.print") as mock_print:
print_streaming_chunk(chunk)
expected_calls = [
call("[TOOL CALL]\nTool: test_tool \nArguments: ", flush=True, end=""),
call('{"param": "value"}', flush=True, end=""),
]
mock_print.assert_has_calls(expected_calls)
def test_print_streaming_chunk_tool_call_result():
chunk = StreamingChunk(
content="",
meta={"model": "test-model"},
component_info=ComponentInfo(name="test", type="test"),
index=0,
tool_call_result=ToolCallResult(
result="Tool execution completed successfully",
origin=ToolCall(id="call_123", tool_name="test_tool", arguments={}),
error=False,
),
)
with patch("builtins.print") as mock_print:
print_streaming_chunk(chunk)
expected_calls = [call("[TOOL RESULT]\nTool execution completed successfully", flush=True, end="")]
mock_print.assert_has_calls(expected_calls)
def test_print_streaming_chunk_with_finish_reason():
chunk = StreamingChunk(
content="Final content.",
meta={"model": "test-model"},
component_info=ComponentInfo(name="test", type="test"),
start=True,
finish_reason="stop",
)
with patch("builtins.print") as mock_print:
print_streaming_chunk(chunk)
expected_calls = [
call("[ASSISTANT]\n", flush=True, end=""),
call("Final content.", flush=True, end=""),
call("\n\n", flush=True, end=""),
]
mock_print.assert_has_calls(expected_calls)
def test_print_streaming_chunk_empty_chunk():
chunk = StreamingChunk(
content="", meta={"model": "test-model"}, component_info=ComponentInfo(name="test", type="test")
)
with patch("builtins.print") as mock_print:
print_streaming_chunk(chunk)
mock_print.assert_not_called()

View File

@ -14,7 +14,7 @@ from haystack.components.generators.chat.openai import OpenAIChatGenerator
from haystack.components.generators.utils import print_streaming_chunk from haystack.components.generators.utils import print_streaming_chunk
from haystack.components.tools.tool_invoker import ToolInvoker, ToolNotFoundException, StringConversionError from haystack.components.tools.tool_invoker import ToolInvoker, ToolNotFoundException, StringConversionError
from haystack.dataclasses import ChatMessage, ToolCall, ToolCallResult, ChatRole from haystack.dataclasses import ChatMessage, ToolCall, ToolCallResult, ChatRole
from haystack.dataclasses.state import State from haystack.components.agents.state import State
from haystack.tools import ComponentTool, Tool, Toolset from haystack.tools import ComponentTool, Tool, Toolset
from haystack.tools.errors import ToolInvocationError from haystack.tools.errors import ToolInvocationError
from haystack.dataclasses import StreamingChunk from haystack.dataclasses import StreamingChunk
@ -100,11 +100,6 @@ def faulty_invoker(faulty_tool):
return ToolInvoker(tools=[faulty_tool], raise_on_failure=True, convert_result_to_json_string=False) return ToolInvoker(tools=[faulty_tool], raise_on_failure=True, convert_result_to_json_string=False)
@pytest.fixture
def thread_executor():
return ThreadPoolExecutor(thread_name_prefix=f"async-test-executor", max_workers=2)
class TestToolInvoker: class TestToolInvoker:
def test_init(self, weather_tool): def test_init(self, weather_tool):
invoker = ToolInvoker(tools=[weather_tool]) invoker = ToolInvoker(tools=[weather_tool])
@ -227,7 +222,7 @@ class TestToolInvoker:
assert final_chunk.content == "" assert final_chunk.content == ""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_async_with_streaming_callback(self, thread_executor, weather_tool): async def test_run_async_with_streaming_callback(self, weather_tool):
streaming_callback_called = False streaming_callback_called = False
async def streaming_callback(chunk: StreamingChunk) -> None: async def streaming_callback(chunk: StreamingChunk) -> None:
@ -235,12 +230,7 @@ class TestToolInvoker:
nonlocal streaming_callback_called nonlocal streaming_callback_called
streaming_callback_called = True streaming_callback_called = True
tool_invoker = ToolInvoker( tool_invoker = ToolInvoker(tools=[weather_tool], raise_on_failure=True, convert_result_to_json_string=False)
tools=[weather_tool],
raise_on_failure=True,
convert_result_to_json_string=False,
async_executor=thread_executor,
)
tool_calls = [ tool_calls = [
ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"}), ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"}),
@ -269,18 +259,13 @@ class TestToolInvoker:
assert streaming_callback_called assert streaming_callback_called
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_async_with_streaming_callback_finish_reason(self, thread_executor, weather_tool): async def test_run_async_with_streaming_callback_finish_reason(self, weather_tool):
streaming_chunks = [] streaming_chunks = []
async def streaming_callback(chunk: StreamingChunk) -> None: async def streaming_callback(chunk: StreamingChunk) -> None:
streaming_chunks.append(chunk) streaming_chunks.append(chunk)
tool_invoker = ToolInvoker( tool_invoker = ToolInvoker(tools=[weather_tool], raise_on_failure=True, convert_result_to_json_string=False)
tools=[weather_tool],
raise_on_failure=True,
convert_result_to_json_string=False,
async_executor=thread_executor,
)
tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"}) tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})
message = ChatMessage.from_assistant(tool_calls=[tool_call]) message = ChatMessage.from_assistant(tool_calls=[tool_call])
@ -319,10 +304,8 @@ class TestToolInvoker:
assert not tool_call_result.error assert not tool_call_result.error
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_async_with_toolset(self, tool_set, thread_executor): async def test_run_async_with_toolset(self, tool_set):
tool_invoker = ToolInvoker( tool_invoker = ToolInvoker(tools=tool_set, raise_on_failure=True, convert_result_to_json_string=False)
tools=tool_set, raise_on_failure=True, convert_result_to_json_string=False, async_executor=thread_executor
)
tool_calls = [ tool_calls = [
ToolCall(tool_name="addition_tool", arguments={"num1": 5, "num2": 3}), ToolCall(tool_name="addition_tool", arguments={"num1": 5, "num2": 3}),
ToolCall(tool_name="addition_tool", arguments={"num1": 5, "num2": 3}), ToolCall(tool_name="addition_tool", arguments={"num1": 5, "num2": 3}),
@ -818,6 +801,55 @@ class TestToolInvoker:
assert state.get("counter") in [1, 2, 3] # Should be one of the tool values assert state.get("counter") in [1, 2, 3] # Should be one of the tool values
assert state.get("last_tool") in ["tool_1", "tool_2", "tool_3"] # Should be one of the tool names assert state.get("last_tool") in ["tool_1", "tool_2", "tool_3"] # Should be one of the tool names
def test_call_invoker_two_subsequent_run_calls(self, invoker: ToolInvoker):
tool_calls = [
ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"}),
ToolCall(tool_name="weather_tool", arguments={"location": "Paris"}),
ToolCall(tool_name="weather_tool", arguments={"location": "Rome"}),
]
message = ChatMessage.from_assistant(tool_calls=tool_calls)
streaming_callback_called = False
def streaming_callback(chunk: StreamingChunk) -> None:
nonlocal streaming_callback_called
streaming_callback_called = True
# First call
result_1 = invoker.run(messages=[message], streaming_callback=streaming_callback)
assert "tool_messages" in result_1
assert len(result_1["tool_messages"]) == 3
# Second call
result_2 = invoker.run(messages=[message], streaming_callback=streaming_callback)
assert "tool_messages" in result_2
assert len(result_2["tool_messages"]) == 3
@pytest.mark.asyncio
async def test_call_invoker_two_subsequent_run_async_calls(self, invoker: ToolInvoker):
tool_calls = [
ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"}),
ToolCall(tool_name="weather_tool", arguments={"location": "Paris"}),
ToolCall(tool_name="weather_tool", arguments={"location": "Rome"}),
]
message = ChatMessage.from_assistant(tool_calls=tool_calls)
streaming_callback_called = False
async def streaming_callback(chunk: StreamingChunk) -> None:
nonlocal streaming_callback_called
streaming_callback_called = True
# First call
result_1 = await invoker.run_async(messages=[message], streaming_callback=streaming_callback)
assert "tool_messages" in result_1
assert len(result_1["tool_messages"]) == 3
# Second call
result_2 = await invoker.run_async(messages=[message], streaming_callback=streaming_callback)
assert "tool_messages" in result_2
assert len(result_2["tool_messages"]) == 3
class TestMergeToolOutputs: class TestMergeToolOutputs:
def test_merge_tool_outputs_result_not_a_dict(self, weather_tool): def test_merge_tool_outputs_result_not_a_dict(self, weather_tool):

View File

@ -81,3 +81,23 @@ def test_str_truncation():
assert len(string_repr) < 200 assert len(string_repr) < 200
assert "text/plain" in string_repr assert "text/plain" in string_repr
assert "foo" in string_repr assert "foo" in string_repr
def test_to_dict():
test_str = "Hello, world!"
b = ByteStream.from_string(test_str, mime_type="text/plain", meta={"foo": "bar"})
d = b.to_dict()
assert d["data"] == list(test_str.encode())
assert d["mime_type"] == "text/plain"
assert d["meta"] == {"foo": "bar"}
def test_from_dict():
test_str = "Hello, world!"
b = ByteStream.from_string(test_str, mime_type="text/plain", meta={"foo": "bar"})
d = b.to_dict()
b2 = ByteStream.from_dict(d)
assert b2.data == b.data
assert b2.mime_type == b.mime_type
assert b2.meta == b.meta
assert str(b2) == str(b)

View File

@ -146,7 +146,7 @@ def test_to_dict_without_flattening():
def test_to_dict_with_custom_parameters(): def test_to_dict_with_custom_parameters():
doc = Document( doc = Document(
content="test text", content="test text",
blob=ByteStream(b"some bytes", mime_type="application/pdf"), blob=ByteStream(b"some bytes", mime_type="application/pdf", meta={"foo": "bar"}),
meta={"some": "values", "test": 10}, meta={"some": "values", "test": 10},
score=0.99, score=0.99,
embedding=[10.0, 10.0], embedding=[10.0, 10.0],
@ -156,7 +156,7 @@ def test_to_dict_with_custom_parameters():
assert doc.to_dict() == { assert doc.to_dict() == {
"id": doc.id, "id": doc.id,
"content": "test text", "content": "test text",
"blob": {"data": list(b"some bytes"), "mime_type": "application/pdf"}, "blob": {"data": list(b"some bytes"), "mime_type": "application/pdf", "meta": {"foo": "bar"}},
"some": "values", "some": "values",
"test": 10, "test": 10,
"score": 0.99, "score": 0.99,
@ -178,10 +178,10 @@ def test_to_dict_with_custom_parameters_without_flattening():
assert doc.to_dict(flatten=False) == { assert doc.to_dict(flatten=False) == {
"id": doc.id, "id": doc.id,
"content": "test text", "content": "test text",
"blob": {"data": list(b"some bytes"), "mime_type": "application/pdf"}, "blob": {"data": list(b"some bytes"), "mime_type": "application/pdf", "meta": {}},
"meta": {"some": "values", "test": 10}, "meta": {"some": "values", "test": 10},
"score": 0.99, "score": 0.99,
"embedding": [10, 10], "embedding": [10.0, 10.0],
"sparse_embedding": {"indices": [0, 2, 4], "values": [0.1, 0.2, 0.3]}, "sparse_embedding": {"indices": [0, 2, 4], "values": [0.1, 0.2, 0.3]},
} }
@ -212,7 +212,7 @@ def from_from_dict_with_parameters():
assert Document.from_dict( assert Document.from_dict(
{ {
"content": "test text", "content": "test text",
"blob": {"data": list(blob_data), "mime_type": "text/markdown"}, "blob": {"data": list(blob_data), "mime_type": "text/markdown", "meta": {"text": "test text"}},
"meta": {"text": "test text"}, "meta": {"text": "test text"},
"score": 0.812, "score": 0.812,
"embedding": [0.1, 0.2, 0.3], "embedding": [0.1, 0.2, 0.3],
@ -220,7 +220,7 @@ def from_from_dict_with_parameters():
} }
) == Document( ) == Document(
content="test text", content="test text",
blob=ByteStream(blob_data, mime_type="text/markdown"), blob=ByteStream(blob_data, mime_type="text/markdown", meta={"text": "test text"}),
meta={"text": "test text"}, meta={"text": "test text"},
score=0.812, score=0.812,
embedding=[0.1, 0.2, 0.3], embedding=[0.1, 0.2, 0.3],

View File

@ -1,193 +0,0 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import pytest
from typing import List, Dict
from haystack.dataclasses import ChatMessage
from haystack.dataclasses.state import State
from haystack.components.agents.state.state import _validate_schema, _schema_to_dict, _schema_from_dict, merge_lists
@pytest.fixture
def basic_schema():
return {"numbers": {"type": list}, "metadata": {"type": dict}, "name": {"type": str}}
def numbers_handler(current, new):
if current is None:
return sorted(set(new))
return sorted(set(current + new))
@pytest.fixture
def complex_schema():
return {"numbers": {"type": list, "handler": numbers_handler}, "metadata": {"type": dict}, "name": {"type": str}}
def test_validate_schema_valid(basic_schema):
# Should not raise any exceptions
_validate_schema(basic_schema)
def test_validate_schema_invalid_type():
invalid_schema = {"test": {"type": "not_a_type"}}
with pytest.raises(ValueError, match="must be a Python type"):
_validate_schema(invalid_schema)
def test_validate_schema_missing_type():
invalid_schema = {"test": {"handler": lambda x, y: x + y}}
with pytest.raises(ValueError, match="missing a 'type' entry"):
_validate_schema(invalid_schema)
def test_validate_schema_invalid_handler():
invalid_schema = {"test": {"type": list, "handler": "not_callable"}}
with pytest.raises(ValueError, match="must be callable or None"):
_validate_schema(invalid_schema)
def test_state_initialization(basic_schema):
# Test empty initialization
state = State(basic_schema)
assert state.data == {}
# Test initialization with data
initial_data = {"numbers": [1, 2, 3], "name": "test"}
state = State(basic_schema, initial_data)
assert state.data["numbers"] == [1, 2, 3]
assert state.data["name"] == "test"
def test_state_get(basic_schema):
state = State(basic_schema, {"name": "test"})
assert state.get("name") == "test"
assert state.get("non_existent") is None
assert state.get("non_existent", "default") == "default"
def test_state_set_basic(basic_schema):
state = State(basic_schema)
# Test setting new values
state.set("numbers", [1, 2])
assert state.get("numbers") == [1, 2]
# Test updating existing values
state.set("numbers", [3, 4])
assert state.get("numbers") == [1, 2, 3, 4]
def test_state_set_with_handler(complex_schema):
state = State(complex_schema)
# Test custom handler for numbers
state.set("numbers", [3, 2, 1])
assert state.get("numbers") == [1, 2, 3]
state.set("numbers", [6, 5, 4])
assert state.get("numbers") == [1, 2, 3, 4, 5, 6]
def test_state_set_with_handler_override(basic_schema):
state = State(basic_schema)
# Custom handler that concatenates strings
custom_handler = lambda current, new: f"{current}-{new}" if current else new
state.set("name", "first")
state.set("name", "second", handler_override=custom_handler)
assert state.get("name") == "first-second"
def test_state_has(basic_schema):
state = State(basic_schema, {"name": "test"})
assert state.has("name") is True
assert state.has("non_existent") is False
def test_state_empty_schema():
state = State({})
assert state.data == {}
assert state.schema == {"messages": {"type": List[ChatMessage], "handler": merge_lists}}
with pytest.raises(ValueError, match="Key 'any_key' not found in schema"):
state.set("any_key", "value")
def test_state_none_values(basic_schema):
state = State(basic_schema)
state.set("name", None)
assert state.get("name") is None
state.set("name", "value")
assert state.get("name") == "value"
def test_state_merge_lists(basic_schema):
state = State(basic_schema)
state.set("numbers", "not_a_list")
assert state.get("numbers") == ["not_a_list"]
state.set("numbers", [1, 2])
assert state.get("numbers") == ["not_a_list", 1, 2]
def test_state_nested_structures():
schema = {
"complex": {
"type": Dict[str, List[int]],
"handler": lambda current, new: {
k: current.get(k, []) + new.get(k, []) for k in set(current.keys()) | set(new.keys())
}
if current
else new,
}
}
state = State(schema)
state.set("complex", {"a": [1, 2], "b": [3, 4]})
state.set("complex", {"b": [5, 6], "c": [7, 8]})
expected = {"a": [1, 2], "b": [3, 4, 5, 6], "c": [7, 8]}
assert state.get("complex") == expected
def test_schema_to_dict(basic_schema):
expected_dict = {"numbers": {"type": "list"}, "metadata": {"type": "dict"}, "name": {"type": "str"}}
result = _schema_to_dict(basic_schema)
assert result == expected_dict
def test_schema_to_dict_with_handlers(complex_schema):
expected_dict = {
"numbers": {"type": "list", "handler": "test_state.numbers_handler"},
"metadata": {"type": "dict"},
"name": {"type": "str"},
}
result = _schema_to_dict(complex_schema)
assert result == expected_dict
def test_schema_from_dict(basic_schema):
schema_dict = {"numbers": {"type": "list"}, "metadata": {"type": "dict"}, "name": {"type": "str"}}
result = _schema_from_dict(schema_dict)
assert result == basic_schema
def test_schema_from_dict_with_handlers(complex_schema):
schema_dict = {
"numbers": {"type": "list", "handler": "test_state.numbers_handler"},
"metadata": {"type": "dict"},
"name": {"type": "str"},
}
result = _schema_from_dict(schema_dict)
assert result == complex_schema
def test_state_mutability():
state = State({"my_list": {"type": list}}, {"my_list": [1, 2]})
my_list = state.get("my_list")
my_list.append(3)
assert state.get("my_list") == [1, 2]

View File

@ -99,11 +99,6 @@ def test_tool_call_delta():
assert tool_call.index == 0 assert tool_call.index == 0
def test_tool_call_delta_with_missing_fields():
with pytest.raises(ValueError):
_ = ToolCallDelta(id="123", index=0)
def test_create_chunk_with_finish_reason(): def test_create_chunk_with_finish_reason():
"""Test creating a chunk with the new finish_reason field.""" """Test creating a chunk with the new finish_reason field."""
chunk = StreamingChunk(content="Test content", finish_reason="stop") chunk = StreamingChunk(content="Test content", finish_reason="stop")