diff --git a/haystack/retriever/elasticsearch.py b/haystack/retriever/elasticsearch.py index 69cfd7035..284b678cb 100644 --- a/haystack/retriever/elasticsearch.py +++ b/haystack/retriever/elasticsearch.py @@ -44,13 +44,14 @@ class ElasticsearchRetriever(BaseRetriever): logger.info(f"Got {len(paragraphs)} candidates from retriever: {meta_data}") return paragraphs, meta_data - def create_embedding(self, text): + def create_embedding(self, text,extraction_strategy="reduce_mean", extraction_layer=-1): if self.model_format == "farm": - res = self.embedding_model.extract_vectors(dicts=[{"text": text}], extraction_strategy="reduce_mean", extraction_layer=-1) + res = self.embedding_model.extract_vectors(dicts=[{"text": text}], + extraction_strategy=extraction_strategy, + extraction_layer=extraction_layer) emb = list(res[0]["vec"]) elif self.model_format == "sentence_transformers": # text is single string, sentence-transformers needs a list of strings res = self.embedding_model.encode([text]) # get back list of numpy embedding vectors emb = res[0].tolist() - return emb - + return emb \ No newline at end of file