add params to create_embeddings in retriever (#45)

This commit is contained in:
Malte Pietsch 2020-03-22 15:16:30 +01:00 committed by GitHub
parent d767f12f7c
commit 456df59586
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -44,13 +44,14 @@ class ElasticsearchRetriever(BaseRetriever):
logger.info(f"Got {len(paragraphs)} candidates from retriever: {meta_data}")
return paragraphs, meta_data
def create_embedding(self, text):
def create_embedding(self, text,extraction_strategy="reduce_mean", extraction_layer=-1):
if self.model_format == "farm":
res = self.embedding_model.extract_vectors(dicts=[{"text": text}], extraction_strategy="reduce_mean", extraction_layer=-1)
res = self.embedding_model.extract_vectors(dicts=[{"text": text}],
extraction_strategy=extraction_strategy,
extraction_layer=extraction_layer)
emb = list(res[0]["vec"])
elif self.model_format == "sentence_transformers":
# text is single string, sentence-transformers needs a list of strings
res = self.embedding_model.encode([text]) # get back list of numpy embedding vectors
emb = res[0].tolist()
return emb
return emb