Refactor database layer

This commit is contained in:
Tanay Soni 2020-01-22 15:33:18 +01:00
parent 8cdb0ff482
commit b5b62c569e
7 changed files with 120 additions and 107 deletions

View File

@ -1,8 +1,6 @@
from haystack.retriever.tfidf import TfidfRetriever from haystack.retriever.tfidf import TfidfRetriever
from haystack.reader.farm import FARMReader from haystack.reader.farm import FARMReader
from haystack.database import db
import logging import logging
import farm
import pandas as pd import pandas as pd
pd.options.display.max_colwidth = 80 pd.options.display.max_colwidth = 80
@ -40,25 +38,7 @@ class Finder:
# 1) Optional: reduce the search space via document tags # 1) Optional: reduce the search space via document tags
if filters: if filters:
query = """ candidate_doc_ids = self.retriever.datastore.get_document_ids_by_tags(filters)
SELECT id FROM document WHERE id in (
SELECT dt.document_id
FROM document_tag dt JOIN
tag t
ON t.id = dt.tag_id
GROUP BY dt.document_id
"""
tag_filters = []
if filters:
for tag, value in filters.items():
if value:
tag_filters.append(
f"SUM(CASE WHEN t.value='{value}' THEN 1 ELSE 0 END) > 0"
)
final_query = f"{query} HAVING {' AND '.join(tag_filters)});"
query_results = db.session.execute(final_query)
candidate_doc_ids = [row[0] for row in query_results]
else: else:
candidate_doc_ids = None candidate_doc_ids = None

View File

@ -1,28 +0,0 @@
import logging
import os
from flask import Flask
from flask_sqlalchemy import SQLAlchemy
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
DATABASE_URL = os.getenv("DATABASE_URL", None)
if not DATABASE_URL:
try:
from qa_config import DATABASE_URL
except ModuleNotFoundError:
logging.info(
"Using localhost sqlite as the database backend. as Database not configured. Add a qa_config.py file in the Python path with DATABASE_URL set."
"Continuing with the default sqlite on localhost."
)
DATABASE_URL = "sqlite://"
app = Flask(__name__)
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
app.config["SQLALCHEMY_DATABASE_URI"] = f"{DATABASE_URL}"
db = SQLAlchemy(app)
db.create_all()

View File

@ -1,32 +0,0 @@
from sqlalchemy.orm import relationship
from haystack.database import db
class ORMBase(db.Model):
__abstract__ = True
id = db.Column(db.Integer, primary_key=True)
created = db.Column(db.DateTime, server_default=db.func.now())
updated = db.Column(
db.DateTime, server_default=db.func.now(), server_onupdate=db.func.now()
)
class Document(ORMBase):
name = db.Column(db.String)
text = db.Column(db.String)
tags = relationship("Tag", secondary="document_tag", backref="Document")
class Tag(ORMBase):
name = db.Column(db.String)
value = db.Column(db.String)
documents = relationship("Document", secondary="document_tag", backref="Tag")
class DocumentTag(ORMBase):
document_id = db.Column(db.Integer, db.ForeignKey("document.id"), nullable=False)
tag_id = db.Column(db.Integer, db.ForeignKey("tag.id"), nullable=False)

100
haystack/database/sql.py Normal file
View File

@ -0,0 +1,100 @@
from sqlalchemy import create_engine, Column, Integer, String, DateTime, func, ForeignKey
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship, sessionmaker
from haystack.database.base import BaseDocumentStore
Base = declarative_base()
class ORMBase(Base):
__abstract__ = True
id = Column(Integer, primary_key=True)
created = Column(DateTime, server_default=func.now())
updated = Column(DateTime, server_default=func.now(), server_onupdate=func.now())
class Document(ORMBase):
__tablename__ = "document"
name = Column(String)
text = Column(String)
tags = relationship("Tag", secondary="document_tag", backref="Document")
class Tag(ORMBase):
__tablename__ = "tag"
name = Column(String)
value = Column(String)
documents = relationship("Document", secondary="document_tag", backref="Tag")
class DocumentTag(ORMBase):
__tablename__ = "document_tag"
document_id = Column(Integer, ForeignKey("document.id"), nullable=False)
tag_id = Column(Integer, ForeignKey("tag.id"), nullable=False)
class SQLDocumentStore(BaseDocumentStore):
def __init__(self, url="sqlite://"):
engine = create_engine(url)
ORMBase.metadata.create_all(engine)
Session = sessionmaker(bind=engine)
self.session = Session()
def get_document_by_id(self, id):
document_row = self.session.query(Document).get(id)
document = {
"id": document_row.id,
"name": document_row.name,
"text": document_row.text,
"tags": document_row.tags,
}
return document
def get_all_documents(self):
document_rows = self.session.query(Document).all()
documents = []
for row in document_rows:
documents.append({"id": row.id, "name": row.name, "text": row.text, "tags": row.tags})
return documents
def get_document_ids_by_tags(self, tags):
"""
Get list of document ids that have tags from the given list of tags.
:param tags: limit scope to documents having the given tags and their corresponding values.
The format for the dict is {"tag-1": "value-1", "tag-2": "value-2" ...}
"""
if not tags:
raise Exception("No tag supplied for filtering the documents")
query = """
SELECT id FROM document WHERE id in (
SELECT dt.document_id
FROM document_tag dt JOIN
tag t
ON t.id = dt.tag_id
GROUP BY dt.document_id
"""
tag_filters = []
for tag, value in tags.items():
if value:
tag_filters.append(f"SUM(CASE WHEN t.value='{value}' THEN 1 ELSE 0 END) > 0")
final_query = f"{query} HAVING {' AND '.join(tag_filters)});"
query_results = self.session.execute(final_query)
doc_ids = [row[0] for row in query_results]
return doc_ids
def write_documents(self, documents):
for doc in documents:
row = Document(name=doc["name"], text=doc["text"])
self.session.add(row)
self.session.commit()

View File

@ -1,4 +1,3 @@
from haystack.database.orm import Document, db
from pathlib import Path from pathlib import Path
import logging import logging
from farm.data_handler.utils import http_get from farm.data_handler.utils import http_get
@ -9,7 +8,7 @@ import zipfile
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def write_documents_to_db(document_dir, clean_func=None, only_empty_db=False): def write_documents_to_db(datastore, document_dir, clean_func=None, only_empty_db=False):
""" """
Write all text files(.txt) in the sub-directories of the given path to the connected database. Write all text files(.txt) in the sub-directories of the given path to the connected database.
@ -24,23 +23,28 @@ def write_documents_to_db(document_dir, clean_func=None, only_empty_db=False):
# check if db has already docs # check if db has already docs
if only_empty_db: if only_empty_db:
n_docs = db.session.query(Document).count() n_docs = len(datastore.get_all_documents())
if n_docs > 0: if n_docs > 0:
logger.info(f"Skip writing documents since DB already contains {n_docs} docs ... " logger.info(f"Skip writing documents since DB already contains {n_docs} docs ... "
"(Disable `only_empty_db`, if you want to add docs anyway.)") "(Disable `only_empty_db`, if you want to add docs anyway.)")
return None return None
# read and add docs # read and add docs
documents_to_write = []
for path in file_paths: for path in file_paths:
with open(path) as doc: with open(path) as doc:
text = doc.read() text = doc.read()
if clean_func: if clean_func:
text = clean_func(text) text = clean_func(text)
doc = Document(name=path.name, text=text)
db.session.add(doc) documents_to_write.append(
db.session.commit() {
n_docs += 1 "name": path.name,
logger.info(f"Wrote {n_docs} docs to DB") "text": text,
}
)
datastore.write_documents(documents_to_write)
logger.info(f"Wrote {len(documents_to_write)} docs to DB")
def fetch_archive_from_http(url, output_dir, proxies=None): def fetch_archive_from_http(url, output_dir, proxies=None):

View File

@ -4,8 +4,6 @@ import logging
import pandas as pd import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.feature_extraction.text import TfidfVectorizer
from haystack.database import db
from haystack.database.orm import Document
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -37,7 +35,7 @@ class TfidfRetriever(BaseRetriever):
It uses sklearn's TfidfVectorizer to compute a tf-idf matrix. It uses sklearn's TfidfVectorizer to compute a tf-idf matrix.
""" """
def __init__(self): def __init__(self, datastore):
self.vectorizer = TfidfVectorizer( self.vectorizer = TfidfVectorizer(
lowercase=True, lowercase=True,
stop_words=None, stop_words=None,
@ -45,6 +43,7 @@ class TfidfRetriever(BaseRetriever):
ngram_range=(1, 1), ngram_range=(1, 1),
) )
self.datastore = datastore
self.paragraphs = self._get_all_paragraphs() self.paragraphs = self._get_all_paragraphs()
self.df = None self.df = None
self.fit() self.fit()
@ -53,17 +52,17 @@ class TfidfRetriever(BaseRetriever):
""" """
Split the list of documents in paragraphs Split the list of documents in paragraphs
""" """
documents = db.session.query(Document).all() documents = self.datastore.get_all_documents()
paragraphs = [] paragraphs = []
p_id = 0 p_id = 0
for doc in documents: for doc in documents:
_pgs = [d for d in doc.text.splitlines() if d.strip()] _pgs = [d for d in doc["text"].splitlines() if d.strip()]
for p in doc.text.split("\n\n"): for p in doc["text"].split("\n\n"):
if not p.strip(): # skip empty paragraphs if not p.strip(): # skip empty paragraphs
continue continue
paragraphs.append( paragraphs.append(
Paragraph(document_id=doc.id, paragraph_id=p_id, text=(p,)) Paragraph(document_id=doc["id"], paragraph_id=p_id, text=(p,))
) )
p_id += 1 p_id += 1
logger.info(f"Found {len(paragraphs)} candidate paragraphs from {len(documents)} docs in DB") logger.info(f"Found {len(paragraphs)} candidate paragraphs from {len(documents)} docs in DB")

View File

@ -3,21 +3,11 @@ from collections import defaultdict
import logging import logging
import pprint import pprint
from haystack.database.orm import Document from haystack.database.sql import Document
from haystack.database.orm import db
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def create_db():
"""
Create all tables as defined by the ORM in the connected SQL database.
:return:
"""
db.session.create_all()
def print_answers(results, details="all"): def print_answers(results, details="all"):
answers = results["answers"] answers = results["answers"]
pp = pprint.PrettyPrinter(indent=4) pp = pprint.PrettyPrinter(indent=4)