mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-25 05:58:57 +00:00
Add eval for Dense Passage Retriever & Refactor handling of labels/feedback (#243)
This commit is contained in:
parent
52370c7bd4
commit
29a15c0d59
@ -116,7 +116,7 @@ Elasticsearch (Recommended)
|
||||
|
||||
You can get started by running a single Elasticsearch node using docker::
|
||||
|
||||
docker run -d -p 9200:9200 -e "discovery.type=single-node" elasticsearch:7.6.1
|
||||
docker run -d -p 9200:9200 -e "discovery.type=single-node" elasticsearch:7.6.2
|
||||
|
||||
Or if docker is not possible for you::
|
||||
|
||||
|
||||
@ -1,21 +1,117 @@
|
||||
from abc import abstractmethod, ABC
|
||||
from typing import Any, Optional, Dict, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Any, Optional, Dict, List, Union
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
|
||||
class Document(BaseModel):
|
||||
id: str = Field(..., description="_id field from Elasticsearch")
|
||||
text: str = Field(..., description="Text of the document")
|
||||
external_source_id: Optional[str] = Field(
|
||||
None,
|
||||
description="id for the source file the document was created from. In the case when a large file is divided "
|
||||
"across multiple Elasticsearch documents, this id can be used to reference original source file.",
|
||||
)
|
||||
question: Optional[str] = Field(None, description="Question text for FAQs.")
|
||||
query_score: Optional[float] = Field(None, description="Elasticsearch query score for a retrieved document")
|
||||
meta: Dict[str, Any] = Field({}, description="Meta fields for a document like name, url, or author.")
|
||||
tags: Optional[Dict[str, Any]] = Field(None, description="Tags that allow filtering of the data")
|
||||
class Document:
|
||||
def __init__(self, text: str,
|
||||
id: Optional[Union[str, UUID]] = None,
|
||||
query_score: Optional[float] = None,
|
||||
question: Optional[str] = None,
|
||||
meta: Dict[str, Any] = None,
|
||||
tags: Optional[Dict[str, Any]] = None,
|
||||
embedding: Optional[List[float]] = None):
|
||||
"""
|
||||
Object used to represent documents / passages in a standardized way within Haystack.
|
||||
For example, this is what the retriever will return from the DocumentStore,
|
||||
regardless if it's ElasticsearchDocumentStore or InMemoryDocumentStore.
|
||||
|
||||
Note that there can be multiple Documents originating from one file (e.g. PDF),
|
||||
if you split the text into smaller passages. We'll have one Document per passage in this case.
|
||||
|
||||
:param id: ID used within the DocumentStore
|
||||
:param text: Text of the document
|
||||
:param query_score: Retriever's query score for a retrieved document
|
||||
:param question: Question text for FAQs.
|
||||
:param meta: Meta fields for a document like name, url, or author.
|
||||
:param tags: Tags that allow filtering of the data
|
||||
:param embedding: Vector encoding of the text
|
||||
"""
|
||||
|
||||
self.text = text
|
||||
# Create a unique ID (either new one, or one from user input)
|
||||
if id:
|
||||
if isinstance(id, str):
|
||||
self.id = UUID(hex=str(id), version=4)
|
||||
if isinstance(id, UUID):
|
||||
self.id = id
|
||||
else:
|
||||
self.id = uuid4()
|
||||
|
||||
self.query_score = query_score
|
||||
self.question = question
|
||||
self.meta = meta
|
||||
self.tags = tags # deprecate?
|
||||
self.embedding = embedding
|
||||
|
||||
def to_dict(self):
|
||||
return self.__dict__
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, dict):
|
||||
_doc = dict.copy()
|
||||
init_args = ["text", "id", "query_score", "question", "meta", "tags", "embedding"]
|
||||
if "meta" not in _doc.keys():
|
||||
_doc["meta"] = {}
|
||||
# copy additional fields into "meta"
|
||||
for k, v in _doc.items():
|
||||
if k not in init_args:
|
||||
_doc["meta"][k] = v
|
||||
# remove additional fields from top level
|
||||
_doc = {k: v for k, v in _doc.items() if k in init_args}
|
||||
|
||||
return cls(**_doc)
|
||||
|
||||
|
||||
class Label:
|
||||
def __init__(self, question: str,
|
||||
answer: str,
|
||||
is_correct_answer: bool,
|
||||
is_correct_document: bool,
|
||||
origin: str,
|
||||
document_id: Optional[UUID] = None,
|
||||
offset_start_in_doc: Optional[int] = None,
|
||||
no_answer: Optional[bool] = None,
|
||||
model_id: Optional[int] = None):
|
||||
"""
|
||||
Object used to represent label/feedback in a standardized way within Haystack.
|
||||
This includes labels from dataset like SQuAD, annotations from labeling tools,
|
||||
or, user-feedback from the Haystack REST API.
|
||||
|
||||
:param question: the question(or query) for finding answers.
|
||||
:param answer: teh answer string.
|
||||
:param is_correct_answer: whether the sample is positive or negative.
|
||||
:param is_correct_document: in case of negative sample(is_correct_answer is False), there could be two cases;
|
||||
incorrect answer but correct document & incorrect document. This flag denotes if
|
||||
the returned document was correct.
|
||||
:param origin: the source for the labels. It can be used to later for filtering.
|
||||
:param document_id: the document_store's ID for the returned answer document.
|
||||
:param offset_start_in_doc: the answer start offset in the document.
|
||||
:param no_answer: whether the question in unanswerable.
|
||||
:param model_id: model_id used for prediction(in-case of user feedback).
|
||||
"""
|
||||
self.no_answer = no_answer
|
||||
self.origin = origin
|
||||
self.question = question
|
||||
self.is_correct_answer = is_correct_answer
|
||||
self.is_correct_document = is_correct_document
|
||||
if document_id:
|
||||
if isinstance(document_id, str):
|
||||
self.document_id: Optional[UUID] = UUID(hex=str(document_id), version=4)
|
||||
if isinstance(document_id, UUID):
|
||||
self.document_id = document_id
|
||||
else:
|
||||
self.document_id = document_id
|
||||
self.answer = answer
|
||||
self.offset_start_in_doc = offset_start_in_doc
|
||||
self.model_id = model_id
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, dict):
|
||||
return cls(**dict)
|
||||
|
||||
def to_dict(self):
|
||||
return self.__dict__
|
||||
|
||||
|
||||
class BaseDocumentStore(ABC):
|
||||
@ -25,7 +121,7 @@ class BaseDocumentStore(ABC):
|
||||
index: Optional[str]
|
||||
|
||||
@abstractmethod
|
||||
def write_documents(self, documents: List[dict]):
|
||||
def write_documents(self, documents: List[dict], index: Optional[str] = None):
|
||||
"""
|
||||
Indexes documents for later queries.
|
||||
|
||||
@ -34,25 +130,31 @@ class BaseDocumentStore(ABC):
|
||||
Optionally: Include meta data via {"text": "<the-actual-text>",
|
||||
"meta":{"name": "<some-document-name>, "author": "somebody", ...}}
|
||||
It can be used for filtering and is accessible in the responses of the Finder.
|
||||
:param index: Optional name of index where the documents shall be written to.
|
||||
If None, the DocumentStore's default index (self.index) will be used.
|
||||
|
||||
:return: None
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_all_documents(self) -> List[Document]:
|
||||
def get_all_documents(self, index: Optional[str] = None) -> List[Document]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_document_by_id(self, id: str) -> Optional[Document]:
|
||||
def get_all_labels(self, index: str = "label", filters: Optional[dict] = None) -> List[Label]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_document_ids_by_tags(self, tag) -> List[str]:
|
||||
def get_document_by_id(self, id: UUID, index: Optional[str] = None) -> Optional[Document]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_document_count(self) -> int:
|
||||
def get_document_ids_by_tags(self, tag, index) -> List[str]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_document_count(self, index: Optional[str] = None) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -62,3 +164,15 @@ class BaseDocumentStore(ABC):
|
||||
top_k: int = 10,
|
||||
index: Optional[str] = None) -> List[Document]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_label_count(self, index: Optional[str] = None) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_eval_data(self, filename: str, doc_index: str = "document", label_index: str = "label"):
|
||||
pass
|
||||
|
||||
def delete_all_documents(self, index: str):
|
||||
pass
|
||||
|
||||
|
||||
@ -1,12 +1,15 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from string import Template
|
||||
from typing import List, Optional, Union, Dict, Any
|
||||
from elasticsearch import Elasticsearch
|
||||
from elasticsearch.helpers import bulk, scan
|
||||
import numpy as np
|
||||
from uuid import UUID
|
||||
|
||||
from haystack.database.base import BaseDocumentStore, Document
|
||||
from haystack.database.base import BaseDocumentStore, Document, Label
|
||||
from haystack.indexing.utils import eval_data_from_file
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -19,12 +22,12 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
username: str = "",
|
||||
password: str = "",
|
||||
index: str = "document",
|
||||
label_index: str = "label",
|
||||
search_fields: Union[str, list] = "text",
|
||||
text_field: str = "text",
|
||||
name_field: str = "name",
|
||||
external_source_id_field: str = "external_source_id",
|
||||
embedding_field: Optional[str] = None,
|
||||
embedding_dim: Optional[int] = None,
|
||||
embedding_field: str = "embedding",
|
||||
embedding_dim: int = 768,
|
||||
custom_mapping: Optional[dict] = None,
|
||||
excluded_meta_data: Optional[list] = None,
|
||||
faq_question_field: Optional[str] = None,
|
||||
@ -49,7 +52,6 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
:param text_field: Name of field that might contain the answer and will therefore be passed to the Reader Model (e.g. "full_text").
|
||||
If no Reader is used (e.g. in FAQ-Style QA) the plain content of this field will just be returned.
|
||||
:param name_field: Name of field that contains the title of the the doc
|
||||
:param external_source_id_field: If you have an external id (= non-elasticsearch) that identifies your documents, you can specify it here.
|
||||
:param embedding_field: Name of field containing an embedding vector (Only needed when using a dense retriever (e.g. DensePassageRetriever, EmbeddingRetriever) on top)
|
||||
:param embedding_dim: Dimensionality of embedding vector (Only needed when using a dense retriever (e.g. DensePassageRetriever, EmbeddingRetriever) on top)
|
||||
:param custom_mapping: If you want to use your own custom mapping for creating a new index in Elasticsearch, you can supply it here as a dictionary.
|
||||
@ -63,25 +65,6 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
self.client = Elasticsearch(hosts=[{"host": host, "port": port}], http_auth=(username, password),
|
||||
scheme=scheme, ca_certs=ca_certs, verify_certs=verify_certs)
|
||||
|
||||
# if no custom_mapping is supplied, use the default mapping
|
||||
if not custom_mapping:
|
||||
custom_mapping = {
|
||||
"mappings": {
|
||||
"properties": {
|
||||
name_field: {"type": "text"},
|
||||
text_field: {"type": "text"},
|
||||
external_source_id_field: {"type": "text"},
|
||||
}
|
||||
}
|
||||
}
|
||||
if embedding_field:
|
||||
custom_mapping["mappings"]["properties"][embedding_field] = {"type": "dense_vector",
|
||||
"dims": embedding_dim}
|
||||
# create an index if not exists
|
||||
if create_index:
|
||||
self.client.indices.create(index=index, ignore=400, body=custom_mapping)
|
||||
self.index = index
|
||||
|
||||
# configure mappings to ES fields that will be used for querying / displaying results
|
||||
if type(search_fields) == str:
|
||||
search_fields = [search_fields]
|
||||
@ -91,50 +74,114 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
self.search_fields = search_fields
|
||||
self.text_field = text_field
|
||||
self.name_field = name_field
|
||||
self.external_source_id_field = external_source_id_field
|
||||
self.embedding_field = embedding_field
|
||||
self.embedding_dim = embedding_dim
|
||||
self.excluded_meta_data = excluded_meta_data
|
||||
self.faq_question_field = faq_question_field
|
||||
|
||||
def get_document_by_id(self, id: str) -> Optional[Document]:
|
||||
self.custom_mapping = custom_mapping
|
||||
if create_index:
|
||||
self._create_document_index(index)
|
||||
self.index: str = index
|
||||
|
||||
self._create_label_index(label_index)
|
||||
self.label_index = label_index
|
||||
|
||||
def _create_document_index(self, index_name):
|
||||
if self.custom_mapping:
|
||||
mapping = self.custom_mapping
|
||||
else:
|
||||
mapping = {
|
||||
"mappings": {
|
||||
"properties": {
|
||||
self.name_field: {"type": "text"},
|
||||
self.text_field: {"type": "text"},
|
||||
}
|
||||
}
|
||||
}
|
||||
if self.embedding_field:
|
||||
mapping["mappings"]["properties"][self.embedding_field] = {"type": "dense_vector", "dims": self.embedding_dim}
|
||||
self.client.indices.create(index=index_name, ignore=400, body=mapping)
|
||||
|
||||
def _create_label_index(self, index_name):
|
||||
mapping = {
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"question": {"type": "text"},
|
||||
"answer": {"type": "text"},
|
||||
"is_correct_answer": {"type": "boolean"},
|
||||
"is_correct_document": {"type": "boolean"},
|
||||
"origin": {"type": "keyword"},
|
||||
"document_id": {"type": "keyword"},
|
||||
"offset_start_in_doc": {"type": "long"},
|
||||
"no_answer": {"type": "boolean"},
|
||||
"model_id": {"type": "keyword"},
|
||||
"type": {"type": "keyword"},
|
||||
}
|
||||
}
|
||||
}
|
||||
self.client.indices.create(index=index_name, ignore=400, body=mapping)
|
||||
|
||||
def get_document_by_id(self, id: Union[UUID, str], index=None) -> Optional[Document]:
|
||||
if index is None:
|
||||
index = self.index
|
||||
query = {"query": {"ids": {"values": [id]}}}
|
||||
result = self.client.search(index=self.index, body=query)["hits"]["hits"]
|
||||
result = self.client.search(index=index, body=query)["hits"]["hits"]
|
||||
|
||||
document = self._convert_es_hit_to_document(result[0]) if result else None
|
||||
return document
|
||||
|
||||
def get_document_ids_by_tags(self, tags: dict) -> List[str]:
|
||||
def get_document_ids_by_tags(self, tags: dict, index: Optional[str]) -> List[str]:
|
||||
index = index or self.index
|
||||
term_queries = [{"terms": {key: value}} for key, value in tags.items()]
|
||||
query = {"query": {"bool": {"must": term_queries}}}
|
||||
logger.debug(f"Tag filter query: {query}")
|
||||
result = self.client.search(index=self.index, body=query, size=10000)["hits"]["hits"]
|
||||
result = self.client.search(index=index, body=query, size=10000)["hits"]["hits"]
|
||||
doc_ids = []
|
||||
for hit in result:
|
||||
doc_ids.append(hit["_id"])
|
||||
return doc_ids
|
||||
|
||||
def write_documents(self, documents: List[dict]):
|
||||
def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None):
|
||||
"""
|
||||
Indexes documents for later queries in Elasticsearch.
|
||||
|
||||
:param documents: List of dictionaries.
|
||||
Default format: {"text": "<the-actual-text>"}
|
||||
:param documents: a list of Python dictionaries or a list of Haystack Document objects.
|
||||
For documents as dictionaries, the format is {"text": "<the-actual-text>"}.
|
||||
Optionally: Include meta data via {"text": "<the-actual-text>",
|
||||
"meta":{"name": "<some-document-name>, "author": "somebody", ...}}
|
||||
It can be used for filtering and is accessible in the responses of the Finder.
|
||||
Advanced: If you are using your own Elasticsearch mapping, the key names in the dictionary
|
||||
should be changed to what you have set for self.text_field and self.name_field .
|
||||
should be changed to what you have set for self.text_field and self.name_field.
|
||||
:param index: Elasticsearch index where the documents should be indexed. If not supplied, self.index will be used.
|
||||
:return: None
|
||||
"""
|
||||
|
||||
if index and not self.client.indices.exists(index=index):
|
||||
self._create_document_index(index)
|
||||
|
||||
if index is None:
|
||||
index = self.index
|
||||
|
||||
# Make sure we comply to Document class format
|
||||
documents_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents]
|
||||
|
||||
documents_to_index = []
|
||||
for doc in documents:
|
||||
for doc in documents_objects:
|
||||
|
||||
_doc = {
|
||||
"_op_type": "create",
|
||||
"_index": self.index,
|
||||
**doc
|
||||
"_index": index,
|
||||
**doc.to_dict()
|
||||
} # type: Dict[str, Any]
|
||||
|
||||
# rename id for elastic
|
||||
_doc["_id"] = str(_doc.pop("id"))
|
||||
|
||||
# don't index query score and empty fields
|
||||
_ = _doc.pop("query_score", None)
|
||||
_doc = {k:v for k,v in _doc.items() if v is not None}
|
||||
|
||||
# In order to have a flat structure in elastic + similar behaviour to the other DocumentStores,
|
||||
# we "unnest" all value within "meta"
|
||||
if "meta" in _doc.keys():
|
||||
@ -142,24 +189,78 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
_doc[k] = v
|
||||
_doc.pop("meta")
|
||||
documents_to_index.append(_doc)
|
||||
bulk(self.client, documents_to_index, request_timeout=300)
|
||||
bulk(self.client, documents_to_index, request_timeout=300, refresh="wait_for")
|
||||
|
||||
def write_labels(self, labels: Union[List[Label], List[dict]], index: Optional[str] = "label"):
|
||||
if index and not self.client.indices.exists(index=index):
|
||||
self._create_label_index(index)
|
||||
|
||||
# Make sure we comply to Label class format
|
||||
label_objects = [Label.from_dict(l) if isinstance(l, dict) else l for l in labels]
|
||||
|
||||
labels_to_index = []
|
||||
for label in label_objects:
|
||||
_label = {
|
||||
"_op_type": "create",
|
||||
"_index": index,
|
||||
**label.to_dict()
|
||||
} # type: Dict[str, Any]
|
||||
|
||||
labels_to_index.append(_label)
|
||||
bulk(self.client, labels_to_index, request_timeout=300, refresh="wait_for")
|
||||
|
||||
def update_document_meta(self, id: str, meta: Dict[str, str]):
|
||||
body = {"doc": meta}
|
||||
self.client.update(index=self.index, doc_type="_doc", id=id, body=body)
|
||||
|
||||
def get_document_count(self, index: Optional[str] = None,) -> int:
|
||||
def get_document_count(self, index: Optional[str] = None) -> int:
|
||||
if index is None:
|
||||
index = self.index
|
||||
result = self.client.count(index=index)
|
||||
count = result["count"]
|
||||
return count
|
||||
|
||||
def get_all_documents(self) -> List[Document]:
|
||||
result = scan(self.client, query={"query": {"match_all": {}}}, index=self.index)
|
||||
def get_label_count(self, index: Optional[str] = None) -> int:
|
||||
return self.get_document_count(index=index)
|
||||
|
||||
def get_all_documents(self, index: Optional[str] = None, filters: Optional[dict] = None) -> List[Document]:
|
||||
if index is None:
|
||||
index = self.index
|
||||
|
||||
result = self.get_all_documents_in_index(index=index, filters=filters)
|
||||
documents = [self._convert_es_hit_to_document(hit) for hit in result]
|
||||
|
||||
return documents
|
||||
|
||||
def get_all_labels(self, index: str = "label", filters: Optional[dict] = None) -> List[Label]:
|
||||
result = self.get_all_documents_in_index(index=index, filters=filters)
|
||||
labels = [Label.from_dict(hit["_source"]) for hit in result]
|
||||
return labels
|
||||
|
||||
def get_all_documents_in_index(self, index: str, filters: Optional[dict] = None) -> List[dict]:
|
||||
body = {
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": {
|
||||
"match_all": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
} # type: Dict[str, Any]
|
||||
|
||||
if filters:
|
||||
filter_clause = []
|
||||
for key, values in filters.items():
|
||||
filter_clause.append(
|
||||
{
|
||||
"terms": {key: values}
|
||||
}
|
||||
)
|
||||
body["query"]["bool"]["filter"] = filter_clause
|
||||
result = scan(self.client, query=body, index=index)
|
||||
|
||||
return result
|
||||
|
||||
def query(
|
||||
self,
|
||||
query: Optional[str],
|
||||
@ -277,27 +378,42 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
body["_source"] = {"excludes": self.excluded_meta_data}
|
||||
|
||||
logger.debug(f"Retriever query: {body}")
|
||||
result = self.client.search(index=index, body=body)["hits"]["hits"]
|
||||
result = self.client.search(index=index, body=body, request_timeout=300)["hits"]["hits"]
|
||||
|
||||
documents = [self._convert_es_hit_to_document(hit, score_adjustment=-1) for hit in result]
|
||||
return documents
|
||||
|
||||
def _convert_es_hit_to_document(self, hit: dict, score_adjustment: int = 0) -> Document:
|
||||
# We put all additional data of the doc into meta_data and return it in the API
|
||||
meta_data = {k:v for k,v in hit["_source"].items() if k not in (self.text_field, self.external_source_id_field)}
|
||||
meta_data = {k:v for k,v in hit["_source"].items() if k not in (self.text_field, self.faq_question_field, self.embedding_field, "tags")}
|
||||
meta_data["name"] = meta_data.pop(self.name_field, None)
|
||||
|
||||
document = Document(
|
||||
id=hit["_id"],
|
||||
text=hit["_source"][self.text_field],
|
||||
external_source_id=hit["_source"].get(self.external_source_id_field),
|
||||
text=hit["_source"].get(self.text_field),
|
||||
meta=meta_data,
|
||||
query_score=hit["_score"] + score_adjustment if hit["_score"] else None,
|
||||
question=hit["_source"].get(self.faq_question_field)
|
||||
question=hit["_source"].get(self.faq_question_field),
|
||||
tags=hit["_source"].get("tags"),
|
||||
embedding=hit["_source"].get(self.embedding_field)
|
||||
)
|
||||
return document
|
||||
|
||||
def update_embeddings(self, retriever):
|
||||
def describe_documents(self, index=None):
|
||||
if index is None:
|
||||
index = self.index
|
||||
docs = self.get_all_documents(index)
|
||||
|
||||
l = [len(d.text) for d in docs]
|
||||
stats = {"count": len(docs),
|
||||
"chars_mean": np.mean(l),
|
||||
"chars_max": max(l),
|
||||
"chars_min": min(l),
|
||||
"chars_median": np.median(l),
|
||||
}
|
||||
return stats
|
||||
|
||||
def update_embeddings(self, retriever, index=None):
|
||||
"""
|
||||
Updates the embeddings in the the document store using the encoding model specified in the retriever.
|
||||
This can be useful if want to add or change the embeddings for your documents (e.g. after changing the retriever config).
|
||||
@ -305,20 +421,29 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
:param retriever: Retriever
|
||||
:return: None
|
||||
"""
|
||||
if index is None:
|
||||
index = self.index
|
||||
|
||||
if not self.embedding_field:
|
||||
raise RuntimeError("Please specify arg `embedding_field` in ElasticsearchDocumentStore()")
|
||||
docs = self.get_all_documents()
|
||||
raise RuntimeError("Specify the arg `embedding_field` when initializing ElasticsearchDocumentStore()")
|
||||
|
||||
docs = self.get_all_documents(index)
|
||||
passages = [d.text for d in docs]
|
||||
|
||||
#TODO Index embeddings every X batches to avoid OOM for huge document collections
|
||||
logger.info(f"Updating embeddings for {len(passages)} docs ...")
|
||||
embeddings = retriever.embed_passages(passages)
|
||||
|
||||
assert len(docs) == len(embeddings)
|
||||
|
||||
if embeddings[0].shape[0] != self.embedding_dim:
|
||||
raise RuntimeError(f"Embedding dim. of model ({embeddings[0].shape[0]})"
|
||||
f" doesn't match embedding dim. in documentstore ({self.embedding_dim})."
|
||||
"Specify the arg `embedding_dim` when initializing ElasticsearchDocumentStore()")
|
||||
doc_updates = []
|
||||
for doc, emb in zip(docs, embeddings):
|
||||
update = {"_op_type": "update",
|
||||
"_index": self.index,
|
||||
"_index": index,
|
||||
"_id": doc.id,
|
||||
"doc": {self.embedding_field: emb.tolist()},
|
||||
}
|
||||
@ -326,7 +451,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
|
||||
bulk(self.client, doc_updates, request_timeout=300)
|
||||
|
||||
def add_eval_data(self, filename: str, doc_index: str = "eval_document", label_index: str = "feedback"):
|
||||
def add_eval_data(self, filename: str, doc_index: str = "eval_document", label_index: str = "label"):
|
||||
"""
|
||||
Adds a SQuAD-formatted file to the DocumentStore in order to be able to perform evaluation on it.
|
||||
|
||||
@ -338,63 +463,23 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
:type label_index: str
|
||||
"""
|
||||
|
||||
eval_docs_to_index = []
|
||||
questions_to_index = []
|
||||
docs, labels = eval_data_from_file(filename)
|
||||
self.write_documents(docs, index=doc_index)
|
||||
self.write_labels(labels, index=label_index)
|
||||
|
||||
with open(filename, "r") as file:
|
||||
data = json.load(file)
|
||||
for document in data["data"]:
|
||||
for paragraph in document["paragraphs"]:
|
||||
doc_to_index= {}
|
||||
id = hash(paragraph["context"])
|
||||
for fieldname, value in paragraph.items():
|
||||
# write docs to doc_index
|
||||
if fieldname == "context":
|
||||
doc_to_index[self.text_field] = value
|
||||
doc_to_index["doc_id"] = str(id)
|
||||
doc_to_index["_op_type"] = "create"
|
||||
doc_to_index["_index"] = doc_index
|
||||
# write questions to label_index
|
||||
elif fieldname == "qas":
|
||||
for qa in value:
|
||||
question_to_index = {
|
||||
"question": qa["question"],
|
||||
"answers": qa["answers"],
|
||||
"doc_id": str(id),
|
||||
"origin": "gold_label",
|
||||
"index_name": doc_index,
|
||||
"_op_type": "create",
|
||||
"_index": label_index
|
||||
}
|
||||
questions_to_index.append(question_to_index)
|
||||
# additional fields for docs
|
||||
else:
|
||||
doc_to_index[fieldname] = value
|
||||
def delete_all_documents(self, index: str):
|
||||
"""
|
||||
Delete all documents in a index.
|
||||
|
||||
:param index: index name
|
||||
:return: None
|
||||
"""
|
||||
self.client.delete_by_query(index=index, body={"query": {"match_all": {}}}, ignore=[404])
|
||||
# We want to be sure that all docs are deleted before continuing (delete_by_query doesn't support wait_for)
|
||||
time.sleep(1)
|
||||
|
||||
for key, value in document.items():
|
||||
if key == "title":
|
||||
doc_to_index[self.name_field] = value
|
||||
elif key != "paragraphs":
|
||||
doc_to_index[key] = value
|
||||
|
||||
eval_docs_to_index.append(doc_to_index)
|
||||
|
||||
bulk(self.client, eval_docs_to_index)
|
||||
bulk(self.client, questions_to_index)
|
||||
|
||||
def get_all_documents_in_index(self, index: str, filters: Optional[dict] = None):
|
||||
body = {
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": {
|
||||
"match_all" : {}
|
||||
}
|
||||
}
|
||||
}
|
||||
} # type: Dict[str, Any]
|
||||
|
||||
if filters:
|
||||
body["query"]["bool"]["filter"] = {"term": filters}
|
||||
result = scan(self.client, query=body, index=index)
|
||||
|
||||
return result
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
from typing import Any, Dict, List, Optional, Union, Tuple
|
||||
|
||||
from haystack.database.base import BaseDocumentStore, Document
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from uuid import UUID
|
||||
from collections import defaultdict
|
||||
import uuid
|
||||
from haystack.database.base import BaseDocumentStore, Document, Label
|
||||
from haystack.indexing.utils import eval_data_from_file
|
||||
|
||||
|
||||
class InMemoryDocumentStore(BaseDocumentStore):
|
||||
@ -9,44 +12,42 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_field: Optional[str] = None):
|
||||
self.docs = {} # type: Dict[str, Any]
|
||||
self.doc_tags = {} # type: Dict[str, Any]
|
||||
self.embedding_field = embedding_field
|
||||
self.index = None
|
||||
self.doc_tags: Dict[str, Any] = {}
|
||||
self.indexes: Dict[str, Dict] = defaultdict(dict)
|
||||
self.index: str = "document"
|
||||
self.label_index: str = "label"
|
||||
|
||||
def write_documents(self, documents: List[dict]):
|
||||
def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None):
|
||||
"""
|
||||
Indexes documents for later queries.
|
||||
|
||||
:param documents: List of dictionaries in the format {"text": "<the-actual-text>"}.
|
||||
|
||||
:param documents: a list of Python dictionaries or a list of Haystack Document objects.
|
||||
For documents as dictionaries, the format is {"text": "<the-actual-text>"}.
|
||||
Optionally, you can also supply "tags": ["one-tag", "another-one"]
|
||||
or additional meta data via "meta": {"name": "<some-document-name>, "author": "someone", "url":"some-url" ...}
|
||||
|
||||
:param index: write documents to a custom namespace. For instance, documents for evaluation can be indexed in a
|
||||
separate index than the documents for search.
|
||||
:return: None
|
||||
"""
|
||||
import hashlib
|
||||
index = index or self.index
|
||||
|
||||
if documents is None:
|
||||
return
|
||||
documents_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents]
|
||||
|
||||
for document in documents:
|
||||
text = document["text"]
|
||||
if "meta" not in document.keys():
|
||||
document["meta"] = {}
|
||||
for k, v in document.items(): # put additional fields other than text in meta
|
||||
if k not in ["text", "meta", "tags"]:
|
||||
document["meta"][k] = v
|
||||
for document in documents_objects:
|
||||
self.indexes[index][document.id] = document
|
||||
|
||||
if not text:
|
||||
raise Exception("A document cannot have empty text field.")
|
||||
#TODO fix tags after id refactoring
|
||||
tags = document.tags
|
||||
self._map_tags_to_ids(document.id, tags)
|
||||
|
||||
hash = hashlib.md5(text.encode("utf-8")).hexdigest()
|
||||
def write_labels(self, labels: Union[List[dict], List[Label]], index: Optional[str] = None):
|
||||
index = index or self.label_index
|
||||
label_objects = [Label.from_dict(l) if isinstance(l, dict) else l for l in labels]
|
||||
|
||||
self.docs[hash] = document
|
||||
|
||||
tags = document.get("tags", [])
|
||||
|
||||
self._map_tags_to_ids(hash, tags)
|
||||
for label in label_objects:
|
||||
label_id = uuid.uuid4()
|
||||
self.indexes[index][label_id] = label
|
||||
|
||||
def _map_tags_to_ids(self, hash: str, tags: List[str]):
|
||||
if isinstance(tags, list):
|
||||
@ -63,9 +64,9 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
else:
|
||||
self.doc_tags[comp_key] = [hash]
|
||||
|
||||
def get_document_by_id(self, id: str) -> Document:
|
||||
document = self._convert_memory_hit_to_document(self.docs[id], doc_id=id)
|
||||
return document
|
||||
def get_document_by_id(self, id: Union[str, UUID], index: Optional[str] = None) -> Document:
|
||||
index = index or self.index
|
||||
return self.indexes[index][id]
|
||||
|
||||
def _convert_memory_hit_to_document(self, hit: Dict[str, Any], doc_id: Optional[str] = None) -> Document:
|
||||
document = Document(
|
||||
@ -90,22 +91,17 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
"InMemoryDocumentStore.query_by_embedding(). Please remove filters or "
|
||||
"use a different DocumentStore (e.g. ElasticsearchDocumentStore).")
|
||||
|
||||
if self.embedding_field is None:
|
||||
raise Exception(
|
||||
"To use query_by_embedding() 'embedding field' must "
|
||||
"be specified when initializing the document store."
|
||||
)
|
||||
index = index or self.index
|
||||
|
||||
if query_emb is None:
|
||||
return []
|
||||
|
||||
candidate_docs = []
|
||||
for idx, hit in self.docs.items():
|
||||
hit["query_score"] = dot(query_emb, hit[self.embedding_field]) / (
|
||||
norm(query_emb) * norm(hit[self.embedding_field])
|
||||
for idx, doc in self.indexes[index].items():
|
||||
doc.query_score = dot(query_emb, doc.embedding) / (
|
||||
norm(query_emb) * norm(doc.embedding)
|
||||
)
|
||||
_doc = self._convert_memory_hit_to_document(hit=hit, doc_id=idx)
|
||||
candidate_docs.append(_doc)
|
||||
candidate_docs.append(doc)
|
||||
|
||||
return sorted(candidate_docs, key=lambda x: x.query_score, reverse=True)[0:top_k]
|
||||
|
||||
@ -120,17 +116,18 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
#TODO
|
||||
raise NotImplementedError("update_embeddings() is not yet implemented for this DocumentStore")
|
||||
|
||||
def get_document_ids_by_tags(self, tags: Union[List[Dict[str, Union[str, List[str]]]], Dict[str, Union[str, List[str]]]]) -> List[str]:
|
||||
def get_document_ids_by_tags(self, tags: Union[List[Dict[str, Union[str, List[str]]]], Dict[str, Union[str, List[str]]]], index: Optional[str] = None) -> List[str]:
|
||||
"""
|
||||
The format for the dict is {"tag-1": "value-1", "tag-2": "value-2" ...}
|
||||
The format for the dict is {"tag-1": ["value-1","value-2"], "tag-2": ["value-3]" ...}
|
||||
"""
|
||||
index = index or self.index
|
||||
if not isinstance(tags, list):
|
||||
tags = [tags]
|
||||
result = self._find_ids_by_tags(tags)
|
||||
result = self._find_ids_by_tags(tags, index=index)
|
||||
return result
|
||||
|
||||
def _find_ids_by_tags(self, tags: List[Dict[str, Union[str, List[str]]]]):
|
||||
def _find_ids_by_tags(self, tags: List[Dict[str, Union[str, List[str]]]], index: str):
|
||||
result = []
|
||||
for tag in tags:
|
||||
tag_keys = tag.keys()
|
||||
@ -141,14 +138,63 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
comp_key = str((tag_key, tag_value))
|
||||
doc_ids = self.doc_tags.get(comp_key, [])
|
||||
for doc_id in doc_ids:
|
||||
result.append(self.docs.get(doc_id))
|
||||
result.append(self.indexes[index].get(doc_id))
|
||||
return result
|
||||
|
||||
def get_document_count(self) -> int:
|
||||
return len(self.docs.items())
|
||||
def get_document_count(self, index=None) -> int:
|
||||
index = index or self.index
|
||||
return len(self.indexes[index].items())
|
||||
|
||||
def get_all_documents(self) -> List[Document]:
|
||||
return [
|
||||
Document(id=item[0], text=item[1]["text"], meta=item[1].get("meta", {}))
|
||||
for item in self.docs.items()
|
||||
]
|
||||
def get_label_count(self, index=None) -> int:
|
||||
index = index or self.label_index
|
||||
return len(self.indexes[index].items())
|
||||
|
||||
def get_all_documents(self, index=None) -> List[Document]:
|
||||
index = index or self.index
|
||||
return list(self.indexes[index].values())
|
||||
|
||||
def get_all_labels(self, index=None, filters=None) -> List[Label]:
|
||||
index = index or self.label_index
|
||||
|
||||
if filters:
|
||||
result = []
|
||||
for label in self.indexes[index].values():
|
||||
label_dict = label.to_dict()
|
||||
is_hit = True
|
||||
for key, values in filters.items():
|
||||
if label_dict[key] not in values:
|
||||
is_hit = False
|
||||
break
|
||||
if is_hit:
|
||||
result.append(label)
|
||||
else:
|
||||
result = list(self.indexes[index].values())
|
||||
|
||||
return result
|
||||
|
||||
def add_eval_data(self, filename: str, doc_index: str = "document", label_index: str = "label"):
|
||||
"""
|
||||
Adds a SQuAD-formatted file to the DocumentStore in order to be able to perform evaluation on it.
|
||||
|
||||
:param filename: Name of the file containing evaluation data
|
||||
:type filename: str
|
||||
:param doc_index: Elasticsearch index where evaluation documents should be stored
|
||||
:type doc_index: str
|
||||
:param label_index: Elasticsearch index where labeled questions should be stored
|
||||
:type label_index: str
|
||||
"""
|
||||
|
||||
docs, labels = eval_data_from_file(filename)
|
||||
self.write_documents(docs, index=doc_index)
|
||||
self.write_labels(labels, index=label_index)
|
||||
|
||||
def delete_all_documents(self, index=None):
|
||||
"""
|
||||
Delete all documents in a index.
|
||||
|
||||
:param index: index name
|
||||
:return: None
|
||||
"""
|
||||
|
||||
index = index or self.index
|
||||
self.indexes[index] = {}
|
||||
@ -1,10 +1,13 @@
|
||||
import uuid
|
||||
from typing import Any, Dict, Union, List, Optional
|
||||
|
||||
from sqlalchemy import create_engine, Column, Integer, String, DateTime, func, ForeignKey, PickleType
|
||||
from sqlalchemy import create_engine, Column, Integer, String, DateTime, func, ForeignKey, PickleType, Boolean
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import relationship, sessionmaker
|
||||
|
||||
from haystack.database.base import BaseDocumentStore, Document as DocumentSchema
|
||||
from sqlalchemy_utils import UUIDType
|
||||
from uuid import UUID
|
||||
from haystack.indexing.utils import eval_data_from_file
|
||||
from haystack.database.base import BaseDocumentStore, Document, Label
|
||||
|
||||
Base = declarative_base() # type: Any
|
||||
|
||||
@ -12,58 +15,82 @@ Base = declarative_base() # type: Any
|
||||
class ORMBase(Base):
|
||||
__abstract__ = True
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
id = Column(UUIDType(binary=False), default=uuid.uuid4, primary_key=True)
|
||||
created = Column(DateTime, server_default=func.now())
|
||||
updated = Column(DateTime, server_default=func.now(), server_onupdate=func.now())
|
||||
|
||||
|
||||
class Document(ORMBase):
|
||||
class DocumentORM(ORMBase):
|
||||
__tablename__ = "document"
|
||||
|
||||
text = Column(String)
|
||||
text = Column(String, nullable=False)
|
||||
index = Column(String, nullable=False)
|
||||
meta_data = Column(PickleType)
|
||||
|
||||
tags = relationship("Tag", secondary="document_tag", backref="Document")
|
||||
tags = relationship("TagORM", secondary="document_tag", backref="Document")
|
||||
|
||||
|
||||
class Tag(ORMBase):
|
||||
class TagORM(ORMBase):
|
||||
__tablename__ = "tag"
|
||||
|
||||
name = Column(String)
|
||||
value = Column(String)
|
||||
|
||||
documents = relationship("Document", secondary="document_tag", backref="Tag")
|
||||
documents = relationship(DocumentORM, secondary="document_tag", backref="Tag")
|
||||
|
||||
|
||||
class DocumentTag(ORMBase):
|
||||
class DocumentTagORM(ORMBase):
|
||||
__tablename__ = "document_tag"
|
||||
|
||||
document_id = Column(Integer, ForeignKey("document.id"), nullable=False)
|
||||
document_id = Column(UUIDType(binary=False), ForeignKey("document.id"), nullable=False)
|
||||
tag_id = Column(Integer, ForeignKey("tag.id"), nullable=False)
|
||||
|
||||
|
||||
class LabelORM(ORMBase):
|
||||
__tablename__ = "label"
|
||||
|
||||
document_id = Column(UUIDType(binary=False), ForeignKey("document.id"), nullable=False)
|
||||
index = Column(String, nullable=False)
|
||||
no_answer = Column(Boolean, nullable=False)
|
||||
origin = Column(String, nullable=False)
|
||||
question = Column(String, nullable=False)
|
||||
is_correct_answer = Column(Boolean, nullable=False)
|
||||
is_correct_document = Column(Boolean, nullable=False)
|
||||
answer = Column(String, nullable=False)
|
||||
offset_start_in_doc = Column(Integer, nullable=False)
|
||||
model_id = Column(Integer, nullable=True)
|
||||
|
||||
|
||||
class SQLDocumentStore(BaseDocumentStore):
|
||||
def __init__(self, url: str = "sqlite://"):
|
||||
def __init__(self, url: str = "sqlite://", index="document"):
|
||||
engine = create_engine(url)
|
||||
ORMBase.metadata.create_all(engine)
|
||||
Session = sessionmaker(bind=engine)
|
||||
self.session = Session()
|
||||
self.index = index
|
||||
self.label_index = "label"
|
||||
|
||||
def get_document_by_id(self, id: str) -> Optional[DocumentSchema]:
|
||||
document_row = self.session.query(Document).get(id)
|
||||
document = self._convert_sql_row_to_document(document_row)
|
||||
|
||||
def get_document_by_id(self, id: UUID, index=None) -> Optional[Document]:
|
||||
index = index or self.index
|
||||
document_row = self.session.query(DocumentORM).filter_by(index=index, id=id).first()
|
||||
document = document_row or self._convert_sql_row_to_document(document_row)
|
||||
return document
|
||||
|
||||
def get_all_documents(self) -> List[DocumentSchema]:
|
||||
document_rows = self.session.query(Document).all()
|
||||
documents = []
|
||||
for row in document_rows:
|
||||
documents.append(self._convert_sql_row_to_document(row))
|
||||
def get_all_documents(self, index=None) -> List[Document]:
|
||||
index = index or self.index
|
||||
document_rows = self.session.query(DocumentORM).filter_by(index=index).all()
|
||||
documents = [self._convert_sql_row_to_document(row) for row in document_rows]
|
||||
|
||||
return documents
|
||||
|
||||
def get_document_ids_by_tags(self, tags: Dict[str, Union[str, List]]) -> List[str]:
|
||||
def get_all_labels(self, index=None, filters: Optional[dict] = None):
|
||||
index = index or self.label_index
|
||||
label_rows = self.session.query(LabelORM).filter_by(index=index).all()
|
||||
labels = [self._convert_sql_row_to_label(row) for row in label_rows]
|
||||
|
||||
return labels
|
||||
|
||||
def get_document_ids_by_tags(self, tags: Dict[str, Union[str, List]], index: Optional[str] = None) -> List[str]:
|
||||
"""
|
||||
Get list of document ids that have tags from the given list of tags.
|
||||
|
||||
@ -73,6 +100,9 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
if not tags:
|
||||
raise Exception("No tag supplied for filtering the documents")
|
||||
|
||||
if index:
|
||||
raise Exception("'index' parameter is not supported in SQLDocumentStore.get_document_ids_by_tags().")
|
||||
|
||||
query = """
|
||||
SELECT id FROM document WHERE id in (
|
||||
SELECT dt.document_id
|
||||
@ -91,31 +121,74 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
doc_ids = [row[0] for row in query_results]
|
||||
return doc_ids
|
||||
|
||||
def write_documents(self, documents: List[dict]):
|
||||
def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None):
|
||||
"""
|
||||
Indexes documents for later queries.
|
||||
|
||||
:param documents: List of dictionaries in the format {"text": "<the-actual-text>"}.
|
||||
Optionally, you can also supply meta data via "meta": {"author": "someone", "url":"some-url" ...}
|
||||
:param documents: a list of Python dictionaries or a list of Haystack Document objects.
|
||||
For documents as dictionaries, the format is {"text": "<the-actual-text>"}.
|
||||
Optionally, you can also supply "tags": ["one-tag", "another-one"]
|
||||
or additional meta data via "meta": {"name": "<some-document-name>, "author": "someone", "url":"some-url" ...}
|
||||
:param index: add an optional index attribute to documents. It can be later used for filtering. For instance,
|
||||
documents for evaluation can be indexed in a separate index than the documents for search.
|
||||
|
||||
:return: None
|
||||
"""
|
||||
|
||||
# Make sure we comply to Document class format
|
||||
documents = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents]
|
||||
index = index or self.index
|
||||
for doc in documents:
|
||||
if "meta" not in doc.keys():
|
||||
doc["meta"] = {}
|
||||
for k, v in doc.items(): # put additional fields other than text in meta
|
||||
if k not in ["text", "meta", "tags"]:
|
||||
doc["meta"][k] = v
|
||||
row = Document(text=doc["text"], meta_data=doc.get("meta", {}))
|
||||
row = DocumentORM(id=doc.id, text=doc.text, meta_data=doc.meta, index=index) # type: ignore
|
||||
self.session.add(row)
|
||||
self.session.commit()
|
||||
|
||||
def get_document_count(self) -> int:
|
||||
return self.session.query(Document).count()
|
||||
def write_labels(self, labels, index=None):
|
||||
|
||||
def _convert_sql_row_to_document(self, row) -> DocumentSchema:
|
||||
document = DocumentSchema(
|
||||
labels = [Label.from_dict(l) if isinstance(l, dict) else l for l in labels]
|
||||
index = index or self.index
|
||||
for label in labels:
|
||||
label_orm = LabelORM(
|
||||
document_id=label.document_id,
|
||||
no_answer=label.no_answer,
|
||||
origin=label.origin,
|
||||
question=label.question,
|
||||
is_correct_answer=label.is_correct_answer,
|
||||
is_correct_document=label.is_correct_document,
|
||||
answer=label.answer,
|
||||
offset_start_in_doc=label.offset_start_in_doc,
|
||||
model_id=label.model_id,
|
||||
index=index,
|
||||
)
|
||||
self.session.add(label_orm)
|
||||
self.session.commit()
|
||||
|
||||
def add_eval_data(self, filename: str, doc_index: str = "document", label_index: str = "label"):
|
||||
"""
|
||||
Adds a SQuAD-formatted file to the DocumentStore in order to be able to perform evaluation on it.
|
||||
|
||||
:param filename: Name of the file containing evaluation data
|
||||
:type filename: str
|
||||
:param doc_index: Elasticsearch index where evaluation documents should be stored
|
||||
:type doc_index: str
|
||||
:param label_index: Elasticsearch index where labeled questions should be stored
|
||||
:type label_index: str
|
||||
"""
|
||||
|
||||
docs, labels = eval_data_from_file(filename)
|
||||
self.write_documents(docs, index=doc_index)
|
||||
self.write_labels(labels, index=label_index)
|
||||
|
||||
def get_document_count(self, index=None) -> int:
|
||||
index = index or self.index
|
||||
return self.session.query(DocumentORM).filter_by(index=index).count()
|
||||
|
||||
def get_label_count(self, index: Optional[str] = None) -> int:
|
||||
index = index or self.index
|
||||
return self.session.query(LabelORM).filter_by(index=index).count()
|
||||
|
||||
def _convert_sql_row_to_document(self, row) -> Document:
|
||||
document = Document(
|
||||
id=row.id,
|
||||
text=row.text,
|
||||
meta=row.meta_data,
|
||||
@ -123,12 +196,38 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
)
|
||||
return document
|
||||
|
||||
def _convert_sql_row_to_label(self, row) -> Label:
|
||||
label = Label(
|
||||
document_id=row.document_id,
|
||||
no_answer=row.no_answer,
|
||||
origin=row.origin,
|
||||
question=row.question,
|
||||
is_correct_answer=row.is_correct_answer,
|
||||
is_correct_document=row.is_correct_document,
|
||||
answer=row.answer,
|
||||
offset_start_in_doc=row.offset_start_in_doc,
|
||||
model_id=row.model_id,
|
||||
)
|
||||
return label
|
||||
|
||||
def query_by_embedding(self,
|
||||
query_emb: List[float],
|
||||
filters: Optional[dict] = None,
|
||||
top_k: int = 10,
|
||||
index: Optional[str] = None) -> List[DocumentSchema]:
|
||||
index: Optional[str] = None) -> List[Document]:
|
||||
|
||||
raise NotImplementedError("SQLDocumentStore is currently not supporting embedding queries. "
|
||||
"Change the query type (e.g. by choosing a different retriever) "
|
||||
"or change the DocumentStore (e.g. to ElasticsearchDocumentStore)")
|
||||
|
||||
def delete_all_documents(self, index=None):
|
||||
"""
|
||||
Delete all documents in a index.
|
||||
|
||||
:param index: index name
|
||||
:return: None
|
||||
"""
|
||||
|
||||
index = index or self.index
|
||||
documents = self.session.query(DocumentORM).filter_by(index=index)
|
||||
documents.delete(synchronize_session=False)
|
||||
|
||||
@ -88,9 +88,16 @@ class Finder:
|
||||
# 2) Format response
|
||||
for doc in documents:
|
||||
#TODO proper calibratation of pseudo probabilities
|
||||
cur_answer = {"question": doc.question, "answer": doc.text, "context": doc.text, # type: ignore
|
||||
"score": doc.query_score, "offset_start": 0, "offset_end": len(doc.text), "meta": doc.meta
|
||||
}
|
||||
cur_answer = {
|
||||
"question": doc.question,
|
||||
"answer": doc.text,
|
||||
"document_id": doc.id,
|
||||
"context": doc.text,
|
||||
"score": doc.query_score,
|
||||
"offset_start": 0,
|
||||
"offset_end": len(doc.text),
|
||||
"meta": doc.meta
|
||||
}
|
||||
if self.retriever.embedding_model: # type: ignore
|
||||
probability = (doc.query_score + 1) / 2 # type: ignore
|
||||
else:
|
||||
@ -103,7 +110,7 @@ class Finder:
|
||||
|
||||
def eval(
|
||||
self,
|
||||
label_index: str = "feedback",
|
||||
label_index: str = "label",
|
||||
doc_index: str = "eval_document",
|
||||
label_origin: str = "gold_label",
|
||||
top_k_retriever: int = 10,
|
||||
@ -154,14 +161,16 @@ class Finder:
|
||||
:param top_k_reader: How many answers to return per question
|
||||
:type top_k_reader: int
|
||||
"""
|
||||
raise NotImplementedError("The Finder evaluation is unavailable in the current Haystack version due to code "
|
||||
"refactoring in-progress. Please use Reader and Retriever evaluation.")
|
||||
|
||||
if not self.reader or not self.retriever:
|
||||
raise Exception("Finder needs to have a reader and retriever for the evaluation.")
|
||||
|
||||
finder_start_time = time.time()
|
||||
# extract all questions for evaluation
|
||||
filter = {"origin": label_origin}
|
||||
questions = self.retriever.document_store.get_all_documents_in_index(index=label_index, filters=filter) # type: ignore
|
||||
filters = {"origin": [label_origin]}
|
||||
questions = self.retriever.document_store.get_all_documents_in_index(index=label_index, filters=filters) # type: ignore
|
||||
|
||||
correct_retrievals = 0
|
||||
summed_avg_precision_retriever = 0
|
||||
@ -193,7 +202,7 @@ class Finder:
|
||||
retrieve_times.append(time.time() - single_retrieve_start)
|
||||
for doc_idx, doc in enumerate(retrieved_docs):
|
||||
# check if correct doc among retrieved docs
|
||||
if doc.meta["doc_id"] == question["_source"]["doc_id"]:
|
||||
if doc.meta["doc_id"] == question["_source"]["doc_id"]: # type: ignore
|
||||
correct_retrievals += 1
|
||||
summed_avg_precision_retriever += 1 / (doc_idx + 1) # type: ignore
|
||||
questions_with_docs.append({
|
||||
|
||||
@ -3,15 +3,62 @@ import tarfile
|
||||
import tempfile
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
import json
|
||||
|
||||
from farm.data_handler.utils import http_get
|
||||
|
||||
from haystack.indexing.file_converters.pdf import PDFToTextConverter
|
||||
from haystack.database.base import Document, Label
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def eval_data_from_file(filename: str) -> Tuple[List[Document], List[Label]]:
|
||||
"""
|
||||
Read Documents + Labels from a SQuAD-style file.
|
||||
Document and Labels can then be indexed to the DocumentStore and be used for evaluation.
|
||||
|
||||
:param filename: Path to file in SQuAD format
|
||||
:return: (List of Documents, List of Labels)
|
||||
"""
|
||||
docs = []
|
||||
labels = []
|
||||
|
||||
with open(filename, "r") as file:
|
||||
data = json.load(file)
|
||||
for document in data["data"]:
|
||||
# get all extra fields from document level (e.g. title)
|
||||
meta_doc = {k: v for k, v in document.items() if k not in ("paragraphs", "title")}
|
||||
for paragraph in document["paragraphs"]:
|
||||
cur_meta = {"name": document["title"]}
|
||||
# all other fields from paragraph level
|
||||
meta_paragraph = {k: v for k, v in paragraph.items() if k not in ("qas", "context")}
|
||||
cur_meta.update(meta_paragraph)
|
||||
# meta from parent document
|
||||
cur_meta.update(meta_doc)
|
||||
# Create Document
|
||||
cur_doc = Document(text=paragraph["context"], meta=cur_meta)
|
||||
docs.append(cur_doc)
|
||||
|
||||
# Get Labels
|
||||
for qa in paragraph["qas"]:
|
||||
for answer in qa["answers"]:
|
||||
label = Label(
|
||||
question=qa["question"],
|
||||
answer=answer["text"],
|
||||
is_correct_answer=True,
|
||||
is_correct_document=True,
|
||||
document_id=cur_doc.id,
|
||||
offset_start_in_doc=answer["answer_start"],
|
||||
no_answer=qa["is_impossible"],
|
||||
origin="gold_label",
|
||||
)
|
||||
labels.append(label)
|
||||
|
||||
return docs, labels
|
||||
|
||||
|
||||
def convert_files_to_dicts(dir_path: str, clean_func: Optional[Callable] = None, split_paragraphs: bool = False) -> List[dict]:
|
||||
"""
|
||||
Convert all files(.txt, .pdf) in the sub-directories of the given path to Python dicts that can be written to a
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
import multiprocessing
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union, Dict, Any
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
from farm.data_handler.data_silo import DataSilo
|
||||
@ -18,8 +19,7 @@ from farm.utils import set_all_seeds, initialize_device_settings
|
||||
from scipy.special import expit
|
||||
import shutil
|
||||
|
||||
from haystack.database.base import Document
|
||||
from haystack.database.elasticsearch import ElasticsearchDocumentStore
|
||||
from haystack.database.base import Document, BaseDocumentStore
|
||||
from haystack.reader.base import BaseReader
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -351,68 +351,76 @@ class FARMReader(BaseReader):
|
||||
return results
|
||||
|
||||
def eval(
|
||||
self,
|
||||
document_store: ElasticsearchDocumentStore,
|
||||
device: str,
|
||||
label_index: str = "feedback",
|
||||
doc_index: str = "eval_document",
|
||||
label_origin: str = "gold_label",
|
||||
self,
|
||||
document_store: BaseDocumentStore,
|
||||
device: str,
|
||||
label_index: str = "label",
|
||||
doc_index: str = "eval_document",
|
||||
label_origin: str = "gold_label",
|
||||
):
|
||||
"""
|
||||
Performs evaluation on evaluation documents in Elasticsearch DocumentStore.
|
||||
Performs evaluation on evaluation documents in the DocumentStore.
|
||||
|
||||
Returns a dict containing the following metrics:
|
||||
- "EM": Proportion of exact matches of predicted answers with their corresponding correct answers
|
||||
- "f1": Average overlap between predicted answers and their corresponding correct answers
|
||||
- "top_n_accuracy": Proportion of predicted answers that match with correct answer
|
||||
|
||||
:param document_store: The ElasticsearchDocumentStore containing the evaluation documents
|
||||
:type document_store: ElasticsearchDocumentStore
|
||||
:param document_store: DocumentStore containing the evaluation documents
|
||||
:param device: The device on which the tensors should be processed. Choose from "cpu" and "cuda".
|
||||
:type device: str
|
||||
:param label_index: Elasticsearch index where labeled questions are stored
|
||||
:type label_index: str
|
||||
:param doc_index: Elasticsearch index where documents that are used for evaluation are stored
|
||||
:type doc_index: str
|
||||
:param label_index: Index/Table name where labeled questions are stored
|
||||
:param doc_index: Index/Table name where documents that are used for evaluation are stored
|
||||
"""
|
||||
|
||||
# extract all questions for evaluation
|
||||
filter = {"origin": label_origin}
|
||||
questions = document_store.get_all_documents_in_index(index=label_index, filters=filter)
|
||||
filters = {"origin": [label_origin]}
|
||||
|
||||
# mapping from doc_id to questions
|
||||
doc_questions_dict = {}
|
||||
id = 0
|
||||
for question in questions:
|
||||
doc_id = question["_source"]["doc_id"]
|
||||
if doc_id not in doc_questions_dict:
|
||||
doc_questions_dict[doc_id] = [{
|
||||
"id": id,
|
||||
"question" : question["_source"]["question"],
|
||||
"answers" : question["_source"]["answers"],
|
||||
"is_impossible" : False if question["_source"]["answers"] else True
|
||||
}]
|
||||
else:
|
||||
doc_questions_dict[doc_id].append({
|
||||
"id": id,
|
||||
"question" : question["_source"]["question"],
|
||||
"answers" : question["_source"]["answers"],
|
||||
"is_impossible" : False if question["_source"]["answers"] else True
|
||||
})
|
||||
id += 1
|
||||
labels = document_store.get_all_labels(index=label_index, filters=filters)
|
||||
|
||||
# extract eval documents and convert data back to SQuAD-like format
|
||||
documents = document_store.get_all_documents_in_index(index=doc_index)
|
||||
dicts = []
|
||||
for document in documents:
|
||||
doc_id = document["_source"]["doc_id"]
|
||||
text = document["_source"]["text"]
|
||||
questions = doc_questions_dict[doc_id]
|
||||
dicts.append({"qas" : questions, "context" : text})
|
||||
# Aggregate all answer labels per question
|
||||
aggregated_per_doc = defaultdict(list)
|
||||
for label in labels:
|
||||
if not label.document_id:
|
||||
logger.error(f"Label does not contain a document_id")
|
||||
continue
|
||||
aggregated_per_doc[label.document_id].append(label)
|
||||
|
||||
# Create squad style dicts
|
||||
d: Dict[str, Any] = {}
|
||||
for doc_id in aggregated_per_doc.keys():
|
||||
doc = document_store.get_document_by_id(doc_id, index=doc_index)
|
||||
if not doc:
|
||||
logger.error(f"Document with the ID '{doc_id}' is not present in the document store.")
|
||||
continue
|
||||
d[str(doc_id)] = {
|
||||
"context": doc.text
|
||||
}
|
||||
# get all questions / answers
|
||||
aggregated_per_question: Dict[str, Any] = defaultdict(list)
|
||||
for label in aggregated_per_doc[doc_id]:
|
||||
# add to existing answers
|
||||
if label.question in aggregated_per_question.keys():
|
||||
aggregated_per_question[label.question]["answers"].append({
|
||||
"text": label.answer,
|
||||
"answer_start": label.offset_start_in_doc})
|
||||
# create new one
|
||||
else:
|
||||
aggregated_per_question[label.question] = {
|
||||
"id": str(hash(str(doc_id)+label.question)),
|
||||
"question": label.question,
|
||||
"answers": [{
|
||||
"text": label.answer,
|
||||
"answer_start": label.offset_start_in_doc}]
|
||||
}
|
||||
# Get rid of the question key again (after we aggregated we don't need it anymore)
|
||||
d[str(doc_id)]["qas"] = [v for v in aggregated_per_question.values()]
|
||||
|
||||
# Convert input format for FARM
|
||||
farm_input = [v for v in d.values()]
|
||||
|
||||
# Create DataLoader that can be passed to the Evaluator
|
||||
indices = range(len(dicts))
|
||||
dataset, tensor_names = self.inferencer.processor.dataset_from_dicts(dicts, indices=indices)
|
||||
indices = range(len(farm_input))
|
||||
dataset, tensor_names = self.inferencer.processor.dataset_from_dicts(farm_input, indices=indices)
|
||||
data_loader = NamedDataLoader(dataset=dataset, batch_size=self.inferencer.batch_size, tensor_names=tensor_names)
|
||||
|
||||
evaluator = Evaluator(data_loader=data_loader, tasks=self.inferencer.processor.tasks, device=device)
|
||||
@ -466,10 +474,9 @@ class FARMReader(BaseReader):
|
||||
|
||||
def predict_on_texts(self, question: str, texts: List[str], top_k: Optional[int] = None):
|
||||
documents = []
|
||||
for i, text in enumerate(texts):
|
||||
for text in texts:
|
||||
documents.append(
|
||||
Document(
|
||||
id=i,
|
||||
text=text
|
||||
)
|
||||
)
|
||||
|
||||
@ -1,11 +1,91 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
|
||||
from haystack.database.base import Document
|
||||
from haystack.database.base import BaseDocumentStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseRetriever(ABC):
|
||||
document_store: BaseDocumentStore
|
||||
|
||||
@abstractmethod
|
||||
def retrieve(self, query: str, filters: dict = None, top_k: int = 10, index: str = None) -> List[Document]:
|
||||
pass
|
||||
|
||||
def eval(
|
||||
self,
|
||||
label_index: str = "label",
|
||||
doc_index: str = "eval_document",
|
||||
label_origin: str = "gold_label",
|
||||
top_k: int = 10,
|
||||
open_domain: bool = False
|
||||
) -> dict:
|
||||
"""
|
||||
Performs evaluation on the Retriever.
|
||||
Retriever is evaluated based on whether it finds the correct document given the question string and at which
|
||||
position in the ranking of documents the correct document is.
|
||||
|
||||
Returns a dict containing the following metrics:
|
||||
- "recall": Proportion of questions for which correct document is among retrieved documents
|
||||
- "mean avg precision": Mean of average precision for each question. Rewards retrievers that give relevant
|
||||
documents a higher rank.
|
||||
|
||||
:param label_index: Index/Table in DocumentStore where labeled questions are stored
|
||||
:param doc_index: Index/Table in DocumentStore where documents that are used for evaluation are stored
|
||||
:param top_k: How many documents to return per question
|
||||
:param open_domain: If true, retrieval will be evaluated by checking if the answer string to a question is
|
||||
contained in the retrieved docs (common approach in open-domain QA).
|
||||
If false, retrieval uses a stricter evaluation that checks if the retrieved document ids
|
||||
are within ids explicitly stated in the labels.
|
||||
"""
|
||||
|
||||
# Extract all questions for evaluation
|
||||
filters = {"origin": [label_origin]}
|
||||
|
||||
labels = self.document_store.get_all_labels(index=label_index, filters=filters)
|
||||
|
||||
correct_retrievals = 0
|
||||
summed_avg_precision = 0
|
||||
|
||||
# Aggregate all positive document ids / answers per question
|
||||
aggregated_labels = defaultdict(set)
|
||||
for label in labels:
|
||||
if open_domain:
|
||||
aggregated_labels[label.question].add(label.answer)
|
||||
else:
|
||||
aggregated_labels[label.question].add(str(label.document_id))
|
||||
|
||||
# Option 1: Open-domain evaluation by checking if the answer string is in the retrieved docs
|
||||
if open_domain:
|
||||
for question, gold_answers in aggregated_labels.items():
|
||||
retrieved_docs = self.retrieve(question, top_k=top_k, index=doc_index)
|
||||
# check if correct doc in retrieved docs
|
||||
for doc_idx, doc in enumerate(retrieved_docs):
|
||||
for gold_answer in gold_answers:
|
||||
if gold_answer in doc.text:
|
||||
correct_retrievals += 1
|
||||
summed_avg_precision += 1 / (doc_idx + 1) # type: ignore
|
||||
break
|
||||
# Option 2: Strict evaluation by document ids that are listed in the labels
|
||||
else:
|
||||
for question, gold_ids in aggregated_labels.items():
|
||||
retrieved_docs = self.retrieve(question, top_k=top_k, index=doc_index)
|
||||
# check if correct doc in retrieved docs
|
||||
for doc_idx, doc in enumerate(retrieved_docs):
|
||||
if str(doc.id) in gold_ids:
|
||||
correct_retrievals += 1
|
||||
summed_avg_precision += 1 / (doc_idx + 1) # type: ignore
|
||||
break
|
||||
# Metrics
|
||||
number_of_questions = len(aggregated_labels)
|
||||
recall = correct_retrievals / number_of_questions
|
||||
mean_avg_precision = summed_avg_precision / number_of_questions
|
||||
|
||||
logger.info((f"For {correct_retrievals} out of {number_of_questions} questions ({recall:.2%}), the answer was in"
|
||||
f" the top-{top_k} candidate passages selected by the retriever."))
|
||||
|
||||
return {"recall": recall, "map": mean_avg_precision}
|
||||
@ -194,7 +194,7 @@ class DensePassageRetriever(BaseRetriever):
|
||||
class EmbeddingRetriever(BaseRetriever):
|
||||
def __init__(
|
||||
self,
|
||||
document_store: ElasticsearchDocumentStore,
|
||||
document_store: BaseDocumentStore,
|
||||
embedding_model: str,
|
||||
use_gpu: bool = True,
|
||||
model_format: str = "farm",
|
||||
|
||||
@ -46,7 +46,7 @@ class ElasticsearchRetriever(BaseRetriever):
|
||||
self.retrieve(query="Why did the revenue increase?",
|
||||
filters={"years": ["2019"], "quarters": ["Q1", "Q2"]})
|
||||
"""
|
||||
self.document_store = document_store # type: ignore
|
||||
self.document_store: ElasticsearchDocumentStore = document_store
|
||||
self.custom_query = custom_query
|
||||
|
||||
def retrieve(self, query: str, filters: dict = None, top_k: int = 10, index: str = None) -> List[Document]:
|
||||
@ -58,54 +58,6 @@ class ElasticsearchRetriever(BaseRetriever):
|
||||
|
||||
return documents
|
||||
|
||||
def eval(
|
||||
self,
|
||||
label_index: str = "feedback",
|
||||
doc_index: str = "eval_document",
|
||||
label_origin: str = "gold_label",
|
||||
top_k: int = 10,
|
||||
) -> dict:
|
||||
"""
|
||||
Performs evaluation on the Retriever.
|
||||
Retriever is evaluated based on whether it finds the correct document given the question string and at which
|
||||
position in the ranking of documents the correct document is.
|
||||
|
||||
Returns a dict containing the following metrics:
|
||||
- "recall": Proportion of questions for which correct document is among retrieved documents
|
||||
- "mean avg precision": Mean of average precision for each question. Rewards retrievers that give relevant
|
||||
documents a higher rank.
|
||||
|
||||
:param label_index: Index/Table in DocumentStore where labeled questions are stored
|
||||
:param doc_index: Index/Table in DocumentStore where documents that are used for evaluation are stored
|
||||
:param top_k: How many documents to return per question
|
||||
"""
|
||||
|
||||
# extract all questions for evaluation
|
||||
filter = {"origin": label_origin}
|
||||
questions = self.document_store.get_all_documents_in_index(index=label_index, filters=filter)
|
||||
|
||||
# calculate recall and mean-average-precision
|
||||
correct_retrievals = 0
|
||||
summed_avg_precision = 0
|
||||
for q_idx, question in enumerate(questions):
|
||||
question_string = question["_source"]["question"]
|
||||
retrieved_docs = self.retrieve(question_string, top_k=top_k, index=doc_index)
|
||||
# check if correct doc in retrieved docs
|
||||
for doc_idx, doc in enumerate(retrieved_docs):
|
||||
if doc.meta["doc_id"] == question["_source"]["doc_id"]:
|
||||
correct_retrievals += 1
|
||||
summed_avg_precision += 1 / (doc_idx + 1) # type: ignore
|
||||
break
|
||||
|
||||
number_of_questions = q_idx + 1
|
||||
recall = correct_retrievals / number_of_questions
|
||||
mean_avg_precision = summed_avg_precision / number_of_questions
|
||||
|
||||
logger.info((f"For {correct_retrievals} out of {number_of_questions} questions ({recall:.2%}), the answer was in"
|
||||
f" the top-{top_k} candidate passages selected by the retriever."))
|
||||
|
||||
return {"recall": recall, "map": mean_avg_precision}
|
||||
|
||||
|
||||
class ElasticsearchFilterOnlyRetriever(ElasticsearchRetriever):
|
||||
"""
|
||||
|
||||
@ -4,7 +4,8 @@ import logging
|
||||
import pprint
|
||||
import pandas as pd
|
||||
from typing import Dict, Any, List
|
||||
from haystack.database.sql import Document
|
||||
from haystack.database.sql import DocumentORM
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -79,7 +80,7 @@ def convert_labels_to_squad(labels_file: str):
|
||||
for document_id, labels in labels_grouped_by_documents.items():
|
||||
qas = []
|
||||
for label in labels:
|
||||
doc = Document.query.get(label["document_id"])
|
||||
doc = DocumentORM.query.get(label["document_id"])
|
||||
|
||||
assert (
|
||||
doc.text[label["start_offset"] : label["end_offset"]]
|
||||
|
||||
@ -15,4 +15,5 @@ langdetect # for PDF conversions
|
||||
#temporarily (used for DPR downloads)
|
||||
wget
|
||||
python-multipart
|
||||
python-docx
|
||||
python-docx
|
||||
sqlalchemy_utils
|
||||
@ -2,11 +2,10 @@ import logging
|
||||
|
||||
import uvicorn
|
||||
from elasticapm.contrib.starlette import make_apm_client, ElasticAPM
|
||||
from elasticsearch import Elasticsearch
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
from rest_api.config import DB_HOST, DB_USER, DB_PW, DB_PORT, ES_CONN_SCHEME, APM_SERVER, APM_SERVICE_NAME
|
||||
from rest_api.config import APM_SERVER, APM_SERVICE_NAME
|
||||
from rest_api.controller.errors.http_error import http_error_handler
|
||||
from rest_api.controller.router import router as api_router
|
||||
|
||||
@ -14,10 +13,6 @@ logging.basicConfig(format="%(asctime)s %(message)s", datefmt="%m/%d/%Y %I:%M:%S
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.getLogger("elasticsearch").setLevel(logging.WARNING)
|
||||
|
||||
elasticsearch_client = Elasticsearch(
|
||||
hosts=[{"host": DB_HOST, "port": DB_PORT}], http_auth=(DB_USER, DB_PW), scheme=ES_CONN_SCHEME, ca_certs=False, verify_certs=False
|
||||
)
|
||||
|
||||
|
||||
def get_application() -> FastAPI:
|
||||
application = FastAPI(title="Haystack-API", debug=True, version="0.1")
|
||||
|
||||
@ -17,7 +17,7 @@ DB_PORT = int(os.getenv("DB_PORT", 9200))
|
||||
DB_USER = os.getenv("DB_USER", "")
|
||||
DB_PW = os.getenv("DB_PW", "")
|
||||
DB_INDEX = os.getenv("DB_INDEX", "document")
|
||||
DB_INDEX_FEEDBACK = os.getenv("DB_INDEX_FEEDBACK", "feedback")
|
||||
DB_INDEX_FEEDBACK = os.getenv("DB_INDEX_FEEDBACK", "label")
|
||||
ES_CONN_SCHEME = os.getenv("ES_CONN_SCHEME", "http")
|
||||
TEXT_FIELD_NAME = os.getenv("TEXT_FIELD_NAME", "text")
|
||||
SEARCH_FIELD_NAME = os.getenv("SEARCH_FIELD_NAME", "text")
|
||||
|
||||
@ -1,11 +1,9 @@
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
||||
from elasticsearch.helpers import scan
|
||||
from fastapi import APIRouter, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from haystack.database.elasticsearch import ElasticsearchDocumentStore
|
||||
from rest_api.config import (
|
||||
DB_HOST,
|
||||
DB_PORT,
|
||||
@ -18,10 +16,9 @@ from rest_api.config import (
|
||||
EMBEDDING_DIM,
|
||||
EMBEDDING_FIELD_NAME,
|
||||
EXCLUDE_META_DATA_FIELDS,
|
||||
FAQ_QUESTION_FIELD_NAME,
|
||||
)
|
||||
from rest_api.config import DB_INDEX_FEEDBACK
|
||||
from rest_api.elasticsearch_client import elasticsearch_client
|
||||
from haystack.database.elasticsearch import ElasticsearchDocumentStore
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@ -36,66 +33,77 @@ document_store = ElasticsearchDocumentStore(
|
||||
verify_certs=False,
|
||||
text_field=TEXT_FIELD_NAME,
|
||||
search_fields=SEARCH_FIELD_NAME,
|
||||
faq_question_field=FAQ_QUESTION_FIELD_NAME,
|
||||
embedding_dim=EMBEDDING_DIM,
|
||||
embedding_field=EMBEDDING_FIELD_NAME,
|
||||
excluded_meta_data=EXCLUDE_META_DATA_FIELDS, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
class Feedback(BaseModel):
|
||||
class FAQQAFeedback(BaseModel):
|
||||
question: str = Field(..., description="The question input by the user, i.e., the query.")
|
||||
label: str = Field(..., description="The Label for the feedback, eg, relevant or irrelevant.")
|
||||
is_correct_answer: bool = Field(..., description="Whether the answer is correct or not.")
|
||||
document_id: str = Field(..., description="The document in the query result for which feedback is given.")
|
||||
answer: Optional[str] = Field(None, description="The answer string. Only required for doc-qa feedback.")
|
||||
offset_start_in_doc: Optional[int] = Field(None, description="The answer start offset in the original doc. Only required for doc-qa feedback.")
|
||||
model_id: Optional[int] = Field(None, description="The model used for the query.")
|
||||
|
||||
|
||||
class DocQAFeedback(FAQQAFeedback):
|
||||
is_correct_document: bool = Field(
|
||||
...,
|
||||
description="In case of negative feedback, there could be two cases; incorrect answer but correct "
|
||||
"document & incorrect document. This flag denotes if the returned document was correct.",
|
||||
)
|
||||
answer: str = Field(..., description="The answer string.")
|
||||
offset_start_in_doc: int = Field(
|
||||
..., description="The answer start offset in the original doc. Only required for doc-qa feedback."
|
||||
)
|
||||
|
||||
|
||||
@router.post("/doc-qa-feedback")
|
||||
def doc_qa_feedback(feedback: Feedback):
|
||||
if feedback.answer and feedback.offset_start_in_doc:
|
||||
elasticsearch_client.index(index=DB_INDEX_FEEDBACK, body=feedback.dict())
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content="doc-qa feedback must contain 'answer' and 'answer_doc_start' fields.",
|
||||
)
|
||||
def doc_qa_feedback(feedback: DocQAFeedback):
|
||||
document_store.write_labels([{"origin": "user-feedback", **feedback.dict()}])
|
||||
|
||||
|
||||
@router.post("/faq-qa-feedback")
|
||||
def faq_qa_feedback(feedback: Feedback):
|
||||
elasticsearch_client.index(index=DB_INDEX_FEEDBACK, body=feedback.dict())
|
||||
def faq_qa_feedback(feedback: FAQQAFeedback):
|
||||
feedback_payload = {"is_correct_document": feedback.is_correct_answer, "answer": None, **feedback.dict()}
|
||||
document_store.write_labels([{"origin": "user-feedback-faq", **feedback_payload}])
|
||||
|
||||
|
||||
@router.get("/export-doc-qa-feedback")
|
||||
def export_doc_qa_feedback():
|
||||
def export_doc_qa_feedback(context_size: int = 2_000):
|
||||
"""
|
||||
SQuAD format JSON export for question/answer pairs that were marked as "relevant".
|
||||
|
||||
The context_size param can be used to limit response size for large documents.
|
||||
"""
|
||||
#TODO filter out faq-qa feedback.
|
||||
#TODO Reduce length of context for large documents
|
||||
|
||||
relevant_feedback_query = {"query": {"bool": {"must": [{"term": {"label": "relevant"}}]}}}
|
||||
result = scan(elasticsearch_client, index=DB_INDEX_FEEDBACK, query=relevant_feedback_query)
|
||||
|
||||
per_document_feedback = defaultdict(list)
|
||||
for feedback in result:
|
||||
document_id = feedback["_source"]["document_id"]
|
||||
per_document_feedback[document_id].append(
|
||||
{
|
||||
"question": feedback["_source"]["question"],
|
||||
"id": feedback["_id"],
|
||||
"answers": [
|
||||
{"text": feedback["_source"]["answer"], "answer_start": feedback["_source"]["offset_start_in_doc"]}
|
||||
],
|
||||
}
|
||||
)
|
||||
labels = document_store.get_all_labels(
|
||||
index=DB_INDEX_FEEDBACK, filters={"is_correct_answer": [True], "origin": ["user-feedback"]}
|
||||
)
|
||||
|
||||
export_data = []
|
||||
for document_id, feedback in per_document_feedback.items():
|
||||
document = document_store.get_document_by_id(document_id)
|
||||
context = document.text
|
||||
export_data.append({"paragraphs": [{"qas": feedback, "context": context}],})
|
||||
for label in labels:
|
||||
document = document_store.get_document_by_id(label.document_id)
|
||||
text = document.text
|
||||
|
||||
# the final length of context(including the answer string) is 'context_size'.
|
||||
# we try to add equal characters for context before and after the answer string.
|
||||
# if either beginning or end of text is reached, we correspondingly
|
||||
# append more context characters at the other end of answer string.
|
||||
context_to_add = int((context_size - len(label.answer)) / 2)
|
||||
|
||||
start_pos = max(label.offset_start_in_doc - context_to_add, 0)
|
||||
additional_context_at_end = max(context_to_add - label.offset_start_in_doc, 0)
|
||||
|
||||
end_pos = min(label.offset_start_in_doc + len(label.answer) + context_to_add, len(text) - 1)
|
||||
additional_context_at_start = max(label.offset_start_in_doc + len(label.answer) + context_to_add - len(text), 0)
|
||||
|
||||
start_pos = max(0, start_pos - additional_context_at_start)
|
||||
end_pos = min(len(text) - 1, end_pos + additional_context_at_end)
|
||||
|
||||
context_to_export = text[start_pos:end_pos]
|
||||
|
||||
export_data.append({"paragraphs": [{"qas": label, "context": context_to_export}]})
|
||||
|
||||
export = {"data": export_data}
|
||||
|
||||
@ -107,24 +115,19 @@ def export_faq_feedback():
|
||||
"""
|
||||
Export feedback for faq-qa in JSON format.
|
||||
"""
|
||||
result = scan(elasticsearch_client, index=DB_INDEX_FEEDBACK)
|
||||
|
||||
per_document_feedback = defaultdict(list)
|
||||
for feedback in result:
|
||||
document_id = feedback["_source"]["document_id"]
|
||||
question = feedback["_source"]["question"]
|
||||
feedback_id = feedback["_id"]
|
||||
feedback_label = feedback["_source"]["label"]
|
||||
per_document_feedback[document_id].append(
|
||||
{"question": question, "id": feedback_id, "feedback_label": feedback_label}
|
||||
)
|
||||
labels = document_store.get_all_labels(index=DB_INDEX_FEEDBACK, filters={"origin": ["user-feedback-faq"]})
|
||||
|
||||
export_data = []
|
||||
for document_id, feedback in per_document_feedback.items():
|
||||
document = document_store.get_document_by_id(document_id)
|
||||
export_data.append(
|
||||
{"target_question": document.question, "target_answer": document.text, "queries": feedback}
|
||||
)
|
||||
for label in labels:
|
||||
document = document_store.get_document_by_id(label.document_id)
|
||||
feedback = {
|
||||
"question": document.question,
|
||||
"query": label.question,
|
||||
"is_correct_answer": label.is_correct_answer,
|
||||
"is_correct_document": label.is_correct_answer,
|
||||
}
|
||||
export_data.append(feedback)
|
||||
|
||||
export = {"data": export_data}
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@ import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Optional
|
||||
from uuid import UUID
|
||||
|
||||
import elasticapm
|
||||
from fastapi import APIRouter
|
||||
@ -115,7 +116,7 @@ class Answer(BaseModel):
|
||||
offset_end: int
|
||||
offset_start_in_doc: Optional[int]
|
||||
offset_end_in_doc: Optional[int]
|
||||
document_id: Optional[str] = None
|
||||
document_id: Optional[UUID] = None
|
||||
meta: Optional[Dict[str, Optional[str]]]
|
||||
|
||||
|
||||
|
||||
@ -1,7 +0,0 @@
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from rest_api.config import DB_HOST, DB_USER, DB_PW, DB_PORT, ES_CONN_SCHEME
|
||||
|
||||
elasticsearch_client = Elasticsearch(
|
||||
hosts=[{"host": DB_HOST, "port": DB_PORT}], http_auth=(DB_USER, DB_PW), scheme=ES_CONN_SCHEME, ca_certs=False, verify_certs=False
|
||||
)
|
||||
@ -56,10 +56,12 @@ def xpdf_fixture():
|
||||
@pytest.fixture()
|
||||
def test_docs_xs():
|
||||
return [
|
||||
# current "dict" format for a document
|
||||
{"text": "My name is Carla and I live in Berlin", "meta": {"meta_field": "test1", "name": "filename1"}},
|
||||
{"text": "My name is Paul and I live in New York", "meta": {"meta_field": "test2", "name": "filename2"}},
|
||||
{"text": "My name is Christelle and I live in Paris", "meta_field": "test3", "meta": {"name": "filename3"}}
|
||||
# last doc has meta_field at the top level for backward compatibility
|
||||
# meta_field at the top level for backward compatibility
|
||||
{"text": "My name is Paul and I live in New York", "meta_field": "test2", "name": "filename2"},
|
||||
# Document object for a doc
|
||||
Document(text="My name is Christelle and I live in Paris", meta={"meta_field": "test3", "name": "filename3"})
|
||||
]
|
||||
|
||||
|
||||
@ -89,20 +91,14 @@ def no_answer_reader(request):
|
||||
|
||||
@pytest.fixture()
|
||||
def prediction(reader, test_docs_xs):
|
||||
docs = []
|
||||
for d in test_docs_xs:
|
||||
doc = Document(id=d["meta"]["name"], text=d["text"], meta=d["meta"])
|
||||
docs.append(doc)
|
||||
docs = [Document.from_dict(d) if isinstance(d, dict) else d for d in test_docs_xs]
|
||||
prediction = reader.predict(question="Who lives in Berlin?", documents=docs, top_k=5)
|
||||
return prediction
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def no_answer_prediction(no_answer_reader, test_docs_xs):
|
||||
docs = []
|
||||
for d in test_docs_xs:
|
||||
doc = Document(id=d["meta"]["name"], text=d["text"], meta=d["meta"])
|
||||
docs.append(doc)
|
||||
docs = [Document.from_dict(d) if isinstance(d, dict) else d for d in test_docs_xs]
|
||||
prediction = no_answer_reader.predict(question="What is the meaning of life?", documents=docs, top_k=5)
|
||||
return prediction
|
||||
|
||||
@ -129,3 +125,22 @@ def document_store_with_docs(request, test_docs_xs, elasticsearch_fixture):
|
||||
time.sleep(2)
|
||||
|
||||
return document_store
|
||||
|
||||
|
||||
@pytest.fixture(params=["sql", "memory", "elasticsearch"])
|
||||
def document_store(request, test_docs_xs, elasticsearch_fixture):
|
||||
if request.param == "sql":
|
||||
if os.path.exists("qa_test.db"):
|
||||
os.remove("qa_test.db")
|
||||
document_store = SQLDocumentStore(url="sqlite:///qa_test.db")
|
||||
|
||||
if request.param == "memory":
|
||||
document_store = InMemoryDocumentStore()
|
||||
|
||||
if request.param == "elasticsearch":
|
||||
# make sure we start from a fresh index
|
||||
client = Elasticsearch()
|
||||
client.indices.delete(index='haystack_test', ignore=[404])
|
||||
document_store = ElasticsearchDocumentStore(index="haystack_test")
|
||||
|
||||
return document_store
|
||||
|
||||
10021
test/samples/squad/small.json
Normal file
10021
test/samples/squad/small.json
Normal file
File diff suppressed because it is too large
Load Diff
54
test/samples/squad/tiny.json
Normal file
54
test/samples/squad/tiny.json
Normal file
@ -0,0 +1,54 @@
|
||||
{
|
||||
"data": [
|
||||
{
|
||||
"title": "test1",
|
||||
"paragraphs": [
|
||||
{
|
||||
"context": "My name is Carla and I live together with Abdul in Berlin",
|
||||
"qas": [
|
||||
{
|
||||
"answers": [
|
||||
{
|
||||
"answer_start": 11,
|
||||
"text": "Carla"
|
||||
},
|
||||
{
|
||||
"answer_start": 42,
|
||||
"text": "Abdul"
|
||||
}
|
||||
],
|
||||
"id": 7211011040021040393,
|
||||
"question": "Who lives in Berlin?",
|
||||
"is_impossible": false
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"title": "test2",
|
||||
"paragraphs": [
|
||||
{
|
||||
"context": "This is another test context",
|
||||
"qas": [
|
||||
{
|
||||
"answers": [
|
||||
{
|
||||
"answer_start": 0,
|
||||
"text": "This"
|
||||
},
|
||||
{
|
||||
"answer_start": 5,
|
||||
"text": "is"
|
||||
}
|
||||
],
|
||||
"id": -5782547119306399562,
|
||||
"question": "The model can't answer this",
|
||||
"is_impossible": false
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
@ -2,5 +2,5 @@ from haystack.database.base import Document
|
||||
|
||||
|
||||
def test_document_data_access():
|
||||
doc = Document(id=1, text="test")
|
||||
doc = Document(text="test")
|
||||
assert doc.text == "test"
|
||||
|
||||
@ -3,7 +3,7 @@ from haystack.retriever.dense import DensePassageRetriever
|
||||
|
||||
|
||||
def test_dpr_inmemory_retrieval():
|
||||
document_store = InMemoryDocumentStore(embedding_field="embedding")
|
||||
document_store = InMemoryDocumentStore()
|
||||
|
||||
documents = [
|
||||
{'name': '0', 'text': """Aaron Aaron ( or ; ""Ahärôn"") is a prophet, high priest, and the brother of Moses in the Abrahamic religions. Knowledge of Aaron, along with his brother Moses, comes exclusively from religious texts, such as the Bible and Quran. The Hebrew Bible relates that, unlike Moses, who grew up in the Egyptian royal court, Aaron and his elder sister Miriam remained with their kinsmen in the eastern border-land of Egypt (Goshen). When Moses first confronted the Egyptian king about the Israelites, Aaron served as his brother's spokesman (""prophet"") to the Pharaoh. Part of the Law (Torah) that Moses received from"""},
|
||||
|
||||
80
test/test_eval.py
Normal file
80
test/test_eval.py
Normal file
@ -0,0 +1,80 @@
|
||||
import pytest
|
||||
from haystack.database.base import BaseDocumentStore
|
||||
from haystack.retriever.sparse import ElasticsearchRetriever
|
||||
|
||||
|
||||
def test_add_eval_data(document_store):
|
||||
# add eval data (SQUAD format)
|
||||
document_store.delete_all_documents(index="test_eval_document")
|
||||
document_store.delete_all_documents(index="test_feedback")
|
||||
document_store.add_eval_data(filename="samples/squad/small.json", doc_index="test_eval_document", label_index="test_feedback")
|
||||
|
||||
assert document_store.get_document_count(index="test_eval_document") == 87
|
||||
assert document_store.get_label_count(index="test_feedback") == 881
|
||||
|
||||
# test documents
|
||||
docs = document_store.get_all_documents(index="test_eval_document")
|
||||
assert docs[0].text[:10] == "The Norman"
|
||||
assert docs[0].meta["name"] == "Normans"
|
||||
assert len(docs[0].meta.keys()) == 1
|
||||
|
||||
# test labels
|
||||
labels = document_store.get_all_labels(index="test_feedback")
|
||||
assert labels[0].answer == "France"
|
||||
assert labels[0].no_answer == False
|
||||
assert labels[0].is_correct_answer == True
|
||||
assert labels[0].is_correct_document == True
|
||||
assert labels[0].question == 'In what country is Normandy located?'
|
||||
assert labels[0].origin == "gold_label"
|
||||
assert labels[0].offset_start_in_doc == 159
|
||||
|
||||
# check combination
|
||||
assert labels[0].document_id == docs[0].id
|
||||
start = labels[0].offset_start_in_doc
|
||||
end = start+len(labels[0].answer)
|
||||
assert docs[0].text[start:end] == "France"
|
||||
|
||||
# clean up
|
||||
document_store.delete_all_documents(index="test_eval_document")
|
||||
document_store.delete_all_documents(index="test_feedback")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||
def test_eval_reader(reader, document_store: BaseDocumentStore):
|
||||
# add eval data (SQUAD format)
|
||||
document_store.delete_all_documents(index="test_eval_document")
|
||||
document_store.delete_all_documents(index="test_feedback")
|
||||
document_store.add_eval_data(filename="samples/squad/tiny.json", doc_index="test_eval_document", label_index="test_feedback")
|
||||
assert document_store.get_document_count(index="test_eval_document") == 2
|
||||
# eval reader
|
||||
reader_eval_results = reader.eval(document_store=document_store, label_index="test_feedback",
|
||||
doc_index="test_eval_document", device="cpu")
|
||||
assert reader_eval_results["f1"] > 0.65
|
||||
assert reader_eval_results["f1"] < 0.67
|
||||
assert reader_eval_results["EM"] == 0.5
|
||||
assert reader_eval_results["top_n_accuracy"] == 1.0
|
||||
|
||||
# clean up
|
||||
document_store.delete_all_documents(index="test_eval_document")
|
||||
document_store.delete_all_documents(index="test_feedback")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True)
|
||||
@pytest.mark.parametrize("open_domain", [True, False])
|
||||
def test_eval_elastic_retriever(document_store: BaseDocumentStore, open_domain):
|
||||
retriever = ElasticsearchRetriever(document_store=document_store)
|
||||
|
||||
# add eval data (SQUAD format)
|
||||
document_store.delete_all_documents(index="test_eval_document")
|
||||
document_store.delete_all_documents(index="test_feedback")
|
||||
document_store.add_eval_data(filename="samples/squad/tiny.json", doc_index="test_eval_document", label_index="test_feedback")
|
||||
assert document_store.get_document_count(index="test_eval_document") == 2
|
||||
|
||||
# eval retriever
|
||||
results = retriever.eval(top_k=1, label_index="test_feedback", doc_index="test_eval_document", open_domain=open_domain)
|
||||
assert results["recall"] == 1.0
|
||||
assert results["map"] == 1.0
|
||||
|
||||
# clean up
|
||||
document_store.delete_all_documents(index="test_eval_document")
|
||||
document_store.delete_all_documents(index="test_feedback")
|
||||
@ -2,8 +2,7 @@ from haystack import Finder
|
||||
from haystack.retriever.sparse import TfidfRetriever
|
||||
import pytest
|
||||
|
||||
#@pytest.mark.parametrize("reader", [("farm")], indirect=True)
|
||||
#@pytest.mark.parametrize("document_store_with_docs", [("elasticsearch")], indirect=True)
|
||||
|
||||
def test_finder_get_answers(reader, document_store_with_docs):
|
||||
retriever = TfidfRetriever(document_store=document_store_with_docs)
|
||||
finder = Finder(reader, retriever)
|
||||
|
||||
@ -51,12 +51,10 @@ def test_memory_store_get_by_tag_lists_union():
|
||||
document_store.write_documents(test_docs)
|
||||
|
||||
docs = document_store.get_document_ids_by_tags({'tag2': ["1"]})
|
||||
|
||||
assert docs == [
|
||||
{'text': 'testing the finder with pyhton unit test 1', 'meta': {'name': 'testing the finder 1', 'url': 'url'}, 'tags': [{'tag2': ['1']}]},
|
||||
{'text': 'testing the finder with pyhton unit test 3', 'meta': {'name': 'testing the finder 3', 'url': 'url'}, 'tags': [{'tag2': ['1', '2']}]}
|
||||
]
|
||||
|
||||
assert docs[0].text == 'testing the finder with pyhton unit test 1'
|
||||
assert docs[1].text == 'testing the finder with pyhton unit test 3'
|
||||
assert docs[1].text == 'testing the finder with pyhton unit test 3'
|
||||
assert docs[1].tags[0] == {"tag2": ["1", "2"]}
|
||||
|
||||
def test_memory_store_get_by_tag_lists_non_existent_tag():
|
||||
test_docs = [
|
||||
@ -82,5 +80,6 @@ def test_memory_store_get_by_tag_lists_disjoint():
|
||||
document_store.write_documents(test_docs)
|
||||
|
||||
docs = document_store.get_document_ids_by_tags({'tag3': ["3"]})
|
||||
|
||||
assert docs == [{'text': 'testing the finder with pyhton unit test 3', 'meta': {'name': 'testing the finder 4', 'url': 'url'}, 'tags': [{'tag3': ['1', '3']}]}]
|
||||
assert len(docs) == 1
|
||||
assert docs[0].text == 'testing the finder with pyhton unit test 3'
|
||||
assert docs[0].tags[0] == {"tag3": ["1", "3"]}
|
||||
@ -21,7 +21,6 @@ def test_output(prediction):
|
||||
assert prediction["answers"][0]["probability"] <= 1
|
||||
assert prediction["answers"][0]["probability"] >= 0
|
||||
assert prediction["answers"][0]["context"] == "My name is Carla and I live in Berlin"
|
||||
assert prediction["answers"][0]["document_id"] == "filename1"
|
||||
assert len(prediction["answers"]) == 5
|
||||
|
||||
|
||||
@ -61,10 +60,7 @@ def test_answer_attributes(prediction):
|
||||
|
||||
def test_context_window_size(test_docs_xs):
|
||||
# TODO parametrize window_size and farm/transformers reader using pytest
|
||||
docs = []
|
||||
for d in test_docs_xs:
|
||||
doc = Document(id=d["meta"]["name"], text=d["text"], meta=d["meta"])
|
||||
docs.append(doc)
|
||||
docs = [Document.from_dict(d) if isinstance(d, dict) else d for d in test_docs_xs]
|
||||
for window_size in [10, 15, 20]:
|
||||
farm_reader = FARMReader(model_name_or_path="distilbert-base-uncased-distilled-squad", num_processes=0,
|
||||
use_gpu=False, top_k_per_sample=5, no_ans_boost=None, context_window_size=window_size)
|
||||
@ -86,10 +82,8 @@ def test_context_window_size(test_docs_xs):
|
||||
def test_top_k(test_docs_xs):
|
||||
# TODO parametrize top_k and farm/transformers reader using pytest
|
||||
# TODO transformers reader was crashing when tested on this
|
||||
docs = []
|
||||
for d in test_docs_xs:
|
||||
doc = Document(id=d["meta"]["name"], text=d["text"], meta=d["meta"])
|
||||
docs.append(doc)
|
||||
|
||||
docs = [Document.from_dict(d) if isinstance(d, dict) else d for d in test_docs_xs]
|
||||
farm_reader = FARMReader(model_name_or_path="distilbert-base-uncased-distilled-squad", num_processes=0,
|
||||
use_gpu=False, top_k_per_sample=4, no_ans_boost=None, top_k_per_candidate=4)
|
||||
for top_k in [2, 5, 10]:
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
from haystack.database.base import Document
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
def test_tfidf_retriever():
|
||||
from haystack.retriever.sparse import TfidfRetriever
|
||||
|
||||
test_docs = [
|
||||
{"name": "testing the finder 1", "text": "godzilla says hello"},
|
||||
{"id": "26f84672c6d7aaeb8e2cd53e9c62d62d", "name": "testing the finder 1", "text": "godzilla says hello"},
|
||||
{"name": "testing the finder 2", "text": "optimus prime says bye"},
|
||||
{"name": "testing the finder 3", "text": "alien says arghh"}
|
||||
]
|
||||
@ -16,13 +16,7 @@ def test_tfidf_retriever():
|
||||
|
||||
retriever = TfidfRetriever(document_store)
|
||||
retriever.fit()
|
||||
assert retriever.retrieve("godzilla", top_k=1) == [
|
||||
Document(
|
||||
id='26f84672c6d7aaeb8e2cd53e9c62d62d',
|
||||
text='godzilla says hello',
|
||||
external_source_id=None,
|
||||
question=None,
|
||||
query_score=None,
|
||||
meta={"name": "testing the finder 1"},
|
||||
)
|
||||
]
|
||||
doc = retriever.retrieve("godzilla", top_k=1)[0]
|
||||
assert doc.id == UUID("26f84672c6d7aaeb8e2cd53e9c62d62d", version=4)
|
||||
assert doc.text == 'godzilla says hello'
|
||||
assert doc.meta == {"name": "testing the finder 1"}
|
||||
|
||||
@ -125,7 +125,6 @@
|
||||
"from haystack.database.elasticsearch import ElasticsearchDocumentStore\n",
|
||||
"document_store = ElasticsearchDocumentStore(host=\"localhost\", username=\"\", password=\"\",\n",
|
||||
" index=\"document\",\n",
|
||||
" text_field=\"answer\",\n",
|
||||
" embedding_field=\"question_emb\",\n",
|
||||
" embedding_dim=768,\n",
|
||||
" excluded_meta_data=[\"question_emb\"])"
|
||||
@ -187,6 +186,7 @@
|
||||
"questions = list(df[\"question\"].values)\n",
|
||||
"df[\"question_emb\"] = retriever.embed_queries(texts=questions)\n",
|
||||
"df[\"question_emb\"] = df[\"question_emb\"].apply(list) # convert from numpy to list for ES indexing\n",
|
||||
"df = df.rename(columns={\"answer\": \"text\"})\n",
|
||||
"\n",
|
||||
"# Convert Dataframe to list of dicts and index them in our DocumentStore\n",
|
||||
"docs_to_index = df.to_dict(orient=\"records\")\n",
|
||||
|
||||
@ -42,7 +42,6 @@ if LAUNCH_ELASTICSEARCH:
|
||||
|
||||
document_store = ElasticsearchDocumentStore(host="localhost", username="", password="",
|
||||
index="document",
|
||||
text_field="answer",
|
||||
embedding_field="question_emb",
|
||||
embedding_dim=768,
|
||||
excluded_meta_data=["question_emb"])
|
||||
@ -69,6 +68,7 @@ print(df.head())
|
||||
questions = list(df["question"].values)
|
||||
df["question_emb"] = retriever.embed_queries(texts=questions)
|
||||
df["question_emb"] = df["question_emb"].apply(list) # convert from numpy to list for ES indexing
|
||||
df = df.rename(columns={"answer": "text"})
|
||||
|
||||
# Convert Dataframe to list of dicts and index them in our DocumentStore
|
||||
docs_to_index = df.to_dict(orient="records")
|
||||
|
||||
@ -1680,7 +1680,9 @@
|
||||
"# Connect to Elasticsearch\n",
|
||||
"from haystack.database.elasticsearch import ElasticsearchDocumentStore\n",
|
||||
"\n",
|
||||
"document_store = ElasticsearchDocumentStore(host=\"localhost\", username=\"\", password=\"\", create_index=False)"
|
||||
"document_store = ElasticsearchDocumentStore(host=\"localhost\", username=\"\", password=\"\", index=\"document\",\n",
|
||||
" create_index=False, embedding_field=\"emb\",\n",
|
||||
" embedding_dim=768, excluded_meta_data=[\"emb\"])"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
@ -1743,7 +1745,12 @@
|
||||
"# Initialize Retriever\n",
|
||||
"from haystack.retriever.sparse import ElasticsearchRetriever\n",
|
||||
"\n",
|
||||
"retriever = ElasticsearchRetriever(document_store=document_store)"
|
||||
"retriever = ElasticsearchRetriever(document_store=document_store)\n",
|
||||
"\n",
|
||||
"# Alternative: Evaluate DensePassageRetriever\n",
|
||||
"# from haystack.retriever.dense import DensePassageRetriever\n",
|
||||
"# retriever = DensePassageRetriever(document_store=document_store, embedding_model=\"dpr-bert-base-nq\",batch_size=32)\n",
|
||||
"# document_store.update_embeddings(retriever, index=\"eval_document\")"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
|
||||
@ -16,9 +16,9 @@ logger = logging.getLogger(__name__)
|
||||
##############################################
|
||||
LAUNCH_ELASTICSEARCH = True
|
||||
|
||||
eval_retriever_only = False
|
||||
eval_retriever_only = True
|
||||
eval_reader_only = False
|
||||
eval_both = True
|
||||
eval_both = False
|
||||
|
||||
##############################################
|
||||
# Code
|
||||
@ -43,17 +43,31 @@ s3_url = "https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-qa/datasets/nq_d
|
||||
fetch_archive_from_http(url=s3_url, output_dir=doc_dir)
|
||||
|
||||
# Connect to Elasticsearch
|
||||
document_store = ElasticsearchDocumentStore(host="localhost", username="", password="", index="document", create_index=False)
|
||||
document_store = ElasticsearchDocumentStore(host="localhost", username="", password="", index="document",
|
||||
create_index=False, embedding_field="emb",
|
||||
embedding_dim=768, excluded_meta_data=["emb"])
|
||||
|
||||
# Add evaluation data to Elasticsearch database
|
||||
if LAUNCH_ELASTICSEARCH:
|
||||
document_store.add_eval_data("../data/nq/nq_dev_subset_v2.json")
|
||||
document_store.add_eval_data(filename="../data/nq/nq_dev_subset_v2.json", doc_index="eval_document", label_index="feedback" )
|
||||
else:
|
||||
logger.warning("Since we already have a running ES instance we should not index the same documents again."
|
||||
"If you still want to do this call: 'document_store.add_eval_data('../data/nq/nq_dev_subset_v2.json')' manually ")
|
||||
|
||||
|
||||
# Initialize Retriever
|
||||
retriever = ElasticsearchRetriever(document_store=document_store)
|
||||
|
||||
# Alternative: Evaluate DensePassageRetriever
|
||||
# Note, that DPR works best when you index short passages < 512 tokens as only those tokens will be used for the embedding.
|
||||
# Here, for nq_dev_subset_v2.json we have avg. num of tokens = 5220(!).
|
||||
# DPR still outperforms Elastic's BM25 by a small margin here.
|
||||
|
||||
# from haystack.retriever.dense import DensePassageRetriever
|
||||
# retriever = DensePassageRetriever(document_store=document_store, embedding_model="dpr-bert-base-nq",batch_size=32)
|
||||
# document_store.update_embeddings(retriever, index="eval_document")
|
||||
|
||||
|
||||
# Initialize Reader
|
||||
reader = FARMReader("deepset/roberta-base-squad2")
|
||||
|
||||
@ -63,7 +77,7 @@ finder = Finder(reader, retriever)
|
||||
|
||||
## Evaluate Retriever on its own
|
||||
if eval_retriever_only:
|
||||
retriever_eval_results = retriever.eval()
|
||||
retriever_eval_results = retriever.eval(top_k=1)
|
||||
## Retriever Recall is the proportion of questions for which the correct document containing the answer is
|
||||
## among the correct documents
|
||||
print("Retriever Recall:", retriever_eval_results["recall"])
|
||||
@ -86,5 +100,5 @@ if eval_reader_only:
|
||||
|
||||
# Evaluate combination of Reader and Retriever through Finder
|
||||
if eval_both:
|
||||
finder_eval_results = finder.eval(top_k_retriever = 10, top_k_reader = 10)
|
||||
finder_eval_results = finder.eval(top_k_retriever=1, top_k_reader=10)
|
||||
finder.print_eval_results(finder_eval_results)
|
||||
|
||||
@ -44,10 +44,10 @@ dicts = convert_files_to_dicts(dir_path=doc_dir, clean_func=clean_wiki_text, spl
|
||||
# Now, let's write the docs to our DB.
|
||||
document_store.write_documents(dicts[:16])
|
||||
|
||||
|
||||
### Retriever
|
||||
retriever = DensePassageRetriever(document_store=document_store, embedding_model="dpr-bert-base-nq",
|
||||
do_lower_case=True, use_gpu=True)
|
||||
|
||||
# Important:
|
||||
# Now that after we have the DPR initialized, we need to call update_embeddings() to iterate over all
|
||||
# previously indexed documents and update their embedding representation.
|
||||
|
||||
1145
tutorials/small_faq_covid.csv
Normal file
1145
tutorials/small_faq_covid.csv
Normal file
File diff suppressed because one or more lines are too long
Loading…
x
Reference in New Issue
Block a user