move embedding backends (#6033)

This commit is contained in:
Stefano Fiorucci 2023-10-12 17:52:28 +02:00 committed by GitHub
parent d51be9edac
commit 2c2549f13d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 19 additions and 15 deletions

View File

@ -1,7 +1,7 @@
from typing import List, Optional, Union, Dict, Any
from haystack.preview import component, Document, default_to_dict, default_from_dict
from haystack.preview.embedding_backends.sentence_transformers_backend import (
from haystack.preview.components.embedders.backends.sentence_transformers_backend import (
_SentenceTransformersEmbeddingBackendFactory,
)
@ -29,8 +29,10 @@ class SentenceTransformersDocumentEmbedder:
"""
Create a SentenceTransformersDocumentEmbedder component.
:param model_name_or_path: Local path or name of the model in Hugging Face's model hub, such as ``'sentence-transformers/all-mpnet-base-v2'``.
:param device: Device (like 'cuda' / 'cpu') that should be used for computation. If None, checks if a GPU can be used.
:param model_name_or_path: Local path or name of the model in Hugging Face's model hub,
such as ``'sentence-transformers/all-mpnet-base-v2'``.
:param device: Device (like 'cuda' / 'cpu') that should be used for computation.
Defaults to CPU.
:param use_auth_token: The API token used to download private models from Hugging Face.
If this parameter is set to `True`, then the token generated when running
`transformers-cli login` (stored in ~/.huggingface) will be used.
@ -95,7 +97,7 @@ class SentenceTransformersDocumentEmbedder:
Embed a list of Documents.
The embedding of each Document is stored in the `embedding` field of the Document.
"""
if not isinstance(documents, list) or not isinstance(documents[0], Document):
if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
raise TypeError(
"SentenceTransformersDocumentEmbedder expects a list of Documents as input."
"In case you want to embed a list of strings, please use the SentenceTransformersTextEmbedder."

View File

@ -1,7 +1,7 @@
from typing import List, Optional, Union, Dict, Any
from haystack.preview import component, default_to_dict, default_from_dict
from haystack.preview.embedding_backends.sentence_transformers_backend import (
from haystack.preview.components.embedders.backends.sentence_transformers_backend import (
_SentenceTransformersEmbeddingBackendFactory,
)
@ -26,8 +26,10 @@ class SentenceTransformersTextEmbedder:
"""
Create a SentenceTransformersTextEmbedder component.
:param model_name_or_path: Local path or name of the model in Hugging Face's model hub, such as ``'sentence-transformers/all-mpnet-base-v2'``.
:param device: Device (like 'cuda' / 'cpu') that should be used for computation. If None, checks if a GPU can be used.
:param model_name_or_path: Local path or name of the model in Hugging Face's model hub,
such as ``'sentence-transformers/all-mpnet-base-v2'``.
:param device: Device (like 'cuda' / 'cpu') that should be used for computation.
Defaults to CPU.
:param use_auth_token: The API token used to download private models from Hugging Face.
If this parameter is set to `True`, then the token generated when running
`transformers-cli login` (stored in ~/.huggingface) will be used.

View File

@ -1,12 +1,12 @@
from unittest.mock import patch
import pytest
from haystack.preview.embedding_backends.sentence_transformers_backend import (
from haystack.preview.components.embedders.backends.sentence_transformers_backend import (
_SentenceTransformersEmbeddingBackendFactory,
)
@pytest.mark.unit
@patch("haystack.preview.embedding_backends.sentence_transformers_backend.SentenceTransformer")
@patch("haystack.preview.components.embedders.backends.sentence_transformers_backend.SentenceTransformer")
def test_factory_behavior(mock_sentence_transformer):
embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
model_name_or_path="my_model", device="cpu"
@ -21,7 +21,7 @@ def test_factory_behavior(mock_sentence_transformer):
@pytest.mark.unit
@patch("haystack.preview.embedding_backends.sentence_transformers_backend.SentenceTransformer")
@patch("haystack.preview.components.embedders.backends.sentence_transformers_backend.SentenceTransformer")
def test_model_initialization(mock_sentence_transformer):
_SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
model_name_or_path="model", device="cpu", use_auth_token="my_token"
@ -32,7 +32,7 @@ def test_model_initialization(mock_sentence_transformer):
@pytest.mark.unit
@patch("haystack.preview.embedding_backends.sentence_transformers_backend.SentenceTransformer")
@patch("haystack.preview.components.embedders.backends.sentence_transformers_backend.SentenceTransformer")
def test_embedding_function_with_kwargs(mock_sentence_transformer):
embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(model_name_or_path="model")