diff --git a/haystack/document_stores/memory.py b/haystack/document_stores/memory.py index 418b1bbeb..bf3029c74 100644 --- a/haystack/document_stores/memory.py +++ b/haystack/document_stores/memory.py @@ -15,6 +15,7 @@ import numpy as np import torch from tqdm.auto import tqdm import rank_bm25 +import pandas as pd from haystack.schema import Document, FilterType, Label from haystack.errors import DuplicateDocumentError, DocumentStoreError @@ -206,7 +207,15 @@ class InMemoryDocumentStore(KeywordDocumentStore): index = index or self.index all_documents = self.get_all_documents(index=index) - textual_documents = [doc for doc in all_documents if doc.content_type == "text"] + textual_documents = [] + for doc in all_documents: + if doc.content_type == "text": + textual_documents.append(doc.content.lower()) + elif doc.content_type == "table": + if isinstance(doc.content, pd.DataFrame): + textual_documents.append(doc.content.astype(str).to_csv(index=False).lower()) + else: + raise DocumentStoreError("Documents of type 'table' need to have a pd.DataFrame as content field") if len(textual_documents) < len(all_documents): logger.warning( "Some documents in %s index are non-textual." @@ -215,7 +224,7 @@ class InMemoryDocumentStore(KeywordDocumentStore): ) tokenized_corpus = [ - self.bm25_tokenization_regex(doc.content.lower()) + self.bm25_tokenization_regex(doc) for doc in tqdm(textual_documents, unit=" docs", desc="Updating BM25 representation...") ] self.bm25[index] = self.bm25_algorithm(tokenized_corpus, **self.bm25_parameters) @@ -962,7 +971,7 @@ class InMemoryDocumentStore(KeywordDocumentStore): docs_scores = self.bm25[index].get_scores(tokenized_query) top_docs_positions = np.argsort(docs_scores)[::-1][:top_k] - textual_docs_list = [doc for doc in self.indexes[index].values() if doc.content_type == "text"] + textual_docs_list = [doc for doc in self.indexes[index].values() if doc.content_type in ["text", "table"]] top_docs = [] for i in top_docs_positions: doc = textual_docs_list[i] diff --git a/test/document_stores/test_memory.py b/test/document_stores/test_memory.py index ee686f1fa..92d3196f4 100644 --- a/test/document_stores/test_memory.py +++ b/test/document_stores/test_memory.py @@ -1,5 +1,6 @@ import logging +import pandas as pd import pytest from rank_bm25 import BM25 @@ -65,10 +66,19 @@ class TestInMemoryDocumentStore(DocumentStoreBaseTestAbstract): assert set(ids) == result @pytest.mark.integration - def test_update_bm25(self, documents): - ds = InMemoryDocumentStore(use_bm25=False) + def test_update_bm25(self, ds, documents): ds.write_documents(documents) - ds.update_bm25() + bm25_representation = ds.bm25[ds.index] + assert isinstance(bm25_representation, BM25) + assert bm25_representation.corpus_size == ds.get_document_count() + + @pytest.mark.integration + def test_update_bm25_table(self, ds): + table_doc = Document( + content=pd.DataFrame(columns=["id", "text"], data=[[0, "This is a test"], ["2", "This is another test"]]), + content_type="table", + ) + ds.write_documents([table_doc]) bm25_representation = ds.bm25[ds.index] assert isinstance(bm25_representation, BM25) assert bm25_representation.corpus_size == ds.get_document_count()