From 2c2549f13d87193725d1441faa27ed9cad7de1e3 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com> Date: Thu, 12 Oct 2023 17:52:28 +0200 Subject: [PATCH] move embedding backends (#6033) --- .../embedders/backends}/__init__.py | 0 .../backends}/sentence_transformers_backend.py | 0 .../sentence_transformers_document_embedder.py | 14 ++++++++------ .../sentence_transformers_text_embedder.py | 12 +++++++----- ...est_sentence_transformers_embedding_backend.py} | 8 ++++---- 5 files changed, 19 insertions(+), 15 deletions(-) rename haystack/preview/{embedding_backends => components/embedders/backends}/__init__.py (100%) rename haystack/preview/{embedding_backends => components/embedders/backends}/sentence_transformers_backend.py (100%) rename test/preview/{embedding_backends/test_sentence_transformers.py => components/embedders/test_sentence_transformers_embedding_backend.py} (78%) diff --git a/haystack/preview/embedding_backends/__init__.py b/haystack/preview/components/embedders/backends/__init__.py similarity index 100% rename from haystack/preview/embedding_backends/__init__.py rename to haystack/preview/components/embedders/backends/__init__.py diff --git a/haystack/preview/embedding_backends/sentence_transformers_backend.py b/haystack/preview/components/embedders/backends/sentence_transformers_backend.py similarity index 100% rename from haystack/preview/embedding_backends/sentence_transformers_backend.py rename to haystack/preview/components/embedders/backends/sentence_transformers_backend.py diff --git a/haystack/preview/components/embedders/sentence_transformers_document_embedder.py b/haystack/preview/components/embedders/sentence_transformers_document_embedder.py index 40ac170aa..381a2d7d9 100644 --- a/haystack/preview/components/embedders/sentence_transformers_document_embedder.py +++ b/haystack/preview/components/embedders/sentence_transformers_document_embedder.py @@ -1,7 +1,7 @@ from typing import List, Optional, Union, Dict, Any from haystack.preview import component, Document, default_to_dict, default_from_dict -from haystack.preview.embedding_backends.sentence_transformers_backend import ( +from haystack.preview.components.embedders.backends.sentence_transformers_backend import ( _SentenceTransformersEmbeddingBackendFactory, ) @@ -29,11 +29,13 @@ class SentenceTransformersDocumentEmbedder: """ Create a SentenceTransformersDocumentEmbedder component. - :param model_name_or_path: Local path or name of the model in Hugging Face's model hub, such as ``'sentence-transformers/all-mpnet-base-v2'``. - :param device: Device (like 'cuda' / 'cpu') that should be used for computation. If None, checks if a GPU can be used. + :param model_name_or_path: Local path or name of the model in Hugging Face's model hub, + such as ``'sentence-transformers/all-mpnet-base-v2'``. + :param device: Device (like 'cuda' / 'cpu') that should be used for computation. + Defaults to CPU. :param use_auth_token: The API token used to download private models from Hugging Face. - If this parameter is set to `True`, then the token generated when running - `transformers-cli login` (stored in ~/.huggingface) will be used. + If this parameter is set to `True`, then the token generated when running + `transformers-cli login` (stored in ~/.huggingface) will be used. :param prefix: A string to add to the beginning of each Document text before embedding. :param suffix: A string to add to the end of each Document text before embedding. :param batch_size: Number of strings to encode at once. @@ -95,7 +97,7 @@ class SentenceTransformersDocumentEmbedder: Embed a list of Documents. The embedding of each Document is stored in the `embedding` field of the Document. """ - if not isinstance(documents, list) or not isinstance(documents[0], Document): + if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): raise TypeError( "SentenceTransformersDocumentEmbedder expects a list of Documents as input." "In case you want to embed a list of strings, please use the SentenceTransformersTextEmbedder." diff --git a/haystack/preview/components/embedders/sentence_transformers_text_embedder.py b/haystack/preview/components/embedders/sentence_transformers_text_embedder.py index c22e04379..527191b1d 100644 --- a/haystack/preview/components/embedders/sentence_transformers_text_embedder.py +++ b/haystack/preview/components/embedders/sentence_transformers_text_embedder.py @@ -1,7 +1,7 @@ from typing import List, Optional, Union, Dict, Any from haystack.preview import component, default_to_dict, default_from_dict -from haystack.preview.embedding_backends.sentence_transformers_backend import ( +from haystack.preview.components.embedders.backends.sentence_transformers_backend import ( _SentenceTransformersEmbeddingBackendFactory, ) @@ -26,11 +26,13 @@ class SentenceTransformersTextEmbedder: """ Create a SentenceTransformersTextEmbedder component. - :param model_name_or_path: Local path or name of the model in Hugging Face's model hub, such as ``'sentence-transformers/all-mpnet-base-v2'``. - :param device: Device (like 'cuda' / 'cpu') that should be used for computation. If None, checks if a GPU can be used. + :param model_name_or_path: Local path or name of the model in Hugging Face's model hub, + such as ``'sentence-transformers/all-mpnet-base-v2'``. + :param device: Device (like 'cuda' / 'cpu') that should be used for computation. + Defaults to CPU. :param use_auth_token: The API token used to download private models from Hugging Face. - If this parameter is set to `True`, then the token generated when running - `transformers-cli login` (stored in ~/.huggingface) will be used. + If this parameter is set to `True`, then the token generated when running + `transformers-cli login` (stored in ~/.huggingface) will be used. :param prefix: A string to add to the beginning of each text. :param suffix: A string to add to the end of each text. :param batch_size: Number of strings to encode at once. diff --git a/test/preview/embedding_backends/test_sentence_transformers.py b/test/preview/components/embedders/test_sentence_transformers_embedding_backend.py similarity index 78% rename from test/preview/embedding_backends/test_sentence_transformers.py rename to test/preview/components/embedders/test_sentence_transformers_embedding_backend.py index f9f98d0a0..4ac8c5586 100644 --- a/test/preview/embedding_backends/test_sentence_transformers.py +++ b/test/preview/components/embedders/test_sentence_transformers_embedding_backend.py @@ -1,12 +1,12 @@ from unittest.mock import patch import pytest -from haystack.preview.embedding_backends.sentence_transformers_backend import ( +from haystack.preview.components.embedders.backends.sentence_transformers_backend import ( _SentenceTransformersEmbeddingBackendFactory, ) @pytest.mark.unit -@patch("haystack.preview.embedding_backends.sentence_transformers_backend.SentenceTransformer") +@patch("haystack.preview.components.embedders.backends.sentence_transformers_backend.SentenceTransformer") def test_factory_behavior(mock_sentence_transformer): embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend( model_name_or_path="my_model", device="cpu" @@ -21,7 +21,7 @@ def test_factory_behavior(mock_sentence_transformer): @pytest.mark.unit -@patch("haystack.preview.embedding_backends.sentence_transformers_backend.SentenceTransformer") +@patch("haystack.preview.components.embedders.backends.sentence_transformers_backend.SentenceTransformer") def test_model_initialization(mock_sentence_transformer): _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend( model_name_or_path="model", device="cpu", use_auth_token="my_token" @@ -32,7 +32,7 @@ def test_model_initialization(mock_sentence_transformer): @pytest.mark.unit -@patch("haystack.preview.embedding_backends.sentence_transformers_backend.SentenceTransformer") +@patch("haystack.preview.components.embedders.backends.sentence_transformers_backend.SentenceTransformer") def test_embedding_function_with_kwargs(mock_sentence_transformer): embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(model_name_or_path="model")