mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-25 14:08:27 +00:00
Add pipelines for GenerativeQA & FAQs (#645)
This commit is contained in:
parent
216787ed34
commit
8e52b48e1d
16
.github/workflows/ci.yml
vendored
16
.github/workflows/ci.yml
vendored
@ -22,10 +22,10 @@ jobs:
|
||||
sudo sysctl -w vm.max_map_count=262144
|
||||
|
||||
- name: Run Elasticsearch
|
||||
run: docker run -d -p 9200:9200 -e "discovery.type=single-node" -e 'ES_JAVA_OPTS=-Xms500m -Xmx500m' elasticsearch:7.9.2
|
||||
run: docker run -d -p 9200:9200 -e "discovery.type=single-node" -e "ES_JAVA_OPTS=-Xms128m -Xmx128m" elasticsearch:7.9.2
|
||||
|
||||
- name: Run Apache Tika
|
||||
run: docker run -d -p 9998:9998 apache/tika:1.24.1
|
||||
run: docker run -d -p 9998:9998 -e "TIKA_CHILD_JAVA_OPTS=-JXms128m" -e "TIKA_CHILD_JAVA_OPTS=-JXmx128m" apache/tika:1.24.1
|
||||
|
||||
- name: Set up Python 3.7
|
||||
uses: actions/setup-python@v2
|
||||
@ -39,14 +39,14 @@ jobs:
|
||||
pip install -r requirements.txt
|
||||
pip install -e .
|
||||
|
||||
- name: Run Pytest without generator marker
|
||||
run: cd test && pytest -m "not generator"
|
||||
- name: Run Pytest without pipeline marker
|
||||
run: cd test && pytest -m "not pipeline"
|
||||
|
||||
- name: Stop Containers
|
||||
run: docker rm -f `docker ps -a -q`
|
||||
# - name: Stop Containers
|
||||
# run: docker rm -f `docker ps -a -q`
|
||||
|
||||
- name: Run pytest with generator marker
|
||||
run: cd test && pytest -m generator
|
||||
- name: Run pytest with pipeline marker
|
||||
run: cd test && pytest -m pipeline
|
||||
|
||||
- name: Test with mypy
|
||||
run: |
|
||||
|
||||
@ -232,7 +232,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
index = index or self.index
|
||||
query = {"query": {"ids": {"values": ids}}}
|
||||
result = self.client.search(index=index, body=query)["hits"]["hits"]
|
||||
documents = [self._convert_es_hit_to_document(hit) for hit in result]
|
||||
documents = [self._convert_es_hit_to_document(hit, return_embedding=self.return_embedding) for hit in result]
|
||||
return documents
|
||||
|
||||
def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None):
|
||||
@ -500,7 +500,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
logger.debug(f"Retriever query: {body}")
|
||||
result = self.client.search(index=index, body=body)["hits"]["hits"]
|
||||
|
||||
documents = [self._convert_es_hit_to_document(hit) for hit in result]
|
||||
documents = [self._convert_es_hit_to_document(hit, return_embedding=self.return_embedding) for hit in result]
|
||||
return documents
|
||||
|
||||
def query_by_embedding(self,
|
||||
@ -573,14 +573,18 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
logger.debug(f"Retriever query: {body}")
|
||||
result = self.client.search(index=index, body=body, request_timeout=300)["hits"]["hits"]
|
||||
|
||||
documents = [self._convert_es_hit_to_document(hit, adapt_score_for_embedding=True) for hit in result]
|
||||
documents = [
|
||||
self._convert_es_hit_to_document(hit, adapt_score_for_embedding=True, return_embedding=return_embedding)
|
||||
for hit in result
|
||||
]
|
||||
return documents
|
||||
|
||||
def _convert_es_hit_to_document(
|
||||
self,
|
||||
hit: dict,
|
||||
return_embedding: bool,
|
||||
adapt_score_for_embedding: bool = False,
|
||||
return_embedding: bool = True
|
||||
|
||||
) -> Document:
|
||||
# We put all additional data of the doc into meta_data and return it in the API
|
||||
meta_data = {k:v for k,v in hit["_source"].items() if k not in (self.text_field, self.faq_question_field, self.embedding_field)}
|
||||
@ -597,6 +601,13 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
probability = float(expit(np.asarray(score / 8))) # scaling probability from TFIDF/BM25
|
||||
else:
|
||||
probability = None
|
||||
|
||||
embedding = None
|
||||
if return_embedding:
|
||||
embedding_list = hit["_source"].get(self.embedding_field)
|
||||
if embedding_list:
|
||||
embedding = np.asarray(embedding_list, dtype=np.float32)
|
||||
|
||||
document = Document(
|
||||
id=hit["_id"],
|
||||
text=hit["_source"].get(self.text_field),
|
||||
@ -604,7 +615,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
score=score,
|
||||
probability=probability,
|
||||
question=hit["_source"].get(self.faq_question_field),
|
||||
embedding=hit["_source"].get(self.embedding_field, None) if return_embedding else None,
|
||||
embedding=embedding,
|
||||
)
|
||||
return document
|
||||
|
||||
|
||||
@ -272,7 +272,7 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
#assign query score to each document
|
||||
scores_for_vector_ids: Dict[str, float] = {str(v_id): s for v_id, s in zip(vector_id_matrix[0], score_matrix[0])}
|
||||
for doc in documents:
|
||||
doc.score = scores_for_vector_ids[doc.meta["vector_id"]] # type: ignore
|
||||
doc.score = scores_for_vector_ids[doc.meta["vector_id"]]
|
||||
doc.probability = (doc.score + 1) / 2
|
||||
if return_embedding is True:
|
||||
doc.embedding = self.faiss_index.reconstruct(int(doc.meta["vector_id"]))
|
||||
|
||||
@ -109,7 +109,7 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
DocumentORM.vector_id.in_(vector_ids),
|
||||
DocumentORM.index == index
|
||||
).all()
|
||||
sorted_results = sorted(results, key=lambda doc: vector_ids.index(doc.vector_id)) # type: ignore
|
||||
sorted_results = sorted(results, key=lambda doc: vector_ids.index(doc.vector_id))
|
||||
documents = [self._convert_sql_row_to_document(row) for row in sorted_results]
|
||||
return documents
|
||||
|
||||
@ -282,7 +282,7 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
meta={meta.name: meta.value for meta in row.meta}
|
||||
)
|
||||
if row.vector_id:
|
||||
document.meta["vector_id"] = row.vector_id # type: ignore
|
||||
document.meta["vector_id"] = row.vector_id
|
||||
return document
|
||||
|
||||
def _convert_sql_row_to_label(self, row) -> Label:
|
||||
|
||||
@ -8,14 +8,27 @@ class BaseGenerator(ABC):
|
||||
"""
|
||||
Abstract class for Generators
|
||||
"""
|
||||
|
||||
outgoing_edges = 1
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, question: str, documents: List[Document], top_k: Optional[int]) -> Dict:
|
||||
def predict(self, query: str, documents: List[Document], top_k: Optional[int]) -> Dict:
|
||||
"""
|
||||
Abstract method to generate answers.
|
||||
|
||||
:param question: Question
|
||||
:param query: Query
|
||||
:param documents: Related documents (e.g. coming from a retriever) that the answer shall be conditioned on.
|
||||
:param top_k: Number of returned answers
|
||||
:return: Generated answers plus additional infos in a dict
|
||||
"""
|
||||
pass
|
||||
|
||||
def run(self, query: str, documents: List[Document], top_k_generator: Optional[int] = None, **kwargs):
|
||||
|
||||
if documents:
|
||||
results = self.predict(query=query, documents=documents, top_k=top_k_generator)
|
||||
else:
|
||||
results = {"answers": []}
|
||||
|
||||
results.update(**kwargs)
|
||||
return results, "output_1"
|
||||
|
||||
@ -34,23 +34,23 @@ class RAGenerator(BaseGenerator):
|
||||
**Example**
|
||||
|
||||
```python
|
||||
| question = "who got the first nobel prize in physics?"
|
||||
| query = "who got the first nobel prize in physics?"
|
||||
|
|
||||
| # Retrieve related documents from retriever
|
||||
| retrieved_docs = retriever.retrieve(query=question)
|
||||
| retrieved_docs = retriever.retrieve(query=query)
|
||||
|
|
||||
| # Now generate answer from question and retrieved documents
|
||||
| # Now generate answer from query and retrieved documents
|
||||
| generator.predict(
|
||||
| question=question,
|
||||
| query=query,
|
||||
| documents=retrieved_docs,
|
||||
| top_k=1
|
||||
| )
|
||||
|
|
||||
| # Answer
|
||||
|
|
||||
| {'question': 'who got the first nobel prize in physics',
|
||||
| {'query': 'who got the first nobel prize in physics',
|
||||
| 'answers':
|
||||
| [{'question': 'who got the first nobel prize in physics',
|
||||
| [{'query': 'who got the first nobel prize in physics',
|
||||
| 'answer': ' albert einstein',
|
||||
| 'meta': { 'doc_ids': [...],
|
||||
| 'doc_scores': [80.42758 ...],
|
||||
@ -138,7 +138,7 @@ class RAGenerator(BaseGenerator):
|
||||
return out
|
||||
|
||||
# Copied postprocess_docs method from transformers.RagRetriever and modified
|
||||
def _get_contextualized_inputs(self, texts: List[str], question: str, titles: Optional[List[str]] = None,
|
||||
def _get_contextualized_inputs(self, texts: List[str], query: str, titles: Optional[List[str]] = None,
|
||||
return_tensors: str = "pt"):
|
||||
|
||||
titles_list = titles if self.embed_title and titles is not None else [""] * len(texts)
|
||||
@ -148,7 +148,7 @@ class RAGenerator(BaseGenerator):
|
||||
self._cat_input_and_doc(
|
||||
doc_title=titles_list[i],
|
||||
doc_text=texts[i],
|
||||
input_string=question,
|
||||
input_string=query,
|
||||
prefix=prefix
|
||||
)
|
||||
for i in range(len(texts))
|
||||
@ -172,7 +172,7 @@ class RAGenerator(BaseGenerator):
|
||||
|
||||
if is_embedding_required:
|
||||
if self.retriever is None:
|
||||
raise AttributeError("_prepare_passage_embeddings need self.dpr_retriever to embed document")
|
||||
raise AttributeError("_prepare_passage_embeddings need a DPR instance as self.retriever to embed document")
|
||||
|
||||
embeddings = self.retriever.embed_passages(docs)
|
||||
|
||||
@ -183,20 +183,20 @@ class RAGenerator(BaseGenerator):
|
||||
|
||||
return embeddings_in_tensor
|
||||
|
||||
def predict(self, question: str, documents: List[Document], top_k: Optional[int] = None) -> Dict:
|
||||
def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None) -> Dict:
|
||||
"""
|
||||
Generate the answer to the input question. The generation will be conditioned on the supplied documents.
|
||||
Generate the answer to the input query. The generation will be conditioned on the supplied documents.
|
||||
These document can for example be retrieved via the Retriever.
|
||||
|
||||
:param question: Question
|
||||
:param query: Query
|
||||
:param documents: Related documents (e.g. coming from a retriever) that the answer shall be conditioned on.
|
||||
:param top_k: Number of returned answers
|
||||
:return: Generated answers plus additional infos in a dict like this:
|
||||
|
||||
```python
|
||||
| {'question': 'who got the first nobel prize in physics',
|
||||
| {'query': 'who got the first nobel prize in physics',
|
||||
| 'answers':
|
||||
| [{'question': 'who got the first nobel prize in physics',
|
||||
| [{'query': 'who got the first nobel prize in physics',
|
||||
| 'answer': ' albert einstein',
|
||||
| 'meta': { 'doc_ids': [...],
|
||||
| 'doc_scores': [80.42758 ...],
|
||||
@ -227,28 +227,28 @@ class RAGenerator(BaseGenerator):
|
||||
# Extract title
|
||||
titles = [d.meta["name"] if d.meta and "name" in d.meta else "" for d in documents]
|
||||
|
||||
# Raw document embedding and set device of question_embedding
|
||||
# Raw document embedding and set device of query_embedding
|
||||
passage_embeddings = self._prepare_passage_embeddings(docs=documents, embeddings=flat_docs_dict["embedding"])
|
||||
|
||||
# Question tokenization
|
||||
# Query tokenization
|
||||
input_dict = self.tokenizer.prepare_seq2seq_batch(
|
||||
src_texts=[question],
|
||||
src_texts=[query],
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
||||
# Question embedding
|
||||
question_embedding = self.model.question_encoder(input_dict["input_ids"])[0]
|
||||
# Query embedding
|
||||
query_embedding = self.model.question_encoder(input_dict["input_ids"])[0]
|
||||
|
||||
# Prepare contextualized input_ids of documents
|
||||
# (will be transformed into contextualized inputs inside generator)
|
||||
context_input_ids, context_attention_mask = self._get_contextualized_inputs(
|
||||
texts=flat_docs_dict["text"],
|
||||
titles=titles,
|
||||
question=question
|
||||
query=query
|
||||
)
|
||||
|
||||
# Compute doc scores from docs_embedding
|
||||
doc_scores = torch.bmm(question_embedding.unsqueeze(1),
|
||||
doc_scores = torch.bmm(query_embedding.unsqueeze(1),
|
||||
passage_embeddings.unsqueeze(0).transpose(1, 2)).squeeze(1)
|
||||
|
||||
# TODO Need transformers 3.4.0
|
||||
@ -277,7 +277,7 @@ class RAGenerator(BaseGenerator):
|
||||
|
||||
for generated_answer in generated_answers:
|
||||
cur_answer = {
|
||||
"question": question,
|
||||
"query": query,
|
||||
"answer": generated_answer,
|
||||
"meta": {
|
||||
"doc_ids": flat_docs_dict["id"],
|
||||
@ -289,6 +289,6 @@ class RAGenerator(BaseGenerator):
|
||||
}
|
||||
answers.append(cur_answer)
|
||||
|
||||
result = {"question": question, "answers": answers}
|
||||
result = {"query": query, "answers": answers}
|
||||
|
||||
return result
|
||||
|
||||
@ -1,8 +1,11 @@
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict
|
||||
|
||||
import networkx as nx
|
||||
from networkx import DiGraph
|
||||
from networkx.drawing.nx_agraph import to_agraph
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
|
||||
from haystack.generator.base import BaseGenerator
|
||||
from haystack.reader.base import BaseReader
|
||||
from haystack.retriever.base import BaseRetriever
|
||||
|
||||
@ -15,6 +18,7 @@ class Pipeline:
|
||||
flows with options to branch queries(eg, extractive qa vs keyword match query), merge candidate documents for a
|
||||
Reader from multiple Retrievers, or re-ranking of candidate documents.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.graph = DiGraph()
|
||||
self.root_node_id = "Query"
|
||||
@ -115,7 +119,15 @@ class Pipeline:
|
||||
graphviz.draw(path)
|
||||
|
||||
|
||||
class ExtractiveQAPipeline:
|
||||
class BaseStandardPipeline:
|
||||
def add_node(self, component, name: str, inputs: List[str]):
|
||||
self.pipeline.add_node(component=component, name=name, inputs=inputs) # type: ignore
|
||||
|
||||
def draw(self, path: Path = Path("pipeline.png")):
|
||||
self.pipeline.draw(path) # type: ignore
|
||||
|
||||
|
||||
class ExtractiveQAPipeline(BaseStandardPipeline):
|
||||
def __init__(self, reader: BaseReader, retriever: BaseRetriever):
|
||||
"""
|
||||
Initialize a Pipeline for Extractive Question Answering.
|
||||
@ -127,20 +139,14 @@ class ExtractiveQAPipeline:
|
||||
self.pipeline.add_node(component=retriever, name="Retriever", inputs=["Query"])
|
||||
self.pipeline.add_node(component=reader, name="Reader", inputs=["Retriever"])
|
||||
|
||||
def run(self, query, top_k_retriever=5, top_k_reader=5):
|
||||
output = self.pipeline.run(query=query,
|
||||
top_k_retriever=top_k_retriever,
|
||||
top_k_reader=top_k_reader)
|
||||
def run(self, query: str, filters: Optional[Dict] = None, top_k_retriever: int = 10, top_k_reader: int = 10):
|
||||
output = self.pipeline.run(
|
||||
query=query, filters=filters, top_k_retriever=top_k_retriever, top_k_reader=top_k_reader
|
||||
)
|
||||
return output
|
||||
|
||||
def add_node(self, component, name: str, inputs: List[str]):
|
||||
self.pipeline.add_node(component=component, name=name, inputs=inputs)
|
||||
|
||||
def draw(self, path: Path = Path("pipeline.png")):
|
||||
self.pipeline.draw(path)
|
||||
|
||||
|
||||
class DocumentSearchPipeline:
|
||||
class DocumentSearchPipeline(BaseStandardPipeline):
|
||||
def __init__(self, retriever: BaseRetriever):
|
||||
"""
|
||||
Initialize a Pipeline for semantic document search.
|
||||
@ -150,17 +156,64 @@ class DocumentSearchPipeline:
|
||||
self.pipeline = Pipeline()
|
||||
self.pipeline.add_node(component=retriever, name="Retriever", inputs=["Query"])
|
||||
|
||||
def run(self, query, top_k_retriever=5):
|
||||
output = self.pipeline.run(query=query, top_k_retriever=top_k_retriever)
|
||||
def run(self, query: str, filters: Optional[Dict] = None, top_k_retriever: int = 10):
|
||||
output = self.pipeline.run(query=query, filters=filters, top_k_retriever=top_k_retriever)
|
||||
document_dicts = [doc.to_dict() for doc in output["documents"]]
|
||||
output["documents"] = document_dicts
|
||||
return output
|
||||
|
||||
def add_node(self, component, name: str, inputs: List[str]):
|
||||
self.pipeline.add_node(component=component, name=name, inputs=inputs)
|
||||
|
||||
def draw(self, path: Path = Path("pipeline.png")):
|
||||
self.pipeline.draw(path)
|
||||
class GenerativeQAPipeline(BaseStandardPipeline):
|
||||
def __init__(self, generator: BaseGenerator, retriever: BaseRetriever):
|
||||
"""
|
||||
Initialize a Pipeline for Generative Question Answering.
|
||||
|
||||
:param generator: Generator instance
|
||||
:param retriever: Retriever instance
|
||||
"""
|
||||
self.pipeline = Pipeline()
|
||||
self.pipeline.add_node(component=retriever, name="Retriever", inputs=["Query"])
|
||||
self.pipeline.add_node(component=generator, name="Generator", inputs=["Retriever"])
|
||||
|
||||
def run(self, query: str, filters: Optional[Dict] = None, top_k_retriever: int = 10, top_k_generator: int = 10):
|
||||
output = self.pipeline.run(
|
||||
query=query, filters=filters, top_k_retriever=top_k_retriever, top_k_generator=top_k_generator
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
class FAQPipeline(BaseStandardPipeline):
|
||||
def __init__(self, retriever: BaseRetriever):
|
||||
"""
|
||||
Initialize a Pipeline for finding similar FAQs using semantic document search.
|
||||
|
||||
:param retriever: Retriever instance
|
||||
"""
|
||||
self.pipeline = Pipeline()
|
||||
self.pipeline.add_node(component=retriever, name="Retriever", inputs=["Query"])
|
||||
|
||||
def run(self, query: str, filters: Optional[Dict] = None, top_k_retriever: int = 10):
|
||||
output = self.pipeline.run(query=query, filters=filters, top_k_retriever=top_k_retriever)
|
||||
documents = output["documents"]
|
||||
|
||||
results: Dict = {"query": query, "answers": []}
|
||||
for doc in documents:
|
||||
# TODO proper calibratation of pseudo probabilities
|
||||
cur_answer = {
|
||||
"query": doc.text,
|
||||
"answer": doc.meta["answer"],
|
||||
"document_id": doc.id,
|
||||
"context": doc.meta["answer"],
|
||||
"score": doc.score,
|
||||
"probability": doc.probability,
|
||||
"offset_start": 0,
|
||||
"offset_end": len(doc.meta["answer"]),
|
||||
"meta": doc.meta,
|
||||
}
|
||||
|
||||
results["answers"].append(cur_answer)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class QueryNode:
|
||||
|
||||
@ -46,9 +46,9 @@ class BaseReader(ABC):
|
||||
"meta": None,}
|
||||
return no_ans_prediction, max_no_ans_gap
|
||||
|
||||
def run(self, query: str, documents: List[Document], top_k: Optional[int] = None):
|
||||
def run(self, query: str, documents: List[Document], top_k_reader: Optional[int] = None, **kwargs):
|
||||
if documents:
|
||||
results = self.predict(query=query, documents=documents, top_k=top_k)
|
||||
results = self.predict(query=query, documents=documents, top_k=top_k_reader)
|
||||
else:
|
||||
results = {"answers": []}
|
||||
|
||||
@ -59,4 +59,5 @@ class BaseReader(ABC):
|
||||
if doc.id == ans["document_id"]:
|
||||
ans["meta"] = deepcopy(doc.meta)
|
||||
|
||||
results.update(**kwargs)
|
||||
return results, "output_1"
|
||||
|
||||
@ -173,16 +173,17 @@ class BaseRetriever(ABC):
|
||||
query: str,
|
||||
filters: Optional[dict] = None,
|
||||
top_k_retriever: Optional[int] = None,
|
||||
top_k_reader: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if top_k_retriever:
|
||||
documents = self.retrieve(query=query, filters=filters, top_k=top_k_retriever)
|
||||
else:
|
||||
documents = self.retrieve(query=query, filters=filters)
|
||||
|
||||
output = {
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
"top_k": top_k_reader
|
||||
**kwargs
|
||||
}
|
||||
|
||||
return output, "output_1"
|
||||
|
||||
@ -390,7 +390,6 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
"""
|
||||
self.document_store = document_store
|
||||
self.model_format = model_format
|
||||
self.embedding_model = embedding_model
|
||||
self.pooling_strategy = pooling_strategy
|
||||
self.emb_extraction_layer = emb_extraction_layer
|
||||
|
||||
@ -444,19 +443,19 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
"""
|
||||
|
||||
# for backward compatibility: cast pure str input
|
||||
if type(texts) == str:
|
||||
texts = [texts] # type: ignore
|
||||
assert type(texts) == list, "Expecting a list of texts, i.e. create_embeddings(texts=['text1',...])"
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
assert isinstance(texts, list), "Expecting a list of texts, i.e. create_embeddings(texts=['text1',...])"
|
||||
|
||||
if self.model_format == "farm" or self.model_format == "transformers":
|
||||
# TODO: FARM's `sample_to_features_text` need to fix following warning -
|
||||
# tokenization_utils.py:460: FutureWarning: `is_pretokenized` is deprecated and will be removed in a future version, use `is_split_into_words` instead.
|
||||
emb = self.embedding_model.inference_from_dicts(dicts=[{"text": t} for t in texts]) # type: ignore
|
||||
emb = self.embedding_model.inference_from_dicts(dicts=[{"text": t} for t in texts])
|
||||
emb = [(r["vec"]) for r in emb]
|
||||
elif self.model_format == "sentence_transformers":
|
||||
# text is single string, sentence-transformers needs a list of strings
|
||||
# get back list of numpy embedding vectors
|
||||
emb = self.embedding_model.encode(texts) # type: ignore
|
||||
emb = self.embedding_model.encode(texts)
|
||||
emb = [r for r in emb]
|
||||
return emb
|
||||
|
||||
|
||||
@ -159,6 +159,9 @@ class TfidfRetriever(BaseRetriever):
|
||||
:param top_k: How many documents to return per query.
|
||||
:param index: The name of the index in the DocumentStore from which to retrieve documents
|
||||
"""
|
||||
if self.df is None:
|
||||
raise Exception("fit() needs to called before retrieve()")
|
||||
|
||||
if filters:
|
||||
raise NotImplementedError("Filters are not implemented in TfidfRetriever.")
|
||||
if index:
|
||||
@ -168,7 +171,7 @@ class TfidfRetriever(BaseRetriever):
|
||||
indices_and_scores = self._calc_scores(query)
|
||||
|
||||
# rank paragraphs
|
||||
df_sliced = self.df.loc[indices_and_scores.keys()] # type: ignore
|
||||
df_sliced = self.df.loc[indices_and_scores.keys()]
|
||||
df_sliced = df_sliced[:top_k]
|
||||
|
||||
logger.debug(
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any, Optional, Dict, List, Union
|
||||
from typing import Any, Optional, Dict, List
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
@ -6,11 +6,11 @@ import numpy as np
|
||||
|
||||
class Document:
|
||||
def __init__(self, text: str,
|
||||
id: str = None,
|
||||
id: Optional[str] = None,
|
||||
score: Optional[float] = None,
|
||||
probability: Optional[float] = None,
|
||||
question: Optional[str] = None,
|
||||
meta: Optional[Dict[str, Any]] = None,
|
||||
meta: Dict[str, Any] = None,
|
||||
embedding: Optional[np.array] = None):
|
||||
"""
|
||||
Object used to represent documents / passages in a standardized way within Haystack.
|
||||
@ -23,7 +23,7 @@ class Document:
|
||||
:param id: ID used within the DocumentStore
|
||||
:param text: Text of the document
|
||||
:param score: Retriever's query score for a retrieved document
|
||||
:param probability: a psuedo probability by scaling score in the range 0 to 1
|
||||
:param probability: a pseudo probability by scaling score in the range 0 to 1
|
||||
:param question: Question text for FAQs.
|
||||
:param meta: Meta fields for a document like name, url, or author.
|
||||
:param embedding: Vector encoding of the text
|
||||
@ -39,7 +39,7 @@ class Document:
|
||||
self.score = score
|
||||
self.probability = probability
|
||||
self.question = question
|
||||
self.meta = meta
|
||||
self.meta = meta or {}
|
||||
self.embedding = embedding
|
||||
|
||||
def to_dict(self, field_map={}):
|
||||
|
||||
@ -30,6 +30,8 @@ def pytest_collection_modifyitems(items):
|
||||
item.add_marker(pytest.mark.tika)
|
||||
elif "elasticsearch" in item.nodeid:
|
||||
item.add_marker(pytest.mark.elasticsearch)
|
||||
elif "pipeline" in item.nodeid:
|
||||
item.add_marker(pytest.mark.pipeline)
|
||||
elif "slow" in item.nodeid:
|
||||
item.add_marker(pytest.mark.slow)
|
||||
|
||||
@ -107,47 +109,7 @@ def xpdf_fixture(tika_fixture):
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def farm_distilbert():
|
||||
return FARMReader(
|
||||
model_name_or_path="distilbert-base-uncased-distilled-squad",
|
||||
use_gpu=False,
|
||||
top_k_per_sample=5,
|
||||
num_processes=0
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def farm_roberta():
|
||||
return FARMReader(
|
||||
model_name_or_path="deepset/roberta-base-squad2",
|
||||
use_gpu=False,
|
||||
top_k_per_sample=5,
|
||||
no_ans_boost=0,
|
||||
num_processes=0
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def transformers_distilbert():
|
||||
return TransformersReader(
|
||||
model_name_or_path="distilbert-base-uncased-distilled-squad",
|
||||
tokenizer="distilbert-base-uncased",
|
||||
use_gpu=-1
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def transformers_roberta():
|
||||
return TransformersReader(
|
||||
model_name_or_path="deepset/roberta-base-squad2",
|
||||
tokenizer="deepset/roberta-base-squad2",
|
||||
use_gpu=-1,
|
||||
top_k_per_candidate=5
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture(scope="module")
|
||||
def rag_generator():
|
||||
return RAGenerator(
|
||||
model_name_or_path="facebook/rag-token-nq",
|
||||
@ -161,7 +123,7 @@ def faiss_document_store():
|
||||
os.remove("haystack_test_faiss.db")
|
||||
document_store = FAISSDocumentStore(
|
||||
sql_url="sqlite:///haystack_test_faiss.db",
|
||||
return_embedding=False
|
||||
return_embedding=True
|
||||
)
|
||||
yield document_store
|
||||
document_store.faiss_index.reset()
|
||||
@ -169,7 +131,7 @@ def faiss_document_store():
|
||||
|
||||
@pytest.fixture()
|
||||
def inmemory_document_store():
|
||||
return InMemoryDocumentStore(return_embedding=False)
|
||||
return InMemoryDocumentStore(return_embedding=True)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@ -198,7 +160,7 @@ def tfidf_retriever(inmemory_document_store):
|
||||
return TfidfRetriever(document_store=inmemory_document_store)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture(scope="module")
|
||||
def test_docs_xs():
|
||||
return [
|
||||
# current "dict" format for a document
|
||||
@ -210,32 +172,52 @@ def test_docs_xs():
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(params=["farm", "transformers"])
|
||||
def reader(request, transformers_distilbert, farm_distilbert):
|
||||
@pytest.fixture(params=["farm", "transformers"], scope="module")
|
||||
def reader(request):
|
||||
if request.param == "farm":
|
||||
return farm_distilbert
|
||||
return FARMReader(
|
||||
model_name_or_path="distilbert-base-uncased-distilled-squad",
|
||||
use_gpu=False,
|
||||
top_k_per_sample=5,
|
||||
num_processes=0
|
||||
)
|
||||
if request.param == "transformers":
|
||||
return transformers_distilbert
|
||||
return TransformersReader(
|
||||
model_name_or_path="distilbert-base-uncased-distilled-squad",
|
||||
tokenizer="distilbert-base-uncased",
|
||||
use_gpu=-1
|
||||
)
|
||||
|
||||
|
||||
# TODO Fix bug in test_no_answer_output when using
|
||||
# @pytest.fixture(params=["farm", "transformers"])
|
||||
@pytest.fixture(params=["farm"])
|
||||
def no_answer_reader(request, transformers_roberta, farm_roberta):
|
||||
@pytest.fixture(params=["farm"], scope="module")
|
||||
def no_answer_reader(request):
|
||||
if request.param == "farm":
|
||||
return farm_roberta
|
||||
return FARMReader(
|
||||
model_name_or_path="deepset/roberta-base-squad2",
|
||||
use_gpu=False,
|
||||
top_k_per_sample=5,
|
||||
no_ans_boost=0,
|
||||
num_processes=0
|
||||
)
|
||||
if request.param == "transformers":
|
||||
return transformers_roberta
|
||||
return TransformersReader(
|
||||
model_name_or_path="deepset/roberta-base-squad2",
|
||||
tokenizer="deepset/roberta-base-squad2",
|
||||
use_gpu=-1,
|
||||
top_k_per_candidate=5
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture(scope="module")
|
||||
def prediction(reader, test_docs_xs):
|
||||
docs = [Document.from_dict(d) if isinstance(d, dict) else d for d in test_docs_xs]
|
||||
prediction = reader.predict(query="Who lives in Berlin?", documents=docs, top_k=5)
|
||||
return prediction
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture(scope="module")
|
||||
def no_answer_prediction(no_answer_reader, test_docs_xs):
|
||||
docs = [Document.from_dict(d) if isinstance(d, dict) else d for d in test_docs_xs]
|
||||
prediction = no_answer_reader.predict(query="What is the meaning of life?", documents=docs, top_k=5)
|
||||
@ -276,7 +258,7 @@ def get_document_store(document_store_type, faiss_document_store, inmemory_docum
|
||||
# make sure we start from a fresh index
|
||||
client = Elasticsearch()
|
||||
client.indices.delete(index='haystack_test*', ignore=[404])
|
||||
document_store = ElasticsearchDocumentStore(index="haystack_test", return_embedding=False)
|
||||
document_store = ElasticsearchDocumentStore(index="haystack_test", return_embedding=True)
|
||||
elif document_store_type == "faiss":
|
||||
document_store = faiss_document_store
|
||||
else:
|
||||
|
||||
@ -5,3 +5,4 @@ markers =
|
||||
tika: marks tests which require tika container (deselect with '-m "not tika"')
|
||||
elasticsearch: marks tests which require elasticsearch container (deselect with '-m "not elasticsearch"')
|
||||
generator: marks generator tests (deselect with '-m "not generator"')
|
||||
pipeline: marks tests with pipeline
|
||||
|
||||
@ -24,7 +24,7 @@ def test_get_all_document_filter_duplicate_value(document_store):
|
||||
),
|
||||
Document(
|
||||
text="Doc1",
|
||||
meta={"f1": "1", "vector_id": "0"}
|
||||
meta={"f1": "1", "meta_id": "0"}
|
||||
),
|
||||
Document(
|
||||
text="Doc2",
|
||||
@ -35,7 +35,7 @@ def test_get_all_document_filter_duplicate_value(document_store):
|
||||
documents = document_store.get_all_documents(filters={"f1": ["1"]})
|
||||
assert documents[0].text == "Doc1"
|
||||
assert len(documents) == 1
|
||||
assert {d.meta["vector_id"] for d in documents} == {"0"}
|
||||
assert {d.meta["meta_id"] for d in documents} == {"0"}
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
|
||||
@ -4,6 +4,7 @@ import pytest
|
||||
from haystack import Document
|
||||
from haystack import Finder
|
||||
from haystack.document_store.faiss import FAISSDocumentStore
|
||||
from haystack.pipeline import Pipeline
|
||||
from haystack.retriever.dense import EmbeddingRetriever
|
||||
|
||||
DOCUMENTS = [
|
||||
@ -133,6 +134,20 @@ def test_faiss_finding(faiss_document_store, embedding_retriever):
|
||||
assert len(prediction.get('answers', [])) == 1
|
||||
|
||||
|
||||
def test_faiss_pipeline(faiss_document_store, embedding_retriever):
|
||||
documents = [
|
||||
{"name": "name_1", "text": "text_1", "embedding": np.random.rand(768).astype(np.float32)},
|
||||
{"name": "name_2", "text": "text_2", "embedding": np.random.rand(768).astype(np.float32)},
|
||||
{"name": "name_3", "text": "text_3", "embedding": np.random.rand(768).astype(np.float64)},
|
||||
{"name": "name_4", "text": "text_4", "embedding": np.random.rand(768).astype(np.float32)},
|
||||
]
|
||||
faiss_document_store.write_documents(documents)
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(component=embedding_retriever, name="FAISS", inputs=["Query"])
|
||||
output = pipeline.run(query="How to test this?", top_k_retriever=3)
|
||||
assert len(output["documents"]) == 3
|
||||
|
||||
|
||||
def test_faiss_passing_index_from_outside():
|
||||
d = 768
|
||||
nlist = 2
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from haystack import Document
|
||||
import numpy as np
|
||||
from haystack.pipeline import GenerativeQAPipeline
|
||||
|
||||
DOCS_WITH_EMBEDDINGS = [
|
||||
Document(
|
||||
@ -402,8 +403,26 @@ DOCS_WITH_EMBEDDINGS = [
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.generator
|
||||
def test_rag_token_generator(rag_generator):
|
||||
question = "What is capital of the Germany?"
|
||||
generated_docs = rag_generator.predict(question=question, documents=DOCS_WITH_EMBEDDINGS, top_k=1)
|
||||
query = "What is capital of the Germany?"
|
||||
generated_docs = rag_generator.predict(query=query, documents=DOCS_WITH_EMBEDDINGS, top_k=1)
|
||||
answers = generated_docs["answers"]
|
||||
assert len(answers) == 1
|
||||
assert "berlin" in answers[0]["answer"]
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.generator
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize(
|
||||
"retriever,document_store",
|
||||
[("embedding", "memory"), ("embedding", "faiss"), ("elasticsearch", "elasticsearch")],
|
||||
indirect=True,
|
||||
)
|
||||
def test_generator_pipeline(document_store, retriever, rag_generator):
|
||||
document_store.write_documents(DOCS_WITH_EMBEDDINGS)
|
||||
query = "What is capital of the Germany?"
|
||||
pipeline = GenerativeQAPipeline(retriever=retriever, generator=rag_generator)
|
||||
output = pipeline.run(query=query, top_k_generator=2, top_k_retriever=1)
|
||||
answers = output["answers"]
|
||||
assert len(answers) == 2
|
||||
assert "berlin" in answers[0]["answer"]
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from haystack.pipeline import ExtractiveQAPipeline, Pipeline
|
||||
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
|
||||
from haystack.pipeline import ExtractiveQAPipeline, Pipeline, FAQPipeline, DocumentSearchPipeline
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@ -60,3 +61,59 @@ def test_extractive_qa_answers_single_result(reader, retriever_with_docs, docume
|
||||
assert prediction is not None
|
||||
assert len(prediction["answers"]) == 1
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize(
|
||||
"retriever,document_store",
|
||||
[("embedding", "memory"), ("embedding", "faiss"), ("embedding", "elasticsearch")],
|
||||
indirect=True,
|
||||
)
|
||||
def test_faq_pipeline(retriever, document_store):
|
||||
documents = [
|
||||
{"text": "How to test module-1?", 'meta': {"source": "wiki1", "answer": "Using tests for module-1"}},
|
||||
{"text": "How to test module-2?", 'meta': {"source": "wiki2", "answer": "Using tests for module-2"}},
|
||||
{"text": "How to test module-3?", 'meta': {"source": "wiki3", "answer": "Using tests for module-3"}},
|
||||
{"text": "How to test module-4?", 'meta': {"source": "wiki4", "answer": "Using tests for module-4"}},
|
||||
{"text": "How to test module-5?", 'meta': {"source": "wiki5", "answer": "Using tests for module-5"}},
|
||||
]
|
||||
|
||||
document_store.write_documents(documents)
|
||||
document_store.update_embeddings(retriever)
|
||||
|
||||
pipeline = FAQPipeline(retriever=retriever)
|
||||
|
||||
output = pipeline.run(query="How to test this?", top_k_retriever=3)
|
||||
assert len(output["answers"]) == 3
|
||||
assert output["answers"][0]["query"].startswith("How to")
|
||||
assert output["answers"][0]["answer"].startswith("Using tests")
|
||||
|
||||
if isinstance(document_store, ElasticsearchDocumentStore):
|
||||
output = pipeline.run(query="How to test this?", filters={"source": ["wiki2"]}, top_k_retriever=5)
|
||||
assert len(output["answers"]) == 1
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize(
|
||||
"retriever,document_store",
|
||||
[("embedding", "memory"), ("embedding", "faiss"), ("embedding", "elasticsearch")],
|
||||
indirect=True,
|
||||
)
|
||||
def test_document_search_pipeline(retriever, document_store):
|
||||
documents = [
|
||||
{"text": "Sample text for document-1", 'meta': {"source": "wiki1"}},
|
||||
{"text": "Sample text for document-2", 'meta': {"source": "wiki2"}},
|
||||
{"text": "Sample text for document-3", 'meta': {"source": "wiki3"}},
|
||||
{"text": "Sample text for document-4", 'meta': {"source": "wiki4"}},
|
||||
{"text": "Sample text for document-5", 'meta': {"source": "wiki5"}},
|
||||
]
|
||||
|
||||
document_store.write_documents(documents)
|
||||
document_store.update_embeddings(retriever)
|
||||
|
||||
pipeline = DocumentSearchPipeline(retriever=retriever)
|
||||
output = pipeline.run(query="How to test this?", top_k_retriever=4)
|
||||
assert len(output.get('documents', [])) == 4
|
||||
|
||||
if isinstance(document_store, ElasticsearchDocumentStore):
|
||||
output = pipeline.run(query="How to test this?", filters={"source": ["wiki2"]}, top_k_retriever=5)
|
||||
assert len(output["documents"]) == 1
|
||||
|
||||
@ -275,7 +275,7 @@
|
||||
"\n",
|
||||
" # Now generate answer from question and retrieved documents\n",
|
||||
" predicted_result = generator.predict(\n",
|
||||
" question=question,\n",
|
||||
" query=question,\n",
|
||||
" documents=retriever_results,\n",
|
||||
" top_k=1\n",
|
||||
" )\n",
|
||||
|
||||
@ -101,7 +101,7 @@ for question in QUESTIONS:
|
||||
|
||||
# Now generate answer from question and retrieved documents
|
||||
predicted_result = generator.predict(
|
||||
question=question,
|
||||
query=question,
|
||||
documents=retriever_results,
|
||||
top_k=1
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user