Refactor database layer

This commit is contained in:
Tanay Soni 2020-01-22 15:33:18 +01:00
parent 8cdb0ff482
commit b5b62c569e
7 changed files with 120 additions and 107 deletions

View File

@ -1,8 +1,6 @@
from haystack.retriever.tfidf import TfidfRetriever
from haystack.reader.farm import FARMReader
from haystack.database import db
import logging
import farm
import pandas as pd
pd.options.display.max_colwidth = 80
@ -40,25 +38,7 @@ class Finder:
# 1) Optional: reduce the search space via document tags
if filters:
query = """
SELECT id FROM document WHERE id in (
SELECT dt.document_id
FROM document_tag dt JOIN
tag t
ON t.id = dt.tag_id
GROUP BY dt.document_id
"""
tag_filters = []
if filters:
for tag, value in filters.items():
if value:
tag_filters.append(
f"SUM(CASE WHEN t.value='{value}' THEN 1 ELSE 0 END) > 0"
)
final_query = f"{query} HAVING {' AND '.join(tag_filters)});"
query_results = db.session.execute(final_query)
candidate_doc_ids = [row[0] for row in query_results]
candidate_doc_ids = self.retriever.datastore.get_document_ids_by_tags(filters)
else:
candidate_doc_ids = None

View File

@ -1,28 +0,0 @@
import logging
import os
from flask import Flask
from flask_sqlalchemy import SQLAlchemy
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
DATABASE_URL = os.getenv("DATABASE_URL", None)
if not DATABASE_URL:
try:
from qa_config import DATABASE_URL
except ModuleNotFoundError:
logging.info(
"Using localhost sqlite as the database backend. as Database not configured. Add a qa_config.py file in the Python path with DATABASE_URL set."
"Continuing with the default sqlite on localhost."
)
DATABASE_URL = "sqlite://"
app = Flask(__name__)
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
app.config["SQLALCHEMY_DATABASE_URI"] = f"{DATABASE_URL}"
db = SQLAlchemy(app)
db.create_all()

View File

@ -1,32 +0,0 @@
from sqlalchemy.orm import relationship
from haystack.database import db
class ORMBase(db.Model):
__abstract__ = True
id = db.Column(db.Integer, primary_key=True)
created = db.Column(db.DateTime, server_default=db.func.now())
updated = db.Column(
db.DateTime, server_default=db.func.now(), server_onupdate=db.func.now()
)
class Document(ORMBase):
name = db.Column(db.String)
text = db.Column(db.String)
tags = relationship("Tag", secondary="document_tag", backref="Document")
class Tag(ORMBase):
name = db.Column(db.String)
value = db.Column(db.String)
documents = relationship("Document", secondary="document_tag", backref="Tag")
class DocumentTag(ORMBase):
document_id = db.Column(db.Integer, db.ForeignKey("document.id"), nullable=False)
tag_id = db.Column(db.Integer, db.ForeignKey("tag.id"), nullable=False)

100
haystack/database/sql.py Normal file
View File

@ -0,0 +1,100 @@
from sqlalchemy import create_engine, Column, Integer, String, DateTime, func, ForeignKey
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship, sessionmaker
from haystack.database.base import BaseDocumentStore
Base = declarative_base()
class ORMBase(Base):
__abstract__ = True
id = Column(Integer, primary_key=True)
created = Column(DateTime, server_default=func.now())
updated = Column(DateTime, server_default=func.now(), server_onupdate=func.now())
class Document(ORMBase):
__tablename__ = "document"
name = Column(String)
text = Column(String)
tags = relationship("Tag", secondary="document_tag", backref="Document")
class Tag(ORMBase):
__tablename__ = "tag"
name = Column(String)
value = Column(String)
documents = relationship("Document", secondary="document_tag", backref="Tag")
class DocumentTag(ORMBase):
__tablename__ = "document_tag"
document_id = Column(Integer, ForeignKey("document.id"), nullable=False)
tag_id = Column(Integer, ForeignKey("tag.id"), nullable=False)
class SQLDocumentStore(BaseDocumentStore):
def __init__(self, url="sqlite://"):
engine = create_engine(url)
ORMBase.metadata.create_all(engine)
Session = sessionmaker(bind=engine)
self.session = Session()
def get_document_by_id(self, id):
document_row = self.session.query(Document).get(id)
document = {
"id": document_row.id,
"name": document_row.name,
"text": document_row.text,
"tags": document_row.tags,
}
return document
def get_all_documents(self):
document_rows = self.session.query(Document).all()
documents = []
for row in document_rows:
documents.append({"id": row.id, "name": row.name, "text": row.text, "tags": row.tags})
return documents
def get_document_ids_by_tags(self, tags):
"""
Get list of document ids that have tags from the given list of tags.
:param tags: limit scope to documents having the given tags and their corresponding values.
The format for the dict is {"tag-1": "value-1", "tag-2": "value-2" ...}
"""
if not tags:
raise Exception("No tag supplied for filtering the documents")
query = """
SELECT id FROM document WHERE id in (
SELECT dt.document_id
FROM document_tag dt JOIN
tag t
ON t.id = dt.tag_id
GROUP BY dt.document_id
"""
tag_filters = []
for tag, value in tags.items():
if value:
tag_filters.append(f"SUM(CASE WHEN t.value='{value}' THEN 1 ELSE 0 END) > 0")
final_query = f"{query} HAVING {' AND '.join(tag_filters)});"
query_results = self.session.execute(final_query)
doc_ids = [row[0] for row in query_results]
return doc_ids
def write_documents(self, documents):
for doc in documents:
row = Document(name=doc["name"], text=doc["text"])
self.session.add(row)
self.session.commit()

View File

@ -1,4 +1,3 @@
from haystack.database.orm import Document, db
from pathlib import Path
import logging
from farm.data_handler.utils import http_get
@ -9,7 +8,7 @@ import zipfile
logger = logging.getLogger(__name__)
def write_documents_to_db(document_dir, clean_func=None, only_empty_db=False):
def write_documents_to_db(datastore, document_dir, clean_func=None, only_empty_db=False):
"""
Write all text files(.txt) in the sub-directories of the given path to the connected database.
@ -24,23 +23,28 @@ def write_documents_to_db(document_dir, clean_func=None, only_empty_db=False):
# check if db has already docs
if only_empty_db:
n_docs = db.session.query(Document).count()
n_docs = len(datastore.get_all_documents())
if n_docs > 0:
logger.info(f"Skip writing documents since DB already contains {n_docs} docs ... "
"(Disable `only_empty_db`, if you want to add docs anyway.)")
return None
# read and add docs
documents_to_write = []
for path in file_paths:
with open(path) as doc:
text = doc.read()
if clean_func:
text = clean_func(text)
doc = Document(name=path.name, text=text)
db.session.add(doc)
db.session.commit()
n_docs += 1
logger.info(f"Wrote {n_docs} docs to DB")
documents_to_write.append(
{
"name": path.name,
"text": text,
}
)
datastore.write_documents(documents_to_write)
logger.info(f"Wrote {len(documents_to_write)} docs to DB")
def fetch_archive_from_http(url, output_dir, proxies=None):

View File

@ -4,8 +4,6 @@ import logging
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from haystack.database import db
from haystack.database.orm import Document
logger = logging.getLogger(__name__)
@ -37,7 +35,7 @@ class TfidfRetriever(BaseRetriever):
It uses sklearn's TfidfVectorizer to compute a tf-idf matrix.
"""
def __init__(self):
def __init__(self, datastore):
self.vectorizer = TfidfVectorizer(
lowercase=True,
stop_words=None,
@ -45,6 +43,7 @@ class TfidfRetriever(BaseRetriever):
ngram_range=(1, 1),
)
self.datastore = datastore
self.paragraphs = self._get_all_paragraphs()
self.df = None
self.fit()
@ -53,17 +52,17 @@ class TfidfRetriever(BaseRetriever):
"""
Split the list of documents in paragraphs
"""
documents = db.session.query(Document).all()
documents = self.datastore.get_all_documents()
paragraphs = []
p_id = 0
for doc in documents:
_pgs = [d for d in doc.text.splitlines() if d.strip()]
for p in doc.text.split("\n\n"):
_pgs = [d for d in doc["text"].splitlines() if d.strip()]
for p in doc["text"].split("\n\n"):
if not p.strip(): # skip empty paragraphs
continue
paragraphs.append(
Paragraph(document_id=doc.id, paragraph_id=p_id, text=(p,))
Paragraph(document_id=doc["id"], paragraph_id=p_id, text=(p,))
)
p_id += 1
logger.info(f"Found {len(paragraphs)} candidate paragraphs from {len(documents)} docs in DB")

View File

@ -3,21 +3,11 @@ from collections import defaultdict
import logging
import pprint
from haystack.database.orm import Document
from haystack.database.orm import db
from haystack.database.sql import Document
logger = logging.getLogger(__name__)
def create_db():
"""
Create all tables as defined by the ORM in the connected SQL database.
:return:
"""
db.session.create_all()
def print_answers(results, details="all"):
answers = results["answers"]
pp = pprint.PrettyPrinter(indent=4)