mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-28 18:36:36 +00:00
move embedding backends (#6033)
This commit is contained in:
parent
d51be9edac
commit
2c2549f13d
@ -1,7 +1,7 @@
|
|||||||
from typing import List, Optional, Union, Dict, Any
|
from typing import List, Optional, Union, Dict, Any
|
||||||
|
|
||||||
from haystack.preview import component, Document, default_to_dict, default_from_dict
|
from haystack.preview import component, Document, default_to_dict, default_from_dict
|
||||||
from haystack.preview.embedding_backends.sentence_transformers_backend import (
|
from haystack.preview.components.embedders.backends.sentence_transformers_backend import (
|
||||||
_SentenceTransformersEmbeddingBackendFactory,
|
_SentenceTransformersEmbeddingBackendFactory,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -29,11 +29,13 @@ class SentenceTransformersDocumentEmbedder:
|
|||||||
"""
|
"""
|
||||||
Create a SentenceTransformersDocumentEmbedder component.
|
Create a SentenceTransformersDocumentEmbedder component.
|
||||||
|
|
||||||
:param model_name_or_path: Local path or name of the model in Hugging Face's model hub, such as ``'sentence-transformers/all-mpnet-base-v2'``.
|
:param model_name_or_path: Local path or name of the model in Hugging Face's model hub,
|
||||||
:param device: Device (like 'cuda' / 'cpu') that should be used for computation. If None, checks if a GPU can be used.
|
such as ``'sentence-transformers/all-mpnet-base-v2'``.
|
||||||
|
:param device: Device (like 'cuda' / 'cpu') that should be used for computation.
|
||||||
|
Defaults to CPU.
|
||||||
:param use_auth_token: The API token used to download private models from Hugging Face.
|
:param use_auth_token: The API token used to download private models from Hugging Face.
|
||||||
If this parameter is set to `True`, then the token generated when running
|
If this parameter is set to `True`, then the token generated when running
|
||||||
`transformers-cli login` (stored in ~/.huggingface) will be used.
|
`transformers-cli login` (stored in ~/.huggingface) will be used.
|
||||||
:param prefix: A string to add to the beginning of each Document text before embedding.
|
:param prefix: A string to add to the beginning of each Document text before embedding.
|
||||||
:param suffix: A string to add to the end of each Document text before embedding.
|
:param suffix: A string to add to the end of each Document text before embedding.
|
||||||
:param batch_size: Number of strings to encode at once.
|
:param batch_size: Number of strings to encode at once.
|
||||||
@ -95,7 +97,7 @@ class SentenceTransformersDocumentEmbedder:
|
|||||||
Embed a list of Documents.
|
Embed a list of Documents.
|
||||||
The embedding of each Document is stored in the `embedding` field of the Document.
|
The embedding of each Document is stored in the `embedding` field of the Document.
|
||||||
"""
|
"""
|
||||||
if not isinstance(documents, list) or not isinstance(documents[0], Document):
|
if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"SentenceTransformersDocumentEmbedder expects a list of Documents as input."
|
"SentenceTransformersDocumentEmbedder expects a list of Documents as input."
|
||||||
"In case you want to embed a list of strings, please use the SentenceTransformersTextEmbedder."
|
"In case you want to embed a list of strings, please use the SentenceTransformersTextEmbedder."
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from typing import List, Optional, Union, Dict, Any
|
from typing import List, Optional, Union, Dict, Any
|
||||||
|
|
||||||
from haystack.preview import component, default_to_dict, default_from_dict
|
from haystack.preview import component, default_to_dict, default_from_dict
|
||||||
from haystack.preview.embedding_backends.sentence_transformers_backend import (
|
from haystack.preview.components.embedders.backends.sentence_transformers_backend import (
|
||||||
_SentenceTransformersEmbeddingBackendFactory,
|
_SentenceTransformersEmbeddingBackendFactory,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -26,11 +26,13 @@ class SentenceTransformersTextEmbedder:
|
|||||||
"""
|
"""
|
||||||
Create a SentenceTransformersTextEmbedder component.
|
Create a SentenceTransformersTextEmbedder component.
|
||||||
|
|
||||||
:param model_name_or_path: Local path or name of the model in Hugging Face's model hub, such as ``'sentence-transformers/all-mpnet-base-v2'``.
|
:param model_name_or_path: Local path or name of the model in Hugging Face's model hub,
|
||||||
:param device: Device (like 'cuda' / 'cpu') that should be used for computation. If None, checks if a GPU can be used.
|
such as ``'sentence-transformers/all-mpnet-base-v2'``.
|
||||||
|
:param device: Device (like 'cuda' / 'cpu') that should be used for computation.
|
||||||
|
Defaults to CPU.
|
||||||
:param use_auth_token: The API token used to download private models from Hugging Face.
|
:param use_auth_token: The API token used to download private models from Hugging Face.
|
||||||
If this parameter is set to `True`, then the token generated when running
|
If this parameter is set to `True`, then the token generated when running
|
||||||
`transformers-cli login` (stored in ~/.huggingface) will be used.
|
`transformers-cli login` (stored in ~/.huggingface) will be used.
|
||||||
:param prefix: A string to add to the beginning of each text.
|
:param prefix: A string to add to the beginning of each text.
|
||||||
:param suffix: A string to add to the end of each text.
|
:param suffix: A string to add to the end of each text.
|
||||||
:param batch_size: Number of strings to encode at once.
|
:param batch_size: Number of strings to encode at once.
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
from haystack.preview.embedding_backends.sentence_transformers_backend import (
|
from haystack.preview.components.embedders.backends.sentence_transformers_backend import (
|
||||||
_SentenceTransformersEmbeddingBackendFactory,
|
_SentenceTransformersEmbeddingBackendFactory,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
@patch("haystack.preview.embedding_backends.sentence_transformers_backend.SentenceTransformer")
|
@patch("haystack.preview.components.embedders.backends.sentence_transformers_backend.SentenceTransformer")
|
||||||
def test_factory_behavior(mock_sentence_transformer):
|
def test_factory_behavior(mock_sentence_transformer):
|
||||||
embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
|
embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
|
||||||
model_name_or_path="my_model", device="cpu"
|
model_name_or_path="my_model", device="cpu"
|
||||||
@ -21,7 +21,7 @@ def test_factory_behavior(mock_sentence_transformer):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
@patch("haystack.preview.embedding_backends.sentence_transformers_backend.SentenceTransformer")
|
@patch("haystack.preview.components.embedders.backends.sentence_transformers_backend.SentenceTransformer")
|
||||||
def test_model_initialization(mock_sentence_transformer):
|
def test_model_initialization(mock_sentence_transformer):
|
||||||
_SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
|
_SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
|
||||||
model_name_or_path="model", device="cpu", use_auth_token="my_token"
|
model_name_or_path="model", device="cpu", use_auth_token="my_token"
|
||||||
@ -32,7 +32,7 @@ def test_model_initialization(mock_sentence_transformer):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
@patch("haystack.preview.embedding_backends.sentence_transformers_backend.SentenceTransformer")
|
@patch("haystack.preview.components.embedders.backends.sentence_transformers_backend.SentenceTransformer")
|
||||||
def test_embedding_function_with_kwargs(mock_sentence_transformer):
|
def test_embedding_function_with_kwargs(mock_sentence_transformer):
|
||||||
embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(model_name_or_path="model")
|
embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(model_name_or_path="model")
|
||||||
|
|
Loading…
x
Reference in New Issue
Block a user