mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-28 10:26:27 +00:00
move embedding backends (#6033)
This commit is contained in:
parent
d51be9edac
commit
2c2549f13d
@ -1,7 +1,7 @@
|
||||
from typing import List, Optional, Union, Dict, Any
|
||||
|
||||
from haystack.preview import component, Document, default_to_dict, default_from_dict
|
||||
from haystack.preview.embedding_backends.sentence_transformers_backend import (
|
||||
from haystack.preview.components.embedders.backends.sentence_transformers_backend import (
|
||||
_SentenceTransformersEmbeddingBackendFactory,
|
||||
)
|
||||
|
||||
@ -29,11 +29,13 @@ class SentenceTransformersDocumentEmbedder:
|
||||
"""
|
||||
Create a SentenceTransformersDocumentEmbedder component.
|
||||
|
||||
:param model_name_or_path: Local path or name of the model in Hugging Face's model hub, such as ``'sentence-transformers/all-mpnet-base-v2'``.
|
||||
:param device: Device (like 'cuda' / 'cpu') that should be used for computation. If None, checks if a GPU can be used.
|
||||
:param model_name_or_path: Local path or name of the model in Hugging Face's model hub,
|
||||
such as ``'sentence-transformers/all-mpnet-base-v2'``.
|
||||
:param device: Device (like 'cuda' / 'cpu') that should be used for computation.
|
||||
Defaults to CPU.
|
||||
:param use_auth_token: The API token used to download private models from Hugging Face.
|
||||
If this parameter is set to `True`, then the token generated when running
|
||||
`transformers-cli login` (stored in ~/.huggingface) will be used.
|
||||
If this parameter is set to `True`, then the token generated when running
|
||||
`transformers-cli login` (stored in ~/.huggingface) will be used.
|
||||
:param prefix: A string to add to the beginning of each Document text before embedding.
|
||||
:param suffix: A string to add to the end of each Document text before embedding.
|
||||
:param batch_size: Number of strings to encode at once.
|
||||
@ -95,7 +97,7 @@ class SentenceTransformersDocumentEmbedder:
|
||||
Embed a list of Documents.
|
||||
The embedding of each Document is stored in the `embedding` field of the Document.
|
||||
"""
|
||||
if not isinstance(documents, list) or not isinstance(documents[0], Document):
|
||||
if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
|
||||
raise TypeError(
|
||||
"SentenceTransformersDocumentEmbedder expects a list of Documents as input."
|
||||
"In case you want to embed a list of strings, please use the SentenceTransformersTextEmbedder."
|
||||
|
@ -1,7 +1,7 @@
|
||||
from typing import List, Optional, Union, Dict, Any
|
||||
|
||||
from haystack.preview import component, default_to_dict, default_from_dict
|
||||
from haystack.preview.embedding_backends.sentence_transformers_backend import (
|
||||
from haystack.preview.components.embedders.backends.sentence_transformers_backend import (
|
||||
_SentenceTransformersEmbeddingBackendFactory,
|
||||
)
|
||||
|
||||
@ -26,11 +26,13 @@ class SentenceTransformersTextEmbedder:
|
||||
"""
|
||||
Create a SentenceTransformersTextEmbedder component.
|
||||
|
||||
:param model_name_or_path: Local path or name of the model in Hugging Face's model hub, such as ``'sentence-transformers/all-mpnet-base-v2'``.
|
||||
:param device: Device (like 'cuda' / 'cpu') that should be used for computation. If None, checks if a GPU can be used.
|
||||
:param model_name_or_path: Local path or name of the model in Hugging Face's model hub,
|
||||
such as ``'sentence-transformers/all-mpnet-base-v2'``.
|
||||
:param device: Device (like 'cuda' / 'cpu') that should be used for computation.
|
||||
Defaults to CPU.
|
||||
:param use_auth_token: The API token used to download private models from Hugging Face.
|
||||
If this parameter is set to `True`, then the token generated when running
|
||||
`transformers-cli login` (stored in ~/.huggingface) will be used.
|
||||
If this parameter is set to `True`, then the token generated when running
|
||||
`transformers-cli login` (stored in ~/.huggingface) will be used.
|
||||
:param prefix: A string to add to the beginning of each text.
|
||||
:param suffix: A string to add to the end of each text.
|
||||
:param batch_size: Number of strings to encode at once.
|
||||
|
@ -1,12 +1,12 @@
|
||||
from unittest.mock import patch
|
||||
import pytest
|
||||
from haystack.preview.embedding_backends.sentence_transformers_backend import (
|
||||
from haystack.preview.components.embedders.backends.sentence_transformers_backend import (
|
||||
_SentenceTransformersEmbeddingBackendFactory,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.preview.embedding_backends.sentence_transformers_backend.SentenceTransformer")
|
||||
@patch("haystack.preview.components.embedders.backends.sentence_transformers_backend.SentenceTransformer")
|
||||
def test_factory_behavior(mock_sentence_transformer):
|
||||
embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
|
||||
model_name_or_path="my_model", device="cpu"
|
||||
@ -21,7 +21,7 @@ def test_factory_behavior(mock_sentence_transformer):
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.preview.embedding_backends.sentence_transformers_backend.SentenceTransformer")
|
||||
@patch("haystack.preview.components.embedders.backends.sentence_transformers_backend.SentenceTransformer")
|
||||
def test_model_initialization(mock_sentence_transformer):
|
||||
_SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
|
||||
model_name_or_path="model", device="cpu", use_auth_token="my_token"
|
||||
@ -32,7 +32,7 @@ def test_model_initialization(mock_sentence_transformer):
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.preview.embedding_backends.sentence_transformers_backend.SentenceTransformer")
|
||||
@patch("haystack.preview.components.embedders.backends.sentence_transformers_backend.SentenceTransformer")
|
||||
def test_embedding_function_with_kwargs(mock_sentence_transformer):
|
||||
embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(model_name_or_path="model")
|
||||
|
Loading…
x
Reference in New Issue
Block a user