mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-12 23:27:04 +00:00
553 lines
23 KiB
Python
553 lines
23 KiB
Python
import itertools
|
|
import logging
|
|
from typing import Any, Dict, Union, List, Optional, Generator
|
|
from uuid import uuid4
|
|
|
|
import numpy as np
|
|
from sqlalchemy import and_, func, create_engine, Column, Integer, String, DateTime, ForeignKey, Boolean, Text, text
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
from sqlalchemy.orm import relationship, sessionmaker
|
|
from sqlalchemy.sql import case, null
|
|
|
|
from haystack import Document, Label
|
|
from haystack.document_store.base import BaseDocumentStore
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
Base = declarative_base() # type: Any
|
|
|
|
|
|
class ORMBase(Base):
|
|
__abstract__ = True
|
|
|
|
id = Column(String(100), default=lambda: str(uuid4()), primary_key=True)
|
|
created_at = Column(DateTime, server_default=func.now())
|
|
updated_at = Column(DateTime, server_default=func.now(), server_onupdate=func.now())
|
|
|
|
|
|
class DocumentORM(ORMBase):
|
|
__tablename__ = "document"
|
|
|
|
text = Column(Text, nullable=False)
|
|
index = Column(String(100), nullable=False)
|
|
vector_id = Column(String(100), unique=True, nullable=True)
|
|
|
|
# speeds up queries for get_documents_by_vector_ids() by having a single query that returns joined metadata
|
|
meta = relationship("MetaORM", backref="Document", lazy="joined")
|
|
|
|
|
|
class MetaORM(ORMBase):
|
|
__tablename__ = "meta"
|
|
|
|
name = Column(String(100), index=True)
|
|
value = Column(String(1000), index=True)
|
|
document_id = Column(
|
|
String(100),
|
|
ForeignKey("document.id", ondelete="CASCADE", onupdate="CASCADE"),
|
|
nullable=False,
|
|
index=True
|
|
)
|
|
|
|
documents = relationship(DocumentORM, backref="Meta")
|
|
|
|
|
|
class LabelORM(ORMBase):
|
|
__tablename__ = "label"
|
|
|
|
document_id = Column(String(100), ForeignKey("document.id", ondelete="CASCADE", onupdate="CASCADE"), nullable=False)
|
|
index = Column(String(100), nullable=False)
|
|
no_answer = Column(Boolean, nullable=False)
|
|
origin = Column(String(100), nullable=False)
|
|
question = Column(Text, nullable=False)
|
|
is_correct_answer = Column(Boolean, nullable=False)
|
|
is_correct_document = Column(Boolean, nullable=False)
|
|
answer = Column(Text, nullable=False)
|
|
offset_start_in_doc = Column(Integer, nullable=False)
|
|
model_id = Column(Integer, nullable=True)
|
|
|
|
|
|
class SQLDocumentStore(BaseDocumentStore):
|
|
def __init__(
|
|
self,
|
|
url: str = "sqlite://",
|
|
index: str = "document",
|
|
label_index: str = "label",
|
|
duplicate_documents: str = "overwrite"
|
|
):
|
|
"""
|
|
An SQL backed DocumentStore. Currently supports SQLite, PostgreSQL and MySQL backends.
|
|
|
|
:param url: URL for SQL database as expected by SQLAlchemy. More info here: https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls
|
|
:param index: The documents are scoped to an index attribute that can be used when writing, querying, or deleting documents.
|
|
This parameter sets the default value for document index.
|
|
:param label_index: The default value of index attribute for the labels.
|
|
:param duplicate_documents: Handle duplicates document based on parameter options.
|
|
Parameter options : ( 'skip','overwrite','fail')
|
|
skip: Ignore the duplicates documents
|
|
overwrite: Update any existing documents with the same ID when adding documents.
|
|
fail: an error is raised if the document ID of the document being added already
|
|
exists.
|
|
"""
|
|
|
|
# save init parameters to enable export of component config as YAML
|
|
self.set_config(
|
|
url=url, index=index, label_index=label_index, duplicate_documents=duplicate_documents
|
|
)
|
|
|
|
engine = create_engine(url)
|
|
ORMBase.metadata.create_all(engine)
|
|
Session = sessionmaker(bind=engine)
|
|
self.session = Session()
|
|
self.index: str = index
|
|
self.label_index = label_index
|
|
self.duplicate_documents = duplicate_documents
|
|
if getattr(self, "similarity", None) is None:
|
|
self.similarity = None
|
|
self.use_windowed_query = True
|
|
if "sqlite" in url:
|
|
import sqlite3
|
|
if sqlite3.sqlite_version < "3.25":
|
|
self.use_windowed_query = False
|
|
|
|
def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]:
|
|
"""Fetch a document by specifying its text id string"""
|
|
documents = self.get_documents_by_id([id], index)
|
|
document = documents[0] if documents else None
|
|
return document
|
|
|
|
def get_documents_by_id(self, ids: List[str], index: Optional[str] = None, batch_size: int = 10_000) -> List[Document]:
|
|
"""Fetch documents by specifying a list of text id strings"""
|
|
index = index or self.index
|
|
|
|
documents = []
|
|
for i in range(0, len(ids), batch_size):
|
|
query = self.session.query(DocumentORM).filter(
|
|
DocumentORM.id.in_(ids[i: i + batch_size]),
|
|
DocumentORM.index == index
|
|
)
|
|
for row in query.all():
|
|
documents.append(self._convert_sql_row_to_document(row))
|
|
|
|
return documents
|
|
|
|
def get_documents_by_vector_ids(self, vector_ids: List[str], index: Optional[str] = None, batch_size: int = 10_000):
|
|
"""Fetch documents by specifying a list of text vector id strings"""
|
|
index = index or self.index
|
|
|
|
documents = []
|
|
for i in range(0, len(vector_ids), batch_size):
|
|
query = self.session.query(DocumentORM).filter(
|
|
DocumentORM.vector_id.in_(vector_ids[i: i + batch_size]),
|
|
DocumentORM.index == index
|
|
)
|
|
for row in query.all():
|
|
documents.append(self._convert_sql_row_to_document(row))
|
|
|
|
sorted_documents = sorted(documents, key=lambda doc: vector_ids.index(doc.meta["vector_id"]))
|
|
return sorted_documents
|
|
|
|
def get_all_documents(
|
|
self,
|
|
index: Optional[str] = None,
|
|
filters: Optional[Dict[str, List[str]]] = None,
|
|
return_embedding: Optional[bool] = None,
|
|
) -> List[Document]:
|
|
documents = list(self.get_all_documents_generator(index=index, filters=filters))
|
|
return documents
|
|
|
|
def get_all_documents_generator(
|
|
self,
|
|
index: Optional[str] = None,
|
|
filters: Optional[Dict[str, List[str]]] = None,
|
|
return_embedding: Optional[bool] = None,
|
|
batch_size: int = 10_000,
|
|
) -> Generator[Document, None, None]:
|
|
"""
|
|
Get documents from the document store. Under-the-hood, documents are fetched in batches from the
|
|
document store and yielded as individual documents. This method can be used to iteratively process
|
|
a large number of documents without having to load all documents in memory.
|
|
|
|
:param index: Name of the index to get the documents from. If None, the
|
|
DocumentStore's default index (self.index) will be used.
|
|
:param filters: Optional filters to narrow down the documents to return.
|
|
Example: {"name": ["some", "more"], "category": ["only_one"]}
|
|
:param return_embedding: Whether to return the document embeddings.
|
|
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
|
|
"""
|
|
|
|
if return_embedding is True:
|
|
raise Exception("return_embeddings is not supported by SQLDocumentStore.")
|
|
result = self._query(
|
|
index=index,
|
|
filters=filters,
|
|
batch_size=batch_size,
|
|
)
|
|
yield from result
|
|
|
|
def _query(
|
|
self,
|
|
index: Optional[str] = None,
|
|
filters: Optional[Dict[str, List[str]]] = None,
|
|
vector_ids: Optional[List[str]] = None,
|
|
only_documents_without_embedding: bool = False,
|
|
batch_size: int = 10_000
|
|
):
|
|
"""
|
|
:param index: Name of the index to get the documents from. If None, the
|
|
DocumentStore's default index (self.index) will be used.
|
|
:param filters: Optional filters to narrow down the documents to return.
|
|
Example: {"name": ["some", "more"], "category": ["only_one"]}
|
|
:param vector_ids: List of vector_id strings to filter the documents by.
|
|
:param only_documents_without_embedding: return only documents without an embedding.
|
|
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
|
|
"""
|
|
index = index or self.index
|
|
# Generally ORM objects kept in memory cause performance issue
|
|
# Hence using directly column name improve memory and performance.
|
|
# Refer https://stackoverflow.com/questions/23185319/why-is-loading-sqlalchemy-objects-via-the-orm-5-8x-slower-than-rows-via-a-raw-my
|
|
documents_query = self.session.query(
|
|
DocumentORM.id,
|
|
DocumentORM.text,
|
|
DocumentORM.vector_id
|
|
).filter_by(index=index)
|
|
|
|
if filters:
|
|
documents_query = documents_query.join(MetaORM)
|
|
for key, values in filters.items():
|
|
documents_query = documents_query.filter(
|
|
MetaORM.name == key,
|
|
MetaORM.value.in_(values),
|
|
DocumentORM.id == MetaORM.document_id
|
|
)
|
|
if only_documents_without_embedding:
|
|
documents_query = documents_query.filter(DocumentORM.vector_id.is_(None))
|
|
if vector_ids:
|
|
documents_query = documents_query.filter(DocumentORM.vector_id.in_(vector_ids))
|
|
|
|
documents_map = {}
|
|
|
|
if self.use_windowed_query:
|
|
documents_query = self._windowed_query(documents_query, DocumentORM.id, batch_size)
|
|
|
|
for i, row in enumerate(documents_query, start=1):
|
|
documents_map[row.id] = Document(
|
|
id=row.id,
|
|
text=row.text,
|
|
meta=None if row.vector_id is None else {"vector_id": row.vector_id}
|
|
)
|
|
if i % batch_size == 0:
|
|
documents_map = self._get_documents_meta(documents_map)
|
|
yield from documents_map.values()
|
|
documents_map = {}
|
|
if documents_map:
|
|
documents_map = self._get_documents_meta(documents_map)
|
|
yield from documents_map.values()
|
|
|
|
def _get_documents_meta(self, documents_map):
|
|
doc_ids = documents_map.keys()
|
|
meta_query = self.session.query(
|
|
MetaORM.document_id,
|
|
MetaORM.name,
|
|
MetaORM.value
|
|
).filter(MetaORM.document_id.in_(doc_ids))
|
|
|
|
for row in meta_query.all():
|
|
documents_map[row.document_id].meta[row.name] = row.value
|
|
return documents_map
|
|
|
|
def get_all_labels(self, index=None, filters: Optional[dict] = None):
|
|
"""
|
|
Return all labels in the document store
|
|
"""
|
|
index = index or self.label_index
|
|
# TODO: Use batch_size
|
|
label_rows = self.session.query(LabelORM).filter_by(index=index).all()
|
|
labels = [self._convert_sql_row_to_label(row) for row in label_rows]
|
|
|
|
return labels
|
|
|
|
def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None,
|
|
batch_size: int = 10_000, duplicate_documents: Optional[str] = None):
|
|
"""
|
|
Indexes documents for later queries.
|
|
|
|
:param documents: a list of Python dictionaries or a list of Haystack Document objects.
|
|
For documents as dictionaries, the format is {"text": "<the-actual-text>"}.
|
|
Optionally: Include meta data via {"text": "<the-actual-text>",
|
|
"meta":{"name": "<some-document-name>, "author": "somebody", ...}}
|
|
It can be used for filtering and is accessible in the responses of the Finder.
|
|
:param index: add an optional index attribute to documents. It can be later used for filtering. For instance,
|
|
documents for evaluation can be indexed in a separate index than the documents for search.
|
|
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
|
|
:param duplicate_documents: Handle duplicates document based on parameter options.
|
|
Parameter options : ( 'skip','overwrite','fail')
|
|
skip: Ignore the duplicates documents
|
|
overwrite: Update any existing documents with the same ID when adding documents.
|
|
fail: an error is raised if the document ID of the document being added already
|
|
exists.
|
|
|
|
:return: None
|
|
"""
|
|
|
|
index = index or self.index
|
|
duplicate_documents = duplicate_documents or self.duplicate_documents
|
|
if len(documents) == 0:
|
|
return
|
|
# Make sure we comply to Document class format
|
|
if isinstance(documents[0], dict):
|
|
document_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents]
|
|
else:
|
|
document_objects = documents
|
|
|
|
document_objects = self._handle_duplicate_documents(document_objects, duplicate_documents)
|
|
for i in range(0, len(document_objects), batch_size):
|
|
for doc in document_objects[i: i + batch_size]:
|
|
meta_fields = doc.meta or {}
|
|
vector_id = meta_fields.pop("vector_id", None)
|
|
meta_orms = [MetaORM(name=key, value=value) for key, value in meta_fields.items()]
|
|
doc_orm = DocumentORM(id=doc.id, text=doc.text, vector_id=vector_id, meta=meta_orms, index=index)
|
|
if duplicate_documents == "overwrite":
|
|
# First old meta data cleaning is required
|
|
self.session.query(MetaORM).filter_by(document_id=doc.id).delete()
|
|
self.session.merge(doc_orm)
|
|
else:
|
|
self.session.add(doc_orm)
|
|
try:
|
|
self.session.commit()
|
|
except Exception as ex:
|
|
logger.error(f"Transaction rollback: {ex.__cause__}")
|
|
# Rollback is important here otherwise self.session will be in inconsistent state and next call will fail
|
|
self.session.rollback()
|
|
raise ex
|
|
|
|
def write_labels(self, labels, index=None):
|
|
"""Write annotation labels into document store."""
|
|
|
|
labels = [Label.from_dict(l) if isinstance(l, dict) else l for l in labels]
|
|
index = index or self.label_index
|
|
# TODO: Use batch_size
|
|
for label in labels:
|
|
label_orm = LabelORM(
|
|
document_id=label.document_id,
|
|
no_answer=label.no_answer,
|
|
origin=label.origin,
|
|
question=label.question,
|
|
is_correct_answer=label.is_correct_answer,
|
|
is_correct_document=label.is_correct_document,
|
|
answer=label.answer,
|
|
offset_start_in_doc=label.offset_start_in_doc,
|
|
model_id=label.model_id,
|
|
index=index,
|
|
)
|
|
self.session.add(label_orm)
|
|
self.session.commit()
|
|
|
|
def update_vector_ids(self, vector_id_map: Dict[str, str], index: Optional[str] = None, batch_size: int = 10_000):
|
|
"""
|
|
Update vector_ids for given document_ids.
|
|
|
|
:param vector_id_map: dict containing mapping of document_id -> vector_id.
|
|
:param index: filter documents by the optional index attribute for documents in database.
|
|
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
|
|
"""
|
|
index = index or self.index
|
|
for chunk_map in self.chunked_dict(vector_id_map, size=batch_size):
|
|
self.session.query(DocumentORM).filter(
|
|
DocumentORM.id.in_(chunk_map),
|
|
DocumentORM.index == index
|
|
).update({
|
|
DocumentORM.vector_id: case(
|
|
chunk_map,
|
|
value=DocumentORM.id,
|
|
)
|
|
}, synchronize_session=False)
|
|
try:
|
|
self.session.commit()
|
|
except Exception as ex:
|
|
logger.error(f"Transaction rollback: {ex.__cause__}")
|
|
self.session.rollback()
|
|
raise ex
|
|
|
|
def reset_vector_ids(self, index: Optional[str] = None):
|
|
"""
|
|
Set vector IDs for all documents as None
|
|
"""
|
|
index = index or self.index
|
|
self.session.query(DocumentORM).filter_by(index=index).update({DocumentORM.vector_id: null()})
|
|
self.session.commit()
|
|
|
|
def update_document_meta(self, id: str, meta: Dict[str, str]):
|
|
"""
|
|
Update the metadata dictionary of a document by specifying its string id
|
|
"""
|
|
self.session.query(MetaORM).filter_by(document_id=id).delete()
|
|
meta_orms = [MetaORM(name=key, value=value, document_id=id) for key, value in meta.items()]
|
|
for m in meta_orms:
|
|
self.session.add(m)
|
|
self.session.commit()
|
|
|
|
def get_document_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None) -> int:
|
|
"""
|
|
Return the number of documents in the document store.
|
|
"""
|
|
index = index or self.index
|
|
query = self.session.query(DocumentORM).filter_by(index=index)
|
|
|
|
if filters:
|
|
query = query.join(MetaORM)
|
|
for key, values in filters.items():
|
|
query = query.filter(MetaORM.name == key, MetaORM.value.in_(values))
|
|
|
|
count = query.count()
|
|
return count
|
|
|
|
def get_label_count(self, index: Optional[str] = None) -> int:
|
|
"""
|
|
Return the number of labels in the document store
|
|
"""
|
|
index = index or self.label_index
|
|
return self.session.query(LabelORM).filter_by(index=index).count()
|
|
|
|
def _convert_sql_row_to_document(self, row) -> Document:
|
|
document = Document(
|
|
id=row.id,
|
|
text=row.text,
|
|
meta={meta.name: meta.value for meta in row.meta}
|
|
)
|
|
if row.vector_id:
|
|
document.meta["vector_id"] = row.vector_id
|
|
return document
|
|
|
|
def _convert_sql_row_to_label(self, row) -> Label:
|
|
label = Label(
|
|
document_id=row.document_id,
|
|
no_answer=row.no_answer,
|
|
origin=row.origin,
|
|
question=row.question,
|
|
is_correct_answer=row.is_correct_answer,
|
|
is_correct_document=row.is_correct_document,
|
|
answer=row.answer,
|
|
offset_start_in_doc=row.offset_start_in_doc,
|
|
model_id=row.model_id,
|
|
created_at=row.created_at,
|
|
updated_at=row.updated_at
|
|
)
|
|
return label
|
|
|
|
def query_by_embedding(self,
|
|
query_emb: np.ndarray,
|
|
filters: Optional[dict] = None,
|
|
top_k: int = 10,
|
|
index: Optional[str] = None,
|
|
return_embedding: Optional[bool] = None) -> List[Document]:
|
|
|
|
raise NotImplementedError("SQLDocumentStore is currently not supporting embedding queries. "
|
|
"Change the query type (e.g. by choosing a different retriever) "
|
|
"or change the DocumentStore (e.g. to ElasticsearchDocumentStore)")
|
|
|
|
def delete_all_documents(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None):
|
|
"""
|
|
Delete documents in an index. All documents are deleted if no filters are passed.
|
|
|
|
:param index: Index name to delete the document from.
|
|
:param filters: Optional filters to narrow down the documents to be deleted.
|
|
:return: None
|
|
"""
|
|
logger.warning(
|
|
"""DEPRECATION WARNINGS:
|
|
1. delete_all_documents() method is deprecated, please use delete_documents method
|
|
For more details, please refer to the issue: https://github.com/deepset-ai/haystack/issues/1045
|
|
"""
|
|
)
|
|
self.delete_documents(index, filters)
|
|
|
|
def delete_documents(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None):
|
|
"""
|
|
Delete documents in an index. All documents are deleted if no filters are passed.
|
|
|
|
:param index: Index name to delete the document from.
|
|
:param filters: Optional filters to narrow down the documents to be deleted.
|
|
:return: None
|
|
"""
|
|
index = index or self.index
|
|
|
|
if filters:
|
|
# documents_query = documents_query.join(MetaORM)
|
|
document_ids_to_delete = self.session.query(DocumentORM.id).filter_by(index=index)
|
|
for key, values in filters.items():
|
|
document_ids_to_delete = document_ids_to_delete.filter(
|
|
MetaORM.name == key,
|
|
MetaORM.value.in_(values),
|
|
DocumentORM.id == MetaORM.document_id
|
|
)
|
|
self.session.query(DocumentORM).filter(DocumentORM.id.in_(document_ids_to_delete)).delete(
|
|
synchronize_session=False)
|
|
else:
|
|
self.session.query(DocumentORM).filter_by(index=index).delete(synchronize_session=False)
|
|
|
|
self.session.commit()
|
|
|
|
def _get_or_create(self, session, model, **kwargs):
|
|
instance = session.query(model).filter_by(**kwargs).first()
|
|
if instance:
|
|
return instance
|
|
else:
|
|
instance = model(**kwargs)
|
|
session.add(instance)
|
|
session.commit()
|
|
return instance
|
|
|
|
def chunked_dict(self, dictionary, size):
|
|
it = iter(dictionary)
|
|
for i in range(0, len(dictionary), size):
|
|
yield {k: dictionary[k] for k in itertools.islice(it, size)}
|
|
|
|
def _column_windows(self, session, column, windowsize):
|
|
"""Return a series of WHERE clauses against
|
|
a given column that break it into windows.
|
|
|
|
Result is an iterable of tuples, consisting of
|
|
((start, end), whereclause), where (start, end) are the ids.
|
|
|
|
The code is taken from: https://github.com/sqlalchemy/sqlalchemy/wiki/RangeQuery-and-WindowedRangeQuery
|
|
"""
|
|
|
|
def int_for_range(start_id, end_id):
|
|
if end_id:
|
|
return and_(
|
|
column >= start_id,
|
|
column < end_id
|
|
)
|
|
else:
|
|
return column >= start_id
|
|
|
|
q = session.query(
|
|
column,
|
|
func.row_number(). \
|
|
over(order_by=column). \
|
|
label('rownum')
|
|
). \
|
|
from_self(column)
|
|
if windowsize > 1:
|
|
q = q.filter(text("rownum %% %d=1" % windowsize))
|
|
|
|
intervals = [id for id, in q]
|
|
|
|
while intervals:
|
|
start = intervals.pop(0)
|
|
if intervals:
|
|
end = intervals[0]
|
|
else:
|
|
end = None
|
|
yield int_for_range(start, end)
|
|
|
|
def _windowed_query(self, q, column, windowsize):
|
|
""""Break a Query into windows on a given column."""
|
|
|
|
for whereclause in self._column_windows(
|
|
q.session,
|
|
column, windowsize):
|
|
for row in q.filter(whereclause).order_by(column):
|
|
yield row
|