mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-03 07:04:01 +00:00
Compare commits
13 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
16fc41cd95 | ||
![]() |
9fd552f906 | ||
![]() |
adb2759d00 | ||
![]() |
848115c65e | ||
![]() |
3aaa201ed6 | ||
![]() |
f11870b212 | ||
![]() |
97e72b9693 | ||
![]() |
fc64884819 | ||
![]() |
c54a68ab63 | ||
![]() |
c18f81283c | ||
![]() |
101e9cdc34 | ||
![]() |
bcaef53cbc | ||
![]() |
85e8493f4f |
16
.github/utils/check_imports.py
vendored
16
.github/utils/check_imports.py
vendored
@ -1,14 +1,17 @@
|
|||||||
|
import importlib
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import importlib
|
|
||||||
import traceback
|
import traceback
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from haystack import logging # pylint: disable=unused-import # this is needed to avoid circular imports
|
from haystack import logging # pylint: disable=unused-import # this is needed to avoid circular imports
|
||||||
|
|
||||||
|
|
||||||
def validate_module_imports(root_dir: str, exclude_subdirs: Optional[List[str]] = None) -> tuple[list, list]:
|
def validate_module_imports(root_dir: str, exclude_subdirs: Optional[List[str]] = None) -> tuple[list, list]:
|
||||||
"""
|
"""
|
||||||
Recursively search for all Python modules and attempt to import them.
|
Recursively search for all Python modules and attempt to import them.
|
||||||
|
|
||||||
This includes both packages (directories with __init__.py) and individual Python files.
|
This includes both packages (directories with __init__.py) and individual Python files.
|
||||||
"""
|
"""
|
||||||
imported = []
|
imported = []
|
||||||
@ -25,7 +28,7 @@ def validate_module_imports(root_dir: str, exclude_subdirs: Optional[List[str]]
|
|||||||
|
|
||||||
# Convert path to module format
|
# Convert path to module format
|
||||||
module_path = ".".join(Path(root).relative_to(base_path.parent).parts)
|
module_path = ".".join(Path(root).relative_to(base_path.parent).parts)
|
||||||
python_files = [f for f in files if f.endswith('.py')]
|
python_files = [f for f in files if f.endswith(".py")]
|
||||||
|
|
||||||
# Try importing package and individual files
|
# Try importing package and individual files
|
||||||
for file in python_files:
|
for file in python_files:
|
||||||
@ -39,16 +42,15 @@ def validate_module_imports(root_dir: str, exclude_subdirs: Optional[List[str]]
|
|||||||
importlib.import_module(module_to_import)
|
importlib.import_module(module_to_import)
|
||||||
imported.append(module_to_import)
|
imported.append(module_to_import)
|
||||||
except:
|
except:
|
||||||
failed.append({
|
failed.append({"module": module_to_import, "traceback": traceback.format_exc()})
|
||||||
'module': module_to_import,
|
|
||||||
'traceback': traceback.format_exc()
|
|
||||||
})
|
|
||||||
|
|
||||||
return imported, failed
|
return imported, failed
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""
|
"""
|
||||||
This script checks that all Haystack modules can be imported successfully.
|
This script checks that all Haystack modules can be imported successfully.
|
||||||
|
|
||||||
This includes both packages and individual Python files.
|
This includes both packages and individual Python files.
|
||||||
This can detect several issues, such as:
|
This can detect several issues, such as:
|
||||||
- Syntax errors in Python files
|
- Syntax errors in Python files
|
||||||
@ -80,5 +82,5 @@ def main():
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
8
.github/utils/create_unstable_docs.py
vendored
8
.github/utils/create_unstable_docs.py
vendored
@ -1,14 +1,16 @@
|
|||||||
|
import argparse
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import argparse
|
|
||||||
|
|
||||||
from readme_api import get_versions, create_new_unstable
|
|
||||||
|
|
||||||
|
from readme_api import create_new_unstable, get_versions
|
||||||
|
|
||||||
VERSION_VALIDATOR = re.compile(r"^[0-9]+\.[0-9]+$")
|
VERSION_VALIDATOR = re.compile(r"^[0-9]+\.[0-9]+$")
|
||||||
|
|
||||||
|
|
||||||
def calculate_new_unstable(version: str):
|
def calculate_new_unstable(version: str):
|
||||||
|
"""
|
||||||
|
Calculate the new unstable version based on the given version.
|
||||||
|
"""
|
||||||
# version must be formatted like so <major>.<minor>
|
# version must be formatted like so <major>.<minor>
|
||||||
major, minor = version.split(".")
|
major, minor = version.split(".")
|
||||||
return f"{major}.{int(minor) + 1}-unstable"
|
return f"{major}.{int(minor) + 1}-unstable"
|
||||||
|
181
.github/utils/deepset_sync.py
vendored
Normal file
181
.github/utils/deepset_sync.py
vendored
Normal file
@ -0,0 +1,181 @@
|
|||||||
|
# /// script
|
||||||
|
# dependencies = [
|
||||||
|
# "requests",
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
def transform_filename(filepath: Path) -> str:
|
||||||
|
"""
|
||||||
|
Transform a file path to the required format:
|
||||||
|
|
||||||
|
- Replace path separators with underscores
|
||||||
|
"""
|
||||||
|
# Convert to string and replace path separators with underscores
|
||||||
|
transformed = str(filepath).replace("/", "_").replace("\\", "_")
|
||||||
|
|
||||||
|
return transformed
|
||||||
|
|
||||||
|
|
||||||
|
def upload_file_to_deepset(filepath: Path, api_key: str, workspace: str) -> bool:
|
||||||
|
"""
|
||||||
|
Upload a single file to Deepset API.
|
||||||
|
"""
|
||||||
|
# Read file content
|
||||||
|
try:
|
||||||
|
content = filepath.read_text(encoding="utf-8")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading file {filepath}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Transform filename
|
||||||
|
transformed_name = transform_filename(filepath)
|
||||||
|
|
||||||
|
# Prepare metadata
|
||||||
|
metadata: dict[str, str] = {"original_file_path": str(filepath)}
|
||||||
|
|
||||||
|
# Prepare API request
|
||||||
|
url = f"https://api.cloud.deepset.ai/api/v1/workspaces/{workspace}/files"
|
||||||
|
params: dict[str, str] = {"file_name": transformed_name, "write_mode": "OVERWRITE"}
|
||||||
|
|
||||||
|
headers: dict[str, str] = {"accept": "application/json", "authorization": f"Bearer {api_key}"}
|
||||||
|
|
||||||
|
# Prepare multipart form data
|
||||||
|
files: dict[str, tuple[None, str, str]] = {
|
||||||
|
"meta": (None, json.dumps(metadata), "application/json"),
|
||||||
|
"text": (None, content, "text/plain"),
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(url, params=params, headers=headers, files=files, timeout=300)
|
||||||
|
response.raise_for_status()
|
||||||
|
print(f"Successfully uploaded: {filepath} as {transformed_name}")
|
||||||
|
return True
|
||||||
|
except requests.exceptions.HTTPError:
|
||||||
|
print(f"Failed to upload {filepath}: HTTP {response.status_code}")
|
||||||
|
print(f" Response: {response.text}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to upload {filepath}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def delete_files_from_deepset(filepaths: list[Path], api_key: str, workspace: str) -> bool:
|
||||||
|
"""
|
||||||
|
Delete multiple files from Deepset API.
|
||||||
|
"""
|
||||||
|
if not filepaths:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Transform filenames
|
||||||
|
transformed_names: list[str] = [transform_filename(fp) for fp in filepaths]
|
||||||
|
|
||||||
|
# Prepare API request
|
||||||
|
url = f"https://api.cloud.deepset.ai/api/v1/workspaces/{workspace}/files"
|
||||||
|
|
||||||
|
headers: dict[str, str] = {
|
||||||
|
"accept": "application/json",
|
||||||
|
"authorization": f"Bearer {api_key}",
|
||||||
|
"content-type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
data: dict[str, list[str]] = {"names": transformed_names}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.delete(url, headers=headers, json=data, timeout=300)
|
||||||
|
response.raise_for_status()
|
||||||
|
print(f"Successfully deleted {len(transformed_names)} file(s):")
|
||||||
|
for original, transformed in zip(filepaths, transformed_names):
|
||||||
|
print(f" - {original} (as {transformed})")
|
||||||
|
return True
|
||||||
|
except requests.exceptions.HTTPError:
|
||||||
|
print(f"Failed to delete files: HTTP {response.status_code}")
|
||||||
|
print(f" Response: {response.text}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to delete files: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
"""
|
||||||
|
Main function to process and upload/delete files.
|
||||||
|
"""
|
||||||
|
# Parse command line arguments
|
||||||
|
parser = argparse.ArgumentParser(description="Upload/delete Python files to/from Deepset")
|
||||||
|
parser.add_argument("--changed", nargs="*", default=[], help="Changed or added files")
|
||||||
|
parser.add_argument("--deleted", nargs="*", default=[], help="Deleted files")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Get environment variables
|
||||||
|
api_key: Optional[str] = os.environ.get("DEEPSET_API_KEY")
|
||||||
|
workspace: str = os.environ.get("DEEPSET_WORKSPACE")
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
print("Error: DEEPSET_API_KEY environment variable not set")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Process arguments and convert to Path objects
|
||||||
|
changed_files: list[Path] = [Path(f.strip()) for f in args.changed if f.strip()]
|
||||||
|
deleted_files: list[Path] = [Path(f.strip()) for f in args.deleted if f.strip()]
|
||||||
|
|
||||||
|
if not changed_files and not deleted_files:
|
||||||
|
print("No files to process")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
print(f"Processing files in Deepset workspace: {workspace}")
|
||||||
|
print("-" * 50)
|
||||||
|
|
||||||
|
# Track results
|
||||||
|
upload_success: int = 0
|
||||||
|
upload_failed: list[Path] = []
|
||||||
|
delete_success: bool = False
|
||||||
|
|
||||||
|
# Handle deletions first
|
||||||
|
if deleted_files:
|
||||||
|
print(f"\nDeleting {len(deleted_files)} file(s)...")
|
||||||
|
delete_success = delete_files_from_deepset(deleted_files, api_key, workspace)
|
||||||
|
|
||||||
|
# Upload changed/new files
|
||||||
|
if changed_files:
|
||||||
|
print(f"\nUploading {len(changed_files)} file(s)...")
|
||||||
|
for filepath in changed_files:
|
||||||
|
if filepath.exists():
|
||||||
|
if upload_file_to_deepset(filepath, api_key, workspace):
|
||||||
|
upload_success += 1
|
||||||
|
else:
|
||||||
|
upload_failed.append(filepath)
|
||||||
|
else:
|
||||||
|
print(f"Skipping non-existent file: {filepath}")
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
print("-" * 50)
|
||||||
|
print("Processing Summary:")
|
||||||
|
if changed_files:
|
||||||
|
print(f" Uploads - Successful: {upload_success}, Failed: {len(upload_failed)}")
|
||||||
|
if deleted_files:
|
||||||
|
print(f" Deletions - {'Successful' if delete_success else 'Failed'}: {len(deleted_files)} file(s)")
|
||||||
|
|
||||||
|
if upload_failed:
|
||||||
|
print("\nFailed uploads:")
|
||||||
|
for f in upload_failed:
|
||||||
|
print(f" - {f}")
|
||||||
|
|
||||||
|
# Exit with error if any operation failed
|
||||||
|
if upload_failed or (deleted_files and not delete_success):
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
print("\nAll operations completed successfully!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
9
.github/utils/delete_outdated_docs.py
vendored
9
.github/utils/delete_outdated_docs.py
vendored
@ -12,6 +12,9 @@ VERSION_VALIDATOR = re.compile(r"^[0-9]+\.[0-9]+$")
|
|||||||
|
|
||||||
|
|
||||||
def readme_token():
|
def readme_token():
|
||||||
|
"""
|
||||||
|
Get the Readme API token from the environment variable and encode it in base64.
|
||||||
|
"""
|
||||||
api_key = os.getenv("README_API_KEY", None)
|
api_key = os.getenv("README_API_KEY", None)
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise Exception("README_API_KEY env var is not set")
|
raise Exception("README_API_KEY env var is not set")
|
||||||
@ -21,6 +24,9 @@ def readme_token():
|
|||||||
|
|
||||||
|
|
||||||
def create_headers(version: str):
|
def create_headers(version: str):
|
||||||
|
"""
|
||||||
|
Create headers for the Readme API.
|
||||||
|
"""
|
||||||
return {"authorization": f"Basic {readme_token()}", "x-readme-version": version}
|
return {"authorization": f"Basic {readme_token()}", "x-readme-version": version}
|
||||||
|
|
||||||
|
|
||||||
@ -35,6 +41,9 @@ def get_docs_in_category(category_slug: str, version: str) -> List[str]:
|
|||||||
|
|
||||||
|
|
||||||
def delete_doc(slug: str, version: str):
|
def delete_doc(slug: str, version: str):
|
||||||
|
"""
|
||||||
|
Delete a document from Readme, based on the slug and version.
|
||||||
|
"""
|
||||||
url = f"https://dash.readme.com/api/v1/docs/{slug}"
|
url = f"https://dash.readme.com/api/v1/docs/{slug}"
|
||||||
headers = create_headers(version)
|
headers = create_headers(version)
|
||||||
res = requests.delete(url, headers=headers, timeout=10)
|
res = requests.delete(url, headers=headers, timeout=10)
|
||||||
|
8
.github/utils/docstrings_checksum.py
vendored
8
.github/utils/docstrings_checksum.py
vendored
@ -1,11 +1,13 @@
|
|||||||
|
import ast
|
||||||
|
import hashlib
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterator
|
from typing import Iterator
|
||||||
|
|
||||||
import ast
|
|
||||||
import hashlib
|
|
||||||
|
|
||||||
|
|
||||||
def docstrings_checksum(python_files: Iterator[Path]):
|
def docstrings_checksum(python_files: Iterator[Path]):
|
||||||
|
"""
|
||||||
|
Calculate the checksum of the docstrings in the given Python files.
|
||||||
|
"""
|
||||||
files_content = (f.read_text() for f in python_files)
|
files_content = (f.read_text() for f in python_files)
|
||||||
trees = (ast.parse(c) for c in files_content)
|
trees = (ast.parse(c) for c in files_content)
|
||||||
|
|
||||||
|
6
.github/utils/promote_unstable_docs.py
vendored
6
.github/utils/promote_unstable_docs.py
vendored
@ -1,6 +1,6 @@
|
|||||||
|
import argparse
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import argparse
|
|
||||||
|
|
||||||
from readme_api import get_versions, promote_unstable_to_stable
|
from readme_api import get_versions, promote_unstable_to_stable
|
||||||
|
|
||||||
@ -8,9 +8,7 @@ VERSION_VALIDATOR = re.compile(r"^[0-9]+\.[0-9]+$")
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument("-v", "--version", help="The version to promote to stable (e.g. 2.1).", required=True)
|
||||||
"-v", "--version", help="The version to promote to stable (e.g. 2.1).", required=True
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if VERSION_VALIDATOR.match(args.version) is None:
|
if VERSION_VALIDATOR.match(args.version) is None:
|
||||||
|
9
.github/utils/pyproject_to_requirements.py
vendored
9
.github/utils/pyproject_to_requirements.py
vendored
@ -3,7 +3,8 @@ import re
|
|||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import toml
|
# toml is available in the default environment but not in the test environment, so pylint complains
|
||||||
|
import toml # pylint: disable=import-error
|
||||||
|
|
||||||
matcher = re.compile(r"farm-haystack\[(.+)\]")
|
matcher = re.compile(r"farm-haystack\[(.+)\]")
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -14,6 +15,9 @@ parser.add_argument("--extra", default="")
|
|||||||
|
|
||||||
|
|
||||||
def resolve(target: str, extras: dict, results: set):
|
def resolve(target: str, extras: dict, results: set):
|
||||||
|
"""
|
||||||
|
Resolve the dependencies for a given target.
|
||||||
|
"""
|
||||||
if target not in extras:
|
if target not in extras:
|
||||||
results.add(target)
|
results.add(target)
|
||||||
return
|
return
|
||||||
@ -28,6 +32,9 @@ def resolve(target: str, extras: dict, results: set):
|
|||||||
|
|
||||||
|
|
||||||
def main(pyproject_path: Path, extra: str = ""):
|
def main(pyproject_path: Path, extra: str = ""):
|
||||||
|
"""
|
||||||
|
Convert a pyproject.toml file to a requirements.txt file.
|
||||||
|
"""
|
||||||
content = toml.load(pyproject_path)
|
content = toml.load(pyproject_path)
|
||||||
# basic set of dependencies
|
# basic set of dependencies
|
||||||
deps = set(content["project"]["dependencies"])
|
deps = set(content["project"]["dependencies"])
|
||||||
|
13
.github/utils/readme_api.py
vendored
13
.github/utils/readme_api.py
vendored
@ -1,16 +1,23 @@
|
|||||||
import os
|
|
||||||
import base64
|
import base64
|
||||||
|
import os
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ReadmeAuth(requests.auth.AuthBase):
|
class ReadmeAuth(requests.auth.AuthBase):
|
||||||
def __call__(self, r):
|
"""
|
||||||
|
Custom authentication class for Readme API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, r): # noqa: D102
|
||||||
r.headers["authorization"] = f"Basic {readme_token()}"
|
r.headers["authorization"] = f"Basic {readme_token()}"
|
||||||
return r
|
return r
|
||||||
|
|
||||||
|
|
||||||
def readme_token():
|
def readme_token():
|
||||||
|
"""
|
||||||
|
Get the Readme API token from the environment variable and encode it in base64.
|
||||||
|
"""
|
||||||
api_key = os.getenv("RDME_API_KEY", None)
|
api_key = os.getenv("RDME_API_KEY", None)
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise Exception("RDME_API_KEY env var is not set")
|
raise Exception("RDME_API_KEY env var is not set")
|
||||||
|
4
.github/workflows/e2e.yml
vendored
4
.github/workflows/e2e.yml
vendored
@ -18,7 +18,9 @@ env:
|
|||||||
PYTHON_VERSION: "3.9"
|
PYTHON_VERSION: "3.9"
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
HATCH_VERSION: "1.14.1"
|
HATCH_VERSION: "1.14.1"
|
||||||
HF_API_TOKEN: ${{ secrets.HUGGINGFACE_API_KEY }}
|
# we use HF_TOKEN instead of HF_API_TOKEN to work around a Hugging Face bug
|
||||||
|
# see https://github.com/deepset-ai/haystack/issues/9552
|
||||||
|
HF_TOKEN: ${{ secrets.HUGGINGFACE_API_KEY }}
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
run:
|
run:
|
||||||
|
56
.github/workflows/sync_code_to_deepset.yml
vendored
Normal file
56
.github/workflows/sync_code_to_deepset.yml
vendored
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
name: Upload Code to Deepset
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
jobs:
|
||||||
|
upload-files:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0 # Fetch all history for proper diff
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.12'
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
run: |
|
||||||
|
pip install uv
|
||||||
|
|
||||||
|
- name: Get changed files
|
||||||
|
id: changed-files
|
||||||
|
uses: tj-actions/changed-files@v46
|
||||||
|
with:
|
||||||
|
files: |
|
||||||
|
haystack/**/*.py
|
||||||
|
separator: ' '
|
||||||
|
|
||||||
|
- name: Upload files to Deepset
|
||||||
|
if: steps.changed-files.outputs.any_changed == 'true' || steps.changed-files.outputs.any_deleted == 'true'
|
||||||
|
env:
|
||||||
|
DEEPSET_API_KEY: ${{ secrets.DEEPSET_API_KEY }}
|
||||||
|
DEEPSET_WORKSPACE: haystack-code
|
||||||
|
run: |
|
||||||
|
# Combine added and modified files for upload
|
||||||
|
CHANGED_FILES=""
|
||||||
|
if [ -n "${{ steps.changed-files.outputs.added_files }}" ]; then
|
||||||
|
CHANGED_FILES="${{ steps.changed-files.outputs.added_files }}"
|
||||||
|
fi
|
||||||
|
if [ -n "${{ steps.changed-files.outputs.modified_files }}" ]; then
|
||||||
|
if [ -n "$CHANGED_FILES" ]; then
|
||||||
|
CHANGED_FILES="$CHANGED_FILES ${{ steps.changed-files.outputs.modified_files }}"
|
||||||
|
else
|
||||||
|
CHANGED_FILES="${{ steps.changed-files.outputs.modified_files }}"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Run the script with changed and deleted files
|
||||||
|
# shellcheck disable=SC2086
|
||||||
|
uv run --no-project --no-config --no-cache .github/utils/deepset_sync.py \
|
||||||
|
--changed $CHANGED_FILES \
|
||||||
|
--deleted ${{ steps.changed-files.outputs.deleted_files }}
|
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@ -19,7 +19,7 @@ on:
|
|||||||
- "haystack/core/pipeline/predefined/*"
|
- "haystack/core/pipeline/predefined/*"
|
||||||
- "test/**/*.py"
|
- "test/**/*.py"
|
||||||
- "pyproject.toml"
|
- "pyproject.toml"
|
||||||
- ".github/utils/check_imports.py"
|
- ".github/utils/*.py"
|
||||||
|
|
||||||
env:
|
env:
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
|
4
.github/workflows/tests_skipper_trigger.yml
vendored
4
.github/workflows/tests_skipper_trigger.yml
vendored
@ -16,7 +16,7 @@ on:
|
|||||||
- "haystack/core/pipeline/predefined/*"
|
- "haystack/core/pipeline/predefined/*"
|
||||||
- "test/**/*.py"
|
- "test/**/*.py"
|
||||||
- "pyproject.toml"
|
- "pyproject.toml"
|
||||||
- ".github/utils/check_imports.py"
|
- ".github/utils/*.py"
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
check_if_changed:
|
check_if_changed:
|
||||||
@ -39,7 +39,7 @@ jobs:
|
|||||||
- "haystack/core/pipeline/predefined/*"
|
- "haystack/core/pipeline/predefined/*"
|
||||||
- "test/**/*.py"
|
- "test/**/*.py"
|
||||||
- "pyproject.toml"
|
- "pyproject.toml"
|
||||||
- ".github/utils/check_imports.py"
|
- ".github/utils/*.py"
|
||||||
|
|
||||||
trigger-catch-all:
|
trigger-catch-all:
|
||||||
name: Tests completed
|
name: Tests completed
|
||||||
|
@ -1 +1 @@
|
|||||||
2.15.0-rc0
|
2.16.0-rc0
|
||||||
|
@ -2,7 +2,7 @@ loaders:
|
|||||||
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
|
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
|
||||||
search_path: [../../../haystack/dataclasses]
|
search_path: [../../../haystack/dataclasses]
|
||||||
modules:
|
modules:
|
||||||
["answer", "byte_stream", "chat_message", "document", "sparse_embedding", "state", "streaming_chunk"]
|
["answer", "byte_stream", "chat_message", "document", "sparse_embedding", "streaming_chunk"]
|
||||||
ignore_when_discovered: ["__init__"]
|
ignore_when_discovered: ["__init__"]
|
||||||
processors:
|
processors:
|
||||||
- type: filter
|
- type: filter
|
||||||
|
@ -68,8 +68,8 @@ def test_ner_extractor_hf_backend(raw_texts, hf_annotations, batch_size):
|
|||||||
|
|
||||||
@pytest.mark.parametrize("batch_size", [1, 3])
|
@pytest.mark.parametrize("batch_size", [1, 3])
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
not os.environ.get("HF_API_TOKEN", None),
|
not os.environ.get("HF_API_TOKEN", None) and not os.environ.get("HF_TOKEN", None),
|
||||||
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
|
reason="Export an env var called HF_API_TOKEN or HF_TOKEN containing the Hugging Face token to run this test.",
|
||||||
)
|
)
|
||||||
def test_ner_extractor_hf_backend_private_models(raw_texts, hf_annotations, batch_size):
|
def test_ner_extractor_hf_backend_private_models(raw_texts, hf_annotations, batch_size):
|
||||||
extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="deepset/bert-base-NER")
|
extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="deepset/bert-base-NER")
|
||||||
|
@ -67,8 +67,9 @@ class Agent:
|
|||||||
exit_conditions: Optional[List[str]] = None,
|
exit_conditions: Optional[List[str]] = None,
|
||||||
state_schema: Optional[Dict[str, Any]] = None,
|
state_schema: Optional[Dict[str, Any]] = None,
|
||||||
max_agent_steps: int = 100,
|
max_agent_steps: int = 100,
|
||||||
raise_on_tool_invocation_failure: bool = False,
|
|
||||||
streaming_callback: Optional[StreamingCallbackT] = None,
|
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||||
|
raise_on_tool_invocation_failure: bool = False,
|
||||||
|
tool_invoker_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize the agent component.
|
Initialize the agent component.
|
||||||
@ -82,10 +83,11 @@ class Agent:
|
|||||||
:param state_schema: The schema for the runtime state used by the tools.
|
:param state_schema: The schema for the runtime state used by the tools.
|
||||||
:param max_agent_steps: Maximum number of steps the agent will run before stopping. Defaults to 100.
|
:param max_agent_steps: Maximum number of steps the agent will run before stopping. Defaults to 100.
|
||||||
If the agent exceeds this number of steps, it will stop and return the current state.
|
If the agent exceeds this number of steps, it will stop and return the current state.
|
||||||
:param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails?
|
|
||||||
If set to False, the exception will be turned into a chat message and passed to the LLM.
|
|
||||||
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
|
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
|
||||||
The same callback can be configured to emit tool results when a tool is called.
|
The same callback can be configured to emit tool results when a tool is called.
|
||||||
|
:param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails?
|
||||||
|
If set to False, the exception will be turned into a chat message and passed to the LLM.
|
||||||
|
:param tool_invoker_kwargs: Additional keyword arguments to pass to the ToolInvoker.
|
||||||
:raises TypeError: If the chat_generator does not support tools parameter in its run method.
|
:raises TypeError: If the chat_generator does not support tools parameter in its run method.
|
||||||
:raises ValueError: If the exit_conditions are not valid.
|
:raises ValueError: If the exit_conditions are not valid.
|
||||||
"""
|
"""
|
||||||
@ -135,9 +137,15 @@ class Agent:
|
|||||||
component.set_input_type(self, name=param, type=config["type"], default=None)
|
component.set_input_type(self, name=param, type=config["type"], default=None)
|
||||||
component.set_output_types(self, **output_types)
|
component.set_output_types(self, **output_types)
|
||||||
|
|
||||||
|
self.tool_invoker_kwargs = tool_invoker_kwargs
|
||||||
self._tool_invoker = None
|
self._tool_invoker = None
|
||||||
if self.tools:
|
if self.tools:
|
||||||
self._tool_invoker = ToolInvoker(tools=self.tools, raise_on_failure=self.raise_on_tool_invocation_failure)
|
resolved_tool_invoker_kwargs = {
|
||||||
|
"tools": self.tools,
|
||||||
|
"raise_on_failure": self.raise_on_tool_invocation_failure,
|
||||||
|
**(tool_invoker_kwargs or {}),
|
||||||
|
}
|
||||||
|
self._tool_invoker = ToolInvoker(**resolved_tool_invoker_kwargs)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"No tools provided to the Agent. The Agent will behave like a ChatGenerator and only return text "
|
"No tools provided to the Agent. The Agent will behave like a ChatGenerator and only return text "
|
||||||
@ -175,8 +183,9 @@ class Agent:
|
|||||||
# We serialize the original state schema, not the resolved one to reflect the original user input
|
# We serialize the original state schema, not the resolved one to reflect the original user input
|
||||||
state_schema=_schema_to_dict(self._state_schema),
|
state_schema=_schema_to_dict(self._state_schema),
|
||||||
max_agent_steps=self.max_agent_steps,
|
max_agent_steps=self.max_agent_steps,
|
||||||
raise_on_tool_invocation_failure=self.raise_on_tool_invocation_failure,
|
|
||||||
streaming_callback=streaming_callback,
|
streaming_callback=streaming_callback,
|
||||||
|
raise_on_tool_invocation_failure=self.raise_on_tool_invocation_failure,
|
||||||
|
tool_invoker_kwargs=self.tool_invoker_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -44,7 +44,7 @@ def print_streaming_chunk(chunk: StreamingChunk) -> None:
|
|||||||
if chunk.index and tool_call.index > chunk.index:
|
if chunk.index and tool_call.index > chunk.index:
|
||||||
print("\n\n", flush=True, end="")
|
print("\n\n", flush=True, end="")
|
||||||
|
|
||||||
print("[TOOL CALL]\nTool: {tool_call.tool_name} \nArguments: ", flush=True, end="")
|
print(f"[TOOL CALL]\nTool: {tool_call.tool_name} \nArguments: ", flush=True, end="")
|
||||||
|
|
||||||
# print the tool arguments
|
# print the tool arguments
|
||||||
if tool_call.arguments:
|
if tool_call.arguments:
|
||||||
@ -84,24 +84,20 @@ def _convert_streaming_chunks_to_chat_message(chunks: List[StreamingChunk]) -> C
|
|||||||
tool_call_data: Dict[int, Dict[str, str]] = {} # Track tool calls by index
|
tool_call_data: Dict[int, Dict[str, str]] = {} # Track tool calls by index
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
if chunk.tool_calls:
|
if chunk.tool_calls:
|
||||||
# We do this to make sure mypy is happy, but we enforce index is not None in the StreamingChunk dataclass if
|
|
||||||
# tool_call is present
|
|
||||||
assert chunk.index is not None
|
|
||||||
|
|
||||||
for tool_call in chunk.tool_calls:
|
for tool_call in chunk.tool_calls:
|
||||||
# We use the index of the tool_call to track the tool call across chunks since the ID is not always
|
# We use the index of the tool_call to track the tool call across chunks since the ID is not always
|
||||||
# provided
|
# provided
|
||||||
if tool_call.index not in tool_call_data:
|
if tool_call.index not in tool_call_data:
|
||||||
tool_call_data[chunk.index] = {"id": "", "name": "", "arguments": ""}
|
tool_call_data[tool_call.index] = {"id": "", "name": "", "arguments": ""}
|
||||||
|
|
||||||
# Save the ID if present
|
# Save the ID if present
|
||||||
if tool_call.id is not None:
|
if tool_call.id is not None:
|
||||||
tool_call_data[chunk.index]["id"] = tool_call.id
|
tool_call_data[tool_call.index]["id"] = tool_call.id
|
||||||
|
|
||||||
if tool_call.tool_name is not None:
|
if tool_call.tool_name is not None:
|
||||||
tool_call_data[chunk.index]["name"] += tool_call.tool_name
|
tool_call_data[tool_call.index]["name"] += tool_call.tool_name
|
||||||
if tool_call.arguments is not None:
|
if tool_call.arguments is not None:
|
||||||
tool_call_data[chunk.index]["arguments"] += tool_call.arguments
|
tool_call_data[tool_call.index]["arguments"] += tool_call.arguments
|
||||||
|
|
||||||
# Convert accumulated tool call data into ToolCall objects
|
# Convert accumulated tool call data into ToolCall objects
|
||||||
sorted_keys = sorted(tool_call_data.keys())
|
sorted_keys = sorted(tool_call_data.keys())
|
||||||
|
@ -5,7 +5,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import warnings
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Dict, List, Optional, Set, Union
|
from typing import Any, Dict, List, Optional, Set, Union
|
||||||
@ -171,7 +170,6 @@ class ToolInvoker:
|
|||||||
*,
|
*,
|
||||||
enable_streaming_callback_passthrough: bool = False,
|
enable_streaming_callback_passthrough: bool = False,
|
||||||
max_workers: int = 4,
|
max_workers: int = 4,
|
||||||
async_executor: Optional[ThreadPoolExecutor] = None,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the ToolInvoker component.
|
Initialize the ToolInvoker component.
|
||||||
@ -197,13 +195,7 @@ class ToolInvoker:
|
|||||||
If False, the `streaming_callback` will not be passed to the tool invocation.
|
If False, the `streaming_callback` will not be passed to the tool invocation.
|
||||||
:param max_workers:
|
:param max_workers:
|
||||||
The maximum number of workers to use in the thread pool executor.
|
The maximum number of workers to use in the thread pool executor.
|
||||||
:param async_executor:
|
This also decides the maximum number of concurrent tool invocations.
|
||||||
Optional `ThreadPoolExecutor` to use for asynchronous calls.
|
|
||||||
Note: As of Haystack 2.15.0, you no longer need to explicitly pass
|
|
||||||
`async_executor`. Instead, you can provide the `max_workers` parameter,
|
|
||||||
and a `ThreadPoolExecutor` will be created automatically for parallel tool invocations.
|
|
||||||
Support for `async_executor` will be removed in Haystack 2.16.0.
|
|
||||||
Please migrate to using `max_workers` instead.
|
|
||||||
:raises ValueError:
|
:raises ValueError:
|
||||||
If no tools are provided or if duplicate tool names are found.
|
If no tools are provided or if duplicate tool names are found.
|
||||||
"""
|
"""
|
||||||
@ -231,37 +223,6 @@ class ToolInvoker:
|
|||||||
self._tools_with_names = dict(zip(tool_names, converted_tools))
|
self._tools_with_names = dict(zip(tool_names, converted_tools))
|
||||||
self.raise_on_failure = raise_on_failure
|
self.raise_on_failure = raise_on_failure
|
||||||
self.convert_result_to_json_string = convert_result_to_json_string
|
self.convert_result_to_json_string = convert_result_to_json_string
|
||||||
self._owns_executor = async_executor is None
|
|
||||||
if self._owns_executor:
|
|
||||||
warnings.warn(
|
|
||||||
"'async_executor' is deprecated in favor of the 'max_workers' parameter. "
|
|
||||||
"ToolInvoker now creates its own thread pool executor by default using 'max_workers'. "
|
|
||||||
"Support for 'async_executor' will be removed in Haystack 2.16.0. "
|
|
||||||
"Please update your usage to pass 'max_workers' instead.",
|
|
||||||
DeprecationWarning,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.executor = (
|
|
||||||
ThreadPoolExecutor(
|
|
||||||
thread_name_prefix=f"async-ToolInvoker-executor-{id(self)}", max_workers=self.max_workers
|
|
||||||
)
|
|
||||||
if async_executor is None
|
|
||||||
else async_executor
|
|
||||||
)
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
"""
|
|
||||||
Cleanup when the instance is being destroyed.
|
|
||||||
"""
|
|
||||||
if hasattr(self, "_owns_executor") and self._owns_executor and hasattr(self, "executor"):
|
|
||||||
self.executor.shutdown(wait=True)
|
|
||||||
|
|
||||||
def shutdown(self):
|
|
||||||
"""
|
|
||||||
Explicitly shutdown the executor if we own it.
|
|
||||||
"""
|
|
||||||
if self._owns_executor:
|
|
||||||
self.executor.shutdown(wait=True)
|
|
||||||
|
|
||||||
def _handle_error(self, error: Exception) -> str:
|
def _handle_error(self, error: Exception) -> str:
|
||||||
"""
|
"""
|
||||||
@ -655,7 +616,7 @@ class ToolInvoker:
|
|||||||
return e
|
return e
|
||||||
|
|
||||||
@component.output_types(tool_messages=List[ChatMessage], state=State)
|
@component.output_types(tool_messages=List[ChatMessage], state=State)
|
||||||
async def run_async(
|
async def run_async( # noqa: PLR0915
|
||||||
self,
|
self,
|
||||||
messages: List[ChatMessage],
|
messages: List[ChatMessage],
|
||||||
state: Optional[State] = None,
|
state: Optional[State] = None,
|
||||||
@ -718,10 +679,10 @@ class ToolInvoker:
|
|||||||
|
|
||||||
# 2) Execute valid tool calls in parallel
|
# 2) Execute valid tool calls in parallel
|
||||||
if tool_call_params:
|
if tool_call_params:
|
||||||
with self.executor as executor:
|
tool_call_tasks = []
|
||||||
tool_call_tasks = []
|
valid_tool_calls = []
|
||||||
valid_tool_calls = []
|
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||||
# 3) Create async tasks for valid tool calls
|
# 3) Create async tasks for valid tool calls
|
||||||
for params in tool_call_params:
|
for params in tool_call_params:
|
||||||
task = ToolInvoker.invoke_tool_safely(executor, params["tool_to_invoke"], params["final_args"])
|
task = ToolInvoker.invoke_tool_safely(executor, params["tool_to_invoke"], params["final_args"])
|
||||||
@ -749,7 +710,7 @@ class ToolInvoker:
|
|||||||
try:
|
try:
|
||||||
error_message = self._handle_error(
|
error_message = self._handle_error(
|
||||||
ToolOutputMergeError(
|
ToolOutputMergeError(
|
||||||
f"Failed to merge tool outputs fromtool {tool_call.tool_name} into State: {e}"
|
f"Failed to merge tool outputs from tool {tool_call.tool_name} into State: {e}"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
tool_messages.append(
|
tool_messages.append(
|
||||||
|
@ -38,7 +38,6 @@ if TYPE_CHECKING:
|
|||||||
from .chat_message import ToolCallResult as ToolCallResult
|
from .chat_message import ToolCallResult as ToolCallResult
|
||||||
from .document import Document as Document
|
from .document import Document as Document
|
||||||
from .sparse_embedding import SparseEmbedding as SparseEmbedding
|
from .sparse_embedding import SparseEmbedding as SparseEmbedding
|
||||||
from .state import State as State
|
|
||||||
from .streaming_chunk import AsyncStreamingCallbackT as AsyncStreamingCallbackT
|
from .streaming_chunk import AsyncStreamingCallbackT as AsyncStreamingCallbackT
|
||||||
from .streaming_chunk import ComponentInfo as ComponentInfo
|
from .streaming_chunk import ComponentInfo as ComponentInfo
|
||||||
from .streaming_chunk import FinishReason as FinishReason
|
from .streaming_chunk import FinishReason as FinishReason
|
||||||
|
@ -79,3 +79,24 @@ class ByteStream:
|
|||||||
fields.append(f"mime_type={self.mime_type!r}")
|
fields.append(f"mime_type={self.mime_type!r}")
|
||||||
fields_str = ", ".join(fields)
|
fields_str = ", ".join(fields)
|
||||||
return f"{self.__class__.__name__}({fields_str})"
|
return f"{self.__class__.__name__}({fields_str})"
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Convert the ByteStream to a dictionary representation.
|
||||||
|
|
||||||
|
:returns: A dictionary with keys 'data', 'meta', and 'mime_type'.
|
||||||
|
"""
|
||||||
|
# Note: The data is converted to a list of integers for serialization since JSON does not support bytes
|
||||||
|
# directly.
|
||||||
|
return {"data": list(self.data), "meta": self.meta, "mime_type": self.mime_type}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "ByteStream":
|
||||||
|
"""
|
||||||
|
Create a ByteStream from a dictionary representation.
|
||||||
|
|
||||||
|
:param data: A dictionary with keys 'data', 'meta', and 'mime_type'.
|
||||||
|
|
||||||
|
:returns: A ByteStream instance.
|
||||||
|
"""
|
||||||
|
return ByteStream(data=bytes(data["data"]), meta=data.get("meta", {}), mime_type=data.get("mime_type"))
|
||||||
|
@ -127,8 +127,12 @@ class Document(metaclass=_BackwardCompatible): # noqa: PLW1641
|
|||||||
Whether to flatten `meta` field or not. Defaults to `True` to be backward-compatible with Haystack 1.x.
|
Whether to flatten `meta` field or not. Defaults to `True` to be backward-compatible with Haystack 1.x.
|
||||||
"""
|
"""
|
||||||
data = asdict(self)
|
data = asdict(self)
|
||||||
if (blob := data.get("blob")) is not None:
|
|
||||||
data["blob"] = {"data": list(blob["data"]), "mime_type": blob["mime_type"]}
|
# Use `ByteStream` and `SparseEmbedding`'s to_dict methods to convert them to JSON-serializable types.
|
||||||
|
if self.blob is not None:
|
||||||
|
data["blob"] = self.blob.to_dict()
|
||||||
|
if self.sparse_embedding is not None:
|
||||||
|
data["sparse_embedding"] = self.sparse_embedding.to_dict()
|
||||||
|
|
||||||
if flatten:
|
if flatten:
|
||||||
meta = data.pop("meta")
|
meta = data.pop("meta")
|
||||||
@ -144,7 +148,7 @@ class Document(metaclass=_BackwardCompatible): # noqa: PLW1641
|
|||||||
The `blob` field is converted to its original type.
|
The `blob` field is converted to its original type.
|
||||||
"""
|
"""
|
||||||
if blob := data.get("blob"):
|
if blob := data.get("blob"):
|
||||||
data["blob"] = ByteStream(data=bytes(blob["data"]), mime_type=blob["mime_type"])
|
data["blob"] = ByteStream.from_dict(blob)
|
||||||
if sparse_embedding := data.get("sparse_embedding"):
|
if sparse_embedding := data.get("sparse_embedding"):
|
||||||
data["sparse_embedding"] = SparseEmbedding.from_dict(sparse_embedding)
|
data["sparse_embedding"] = SparseEmbedding.from_dict(sparse_embedding)
|
||||||
|
|
||||||
|
@ -1,43 +0,0 @@
|
|||||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
||||||
#
|
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
|
|
||||||
|
|
||||||
import warnings
|
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
from haystack.components.agents import State as Utils_State
|
|
||||||
|
|
||||||
|
|
||||||
class State(Utils_State):
|
|
||||||
"""
|
|
||||||
A class that wraps a StateSchema and maintains an internal _data dictionary.
|
|
||||||
|
|
||||||
Deprecated in favor of `haystack.components.agents.State`. It will be removed in Haystack 2.16.0.
|
|
||||||
|
|
||||||
Each schema entry has:
|
|
||||||
"parameter_name": {
|
|
||||||
"type": SomeType,
|
|
||||||
"handler": Optional[Callable[[Any, Any], Any]]
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, schema: Dict[str, Any], data: Optional[Dict[str, Any]] = None):
|
|
||||||
"""
|
|
||||||
Initialize a State object with a schema and optional data.
|
|
||||||
|
|
||||||
:param schema: Dictionary mapping parameter names to their type and handler configs.
|
|
||||||
Type must be a valid Python type, and handler must be a callable function or None.
|
|
||||||
If handler is None, the default handler for the type will be used. The default handlers are:
|
|
||||||
- For list types: `haystack.components.agents.state.state_utils.merge_lists`
|
|
||||||
- For all other types: `haystack.components.agents.state.state_utils.replace_values`
|
|
||||||
:param data: Optional dictionary of initial data to populate the state
|
|
||||||
"""
|
|
||||||
|
|
||||||
warnings.warn(
|
|
||||||
"`haystack.dataclasses.State` is deprecated and will be removed in Haystack 2.16.0. "
|
|
||||||
"Use `haystack.components.agents.State` instead.",
|
|
||||||
DeprecationWarning,
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(schema, data)
|
|
@ -1,52 +0,0 @@
|
|||||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
||||||
#
|
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
|
|
||||||
import warnings
|
|
||||||
from typing import Any, List, TypeVar, Union
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
def merge_lists(current: Union[List[T], T, None], new: Union[List[T], T]) -> List[T]:
|
|
||||||
"""
|
|
||||||
Merges two values into a single list.
|
|
||||||
|
|
||||||
Deprecated in favor of `haystack.components.agents.state.merge_lists`. It will be removed in Haystack 2.16.0.
|
|
||||||
|
|
||||||
|
|
||||||
If either `current` or `new` is not already a list, it is converted into one.
|
|
||||||
The function ensures that both inputs are treated as lists and concatenates them.
|
|
||||||
|
|
||||||
If `current` is None, it is treated as an empty list.
|
|
||||||
|
|
||||||
:param current: The existing value(s), either a single item or a list.
|
|
||||||
:param new: The new value(s) to merge, either a single item or a list.
|
|
||||||
:return: A list containing elements from both `current` and `new`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
warnings.warn(
|
|
||||||
"`haystack.dataclasses.state_utils.merge_lists` is deprecated and will be removed in Haystack 2.16.0. "
|
|
||||||
"Use `haystack.components.agents.state.merge_lists` instead.",
|
|
||||||
DeprecationWarning,
|
|
||||||
)
|
|
||||||
current_list = [] if current is None else current if isinstance(current, list) else [current]
|
|
||||||
new_list = new if isinstance(new, list) else [new]
|
|
||||||
return current_list + new_list
|
|
||||||
|
|
||||||
|
|
||||||
def replace_values(current: Any, new: Any) -> Any:
|
|
||||||
"""
|
|
||||||
Replace the `current` value with the `new` value.
|
|
||||||
|
|
||||||
:param current: The existing value
|
|
||||||
:param new: The new value to replace
|
|
||||||
:return: The new value
|
|
||||||
"""
|
|
||||||
|
|
||||||
warnings.warn(
|
|
||||||
"`haystack.dataclasses.state_utils.replace_values` is deprecated and will be removed in Haystack 2.16.0. "
|
|
||||||
"Use `haystack.components.agents.state.replace_values` instead.",
|
|
||||||
DeprecationWarning,
|
|
||||||
)
|
|
||||||
return new
|
|
@ -30,12 +30,6 @@ class ToolCallDelta:
|
|||||||
arguments: Optional[str] = field(default=None)
|
arguments: Optional[str] = field(default=None)
|
||||||
id: Optional[str] = field(default=None) # noqa: A003
|
id: Optional[str] = field(default=None) # noqa: A003
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
# NOTE: We allow for name and arguments to both be present because some providers like Mistral provide the
|
|
||||||
# name and full arguments in one chunk
|
|
||||||
if self.tool_name is None and self.arguments is None:
|
|
||||||
raise ValueError("At least one of tool_name or arguments must be provided.")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ComponentInfo:
|
class ComponentInfo:
|
||||||
|
@ -283,7 +283,7 @@ disallow_incomplete_defs = false
|
|||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 120
|
line-length = 120
|
||||||
exclude = [".github", "proposals"]
|
exclude = ["proposals"]
|
||||||
|
|
||||||
[tool.ruff.format]
|
[tool.ruff.format]
|
||||||
skip-magic-trailing-comma = true
|
skip-magic-trailing-comma = true
|
||||||
|
@ -0,0 +1,4 @@
|
|||||||
|
---
|
||||||
|
features:
|
||||||
|
- |
|
||||||
|
Add `to_dict` and `from_dict` to ByteStream so it is consistent with our other dataclasses in having serialization and deserialization methods.
|
@ -0,0 +1,4 @@
|
|||||||
|
---
|
||||||
|
features:
|
||||||
|
- |
|
||||||
|
Added the `tool_invoker_kwargs` param to Agent so additional kwargs can be passed to the ToolInvoker like `max_workers` and `enable_streaming_callback_passthrough`.
|
@ -0,0 +1,4 @@
|
|||||||
|
---
|
||||||
|
fixes:
|
||||||
|
- |
|
||||||
|
Fix `_convert_streaming_chunks_to_chat_message` which is used to convert Haystack StreamingChunks into a Haystack ChatMessage. This fixes the scenario where one StreamingChunk contains two ToolCallDetlas in StreamingChunk.tool_calls. With this fix this correctly saves both ToolCallDeltas whereas before they were overwriting each other. This only occurs with some LLM providers like Mistral (and not OpenAI) due to how the provider returns tool calls.
|
@ -0,0 +1,3 @@
|
|||||||
|
---
|
||||||
|
fixes:
|
||||||
|
- Fixed a bug in the `print_streaming_chunk` utility function that prevented tool call name from being printed.
|
@ -0,0 +1,4 @@
|
|||||||
|
---
|
||||||
|
enhancements:
|
||||||
|
- |
|
||||||
|
We relaxed the requirement that in ToolCallDelta (introduced in Haystack 2.15) which required the parameters arguments or name to be populated to be able to create a ToolCallDelta dataclass. We remove this requirement to be more in line with OpenAI's SDK and since this was causing errors for some hosted versions of open source models following OpenAI's SDK specification.
|
@ -0,0 +1,6 @@
|
|||||||
|
---
|
||||||
|
upgrade:
|
||||||
|
- |
|
||||||
|
The deprecated `async_executor` parameter has been removed from the `ToolInvoker` class.
|
||||||
|
Please use the `max_workers` parameter instead and a `ThreadPoolExecutor` with
|
||||||
|
these workers will be created automatically for parallel tool invocations.
|
@ -0,0 +1,5 @@
|
|||||||
|
---
|
||||||
|
upgrade:
|
||||||
|
- |
|
||||||
|
The deprecated `State` class has been removed from the `haystack.dataclasses` module.
|
||||||
|
The `State` class is now part of the `haystack.components.agents` module.
|
@ -174,6 +174,7 @@ class TestAgent:
|
|||||||
tools=[weather_tool, component_tool],
|
tools=[weather_tool, component_tool],
|
||||||
exit_conditions=["text", "weather_tool"],
|
exit_conditions=["text", "weather_tool"],
|
||||||
state_schema={"foo": {"type": str}},
|
state_schema={"foo": {"type": str}},
|
||||||
|
tool_invoker_kwargs={"max_workers": 5, "enable_streaming_callback_passthrough": True},
|
||||||
)
|
)
|
||||||
serialized_agent = agent.to_dict()
|
serialized_agent = agent.to_dict()
|
||||||
assert serialized_agent == {
|
assert serialized_agent == {
|
||||||
@ -236,8 +237,9 @@ class TestAgent:
|
|||||||
"exit_conditions": ["text", "weather_tool"],
|
"exit_conditions": ["text", "weather_tool"],
|
||||||
"state_schema": {"foo": {"type": "str"}},
|
"state_schema": {"foo": {"type": "str"}},
|
||||||
"max_agent_steps": 100,
|
"max_agent_steps": 100,
|
||||||
"raise_on_tool_invocation_failure": False,
|
|
||||||
"streaming_callback": None,
|
"streaming_callback": None,
|
||||||
|
"raise_on_tool_invocation_failure": False,
|
||||||
|
"tool_invoker_kwargs": {"max_workers": 5, "enable_streaming_callback_passthrough": True},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -294,6 +296,7 @@ class TestAgent:
|
|||||||
"max_agent_steps": 100,
|
"max_agent_steps": 100,
|
||||||
"raise_on_tool_invocation_failure": False,
|
"raise_on_tool_invocation_failure": False,
|
||||||
"streaming_callback": None,
|
"streaming_callback": None,
|
||||||
|
"tool_invoker_kwargs": None,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -361,6 +364,7 @@ class TestAgent:
|
|||||||
"max_agent_steps": 100,
|
"max_agent_steps": 100,
|
||||||
"raise_on_tool_invocation_failure": False,
|
"raise_on_tool_invocation_failure": False,
|
||||||
"streaming_callback": None,
|
"streaming_callback": None,
|
||||||
|
"tool_invoker_kwargs": {"max_workers": 5, "enable_streaming_callback_passthrough": True},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
agent = Agent.from_dict(data)
|
agent = Agent.from_dict(data)
|
||||||
@ -375,6 +379,9 @@ class TestAgent:
|
|||||||
"foo": {"type": str},
|
"foo": {"type": str},
|
||||||
"messages": {"handler": merge_lists, "type": List[ChatMessage]},
|
"messages": {"handler": merge_lists, "type": List[ChatMessage]},
|
||||||
}
|
}
|
||||||
|
assert agent.tool_invoker_kwargs == {"max_workers": 5, "enable_streaming_callback_passthrough": True}
|
||||||
|
assert agent._tool_invoker.max_workers == 5
|
||||||
|
assert agent._tool_invoker.enable_streaming_callback_passthrough is True
|
||||||
|
|
||||||
def test_from_dict_with_toolset(self, monkeypatch):
|
def test_from_dict_with_toolset(self, monkeypatch):
|
||||||
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
|
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
|
||||||
@ -426,6 +433,7 @@ class TestAgent:
|
|||||||
"max_agent_steps": 100,
|
"max_agent_steps": 100,
|
||||||
"raise_on_tool_invocation_failure": False,
|
"raise_on_tool_invocation_failure": False,
|
||||||
"streaming_callback": None,
|
"streaming_callback": None,
|
||||||
|
"tool_invoker_kwargs": None,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
agent = Agent.from_dict(data)
|
agent = Agent.from_dict(data)
|
||||||
|
@ -1177,6 +1177,32 @@ class TestChatCompletionChunkConversion:
|
|||||||
assert stream_chunk == haystack_chunk
|
assert stream_chunk == haystack_chunk
|
||||||
previous_chunks.append(stream_chunk)
|
previous_chunks.append(stream_chunk)
|
||||||
|
|
||||||
|
def test_convert_chat_completion_chunk_with_empty_tool_calls(self):
|
||||||
|
# This can happen with some LLM providers where tool calls are not present but the pydantic models are still
|
||||||
|
# initialized.
|
||||||
|
chunk = ChatCompletionChunk(
|
||||||
|
id="chatcmpl-BC1y4wqIhe17R8sv3lgLcWlB4tXCw",
|
||||||
|
choices=[
|
||||||
|
chat_completion_chunk.Choice(
|
||||||
|
delta=chat_completion_chunk.ChoiceDelta(
|
||||||
|
tool_calls=[ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction())]
|
||||||
|
),
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
created=1742207200,
|
||||||
|
model="gpt-4o-mini-2024-07-18",
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
)
|
||||||
|
result = _convert_chat_completion_chunk_to_streaming_chunk(chunk=chunk, previous_chunks=[])
|
||||||
|
assert result.content == ""
|
||||||
|
assert result.start is False
|
||||||
|
assert result.tool_calls == [ToolCallDelta(index=0)]
|
||||||
|
assert result.tool_call_result is None
|
||||||
|
assert result.index == 0
|
||||||
|
assert result.meta["model"] == "gpt-4o-mini-2024-07-18"
|
||||||
|
assert result.meta["received_at"] is not None
|
||||||
|
|
||||||
def test_handle_stream_response(self, chat_completion_chunks):
|
def test_handle_stream_response(self, chat_completion_chunks):
|
||||||
openai_chunks = chat_completion_chunks
|
openai_chunks = chat_completion_chunks
|
||||||
comp = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
|
comp = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
|
||||||
|
@ -3,9 +3,10 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from openai.types.chat import chat_completion_chunk
|
from openai.types.chat import chat_completion_chunk
|
||||||
|
from unittest.mock import patch, call
|
||||||
|
|
||||||
from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message
|
from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message, print_streaming_chunk
|
||||||
from haystack.dataclasses import ComponentInfo, StreamingChunk, ToolCallDelta
|
from haystack.dataclasses import ComponentInfo, StreamingChunk, ToolCall, ToolCallDelta, ToolCallResult
|
||||||
|
|
||||||
|
|
||||||
def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk():
|
def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk():
|
||||||
@ -325,3 +326,256 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk():
|
|||||||
},
|
},
|
||||||
"prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0},
|
"prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_streaming_chunk_to_chat_message_two_tool_calls_in_same_chunk():
|
||||||
|
chunks = [
|
||||||
|
StreamingChunk(
|
||||||
|
content="",
|
||||||
|
meta={
|
||||||
|
"model": "mistral-small-latest",
|
||||||
|
"index": 0,
|
||||||
|
"tool_calls": None,
|
||||||
|
"finish_reason": None,
|
||||||
|
"usage": None,
|
||||||
|
},
|
||||||
|
component_info=ComponentInfo(
|
||||||
|
type="haystack_integrations.components.generators.mistral.chat.chat_generator.MistralChatGenerator",
|
||||||
|
name=None,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
StreamingChunk(
|
||||||
|
content="",
|
||||||
|
meta={
|
||||||
|
"model": "mistral-small-latest",
|
||||||
|
"index": 0,
|
||||||
|
"finish_reason": "tool_calls",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 35,
|
||||||
|
"prompt_tokens": 77,
|
||||||
|
"total_tokens": 112,
|
||||||
|
"completion_tokens_details": None,
|
||||||
|
"prompt_tokens_details": None,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
component_info=ComponentInfo(
|
||||||
|
type="haystack_integrations.components.generators.mistral.chat.chat_generator.MistralChatGenerator",
|
||||||
|
name=None,
|
||||||
|
),
|
||||||
|
index=0,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCallDelta(index=0, tool_name="weather", arguments='{"city": "Paris"}', id="FL1FFlqUG"),
|
||||||
|
ToolCallDelta(index=1, tool_name="weather", arguments='{"city": "Berlin"}', id="xSuhp66iB"),
|
||||||
|
],
|
||||||
|
start=True,
|
||||||
|
finish_reason="tool_calls",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Convert chunks to a chat message
|
||||||
|
result = _convert_streaming_chunks_to_chat_message(chunks=chunks)
|
||||||
|
|
||||||
|
assert not result.texts
|
||||||
|
assert not result.text
|
||||||
|
|
||||||
|
# Verify both tool calls were found and processed
|
||||||
|
assert len(result.tool_calls) == 2
|
||||||
|
assert result.tool_calls[0].id == "FL1FFlqUG"
|
||||||
|
assert result.tool_calls[0].tool_name == "weather"
|
||||||
|
assert result.tool_calls[0].arguments == {"city": "Paris"}
|
||||||
|
assert result.tool_calls[1].id == "xSuhp66iB"
|
||||||
|
assert result.tool_calls[1].tool_name == "weather"
|
||||||
|
assert result.tool_calls[1].arguments == {"city": "Berlin"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_streaming_chunk_to_chat_message_empty_tool_call_delta():
|
||||||
|
chunks = [
|
||||||
|
StreamingChunk(
|
||||||
|
content="",
|
||||||
|
meta={
|
||||||
|
"model": "gpt-4o-mini-2024-07-18",
|
||||||
|
"index": 0,
|
||||||
|
"tool_calls": None,
|
||||||
|
"finish_reason": None,
|
||||||
|
"received_at": "2025-02-19T16:02:55.910076",
|
||||||
|
},
|
||||||
|
component_info=ComponentInfo(name="test", type="test"),
|
||||||
|
),
|
||||||
|
StreamingChunk(
|
||||||
|
content="",
|
||||||
|
meta={
|
||||||
|
"model": "gpt-4o-mini-2024-07-18",
|
||||||
|
"index": 0,
|
||||||
|
"tool_calls": [
|
||||||
|
chat_completion_chunk.ChoiceDeltaToolCall(
|
||||||
|
index=0,
|
||||||
|
id="call_ZOj5l67zhZOx6jqjg7ATQwb6",
|
||||||
|
function=chat_completion_chunk.ChoiceDeltaToolCallFunction(
|
||||||
|
arguments='{"query":', name="rag_pipeline_tool"
|
||||||
|
),
|
||||||
|
type="function",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
"finish_reason": None,
|
||||||
|
"received_at": "2025-02-19T16:02:55.913919",
|
||||||
|
},
|
||||||
|
component_info=ComponentInfo(name="test", type="test"),
|
||||||
|
index=0,
|
||||||
|
start=True,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCallDelta(
|
||||||
|
id="call_ZOj5l67zhZOx6jqjg7ATQwb6", tool_name="rag_pipeline_tool", arguments='{"query":', index=0
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
StreamingChunk(
|
||||||
|
content="",
|
||||||
|
meta={
|
||||||
|
"model": "gpt-4o-mini-2024-07-18",
|
||||||
|
"index": 0,
|
||||||
|
"tool_calls": [
|
||||||
|
chat_completion_chunk.ChoiceDeltaToolCall(
|
||||||
|
index=0,
|
||||||
|
function=chat_completion_chunk.ChoiceDeltaToolCallFunction(
|
||||||
|
arguments=' "Where does Mark live?"}'
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
"finish_reason": None,
|
||||||
|
"received_at": "2025-02-19T16:02:55.924420",
|
||||||
|
},
|
||||||
|
component_info=ComponentInfo(name="test", type="test"),
|
||||||
|
index=0,
|
||||||
|
tool_calls=[ToolCallDelta(arguments=' "Where does Mark live?"}', index=0)],
|
||||||
|
),
|
||||||
|
StreamingChunk(
|
||||||
|
content="",
|
||||||
|
meta={
|
||||||
|
"model": "gpt-4o-mini-2024-07-18",
|
||||||
|
"index": 0,
|
||||||
|
"tool_calls": [
|
||||||
|
chat_completion_chunk.ChoiceDeltaToolCall(
|
||||||
|
index=0, function=chat_completion_chunk.ChoiceDeltaToolCallFunction()
|
||||||
|
)
|
||||||
|
],
|
||||||
|
"finish_reason": "tool_calls",
|
||||||
|
"received_at": "2025-02-19T16:02:55.948772",
|
||||||
|
},
|
||||||
|
tool_calls=[ToolCallDelta(index=0)],
|
||||||
|
component_info=ComponentInfo(name="test", type="test"),
|
||||||
|
finish_reason="tool_calls",
|
||||||
|
index=0,
|
||||||
|
),
|
||||||
|
StreamingChunk(
|
||||||
|
content="",
|
||||||
|
meta={
|
||||||
|
"model": "gpt-4o-mini-2024-07-18",
|
||||||
|
"index": 0,
|
||||||
|
"tool_calls": None,
|
||||||
|
"finish_reason": None,
|
||||||
|
"received_at": "2025-02-19T16:02:55.948772",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 42,
|
||||||
|
"prompt_tokens": 282,
|
||||||
|
"total_tokens": 324,
|
||||||
|
"completion_tokens_details": {
|
||||||
|
"accepted_prediction_tokens": 0,
|
||||||
|
"audio_tokens": 0,
|
||||||
|
"reasoning_tokens": 0,
|
||||||
|
"rejected_prediction_tokens": 0,
|
||||||
|
},
|
||||||
|
"prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
component_info=ComponentInfo(name="test", type="test"),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Convert chunks to a chat message
|
||||||
|
result = _convert_streaming_chunks_to_chat_message(chunks=chunks)
|
||||||
|
|
||||||
|
assert not result.texts
|
||||||
|
assert not result.text
|
||||||
|
|
||||||
|
# Verify both tool calls were found and processed
|
||||||
|
assert len(result.tool_calls) == 1
|
||||||
|
assert result.tool_calls[0].id == "call_ZOj5l67zhZOx6jqjg7ATQwb6"
|
||||||
|
assert result.tool_calls[0].tool_name == "rag_pipeline_tool"
|
||||||
|
assert result.tool_calls[0].arguments == {"query": "Where does Mark live?"}
|
||||||
|
assert result.meta["finish_reason"] == "tool_calls"
|
||||||
|
|
||||||
|
|
||||||
|
def test_print_streaming_chunk_content_only():
|
||||||
|
chunk = StreamingChunk(
|
||||||
|
content="Hello, world!",
|
||||||
|
meta={"model": "test-model"},
|
||||||
|
component_info=ComponentInfo(name="test", type="test"),
|
||||||
|
start=True,
|
||||||
|
)
|
||||||
|
with patch("builtins.print") as mock_print:
|
||||||
|
print_streaming_chunk(chunk)
|
||||||
|
expected_calls = [call("[ASSISTANT]\n", flush=True, end=""), call("Hello, world!", flush=True, end="")]
|
||||||
|
mock_print.assert_has_calls(expected_calls)
|
||||||
|
|
||||||
|
|
||||||
|
def test_print_streaming_chunk_tool_call():
|
||||||
|
chunk = StreamingChunk(
|
||||||
|
content="",
|
||||||
|
meta={"model": "test-model"},
|
||||||
|
component_info=ComponentInfo(name="test", type="test"),
|
||||||
|
start=True,
|
||||||
|
index=0,
|
||||||
|
tool_calls=[ToolCallDelta(id="call_123", tool_name="test_tool", arguments='{"param": "value"}', index=0)],
|
||||||
|
)
|
||||||
|
with patch("builtins.print") as mock_print:
|
||||||
|
print_streaming_chunk(chunk)
|
||||||
|
expected_calls = [
|
||||||
|
call("[TOOL CALL]\nTool: test_tool \nArguments: ", flush=True, end=""),
|
||||||
|
call('{"param": "value"}', flush=True, end=""),
|
||||||
|
]
|
||||||
|
mock_print.assert_has_calls(expected_calls)
|
||||||
|
|
||||||
|
|
||||||
|
def test_print_streaming_chunk_tool_call_result():
|
||||||
|
chunk = StreamingChunk(
|
||||||
|
content="",
|
||||||
|
meta={"model": "test-model"},
|
||||||
|
component_info=ComponentInfo(name="test", type="test"),
|
||||||
|
index=0,
|
||||||
|
tool_call_result=ToolCallResult(
|
||||||
|
result="Tool execution completed successfully",
|
||||||
|
origin=ToolCall(id="call_123", tool_name="test_tool", arguments={}),
|
||||||
|
error=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
with patch("builtins.print") as mock_print:
|
||||||
|
print_streaming_chunk(chunk)
|
||||||
|
expected_calls = [call("[TOOL RESULT]\nTool execution completed successfully", flush=True, end="")]
|
||||||
|
mock_print.assert_has_calls(expected_calls)
|
||||||
|
|
||||||
|
|
||||||
|
def test_print_streaming_chunk_with_finish_reason():
|
||||||
|
chunk = StreamingChunk(
|
||||||
|
content="Final content.",
|
||||||
|
meta={"model": "test-model"},
|
||||||
|
component_info=ComponentInfo(name="test", type="test"),
|
||||||
|
start=True,
|
||||||
|
finish_reason="stop",
|
||||||
|
)
|
||||||
|
with patch("builtins.print") as mock_print:
|
||||||
|
print_streaming_chunk(chunk)
|
||||||
|
expected_calls = [
|
||||||
|
call("[ASSISTANT]\n", flush=True, end=""),
|
||||||
|
call("Final content.", flush=True, end=""),
|
||||||
|
call("\n\n", flush=True, end=""),
|
||||||
|
]
|
||||||
|
mock_print.assert_has_calls(expected_calls)
|
||||||
|
|
||||||
|
|
||||||
|
def test_print_streaming_chunk_empty_chunk():
|
||||||
|
chunk = StreamingChunk(
|
||||||
|
content="", meta={"model": "test-model"}, component_info=ComponentInfo(name="test", type="test")
|
||||||
|
)
|
||||||
|
with patch("builtins.print") as mock_print:
|
||||||
|
print_streaming_chunk(chunk)
|
||||||
|
mock_print.assert_not_called()
|
||||||
|
@ -14,7 +14,7 @@ from haystack.components.generators.chat.openai import OpenAIChatGenerator
|
|||||||
from haystack.components.generators.utils import print_streaming_chunk
|
from haystack.components.generators.utils import print_streaming_chunk
|
||||||
from haystack.components.tools.tool_invoker import ToolInvoker, ToolNotFoundException, StringConversionError
|
from haystack.components.tools.tool_invoker import ToolInvoker, ToolNotFoundException, StringConversionError
|
||||||
from haystack.dataclasses import ChatMessage, ToolCall, ToolCallResult, ChatRole
|
from haystack.dataclasses import ChatMessage, ToolCall, ToolCallResult, ChatRole
|
||||||
from haystack.dataclasses.state import State
|
from haystack.components.agents.state import State
|
||||||
from haystack.tools import ComponentTool, Tool, Toolset
|
from haystack.tools import ComponentTool, Tool, Toolset
|
||||||
from haystack.tools.errors import ToolInvocationError
|
from haystack.tools.errors import ToolInvocationError
|
||||||
from haystack.dataclasses import StreamingChunk
|
from haystack.dataclasses import StreamingChunk
|
||||||
@ -100,11 +100,6 @@ def faulty_invoker(faulty_tool):
|
|||||||
return ToolInvoker(tools=[faulty_tool], raise_on_failure=True, convert_result_to_json_string=False)
|
return ToolInvoker(tools=[faulty_tool], raise_on_failure=True, convert_result_to_json_string=False)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def thread_executor():
|
|
||||||
return ThreadPoolExecutor(thread_name_prefix=f"async-test-executor", max_workers=2)
|
|
||||||
|
|
||||||
|
|
||||||
class TestToolInvoker:
|
class TestToolInvoker:
|
||||||
def test_init(self, weather_tool):
|
def test_init(self, weather_tool):
|
||||||
invoker = ToolInvoker(tools=[weather_tool])
|
invoker = ToolInvoker(tools=[weather_tool])
|
||||||
@ -227,7 +222,7 @@ class TestToolInvoker:
|
|||||||
assert final_chunk.content == ""
|
assert final_chunk.content == ""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_run_async_with_streaming_callback(self, thread_executor, weather_tool):
|
async def test_run_async_with_streaming_callback(self, weather_tool):
|
||||||
streaming_callback_called = False
|
streaming_callback_called = False
|
||||||
|
|
||||||
async def streaming_callback(chunk: StreamingChunk) -> None:
|
async def streaming_callback(chunk: StreamingChunk) -> None:
|
||||||
@ -235,12 +230,7 @@ class TestToolInvoker:
|
|||||||
nonlocal streaming_callback_called
|
nonlocal streaming_callback_called
|
||||||
streaming_callback_called = True
|
streaming_callback_called = True
|
||||||
|
|
||||||
tool_invoker = ToolInvoker(
|
tool_invoker = ToolInvoker(tools=[weather_tool], raise_on_failure=True, convert_result_to_json_string=False)
|
||||||
tools=[weather_tool],
|
|
||||||
raise_on_failure=True,
|
|
||||||
convert_result_to_json_string=False,
|
|
||||||
async_executor=thread_executor,
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_calls = [
|
tool_calls = [
|
||||||
ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"}),
|
ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"}),
|
||||||
@ -269,18 +259,13 @@ class TestToolInvoker:
|
|||||||
assert streaming_callback_called
|
assert streaming_callback_called
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_run_async_with_streaming_callback_finish_reason(self, thread_executor, weather_tool):
|
async def test_run_async_with_streaming_callback_finish_reason(self, weather_tool):
|
||||||
streaming_chunks = []
|
streaming_chunks = []
|
||||||
|
|
||||||
async def streaming_callback(chunk: StreamingChunk) -> None:
|
async def streaming_callback(chunk: StreamingChunk) -> None:
|
||||||
streaming_chunks.append(chunk)
|
streaming_chunks.append(chunk)
|
||||||
|
|
||||||
tool_invoker = ToolInvoker(
|
tool_invoker = ToolInvoker(tools=[weather_tool], raise_on_failure=True, convert_result_to_json_string=False)
|
||||||
tools=[weather_tool],
|
|
||||||
raise_on_failure=True,
|
|
||||||
convert_result_to_json_string=False,
|
|
||||||
async_executor=thread_executor,
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})
|
tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})
|
||||||
message = ChatMessage.from_assistant(tool_calls=[tool_call])
|
message = ChatMessage.from_assistant(tool_calls=[tool_call])
|
||||||
@ -319,10 +304,8 @@ class TestToolInvoker:
|
|||||||
assert not tool_call_result.error
|
assert not tool_call_result.error
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_run_async_with_toolset(self, tool_set, thread_executor):
|
async def test_run_async_with_toolset(self, tool_set):
|
||||||
tool_invoker = ToolInvoker(
|
tool_invoker = ToolInvoker(tools=tool_set, raise_on_failure=True, convert_result_to_json_string=False)
|
||||||
tools=tool_set, raise_on_failure=True, convert_result_to_json_string=False, async_executor=thread_executor
|
|
||||||
)
|
|
||||||
tool_calls = [
|
tool_calls = [
|
||||||
ToolCall(tool_name="addition_tool", arguments={"num1": 5, "num2": 3}),
|
ToolCall(tool_name="addition_tool", arguments={"num1": 5, "num2": 3}),
|
||||||
ToolCall(tool_name="addition_tool", arguments={"num1": 5, "num2": 3}),
|
ToolCall(tool_name="addition_tool", arguments={"num1": 5, "num2": 3}),
|
||||||
@ -818,6 +801,55 @@ class TestToolInvoker:
|
|||||||
assert state.get("counter") in [1, 2, 3] # Should be one of the tool values
|
assert state.get("counter") in [1, 2, 3] # Should be one of the tool values
|
||||||
assert state.get("last_tool") in ["tool_1", "tool_2", "tool_3"] # Should be one of the tool names
|
assert state.get("last_tool") in ["tool_1", "tool_2", "tool_3"] # Should be one of the tool names
|
||||||
|
|
||||||
|
def test_call_invoker_two_subsequent_run_calls(self, invoker: ToolInvoker):
|
||||||
|
tool_calls = [
|
||||||
|
ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"}),
|
||||||
|
ToolCall(tool_name="weather_tool", arguments={"location": "Paris"}),
|
||||||
|
ToolCall(tool_name="weather_tool", arguments={"location": "Rome"}),
|
||||||
|
]
|
||||||
|
message = ChatMessage.from_assistant(tool_calls=tool_calls)
|
||||||
|
|
||||||
|
streaming_callback_called = False
|
||||||
|
|
||||||
|
def streaming_callback(chunk: StreamingChunk) -> None:
|
||||||
|
nonlocal streaming_callback_called
|
||||||
|
streaming_callback_called = True
|
||||||
|
|
||||||
|
# First call
|
||||||
|
result_1 = invoker.run(messages=[message], streaming_callback=streaming_callback)
|
||||||
|
assert "tool_messages" in result_1
|
||||||
|
assert len(result_1["tool_messages"]) == 3
|
||||||
|
|
||||||
|
# Second call
|
||||||
|
result_2 = invoker.run(messages=[message], streaming_callback=streaming_callback)
|
||||||
|
assert "tool_messages" in result_2
|
||||||
|
assert len(result_2["tool_messages"]) == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_invoker_two_subsequent_run_async_calls(self, invoker: ToolInvoker):
|
||||||
|
tool_calls = [
|
||||||
|
ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"}),
|
||||||
|
ToolCall(tool_name="weather_tool", arguments={"location": "Paris"}),
|
||||||
|
ToolCall(tool_name="weather_tool", arguments={"location": "Rome"}),
|
||||||
|
]
|
||||||
|
message = ChatMessage.from_assistant(tool_calls=tool_calls)
|
||||||
|
|
||||||
|
streaming_callback_called = False
|
||||||
|
|
||||||
|
async def streaming_callback(chunk: StreamingChunk) -> None:
|
||||||
|
nonlocal streaming_callback_called
|
||||||
|
streaming_callback_called = True
|
||||||
|
|
||||||
|
# First call
|
||||||
|
result_1 = await invoker.run_async(messages=[message], streaming_callback=streaming_callback)
|
||||||
|
assert "tool_messages" in result_1
|
||||||
|
assert len(result_1["tool_messages"]) == 3
|
||||||
|
|
||||||
|
# Second call
|
||||||
|
result_2 = await invoker.run_async(messages=[message], streaming_callback=streaming_callback)
|
||||||
|
assert "tool_messages" in result_2
|
||||||
|
assert len(result_2["tool_messages"]) == 3
|
||||||
|
|
||||||
|
|
||||||
class TestMergeToolOutputs:
|
class TestMergeToolOutputs:
|
||||||
def test_merge_tool_outputs_result_not_a_dict(self, weather_tool):
|
def test_merge_tool_outputs_result_not_a_dict(self, weather_tool):
|
||||||
|
@ -81,3 +81,23 @@ def test_str_truncation():
|
|||||||
assert len(string_repr) < 200
|
assert len(string_repr) < 200
|
||||||
assert "text/plain" in string_repr
|
assert "text/plain" in string_repr
|
||||||
assert "foo" in string_repr
|
assert "foo" in string_repr
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_dict():
|
||||||
|
test_str = "Hello, world!"
|
||||||
|
b = ByteStream.from_string(test_str, mime_type="text/plain", meta={"foo": "bar"})
|
||||||
|
d = b.to_dict()
|
||||||
|
assert d["data"] == list(test_str.encode())
|
||||||
|
assert d["mime_type"] == "text/plain"
|
||||||
|
assert d["meta"] == {"foo": "bar"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_dict():
|
||||||
|
test_str = "Hello, world!"
|
||||||
|
b = ByteStream.from_string(test_str, mime_type="text/plain", meta={"foo": "bar"})
|
||||||
|
d = b.to_dict()
|
||||||
|
b2 = ByteStream.from_dict(d)
|
||||||
|
assert b2.data == b.data
|
||||||
|
assert b2.mime_type == b.mime_type
|
||||||
|
assert b2.meta == b.meta
|
||||||
|
assert str(b2) == str(b)
|
||||||
|
@ -146,7 +146,7 @@ def test_to_dict_without_flattening():
|
|||||||
def test_to_dict_with_custom_parameters():
|
def test_to_dict_with_custom_parameters():
|
||||||
doc = Document(
|
doc = Document(
|
||||||
content="test text",
|
content="test text",
|
||||||
blob=ByteStream(b"some bytes", mime_type="application/pdf"),
|
blob=ByteStream(b"some bytes", mime_type="application/pdf", meta={"foo": "bar"}),
|
||||||
meta={"some": "values", "test": 10},
|
meta={"some": "values", "test": 10},
|
||||||
score=0.99,
|
score=0.99,
|
||||||
embedding=[10.0, 10.0],
|
embedding=[10.0, 10.0],
|
||||||
@ -156,7 +156,7 @@ def test_to_dict_with_custom_parameters():
|
|||||||
assert doc.to_dict() == {
|
assert doc.to_dict() == {
|
||||||
"id": doc.id,
|
"id": doc.id,
|
||||||
"content": "test text",
|
"content": "test text",
|
||||||
"blob": {"data": list(b"some bytes"), "mime_type": "application/pdf"},
|
"blob": {"data": list(b"some bytes"), "mime_type": "application/pdf", "meta": {"foo": "bar"}},
|
||||||
"some": "values",
|
"some": "values",
|
||||||
"test": 10,
|
"test": 10,
|
||||||
"score": 0.99,
|
"score": 0.99,
|
||||||
@ -178,10 +178,10 @@ def test_to_dict_with_custom_parameters_without_flattening():
|
|||||||
assert doc.to_dict(flatten=False) == {
|
assert doc.to_dict(flatten=False) == {
|
||||||
"id": doc.id,
|
"id": doc.id,
|
||||||
"content": "test text",
|
"content": "test text",
|
||||||
"blob": {"data": list(b"some bytes"), "mime_type": "application/pdf"},
|
"blob": {"data": list(b"some bytes"), "mime_type": "application/pdf", "meta": {}},
|
||||||
"meta": {"some": "values", "test": 10},
|
"meta": {"some": "values", "test": 10},
|
||||||
"score": 0.99,
|
"score": 0.99,
|
||||||
"embedding": [10, 10],
|
"embedding": [10.0, 10.0],
|
||||||
"sparse_embedding": {"indices": [0, 2, 4], "values": [0.1, 0.2, 0.3]},
|
"sparse_embedding": {"indices": [0, 2, 4], "values": [0.1, 0.2, 0.3]},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -212,7 +212,7 @@ def from_from_dict_with_parameters():
|
|||||||
assert Document.from_dict(
|
assert Document.from_dict(
|
||||||
{
|
{
|
||||||
"content": "test text",
|
"content": "test text",
|
||||||
"blob": {"data": list(blob_data), "mime_type": "text/markdown"},
|
"blob": {"data": list(blob_data), "mime_type": "text/markdown", "meta": {"text": "test text"}},
|
||||||
"meta": {"text": "test text"},
|
"meta": {"text": "test text"},
|
||||||
"score": 0.812,
|
"score": 0.812,
|
||||||
"embedding": [0.1, 0.2, 0.3],
|
"embedding": [0.1, 0.2, 0.3],
|
||||||
@ -220,7 +220,7 @@ def from_from_dict_with_parameters():
|
|||||||
}
|
}
|
||||||
) == Document(
|
) == Document(
|
||||||
content="test text",
|
content="test text",
|
||||||
blob=ByteStream(blob_data, mime_type="text/markdown"),
|
blob=ByteStream(blob_data, mime_type="text/markdown", meta={"text": "test text"}),
|
||||||
meta={"text": "test text"},
|
meta={"text": "test text"},
|
||||||
score=0.812,
|
score=0.812,
|
||||||
embedding=[0.1, 0.2, 0.3],
|
embedding=[0.1, 0.2, 0.3],
|
||||||
|
@ -1,193 +0,0 @@
|
|||||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
||||||
#
|
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from typing import List, Dict
|
|
||||||
|
|
||||||
from haystack.dataclasses import ChatMessage
|
|
||||||
from haystack.dataclasses.state import State
|
|
||||||
from haystack.components.agents.state.state import _validate_schema, _schema_to_dict, _schema_from_dict, merge_lists
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def basic_schema():
|
|
||||||
return {"numbers": {"type": list}, "metadata": {"type": dict}, "name": {"type": str}}
|
|
||||||
|
|
||||||
|
|
||||||
def numbers_handler(current, new):
|
|
||||||
if current is None:
|
|
||||||
return sorted(set(new))
|
|
||||||
return sorted(set(current + new))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def complex_schema():
|
|
||||||
return {"numbers": {"type": list, "handler": numbers_handler}, "metadata": {"type": dict}, "name": {"type": str}}
|
|
||||||
|
|
||||||
|
|
||||||
def test_validate_schema_valid(basic_schema):
|
|
||||||
# Should not raise any exceptions
|
|
||||||
_validate_schema(basic_schema)
|
|
||||||
|
|
||||||
|
|
||||||
def test_validate_schema_invalid_type():
|
|
||||||
invalid_schema = {"test": {"type": "not_a_type"}}
|
|
||||||
with pytest.raises(ValueError, match="must be a Python type"):
|
|
||||||
_validate_schema(invalid_schema)
|
|
||||||
|
|
||||||
|
|
||||||
def test_validate_schema_missing_type():
|
|
||||||
invalid_schema = {"test": {"handler": lambda x, y: x + y}}
|
|
||||||
with pytest.raises(ValueError, match="missing a 'type' entry"):
|
|
||||||
_validate_schema(invalid_schema)
|
|
||||||
|
|
||||||
|
|
||||||
def test_validate_schema_invalid_handler():
|
|
||||||
invalid_schema = {"test": {"type": list, "handler": "not_callable"}}
|
|
||||||
with pytest.raises(ValueError, match="must be callable or None"):
|
|
||||||
_validate_schema(invalid_schema)
|
|
||||||
|
|
||||||
|
|
||||||
def test_state_initialization(basic_schema):
|
|
||||||
# Test empty initialization
|
|
||||||
state = State(basic_schema)
|
|
||||||
assert state.data == {}
|
|
||||||
|
|
||||||
# Test initialization with data
|
|
||||||
initial_data = {"numbers": [1, 2, 3], "name": "test"}
|
|
||||||
state = State(basic_schema, initial_data)
|
|
||||||
assert state.data["numbers"] == [1, 2, 3]
|
|
||||||
assert state.data["name"] == "test"
|
|
||||||
|
|
||||||
|
|
||||||
def test_state_get(basic_schema):
|
|
||||||
state = State(basic_schema, {"name": "test"})
|
|
||||||
assert state.get("name") == "test"
|
|
||||||
assert state.get("non_existent") is None
|
|
||||||
assert state.get("non_existent", "default") == "default"
|
|
||||||
|
|
||||||
|
|
||||||
def test_state_set_basic(basic_schema):
|
|
||||||
state = State(basic_schema)
|
|
||||||
|
|
||||||
# Test setting new values
|
|
||||||
state.set("numbers", [1, 2])
|
|
||||||
assert state.get("numbers") == [1, 2]
|
|
||||||
|
|
||||||
# Test updating existing values
|
|
||||||
state.set("numbers", [3, 4])
|
|
||||||
assert state.get("numbers") == [1, 2, 3, 4]
|
|
||||||
|
|
||||||
|
|
||||||
def test_state_set_with_handler(complex_schema):
|
|
||||||
state = State(complex_schema)
|
|
||||||
|
|
||||||
# Test custom handler for numbers
|
|
||||||
state.set("numbers", [3, 2, 1])
|
|
||||||
assert state.get("numbers") == [1, 2, 3]
|
|
||||||
|
|
||||||
state.set("numbers", [6, 5, 4])
|
|
||||||
assert state.get("numbers") == [1, 2, 3, 4, 5, 6]
|
|
||||||
|
|
||||||
|
|
||||||
def test_state_set_with_handler_override(basic_schema):
|
|
||||||
state = State(basic_schema)
|
|
||||||
|
|
||||||
# Custom handler that concatenates strings
|
|
||||||
custom_handler = lambda current, new: f"{current}-{new}" if current else new
|
|
||||||
|
|
||||||
state.set("name", "first")
|
|
||||||
state.set("name", "second", handler_override=custom_handler)
|
|
||||||
assert state.get("name") == "first-second"
|
|
||||||
|
|
||||||
|
|
||||||
def test_state_has(basic_schema):
|
|
||||||
state = State(basic_schema, {"name": "test"})
|
|
||||||
assert state.has("name") is True
|
|
||||||
assert state.has("non_existent") is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_state_empty_schema():
|
|
||||||
state = State({})
|
|
||||||
assert state.data == {}
|
|
||||||
assert state.schema == {"messages": {"type": List[ChatMessage], "handler": merge_lists}}
|
|
||||||
with pytest.raises(ValueError, match="Key 'any_key' not found in schema"):
|
|
||||||
state.set("any_key", "value")
|
|
||||||
|
|
||||||
|
|
||||||
def test_state_none_values(basic_schema):
|
|
||||||
state = State(basic_schema)
|
|
||||||
state.set("name", None)
|
|
||||||
assert state.get("name") is None
|
|
||||||
state.set("name", "value")
|
|
||||||
assert state.get("name") == "value"
|
|
||||||
|
|
||||||
|
|
||||||
def test_state_merge_lists(basic_schema):
|
|
||||||
state = State(basic_schema)
|
|
||||||
state.set("numbers", "not_a_list")
|
|
||||||
assert state.get("numbers") == ["not_a_list"]
|
|
||||||
state.set("numbers", [1, 2])
|
|
||||||
assert state.get("numbers") == ["not_a_list", 1, 2]
|
|
||||||
|
|
||||||
|
|
||||||
def test_state_nested_structures():
|
|
||||||
schema = {
|
|
||||||
"complex": {
|
|
||||||
"type": Dict[str, List[int]],
|
|
||||||
"handler": lambda current, new: {
|
|
||||||
k: current.get(k, []) + new.get(k, []) for k in set(current.keys()) | set(new.keys())
|
|
||||||
}
|
|
||||||
if current
|
|
||||||
else new,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
state = State(schema)
|
|
||||||
state.set("complex", {"a": [1, 2], "b": [3, 4]})
|
|
||||||
state.set("complex", {"b": [5, 6], "c": [7, 8]})
|
|
||||||
|
|
||||||
expected = {"a": [1, 2], "b": [3, 4, 5, 6], "c": [7, 8]}
|
|
||||||
assert state.get("complex") == expected
|
|
||||||
|
|
||||||
|
|
||||||
def test_schema_to_dict(basic_schema):
|
|
||||||
expected_dict = {"numbers": {"type": "list"}, "metadata": {"type": "dict"}, "name": {"type": "str"}}
|
|
||||||
result = _schema_to_dict(basic_schema)
|
|
||||||
assert result == expected_dict
|
|
||||||
|
|
||||||
|
|
||||||
def test_schema_to_dict_with_handlers(complex_schema):
|
|
||||||
expected_dict = {
|
|
||||||
"numbers": {"type": "list", "handler": "test_state.numbers_handler"},
|
|
||||||
"metadata": {"type": "dict"},
|
|
||||||
"name": {"type": "str"},
|
|
||||||
}
|
|
||||||
result = _schema_to_dict(complex_schema)
|
|
||||||
assert result == expected_dict
|
|
||||||
|
|
||||||
|
|
||||||
def test_schema_from_dict(basic_schema):
|
|
||||||
schema_dict = {"numbers": {"type": "list"}, "metadata": {"type": "dict"}, "name": {"type": "str"}}
|
|
||||||
result = _schema_from_dict(schema_dict)
|
|
||||||
assert result == basic_schema
|
|
||||||
|
|
||||||
|
|
||||||
def test_schema_from_dict_with_handlers(complex_schema):
|
|
||||||
schema_dict = {
|
|
||||||
"numbers": {"type": "list", "handler": "test_state.numbers_handler"},
|
|
||||||
"metadata": {"type": "dict"},
|
|
||||||
"name": {"type": "str"},
|
|
||||||
}
|
|
||||||
result = _schema_from_dict(schema_dict)
|
|
||||||
assert result == complex_schema
|
|
||||||
|
|
||||||
|
|
||||||
def test_state_mutability():
|
|
||||||
state = State({"my_list": {"type": list}}, {"my_list": [1, 2]})
|
|
||||||
|
|
||||||
my_list = state.get("my_list")
|
|
||||||
my_list.append(3)
|
|
||||||
|
|
||||||
assert state.get("my_list") == [1, 2]
|
|
@ -99,11 +99,6 @@ def test_tool_call_delta():
|
|||||||
assert tool_call.index == 0
|
assert tool_call.index == 0
|
||||||
|
|
||||||
|
|
||||||
def test_tool_call_delta_with_missing_fields():
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
_ = ToolCallDelta(id="123", index=0)
|
|
||||||
|
|
||||||
|
|
||||||
def test_create_chunk_with_finish_reason():
|
def test_create_chunk_with_finish_reason():
|
||||||
"""Test creating a chunk with the new finish_reason field."""
|
"""Test creating a chunk with the new finish_reason field."""
|
||||||
chunk = StreamingChunk(content="Test content", finish_reason="stop")
|
chunk = StreamingChunk(content="Test content", finish_reason="stop")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user