feat: Add BM25 support for tables in InMemoryDocumentStore (#4090)

* Add BM25 support for tables in InMemoryDocumentStore

* Add table type to query method

* Fix import order

* Adapt tests
This commit is contained in:
bogdankostic 2023-02-09 10:47:35 +01:00 committed by GitHub
parent 93962c09fc
commit 986472c26f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 6 deletions

View File

@ -15,6 +15,7 @@ import numpy as np
import torch
from tqdm.auto import tqdm
import rank_bm25
import pandas as pd
from haystack.schema import Document, FilterType, Label
from haystack.errors import DuplicateDocumentError, DocumentStoreError
@ -206,7 +207,15 @@ class InMemoryDocumentStore(KeywordDocumentStore):
index = index or self.index
all_documents = self.get_all_documents(index=index)
textual_documents = [doc for doc in all_documents if doc.content_type == "text"]
textual_documents = []
for doc in all_documents:
if doc.content_type == "text":
textual_documents.append(doc.content.lower())
elif doc.content_type == "table":
if isinstance(doc.content, pd.DataFrame):
textual_documents.append(doc.content.astype(str).to_csv(index=False).lower())
else:
raise DocumentStoreError("Documents of type 'table' need to have a pd.DataFrame as content field")
if len(textual_documents) < len(all_documents):
logger.warning(
"Some documents in %s index are non-textual."
@ -215,7 +224,7 @@ class InMemoryDocumentStore(KeywordDocumentStore):
)
tokenized_corpus = [
self.bm25_tokenization_regex(doc.content.lower())
self.bm25_tokenization_regex(doc)
for doc in tqdm(textual_documents, unit=" docs", desc="Updating BM25 representation...")
]
self.bm25[index] = self.bm25_algorithm(tokenized_corpus, **self.bm25_parameters)
@ -962,7 +971,7 @@ class InMemoryDocumentStore(KeywordDocumentStore):
docs_scores = self.bm25[index].get_scores(tokenized_query)
top_docs_positions = np.argsort(docs_scores)[::-1][:top_k]
textual_docs_list = [doc for doc in self.indexes[index].values() if doc.content_type == "text"]
textual_docs_list = [doc for doc in self.indexes[index].values() if doc.content_type in ["text", "table"]]
top_docs = []
for i in top_docs_positions:
doc = textual_docs_list[i]

View File

@ -1,5 +1,6 @@
import logging
import pandas as pd
import pytest
from rank_bm25 import BM25
@ -65,10 +66,19 @@ class TestInMemoryDocumentStore(DocumentStoreBaseTestAbstract):
assert set(ids) == result
@pytest.mark.integration
def test_update_bm25(self, documents):
ds = InMemoryDocumentStore(use_bm25=False)
def test_update_bm25(self, ds, documents):
ds.write_documents(documents)
ds.update_bm25()
bm25_representation = ds.bm25[ds.index]
assert isinstance(bm25_representation, BM25)
assert bm25_representation.corpus_size == ds.get_document_count()
@pytest.mark.integration
def test_update_bm25_table(self, ds):
table_doc = Document(
content=pd.DataFrame(columns=["id", "text"], data=[[0, "This is a test"], ["2", "This is another test"]]),
content_type="table",
)
ds.write_documents([table_doc])
bm25_representation = ds.bm25[ds.index]
assert isinstance(bm25_representation, BM25)
assert bm25_representation.corpus_size == ds.get_document_count()