mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-05 19:47:45 +00:00
feat: Add BM25 support for tables in InMemoryDocumentStore (#4090)
* Add BM25 support for tables in InMemoryDocumentStore * Add table type to query method * Fix import order * Adapt tests
This commit is contained in:
parent
93962c09fc
commit
986472c26f
@ -15,6 +15,7 @@ import numpy as np
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
import rank_bm25
|
||||
import pandas as pd
|
||||
|
||||
from haystack.schema import Document, FilterType, Label
|
||||
from haystack.errors import DuplicateDocumentError, DocumentStoreError
|
||||
@ -206,7 +207,15 @@ class InMemoryDocumentStore(KeywordDocumentStore):
|
||||
index = index or self.index
|
||||
|
||||
all_documents = self.get_all_documents(index=index)
|
||||
textual_documents = [doc for doc in all_documents if doc.content_type == "text"]
|
||||
textual_documents = []
|
||||
for doc in all_documents:
|
||||
if doc.content_type == "text":
|
||||
textual_documents.append(doc.content.lower())
|
||||
elif doc.content_type == "table":
|
||||
if isinstance(doc.content, pd.DataFrame):
|
||||
textual_documents.append(doc.content.astype(str).to_csv(index=False).lower())
|
||||
else:
|
||||
raise DocumentStoreError("Documents of type 'table' need to have a pd.DataFrame as content field")
|
||||
if len(textual_documents) < len(all_documents):
|
||||
logger.warning(
|
||||
"Some documents in %s index are non-textual."
|
||||
@ -215,7 +224,7 @@ class InMemoryDocumentStore(KeywordDocumentStore):
|
||||
)
|
||||
|
||||
tokenized_corpus = [
|
||||
self.bm25_tokenization_regex(doc.content.lower())
|
||||
self.bm25_tokenization_regex(doc)
|
||||
for doc in tqdm(textual_documents, unit=" docs", desc="Updating BM25 representation...")
|
||||
]
|
||||
self.bm25[index] = self.bm25_algorithm(tokenized_corpus, **self.bm25_parameters)
|
||||
@ -962,7 +971,7 @@ class InMemoryDocumentStore(KeywordDocumentStore):
|
||||
docs_scores = self.bm25[index].get_scores(tokenized_query)
|
||||
top_docs_positions = np.argsort(docs_scores)[::-1][:top_k]
|
||||
|
||||
textual_docs_list = [doc for doc in self.indexes[index].values() if doc.content_type == "text"]
|
||||
textual_docs_list = [doc for doc in self.indexes[index].values() if doc.content_type in ["text", "table"]]
|
||||
top_docs = []
|
||||
for i in top_docs_positions:
|
||||
doc = textual_docs_list[i]
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from rank_bm25 import BM25
|
||||
|
||||
@ -65,10 +66,19 @@ class TestInMemoryDocumentStore(DocumentStoreBaseTestAbstract):
|
||||
assert set(ids) == result
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_update_bm25(self, documents):
|
||||
ds = InMemoryDocumentStore(use_bm25=False)
|
||||
def test_update_bm25(self, ds, documents):
|
||||
ds.write_documents(documents)
|
||||
ds.update_bm25()
|
||||
bm25_representation = ds.bm25[ds.index]
|
||||
assert isinstance(bm25_representation, BM25)
|
||||
assert bm25_representation.corpus_size == ds.get_document_count()
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_update_bm25_table(self, ds):
|
||||
table_doc = Document(
|
||||
content=pd.DataFrame(columns=["id", "text"], data=[[0, "This is a test"], ["2", "This is another test"]]),
|
||||
content_type="table",
|
||||
)
|
||||
ds.write_documents([table_doc])
|
||||
bm25_representation = ds.bm25[ds.index]
|
||||
assert isinstance(bm25_representation, BM25)
|
||||
assert bm25_representation.corpus_size == ds.get_document_count()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user