mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-01 10:19:23 +00:00
Refactor database layer
This commit is contained in:
parent
8cdb0ff482
commit
b5b62c569e
@ -1,8 +1,6 @@
|
||||
from haystack.retriever.tfidf import TfidfRetriever
|
||||
from haystack.reader.farm import FARMReader
|
||||
from haystack.database import db
|
||||
import logging
|
||||
import farm
|
||||
|
||||
import pandas as pd
|
||||
pd.options.display.max_colwidth = 80
|
||||
@ -40,25 +38,7 @@ class Finder:
|
||||
|
||||
# 1) Optional: reduce the search space via document tags
|
||||
if filters:
|
||||
query = """
|
||||
SELECT id FROM document WHERE id in (
|
||||
SELECT dt.document_id
|
||||
FROM document_tag dt JOIN
|
||||
tag t
|
||||
ON t.id = dt.tag_id
|
||||
GROUP BY dt.document_id
|
||||
"""
|
||||
tag_filters = []
|
||||
if filters:
|
||||
for tag, value in filters.items():
|
||||
if value:
|
||||
tag_filters.append(
|
||||
f"SUM(CASE WHEN t.value='{value}' THEN 1 ELSE 0 END) > 0"
|
||||
)
|
||||
|
||||
final_query = f"{query} HAVING {' AND '.join(tag_filters)});"
|
||||
query_results = db.session.execute(final_query)
|
||||
candidate_doc_ids = [row[0] for row in query_results]
|
||||
candidate_doc_ids = self.retriever.datastore.get_document_ids_by_tags(filters)
|
||||
else:
|
||||
candidate_doc_ids = None
|
||||
|
||||
|
||||
@ -1,28 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from flask import Flask
|
||||
from flask_sqlalchemy import SQLAlchemy
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", None)
|
||||
if not DATABASE_URL:
|
||||
try:
|
||||
from qa_config import DATABASE_URL
|
||||
except ModuleNotFoundError:
|
||||
logging.info(
|
||||
"Using localhost sqlite as the database backend. as Database not configured. Add a qa_config.py file in the Python path with DATABASE_URL set."
|
||||
"Continuing with the default sqlite on localhost."
|
||||
)
|
||||
DATABASE_URL = "sqlite://"
|
||||
|
||||
app = Flask(__name__)
|
||||
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
|
||||
app.config["SQLALCHEMY_DATABASE_URI"] = f"{DATABASE_URL}"
|
||||
db = SQLAlchemy(app)
|
||||
db.create_all()
|
||||
@ -1,32 +0,0 @@
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from haystack.database import db
|
||||
|
||||
|
||||
class ORMBase(db.Model):
|
||||
__abstract__ = True
|
||||
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
created = db.Column(db.DateTime, server_default=db.func.now())
|
||||
updated = db.Column(
|
||||
db.DateTime, server_default=db.func.now(), server_onupdate=db.func.now()
|
||||
)
|
||||
|
||||
|
||||
class Document(ORMBase):
|
||||
name = db.Column(db.String)
|
||||
text = db.Column(db.String)
|
||||
|
||||
tags = relationship("Tag", secondary="document_tag", backref="Document")
|
||||
|
||||
|
||||
class Tag(ORMBase):
|
||||
name = db.Column(db.String)
|
||||
value = db.Column(db.String)
|
||||
|
||||
documents = relationship("Document", secondary="document_tag", backref="Tag")
|
||||
|
||||
|
||||
class DocumentTag(ORMBase):
|
||||
document_id = db.Column(db.Integer, db.ForeignKey("document.id"), nullable=False)
|
||||
tag_id = db.Column(db.Integer, db.ForeignKey("tag.id"), nullable=False)
|
||||
100
haystack/database/sql.py
Normal file
100
haystack/database/sql.py
Normal file
@ -0,0 +1,100 @@
|
||||
from sqlalchemy import create_engine, Column, Integer, String, DateTime, func, ForeignKey
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import relationship, sessionmaker
|
||||
|
||||
from haystack.database.base import BaseDocumentStore
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class ORMBase(Base):
|
||||
__abstract__ = True
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
created = Column(DateTime, server_default=func.now())
|
||||
updated = Column(DateTime, server_default=func.now(), server_onupdate=func.now())
|
||||
|
||||
|
||||
class Document(ORMBase):
|
||||
__tablename__ = "document"
|
||||
|
||||
name = Column(String)
|
||||
text = Column(String)
|
||||
|
||||
tags = relationship("Tag", secondary="document_tag", backref="Document")
|
||||
|
||||
|
||||
class Tag(ORMBase):
|
||||
__tablename__ = "tag"
|
||||
|
||||
name = Column(String)
|
||||
value = Column(String)
|
||||
|
||||
documents = relationship("Document", secondary="document_tag", backref="Tag")
|
||||
|
||||
|
||||
class DocumentTag(ORMBase):
|
||||
__tablename__ = "document_tag"
|
||||
|
||||
document_id = Column(Integer, ForeignKey("document.id"), nullable=False)
|
||||
tag_id = Column(Integer, ForeignKey("tag.id"), nullable=False)
|
||||
|
||||
|
||||
class SQLDocumentStore(BaseDocumentStore):
|
||||
def __init__(self, url="sqlite://"):
|
||||
engine = create_engine(url)
|
||||
ORMBase.metadata.create_all(engine)
|
||||
Session = sessionmaker(bind=engine)
|
||||
self.session = Session()
|
||||
|
||||
def get_document_by_id(self, id):
|
||||
document_row = self.session.query(Document).get(id)
|
||||
document = {
|
||||
"id": document_row.id,
|
||||
"name": document_row.name,
|
||||
"text": document_row.text,
|
||||
"tags": document_row.tags,
|
||||
}
|
||||
return document
|
||||
|
||||
def get_all_documents(self):
|
||||
document_rows = self.session.query(Document).all()
|
||||
documents = []
|
||||
for row in document_rows:
|
||||
documents.append({"id": row.id, "name": row.name, "text": row.text, "tags": row.tags})
|
||||
return documents
|
||||
|
||||
def get_document_ids_by_tags(self, tags):
|
||||
"""
|
||||
Get list of document ids that have tags from the given list of tags.
|
||||
|
||||
:param tags: limit scope to documents having the given tags and their corresponding values.
|
||||
The format for the dict is {"tag-1": "value-1", "tag-2": "value-2" ...}
|
||||
"""
|
||||
if not tags:
|
||||
raise Exception("No tag supplied for filtering the documents")
|
||||
|
||||
query = """
|
||||
SELECT id FROM document WHERE id in (
|
||||
SELECT dt.document_id
|
||||
FROM document_tag dt JOIN
|
||||
tag t
|
||||
ON t.id = dt.tag_id
|
||||
GROUP BY dt.document_id
|
||||
"""
|
||||
tag_filters = []
|
||||
for tag, value in tags.items():
|
||||
if value:
|
||||
tag_filters.append(f"SUM(CASE WHEN t.value='{value}' THEN 1 ELSE 0 END) > 0")
|
||||
|
||||
final_query = f"{query} HAVING {' AND '.join(tag_filters)});"
|
||||
query_results = self.session.execute(final_query)
|
||||
|
||||
doc_ids = [row[0] for row in query_results]
|
||||
return doc_ids
|
||||
|
||||
def write_documents(self, documents):
|
||||
for doc in documents:
|
||||
row = Document(name=doc["name"], text=doc["text"])
|
||||
self.session.add(row)
|
||||
self.session.commit()
|
||||
@ -1,4 +1,3 @@
|
||||
from haystack.database.orm import Document, db
|
||||
from pathlib import Path
|
||||
import logging
|
||||
from farm.data_handler.utils import http_get
|
||||
@ -9,7 +8,7 @@ import zipfile
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def write_documents_to_db(document_dir, clean_func=None, only_empty_db=False):
|
||||
def write_documents_to_db(datastore, document_dir, clean_func=None, only_empty_db=False):
|
||||
"""
|
||||
Write all text files(.txt) in the sub-directories of the given path to the connected database.
|
||||
|
||||
@ -24,23 +23,28 @@ def write_documents_to_db(document_dir, clean_func=None, only_empty_db=False):
|
||||
|
||||
# check if db has already docs
|
||||
if only_empty_db:
|
||||
n_docs = db.session.query(Document).count()
|
||||
n_docs = len(datastore.get_all_documents())
|
||||
if n_docs > 0:
|
||||
logger.info(f"Skip writing documents since DB already contains {n_docs} docs ... "
|
||||
"(Disable `only_empty_db`, if you want to add docs anyway.)")
|
||||
return None
|
||||
|
||||
# read and add docs
|
||||
documents_to_write = []
|
||||
for path in file_paths:
|
||||
with open(path) as doc:
|
||||
text = doc.read()
|
||||
if clean_func:
|
||||
text = clean_func(text)
|
||||
doc = Document(name=path.name, text=text)
|
||||
db.session.add(doc)
|
||||
db.session.commit()
|
||||
n_docs += 1
|
||||
logger.info(f"Wrote {n_docs} docs to DB")
|
||||
|
||||
documents_to_write.append(
|
||||
{
|
||||
"name": path.name,
|
||||
"text": text,
|
||||
}
|
||||
)
|
||||
datastore.write_documents(documents_to_write)
|
||||
logger.info(f"Wrote {len(documents_to_write)} docs to DB")
|
||||
|
||||
|
||||
def fetch_archive_from_http(url, output_dir, proxies=None):
|
||||
|
||||
@ -4,8 +4,6 @@ import logging
|
||||
import pandas as pd
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
|
||||
from haystack.database import db
|
||||
from haystack.database.orm import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -37,7 +35,7 @@ class TfidfRetriever(BaseRetriever):
|
||||
It uses sklearn's TfidfVectorizer to compute a tf-idf matrix.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, datastore):
|
||||
self.vectorizer = TfidfVectorizer(
|
||||
lowercase=True,
|
||||
stop_words=None,
|
||||
@ -45,6 +43,7 @@ class TfidfRetriever(BaseRetriever):
|
||||
ngram_range=(1, 1),
|
||||
)
|
||||
|
||||
self.datastore = datastore
|
||||
self.paragraphs = self._get_all_paragraphs()
|
||||
self.df = None
|
||||
self.fit()
|
||||
@ -53,17 +52,17 @@ class TfidfRetriever(BaseRetriever):
|
||||
"""
|
||||
Split the list of documents in paragraphs
|
||||
"""
|
||||
documents = db.session.query(Document).all()
|
||||
documents = self.datastore.get_all_documents()
|
||||
|
||||
paragraphs = []
|
||||
p_id = 0
|
||||
for doc in documents:
|
||||
_pgs = [d for d in doc.text.splitlines() if d.strip()]
|
||||
for p in doc.text.split("\n\n"):
|
||||
_pgs = [d for d in doc["text"].splitlines() if d.strip()]
|
||||
for p in doc["text"].split("\n\n"):
|
||||
if not p.strip(): # skip empty paragraphs
|
||||
continue
|
||||
paragraphs.append(
|
||||
Paragraph(document_id=doc.id, paragraph_id=p_id, text=(p,))
|
||||
Paragraph(document_id=doc["id"], paragraph_id=p_id, text=(p,))
|
||||
)
|
||||
p_id += 1
|
||||
logger.info(f"Found {len(paragraphs)} candidate paragraphs from {len(documents)} docs in DB")
|
||||
|
||||
@ -3,21 +3,11 @@ from collections import defaultdict
|
||||
import logging
|
||||
import pprint
|
||||
|
||||
from haystack.database.orm import Document
|
||||
from haystack.database.orm import db
|
||||
from haystack.database.sql import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_db():
|
||||
"""
|
||||
Create all tables as defined by the ORM in the connected SQL database.
|
||||
|
||||
:return:
|
||||
"""
|
||||
db.session.create_all()
|
||||
|
||||
|
||||
def print_answers(results, details="all"):
|
||||
answers = results["answers"]
|
||||
pp = pprint.PrettyPrinter(indent=4)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user