Add metadata for TF-IDF Retriever (#122)

This commit is contained in:
Stan Kirdey 2020-05-28 01:55:28 -07:00 committed by GitHub
parent 46a065e9dc
commit ca6778d934
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 16 additions and 12 deletions

View File

@ -1,4 +1,5 @@
from sqlalchemy import create_engine, Column, Integer, String, DateTime, func, ForeignKey
import json
from sqlalchemy import create_engine, Column, Integer, String, DateTime, func, ForeignKey, PickleType
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship, sessionmaker
@ -20,6 +21,7 @@ class Document(ORMBase):
name = Column(String)
text = Column(String)
meta_data = Column(PickleType)
tags = relationship("Tag", secondary="document_tag", backref="Document")
@ -91,7 +93,7 @@ class SQLDocumentStore(BaseDocumentStore):
def write_documents(self, documents):
for doc in documents:
row = Document(name=doc["name"], text=doc["text"])
row = Document(name=doc["name"], text=doc["text"], meta_data=doc.get("meta", {}))
self.session.add(row)
self.session.commit()
@ -102,6 +104,7 @@ class SQLDocumentStore(BaseDocumentStore):
document = DocumentSchema(
id=row.id,
text=row.text,
meta=row.tags
meta=row.meta_data,
tags=row.tags
)
return document

View File

@ -10,7 +10,7 @@ from sklearn.feature_extraction.text import TfidfVectorizer
logger = logging.getLogger(__name__)
# TODO make Paragraph generic for configurable units of text eg, pages, paragraphs, or split by a char_limit
Paragraph = namedtuple("Paragraph", ["paragraph_id", "document_id", "text"])
Paragraph = namedtuple("Paragraph", ["paragraph_id", "document_id", "text", "meta"])
class TfidfRetriever(BaseRetriever):
@ -45,11 +45,11 @@ class TfidfRetriever(BaseRetriever):
paragraphs = []
p_id = 0
for doc in documents:
for p in doc.text.split("\n\n"):
for p in doc.text.split("\n\n"): # TODO: this assumes paragraphs are separated by "\n\n". Can be switched to paragraph tokenizer.
if not p.strip(): # skip empty paragraphs
continue
paragraphs.append(
Paragraph(document_id=doc.id, paragraph_id=p_id, text=(p,))
Paragraph(document_id=doc.id, paragraph_id=p_id, text=(p,), meta=doc.meta)
)
p_id += 1
logger.info(f"Found {len(paragraphs)} candidate paragraphs from {len(documents)} docs in DB")
@ -83,7 +83,7 @@ class TfidfRetriever(BaseRetriever):
# get actual content for the top candidates
paragraphs = list(df_sliced.text.values)
meta_data = [{"document_id": row["document_id"], "paragraph_id": row["paragraph_id"]}
meta_data = [{"document_id": row["document_id"], "paragraph_id": row["paragraph_id"], "meta": row.get("meta", {})}
for idx, row in df_sliced.iterrows()]
documents = []
@ -91,7 +91,8 @@ class TfidfRetriever(BaseRetriever):
documents.append(
Document(
id=meta["paragraph_id"],
text=para
text=para,
meta=meta.get("meta", {})
))
return documents

View File

@ -6,9 +6,9 @@ from haystack.retriever.tfidf import TfidfRetriever
def test_finder_get_answers():
test_docs = [
{"name": "testing the finder 1", "text": "testing the finder with pyhton unit test 1"},
{"name": "testing the finder 2", "text": "testing the finder with pyhton unit test 2"},
{"name": "testing the finder 3", "text": "testing the finder with pyhton unit test 3"}
{"name": "testing the finder 1", "text": "testing the finder with pyhton unit test 1", "meta": {"test": "test"}},
{"name": "testing the finder 2", "text": "testing the finder with pyhton unit test 2", "meta": {"test": "test"}},
{"name": "testing the finder 3", "text": "testing the finder with pyhton unit test 3", "meta": {"test": "test"}}
]
document_store = SQLDocumentStore(url="sqlite:///qa_test.db")

View File

@ -16,4 +16,4 @@ def test_tfidf_retriever():
retriever = TfidfRetriever(document_store)
retriever.fit()
assert retriever.retrieve("godzilla", top_k=1) == [Document(id='0', text='godzilla says hello', external_source_id=None, question=None, query_score=None, meta=None)]
assert retriever.retrieve("godzilla", top_k=1) == [Document(id='0', text='godzilla says hello', external_source_id=None, question=None, query_score=None, meta={})]