Improve speed for SQLDocumentStore (#330)

This commit is contained in:
Tanay Soni 2020-08-21 09:24:49 +02:00 committed by GitHub
parent a54d6a5bd7
commit 7d2a8f19fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 29 deletions

View File

@ -150,9 +150,9 @@ class FAISSDocumentStore(SQLDocumentStore):
_, vector_id_matrix = self.faiss_index.search(hnsw_vectors, top_k)
vector_ids_for_query = [str(vector_id) for vector_id in vector_id_matrix[0] if vector_id != -1]
documents = [
self.get_all_documents(filters={"vector_id": [vector_id]})[0] for vector_id in vector_ids_for_query
]
documents = self.get_all_documents(filters={"vector_id": vector_ids_for_query}, index=index)
# sort the documents as per query results
documents = sorted(documents, key=lambda doc: vector_ids_for_query.index(doc.meta["vector_id"])) # type: ignore
return documents

View File

@ -31,8 +31,8 @@ class DocumentORM(ORMBase):
class MetaORM(ORMBase):
__tablename__ = "meta"
name = Column(String)
value = Column(String)
name = Column(String, index=True)
value = Column(String, index=True)
documents = relationship(DocumentORM, secondary="document_meta", backref="Meta")
@ -80,36 +80,18 @@ class SQLDocumentStore(BaseDocumentStore):
return documents
def get_all_documents( # type: ignore
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
index: Optional[str] = None,
filters: Optional[Dict[str, List[str]]] = None,
def get_all_documents(
self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None
) -> List[Document]:
index = index or self.index
document_rows = self.session.query(DocumentORM).filter_by(index=index).all()
if offset:
document_rows = document_rows.offset(offset)
if limit:
document_rows = document_rows.limit(limit)
documents = []
for row in document_rows:
documents.append(self._convert_sql_row_to_document(row))
query = self.session.query(DocumentORM).filter_by(index=index)
if filters:
for key, values in filters.items():
results = (
self.session.query(DocumentORM)
.filter(DocumentORM.meta.any(MetaORM.name.in_([key])))
.filter(DocumentORM.meta.any(MetaORM.value.in_(values)))
.all()
)
else:
results = self.session.query(DocumentORM).filter_by(index=index).all()
query = query.filter(DocumentORM.meta.any(MetaORM.name.in_([key])))\
.filter(DocumentORM.meta.any(MetaORM.value.in_(values)))
documents = [self._convert_sql_row_to_document(row) for row in results]
documents = [self._convert_sql_row_to_document(row) for row in query.all()]
return documents
def get_all_labels(self, index=None, filters: Optional[dict] = None):