mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-05 08:01:02 +00:00
Compare commits
17 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
2693f39e44 | ||
![]() |
646eedf26a | ||
![]() |
050c987946 | ||
![]() |
85258f0654 | ||
![]() |
16fc41cd95 | ||
![]() |
9fd552f906 | ||
![]() |
adb2759d00 | ||
![]() |
848115c65e | ||
![]() |
3aaa201ed6 | ||
![]() |
f11870b212 | ||
![]() |
97e72b9693 | ||
![]() |
fc64884819 | ||
![]() |
c54a68ab63 | ||
![]() |
c18f81283c | ||
![]() |
101e9cdc34 | ||
![]() |
bcaef53cbc | ||
![]() |
85e8493f4f |
16
.github/utils/check_imports.py
vendored
16
.github/utils/check_imports.py
vendored
@ -1,14 +1,17 @@
|
|||||||
|
import importlib
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import importlib
|
|
||||||
import traceback
|
import traceback
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from haystack import logging # pylint: disable=unused-import # this is needed to avoid circular imports
|
from haystack import logging # pylint: disable=unused-import # this is needed to avoid circular imports
|
||||||
|
|
||||||
|
|
||||||
def validate_module_imports(root_dir: str, exclude_subdirs: Optional[List[str]] = None) -> tuple[list, list]:
|
def validate_module_imports(root_dir: str, exclude_subdirs: Optional[List[str]] = None) -> tuple[list, list]:
|
||||||
"""
|
"""
|
||||||
Recursively search for all Python modules and attempt to import them.
|
Recursively search for all Python modules and attempt to import them.
|
||||||
|
|
||||||
This includes both packages (directories with __init__.py) and individual Python files.
|
This includes both packages (directories with __init__.py) and individual Python files.
|
||||||
"""
|
"""
|
||||||
imported = []
|
imported = []
|
||||||
@ -25,7 +28,7 @@ def validate_module_imports(root_dir: str, exclude_subdirs: Optional[List[str]]
|
|||||||
|
|
||||||
# Convert path to module format
|
# Convert path to module format
|
||||||
module_path = ".".join(Path(root).relative_to(base_path.parent).parts)
|
module_path = ".".join(Path(root).relative_to(base_path.parent).parts)
|
||||||
python_files = [f for f in files if f.endswith('.py')]
|
python_files = [f for f in files if f.endswith(".py")]
|
||||||
|
|
||||||
# Try importing package and individual files
|
# Try importing package and individual files
|
||||||
for file in python_files:
|
for file in python_files:
|
||||||
@ -39,16 +42,15 @@ def validate_module_imports(root_dir: str, exclude_subdirs: Optional[List[str]]
|
|||||||
importlib.import_module(module_to_import)
|
importlib.import_module(module_to_import)
|
||||||
imported.append(module_to_import)
|
imported.append(module_to_import)
|
||||||
except:
|
except:
|
||||||
failed.append({
|
failed.append({"module": module_to_import, "traceback": traceback.format_exc()})
|
||||||
'module': module_to_import,
|
|
||||||
'traceback': traceback.format_exc()
|
|
||||||
})
|
|
||||||
|
|
||||||
return imported, failed
|
return imported, failed
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""
|
"""
|
||||||
This script checks that all Haystack modules can be imported successfully.
|
This script checks that all Haystack modules can be imported successfully.
|
||||||
|
|
||||||
This includes both packages and individual Python files.
|
This includes both packages and individual Python files.
|
||||||
This can detect several issues, such as:
|
This can detect several issues, such as:
|
||||||
- Syntax errors in Python files
|
- Syntax errors in Python files
|
||||||
@ -80,5 +82,5 @@ def main():
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
8
.github/utils/create_unstable_docs.py
vendored
8
.github/utils/create_unstable_docs.py
vendored
@ -1,14 +1,16 @@
|
|||||||
|
import argparse
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import argparse
|
|
||||||
|
|
||||||
from readme_api import get_versions, create_new_unstable
|
|
||||||
|
|
||||||
|
from readme_api import create_new_unstable, get_versions
|
||||||
|
|
||||||
VERSION_VALIDATOR = re.compile(r"^[0-9]+\.[0-9]+$")
|
VERSION_VALIDATOR = re.compile(r"^[0-9]+\.[0-9]+$")
|
||||||
|
|
||||||
|
|
||||||
def calculate_new_unstable(version: str):
|
def calculate_new_unstable(version: str):
|
||||||
|
"""
|
||||||
|
Calculate the new unstable version based on the given version.
|
||||||
|
"""
|
||||||
# version must be formatted like so <major>.<minor>
|
# version must be formatted like so <major>.<minor>
|
||||||
major, minor = version.split(".")
|
major, minor = version.split(".")
|
||||||
return f"{major}.{int(minor) + 1}-unstable"
|
return f"{major}.{int(minor) + 1}-unstable"
|
||||||
|
181
.github/utils/deepset_sync.py
vendored
Normal file
181
.github/utils/deepset_sync.py
vendored
Normal file
@ -0,0 +1,181 @@
|
|||||||
|
# /// script
|
||||||
|
# dependencies = [
|
||||||
|
# "requests",
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
def transform_filename(filepath: Path) -> str:
|
||||||
|
"""
|
||||||
|
Transform a file path to the required format:
|
||||||
|
|
||||||
|
- Replace path separators with underscores
|
||||||
|
"""
|
||||||
|
# Convert to string and replace path separators with underscores
|
||||||
|
transformed = str(filepath).replace("/", "_").replace("\\", "_")
|
||||||
|
|
||||||
|
return transformed
|
||||||
|
|
||||||
|
|
||||||
|
def upload_file_to_deepset(filepath: Path, api_key: str, workspace: str) -> bool:
|
||||||
|
"""
|
||||||
|
Upload a single file to Deepset API.
|
||||||
|
"""
|
||||||
|
# Read file content
|
||||||
|
try:
|
||||||
|
content = filepath.read_text(encoding="utf-8")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading file {filepath}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Transform filename
|
||||||
|
transformed_name = transform_filename(filepath)
|
||||||
|
|
||||||
|
# Prepare metadata
|
||||||
|
metadata: dict[str, str] = {"original_file_path": str(filepath)}
|
||||||
|
|
||||||
|
# Prepare API request
|
||||||
|
url = f"https://api.cloud.deepset.ai/api/v1/workspaces/{workspace}/files"
|
||||||
|
params: dict[str, str] = {"file_name": transformed_name, "write_mode": "OVERWRITE"}
|
||||||
|
|
||||||
|
headers: dict[str, str] = {"accept": "application/json", "authorization": f"Bearer {api_key}"}
|
||||||
|
|
||||||
|
# Prepare multipart form data
|
||||||
|
files: dict[str, tuple[None, str, str]] = {
|
||||||
|
"meta": (None, json.dumps(metadata), "application/json"),
|
||||||
|
"text": (None, content, "text/plain"),
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(url, params=params, headers=headers, files=files, timeout=300)
|
||||||
|
response.raise_for_status()
|
||||||
|
print(f"Successfully uploaded: {filepath} as {transformed_name}")
|
||||||
|
return True
|
||||||
|
except requests.exceptions.HTTPError:
|
||||||
|
print(f"Failed to upload {filepath}: HTTP {response.status_code}")
|
||||||
|
print(f" Response: {response.text}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to upload {filepath}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def delete_files_from_deepset(filepaths: list[Path], api_key: str, workspace: str) -> bool:
|
||||||
|
"""
|
||||||
|
Delete multiple files from Deepset API.
|
||||||
|
"""
|
||||||
|
if not filepaths:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Transform filenames
|
||||||
|
transformed_names: list[str] = [transform_filename(fp) for fp in filepaths]
|
||||||
|
|
||||||
|
# Prepare API request
|
||||||
|
url = f"https://api.cloud.deepset.ai/api/v1/workspaces/{workspace}/files"
|
||||||
|
|
||||||
|
headers: dict[str, str] = {
|
||||||
|
"accept": "application/json",
|
||||||
|
"authorization": f"Bearer {api_key}",
|
||||||
|
"content-type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
data: dict[str, list[str]] = {"names": transformed_names}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.delete(url, headers=headers, json=data, timeout=300)
|
||||||
|
response.raise_for_status()
|
||||||
|
print(f"Successfully deleted {len(transformed_names)} file(s):")
|
||||||
|
for original, transformed in zip(filepaths, transformed_names):
|
||||||
|
print(f" - {original} (as {transformed})")
|
||||||
|
return True
|
||||||
|
except requests.exceptions.HTTPError:
|
||||||
|
print(f"Failed to delete files: HTTP {response.status_code}")
|
||||||
|
print(f" Response: {response.text}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to delete files: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
"""
|
||||||
|
Main function to process and upload/delete files.
|
||||||
|
"""
|
||||||
|
# Parse command line arguments
|
||||||
|
parser = argparse.ArgumentParser(description="Upload/delete Python files to/from Deepset")
|
||||||
|
parser.add_argument("--changed", nargs="*", default=[], help="Changed or added files")
|
||||||
|
parser.add_argument("--deleted", nargs="*", default=[], help="Deleted files")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Get environment variables
|
||||||
|
api_key: Optional[str] = os.environ.get("DEEPSET_API_KEY")
|
||||||
|
workspace: str = os.environ.get("DEEPSET_WORKSPACE")
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
print("Error: DEEPSET_API_KEY environment variable not set")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Process arguments and convert to Path objects
|
||||||
|
changed_files: list[Path] = [Path(f.strip()) for f in args.changed if f.strip()]
|
||||||
|
deleted_files: list[Path] = [Path(f.strip()) for f in args.deleted if f.strip()]
|
||||||
|
|
||||||
|
if not changed_files and not deleted_files:
|
||||||
|
print("No files to process")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
print(f"Processing files in Deepset workspace: {workspace}")
|
||||||
|
print("-" * 50)
|
||||||
|
|
||||||
|
# Track results
|
||||||
|
upload_success: int = 0
|
||||||
|
upload_failed: list[Path] = []
|
||||||
|
delete_success: bool = False
|
||||||
|
|
||||||
|
# Handle deletions first
|
||||||
|
if deleted_files:
|
||||||
|
print(f"\nDeleting {len(deleted_files)} file(s)...")
|
||||||
|
delete_success = delete_files_from_deepset(deleted_files, api_key, workspace)
|
||||||
|
|
||||||
|
# Upload changed/new files
|
||||||
|
if changed_files:
|
||||||
|
print(f"\nUploading {len(changed_files)} file(s)...")
|
||||||
|
for filepath in changed_files:
|
||||||
|
if filepath.exists():
|
||||||
|
if upload_file_to_deepset(filepath, api_key, workspace):
|
||||||
|
upload_success += 1
|
||||||
|
else:
|
||||||
|
upload_failed.append(filepath)
|
||||||
|
else:
|
||||||
|
print(f"Skipping non-existent file: {filepath}")
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
print("-" * 50)
|
||||||
|
print("Processing Summary:")
|
||||||
|
if changed_files:
|
||||||
|
print(f" Uploads - Successful: {upload_success}, Failed: {len(upload_failed)}")
|
||||||
|
if deleted_files:
|
||||||
|
print(f" Deletions - {'Successful' if delete_success else 'Failed'}: {len(deleted_files)} file(s)")
|
||||||
|
|
||||||
|
if upload_failed:
|
||||||
|
print("\nFailed uploads:")
|
||||||
|
for f in upload_failed:
|
||||||
|
print(f" - {f}")
|
||||||
|
|
||||||
|
# Exit with error if any operation failed
|
||||||
|
if upload_failed or (deleted_files and not delete_success):
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
print("\nAll operations completed successfully!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
9
.github/utils/delete_outdated_docs.py
vendored
9
.github/utils/delete_outdated_docs.py
vendored
@ -12,6 +12,9 @@ VERSION_VALIDATOR = re.compile(r"^[0-9]+\.[0-9]+$")
|
|||||||
|
|
||||||
|
|
||||||
def readme_token():
|
def readme_token():
|
||||||
|
"""
|
||||||
|
Get the Readme API token from the environment variable and encode it in base64.
|
||||||
|
"""
|
||||||
api_key = os.getenv("README_API_KEY", None)
|
api_key = os.getenv("README_API_KEY", None)
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise Exception("README_API_KEY env var is not set")
|
raise Exception("README_API_KEY env var is not set")
|
||||||
@ -21,6 +24,9 @@ def readme_token():
|
|||||||
|
|
||||||
|
|
||||||
def create_headers(version: str):
|
def create_headers(version: str):
|
||||||
|
"""
|
||||||
|
Create headers for the Readme API.
|
||||||
|
"""
|
||||||
return {"authorization": f"Basic {readme_token()}", "x-readme-version": version}
|
return {"authorization": f"Basic {readme_token()}", "x-readme-version": version}
|
||||||
|
|
||||||
|
|
||||||
@ -35,6 +41,9 @@ def get_docs_in_category(category_slug: str, version: str) -> List[str]:
|
|||||||
|
|
||||||
|
|
||||||
def delete_doc(slug: str, version: str):
|
def delete_doc(slug: str, version: str):
|
||||||
|
"""
|
||||||
|
Delete a document from Readme, based on the slug and version.
|
||||||
|
"""
|
||||||
url = f"https://dash.readme.com/api/v1/docs/{slug}"
|
url = f"https://dash.readme.com/api/v1/docs/{slug}"
|
||||||
headers = create_headers(version)
|
headers = create_headers(version)
|
||||||
res = requests.delete(url, headers=headers, timeout=10)
|
res = requests.delete(url, headers=headers, timeout=10)
|
||||||
|
8
.github/utils/docstrings_checksum.py
vendored
8
.github/utils/docstrings_checksum.py
vendored
@ -1,11 +1,13 @@
|
|||||||
|
import ast
|
||||||
|
import hashlib
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterator
|
from typing import Iterator
|
||||||
|
|
||||||
import ast
|
|
||||||
import hashlib
|
|
||||||
|
|
||||||
|
|
||||||
def docstrings_checksum(python_files: Iterator[Path]):
|
def docstrings_checksum(python_files: Iterator[Path]):
|
||||||
|
"""
|
||||||
|
Calculate the checksum of the docstrings in the given Python files.
|
||||||
|
"""
|
||||||
files_content = (f.read_text() for f in python_files)
|
files_content = (f.read_text() for f in python_files)
|
||||||
trees = (ast.parse(c) for c in files_content)
|
trees = (ast.parse(c) for c in files_content)
|
||||||
|
|
||||||
|
6
.github/utils/promote_unstable_docs.py
vendored
6
.github/utils/promote_unstable_docs.py
vendored
@ -1,6 +1,6 @@
|
|||||||
|
import argparse
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import argparse
|
|
||||||
|
|
||||||
from readme_api import get_versions, promote_unstable_to_stable
|
from readme_api import get_versions, promote_unstable_to_stable
|
||||||
|
|
||||||
@ -8,9 +8,7 @@ VERSION_VALIDATOR = re.compile(r"^[0-9]+\.[0-9]+$")
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument("-v", "--version", help="The version to promote to stable (e.g. 2.1).", required=True)
|
||||||
"-v", "--version", help="The version to promote to stable (e.g. 2.1).", required=True
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if VERSION_VALIDATOR.match(args.version) is None:
|
if VERSION_VALIDATOR.match(args.version) is None:
|
||||||
|
9
.github/utils/pyproject_to_requirements.py
vendored
9
.github/utils/pyproject_to_requirements.py
vendored
@ -3,7 +3,8 @@ import re
|
|||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import toml
|
# toml is available in the default environment but not in the test environment, so pylint complains
|
||||||
|
import toml # pylint: disable=import-error
|
||||||
|
|
||||||
matcher = re.compile(r"farm-haystack\[(.+)\]")
|
matcher = re.compile(r"farm-haystack\[(.+)\]")
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -14,6 +15,9 @@ parser.add_argument("--extra", default="")
|
|||||||
|
|
||||||
|
|
||||||
def resolve(target: str, extras: dict, results: set):
|
def resolve(target: str, extras: dict, results: set):
|
||||||
|
"""
|
||||||
|
Resolve the dependencies for a given target.
|
||||||
|
"""
|
||||||
if target not in extras:
|
if target not in extras:
|
||||||
results.add(target)
|
results.add(target)
|
||||||
return
|
return
|
||||||
@ -28,6 +32,9 @@ def resolve(target: str, extras: dict, results: set):
|
|||||||
|
|
||||||
|
|
||||||
def main(pyproject_path: Path, extra: str = ""):
|
def main(pyproject_path: Path, extra: str = ""):
|
||||||
|
"""
|
||||||
|
Convert a pyproject.toml file to a requirements.txt file.
|
||||||
|
"""
|
||||||
content = toml.load(pyproject_path)
|
content = toml.load(pyproject_path)
|
||||||
# basic set of dependencies
|
# basic set of dependencies
|
||||||
deps = set(content["project"]["dependencies"])
|
deps = set(content["project"]["dependencies"])
|
||||||
|
13
.github/utils/readme_api.py
vendored
13
.github/utils/readme_api.py
vendored
@ -1,16 +1,23 @@
|
|||||||
import os
|
|
||||||
import base64
|
import base64
|
||||||
|
import os
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ReadmeAuth(requests.auth.AuthBase):
|
class ReadmeAuth(requests.auth.AuthBase):
|
||||||
def __call__(self, r):
|
"""
|
||||||
|
Custom authentication class for Readme API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, r): # noqa: D102
|
||||||
r.headers["authorization"] = f"Basic {readme_token()}"
|
r.headers["authorization"] = f"Basic {readme_token()}"
|
||||||
return r
|
return r
|
||||||
|
|
||||||
|
|
||||||
def readme_token():
|
def readme_token():
|
||||||
|
"""
|
||||||
|
Get the Readme API token from the environment variable and encode it in base64.
|
||||||
|
"""
|
||||||
api_key = os.getenv("RDME_API_KEY", None)
|
api_key = os.getenv("RDME_API_KEY", None)
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise Exception("RDME_API_KEY env var is not set")
|
raise Exception("RDME_API_KEY env var is not set")
|
||||||
|
4
.github/workflows/e2e.yml
vendored
4
.github/workflows/e2e.yml
vendored
@ -18,7 +18,9 @@ env:
|
|||||||
PYTHON_VERSION: "3.9"
|
PYTHON_VERSION: "3.9"
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
HATCH_VERSION: "1.14.1"
|
HATCH_VERSION: "1.14.1"
|
||||||
HF_API_TOKEN: ${{ secrets.HUGGINGFACE_API_KEY }}
|
# we use HF_TOKEN instead of HF_API_TOKEN to work around a Hugging Face bug
|
||||||
|
# see https://github.com/deepset-ai/haystack/issues/9552
|
||||||
|
HF_TOKEN: ${{ secrets.HUGGINGFACE_API_KEY }}
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
run:
|
run:
|
||||||
|
56
.github/workflows/sync_code_to_deepset.yml
vendored
Normal file
56
.github/workflows/sync_code_to_deepset.yml
vendored
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
name: Upload Code to Deepset
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
jobs:
|
||||||
|
upload-files:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0 # Fetch all history for proper diff
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.12'
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
run: |
|
||||||
|
pip install uv
|
||||||
|
|
||||||
|
- name: Get changed files
|
||||||
|
id: changed-files
|
||||||
|
uses: tj-actions/changed-files@v46
|
||||||
|
with:
|
||||||
|
files: |
|
||||||
|
haystack/**/*.py
|
||||||
|
separator: ' '
|
||||||
|
|
||||||
|
- name: Upload files to Deepset
|
||||||
|
if: steps.changed-files.outputs.any_changed == 'true' || steps.changed-files.outputs.any_deleted == 'true'
|
||||||
|
env:
|
||||||
|
DEEPSET_API_KEY: ${{ secrets.DEEPSET_API_KEY }}
|
||||||
|
DEEPSET_WORKSPACE: haystack-code
|
||||||
|
run: |
|
||||||
|
# Combine added and modified files for upload
|
||||||
|
CHANGED_FILES=""
|
||||||
|
if [ -n "${{ steps.changed-files.outputs.added_files }}" ]; then
|
||||||
|
CHANGED_FILES="${{ steps.changed-files.outputs.added_files }}"
|
||||||
|
fi
|
||||||
|
if [ -n "${{ steps.changed-files.outputs.modified_files }}" ]; then
|
||||||
|
if [ -n "$CHANGED_FILES" ]; then
|
||||||
|
CHANGED_FILES="$CHANGED_FILES ${{ steps.changed-files.outputs.modified_files }}"
|
||||||
|
else
|
||||||
|
CHANGED_FILES="${{ steps.changed-files.outputs.modified_files }}"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Run the script with changed and deleted files
|
||||||
|
# shellcheck disable=SC2086
|
||||||
|
uv run --no-project --no-config --no-cache .github/utils/deepset_sync.py \
|
||||||
|
--changed $CHANGED_FILES \
|
||||||
|
--deleted ${{ steps.changed-files.outputs.deleted_files }}
|
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@ -19,7 +19,7 @@ on:
|
|||||||
- "haystack/core/pipeline/predefined/*"
|
- "haystack/core/pipeline/predefined/*"
|
||||||
- "test/**/*.py"
|
- "test/**/*.py"
|
||||||
- "pyproject.toml"
|
- "pyproject.toml"
|
||||||
- ".github/utils/check_imports.py"
|
- ".github/utils/*.py"
|
||||||
|
|
||||||
env:
|
env:
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
|
4
.github/workflows/tests_skipper_trigger.yml
vendored
4
.github/workflows/tests_skipper_trigger.yml
vendored
@ -16,7 +16,7 @@ on:
|
|||||||
- "haystack/core/pipeline/predefined/*"
|
- "haystack/core/pipeline/predefined/*"
|
||||||
- "test/**/*.py"
|
- "test/**/*.py"
|
||||||
- "pyproject.toml"
|
- "pyproject.toml"
|
||||||
- ".github/utils/check_imports.py"
|
- ".github/utils/*.py"
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
check_if_changed:
|
check_if_changed:
|
||||||
@ -39,7 +39,7 @@ jobs:
|
|||||||
- "haystack/core/pipeline/predefined/*"
|
- "haystack/core/pipeline/predefined/*"
|
||||||
- "test/**/*.py"
|
- "test/**/*.py"
|
||||||
- "pyproject.toml"
|
- "pyproject.toml"
|
||||||
- ".github/utils/check_imports.py"
|
- ".github/utils/*.py"
|
||||||
|
|
||||||
trigger-catch-all:
|
trigger-catch-all:
|
||||||
name: Tests completed
|
name: Tests completed
|
||||||
|
@ -1 +1 @@
|
|||||||
2.15.0-rc0
|
2.16.0-rc0
|
||||||
|
@ -2,7 +2,7 @@ loaders:
|
|||||||
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
|
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
|
||||||
search_path: [../../../haystack/dataclasses]
|
search_path: [../../../haystack/dataclasses]
|
||||||
modules:
|
modules:
|
||||||
["answer", "byte_stream", "chat_message", "document", "sparse_embedding", "state", "streaming_chunk"]
|
["answer", "byte_stream", "chat_message", "document", "sparse_embedding", "streaming_chunk"]
|
||||||
ignore_when_discovered: ["__init__"]
|
ignore_when_discovered: ["__init__"]
|
||||||
processors:
|
processors:
|
||||||
- type: filter
|
- type: filter
|
||||||
|
@ -68,8 +68,8 @@ def test_ner_extractor_hf_backend(raw_texts, hf_annotations, batch_size):
|
|||||||
|
|
||||||
@pytest.mark.parametrize("batch_size", [1, 3])
|
@pytest.mark.parametrize("batch_size", [1, 3])
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
not os.environ.get("HF_API_TOKEN", None),
|
not os.environ.get("HF_API_TOKEN", None) and not os.environ.get("HF_TOKEN", None),
|
||||||
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
|
reason="Export an env var called HF_API_TOKEN or HF_TOKEN containing the Hugging Face token to run this test.",
|
||||||
)
|
)
|
||||||
def test_ner_extractor_hf_backend_private_models(raw_texts, hf_annotations, batch_size):
|
def test_ner_extractor_hf_backend_private_models(raw_texts, hf_annotations, batch_size):
|
||||||
extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="deepset/bert-base-NER")
|
extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="deepset/bert-base-NER")
|
||||||
|
@ -67,8 +67,9 @@ class Agent:
|
|||||||
exit_conditions: Optional[List[str]] = None,
|
exit_conditions: Optional[List[str]] = None,
|
||||||
state_schema: Optional[Dict[str, Any]] = None,
|
state_schema: Optional[Dict[str, Any]] = None,
|
||||||
max_agent_steps: int = 100,
|
max_agent_steps: int = 100,
|
||||||
raise_on_tool_invocation_failure: bool = False,
|
|
||||||
streaming_callback: Optional[StreamingCallbackT] = None,
|
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||||
|
raise_on_tool_invocation_failure: bool = False,
|
||||||
|
tool_invoker_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize the agent component.
|
Initialize the agent component.
|
||||||
@ -82,10 +83,11 @@ class Agent:
|
|||||||
:param state_schema: The schema for the runtime state used by the tools.
|
:param state_schema: The schema for the runtime state used by the tools.
|
||||||
:param max_agent_steps: Maximum number of steps the agent will run before stopping. Defaults to 100.
|
:param max_agent_steps: Maximum number of steps the agent will run before stopping. Defaults to 100.
|
||||||
If the agent exceeds this number of steps, it will stop and return the current state.
|
If the agent exceeds this number of steps, it will stop and return the current state.
|
||||||
:param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails?
|
|
||||||
If set to False, the exception will be turned into a chat message and passed to the LLM.
|
|
||||||
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
|
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
|
||||||
The same callback can be configured to emit tool results when a tool is called.
|
The same callback can be configured to emit tool results when a tool is called.
|
||||||
|
:param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails?
|
||||||
|
If set to False, the exception will be turned into a chat message and passed to the LLM.
|
||||||
|
:param tool_invoker_kwargs: Additional keyword arguments to pass to the ToolInvoker.
|
||||||
:raises TypeError: If the chat_generator does not support tools parameter in its run method.
|
:raises TypeError: If the chat_generator does not support tools parameter in its run method.
|
||||||
:raises ValueError: If the exit_conditions are not valid.
|
:raises ValueError: If the exit_conditions are not valid.
|
||||||
"""
|
"""
|
||||||
@ -135,9 +137,15 @@ class Agent:
|
|||||||
component.set_input_type(self, name=param, type=config["type"], default=None)
|
component.set_input_type(self, name=param, type=config["type"], default=None)
|
||||||
component.set_output_types(self, **output_types)
|
component.set_output_types(self, **output_types)
|
||||||
|
|
||||||
|
self.tool_invoker_kwargs = tool_invoker_kwargs
|
||||||
self._tool_invoker = None
|
self._tool_invoker = None
|
||||||
if self.tools:
|
if self.tools:
|
||||||
self._tool_invoker = ToolInvoker(tools=self.tools, raise_on_failure=self.raise_on_tool_invocation_failure)
|
resolved_tool_invoker_kwargs = {
|
||||||
|
"tools": self.tools,
|
||||||
|
"raise_on_failure": self.raise_on_tool_invocation_failure,
|
||||||
|
**(tool_invoker_kwargs or {}),
|
||||||
|
}
|
||||||
|
self._tool_invoker = ToolInvoker(**resolved_tool_invoker_kwargs)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"No tools provided to the Agent. The Agent will behave like a ChatGenerator and only return text "
|
"No tools provided to the Agent. The Agent will behave like a ChatGenerator and only return text "
|
||||||
@ -175,8 +183,9 @@ class Agent:
|
|||||||
# We serialize the original state schema, not the resolved one to reflect the original user input
|
# We serialize the original state schema, not the resolved one to reflect the original user input
|
||||||
state_schema=_schema_to_dict(self._state_schema),
|
state_schema=_schema_to_dict(self._state_schema),
|
||||||
max_agent_steps=self.max_agent_steps,
|
max_agent_steps=self.max_agent_steps,
|
||||||
raise_on_tool_invocation_failure=self.raise_on_tool_invocation_failure,
|
|
||||||
streaming_callback=streaming_callback,
|
streaming_callback=streaming_callback,
|
||||||
|
raise_on_tool_invocation_failure=self.raise_on_tool_invocation_failure,
|
||||||
|
tool_invoker_kwargs=self.tool_invoker_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -193,13 +193,13 @@ class HuggingFaceAPIChatGenerator:
|
|||||||
|
|
||||||
HuggingFaceAPIChatGenerator uses the [ChatMessage](https://docs.haystack.deepset.ai/docs/chatmessage)
|
HuggingFaceAPIChatGenerator uses the [ChatMessage](https://docs.haystack.deepset.ai/docs/chatmessage)
|
||||||
format for input and output. Use it to generate text with Hugging Face APIs:
|
format for input and output. Use it to generate text with Hugging Face APIs:
|
||||||
- [Free Serverless Inference API](https://huggingface.co/inference-api)
|
- [Serverless Inference API (Inference Providers)](https://huggingface.co/docs/inference-providers)
|
||||||
- [Paid Inference Endpoints](https://huggingface.co/inference-endpoints)
|
- [Paid Inference Endpoints](https://huggingface.co/inference-endpoints)
|
||||||
- [Self-hosted Text Generation Inference](https://github.com/huggingface/text-generation-inference)
|
- [Self-hosted Text Generation Inference](https://github.com/huggingface/text-generation-inference)
|
||||||
|
|
||||||
### Usage examples
|
### Usage examples
|
||||||
|
|
||||||
#### With the free serverless inference API
|
#### With the serverless inference API (Inference Providers) - free tier available
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from haystack.components.generators.chat import HuggingFaceAPIChatGenerator
|
from haystack.components.generators.chat import HuggingFaceAPIChatGenerator
|
||||||
@ -215,7 +215,8 @@ class HuggingFaceAPIChatGenerator:
|
|||||||
api_type = "serverless_inference_api" # this is equivalent to the above
|
api_type = "serverless_inference_api" # this is equivalent to the above
|
||||||
|
|
||||||
generator = HuggingFaceAPIChatGenerator(api_type=api_type,
|
generator = HuggingFaceAPIChatGenerator(api_type=api_type,
|
||||||
api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
|
api_params={"model": "microsoft/Phi-3.5-mini-instruct",
|
||||||
|
"provider": "featherless-ai"},
|
||||||
token=Secret.from_token("<your-api-key>"))
|
token=Secret.from_token("<your-api-key>"))
|
||||||
|
|
||||||
result = generator.run(messages)
|
result = generator.run(messages)
|
||||||
@ -273,13 +274,15 @@ class HuggingFaceAPIChatGenerator:
|
|||||||
The type of Hugging Face API to use. Available types:
|
The type of Hugging Face API to use. Available types:
|
||||||
- `text_generation_inference`: See [TGI](https://github.com/huggingface/text-generation-inference).
|
- `text_generation_inference`: See [TGI](https://github.com/huggingface/text-generation-inference).
|
||||||
- `inference_endpoints`: See [Inference Endpoints](https://huggingface.co/inference-endpoints).
|
- `inference_endpoints`: See [Inference Endpoints](https://huggingface.co/inference-endpoints).
|
||||||
- `serverless_inference_api`: See [Serverless Inference API](https://huggingface.co/inference-api).
|
- `serverless_inference_api`: See
|
||||||
|
[Serverless Inference API - Inference Providers](https://huggingface.co/docs/inference-providers).
|
||||||
:param api_params:
|
:param api_params:
|
||||||
A dictionary with the following keys:
|
A dictionary with the following keys:
|
||||||
- `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`.
|
- `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`.
|
||||||
|
- `provider`: Provider name. Recommended when `api_type` is `SERVERLESS_INFERENCE_API`.
|
||||||
- `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or
|
- `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or
|
||||||
`TEXT_GENERATION_INFERENCE`.
|
`TEXT_GENERATION_INFERENCE`.
|
||||||
- Other parameters specific to the chosen API type, such as `timeout`, `headers`, `provider` etc.
|
- Other parameters specific to the chosen API type, such as `timeout`, `headers`, etc.
|
||||||
:param token:
|
:param token:
|
||||||
The Hugging Face token to use as HTTP bearer authorization.
|
The Hugging Face token to use as HTTP bearer authorization.
|
||||||
Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
|
Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
|
||||||
|
@ -6,7 +6,7 @@ from dataclasses import asdict
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Union, cast
|
from typing import Any, Dict, Iterable, List, Optional, Union, cast
|
||||||
|
|
||||||
from haystack import component, default_from_dict, default_to_dict
|
from haystack import component, default_from_dict, default_to_dict, logging
|
||||||
from haystack.dataclasses import (
|
from haystack.dataclasses import (
|
||||||
ComponentInfo,
|
ComponentInfo,
|
||||||
FinishReason,
|
FinishReason,
|
||||||
@ -29,33 +29,26 @@ with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as hugg
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@component
|
@component
|
||||||
class HuggingFaceAPIGenerator:
|
class HuggingFaceAPIGenerator:
|
||||||
"""
|
"""
|
||||||
Generates text using Hugging Face APIs.
|
Generates text using Hugging Face APIs.
|
||||||
|
|
||||||
Use it with the following Hugging Face APIs:
|
Use it with the following Hugging Face APIs:
|
||||||
- [Free Serverless Inference API]((https://huggingface.co/inference-api)
|
|
||||||
- [Paid Inference Endpoints](https://huggingface.co/inference-endpoints)
|
- [Paid Inference Endpoints](https://huggingface.co/inference-endpoints)
|
||||||
- [Self-hosted Text Generation Inference](https://github.com/huggingface/text-generation-inference)
|
- [Self-hosted Text Generation Inference](https://github.com/huggingface/text-generation-inference)
|
||||||
|
|
||||||
|
**Note:** As of July 2025, the Hugging Face Inference API no longer offers generative models through the
|
||||||
|
`text_generation` endpoint. Generative models are now only available through providers supporting the
|
||||||
|
`chat_completion` endpoint. As a result, this component might no longer work with the Hugging Face Inference API.
|
||||||
|
Use the `HuggingFaceAPIChatGenerator` component, which supports the `chat_completion` endpoint.
|
||||||
|
|
||||||
### Usage examples
|
### Usage examples
|
||||||
|
|
||||||
#### With the free serverless inference API
|
#### With Hugging Face Inference Endpoints
|
||||||
|
|
||||||
```python
|
|
||||||
from haystack.components.generators import HuggingFaceAPIGenerator
|
|
||||||
from haystack.utils import Secret
|
|
||||||
|
|
||||||
generator = HuggingFaceAPIGenerator(api_type="serverless_inference_api",
|
|
||||||
api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
|
|
||||||
token=Secret.from_token("<your-api-key>"))
|
|
||||||
|
|
||||||
result = generator.run(prompt="What's Natural Language Processing?")
|
|
||||||
print(result)
|
|
||||||
```
|
|
||||||
|
|
||||||
#### With paid inference endpoints
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from haystack.components.generators import HuggingFaceAPIGenerator
|
from haystack.components.generators import HuggingFaceAPIGenerator
|
||||||
@ -75,6 +68,24 @@ class HuggingFaceAPIGenerator:
|
|||||||
generator = HuggingFaceAPIGenerator(api_type="text_generation_inference",
|
generator = HuggingFaceAPIGenerator(api_type="text_generation_inference",
|
||||||
api_params={"url": "http://localhost:8080"})
|
api_params={"url": "http://localhost:8080"})
|
||||||
|
|
||||||
|
result = generator.run(prompt="What's Natural Language Processing?")
|
||||||
|
print(result)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### With the free serverless inference API
|
||||||
|
|
||||||
|
Be aware that this example might not work as the Hugging Face Inference API no longer offer models that support the
|
||||||
|
`text_generation` endpoint. Use the `HuggingFaceAPIChatGenerator` for generative models through the
|
||||||
|
`chat_completion` endpoint.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from haystack.components.generators import HuggingFaceAPIGenerator
|
||||||
|
from haystack.utils import Secret
|
||||||
|
|
||||||
|
generator = HuggingFaceAPIGenerator(api_type="serverless_inference_api",
|
||||||
|
api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
|
||||||
|
token=Secret.from_token("<your-api-key>"))
|
||||||
|
|
||||||
result = generator.run(prompt="What's Natural Language Processing?")
|
result = generator.run(prompt="What's Natural Language Processing?")
|
||||||
print(result)
|
print(result)
|
||||||
```
|
```
|
||||||
@ -97,6 +108,8 @@ class HuggingFaceAPIGenerator:
|
|||||||
- `text_generation_inference`: See [TGI](https://github.com/huggingface/text-generation-inference).
|
- `text_generation_inference`: See [TGI](https://github.com/huggingface/text-generation-inference).
|
||||||
- `inference_endpoints`: See [Inference Endpoints](https://huggingface.co/inference-endpoints).
|
- `inference_endpoints`: See [Inference Endpoints](https://huggingface.co/inference-endpoints).
|
||||||
- `serverless_inference_api`: See [Serverless Inference API](https://huggingface.co/inference-api).
|
- `serverless_inference_api`: See [Serverless Inference API](https://huggingface.co/inference-api).
|
||||||
|
This might no longer work due to changes in the models offered in the Hugging Face Inference API.
|
||||||
|
Please use the `HuggingFaceAPIChatGenerator` component instead.
|
||||||
:param api_params:
|
:param api_params:
|
||||||
A dictionary with the following keys:
|
A dictionary with the following keys:
|
||||||
- `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`.
|
- `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`.
|
||||||
@ -120,6 +133,11 @@ class HuggingFaceAPIGenerator:
|
|||||||
api_type = HFGenerationAPIType.from_str(api_type)
|
api_type = HFGenerationAPIType.from_str(api_type)
|
||||||
|
|
||||||
if api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API:
|
if api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API:
|
||||||
|
logger.warning(
|
||||||
|
"Due to changes in the models offered in Hugging Face Inference API, using this component with the "
|
||||||
|
"Serverless Inference API might no longer work. "
|
||||||
|
"Please use the `HuggingFaceAPIChatGenerator` component instead."
|
||||||
|
)
|
||||||
model = api_params.get("model")
|
model = api_params.get("model")
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -44,7 +44,7 @@ def print_streaming_chunk(chunk: StreamingChunk) -> None:
|
|||||||
if chunk.index and tool_call.index > chunk.index:
|
if chunk.index and tool_call.index > chunk.index:
|
||||||
print("\n\n", flush=True, end="")
|
print("\n\n", flush=True, end="")
|
||||||
|
|
||||||
print("[TOOL CALL]\nTool: {tool_call.tool_name} \nArguments: ", flush=True, end="")
|
print(f"[TOOL CALL]\nTool: {tool_call.tool_name} \nArguments: ", flush=True, end="")
|
||||||
|
|
||||||
# print the tool arguments
|
# print the tool arguments
|
||||||
if tool_call.arguments:
|
if tool_call.arguments:
|
||||||
@ -84,24 +84,20 @@ def _convert_streaming_chunks_to_chat_message(chunks: List[StreamingChunk]) -> C
|
|||||||
tool_call_data: Dict[int, Dict[str, str]] = {} # Track tool calls by index
|
tool_call_data: Dict[int, Dict[str, str]] = {} # Track tool calls by index
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
if chunk.tool_calls:
|
if chunk.tool_calls:
|
||||||
# We do this to make sure mypy is happy, but we enforce index is not None in the StreamingChunk dataclass if
|
|
||||||
# tool_call is present
|
|
||||||
assert chunk.index is not None
|
|
||||||
|
|
||||||
for tool_call in chunk.tool_calls:
|
for tool_call in chunk.tool_calls:
|
||||||
# We use the index of the tool_call to track the tool call across chunks since the ID is not always
|
# We use the index of the tool_call to track the tool call across chunks since the ID is not always
|
||||||
# provided
|
# provided
|
||||||
if tool_call.index not in tool_call_data:
|
if tool_call.index not in tool_call_data:
|
||||||
tool_call_data[chunk.index] = {"id": "", "name": "", "arguments": ""}
|
tool_call_data[tool_call.index] = {"id": "", "name": "", "arguments": ""}
|
||||||
|
|
||||||
# Save the ID if present
|
# Save the ID if present
|
||||||
if tool_call.id is not None:
|
if tool_call.id is not None:
|
||||||
tool_call_data[chunk.index]["id"] = tool_call.id
|
tool_call_data[tool_call.index]["id"] = tool_call.id
|
||||||
|
|
||||||
if tool_call.tool_name is not None:
|
if tool_call.tool_name is not None:
|
||||||
tool_call_data[chunk.index]["name"] += tool_call.tool_name
|
tool_call_data[tool_call.index]["name"] += tool_call.tool_name
|
||||||
if tool_call.arguments is not None:
|
if tool_call.arguments is not None:
|
||||||
tool_call_data[chunk.index]["arguments"] += tool_call.arguments
|
tool_call_data[tool_call.index]["arguments"] += tool_call.arguments
|
||||||
|
|
||||||
# Convert accumulated tool call data into ToolCall objects
|
# Convert accumulated tool call data into ToolCall objects
|
||||||
sorted_keys = sorted(tool_call_data.keys())
|
sorted_keys = sorted(tool_call_data.keys())
|
||||||
|
@ -328,8 +328,8 @@ class SentenceTransformersDiversityRanker:
|
|||||||
|
|
||||||
# Normalize embeddings to unit length for computing cosine similarity
|
# Normalize embeddings to unit length for computing cosine similarity
|
||||||
if self.similarity == DiversityRankingSimilarity.COSINE:
|
if self.similarity == DiversityRankingSimilarity.COSINE:
|
||||||
doc_embeddings /= torch.norm(doc_embeddings, p=2, dim=-1).unsqueeze(-1)
|
doc_embeddings = doc_embeddings / torch.norm(doc_embeddings, p=2, dim=-1).unsqueeze(-1)
|
||||||
query_embedding /= torch.norm(query_embedding, p=2, dim=-1).unsqueeze(-1)
|
query_embedding = query_embedding / torch.norm(query_embedding, p=2, dim=-1).unsqueeze(-1)
|
||||||
return doc_embeddings, query_embedding
|
return doc_embeddings, query_embedding
|
||||||
|
|
||||||
def _maximum_margin_relevance(
|
def _maximum_margin_relevance(
|
||||||
|
@ -5,7 +5,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import warnings
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Dict, List, Optional, Set, Union
|
from typing import Any, Dict, List, Optional, Set, Union
|
||||||
@ -171,7 +170,6 @@ class ToolInvoker:
|
|||||||
*,
|
*,
|
||||||
enable_streaming_callback_passthrough: bool = False,
|
enable_streaming_callback_passthrough: bool = False,
|
||||||
max_workers: int = 4,
|
max_workers: int = 4,
|
||||||
async_executor: Optional[ThreadPoolExecutor] = None,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the ToolInvoker component.
|
Initialize the ToolInvoker component.
|
||||||
@ -197,13 +195,7 @@ class ToolInvoker:
|
|||||||
If False, the `streaming_callback` will not be passed to the tool invocation.
|
If False, the `streaming_callback` will not be passed to the tool invocation.
|
||||||
:param max_workers:
|
:param max_workers:
|
||||||
The maximum number of workers to use in the thread pool executor.
|
The maximum number of workers to use in the thread pool executor.
|
||||||
:param async_executor:
|
This also decides the maximum number of concurrent tool invocations.
|
||||||
Optional `ThreadPoolExecutor` to use for asynchronous calls.
|
|
||||||
Note: As of Haystack 2.15.0, you no longer need to explicitly pass
|
|
||||||
`async_executor`. Instead, you can provide the `max_workers` parameter,
|
|
||||||
and a `ThreadPoolExecutor` will be created automatically for parallel tool invocations.
|
|
||||||
Support for `async_executor` will be removed in Haystack 2.16.0.
|
|
||||||
Please migrate to using `max_workers` instead.
|
|
||||||
:raises ValueError:
|
:raises ValueError:
|
||||||
If no tools are provided or if duplicate tool names are found.
|
If no tools are provided or if duplicate tool names are found.
|
||||||
"""
|
"""
|
||||||
@ -231,37 +223,6 @@ class ToolInvoker:
|
|||||||
self._tools_with_names = dict(zip(tool_names, converted_tools))
|
self._tools_with_names = dict(zip(tool_names, converted_tools))
|
||||||
self.raise_on_failure = raise_on_failure
|
self.raise_on_failure = raise_on_failure
|
||||||
self.convert_result_to_json_string = convert_result_to_json_string
|
self.convert_result_to_json_string = convert_result_to_json_string
|
||||||
self._owns_executor = async_executor is None
|
|
||||||
if self._owns_executor:
|
|
||||||
warnings.warn(
|
|
||||||
"'async_executor' is deprecated in favor of the 'max_workers' parameter. "
|
|
||||||
"ToolInvoker now creates its own thread pool executor by default using 'max_workers'. "
|
|
||||||
"Support for 'async_executor' will be removed in Haystack 2.16.0. "
|
|
||||||
"Please update your usage to pass 'max_workers' instead.",
|
|
||||||
DeprecationWarning,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.executor = (
|
|
||||||
ThreadPoolExecutor(
|
|
||||||
thread_name_prefix=f"async-ToolInvoker-executor-{id(self)}", max_workers=self.max_workers
|
|
||||||
)
|
|
||||||
if async_executor is None
|
|
||||||
else async_executor
|
|
||||||
)
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
"""
|
|
||||||
Cleanup when the instance is being destroyed.
|
|
||||||
"""
|
|
||||||
if hasattr(self, "_owns_executor") and self._owns_executor and hasattr(self, "executor"):
|
|
||||||
self.executor.shutdown(wait=True)
|
|
||||||
|
|
||||||
def shutdown(self):
|
|
||||||
"""
|
|
||||||
Explicitly shutdown the executor if we own it.
|
|
||||||
"""
|
|
||||||
if self._owns_executor:
|
|
||||||
self.executor.shutdown(wait=True)
|
|
||||||
|
|
||||||
def _handle_error(self, error: Exception) -> str:
|
def _handle_error(self, error: Exception) -> str:
|
||||||
"""
|
"""
|
||||||
@ -655,7 +616,7 @@ class ToolInvoker:
|
|||||||
return e
|
return e
|
||||||
|
|
||||||
@component.output_types(tool_messages=List[ChatMessage], state=State)
|
@component.output_types(tool_messages=List[ChatMessage], state=State)
|
||||||
async def run_async(
|
async def run_async( # noqa: PLR0915
|
||||||
self,
|
self,
|
||||||
messages: List[ChatMessage],
|
messages: List[ChatMessage],
|
||||||
state: Optional[State] = None,
|
state: Optional[State] = None,
|
||||||
@ -718,10 +679,10 @@ class ToolInvoker:
|
|||||||
|
|
||||||
# 2) Execute valid tool calls in parallel
|
# 2) Execute valid tool calls in parallel
|
||||||
if tool_call_params:
|
if tool_call_params:
|
||||||
with self.executor as executor:
|
|
||||||
tool_call_tasks = []
|
tool_call_tasks = []
|
||||||
valid_tool_calls = []
|
valid_tool_calls = []
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||||
# 3) Create async tasks for valid tool calls
|
# 3) Create async tasks for valid tool calls
|
||||||
for params in tool_call_params:
|
for params in tool_call_params:
|
||||||
task = ToolInvoker.invoke_tool_safely(executor, params["tool_to_invoke"], params["final_args"])
|
task = ToolInvoker.invoke_tool_safely(executor, params["tool_to_invoke"], params["final_args"])
|
||||||
@ -749,7 +710,7 @@ class ToolInvoker:
|
|||||||
try:
|
try:
|
||||||
error_message = self._handle_error(
|
error_message = self._handle_error(
|
||||||
ToolOutputMergeError(
|
ToolOutputMergeError(
|
||||||
f"Failed to merge tool outputs fromtool {tool_call.tool_name} into State: {e}"
|
f"Failed to merge tool outputs from tool {tool_call.tool_name} into State: {e}"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
tool_messages.append(
|
tool_messages.append(
|
||||||
|
@ -38,7 +38,6 @@ if TYPE_CHECKING:
|
|||||||
from .chat_message import ToolCallResult as ToolCallResult
|
from .chat_message import ToolCallResult as ToolCallResult
|
||||||
from .document import Document as Document
|
from .document import Document as Document
|
||||||
from .sparse_embedding import SparseEmbedding as SparseEmbedding
|
from .sparse_embedding import SparseEmbedding as SparseEmbedding
|
||||||
from .state import State as State
|
|
||||||
from .streaming_chunk import AsyncStreamingCallbackT as AsyncStreamingCallbackT
|
from .streaming_chunk import AsyncStreamingCallbackT as AsyncStreamingCallbackT
|
||||||
from .streaming_chunk import ComponentInfo as ComponentInfo
|
from .streaming_chunk import ComponentInfo as ComponentInfo
|
||||||
from .streaming_chunk import FinishReason as FinishReason
|
from .streaming_chunk import FinishReason as FinishReason
|
||||||
|
@ -79,3 +79,24 @@ class ByteStream:
|
|||||||
fields.append(f"mime_type={self.mime_type!r}")
|
fields.append(f"mime_type={self.mime_type!r}")
|
||||||
fields_str = ", ".join(fields)
|
fields_str = ", ".join(fields)
|
||||||
return f"{self.__class__.__name__}({fields_str})"
|
return f"{self.__class__.__name__}({fields_str})"
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Convert the ByteStream to a dictionary representation.
|
||||||
|
|
||||||
|
:returns: A dictionary with keys 'data', 'meta', and 'mime_type'.
|
||||||
|
"""
|
||||||
|
# Note: The data is converted to a list of integers for serialization since JSON does not support bytes
|
||||||
|
# directly.
|
||||||
|
return {"data": list(self.data), "meta": self.meta, "mime_type": self.mime_type}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "ByteStream":
|
||||||
|
"""
|
||||||
|
Create a ByteStream from a dictionary representation.
|
||||||
|
|
||||||
|
:param data: A dictionary with keys 'data', 'meta', and 'mime_type'.
|
||||||
|
|
||||||
|
:returns: A ByteStream instance.
|
||||||
|
"""
|
||||||
|
return ByteStream(data=bytes(data["data"]), meta=data.get("meta", {}), mime_type=data.get("mime_type"))
|
||||||
|
@ -127,8 +127,12 @@ class Document(metaclass=_BackwardCompatible): # noqa: PLW1641
|
|||||||
Whether to flatten `meta` field or not. Defaults to `True` to be backward-compatible with Haystack 1.x.
|
Whether to flatten `meta` field or not. Defaults to `True` to be backward-compatible with Haystack 1.x.
|
||||||
"""
|
"""
|
||||||
data = asdict(self)
|
data = asdict(self)
|
||||||
if (blob := data.get("blob")) is not None:
|
|
||||||
data["blob"] = {"data": list(blob["data"]), "mime_type": blob["mime_type"]}
|
# Use `ByteStream` and `SparseEmbedding`'s to_dict methods to convert them to JSON-serializable types.
|
||||||
|
if self.blob is not None:
|
||||||
|
data["blob"] = self.blob.to_dict()
|
||||||
|
if self.sparse_embedding is not None:
|
||||||
|
data["sparse_embedding"] = self.sparse_embedding.to_dict()
|
||||||
|
|
||||||
if flatten:
|
if flatten:
|
||||||
meta = data.pop("meta")
|
meta = data.pop("meta")
|
||||||
@ -144,7 +148,7 @@ class Document(metaclass=_BackwardCompatible): # noqa: PLW1641
|
|||||||
The `blob` field is converted to its original type.
|
The `blob` field is converted to its original type.
|
||||||
"""
|
"""
|
||||||
if blob := data.get("blob"):
|
if blob := data.get("blob"):
|
||||||
data["blob"] = ByteStream(data=bytes(blob["data"]), mime_type=blob["mime_type"])
|
data["blob"] = ByteStream.from_dict(blob)
|
||||||
if sparse_embedding := data.get("sparse_embedding"):
|
if sparse_embedding := data.get("sparse_embedding"):
|
||||||
data["sparse_embedding"] = SparseEmbedding.from_dict(sparse_embedding)
|
data["sparse_embedding"] = SparseEmbedding.from_dict(sparse_embedding)
|
||||||
|
|
||||||
|
@ -1,43 +0,0 @@
|
|||||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
||||||
#
|
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
|
|
||||||
|
|
||||||
import warnings
|
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
from haystack.components.agents import State as Utils_State
|
|
||||||
|
|
||||||
|
|
||||||
class State(Utils_State):
|
|
||||||
"""
|
|
||||||
A class that wraps a StateSchema and maintains an internal _data dictionary.
|
|
||||||
|
|
||||||
Deprecated in favor of `haystack.components.agents.State`. It will be removed in Haystack 2.16.0.
|
|
||||||
|
|
||||||
Each schema entry has:
|
|
||||||
"parameter_name": {
|
|
||||||
"type": SomeType,
|
|
||||||
"handler": Optional[Callable[[Any, Any], Any]]
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, schema: Dict[str, Any], data: Optional[Dict[str, Any]] = None):
|
|
||||||
"""
|
|
||||||
Initialize a State object with a schema and optional data.
|
|
||||||
|
|
||||||
:param schema: Dictionary mapping parameter names to their type and handler configs.
|
|
||||||
Type must be a valid Python type, and handler must be a callable function or None.
|
|
||||||
If handler is None, the default handler for the type will be used. The default handlers are:
|
|
||||||
- For list types: `haystack.components.agents.state.state_utils.merge_lists`
|
|
||||||
- For all other types: `haystack.components.agents.state.state_utils.replace_values`
|
|
||||||
:param data: Optional dictionary of initial data to populate the state
|
|
||||||
"""
|
|
||||||
|
|
||||||
warnings.warn(
|
|
||||||
"`haystack.dataclasses.State` is deprecated and will be removed in Haystack 2.16.0. "
|
|
||||||
"Use `haystack.components.agents.State` instead.",
|
|
||||||
DeprecationWarning,
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(schema, data)
|
|
@ -1,52 +0,0 @@
|
|||||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
||||||
#
|
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
|
|
||||||
import warnings
|
|
||||||
from typing import Any, List, TypeVar, Union
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
def merge_lists(current: Union[List[T], T, None], new: Union[List[T], T]) -> List[T]:
|
|
||||||
"""
|
|
||||||
Merges two values into a single list.
|
|
||||||
|
|
||||||
Deprecated in favor of `haystack.components.agents.state.merge_lists`. It will be removed in Haystack 2.16.0.
|
|
||||||
|
|
||||||
|
|
||||||
If either `current` or `new` is not already a list, it is converted into one.
|
|
||||||
The function ensures that both inputs are treated as lists and concatenates them.
|
|
||||||
|
|
||||||
If `current` is None, it is treated as an empty list.
|
|
||||||
|
|
||||||
:param current: The existing value(s), either a single item or a list.
|
|
||||||
:param new: The new value(s) to merge, either a single item or a list.
|
|
||||||
:return: A list containing elements from both `current` and `new`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
warnings.warn(
|
|
||||||
"`haystack.dataclasses.state_utils.merge_lists` is deprecated and will be removed in Haystack 2.16.0. "
|
|
||||||
"Use `haystack.components.agents.state.merge_lists` instead.",
|
|
||||||
DeprecationWarning,
|
|
||||||
)
|
|
||||||
current_list = [] if current is None else current if isinstance(current, list) else [current]
|
|
||||||
new_list = new if isinstance(new, list) else [new]
|
|
||||||
return current_list + new_list
|
|
||||||
|
|
||||||
|
|
||||||
def replace_values(current: Any, new: Any) -> Any:
|
|
||||||
"""
|
|
||||||
Replace the `current` value with the `new` value.
|
|
||||||
|
|
||||||
:param current: The existing value
|
|
||||||
:param new: The new value to replace
|
|
||||||
:return: The new value
|
|
||||||
"""
|
|
||||||
|
|
||||||
warnings.warn(
|
|
||||||
"`haystack.dataclasses.state_utils.replace_values` is deprecated and will be removed in Haystack 2.16.0. "
|
|
||||||
"Use `haystack.components.agents.state.replace_values` instead.",
|
|
||||||
DeprecationWarning,
|
|
||||||
)
|
|
||||||
return new
|
|
@ -30,12 +30,6 @@ class ToolCallDelta:
|
|||||||
arguments: Optional[str] = field(default=None)
|
arguments: Optional[str] = field(default=None)
|
||||||
id: Optional[str] = field(default=None) # noqa: A003
|
id: Optional[str] = field(default=None) # noqa: A003
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
# NOTE: We allow for name and arguments to both be present because some providers like Mistral provide the
|
|
||||||
# name and full arguments in one chunk
|
|
||||||
if self.tool_name is None and self.arguments is None:
|
|
||||||
raise ValueError("At least one of tool_name or arguments must be provided.")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ComponentInfo:
|
class ComponentInfo:
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
#
|
#
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import warnings
|
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
from haystack.core.errors import DeserializationError, SerializationError
|
from haystack.core.errors import DeserializationError, SerializationError
|
||||||
@ -210,14 +209,12 @@ def _deserialize_value_with_schema(serialized: Dict[str, Any]) -> Any: # pylint
|
|||||||
|
|
||||||
if not schema_type:
|
if not schema_type:
|
||||||
# for backward comaptability till Haystack 2.16 we use legacy implementation
|
# for backward comaptability till Haystack 2.16 we use legacy implementation
|
||||||
warnings.warn(
|
raise DeserializationError(
|
||||||
"Missing 'type' key in 'serialization_schema'. This likely indicates that you're using a serialized "
|
"Missing 'type' key in 'serialization_schema'. This likely indicates that you're using a serialized "
|
||||||
"State object created with a version of Haystack older than 2.15.0. "
|
"State object created with a version of Haystack older than 2.15.0. "
|
||||||
"Support for the old serialization format will be removed in Haystack 2.16.0. "
|
"Support for the old serialization format is removed in Haystack 2.16.0. "
|
||||||
"Please upgrade to the new serialization format to ensure forward compatibility.",
|
"Please upgrade to the new serialization format to ensure forward compatibility."
|
||||||
DeprecationWarning,
|
|
||||||
)
|
)
|
||||||
return _deserialize_value_with_schema_legacy(serialized)
|
|
||||||
|
|
||||||
# Handle object case (dictionary with properties)
|
# Handle object case (dictionary with properties)
|
||||||
if schema_type == "object":
|
if schema_type == "object":
|
||||||
@ -331,61 +328,3 @@ def _deserialize_value(value: Any) -> Any: # pylint: disable=too-many-return-st
|
|||||||
|
|
||||||
# 4) Fallback (shouldn't usually happen with our schema)
|
# 4) Fallback (shouldn't usually happen with our schema)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
def _deserialize_value_with_schema_legacy(serialized: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Legacy function for deserializing a dictionary with schema information and data to original values.
|
|
||||||
|
|
||||||
Kept for backward compatibility till Haystack 2.16.0.
|
|
||||||
Takes a dict of the form:
|
|
||||||
{
|
|
||||||
"schema": {
|
|
||||||
"numbers": {"type": "integer"},
|
|
||||||
"messages": {"type": "array", "items": {"type": "haystack.dataclasses.chat_message.ChatMessage"}},
|
|
||||||
},
|
|
||||||
"data": {
|
|
||||||
"numbers": 1,
|
|
||||||
"messages": [{"role": "user", "meta": {}, "name": None, "content": [{"text": "Hello, world!"}]}],
|
|
||||||
}
|
|
||||||
|
|
||||||
:param serialized: The serialized dict with schema and data.
|
|
||||||
:returns: The deserialized dict with original values.
|
|
||||||
"""
|
|
||||||
schema = serialized.get("serialization_schema", {})
|
|
||||||
data = serialized.get("serialized_data", {})
|
|
||||||
|
|
||||||
result: Dict[str, Any] = {}
|
|
||||||
for field, raw in data.items():
|
|
||||||
info = schema.get(field)
|
|
||||||
# no schema entry → just deep-deserialize whatever we have
|
|
||||||
if not info:
|
|
||||||
result[field] = _deserialize_value(raw)
|
|
||||||
continue
|
|
||||||
|
|
||||||
t = info["type"]
|
|
||||||
|
|
||||||
# ARRAY case
|
|
||||||
if t == "array":
|
|
||||||
item_type = info["items"]["type"]
|
|
||||||
reconstructed = []
|
|
||||||
for item in raw:
|
|
||||||
envelope = {"type": item_type, "data": item}
|
|
||||||
reconstructed.append(_deserialize_value(envelope))
|
|
||||||
result[field] = reconstructed
|
|
||||||
|
|
||||||
# PRIMITIVE case
|
|
||||||
elif t in ("null", "boolean", "integer", "number", "string"):
|
|
||||||
result[field] = raw
|
|
||||||
|
|
||||||
# GENERIC OBJECT
|
|
||||||
elif t == "object":
|
|
||||||
envelope = {"type": "object", "data": raw}
|
|
||||||
result[field] = _deserialize_value(envelope)
|
|
||||||
|
|
||||||
# CUSTOM CLASS
|
|
||||||
else:
|
|
||||||
envelope = {"type": t, "data": raw}
|
|
||||||
result[field] = _deserialize_value(envelope)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
@ -153,7 +153,8 @@ integration-only-fast = 'pytest --maxfail=5 -m "integration and not slow" {args:
|
|||||||
integration-only-slow = 'pytest --maxfail=5 -m "integration and slow" {args:test}'
|
integration-only-slow = 'pytest --maxfail=5 -m "integration and slow" {args:test}'
|
||||||
all = 'pytest {args:test}'
|
all = 'pytest {args:test}'
|
||||||
|
|
||||||
types = "mypy --install-types --non-interactive --cache-dir=.mypy_cache/ {args:haystack}"
|
# TODO We want to eventually type the whole test folder
|
||||||
|
types = "mypy --install-types --non-interactive --cache-dir=.mypy_cache/ {args:haystack test/core/pipeline/features/test_run.py}"
|
||||||
lint = "pylint -ry -j 0 {args:haystack}"
|
lint = "pylint -ry -j 0 {args:haystack}"
|
||||||
|
|
||||||
[tool.hatch.envs.e2e]
|
[tool.hatch.envs.e2e]
|
||||||
@ -283,14 +284,24 @@ disallow_incomplete_defs = false
|
|||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 120
|
line-length = 120
|
||||||
exclude = [".github", "proposals"]
|
exclude = ["proposals"]
|
||||||
|
|
||||||
[tool.ruff.format]
|
[tool.ruff.format]
|
||||||
skip-magic-trailing-comma = true
|
skip-magic-trailing-comma = true
|
||||||
|
|
||||||
|
[tool.ruff.lint.per-file-ignores]
|
||||||
|
"test/**" = [
|
||||||
|
"D102", # Missing docstring in public method
|
||||||
|
"D103", # Missing docstring in public function
|
||||||
|
"D205", # 1 blank line required between summary line and description
|
||||||
|
"PLC0206", # Extracting value from dictionary without calling `.items()`
|
||||||
|
"SIM105", # try-except-pass instead of contextlib-suppress
|
||||||
|
"SIM117", # Use a single `with` statement with multiple contexts instead of nested `with` statements
|
||||||
|
]
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
isort.split-on-trailing-comma = false
|
isort.split-on-trailing-comma = false
|
||||||
exclude = ["test/**", "e2e/**"]
|
exclude = ["e2e/**"]
|
||||||
select = [
|
select = [
|
||||||
"ASYNC", # flake8-async
|
"ASYNC", # flake8-async
|
||||||
"C4", # flake8-comprehensions
|
"C4", # flake8-comprehensions
|
||||||
|
@ -0,0 +1,4 @@
|
|||||||
|
---
|
||||||
|
features:
|
||||||
|
- |
|
||||||
|
Add `to_dict` and `from_dict` to ByteStream so it is consistent with our other dataclasses in having serialization and deserialization methods.
|
@ -0,0 +1,4 @@
|
|||||||
|
---
|
||||||
|
features:
|
||||||
|
- |
|
||||||
|
Added the `tool_invoker_kwargs` param to Agent so additional kwargs can be passed to the ToolInvoker like `max_workers` and `enable_streaming_callback_passthrough`.
|
@ -0,0 +1,13 @@
|
|||||||
|
---
|
||||||
|
upgrade:
|
||||||
|
- |
|
||||||
|
`HuggingFaceAPIGenerator` might no longer work with the Hugging Face Inference API.
|
||||||
|
|
||||||
|
As of July 2025, the Hugging Face Inference API no longer offers generative models that support the
|
||||||
|
`text_generation` endpoint. Generative models are now only available through providers that support the
|
||||||
|
`chat_completion` endpoint.
|
||||||
|
As a result, the `HuggingFaceAPIGenerator` component might not work with the Hugging Face Inference API.
|
||||||
|
It still works with Hugging Face Inference Endpoints and self-hosted TGI instances.
|
||||||
|
|
||||||
|
To use generative models via Hugging Face Inference API, please use the `HuggingFaceAPIChatGenerator` component,
|
||||||
|
which supports the `chat_completion` endpoint.
|
@ -0,0 +1,4 @@
|
|||||||
|
---
|
||||||
|
fixes:
|
||||||
|
- |
|
||||||
|
Fix `_convert_streaming_chunks_to_chat_message` which is used to convert Haystack StreamingChunks into a Haystack ChatMessage. This fixes the scenario where one StreamingChunk contains two ToolCallDetlas in StreamingChunk.tool_calls. With this fix this correctly saves both ToolCallDeltas whereas before they were overwriting each other. This only occurs with some LLM providers like Mistral (and not OpenAI) due to how the provider returns tool calls.
|
@ -0,0 +1,3 @@
|
|||||||
|
---
|
||||||
|
fixes:
|
||||||
|
- Fixed a bug in the `print_streaming_chunk` utility function that prevented tool call name from being printed.
|
@ -0,0 +1,4 @@
|
|||||||
|
---
|
||||||
|
enhancements:
|
||||||
|
- |
|
||||||
|
We relaxed the requirement that in ToolCallDelta (introduced in Haystack 2.15) which required the parameters arguments or name to be populated to be able to create a ToolCallDelta dataclass. We remove this requirement to be more in line with OpenAI's SDK and since this was causing errors for some hosted versions of open source models following OpenAI's SDK specification.
|
@ -0,0 +1,6 @@
|
|||||||
|
---
|
||||||
|
upgrade:
|
||||||
|
- |
|
||||||
|
The deprecated `async_executor` parameter has been removed from the `ToolInvoker` class.
|
||||||
|
Please use the `max_workers` parameter instead and a `ThreadPoolExecutor` with
|
||||||
|
these workers will be created automatically for parallel tool invocations.
|
@ -0,0 +1,5 @@
|
|||||||
|
---
|
||||||
|
upgrade:
|
||||||
|
- |
|
||||||
|
The deprecated `State` class has been removed from the `haystack.dataclasses` module.
|
||||||
|
The `State` class is now part of the `haystack.components.agents` module.
|
@ -0,0 +1,6 @@
|
|||||||
|
---
|
||||||
|
upgrade:
|
||||||
|
- |
|
||||||
|
Remove the deserialize_value_with_schema_legacy function from the base_serialization module.
|
||||||
|
This function was used to deserialize State objects created with Haystack 2.14.0 or older.
|
||||||
|
Support for the old serialization format is removed in Haystack 2.16.0.
|
@ -5,17 +5,16 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Iterator, Dict, Any, List, Optional, Union
|
from typing import Any, Dict, Iterator, List, Optional, Union
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
from unittest.mock import MagicMock, patch, AsyncMock
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from openai import Stream
|
from openai import Stream
|
||||||
from openai.types.chat import ChatCompletionChunk, chat_completion_chunk
|
from openai.types.chat import ChatCompletionChunk, chat_completion_chunk
|
||||||
|
|
||||||
from haystack.tracing.logging_tracer import LoggingTracer
|
|
||||||
from haystack import Pipeline, tracing
|
from haystack import Pipeline, tracing
|
||||||
from haystack.components.agents import Agent
|
from haystack.components.agents import Agent
|
||||||
|
from haystack.components.agents.state import merge_lists
|
||||||
from haystack.components.builders.chat_prompt_builder import ChatPromptBuilder
|
from haystack.components.builders.chat_prompt_builder import ChatPromptBuilder
|
||||||
from haystack.components.builders.prompt_builder import PromptBuilder
|
from haystack.components.builders.prompt_builder import PromptBuilder
|
||||||
from haystack.components.generators.chat.openai import OpenAIChatGenerator
|
from haystack.components.generators.chat.openai import OpenAIChatGenerator
|
||||||
@ -24,11 +23,10 @@ from haystack.core.component.types import OutputSocket
|
|||||||
from haystack.dataclasses import ChatMessage, ToolCall
|
from haystack.dataclasses import ChatMessage, ToolCall
|
||||||
from haystack.dataclasses.chat_message import ChatRole, TextContent
|
from haystack.dataclasses.chat_message import ChatRole, TextContent
|
||||||
from haystack.dataclasses.streaming_chunk import StreamingChunk
|
from haystack.dataclasses.streaming_chunk import StreamingChunk
|
||||||
|
from haystack.tools import ComponentTool, Tool
|
||||||
from haystack.tools import Tool, ComponentTool
|
|
||||||
from haystack.tools.toolset import Toolset
|
from haystack.tools.toolset import Toolset
|
||||||
from haystack.utils import serialize_callable, Secret
|
from haystack.tracing.logging_tracer import LoggingTracer
|
||||||
from haystack.components.agents.state import merge_lists
|
from haystack.utils import Secret, serialize_callable
|
||||||
|
|
||||||
|
|
||||||
def streaming_callback_for_serde(chunk: StreamingChunk):
|
def streaming_callback_for_serde(chunk: StreamingChunk):
|
||||||
@ -174,6 +172,7 @@ class TestAgent:
|
|||||||
tools=[weather_tool, component_tool],
|
tools=[weather_tool, component_tool],
|
||||||
exit_conditions=["text", "weather_tool"],
|
exit_conditions=["text", "weather_tool"],
|
||||||
state_schema={"foo": {"type": str}},
|
state_schema={"foo": {"type": str}},
|
||||||
|
tool_invoker_kwargs={"max_workers": 5, "enable_streaming_callback_passthrough": True},
|
||||||
)
|
)
|
||||||
serialized_agent = agent.to_dict()
|
serialized_agent = agent.to_dict()
|
||||||
assert serialized_agent == {
|
assert serialized_agent == {
|
||||||
@ -236,8 +235,9 @@ class TestAgent:
|
|||||||
"exit_conditions": ["text", "weather_tool"],
|
"exit_conditions": ["text", "weather_tool"],
|
||||||
"state_schema": {"foo": {"type": "str"}},
|
"state_schema": {"foo": {"type": "str"}},
|
||||||
"max_agent_steps": 100,
|
"max_agent_steps": 100,
|
||||||
"raise_on_tool_invocation_failure": False,
|
|
||||||
"streaming_callback": None,
|
"streaming_callback": None,
|
||||||
|
"raise_on_tool_invocation_failure": False,
|
||||||
|
"tool_invoker_kwargs": {"max_workers": 5, "enable_streaming_callback_passthrough": True},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -294,6 +294,7 @@ class TestAgent:
|
|||||||
"max_agent_steps": 100,
|
"max_agent_steps": 100,
|
||||||
"raise_on_tool_invocation_failure": False,
|
"raise_on_tool_invocation_failure": False,
|
||||||
"streaming_callback": None,
|
"streaming_callback": None,
|
||||||
|
"tool_invoker_kwargs": None,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -361,6 +362,7 @@ class TestAgent:
|
|||||||
"max_agent_steps": 100,
|
"max_agent_steps": 100,
|
||||||
"raise_on_tool_invocation_failure": False,
|
"raise_on_tool_invocation_failure": False,
|
||||||
"streaming_callback": None,
|
"streaming_callback": None,
|
||||||
|
"tool_invoker_kwargs": {"max_workers": 5, "enable_streaming_callback_passthrough": True},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
agent = Agent.from_dict(data)
|
agent = Agent.from_dict(data)
|
||||||
@ -375,6 +377,9 @@ class TestAgent:
|
|||||||
"foo": {"type": str},
|
"foo": {"type": str},
|
||||||
"messages": {"handler": merge_lists, "type": List[ChatMessage]},
|
"messages": {"handler": merge_lists, "type": List[ChatMessage]},
|
||||||
}
|
}
|
||||||
|
assert agent.tool_invoker_kwargs == {"max_workers": 5, "enable_streaming_callback_passthrough": True}
|
||||||
|
assert agent._tool_invoker.max_workers == 5
|
||||||
|
assert agent._tool_invoker.enable_streaming_callback_passthrough is True
|
||||||
|
|
||||||
def test_from_dict_with_toolset(self, monkeypatch):
|
def test_from_dict_with_toolset(self, monkeypatch):
|
||||||
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
|
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
|
||||||
@ -426,6 +431,7 @@ class TestAgent:
|
|||||||
"max_agent_steps": 100,
|
"max_agent_steps": 100,
|
||||||
"raise_on_tool_invocation_failure": False,
|
"raise_on_tool_invocation_failure": False,
|
||||||
"streaming_callback": None,
|
"streaming_callback": None,
|
||||||
|
"tool_invoker_kwargs": None,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
agent = Agent.from_dict(data)
|
agent = Agent.from_dict(data)
|
||||||
@ -882,15 +888,15 @@ class TestAgentTracing:
|
|||||||
'{"messages": "list", "tools": "list"}',
|
'{"messages": "list", "tools": "list"}',
|
||||||
"{}",
|
"{}",
|
||||||
"{}",
|
"{}",
|
||||||
'{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}], "tools": [{"type": "haystack.tools.tool.Tool", "data": {"name": "weather_tool", "description": "Provides weather information for a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}, "function": "test_agent.weather_function", "outputs_to_string": null, "inputs_from_state": null, "outputs_to_state": null}}]}',
|
'{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}], "tools": [{"type": "haystack.tools.tool.Tool", "data": {"name": "weather_tool", "description": "Provides weather information for a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}, "function": "test_agent.weather_function", "outputs_to_string": null, "inputs_from_state": null, "outputs_to_state": null}}]}', # noqa: E501
|
||||||
1,
|
1,
|
||||||
'{"replies": [{"role": "assistant", "meta": {}, "name": null, "content": [{"text": "Hello"}]}]}',
|
'{"replies": [{"role": "assistant", "meta": {}, "name": null, "content": [{"text": "Hello"}]}]}',
|
||||||
100,
|
100,
|
||||||
'[{"type": "haystack.tools.tool.Tool", "data": {"name": "weather_tool", "description": "Provides weather information for a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}, "function": "test_agent.weather_function", "outputs_to_string": null, "inputs_from_state": null, "outputs_to_state": null}}]',
|
'[{"type": "haystack.tools.tool.Tool", "data": {"name": "weather_tool", "description": "Provides weather information for a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}, "function": "test_agent.weather_function", "outputs_to_string": null, "inputs_from_state": null, "outputs_to_state": null}}]', # noqa: E501
|
||||||
'["text"]',
|
'["text"]',
|
||||||
'{"messages": {"type": "typing.List[haystack.dataclasses.chat_message.ChatMessage]", "handler": "haystack.components.agents.state.state_utils.merge_lists"}}',
|
'{"messages": {"type": "typing.List[haystack.dataclasses.chat_message.ChatMessage]", "handler": "haystack.components.agents.state.state_utils.merge_lists"}}', # noqa: E501
|
||||||
'{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}], "streaming_callback": null}',
|
'{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}], "streaming_callback": null}', # noqa: E501
|
||||||
'{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}, {"role": "assistant", "meta": {}, "name": null, "content": [{"text": "Hello"}]}]}',
|
'{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}, {"role": "assistant", "meta": {}, "name": null, "content": [{"text": "Hello"}]}]}', # noqa: E501
|
||||||
1,
|
1,
|
||||||
]
|
]
|
||||||
for idx, record in enumerate(tags_records):
|
for idx, record in enumerate(tags_records):
|
||||||
@ -942,15 +948,15 @@ class TestAgentTracing:
|
|||||||
'{"messages": "list", "tools": "list"}',
|
'{"messages": "list", "tools": "list"}',
|
||||||
"{}",
|
"{}",
|
||||||
"{}",
|
"{}",
|
||||||
'{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}], "tools": [{"type": "haystack.tools.tool.Tool", "data": {"name": "weather_tool", "description": "Provides weather information for a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}, "function": "test_agent.weather_function", "outputs_to_string": null, "inputs_from_state": null, "outputs_to_state": null}}]}',
|
'{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}], "tools": [{"type": "haystack.tools.tool.Tool", "data": {"name": "weather_tool", "description": "Provides weather information for a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}, "function": "test_agent.weather_function", "outputs_to_string": null, "inputs_from_state": null, "outputs_to_state": null}}]}', # noqa: E501
|
||||||
1,
|
1,
|
||||||
'{"replies": [{"role": "assistant", "meta": {}, "name": null, "content": [{"text": "Hello from run_async"}]}]}',
|
'{"replies": [{"role": "assistant", "meta": {}, "name": null, "content": [{"text": "Hello from run_async"}]}]}', # noqa: E501
|
||||||
100,
|
100,
|
||||||
'[{"type": "haystack.tools.tool.Tool", "data": {"name": "weather_tool", "description": "Provides weather information for a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}, "function": "test_agent.weather_function", "outputs_to_string": null, "inputs_from_state": null, "outputs_to_state": null}}]',
|
'[{"type": "haystack.tools.tool.Tool", "data": {"name": "weather_tool", "description": "Provides weather information for a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}, "function": "test_agent.weather_function", "outputs_to_string": null, "inputs_from_state": null, "outputs_to_state": null}}]', # noqa: E501
|
||||||
'["text"]',
|
'["text"]',
|
||||||
'{"messages": {"type": "typing.List[haystack.dataclasses.chat_message.ChatMessage]", "handler": "haystack.components.agents.state.state_utils.merge_lists"}}',
|
'{"messages": {"type": "typing.List[haystack.dataclasses.chat_message.ChatMessage]", "handler": "haystack.components.agents.state.state_utils.merge_lists"}}', # noqa: E501
|
||||||
'{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}], "streaming_callback": null}',
|
'{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}], "streaming_callback": null}', # noqa: E501
|
||||||
'{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}, {"role": "assistant", "meta": {}, "name": null, "content": [{"text": "Hello from run_async"}]}]}',
|
'{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}, {"role": "assistant", "meta": {}, "name": null, "content": [{"text": "Hello from run_async"}]}]}', # noqa: E501
|
||||||
1,
|
1,
|
||||||
]
|
]
|
||||||
for idx, record in enumerate(tags_records):
|
for idx, record in enumerate(tags_records):
|
||||||
|
@ -2,21 +2,22 @@
|
|||||||
#
|
#
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import pytest
|
import inspect
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, Generic, List, Optional, TypeVar, Union
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from haystack.dataclasses import ChatMessage
|
|
||||||
from haystack.components.agents.state.state import (
|
from haystack.components.agents.state.state import (
|
||||||
State,
|
State,
|
||||||
_validate_schema,
|
|
||||||
_schema_to_dict,
|
|
||||||
_schema_from_dict,
|
|
||||||
_is_list_type,
|
_is_list_type,
|
||||||
merge_lists,
|
|
||||||
_is_valid_type,
|
_is_valid_type,
|
||||||
|
_schema_from_dict,
|
||||||
|
_schema_to_dict,
|
||||||
|
_validate_schema,
|
||||||
|
merge_lists,
|
||||||
)
|
)
|
||||||
from typing import List, Dict, Optional, Union, TypeVar, Generic
|
from haystack.dataclasses import ChatMessage
|
||||||
import inspect
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -432,45 +433,3 @@ class TestState:
|
|||||||
assert state.data["numbers"] == 1
|
assert state.data["numbers"] == 1
|
||||||
assert state.data["messages"] == [ChatMessage.from_user(text="Hello, world!")]
|
assert state.data["messages"] == [ChatMessage.from_user(text="Hello, world!")]
|
||||||
assert state.data["dict_of_lists"] == {"numbers": [1, 2, 3]}
|
assert state.data["dict_of_lists"] == {"numbers": [1, 2, 3]}
|
||||||
|
|
||||||
def test_state_from_dict_legacy(self):
|
|
||||||
# this is the old format of the state dictionary
|
|
||||||
# it is kept for backward compatibility
|
|
||||||
# it will be removed in Haystack 2.16.0
|
|
||||||
state_dict = {
|
|
||||||
"schema": {
|
|
||||||
"numbers": {"type": "int", "handler": "haystack.components.agents.state.state_utils.replace_values"},
|
|
||||||
"messages": {
|
|
||||||
"type": "typing.List[haystack.dataclasses.chat_message.ChatMessage]",
|
|
||||||
"handler": "haystack.components.agents.state.state_utils.merge_lists",
|
|
||||||
},
|
|
||||||
"dict_of_lists": {
|
|
||||||
"type": "dict",
|
|
||||||
"handler": "haystack.components.agents.state.state_utils.replace_values",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"data": {
|
|
||||||
"serialization_schema": {
|
|
||||||
"numbers": {"type": "integer"},
|
|
||||||
"messages": {"type": "array", "items": {"type": "haystack.dataclasses.chat_message.ChatMessage"}},
|
|
||||||
"dict_of_lists": {"type": "object"},
|
|
||||||
},
|
|
||||||
"serialized_data": {
|
|
||||||
"numbers": 1,
|
|
||||||
"messages": [{"role": "user", "meta": {}, "name": None, "content": [{"text": "Hello, world!"}]}],
|
|
||||||
"dict_of_lists": {"numbers": [1, 2, 3]},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
state = State.from_dict(state_dict)
|
|
||||||
# Check types are correctly converted
|
|
||||||
assert state.schema["numbers"]["type"] == int
|
|
||||||
assert state.schema["dict_of_lists"]["type"] == dict
|
|
||||||
# Check handlers are functions, not comparing exact functions as they might be different references
|
|
||||||
assert callable(state.schema["numbers"]["handler"])
|
|
||||||
assert callable(state.schema["messages"]["handler"])
|
|
||||||
assert callable(state.schema["dict_of_lists"]["handler"])
|
|
||||||
# Check data is correct
|
|
||||||
assert state.data["numbers"] == 1
|
|
||||||
assert state.data["messages"] == [ChatMessage.from_user(text="Hello, world!")]
|
|
||||||
assert state.data["dict_of_lists"] == {"numbers": [1, 2, 3]}
|
|
||||||
|
@ -4,18 +4,17 @@
|
|||||||
|
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import patch, MagicMock
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from haystack import Pipeline
|
from haystack import Pipeline
|
||||||
from haystack.components.fetchers import LinkContentFetcher
|
|
||||||
from haystack.dataclasses import Document, ByteStream
|
|
||||||
from haystack.components.audio import LocalWhisperTranscriber
|
from haystack.components.audio import LocalWhisperTranscriber
|
||||||
|
from haystack.components.fetchers import LinkContentFetcher
|
||||||
|
from haystack.dataclasses import ByteStream, Document
|
||||||
from haystack.utils.device import ComponentDevice, Device
|
from haystack.utils.device import ComponentDevice, Device
|
||||||
|
|
||||||
|
|
||||||
SAMPLES_PATH = Path(__file__).parent.parent.parent / "test_files"
|
SAMPLES_PATH = Path(__file__).parent.parent.parent / "test_files"
|
||||||
|
|
||||||
|
|
||||||
@ -192,12 +191,12 @@ class TestLocalWhisperTranscriber:
|
|||||||
docs = output["documents"]
|
docs = output["documents"]
|
||||||
assert len(docs) == 3
|
assert len(docs) == 3
|
||||||
|
|
||||||
assert all(word in docs[0].content.strip().lower() for word in {"content", "the", "document"}), (
|
assert all(word in docs[0].content.strip().lower() for word in ("content", "the", "document")), (
|
||||||
f"Expected words not found in: {docs[0].content.strip().lower()}"
|
f"Expected words not found in: {docs[0].content.strip().lower()}"
|
||||||
)
|
)
|
||||||
assert test_files_path / "audio" / "this is the content of the document.wav" == docs[0].meta["audio_file"]
|
assert test_files_path / "audio" / "this is the content of the document.wav" == docs[0].meta["audio_file"]
|
||||||
|
|
||||||
assert all(word in docs[1].content.strip().lower() for word in {"context", "answer"}), (
|
assert all(word in docs[1].content.strip().lower() for word in ("context", "answer")), (
|
||||||
f"Expected words not found in: {docs[1].content.strip().lower()}"
|
f"Expected words not found in: {docs[1].content.strip().lower()}"
|
||||||
)
|
)
|
||||||
path = test_files_path / "audio" / "the context for this answer is here.wav"
|
path = test_files_path / "audio" / "the context for this answer is here.wav"
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from haystack import Pipeline
|
from haystack import Pipeline
|
||||||
|
@ -2,14 +2,15 @@
|
|||||||
#
|
#
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
from jinja2 import TemplateSyntaxError
|
|
||||||
import arrow
|
|
||||||
import logging
|
import logging
|
||||||
import pytest
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import arrow
|
||||||
|
import pytest
|
||||||
|
from jinja2 import TemplateSyntaxError
|
||||||
|
|
||||||
from haystack.components.builders.chat_prompt_builder import ChatPromptBuilder
|
|
||||||
from haystack import component
|
from haystack import component
|
||||||
|
from haystack.components.builders.chat_prompt_builder import ChatPromptBuilder
|
||||||
from haystack.core.pipeline.pipeline import Pipeline
|
from haystack.core.pipeline.pipeline import Pipeline
|
||||||
from haystack.dataclasses.chat_message import ChatMessage
|
from haystack.dataclasses.chat_message import ChatMessage
|
||||||
from haystack.dataclasses.document import Document
|
from haystack.dataclasses.document import Document
|
||||||
|
@ -2,11 +2,11 @@
|
|||||||
#
|
#
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import logging
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import arrow
|
import arrow
|
||||||
import logging
|
|
||||||
import pytest
|
import pytest
|
||||||
from jinja2 import TemplateSyntaxError
|
from jinja2 import TemplateSyntaxError
|
||||||
|
|
||||||
|
@ -2,13 +2,14 @@
|
|||||||
#
|
#
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from haystack import Document, DeserializationError
|
from haystack import DeserializationError, Document
|
||||||
from haystack.testing.factory import document_store_class
|
|
||||||
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
|
||||||
from haystack.components.caching.cache_checker import CacheChecker
|
from haystack.components.caching.cache_checker import CacheChecker
|
||||||
from unittest.mock import MagicMock
|
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
||||||
|
from haystack.testing.factory import document_store_class
|
||||||
|
|
||||||
|
|
||||||
class TestCacheChecker:
|
class TestCacheChecker:
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from haystack import Document
|
from haystack import Document
|
||||||
|
@ -2,10 +2,10 @@
|
|||||||
#
|
#
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from haystack import Document, Pipeline
|
from haystack import Document, Pipeline
|
||||||
from haystack.components.classifiers import TransformersZeroShotDocumentClassifier
|
from haystack.components.classifiers import TransformersZeroShotDocumentClassifier
|
||||||
from haystack.components.retrievers import InMemoryBM25Retriever
|
from haystack.components.retrievers import InMemoryBM25Retriever
|
||||||
@ -30,7 +30,7 @@ class TestTransformersZeroShotDocumentClassifier:
|
|||||||
)
|
)
|
||||||
component_dict = component.to_dict()
|
component_dict = component.to_dict()
|
||||||
assert component_dict == {
|
assert component_dict == {
|
||||||
"type": "haystack.components.classifiers.zero_shot_document_classifier.TransformersZeroShotDocumentClassifier",
|
"type": "haystack.components.classifiers.zero_shot_document_classifier.TransformersZeroShotDocumentClassifier", # noqa: E501
|
||||||
"init_parameters": {
|
"init_parameters": {
|
||||||
"model": "cross-encoder/nli-deberta-v3-xsmall",
|
"model": "cross-encoder/nli-deberta-v3-xsmall",
|
||||||
"labels": ["positive", "negative"],
|
"labels": ["positive", "negative"],
|
||||||
@ -47,7 +47,7 @@ class TestTransformersZeroShotDocumentClassifier:
|
|||||||
monkeypatch.delenv("HF_API_TOKEN", raising=False)
|
monkeypatch.delenv("HF_API_TOKEN", raising=False)
|
||||||
monkeypatch.delenv("HF_TOKEN", raising=False)
|
monkeypatch.delenv("HF_TOKEN", raising=False)
|
||||||
data = {
|
data = {
|
||||||
"type": "haystack.components.classifiers.zero_shot_document_classifier.TransformersZeroShotDocumentClassifier",
|
"type": "haystack.components.classifiers.zero_shot_document_classifier.TransformersZeroShotDocumentClassifier", # noqa: E501
|
||||||
"init_parameters": {
|
"init_parameters": {
|
||||||
"model": "cross-encoder/nli-deberta-v3-xsmall",
|
"model": "cross-encoder/nli-deberta-v3-xsmall",
|
||||||
"labels": ["positive", "negative"],
|
"labels": ["positive", "negative"],
|
||||||
@ -76,7 +76,7 @@ class TestTransformersZeroShotDocumentClassifier:
|
|||||||
monkeypatch.delenv("HF_API_TOKEN", raising=False)
|
monkeypatch.delenv("HF_API_TOKEN", raising=False)
|
||||||
monkeypatch.delenv("HF_TOKEN", raising=False)
|
monkeypatch.delenv("HF_TOKEN", raising=False)
|
||||||
data = {
|
data = {
|
||||||
"type": "haystack.components.classifiers.zero_shot_document_classifier.TransformersZeroShotDocumentClassifier",
|
"type": "haystack.components.classifiers.zero_shot_document_classifier.TransformersZeroShotDocumentClassifier", # noqa: E501
|
||||||
"init_parameters": {"model": "cross-encoder/nli-deberta-v3-xsmall", "labels": ["positive", "negative"]},
|
"init_parameters": {"model": "cross-encoder/nli-deberta-v3-xsmall", "labels": ["positive", "negative"]},
|
||||||
}
|
}
|
||||||
component = TransformersZeroShotDocumentClassifier.from_dict(data)
|
component = TransformersZeroShotDocumentClassifier.from_dict(data)
|
||||||
|
@ -6,9 +6,10 @@ import os
|
|||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from haystack import Pipeline
|
from haystack import Pipeline
|
||||||
from haystack.utils import Secret
|
|
||||||
from haystack.components.connectors.openapi import OpenAPIConnector
|
from haystack.components.connectors.openapi import OpenAPIConnector
|
||||||
|
from haystack.utils import Secret
|
||||||
|
|
||||||
# Mock OpenAPI spec for testing
|
# Mock OpenAPI spec for testing
|
||||||
MOCK_OPENAPI_SPEC = """
|
MOCK_OPENAPI_SPEC = """
|
||||||
|
@ -7,19 +7,18 @@ import os
|
|||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
from unittest.mock import MagicMock, Mock, patch
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
|
from openapi3 import OpenAPI
|
||||||
|
|
||||||
from haystack import Pipeline
|
from haystack import Pipeline
|
||||||
import pytest
|
from haystack.components.connectors import OpenAPIServiceConnector
|
||||||
from haystack.components.converters.openapi_functions import OpenAPIServiceToFunctions
|
from haystack.components.converters.openapi_functions import OpenAPIServiceToFunctions
|
||||||
from haystack.components.converters.output_adapter import OutputAdapter
|
from haystack.components.converters.output_adapter import OutputAdapter
|
||||||
from haystack.components.generators.chat.openai import OpenAIChatGenerator
|
from haystack.components.generators.chat.openai import OpenAIChatGenerator
|
||||||
from haystack.components.generators.utils import print_streaming_chunk
|
from haystack.components.generators.utils import print_streaming_chunk
|
||||||
from haystack.dataclasses.byte_stream import ByteStream
|
|
||||||
from openapi3 import OpenAPI
|
|
||||||
|
|
||||||
from haystack.components.connectors import OpenAPIServiceConnector
|
|
||||||
from haystack.dataclasses import ChatMessage, ToolCall
|
from haystack.dataclasses import ChatMessage, ToolCall
|
||||||
|
from haystack.dataclasses.byte_stream import ByteStream
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -7,8 +7,8 @@ import os
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from haystack.dataclasses import ByteStream
|
|
||||||
from haystack.components.converters.csv import CSVToDocument
|
from haystack.components.converters.csv import CSVToDocument
|
||||||
|
from haystack.dataclasses import ByteStream
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -32,7 +32,7 @@ class TestCSVToDocument:
|
|||||||
output = converter.run(sources=files)
|
output = converter.run(sources=files)
|
||||||
docs = output["documents"]
|
docs = output["documents"]
|
||||||
assert len(docs) == 3
|
assert len(docs) == 3
|
||||||
assert "Name,Age\r\nJohn Doe,27\r\nJane Smith,37\r\nMike Johnson,47\r\n" == docs[0].content
|
assert docs[0].content == "Name,Age\r\nJohn Doe,27\r\nJane Smith,37\r\nMike Johnson,47\r\n"
|
||||||
assert isinstance(docs[0].content, str)
|
assert isinstance(docs[0].content, str)
|
||||||
assert docs[0].meta == {"file_path": os.path.basename(bytestream.meta["file_path"]), "key": "value"}
|
assert docs[0].meta == {"file_path": os.path.basename(bytestream.meta["file_path"]), "key": "value"}
|
||||||
assert docs[1].meta["file_path"] == os.path.basename(files[1])
|
assert docs[1].meta["file_path"] == os.path.basename(files[1])
|
||||||
@ -50,7 +50,7 @@ class TestCSVToDocument:
|
|||||||
output = converter.run(sources=files)
|
output = converter.run(sources=files)
|
||||||
docs = output["documents"]
|
docs = output["documents"]
|
||||||
assert len(docs) == 3
|
assert len(docs) == 3
|
||||||
assert "Name,Age\r\nJohn Doe,27\r\nJane Smith,37\r\nMike Johnson,47\r\n" == docs[0].content
|
assert docs[0].content == "Name,Age\r\nJohn Doe,27\r\nJane Smith,37\r\nMike Johnson,47\r\n"
|
||||||
assert isinstance(docs[0].content, str)
|
assert isinstance(docs[0].content, str)
|
||||||
assert docs[0].meta["file_path"] == "sample_1.csv"
|
assert docs[0].meta["file_path"] == "sample_1.csv"
|
||||||
assert docs[0].meta["key"] == "value"
|
assert docs[0].meta["key"] == "value"
|
||||||
|
@ -2,15 +2,16 @@
|
|||||||
#
|
#
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import logging
|
|
||||||
import pytest
|
|
||||||
import csv
|
import csv
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from haystack import Document, Pipeline
|
from haystack import Document, Pipeline
|
||||||
from haystack.components.converters.docx import DOCXMetadata, DOCXToDocument, DOCXTableFormat, DOCXLinkFormat
|
from haystack.components.converters.docx import DOCXLinkFormat, DOCXMetadata, DOCXTableFormat, DOCXToDocument
|
||||||
from haystack.dataclasses import ByteStream
|
from haystack.dataclasses import ByteStream
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,9 +4,9 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
from haystack.components.converters import HTMLToDocument
|
from haystack.components.converters import HTMLToDocument
|
||||||
from haystack.dataclasses import ByteStream
|
from haystack.dataclasses import ByteStream
|
||||||
|
@ -3,17 +3,16 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
|
||||||
from unittest.mock import patch
|
|
||||||
from pathlib import Path
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from haystack.components.converters import JSONConverter
|
from haystack.components.converters import JSONConverter
|
||||||
from haystack.dataclasses import ByteStream
|
from haystack.dataclasses import ByteStream
|
||||||
|
|
||||||
|
|
||||||
test_data = [
|
test_data = [
|
||||||
{
|
{
|
||||||
"year": "1997",
|
"year": "1997",
|
||||||
@ -23,7 +22,8 @@ test_data = [
|
|||||||
"id": "674",
|
"id": "674",
|
||||||
"firstname": "Dario",
|
"firstname": "Dario",
|
||||||
"surname": "Fokin",
|
"surname": "Fokin",
|
||||||
"motivation": "who emulates the jesters of the Middle Ages in scourging authority and upholding the dignity of the downtrodden",
|
"motivation": "who emulates the jesters of the Middle Ages in scourging authority and upholding the "
|
||||||
|
"dignity of the downtrodden",
|
||||||
"share": "1",
|
"share": "1",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -56,7 +56,9 @@ test_data = [
|
|||||||
"id": "46",
|
"id": "46",
|
||||||
"firstname": "Enrico",
|
"firstname": "Enrico",
|
||||||
"surname": "Fermi",
|
"surname": "Fermi",
|
||||||
"motivation": "for his demonstrations of the existence of new radioactive elements produced by neutron irradiation, and for his related discovery of nuclear reactions brought about by slow neutrons",
|
"motivation": "for his demonstrations of the existence of new radioactive elements produced by neutron "
|
||||||
|
"irradiation, and for his related discovery of nuclear reactions brought about by slow "
|
||||||
|
"neutrons",
|
||||||
"share": "1",
|
"share": "1",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -211,7 +213,8 @@ def test_run_with_non_json_file(tmpdir, caplog):
|
|||||||
assert len(records) == 1
|
assert len(records) == 1
|
||||||
assert (
|
assert (
|
||||||
records[0].msg
|
records[0].msg
|
||||||
== f"Failed to extract text from {test_file}. Skipping it. Error: parse error: Invalid numeric literal at line 1, column 5"
|
== f"Failed to extract text from {test_file}. Skipping it. Error: parse error: Invalid numeric literal at "
|
||||||
|
f"line 1, column 5"
|
||||||
)
|
)
|
||||||
assert result == {"documents": []}
|
assert result == {"documents": []}
|
||||||
|
|
||||||
@ -481,7 +484,8 @@ def test_run_with_jq_schema_content_key_and_extra_meta_fields_literal(tmpdir):
|
|||||||
assert len(result["documents"]) == 4
|
assert len(result["documents"]) == 4
|
||||||
assert (
|
assert (
|
||||||
result["documents"][0].content
|
result["documents"][0].content
|
||||||
== "who emulates the jesters of the Middle Ages in scourging authority and upholding the dignity of the downtrodden"
|
== "who emulates the jesters of the Middle Ages in scourging authority and upholding the dignity of the "
|
||||||
|
"downtrodden"
|
||||||
)
|
)
|
||||||
assert result["documents"][0].meta == {
|
assert result["documents"][0].meta == {
|
||||||
"file_path": os.path.basename(first_test_file),
|
"file_path": os.path.basename(first_test_file),
|
||||||
|
@ -3,8 +3,8 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from haystack.components.converters.markdown import MarkdownToDocument
|
from haystack.components.converters.markdown import MarkdownToDocument
|
||||||
|
@ -5,10 +5,10 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from haystack import Document, Pipeline
|
from haystack import Document, Pipeline
|
||||||
from haystack.core.pipeline.base import component_to_dict, component_from_dict
|
|
||||||
from haystack.core.component.component import Component
|
|
||||||
from haystack.dataclasses import ByteStream
|
|
||||||
from haystack.components.converters.multi_file_converter import MultiFileConverter
|
from haystack.components.converters.multi_file_converter import MultiFileConverter
|
||||||
|
from haystack.core.component.component import Component
|
||||||
|
from haystack.core.pipeline.base import component_from_dict, component_to_dict
|
||||||
|
from haystack.dataclasses import ByteStream
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -2,15 +2,15 @@
|
|||||||
#
|
#
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import List
|
|
||||||
import json
|
import json
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from haystack import Pipeline, component
|
from haystack import Pipeline, component
|
||||||
from haystack.dataclasses import Document
|
|
||||||
from haystack.components.converters import OutputAdapter
|
from haystack.components.converters import OutputAdapter
|
||||||
from haystack.components.converters.output_adapter import OutputAdaptationException
|
from haystack.components.converters.output_adapter import OutputAdaptationException
|
||||||
|
from haystack.dataclasses import Document
|
||||||
|
|
||||||
|
|
||||||
def custom_filter_to_sede(value):
|
def custom_filter_to_sede(value):
|
||||||
|
@ -8,9 +8,9 @@ from unittest.mock import patch
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from haystack import Document
|
from haystack import Document
|
||||||
|
from haystack.components.converters.pdfminer import PDFMinerToDocument
|
||||||
from haystack.components.preprocessors import DocumentSplitter
|
from haystack.components.preprocessors import DocumentSplitter
|
||||||
from haystack.dataclasses import ByteStream
|
from haystack.dataclasses import ByteStream
|
||||||
from haystack.components.converters.pdfminer import PDFMinerToDocument
|
|
||||||
|
|
||||||
|
|
||||||
class TestPDFMinerToDocument:
|
class TestPDFMinerToDocument:
|
||||||
|
@ -5,8 +5,8 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from haystack.dataclasses import ByteStream
|
|
||||||
from haystack.components.converters.pptx import PPTXToDocument
|
from haystack.components.converters.pptx import PPTXToDocument
|
||||||
|
from haystack.dataclasses import ByteStream
|
||||||
|
|
||||||
|
|
||||||
class TestPPTXToDocument:
|
class TestPPTXToDocument:
|
||||||
|
@ -3,12 +3,12 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from unittest.mock import patch, Mock
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from haystack import Document
|
from haystack import Document
|
||||||
from haystack.components.converters.pypdf import PyPDFToDocument, PyPDFExtractionMode
|
from haystack.components.converters.pypdf import PyPDFExtractionMode, PyPDFToDocument
|
||||||
from haystack.components.preprocessors import DocumentSplitter
|
from haystack.components.preprocessors import DocumentSplitter
|
||||||
from haystack.dataclasses import ByteStream
|
from haystack.dataclasses import ByteStream
|
||||||
|
|
||||||
|
@ -7,8 +7,8 @@ import os
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from haystack.dataclasses import ByteStream
|
|
||||||
from haystack.components.converters.txt import TextFileToDocument
|
from haystack.components.converters.txt import TextFileToDocument
|
||||||
|
from haystack.dataclasses import ByteStream
|
||||||
|
|
||||||
|
|
||||||
class TestTextfileToDocument:
|
class TestTextfileToDocument:
|
||||||
|
@ -2,12 +2,13 @@
|
|||||||
#
|
#
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import sys
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import sys
|
|
||||||
from haystack.dataclasses import ByteStream
|
|
||||||
from haystack.components.converters.tika import TikaDocumentConverter
|
from haystack.components.converters.tika import TikaDocumentConverter
|
||||||
|
from haystack.dataclasses import ByteStream
|
||||||
|
|
||||||
|
|
||||||
class TestTikaDocumentConverter:
|
class TestTikaDocumentConverter:
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from haystack.components.converters.utils import normalize_metadata
|
from haystack.components.converters.utils import normalize_metadata
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,16 +3,15 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
from openai import APIError
|
|
||||||
|
|
||||||
from haystack.utils.auth import Secret
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from openai import APIError
|
||||||
|
|
||||||
from haystack import Document
|
from haystack import Document
|
||||||
from haystack.components.embedders import AzureOpenAIDocumentEmbedder
|
from haystack.components.embedders import AzureOpenAIDocumentEmbedder
|
||||||
|
from haystack.utils.auth import Secret
|
||||||
from haystack.utils.azure import default_azure_ad_token_provider
|
from haystack.utils.azure import default_azure_ad_token_provider
|
||||||
from unittest.mock import Mock, patch
|
|
||||||
|
|
||||||
|
|
||||||
class TestAzureOpenAIDocumentEmbedder:
|
class TestAzureOpenAIDocumentEmbedder:
|
||||||
|
@ -3,12 +3,11 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import random
|
|
||||||
import pytest
|
import pytest
|
||||||
from huggingface_hub.utils import RepositoryNotFoundError
|
from huggingface_hub.utils import RepositoryNotFoundError
|
||||||
|
|
||||||
from numpy import array
|
from numpy import array
|
||||||
|
|
||||||
from haystack.components.embedders import HuggingFaceAPIDocumentEmbedder
|
from haystack.components.embedders import HuggingFaceAPIDocumentEmbedder
|
||||||
@ -370,13 +369,13 @@ class TestHuggingFaceAPIDocumentEmbedder:
|
|||||||
assert truncate is True
|
assert truncate is True
|
||||||
assert normalize is False
|
assert normalize is False
|
||||||
|
|
||||||
@pytest.mark.flaky(reruns=5, reruns_delay=5)
|
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
not os.environ.get("HF_API_TOKEN", None),
|
not os.environ.get("HF_API_TOKEN", None),
|
||||||
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
|
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
|
||||||
)
|
)
|
||||||
|
@pytest.mark.flaky(reruns=2, reruns_delay=10)
|
||||||
def test_live_run_serverless(self):
|
def test_live_run_serverless(self):
|
||||||
docs = [
|
docs = [
|
||||||
Document(content="I love cheese", meta={"topic": "Cuisine"}),
|
Document(content="I love cheese", meta={"topic": "Cuisine"}),
|
||||||
|
@ -3,12 +3,13 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import random
|
|
||||||
import pytest
|
import pytest
|
||||||
from huggingface_hub.utils import RepositoryNotFoundError
|
from huggingface_hub.utils import RepositoryNotFoundError
|
||||||
from numpy import array
|
from numpy import array
|
||||||
|
|
||||||
from haystack.components.embedders import HuggingFaceAPITextEmbedder
|
from haystack.components.embedders import HuggingFaceAPITextEmbedder
|
||||||
from haystack.utils.auth import Secret
|
from haystack.utils.auth import Secret
|
||||||
from haystack.utils.hf import HFEmbeddingAPIType
|
from haystack.utils.hf import HFEmbeddingAPIType
|
||||||
@ -213,9 +214,9 @@ class TestHuggingFaceAPITextEmbedder:
|
|||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
embedder.run(text="The food was delicious")
|
embedder.run(text="The food was delicious")
|
||||||
|
|
||||||
@pytest.mark.flaky(reruns=5, reruns_delay=5)
|
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
|
@pytest.mark.flaky(reruns=2, reruns_delay=10)
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
not os.environ.get("HF_API_TOKEN", None),
|
not os.environ.get("HF_API_TOKEN", None),
|
||||||
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
|
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
|
||||||
@ -233,6 +234,7 @@ class TestHuggingFaceAPITextEmbedder:
|
|||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
|
@pytest.mark.flaky(reruns=2, reruns_delay=10)
|
||||||
@pytest.mark.skipif(os.environ.get("HF_API_TOKEN", "") == "", reason="HF_API_TOKEN is not set")
|
@pytest.mark.skipif(os.environ.get("HF_API_TOKEN", "") == "", reason="HF_API_TOKEN is not set")
|
||||||
async def test_live_run_async_serverless(self):
|
async def test_live_run_async_serverless(self):
|
||||||
model_name = "sentence-transformers/all-MiniLM-L6-v2"
|
model_name = "sentence-transformers/all-MiniLM-L6-v2"
|
||||||
|
@ -15,20 +15,6 @@ from haystack.components.embedders.openai_document_embedder import OpenAIDocumen
|
|||||||
from haystack.utils.auth import Secret
|
from haystack.utils.auth import Secret
|
||||||
|
|
||||||
|
|
||||||
def mock_openai_response(input: List[str], model: str = "text-embedding-ada-002", **kwargs) -> dict:
|
|
||||||
dict_response = {
|
|
||||||
"object": "list",
|
|
||||||
"data": [
|
|
||||||
{"object": "embedding", "index": i, "embedding": [random.random() for _ in range(1536)]}
|
|
||||||
for i in range(len(input))
|
|
||||||
],
|
|
||||||
"model": model,
|
|
||||||
"usage": {"prompt_tokens": 4, "total_tokens": 4},
|
|
||||||
}
|
|
||||||
|
|
||||||
return dict_response
|
|
||||||
|
|
||||||
|
|
||||||
class TestOpenAIDocumentEmbedder:
|
class TestOpenAIDocumentEmbedder:
|
||||||
def test_init_default(self, monkeypatch):
|
def test_init_default(self, monkeypatch):
|
||||||
monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key")
|
monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key")
|
||||||
|
@ -5,10 +5,10 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from openai.types import CreateEmbeddingResponse, Embedding
|
||||||
|
|
||||||
from haystack.components.embedders.openai_text_embedder import OpenAITextEmbedder
|
from haystack.components.embedders.openai_text_embedder import OpenAITextEmbedder
|
||||||
from haystack.utils.auth import Secret
|
from haystack.utils.auth import Secret
|
||||||
from openai.types import CreateEmbeddingResponse, Embedding
|
|
||||||
|
|
||||||
|
|
||||||
class TestOpenAITextEmbedder:
|
class TestOpenAITextEmbedder:
|
||||||
@ -149,8 +149,8 @@ class TestOpenAITextEmbedder:
|
|||||||
monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key")
|
monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key")
|
||||||
embedder = OpenAITextEmbedder(dimensions=1536)
|
embedder = OpenAITextEmbedder(dimensions=1536)
|
||||||
|
|
||||||
input = "The food was delicious"
|
inp = "The food was delicious"
|
||||||
prepared_input = embedder._prepare_input(input)
|
prepared_input = embedder._prepare_input(inp)
|
||||||
assert prepared_input == {
|
assert prepared_input == {
|
||||||
"model": "text-embedding-ada-002",
|
"model": "text-embedding-ada-002",
|
||||||
"input": "The food was delicious",
|
"input": "The food was delicious",
|
||||||
|
@ -67,7 +67,7 @@ class TestSentenceTransformersDocumentEmbedder:
|
|||||||
component = SentenceTransformersDocumentEmbedder(model="model", device=ComponentDevice.from_str("cpu"))
|
component = SentenceTransformersDocumentEmbedder(model="model", device=ComponentDevice.from_str("cpu"))
|
||||||
data = component.to_dict()
|
data = component.to_dict()
|
||||||
assert data == {
|
assert data == {
|
||||||
"type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder",
|
"type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder", # noqa: E501
|
||||||
"init_parameters": {
|
"init_parameters": {
|
||||||
"model": "model",
|
"model": "model",
|
||||||
"device": ComponentDevice.from_str("cpu").to_dict(),
|
"device": ComponentDevice.from_str("cpu").to_dict(),
|
||||||
@ -115,7 +115,7 @@ class TestSentenceTransformersDocumentEmbedder:
|
|||||||
data = component.to_dict()
|
data = component.to_dict()
|
||||||
|
|
||||||
assert data == {
|
assert data == {
|
||||||
"type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder",
|
"type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder", # noqa: E501
|
||||||
"init_parameters": {
|
"init_parameters": {
|
||||||
"model": "model",
|
"model": "model",
|
||||||
"device": ComponentDevice.from_str("cuda:0").to_dict(),
|
"device": ComponentDevice.from_str("cuda:0").to_dict(),
|
||||||
@ -161,7 +161,7 @@ class TestSentenceTransformersDocumentEmbedder:
|
|||||||
}
|
}
|
||||||
component = SentenceTransformersDocumentEmbedder.from_dict(
|
component = SentenceTransformersDocumentEmbedder.from_dict(
|
||||||
{
|
{
|
||||||
"type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder",
|
"type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder", # noqa: E501
|
||||||
"init_parameters": init_parameters,
|
"init_parameters": init_parameters,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -186,7 +186,7 @@ class TestSentenceTransformersDocumentEmbedder:
|
|||||||
def test_from_dict_no_default_parameters(self):
|
def test_from_dict_no_default_parameters(self):
|
||||||
component = SentenceTransformersDocumentEmbedder.from_dict(
|
component = SentenceTransformersDocumentEmbedder.from_dict(
|
||||||
{
|
{
|
||||||
"type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder",
|
"type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder", # noqa: E501
|
||||||
"init_parameters": {},
|
"init_parameters": {},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -224,7 +224,7 @@ class TestSentenceTransformersDocumentEmbedder:
|
|||||||
}
|
}
|
||||||
component = SentenceTransformersDocumentEmbedder.from_dict(
|
component = SentenceTransformersDocumentEmbedder.from_dict(
|
||||||
{
|
{
|
||||||
"type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder",
|
"type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder", # noqa: E501
|
||||||
"init_parameters": init_parameters,
|
"init_parameters": init_parameters,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -382,9 +382,9 @@ class TestSentenceTransformersDocumentEmbedder:
|
|||||||
model="sentence-transformers/all-MiniLM-L6-v2",
|
model="sentence-transformers/all-MiniLM-L6-v2",
|
||||||
token=None,
|
token=None,
|
||||||
device=ComponentDevice.from_str("cpu"),
|
device=ComponentDevice.from_str("cpu"),
|
||||||
model_kwargs={
|
# setting the path isn't necessary if the repo contains a "onnx/model.onnx" file but this is to prevent
|
||||||
"file_name": "onnx/model.onnx"
|
# a HF warning
|
||||||
}, # setting the path isn't necessary if the repo contains a "onnx/model.onnx" file but this is to prevent a HF warning
|
model_kwargs={"file_name": "onnx/model.onnx"},
|
||||||
backend="onnx",
|
backend="onnx",
|
||||||
)
|
)
|
||||||
onnx_embedder.warm_up()
|
onnx_embedder.warm_up()
|
||||||
@ -410,9 +410,9 @@ class TestSentenceTransformersDocumentEmbedder:
|
|||||||
model="sentence-transformers/all-MiniLM-L6-v2",
|
model="sentence-transformers/all-MiniLM-L6-v2",
|
||||||
token=None,
|
token=None,
|
||||||
device=ComponentDevice.from_str("cpu"),
|
device=ComponentDevice.from_str("cpu"),
|
||||||
model_kwargs={
|
# setting the path isn't necessary if the repo contains a "openvino/openvino_model.xml" file but this is
|
||||||
"file_name": "openvino/openvino_model.xml"
|
# to prevent a HF warning
|
||||||
}, # setting the path isn't necessary if the repo contains a "openvino/openvino_model.xml" file but this is to prevent a HF warning
|
model_kwargs={"file_name": "openvino/openvino_model.xml"},
|
||||||
backend="openvino",
|
backend="openvino",
|
||||||
)
|
)
|
||||||
openvino_embedder.warm_up()
|
openvino_embedder.warm_up()
|
||||||
|
@ -60,7 +60,7 @@ class TestSentenceTransformersTextEmbedder:
|
|||||||
component = SentenceTransformersTextEmbedder(model="model", device=ComponentDevice.from_str("cpu"))
|
component = SentenceTransformersTextEmbedder(model="model", device=ComponentDevice.from_str("cpu"))
|
||||||
data = component.to_dict()
|
data = component.to_dict()
|
||||||
assert data == {
|
assert data == {
|
||||||
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder",
|
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder", # noqa: E501
|
||||||
"init_parameters": {
|
"init_parameters": {
|
||||||
"token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"},
|
"token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"},
|
||||||
"model": "model",
|
"model": "model",
|
||||||
@ -103,7 +103,7 @@ class TestSentenceTransformersTextEmbedder:
|
|||||||
)
|
)
|
||||||
data = component.to_dict()
|
data = component.to_dict()
|
||||||
assert data == {
|
assert data == {
|
||||||
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder",
|
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder", # noqa: E501
|
||||||
"init_parameters": {
|
"init_parameters": {
|
||||||
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
|
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
|
||||||
"model": "model",
|
"model": "model",
|
||||||
@ -132,7 +132,7 @@ class TestSentenceTransformersTextEmbedder:
|
|||||||
|
|
||||||
def test_from_dict(self):
|
def test_from_dict(self):
|
||||||
data = {
|
data = {
|
||||||
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder",
|
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder", # noqa: E501
|
||||||
"init_parameters": {
|
"init_parameters": {
|
||||||
"token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"},
|
"token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"},
|
||||||
"model": "model",
|
"model": "model",
|
||||||
@ -170,7 +170,7 @@ class TestSentenceTransformersTextEmbedder:
|
|||||||
|
|
||||||
def test_from_dict_no_default_parameters(self):
|
def test_from_dict_no_default_parameters(self):
|
||||||
data = {
|
data = {
|
||||||
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder",
|
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder", # noqa: E501
|
||||||
"init_parameters": {},
|
"init_parameters": {},
|
||||||
}
|
}
|
||||||
component = SentenceTransformersTextEmbedder.from_dict(data)
|
component = SentenceTransformersTextEmbedder.from_dict(data)
|
||||||
@ -189,7 +189,7 @@ class TestSentenceTransformersTextEmbedder:
|
|||||||
|
|
||||||
def test_from_dict_none_device(self):
|
def test_from_dict_none_device(self):
|
||||||
data = {
|
data = {
|
||||||
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder",
|
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder", # noqa: E501
|
||||||
"init_parameters": {
|
"init_parameters": {
|
||||||
"token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"},
|
"token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"},
|
||||||
"model": "model",
|
"model": "model",
|
||||||
@ -283,7 +283,8 @@ class TestSentenceTransformersTextEmbedder:
|
|||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
def test_run_trunc(self, monkeypatch):
|
def test_run_trunc(self, monkeypatch):
|
||||||
"""
|
"""
|
||||||
sentence-transformers/paraphrase-albert-small-v2 maps sentences & paragraphs to a 768 dimensional dense vector space
|
sentence-transformers/paraphrase-albert-small-v2 maps sentences & paragraphs to a 768 dimensional dense vector
|
||||||
|
space
|
||||||
"""
|
"""
|
||||||
monkeypatch.delenv("HF_API_TOKEN", raising=False) # https://github.com/deepset-ai/haystack/issues/8811
|
monkeypatch.delenv("HF_API_TOKEN", raising=False) # https://github.com/deepset-ai/haystack/issues/8811
|
||||||
checkpoint = "sentence-transformers/paraphrase-albert-small-v2"
|
checkpoint = "sentence-transformers/paraphrase-albert-small-v2"
|
||||||
@ -306,7 +307,8 @@ class TestSentenceTransformersTextEmbedder:
|
|||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
def test_run_quantization(self):
|
def test_run_quantization(self):
|
||||||
"""
|
"""
|
||||||
sentence-transformers/paraphrase-albert-small-v2 maps sentences & paragraphs to a 768 dimensional dense vector space
|
sentence-transformers/paraphrase-albert-small-v2 maps sentences & paragraphs to a 768 dimensional dense vector
|
||||||
|
space
|
||||||
"""
|
"""
|
||||||
checkpoint = "sentence-transformers/paraphrase-albert-small-v2"
|
checkpoint = "sentence-transformers/paraphrase-albert-small-v2"
|
||||||
text = "a nice text to embed"
|
text = "a nice text to embed"
|
||||||
@ -341,9 +343,9 @@ class TestSentenceTransformersTextEmbedder:
|
|||||||
model="sentence-transformers/all-MiniLM-L6-v2",
|
model="sentence-transformers/all-MiniLM-L6-v2",
|
||||||
token=None,
|
token=None,
|
||||||
device=ComponentDevice.from_str("cpu"),
|
device=ComponentDevice.from_str("cpu"),
|
||||||
model_kwargs={
|
# setting the path isn't necessary if the repo contains a "onnx/model.onnx" file but this is to prevent
|
||||||
"file_name": "onnx/model.onnx"
|
# a HF warning
|
||||||
}, # setting the path isn't necessary if the repo contains a "onnx/model.onnx" file but this is to prevent a HF warning
|
model_kwargs={"file_name": "onnx/model.onnx"},
|
||||||
backend="onnx",
|
backend="onnx",
|
||||||
)
|
)
|
||||||
onnx_embedder.warm_up()
|
onnx_embedder.warm_up()
|
||||||
@ -369,9 +371,9 @@ class TestSentenceTransformersTextEmbedder:
|
|||||||
model="sentence-transformers/all-MiniLM-L6-v2",
|
model="sentence-transformers/all-MiniLM-L6-v2",
|
||||||
token=None,
|
token=None,
|
||||||
device=ComponentDevice.from_str("cpu"),
|
device=ComponentDevice.from_str("cpu"),
|
||||||
model_kwargs={
|
# setting the path isn't necessary if the repo contains a "openvino/openvino_model.xml" file but
|
||||||
"file_name": "openvino/openvino_model.xml"
|
# this is to prevent a HF warning
|
||||||
}, # setting the path isn't necessary if the repo contains a "openvino/openvino_model.xml" file but this is to prevent a HF warning
|
model_kwargs={"file_name": "openvino/openvino_model.xml"},
|
||||||
backend="openvino",
|
backend="openvino",
|
||||||
)
|
)
|
||||||
openvino_embedder.warm_up()
|
openvino_embedder.warm_up()
|
||||||
|
@ -2,18 +2,17 @@
|
|||||||
#
|
#
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import math
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from haystack import Pipeline
|
from haystack import Pipeline
|
||||||
from haystack.components.evaluators import ContextRelevanceEvaluator
|
from haystack.components.evaluators import ContextRelevanceEvaluator
|
||||||
from haystack.utils.auth import Secret
|
|
||||||
from haystack.dataclasses.chat_message import ChatMessage
|
|
||||||
from haystack.components.generators.chat.openai import OpenAIChatGenerator
|
from haystack.components.generators.chat.openai import OpenAIChatGenerator
|
||||||
|
from haystack.dataclasses.chat_message import ChatMessage
|
||||||
|
from haystack.utils.auth import Secret
|
||||||
|
|
||||||
|
|
||||||
class TestContextRelevanceEvaluator:
|
class TestContextRelevanceEvaluator:
|
||||||
|
@ -122,7 +122,7 @@ def test_run_empty_retrieved_and_empty_ground_truth():
|
|||||||
def test_run_no_retrieved():
|
def test_run_no_retrieved():
|
||||||
evaluator = DocumentNDCGEvaluator()
|
evaluator = DocumentNDCGEvaluator()
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
result = evaluator.run(ground_truth_documents=[[Document(content="France")]], retrieved_documents=[])
|
_ = evaluator.run(ground_truth_documents=[[Document(content="France")]], retrieved_documents=[])
|
||||||
|
|
||||||
|
|
||||||
def test_run_no_ground_truth():
|
def test_run_no_ground_truth():
|
||||||
|
@ -4,9 +4,9 @@
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from haystack import default_from_dict
|
||||||
from haystack.components.evaluators.document_recall import DocumentRecallEvaluator, RecallMode
|
from haystack.components.evaluators.document_recall import DocumentRecallEvaluator, RecallMode
|
||||||
from haystack.dataclasses import Document
|
from haystack.dataclasses import Document
|
||||||
from haystack import default_from_dict
|
|
||||||
|
|
||||||
|
|
||||||
def test_init_with_unknown_mode_string():
|
def test_init_with_unknown_mode_string():
|
||||||
|
@ -2,17 +2,17 @@
|
|||||||
#
|
#
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import os
|
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from haystack import Pipeline
|
from haystack import Pipeline
|
||||||
from haystack.components.evaluators import FaithfulnessEvaluator
|
from haystack.components.evaluators import FaithfulnessEvaluator
|
||||||
from haystack.utils.auth import Secret
|
|
||||||
from haystack.dataclasses.chat_message import ChatMessage
|
|
||||||
from haystack.components.generators.chat.openai import OpenAIChatGenerator
|
from haystack.components.generators.chat.openai import OpenAIChatGenerator
|
||||||
|
from haystack.dataclasses.chat_message import ChatMessage
|
||||||
|
from haystack.utils.auth import Secret
|
||||||
|
|
||||||
|
|
||||||
class TestFaithfulnessEvaluator:
|
class TestFaithfulnessEvaluator:
|
||||||
|
@ -8,8 +8,8 @@ import pytest
|
|||||||
|
|
||||||
from haystack import Pipeline
|
from haystack import Pipeline
|
||||||
from haystack.components.evaluators import LLMEvaluator
|
from haystack.components.evaluators import LLMEvaluator
|
||||||
from haystack.dataclasses.chat_message import ChatMessage
|
|
||||||
from haystack.components.generators.chat.openai import OpenAIChatGenerator
|
from haystack.components.generators.chat.openai import OpenAIChatGenerator
|
||||||
|
from haystack.dataclasses.chat_message import ChatMessage
|
||||||
|
|
||||||
|
|
||||||
class TestLLMEvaluator:
|
class TestLLMEvaluator:
|
||||||
@ -349,7 +349,11 @@ class TestLLMEvaluator:
|
|||||||
template = component.prepare_template()
|
template = component.prepare_template()
|
||||||
assert (
|
assert (
|
||||||
template
|
template
|
||||||
== 'Instructions:\ntest-instruction\n\nGenerate the response in JSON format with the following keys:\n["score"]\nConsider the instructions and the examples below to determine those values.\n\nExamples:\nInputs:\n{"predicted_answers": "Damn, this is straight outta hell!!!"}\nOutputs:\n{"score": 1}\nInputs:\n{"predicted_answers": "Football is the most popular sport."}\nOutputs:\n{"score": 0}\n\nInputs:\n{"predicted_answers": {{ predicted_answers }}}\nOutputs:\n'
|
== "Instructions:\ntest-instruction\n\nGenerate the response in JSON format with the following keys:"
|
||||||
|
'\n["score"]\nConsider the instructions and the examples below to determine those values.\n\n'
|
||||||
|
'Examples:\nInputs:\n{"predicted_answers": "Damn, this is straight outta hell!!!"}\nOutputs:'
|
||||||
|
'\n{"score": 1}\nInputs:\n{"predicted_answers": "Football is the most popular sport."}\nOutputs:'
|
||||||
|
'\n{"score": 0}\n\nInputs:\n{"predicted_answers": {{ predicted_answers }}}\nOutputs:\n'
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_invalid_input_parameters(self, monkeypatch):
|
def test_invalid_input_parameters(self, monkeypatch):
|
||||||
|
@ -3,19 +3,18 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import Mock
|
|
||||||
from haystack import Document, Pipeline
|
from haystack import Document, Pipeline
|
||||||
|
from haystack.components.extractors import LLMMetadataExtractor
|
||||||
|
from haystack.components.generators.chat import OpenAIChatGenerator
|
||||||
from haystack.components.writers import DocumentWriter
|
from haystack.components.writers import DocumentWriter
|
||||||
from haystack.dataclasses import ChatMessage
|
from haystack.dataclasses import ChatMessage
|
||||||
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
||||||
|
|
||||||
|
|
||||||
from haystack.components.extractors import LLMMetadataExtractor
|
|
||||||
from haystack.components.generators.chat import OpenAIChatGenerator
|
|
||||||
|
|
||||||
|
|
||||||
class TestLLMMetadataExtractor:
|
class TestLLMMetadataExtractor:
|
||||||
def test_init(self, monkeypatch):
|
def test_init(self, monkeypatch):
|
||||||
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
||||||
@ -139,7 +138,8 @@ class TestLLMMetadataExtractor:
|
|||||||
"_name": None,
|
"_name": None,
|
||||||
"_content": [
|
"_content": [
|
||||||
{
|
{
|
||||||
"text": "some_user_definer_prompt deepset was founded in 2018 in Berlin, and is known for its Haystack framework"
|
"text": "some_user_definer_prompt deepset was founded in 2018 in Berlin, and is known for "
|
||||||
|
"its Haystack framework"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
@ -151,7 +151,8 @@ class TestLLMMetadataExtractor:
|
|||||||
"_name": None,
|
"_name": None,
|
||||||
"_content": [
|
"_content": [
|
||||||
{
|
{
|
||||||
"text": "some_user_definer_prompt Hugging Face is a company founded in Paris, France and is known for its Transformers library"
|
"text": "some_user_definer_prompt Hugging Face is a company founded in Paris, France and "
|
||||||
|
"is known for its Transformers library"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
@ -179,7 +180,8 @@ class TestLLMMetadataExtractor:
|
|||||||
"_name": None,
|
"_name": None,
|
||||||
"_content": [
|
"_content": [
|
||||||
{
|
{
|
||||||
"text": "some_user_definer_prompt Hugging Face is a company founded in Paris, France and is known for its Transformers library"
|
"text": "some_user_definer_prompt Hugging Face is a company founded in Paris, "
|
||||||
|
"France and is known for its Transformers library"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
@ -195,7 +197,8 @@ class TestLLMMetadataExtractor:
|
|||||||
)
|
)
|
||||||
docs = [
|
docs = [
|
||||||
Document(
|
Document(
|
||||||
content="Hugging Face is a company founded in Paris, France and is known for its Transformers library\fPage 2\fPage 3"
|
content="Hugging Face is a company founded in Paris, France and is known for its Transformers "
|
||||||
|
"library\fPage 2\fPage 3"
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
prompts = extractor._prepare_prompts(docs, expanded_range=[1, 2])
|
prompts = extractor._prepare_prompts(docs, expanded_range=[1, 2])
|
||||||
@ -208,7 +211,8 @@ class TestLLMMetadataExtractor:
|
|||||||
"_name": None,
|
"_name": None,
|
||||||
"_content": [
|
"_content": [
|
||||||
{
|
{
|
||||||
"text": "some_user_definer_prompt Hugging Face is a company founded in Paris, France and is known for its Transformers library\x0cPage 2\x0c"
|
"text": "some_user_definer_prompt Hugging Face is a company founded in Paris, France and "
|
||||||
|
"is known for its Transformers library\x0cPage 2\x0c"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
@ -305,7 +309,7 @@ entity_types: [company, organization, person, country, product, service]
|
|||||||
text: {{ document.content }}
|
text: {{ document.content }}
|
||||||
######################
|
######################
|
||||||
output:
|
output:
|
||||||
"""
|
""" # noqa: E501
|
||||||
|
|
||||||
doc_store = InMemoryDocumentStore()
|
doc_store = InMemoryDocumentStore()
|
||||||
extractor = LLMMetadataExtractor(
|
extractor = LLMMetadataExtractor(
|
||||||
|
@ -6,11 +6,11 @@
|
|||||||
# Spacy is not installed in the test environment to keep the CI fast.
|
# Spacy is not installed in the test environment to keep the CI fast.
|
||||||
# We test the Spacy backend in e2e/pipelines/test_named_entity_extractor.py.
|
# We test the Spacy backend in e2e/pipelines/test_named_entity_extractor.py.
|
||||||
|
|
||||||
from haystack.utils.auth import Secret
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from haystack import ComponentError, DeserializationError, Pipeline
|
from haystack import ComponentError, DeserializationError, Pipeline
|
||||||
from haystack.components.extractors import NamedEntityExtractor, NamedEntityExtractorBackend
|
from haystack.components.extractors import NamedEntityExtractor, NamedEntityExtractorBackend
|
||||||
|
from haystack.utils.auth import Secret
|
||||||
from haystack.utils.device import ComponentDevice
|
from haystack.utils.device import ComponentDevice
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,16 +2,16 @@
|
|||||||
#
|
#
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from unittest.mock import patch, Mock
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pytest
|
|
||||||
import httpx
|
import httpx
|
||||||
|
import pytest
|
||||||
|
|
||||||
from haystack.components.fetchers.link_content import (
|
from haystack.components.fetchers.link_content import (
|
||||||
LinkContentFetcher,
|
|
||||||
_text_content_handler,
|
|
||||||
_binary_content_handler,
|
|
||||||
DEFAULT_USER_AGENT,
|
DEFAULT_USER_AGENT,
|
||||||
|
LinkContentFetcher,
|
||||||
|
_binary_content_handler,
|
||||||
|
_text_content_handler,
|
||||||
)
|
)
|
||||||
|
|
||||||
HTML_URL = "https://docs.haystack.deepset.ai/docs"
|
HTML_URL = "https://docs.haystack.deepset.ai/docs"
|
||||||
@ -30,11 +30,9 @@ def mock_get_link_text_content():
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_get_link_content(test_files_path):
|
def mock_get_link_content(test_files_path):
|
||||||
with patch("haystack.components.fetchers.link_content.httpx.Client.get") as mock_get:
|
with patch("haystack.components.fetchers.link_content.httpx.Client.get") as mock_get:
|
||||||
mock_response = Mock(
|
with open(test_files_path / "pdf" / "sample_pdf_1.pdf", "rb") as f1:
|
||||||
status_code=200,
|
file_bytes = f1.read()
|
||||||
content=open(test_files_path / "pdf" / "sample_pdf_1.pdf", "rb").read(),
|
mock_response = Mock(status_code=200, content=file_bytes, headers={"Content-Type": "application/pdf"})
|
||||||
headers={"Content-Type": "application/pdf"},
|
|
||||||
)
|
|
||||||
mock_get.return_value = mock_response
|
mock_get.return_value = mock_response
|
||||||
yield mock_get
|
yield mock_get
|
||||||
|
|
||||||
@ -110,7 +108,8 @@ class TestLinkContentFetcher:
|
|||||||
|
|
||||||
def test_run_binary(self, test_files_path):
|
def test_run_binary(self, test_files_path):
|
||||||
"""Test fetching binary content"""
|
"""Test fetching binary content"""
|
||||||
file_bytes = open(test_files_path / "pdf" / "sample_pdf_1.pdf", "rb").read()
|
with open(test_files_path / "pdf" / "sample_pdf_1.pdf", "rb") as f1:
|
||||||
|
file_bytes = f1.read()
|
||||||
with patch("haystack.components.fetchers.link_content.httpx.Client.get") as mock_get:
|
with patch("haystack.components.fetchers.link_content.httpx.Client.get") as mock_get:
|
||||||
mock_response = Mock(status_code=200, content=file_bytes, headers={"Content-Type": "application/pdf"})
|
mock_response = Mock(status_code=200, content=file_bytes, headers={"Content-Type": "application/pdf"})
|
||||||
mock_get.return_value = mock_response
|
mock_get.return_value = mock_response
|
||||||
@ -143,8 +142,8 @@ class TestLinkContentFetcher:
|
|||||||
|
|
||||||
def test_bad_request_exception_raised(self):
|
def test_bad_request_exception_raised(self):
|
||||||
"""
|
"""
|
||||||
This test is to ensure that the fetcher raises an exception when a single bad request is made and it is configured to
|
This test is to ensure that the fetcher raises an exception when a single bad request is made and it is
|
||||||
do so.
|
configured to do so.
|
||||||
"""
|
"""
|
||||||
fetcher = LinkContentFetcher(raise_on_failure=True, retry_attempts=0)
|
fetcher = LinkContentFetcher(raise_on_failure=True, retry_attempts=0)
|
||||||
|
|
||||||
@ -247,7 +246,8 @@ class TestLinkContentFetcherAsync:
|
|||||||
streams = (await fetcher.run_async(urls=["https://www.example.com"]))["streams"]
|
streams = (await fetcher.run_async(urls=["https://www.example.com"]))["streams"]
|
||||||
|
|
||||||
first_stream = streams[0]
|
first_stream = streams[0]
|
||||||
assert first_stream.data == b"Example test response"
|
expected_content = b"Example test response"
|
||||||
|
assert first_stream.data == expected_content
|
||||||
assert first_stream.meta["content_type"] == "text/plain"
|
assert first_stream.meta["content_type"] == "text/plain"
|
||||||
assert first_stream.mime_type == "text/plain"
|
assert first_stream.mime_type == "text/plain"
|
||||||
|
|
||||||
@ -265,7 +265,8 @@ class TestLinkContentFetcherAsync:
|
|||||||
|
|
||||||
assert len(streams) == 2
|
assert len(streams) == 2
|
||||||
for stream in streams:
|
for stream in streams:
|
||||||
assert stream.data == b"Example test response"
|
expected_data = b"Example test response"
|
||||||
|
assert stream.data == expected_data
|
||||||
assert stream.meta["content_type"] == "text/plain"
|
assert stream.meta["content_type"] == "text/plain"
|
||||||
assert stream.mime_type == "text/plain"
|
assert stream.mime_type == "text/plain"
|
||||||
|
|
||||||
@ -324,7 +325,8 @@ class TestLinkContentFetcherAsync:
|
|||||||
# Should succeed on the second attempt with the second user agent
|
# Should succeed on the second attempt with the second user agent
|
||||||
streams = (await fetcher.run_async(urls=["https://www.example.com"]))["streams"]
|
streams = (await fetcher.run_async(urls=["https://www.example.com"]))["streams"]
|
||||||
assert len(streams) == 1
|
assert len(streams) == 1
|
||||||
assert streams[0].data == b"Success"
|
expected_result = b"Success"
|
||||||
|
assert streams[0].data == expected_result
|
||||||
|
|
||||||
mock_sleep.assert_called_once()
|
mock_sleep.assert_called_once()
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@ from typing import Any, Dict, List, Optional
|
|||||||
import pytest
|
import pytest
|
||||||
from openai import OpenAIError
|
from openai import OpenAIError
|
||||||
|
|
||||||
from haystack import component, Pipeline
|
from haystack import Pipeline, component
|
||||||
from haystack.components.generators.chat import AzureOpenAIChatGenerator
|
from haystack.components.generators.chat import AzureOpenAIChatGenerator
|
||||||
from haystack.components.generators.utils import print_streaming_chunk
|
from haystack.components.generators.utils import print_streaming_chunk
|
||||||
from haystack.dataclasses import ChatMessage, ToolCall
|
from haystack.dataclasses import ChatMessage, ToolCall
|
||||||
|
@ -2,18 +2,14 @@
|
|||||||
#
|
#
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
import os
|
import os
|
||||||
|
from datetime import datetime
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
from unittest.mock import MagicMock, Mock, AsyncMock, patch
|
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from haystack import Pipeline
|
|
||||||
from haystack.dataclasses import StreamingChunk
|
|
||||||
from haystack.utils.auth import Secret
|
|
||||||
from haystack.utils.hf import HFGenerationAPIType
|
|
||||||
|
|
||||||
from huggingface_hub import (
|
from huggingface_hub import (
|
||||||
|
ChatCompletionInputStreamOptions,
|
||||||
ChatCompletionOutput,
|
ChatCompletionOutput,
|
||||||
ChatCompletionOutputComplete,
|
ChatCompletionOutputComplete,
|
||||||
ChatCompletionOutputFunctionDefinition,
|
ChatCompletionOutputFunctionDefinition,
|
||||||
@ -23,21 +19,22 @@ from huggingface_hub import (
|
|||||||
ChatCompletionStreamOutput,
|
ChatCompletionStreamOutput,
|
||||||
ChatCompletionStreamOutputChoice,
|
ChatCompletionStreamOutputChoice,
|
||||||
ChatCompletionStreamOutputDelta,
|
ChatCompletionStreamOutputDelta,
|
||||||
ChatCompletionInputStreamOptions,
|
|
||||||
ChatCompletionStreamOutputUsage,
|
ChatCompletionStreamOutputUsage,
|
||||||
)
|
)
|
||||||
from huggingface_hub.errors import RepositoryNotFoundError
|
from huggingface_hub.errors import RepositoryNotFoundError
|
||||||
|
|
||||||
|
from haystack import Pipeline
|
||||||
from haystack.components.generators.chat.hugging_face_api import (
|
from haystack.components.generators.chat.hugging_face_api import (
|
||||||
HuggingFaceAPIChatGenerator,
|
HuggingFaceAPIChatGenerator,
|
||||||
|
_convert_chat_completion_stream_output_to_streaming_chunk,
|
||||||
_convert_hfapi_tool_calls,
|
_convert_hfapi_tool_calls,
|
||||||
_convert_tools_to_hfapi_tools,
|
_convert_tools_to_hfapi_tools,
|
||||||
_convert_chat_completion_stream_output_to_streaming_chunk,
|
|
||||||
)
|
)
|
||||||
|
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall
|
||||||
from haystack.tools import Tool
|
from haystack.tools import Tool
|
||||||
from haystack.dataclasses import ChatMessage, ToolCall
|
|
||||||
from haystack.tools.toolset import Toolset
|
from haystack.tools.toolset import Toolset
|
||||||
|
from haystack.utils.auth import Secret
|
||||||
|
from haystack.utils.hf import HFGenerationAPIType
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -753,11 +750,11 @@ class TestHuggingFaceAPIChatGenerator:
|
|||||||
not os.environ.get("HF_API_TOKEN", None),
|
not os.environ.get("HF_API_TOKEN", None),
|
||||||
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
|
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
|
||||||
)
|
)
|
||||||
@pytest.mark.flaky(reruns=3, reruns_delay=10)
|
@pytest.mark.flaky(reruns=2, reruns_delay=10)
|
||||||
def test_live_run_serverless(self):
|
def test_live_run_serverless(self):
|
||||||
generator = HuggingFaceAPIChatGenerator(
|
generator = HuggingFaceAPIChatGenerator(
|
||||||
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
||||||
api_params={"model": "microsoft/Phi-3.5-mini-instruct", "provider": "hf-inference"},
|
api_params={"model": "microsoft/Phi-3.5-mini-instruct", "provider": "featherless-ai"},
|
||||||
generation_kwargs={"max_tokens": 20},
|
generation_kwargs={"max_tokens": 20},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -788,11 +785,11 @@ class TestHuggingFaceAPIChatGenerator:
|
|||||||
not os.environ.get("HF_API_TOKEN", None),
|
not os.environ.get("HF_API_TOKEN", None),
|
||||||
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
|
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
|
||||||
)
|
)
|
||||||
@pytest.mark.flaky(reruns=3, reruns_delay=10)
|
@pytest.mark.flaky(reruns=2, reruns_delay=10)
|
||||||
def test_live_run_serverless_streaming(self):
|
def test_live_run_serverless_streaming(self):
|
||||||
generator = HuggingFaceAPIChatGenerator(
|
generator = HuggingFaceAPIChatGenerator(
|
||||||
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
||||||
api_params={"model": "microsoft/Phi-3.5-mini-instruct", "provider": "hf-inference"},
|
api_params={"model": "microsoft/Phi-3.5-mini-instruct", "provider": "featherless-ai"},
|
||||||
generation_kwargs={"max_tokens": 20},
|
generation_kwargs={"max_tokens": 20},
|
||||||
streaming_callback=streaming_callback_handler,
|
streaming_callback=streaming_callback_handler,
|
||||||
)
|
)
|
||||||
@ -837,7 +834,7 @@ class TestHuggingFaceAPIChatGenerator:
|
|||||||
chat_messages = [ChatMessage.from_user("What's the weather like in Paris?")]
|
chat_messages = [ChatMessage.from_user("What's the weather like in Paris?")]
|
||||||
generator = HuggingFaceAPIChatGenerator(
|
generator = HuggingFaceAPIChatGenerator(
|
||||||
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
||||||
api_params={"model": "Qwen/Qwen2.5-72B-Instruct", "provider": "hf-inference"},
|
api_params={"model": "Qwen/Qwen2.5-72B-Instruct", "provider": "together"},
|
||||||
generation_kwargs={"temperature": 0.5},
|
generation_kwargs={"temperature": 0.5},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -851,7 +848,7 @@ class TestHuggingFaceAPIChatGenerator:
|
|||||||
assert tool_call.tool_name == "weather"
|
assert tool_call.tool_name == "weather"
|
||||||
assert "city" in tool_call.arguments
|
assert "city" in tool_call.arguments
|
||||||
assert "Paris" in tool_call.arguments["city"]
|
assert "Paris" in tool_call.arguments["city"]
|
||||||
assert message.meta["finish_reason"] == "stop"
|
assert message.meta["finish_reason"] == "tool_calls"
|
||||||
|
|
||||||
new_messages = chat_messages + [message, ChatMessage.from_tool(tool_result="22° C", origin=tool_call)]
|
new_messages = chat_messages + [message, ChatMessage.from_tool(tool_result="22° C", origin=tool_call)]
|
||||||
|
|
||||||
@ -1024,12 +1021,12 @@ class TestHuggingFaceAPIChatGenerator:
|
|||||||
not os.environ.get("HF_API_TOKEN", None),
|
not os.environ.get("HF_API_TOKEN", None),
|
||||||
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
|
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
|
||||||
)
|
)
|
||||||
@pytest.mark.flaky(reruns=3, reruns_delay=10)
|
@pytest.mark.flaky(reruns=2, reruns_delay=10)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_live_run_async_serverless(self):
|
async def test_live_run_async_serverless(self):
|
||||||
generator = HuggingFaceAPIChatGenerator(
|
generator = HuggingFaceAPIChatGenerator(
|
||||||
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
||||||
api_params={"model": "microsoft/Phi-3.5-mini-instruct", "provider": "hf-inference"},
|
api_params={"model": "microsoft/Phi-3.5-mini-instruct", "provider": "featherless-ai"},
|
||||||
generation_kwargs={"max_tokens": 20},
|
generation_kwargs={"max_tokens": 20},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -4,20 +4,20 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import gc
|
import gc
|
||||||
from typing import Optional, List
|
from typing import List, Optional
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
from haystack.utils.hf import AsyncHFTokenStreamingHandler
|
|
||||||
import pytest
|
import pytest
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
from haystack.components.generators.chat import HuggingFaceLocalChatGenerator
|
from haystack.components.generators.chat import HuggingFaceLocalChatGenerator
|
||||||
from haystack.dataclasses import ChatMessage, ChatRole, ToolCall
|
from haystack.dataclasses import ChatMessage, ChatRole, ToolCall
|
||||||
from haystack.dataclasses.streaming_chunk import StreamingChunk, AsyncStreamingCallbackT
|
from haystack.dataclasses.streaming_chunk import AsyncStreamingCallbackT, StreamingChunk
|
||||||
from haystack.tools import Tool
|
from haystack.tools import Tool
|
||||||
|
from haystack.tools.toolset import Toolset
|
||||||
from haystack.utils import ComponentDevice
|
from haystack.utils import ComponentDevice
|
||||||
from haystack.utils.auth import Secret
|
from haystack.utils.auth import Secret
|
||||||
from haystack.tools.toolset import Toolset
|
from haystack.utils.hf import AsyncHFTokenStreamingHandler
|
||||||
|
|
||||||
|
|
||||||
# used to test serialization of streaming_callback
|
# used to test serialization of streaming_callback
|
||||||
@ -558,7 +558,7 @@ class TestHuggingFaceLocalChatGeneratorAsync:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def test_executor_shutdown(self, model_info_mock, mock_pipeline_with_tokenizer):
|
def test_executor_shutdown(self, model_info_mock, mock_pipeline_with_tokenizer):
|
||||||
with patch("haystack.components.generators.chat.hugging_face_local.pipeline") as mock_pipeline:
|
with patch("haystack.components.generators.chat.hugging_face_local.pipeline"):
|
||||||
generator = HuggingFaceLocalChatGenerator(model="mocked-model")
|
generator = HuggingFaceLocalChatGenerator(model="mocked-model")
|
||||||
executor = generator.executor
|
executor = generator.executor
|
||||||
with patch.object(executor, "shutdown", wraps=executor.shutdown) as mock_shutdown:
|
with patch.object(executor, "shutdown", wraps=executor.shutdown) as mock_shutdown:
|
||||||
|
@ -2,35 +2,37 @@
|
|||||||
#
|
#
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from unittest.mock import patch, ANY, MagicMock
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
from unittest.mock import ANY, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
from openai import OpenAIError
|
from openai import OpenAIError
|
||||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage, ChatCompletionMessageToolCall
|
from openai.types.chat import (
|
||||||
|
ChatCompletion,
|
||||||
|
ChatCompletionChunk,
|
||||||
|
ChatCompletionMessage,
|
||||||
|
ChatCompletionMessageToolCall,
|
||||||
|
chat_completion_chunk,
|
||||||
|
)
|
||||||
from openai.types.chat.chat_completion import Choice
|
from openai.types.chat.chat_completion import Choice
|
||||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta, ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction
|
from openai.types.chat.chat_completion_chunk import ChoiceDelta, ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction
|
||||||
from openai.types.completion_usage import CompletionTokensDetails, CompletionUsage, PromptTokensDetails
|
|
||||||
from openai.types.chat.chat_completion_message_tool_call import Function
|
from openai.types.chat.chat_completion_message_tool_call import Function
|
||||||
from openai.types.chat import chat_completion_chunk
|
from openai.types.completion_usage import CompletionTokensDetails, CompletionUsage, PromptTokensDetails
|
||||||
|
|
||||||
from haystack import component
|
from haystack import component
|
||||||
from haystack.components.generators.utils import print_streaming_chunk
|
|
||||||
from haystack.dataclasses import StreamingChunk, ToolCallDelta
|
|
||||||
from haystack.utils.auth import Secret
|
|
||||||
from haystack.dataclasses import ChatMessage, ToolCall
|
|
||||||
from haystack.tools import ComponentTool, Tool
|
|
||||||
from haystack.components.generators.chat.openai import (
|
from haystack.components.generators.chat.openai import (
|
||||||
OpenAIChatGenerator,
|
OpenAIChatGenerator,
|
||||||
_check_finish_reason,
|
_check_finish_reason,
|
||||||
_convert_chat_completion_chunk_to_streaming_chunk,
|
_convert_chat_completion_chunk_to_streaming_chunk,
|
||||||
)
|
)
|
||||||
|
from haystack.components.generators.utils import print_streaming_chunk
|
||||||
|
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, ToolCallDelta
|
||||||
|
from haystack.tools import ComponentTool, Tool
|
||||||
from haystack.tools.toolset import Toolset
|
from haystack.tools.toolset import Toolset
|
||||||
|
from haystack.utils.auth import Secret
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -1177,6 +1179,32 @@ class TestChatCompletionChunkConversion:
|
|||||||
assert stream_chunk == haystack_chunk
|
assert stream_chunk == haystack_chunk
|
||||||
previous_chunks.append(stream_chunk)
|
previous_chunks.append(stream_chunk)
|
||||||
|
|
||||||
|
def test_convert_chat_completion_chunk_with_empty_tool_calls(self):
|
||||||
|
# This can happen with some LLM providers where tool calls are not present but the pydantic models are still
|
||||||
|
# initialized.
|
||||||
|
chunk = ChatCompletionChunk(
|
||||||
|
id="chatcmpl-BC1y4wqIhe17R8sv3lgLcWlB4tXCw",
|
||||||
|
choices=[
|
||||||
|
chat_completion_chunk.Choice(
|
||||||
|
delta=chat_completion_chunk.ChoiceDelta(
|
||||||
|
tool_calls=[ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction())]
|
||||||
|
),
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
created=1742207200,
|
||||||
|
model="gpt-4o-mini-2024-07-18",
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
)
|
||||||
|
result = _convert_chat_completion_chunk_to_streaming_chunk(chunk=chunk, previous_chunks=[])
|
||||||
|
assert result.content == ""
|
||||||
|
assert result.start is False
|
||||||
|
assert result.tool_calls == [ToolCallDelta(index=0)]
|
||||||
|
assert result.tool_call_result is None
|
||||||
|
assert result.index == 0
|
||||||
|
assert result.meta["model"] == "gpt-4o-mini-2024-07-18"
|
||||||
|
assert result.meta["received_at"] is not None
|
||||||
|
|
||||||
def test_handle_stream_response(self, chat_completion_chunks):
|
def test_handle_stream_response(self, chat_completion_chunks):
|
||||||
openai_chunks = chat_completion_chunks
|
openai_chunks = chat_completion_chunks
|
||||||
comp = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
|
comp = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
|
||||||
|
@ -2,24 +2,27 @@
|
|||||||
#
|
#
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, patch, MagicMock
|
|
||||||
|
|
||||||
from openai import AsyncOpenAI, OpenAIError
|
|
||||||
import pytest
|
|
||||||
from datetime import datetime
|
|
||||||
import os
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall, ChatCompletionChunk
|
import pytest
|
||||||
|
from openai import AsyncOpenAI, OpenAIError
|
||||||
|
from openai.types.chat import (
|
||||||
|
ChatCompletion,
|
||||||
|
ChatCompletionChunk,
|
||||||
|
ChatCompletionMessage,
|
||||||
|
ChatCompletionMessageToolCall,
|
||||||
|
chat_completion_chunk,
|
||||||
|
)
|
||||||
from openai.types.chat.chat_completion import Choice
|
from openai.types.chat.chat_completion import Choice
|
||||||
from openai.types.completion_usage import CompletionTokensDetails, CompletionUsage, PromptTokensDetails
|
|
||||||
from openai.types.chat.chat_completion_message_tool_call import Function
|
from openai.types.chat.chat_completion_message_tool_call import Function
|
||||||
from openai.types.chat import chat_completion_chunk
|
from openai.types.completion_usage import CompletionTokensDetails, CompletionUsage, PromptTokensDetails
|
||||||
|
|
||||||
from haystack.dataclasses import StreamingChunk
|
|
||||||
from haystack.utils.auth import Secret
|
|
||||||
from haystack.dataclasses import ChatMessage, ToolCall
|
|
||||||
from haystack.tools import Tool
|
|
||||||
from haystack.components.generators.chat.openai import OpenAIChatGenerator
|
from haystack.components.generators.chat.openai import OpenAIChatGenerator
|
||||||
|
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall
|
||||||
|
from haystack.tools import Tool
|
||||||
|
from haystack.utils.auth import Secret
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -4,12 +4,11 @@
|
|||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Iterator
|
from typing import Iterator
|
||||||
from unittest.mock import MagicMock, patch, AsyncMock
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from openai import AsyncStream, Stream
|
from openai import AsyncStream, Stream
|
||||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
from openai.types.chat import ChatCompletion, ChatCompletionChunk, chat_completion_chunk
|
||||||
from openai.types.chat import chat_completion_chunk
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -4,14 +4,13 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from haystack import Pipeline
|
|
||||||
from haystack.utils.auth import Secret
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from openai import OpenAIError
|
from openai import OpenAIError
|
||||||
|
|
||||||
|
from haystack import Pipeline
|
||||||
from haystack.components.generators import AzureOpenAIGenerator
|
from haystack.components.generators import AzureOpenAIGenerator
|
||||||
from haystack.components.generators.utils import print_streaming_chunk
|
from haystack.components.generators.utils import print_streaming_chunk
|
||||||
|
from haystack.utils.auth import Secret
|
||||||
from haystack.utils.azure import default_azure_ad_token_provider
|
from haystack.utils.azure import default_azure_ad_token_provider
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,8 +3,8 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from unittest.mock import MagicMock, Mock, patch
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from huggingface_hub import (
|
from huggingface_hub import (
|
||||||
@ -297,6 +297,10 @@ class TestHuggingFaceAPIGenerator:
|
|||||||
not os.environ.get("HF_API_TOKEN", None),
|
not os.environ.get("HF_API_TOKEN", None),
|
||||||
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
|
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
|
||||||
)
|
)
|
||||||
|
@pytest.mark.skip(
|
||||||
|
reason="HF Inference API is not currently serving these models. "
|
||||||
|
"See https://github.com/deepset-ai/haystack/issues/9586."
|
||||||
|
)
|
||||||
def test_run_serverless(self):
|
def test_run_serverless(self):
|
||||||
generator = HuggingFaceAPIGenerator(
|
generator = HuggingFaceAPIGenerator(
|
||||||
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
||||||
@ -307,7 +311,8 @@ class TestHuggingFaceAPIGenerator:
|
|||||||
# You must include the instruction tokens in the prompt. HF does not add them automatically.
|
# You must include the instruction tokens in the prompt. HF does not add them automatically.
|
||||||
# Without them the model will behave erratically.
|
# Without them the model will behave erratically.
|
||||||
response = generator.run(
|
response = generator.run(
|
||||||
"<|user|>\nWhat is the capital of France? Be concise only provide the capital, nothing else.<|end|>\n<|assistant|>\n"
|
"<|user|>\nWhat is the capital of France? Be concise only provide the capital, nothing else.<|end|>"
|
||||||
|
"\n<|assistant|>\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert that the response contains the generated replies
|
# Assert that the response contains the generated replies
|
||||||
@ -329,6 +334,10 @@ class TestHuggingFaceAPIGenerator:
|
|||||||
not os.environ.get("HF_API_TOKEN", None),
|
not os.environ.get("HF_API_TOKEN", None),
|
||||||
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
|
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
|
||||||
)
|
)
|
||||||
|
@pytest.mark.skip(
|
||||||
|
reason="HF Inference API is not currently serving these models. "
|
||||||
|
"See https://github.com/deepset-ai/haystack/issues/9586."
|
||||||
|
)
|
||||||
def test_live_run_streaming_check_completion_start_time(self):
|
def test_live_run_streaming_check_completion_start_time(self):
|
||||||
generator = HuggingFaceAPIGenerator(
|
generator = HuggingFaceAPIGenerator(
|
||||||
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
|
||||||
@ -338,7 +347,8 @@ class TestHuggingFaceAPIGenerator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
results = generator.run(
|
results = generator.run(
|
||||||
"<|user|>\nWhat is the capital of France? Be concise only provide the capital, nothing else.<|end|>\n<|assistant|>\n"
|
"<|user|>\nWhat is the capital of France? Be concise only provide the capital, nothing else."
|
||||||
|
"<|end|>\n<|assistant|>\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert that the response contains the generated replies
|
# Assert that the response contains the generated replies
|
||||||
|
@ -4,10 +4,11 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from openai import OpenAIError
|
from openai import OpenAIError
|
||||||
from openai.types.chat import ChatCompletionChunk, chat_completion_chunk
|
from openai.types.chat import ChatCompletionChunk, chat_completion_chunk
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
from haystack.components.generators import OpenAIGenerator
|
from haystack.components.generators import OpenAIGenerator
|
||||||
from haystack.components.generators.utils import print_streaming_chunk
|
from haystack.components.generators.utils import print_streaming_chunk
|
||||||
|
@ -2,13 +2,14 @@
|
|||||||
#
|
#
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import pytest
|
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
from haystack.utils import Secret
|
|
||||||
|
|
||||||
from openai.types.image import Image
|
import pytest
|
||||||
from openai.types import ImagesResponse
|
from openai.types import ImagesResponse
|
||||||
|
from openai.types.image import Image
|
||||||
|
|
||||||
from haystack.components.generators.openai_dalle import DALLEImageGenerator
|
from haystack.components.generators.openai_dalle import DALLEImageGenerator
|
||||||
|
from haystack.utils import Secret
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -30,7 +31,7 @@ class TestDALLEImageGenerator:
|
|||||||
assert component.api_base_url is None
|
assert component.api_base_url is None
|
||||||
assert component.organization is None
|
assert component.organization is None
|
||||||
assert pytest.approx(component.timeout) == 30.0
|
assert pytest.approx(component.timeout) == 30.0
|
||||||
assert component.max_retries is 5
|
assert component.max_retries == 5
|
||||||
assert component.http_client_kwargs is None
|
assert component.http_client_kwargs is None
|
||||||
|
|
||||||
def test_init_with_params(self, monkeypatch):
|
def test_init_with_params(self, monkeypatch):
|
||||||
|
@ -2,10 +2,12 @@
|
|||||||
#
|
#
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from unittest.mock import call, patch
|
||||||
|
|
||||||
from openai.types.chat import chat_completion_chunk
|
from openai.types.chat import chat_completion_chunk
|
||||||
|
|
||||||
from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message
|
from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message, print_streaming_chunk
|
||||||
from haystack.dataclasses import ComponentInfo, StreamingChunk, ToolCallDelta
|
from haystack.dataclasses import ComponentInfo, StreamingChunk, ToolCall, ToolCallDelta, ToolCallResult
|
||||||
|
|
||||||
|
|
||||||
def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk():
|
def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk():
|
||||||
@ -325,3 +327,256 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk():
|
|||||||
},
|
},
|
||||||
"prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0},
|
"prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_streaming_chunk_to_chat_message_two_tool_calls_in_same_chunk():
|
||||||
|
chunks = [
|
||||||
|
StreamingChunk(
|
||||||
|
content="",
|
||||||
|
meta={
|
||||||
|
"model": "mistral-small-latest",
|
||||||
|
"index": 0,
|
||||||
|
"tool_calls": None,
|
||||||
|
"finish_reason": None,
|
||||||
|
"usage": None,
|
||||||
|
},
|
||||||
|
component_info=ComponentInfo(
|
||||||
|
type="haystack_integrations.components.generators.mistral.chat.chat_generator.MistralChatGenerator",
|
||||||
|
name=None,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
StreamingChunk(
|
||||||
|
content="",
|
||||||
|
meta={
|
||||||
|
"model": "mistral-small-latest",
|
||||||
|
"index": 0,
|
||||||
|
"finish_reason": "tool_calls",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 35,
|
||||||
|
"prompt_tokens": 77,
|
||||||
|
"total_tokens": 112,
|
||||||
|
"completion_tokens_details": None,
|
||||||
|
"prompt_tokens_details": None,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
component_info=ComponentInfo(
|
||||||
|
type="haystack_integrations.components.generators.mistral.chat.chat_generator.MistralChatGenerator",
|
||||||
|
name=None,
|
||||||
|
),
|
||||||
|
index=0,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCallDelta(index=0, tool_name="weather", arguments='{"city": "Paris"}', id="FL1FFlqUG"),
|
||||||
|
ToolCallDelta(index=1, tool_name="weather", arguments='{"city": "Berlin"}', id="xSuhp66iB"),
|
||||||
|
],
|
||||||
|
start=True,
|
||||||
|
finish_reason="tool_calls",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Convert chunks to a chat message
|
||||||
|
result = _convert_streaming_chunks_to_chat_message(chunks=chunks)
|
||||||
|
|
||||||
|
assert not result.texts
|
||||||
|
assert not result.text
|
||||||
|
|
||||||
|
# Verify both tool calls were found and processed
|
||||||
|
assert len(result.tool_calls) == 2
|
||||||
|
assert result.tool_calls[0].id == "FL1FFlqUG"
|
||||||
|
assert result.tool_calls[0].tool_name == "weather"
|
||||||
|
assert result.tool_calls[0].arguments == {"city": "Paris"}
|
||||||
|
assert result.tool_calls[1].id == "xSuhp66iB"
|
||||||
|
assert result.tool_calls[1].tool_name == "weather"
|
||||||
|
assert result.tool_calls[1].arguments == {"city": "Berlin"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_streaming_chunk_to_chat_message_empty_tool_call_delta():
|
||||||
|
chunks = [
|
||||||
|
StreamingChunk(
|
||||||
|
content="",
|
||||||
|
meta={
|
||||||
|
"model": "gpt-4o-mini-2024-07-18",
|
||||||
|
"index": 0,
|
||||||
|
"tool_calls": None,
|
||||||
|
"finish_reason": None,
|
||||||
|
"received_at": "2025-02-19T16:02:55.910076",
|
||||||
|
},
|
||||||
|
component_info=ComponentInfo(name="test", type="test"),
|
||||||
|
),
|
||||||
|
StreamingChunk(
|
||||||
|
content="",
|
||||||
|
meta={
|
||||||
|
"model": "gpt-4o-mini-2024-07-18",
|
||||||
|
"index": 0,
|
||||||
|
"tool_calls": [
|
||||||
|
chat_completion_chunk.ChoiceDeltaToolCall(
|
||||||
|
index=0,
|
||||||
|
id="call_ZOj5l67zhZOx6jqjg7ATQwb6",
|
||||||
|
function=chat_completion_chunk.ChoiceDeltaToolCallFunction(
|
||||||
|
arguments='{"query":', name="rag_pipeline_tool"
|
||||||
|
),
|
||||||
|
type="function",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
"finish_reason": None,
|
||||||
|
"received_at": "2025-02-19T16:02:55.913919",
|
||||||
|
},
|
||||||
|
component_info=ComponentInfo(name="test", type="test"),
|
||||||
|
index=0,
|
||||||
|
start=True,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCallDelta(
|
||||||
|
id="call_ZOj5l67zhZOx6jqjg7ATQwb6", tool_name="rag_pipeline_tool", arguments='{"query":', index=0
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
StreamingChunk(
|
||||||
|
content="",
|
||||||
|
meta={
|
||||||
|
"model": "gpt-4o-mini-2024-07-18",
|
||||||
|
"index": 0,
|
||||||
|
"tool_calls": [
|
||||||
|
chat_completion_chunk.ChoiceDeltaToolCall(
|
||||||
|
index=0,
|
||||||
|
function=chat_completion_chunk.ChoiceDeltaToolCallFunction(
|
||||||
|
arguments=' "Where does Mark live?"}'
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
"finish_reason": None,
|
||||||
|
"received_at": "2025-02-19T16:02:55.924420",
|
||||||
|
},
|
||||||
|
component_info=ComponentInfo(name="test", type="test"),
|
||||||
|
index=0,
|
||||||
|
tool_calls=[ToolCallDelta(arguments=' "Where does Mark live?"}', index=0)],
|
||||||
|
),
|
||||||
|
StreamingChunk(
|
||||||
|
content="",
|
||||||
|
meta={
|
||||||
|
"model": "gpt-4o-mini-2024-07-18",
|
||||||
|
"index": 0,
|
||||||
|
"tool_calls": [
|
||||||
|
chat_completion_chunk.ChoiceDeltaToolCall(
|
||||||
|
index=0, function=chat_completion_chunk.ChoiceDeltaToolCallFunction()
|
||||||
|
)
|
||||||
|
],
|
||||||
|
"finish_reason": "tool_calls",
|
||||||
|
"received_at": "2025-02-19T16:02:55.948772",
|
||||||
|
},
|
||||||
|
tool_calls=[ToolCallDelta(index=0)],
|
||||||
|
component_info=ComponentInfo(name="test", type="test"),
|
||||||
|
finish_reason="tool_calls",
|
||||||
|
index=0,
|
||||||
|
),
|
||||||
|
StreamingChunk(
|
||||||
|
content="",
|
||||||
|
meta={
|
||||||
|
"model": "gpt-4o-mini-2024-07-18",
|
||||||
|
"index": 0,
|
||||||
|
"tool_calls": None,
|
||||||
|
"finish_reason": None,
|
||||||
|
"received_at": "2025-02-19T16:02:55.948772",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 42,
|
||||||
|
"prompt_tokens": 282,
|
||||||
|
"total_tokens": 324,
|
||||||
|
"completion_tokens_details": {
|
||||||
|
"accepted_prediction_tokens": 0,
|
||||||
|
"audio_tokens": 0,
|
||||||
|
"reasoning_tokens": 0,
|
||||||
|
"rejected_prediction_tokens": 0,
|
||||||
|
},
|
||||||
|
"prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
component_info=ComponentInfo(name="test", type="test"),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Convert chunks to a chat message
|
||||||
|
result = _convert_streaming_chunks_to_chat_message(chunks=chunks)
|
||||||
|
|
||||||
|
assert not result.texts
|
||||||
|
assert not result.text
|
||||||
|
|
||||||
|
# Verify both tool calls were found and processed
|
||||||
|
assert len(result.tool_calls) == 1
|
||||||
|
assert result.tool_calls[0].id == "call_ZOj5l67zhZOx6jqjg7ATQwb6"
|
||||||
|
assert result.tool_calls[0].tool_name == "rag_pipeline_tool"
|
||||||
|
assert result.tool_calls[0].arguments == {"query": "Where does Mark live?"}
|
||||||
|
assert result.meta["finish_reason"] == "tool_calls"
|
||||||
|
|
||||||
|
|
||||||
|
def test_print_streaming_chunk_content_only():
|
||||||
|
chunk = StreamingChunk(
|
||||||
|
content="Hello, world!",
|
||||||
|
meta={"model": "test-model"},
|
||||||
|
component_info=ComponentInfo(name="test", type="test"),
|
||||||
|
start=True,
|
||||||
|
)
|
||||||
|
with patch("builtins.print") as mock_print:
|
||||||
|
print_streaming_chunk(chunk)
|
||||||
|
expected_calls = [call("[ASSISTANT]\n", flush=True, end=""), call("Hello, world!", flush=True, end="")]
|
||||||
|
mock_print.assert_has_calls(expected_calls)
|
||||||
|
|
||||||
|
|
||||||
|
def test_print_streaming_chunk_tool_call():
|
||||||
|
chunk = StreamingChunk(
|
||||||
|
content="",
|
||||||
|
meta={"model": "test-model"},
|
||||||
|
component_info=ComponentInfo(name="test", type="test"),
|
||||||
|
start=True,
|
||||||
|
index=0,
|
||||||
|
tool_calls=[ToolCallDelta(id="call_123", tool_name="test_tool", arguments='{"param": "value"}', index=0)],
|
||||||
|
)
|
||||||
|
with patch("builtins.print") as mock_print:
|
||||||
|
print_streaming_chunk(chunk)
|
||||||
|
expected_calls = [
|
||||||
|
call("[TOOL CALL]\nTool: test_tool \nArguments: ", flush=True, end=""),
|
||||||
|
call('{"param": "value"}', flush=True, end=""),
|
||||||
|
]
|
||||||
|
mock_print.assert_has_calls(expected_calls)
|
||||||
|
|
||||||
|
|
||||||
|
def test_print_streaming_chunk_tool_call_result():
|
||||||
|
chunk = StreamingChunk(
|
||||||
|
content="",
|
||||||
|
meta={"model": "test-model"},
|
||||||
|
component_info=ComponentInfo(name="test", type="test"),
|
||||||
|
index=0,
|
||||||
|
tool_call_result=ToolCallResult(
|
||||||
|
result="Tool execution completed successfully",
|
||||||
|
origin=ToolCall(id="call_123", tool_name="test_tool", arguments={}),
|
||||||
|
error=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
with patch("builtins.print") as mock_print:
|
||||||
|
print_streaming_chunk(chunk)
|
||||||
|
expected_calls = [call("[TOOL RESULT]\nTool execution completed successfully", flush=True, end="")]
|
||||||
|
mock_print.assert_has_calls(expected_calls)
|
||||||
|
|
||||||
|
|
||||||
|
def test_print_streaming_chunk_with_finish_reason():
|
||||||
|
chunk = StreamingChunk(
|
||||||
|
content="Final content.",
|
||||||
|
meta={"model": "test-model"},
|
||||||
|
component_info=ComponentInfo(name="test", type="test"),
|
||||||
|
start=True,
|
||||||
|
finish_reason="stop",
|
||||||
|
)
|
||||||
|
with patch("builtins.print") as mock_print:
|
||||||
|
print_streaming_chunk(chunk)
|
||||||
|
expected_calls = [
|
||||||
|
call("[ASSISTANT]\n", flush=True, end=""),
|
||||||
|
call("Final content.", flush=True, end=""),
|
||||||
|
call("\n\n", flush=True, end=""),
|
||||||
|
]
|
||||||
|
mock_print.assert_has_calls(expected_calls)
|
||||||
|
|
||||||
|
|
||||||
|
def test_print_streaming_chunk_empty_chunk():
|
||||||
|
chunk = StreamingChunk(
|
||||||
|
content="", meta={"model": "test-model"}, component_info=ComponentInfo(name="test", type="test")
|
||||||
|
)
|
||||||
|
with patch("builtins.print") as mock_print:
|
||||||
|
print_streaming_chunk(chunk)
|
||||||
|
mock_print.assert_not_called()
|
||||||
|
@ -5,8 +5,8 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from haystack import Document
|
from haystack import Document
|
||||||
from haystack.dataclasses.answer import ExtractedAnswer, GeneratedAnswer
|
|
||||||
from haystack.components.joiners.answer_joiner import AnswerJoiner, JoinMode
|
from haystack.components.joiners.answer_joiner import AnswerJoiner, JoinMode
|
||||||
|
from haystack.dataclasses.answer import ExtractedAnswer, GeneratedAnswer
|
||||||
|
|
||||||
|
|
||||||
class TestAnswerJoiner:
|
class TestAnswerJoiner:
|
||||||
|
@ -3,16 +3,17 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from haystack import Document, Pipeline
|
from haystack import Document, Pipeline
|
||||||
from haystack.dataclasses import ChatMessage
|
|
||||||
from haystack.dataclasses.answer import GeneratedAnswer
|
|
||||||
from haystack.components.builders import AnswerBuilder, ChatPromptBuilder
|
from haystack.components.builders import AnswerBuilder, ChatPromptBuilder
|
||||||
|
from haystack.components.embedders import SentenceTransformersTextEmbedder
|
||||||
from haystack.components.generators.chat.openai import OpenAIChatGenerator
|
from haystack.components.generators.chat.openai import OpenAIChatGenerator
|
||||||
from haystack.components.joiners.list_joiner import ListJoiner
|
from haystack.components.joiners.list_joiner import ListJoiner
|
||||||
from haystack.components.embedders import SentenceTransformersTextEmbedder
|
|
||||||
from haystack.core.errors import PipelineConnectError
|
from haystack.core.errors import PipelineConnectError
|
||||||
|
from haystack.dataclasses import ChatMessage
|
||||||
|
from haystack.dataclasses.answer import GeneratedAnswer
|
||||||
from haystack.utils.auth import Secret
|
from haystack.utils.auth import Secret
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,8 +2,8 @@
|
|||||||
#
|
#
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from haystack.core.serialization import component_from_dict, component_to_dict
|
|
||||||
from haystack.components.joiners.string_joiner import StringJoiner
|
from haystack.components.joiners.string_joiner import StringJoiner
|
||||||
|
from haystack.core.serialization import component_from_dict, component_to_dict
|
||||||
|
|
||||||
|
|
||||||
class TestStringJoiner:
|
class TestStringJoiner:
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from haystack import Document
|
from haystack import Document
|
||||||
|
|
||||||
from haystack.components.preprocessors.csv_document_cleaner import CSVDocumentCleaner
|
from haystack.components.preprocessors.csv_document_cleaner import CSVDocumentCleaner
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,13 +2,15 @@
|
|||||||
#
|
#
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import pytest
|
|
||||||
import logging
|
import logging
|
||||||
from pandas import read_csv
|
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pandas import read_csv
|
||||||
|
|
||||||
from haystack import Document
|
from haystack import Document
|
||||||
from haystack.core.serialization import component_from_dict, component_to_dict
|
|
||||||
from haystack.components.preprocessors.csv_document_splitter import CSVDocumentSplitter
|
from haystack.components.preprocessors.csv_document_splitter import CSVDocumentSplitter
|
||||||
|
from haystack.core.serialization import component_from_dict, component_to_dict
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -7,8 +7,8 @@ import logging
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from haystack import Document
|
from haystack import Document
|
||||||
from haystack.dataclasses import ByteStream, SparseEmbedding
|
|
||||||
from haystack.components.preprocessors import DocumentCleaner
|
from haystack.components.preprocessors import DocumentCleaner
|
||||||
|
from haystack.dataclasses import ByteStream, SparseEmbedding
|
||||||
|
|
||||||
|
|
||||||
class TestDocumentCleaner:
|
class TestDocumentCleaner:
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from haystack import Document, Pipeline
|
from haystack import Document, Pipeline
|
||||||
|
@ -2,9 +2,8 @@
|
|||||||
#
|
#
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -27,8 +26,7 @@ def merge_documents(documents):
|
|||||||
start = doc.meta["split_idx_start"] # start of the current content
|
start = doc.meta["split_idx_start"] # start of the current content
|
||||||
|
|
||||||
# if the start of the current content is before the end of the last appended content, adjust it
|
# if the start of the current content is before the end of the last appended content, adjust it
|
||||||
if start < last_idx_end:
|
start = max(start, last_idx_end)
|
||||||
start = last_idx_end
|
|
||||||
|
|
||||||
# append the non-overlapping part to the merged text
|
# append the non-overlapping part to the merged text
|
||||||
merged_text += doc.content[start - doc.meta["split_idx_start"] :]
|
merged_text += doc.content[start - doc.meta["split_idx_start"] :]
|
||||||
@ -40,7 +38,7 @@ def merge_documents(documents):
|
|||||||
|
|
||||||
|
|
||||||
class TestSplittingByFunctionOrCharacterRegex:
|
class TestSplittingByFunctionOrCharacterRegex:
|
||||||
def test_non_text_document(self):
|
def test_non_text_document(self, caplog):
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError, match="DocumentSplitter only works with text documents but content for document ID"
|
ValueError, match="DocumentSplitter only works with text documents but content for document ID"
|
||||||
):
|
):
|
||||||
@ -110,7 +108,10 @@ class TestSplittingByFunctionOrCharacterRegex:
|
|||||||
def test_split_by_word_multiple_input_docs(self):
|
def test_split_by_word_multiple_input_docs(self):
|
||||||
splitter = DocumentSplitter(split_by="word", split_length=10)
|
splitter = DocumentSplitter(split_by="word", split_length=10)
|
||||||
text1 = "This is a text with some words. There is a second sentence. And there is a third sentence."
|
text1 = "This is a text with some words. There is a second sentence. And there is a third sentence."
|
||||||
text2 = "This is a different text with some words. There is a second sentence. And there is a third sentence. And there is a fourth sentence."
|
text2 = (
|
||||||
|
"This is a different text with some words. There is a second sentence. And there is a third sentence. "
|
||||||
|
"And there is a fourth sentence."
|
||||||
|
)
|
||||||
splitter.warm_up()
|
splitter.warm_up()
|
||||||
result = splitter.run(documents=[Document(content=text1), Document(content=text2)])
|
result = splitter.run(documents=[Document(content=text1), Document(content=text2)])
|
||||||
docs = result["documents"]
|
docs = result["documents"]
|
||||||
@ -155,7 +156,10 @@ class TestSplittingByFunctionOrCharacterRegex:
|
|||||||
|
|
||||||
def test_split_by_passage(self):
|
def test_split_by_passage(self):
|
||||||
splitter = DocumentSplitter(split_by="passage", split_length=1)
|
splitter = DocumentSplitter(split_by="passage", split_length=1)
|
||||||
text = "This is a text with some words. There is a second sentence.\n\nAnd there is a third sentence.\n\n And another passage."
|
text = (
|
||||||
|
"This is a text with some words. There is a second sentence.\n\nAnd there is a third sentence.\n\n "
|
||||||
|
"And another passage."
|
||||||
|
)
|
||||||
splitter.warm_up()
|
splitter.warm_up()
|
||||||
result = splitter.run(documents=[Document(content=text)])
|
result = splitter.run(documents=[Document(content=text)])
|
||||||
docs = result["documents"]
|
docs = result["documents"]
|
||||||
@ -172,7 +176,10 @@ class TestSplittingByFunctionOrCharacterRegex:
|
|||||||
|
|
||||||
def test_split_by_page(self):
|
def test_split_by_page(self):
|
||||||
splitter = DocumentSplitter(split_by="page", split_length=1)
|
splitter = DocumentSplitter(split_by="page", split_length=1)
|
||||||
text = "This is a text with some words. There is a second sentence.\f And there is a third sentence.\f And another passage."
|
text = (
|
||||||
|
"This is a text with some words. There is a second sentence.\f And there is a third sentence.\f And "
|
||||||
|
"another passage."
|
||||||
|
)
|
||||||
splitter.warm_up()
|
splitter.warm_up()
|
||||||
result = splitter.run(documents=[Document(content=text)])
|
result = splitter.run(documents=[Document(content=text)])
|
||||||
docs = result["documents"]
|
docs = result["documents"]
|
||||||
@ -310,7 +317,8 @@ class TestSplittingByFunctionOrCharacterRegex:
|
|||||||
def test_add_page_number_to_metadata_with_no_overlap_passage_split(self):
|
def test_add_page_number_to_metadata_with_no_overlap_passage_split(self):
|
||||||
splitter = DocumentSplitter(split_by="passage", split_length=1)
|
splitter = DocumentSplitter(split_by="passage", split_length=1)
|
||||||
doc1 = Document(
|
doc1 = Document(
|
||||||
content="This is a text with some words.\f There is a second sentence.\n\nAnd there is a third sentence.\n\nAnd more passages.\n\n\f And another passage."
|
content="This is a text with some words.\f There is a second sentence.\n\nAnd there is a third sentence."
|
||||||
|
"\n\nAnd more passages.\n\n\f And another passage."
|
||||||
)
|
)
|
||||||
splitter.warm_up()
|
splitter.warm_up()
|
||||||
result = splitter.run(documents=[doc1])
|
result = splitter.run(documents=[doc1])
|
||||||
@ -322,7 +330,8 @@ class TestSplittingByFunctionOrCharacterRegex:
|
|||||||
def test_add_page_number_to_metadata_with_no_overlap_page_split(self):
|
def test_add_page_number_to_metadata_with_no_overlap_page_split(self):
|
||||||
splitter = DocumentSplitter(split_by="page", split_length=1)
|
splitter = DocumentSplitter(split_by="page", split_length=1)
|
||||||
doc1 = Document(
|
doc1 = Document(
|
||||||
content="This is a text with some words. There is a second sentence.\f And there is a third sentence.\f And another passage."
|
content="This is a text with some words. There is a second sentence.\f And there is a third sentence.\f "
|
||||||
|
"And another passage."
|
||||||
)
|
)
|
||||||
splitter.warm_up()
|
splitter.warm_up()
|
||||||
result = splitter.run(documents=[doc1])
|
result = splitter.run(documents=[doc1])
|
||||||
@ -332,7 +341,8 @@ class TestSplittingByFunctionOrCharacterRegex:
|
|||||||
|
|
||||||
splitter = DocumentSplitter(split_by="page", split_length=2)
|
splitter = DocumentSplitter(split_by="page", split_length=2)
|
||||||
doc1 = Document(
|
doc1 = Document(
|
||||||
content="This is a text with some words. There is a second sentence.\f And there is a third sentence.\f And another passage."
|
content="This is a text with some words. There is a second sentence.\f And there is a third sentence.\f "
|
||||||
|
"And another passage."
|
||||||
)
|
)
|
||||||
splitter.warm_up()
|
splitter.warm_up()
|
||||||
result = splitter.run(documents=[doc1])
|
result = splitter.run(documents=[doc1])
|
||||||
@ -366,7 +376,8 @@ class TestSplittingByFunctionOrCharacterRegex:
|
|||||||
def test_add_page_number_to_metadata_with_overlap_passage_split(self):
|
def test_add_page_number_to_metadata_with_overlap_passage_split(self):
|
||||||
splitter = DocumentSplitter(split_by="passage", split_length=2, split_overlap=1)
|
splitter = DocumentSplitter(split_by="passage", split_length=2, split_overlap=1)
|
||||||
doc1 = Document(
|
doc1 = Document(
|
||||||
content="This is a text with some words.\f There is a second sentence.\n\nAnd there is a third sentence.\n\nAnd more passages.\n\n\f And another passage."
|
content="This is a text with some words.\f There is a second sentence.\n\nAnd there is a third sentence."
|
||||||
|
"\n\nAnd more passages.\n\n\f And another passage."
|
||||||
)
|
)
|
||||||
splitter.warm_up()
|
splitter.warm_up()
|
||||||
result = splitter.run(documents=[doc1])
|
result = splitter.run(documents=[doc1])
|
||||||
@ -378,7 +389,8 @@ class TestSplittingByFunctionOrCharacterRegex:
|
|||||||
def test_add_page_number_to_metadata_with_overlap_page_split(self):
|
def test_add_page_number_to_metadata_with_overlap_page_split(self):
|
||||||
splitter = DocumentSplitter(split_by="page", split_length=2, split_overlap=1)
|
splitter = DocumentSplitter(split_by="page", split_length=2, split_overlap=1)
|
||||||
doc1 = Document(
|
doc1 = Document(
|
||||||
content="This is a text with some words. There is a second sentence.\f And there is a third sentence.\f And another passage."
|
content="This is a text with some words. There is a second sentence.\f And there is a third sentence.\f "
|
||||||
|
"And another passage."
|
||||||
)
|
)
|
||||||
splitter.warm_up()
|
splitter.warm_up()
|
||||||
result = splitter.run(documents=[doc1])
|
result = splitter.run(documents=[doc1])
|
||||||
|
@ -120,7 +120,7 @@ class TestHierarchicalDocumentSplitter:
|
|||||||
"max_runs_per_component": 100,
|
"max_runs_per_component": 100,
|
||||||
"components": {
|
"components": {
|
||||||
"hierarchical_document_splitter": {
|
"hierarchical_document_splitter": {
|
||||||
"type": "haystack.components.preprocessors.hierarchical_document_splitter.HierarchicalDocumentSplitter",
|
"type": "haystack.components.preprocessors.hierarchical_document_splitter.HierarchicalDocumentSplitter", # noqa: E501
|
||||||
"init_parameters": {"block_sizes": [10, 5, 2], "split_overlap": 0, "split_by": "word"},
|
"init_parameters": {"block_sizes": [10, 5, 2], "split_overlap": 0, "split_by": "word"},
|
||||||
},
|
},
|
||||||
"doc_writer": {
|
"doc_writer": {
|
||||||
@ -203,7 +203,8 @@ class TestHierarchicalDocumentSplitter:
|
|||||||
def test_hierarchical_splitter_multiple_block_sizes(self):
|
def test_hierarchical_splitter_multiple_block_sizes(self):
|
||||||
# Test with three different block sizes
|
# Test with three different block sizes
|
||||||
doc = Document(
|
doc = Document(
|
||||||
content="This is a simple test document with multiple sentences. It should be split into various sizes. This helps test the hierarchy."
|
content="This is a simple test document with multiple sentences. It should be split into various sizes. "
|
||||||
|
"This helps test the hierarchy."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Using three block sizes: 10, 5, 2 words
|
# Using three block sizes: 10, 5, 2 words
|
||||||
|
@ -58,8 +58,8 @@ def test_apply_overlap_with_overlap_capturing_completely_previous_chunk(caplog):
|
|||||||
chunks = ["chunk1", "chunk2", "chunk3", "chunk4"]
|
chunks = ["chunk1", "chunk2", "chunk3", "chunk4"]
|
||||||
_ = splitter._apply_overlap(chunks)
|
_ = splitter._apply_overlap(chunks)
|
||||||
assert (
|
assert (
|
||||||
"Overlap is the same as the previous chunk. Consider increasing the `split_length` parameter or decreasing the `split_overlap` parameter."
|
"Overlap is the same as the previous chunk. Consider increasing the `split_length` parameter or decreasing "
|
||||||
in caplog.text
|
"the `split_overlap` parameter." in caplog.text
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -134,12 +134,15 @@ AI technology is widely used throughout industry, government, and science. Some
|
|||||||
assert (
|
assert (
|
||||||
chunks[1].content
|
chunks[1].content
|
||||||
== "AI, in its broadest sense, is intelligence exhibited by machines, particularly computer systems.\n"
|
== "AI, in its broadest sense, is intelligence exhibited by machines, particularly computer systems.\n"
|
||||||
) # noqa: E501
|
)
|
||||||
assert chunks[2].content == "AI technology is widely used throughout industry, government, and science." # noqa: E501
|
assert chunks[2].content == "AI technology is widely used throughout industry, government, and science."
|
||||||
assert (
|
assert (
|
||||||
chunks[3].content
|
chunks[3].content
|
||||||
== "Some high-profile applications include advanced web search engines (e.g., Google Search); recommendation systems (used by YouTube, Amazon, and Netflix); interacting via human speech (e.g., Google Assistant, Siri, and Alexa); autonomous vehicles (e.g., Waymo); generative and creative tools (e.g., ChatGPT and AI art); and superhuman play and analysis in strategy games (e.g., chess and Go)."
|
== "Some high-profile applications include advanced web search engines (e.g., Google Search); recommendation "
|
||||||
) # noqa: E501
|
"systems (used by YouTube, Amazon, and Netflix); interacting via human speech (e.g., Google Assistant, "
|
||||||
|
"Siri, and Alexa); autonomous vehicles (e.g., Waymo); generative and creative tools (e.g., ChatGPT and "
|
||||||
|
"AI art); and superhuman play and analysis in strategy games (e.g., chess and Go)."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_run_split_by_dot_count_page_breaks_split_unit_char() -> None:
|
def test_run_split_by_dot_count_page_breaks_split_unit_char() -> None:
|
||||||
|
@ -3,15 +3,14 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import time
|
import time
|
||||||
import pytest
|
|
||||||
from unittest.mock import patch
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
from haystack.components.preprocessors.sentence_tokenizer import SentenceSplitter
|
import pytest
|
||||||
from haystack.components.preprocessors.sentence_tokenizer import QUOTE_SPANS_RE
|
|
||||||
|
|
||||||
from pytest import LogCaptureFixture
|
from pytest import LogCaptureFixture
|
||||||
|
|
||||||
|
from haystack.components.preprocessors.sentence_tokenizer import QUOTE_SPANS_RE, SentenceSplitter
|
||||||
|
|
||||||
|
|
||||||
def test_apply_split_rules_no_join() -> None:
|
def test_apply_split_rules_no_join() -> None:
|
||||||
text = "This is a test. This is another test. And a third test."
|
text = "This is a test. This is another test. And a third test."
|
||||||
@ -56,7 +55,7 @@ def mock_file_content():
|
|||||||
def test_read_abbreviations_existing_file(tmp_path, mock_file_content):
|
def test_read_abbreviations_existing_file(tmp_path, mock_file_content):
|
||||||
abbrev_dir = tmp_path / "data" / "abbreviations"
|
abbrev_dir = tmp_path / "data" / "abbreviations"
|
||||||
abbrev_dir.mkdir(parents=True)
|
abbrev_dir.mkdir(parents=True)
|
||||||
abbrev_file = abbrev_dir / f"en.txt"
|
abbrev_file = abbrev_dir / "en.txt"
|
||||||
abbrev_file.write_text(mock_file_content)
|
abbrev_file.write_text(mock_file_content)
|
||||||
|
|
||||||
with patch("haystack.components.preprocessors.sentence_tokenizer.Path") as mock_path:
|
with patch("haystack.components.preprocessors.sentence_tokenizer.Path") as mock_path:
|
||||||
|
@ -2,11 +2,11 @@
|
|||||||
#
|
#
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import pytest
|
from unittest.mock import MagicMock, patch
|
||||||
from unittest.mock import patch, MagicMock
|
|
||||||
|
|
||||||
import requests
|
|
||||||
import httpx
|
import httpx
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
|
||||||
from haystack import Document
|
from haystack import Document
|
||||||
from haystack.components.rankers.hugging_face_tei import HuggingFaceTEIRanker, TruncationDirection
|
from haystack.components.rankers.hugging_face_tei import HuggingFaceTEIRanker, TruncationDirection
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user