mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-11 15:23:41 +00:00
Refactor database layer
This commit is contained in:
parent
8cdb0ff482
commit
b5b62c569e
@ -1,8 +1,6 @@
|
|||||||
from haystack.retriever.tfidf import TfidfRetriever
|
from haystack.retriever.tfidf import TfidfRetriever
|
||||||
from haystack.reader.farm import FARMReader
|
from haystack.reader.farm import FARMReader
|
||||||
from haystack.database import db
|
|
||||||
import logging
|
import logging
|
||||||
import farm
|
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
pd.options.display.max_colwidth = 80
|
pd.options.display.max_colwidth = 80
|
||||||
@ -40,25 +38,7 @@ class Finder:
|
|||||||
|
|
||||||
# 1) Optional: reduce the search space via document tags
|
# 1) Optional: reduce the search space via document tags
|
||||||
if filters:
|
if filters:
|
||||||
query = """
|
candidate_doc_ids = self.retriever.datastore.get_document_ids_by_tags(filters)
|
||||||
SELECT id FROM document WHERE id in (
|
|
||||||
SELECT dt.document_id
|
|
||||||
FROM document_tag dt JOIN
|
|
||||||
tag t
|
|
||||||
ON t.id = dt.tag_id
|
|
||||||
GROUP BY dt.document_id
|
|
||||||
"""
|
|
||||||
tag_filters = []
|
|
||||||
if filters:
|
|
||||||
for tag, value in filters.items():
|
|
||||||
if value:
|
|
||||||
tag_filters.append(
|
|
||||||
f"SUM(CASE WHEN t.value='{value}' THEN 1 ELSE 0 END) > 0"
|
|
||||||
)
|
|
||||||
|
|
||||||
final_query = f"{query} HAVING {' AND '.join(tag_filters)});"
|
|
||||||
query_results = db.session.execute(final_query)
|
|
||||||
candidate_doc_ids = [row[0] for row in query_results]
|
|
||||||
else:
|
else:
|
||||||
candidate_doc_ids = None
|
candidate_doc_ids = None
|
||||||
|
|
||||||
|
|||||||
@ -1,28 +0,0 @@
|
|||||||
import logging
|
|
||||||
import os
|
|
||||||
|
|
||||||
from flask import Flask
|
|
||||||
from flask_sqlalchemy import SQLAlchemy
|
|
||||||
|
|
||||||
logging.basicConfig(
|
|
||||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
||||||
datefmt="%m/%d/%Y %H:%M:%S",
|
|
||||||
level=logging.INFO,
|
|
||||||
)
|
|
||||||
|
|
||||||
DATABASE_URL = os.getenv("DATABASE_URL", None)
|
|
||||||
if not DATABASE_URL:
|
|
||||||
try:
|
|
||||||
from qa_config import DATABASE_URL
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
logging.info(
|
|
||||||
"Using localhost sqlite as the database backend. as Database not configured. Add a qa_config.py file in the Python path with DATABASE_URL set."
|
|
||||||
"Continuing with the default sqlite on localhost."
|
|
||||||
)
|
|
||||||
DATABASE_URL = "sqlite://"
|
|
||||||
|
|
||||||
app = Flask(__name__)
|
|
||||||
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
|
|
||||||
app.config["SQLALCHEMY_DATABASE_URI"] = f"{DATABASE_URL}"
|
|
||||||
db = SQLAlchemy(app)
|
|
||||||
db.create_all()
|
|
||||||
@ -1,32 +0,0 @@
|
|||||||
from sqlalchemy.orm import relationship
|
|
||||||
|
|
||||||
from haystack.database import db
|
|
||||||
|
|
||||||
|
|
||||||
class ORMBase(db.Model):
|
|
||||||
__abstract__ = True
|
|
||||||
|
|
||||||
id = db.Column(db.Integer, primary_key=True)
|
|
||||||
created = db.Column(db.DateTime, server_default=db.func.now())
|
|
||||||
updated = db.Column(
|
|
||||||
db.DateTime, server_default=db.func.now(), server_onupdate=db.func.now()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Document(ORMBase):
|
|
||||||
name = db.Column(db.String)
|
|
||||||
text = db.Column(db.String)
|
|
||||||
|
|
||||||
tags = relationship("Tag", secondary="document_tag", backref="Document")
|
|
||||||
|
|
||||||
|
|
||||||
class Tag(ORMBase):
|
|
||||||
name = db.Column(db.String)
|
|
||||||
value = db.Column(db.String)
|
|
||||||
|
|
||||||
documents = relationship("Document", secondary="document_tag", backref="Tag")
|
|
||||||
|
|
||||||
|
|
||||||
class DocumentTag(ORMBase):
|
|
||||||
document_id = db.Column(db.Integer, db.ForeignKey("document.id"), nullable=False)
|
|
||||||
tag_id = db.Column(db.Integer, db.ForeignKey("tag.id"), nullable=False)
|
|
||||||
100
haystack/database/sql.py
Normal file
100
haystack/database/sql.py
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
from sqlalchemy import create_engine, Column, Integer, String, DateTime, func, ForeignKey
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
from sqlalchemy.orm import relationship, sessionmaker
|
||||||
|
|
||||||
|
from haystack.database.base import BaseDocumentStore
|
||||||
|
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
|
class ORMBase(Base):
|
||||||
|
__abstract__ = True
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True)
|
||||||
|
created = Column(DateTime, server_default=func.now())
|
||||||
|
updated = Column(DateTime, server_default=func.now(), server_onupdate=func.now())
|
||||||
|
|
||||||
|
|
||||||
|
class Document(ORMBase):
|
||||||
|
__tablename__ = "document"
|
||||||
|
|
||||||
|
name = Column(String)
|
||||||
|
text = Column(String)
|
||||||
|
|
||||||
|
tags = relationship("Tag", secondary="document_tag", backref="Document")
|
||||||
|
|
||||||
|
|
||||||
|
class Tag(ORMBase):
|
||||||
|
__tablename__ = "tag"
|
||||||
|
|
||||||
|
name = Column(String)
|
||||||
|
value = Column(String)
|
||||||
|
|
||||||
|
documents = relationship("Document", secondary="document_tag", backref="Tag")
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentTag(ORMBase):
|
||||||
|
__tablename__ = "document_tag"
|
||||||
|
|
||||||
|
document_id = Column(Integer, ForeignKey("document.id"), nullable=False)
|
||||||
|
tag_id = Column(Integer, ForeignKey("tag.id"), nullable=False)
|
||||||
|
|
||||||
|
|
||||||
|
class SQLDocumentStore(BaseDocumentStore):
|
||||||
|
def __init__(self, url="sqlite://"):
|
||||||
|
engine = create_engine(url)
|
||||||
|
ORMBase.metadata.create_all(engine)
|
||||||
|
Session = sessionmaker(bind=engine)
|
||||||
|
self.session = Session()
|
||||||
|
|
||||||
|
def get_document_by_id(self, id):
|
||||||
|
document_row = self.session.query(Document).get(id)
|
||||||
|
document = {
|
||||||
|
"id": document_row.id,
|
||||||
|
"name": document_row.name,
|
||||||
|
"text": document_row.text,
|
||||||
|
"tags": document_row.tags,
|
||||||
|
}
|
||||||
|
return document
|
||||||
|
|
||||||
|
def get_all_documents(self):
|
||||||
|
document_rows = self.session.query(Document).all()
|
||||||
|
documents = []
|
||||||
|
for row in document_rows:
|
||||||
|
documents.append({"id": row.id, "name": row.name, "text": row.text, "tags": row.tags})
|
||||||
|
return documents
|
||||||
|
|
||||||
|
def get_document_ids_by_tags(self, tags):
|
||||||
|
"""
|
||||||
|
Get list of document ids that have tags from the given list of tags.
|
||||||
|
|
||||||
|
:param tags: limit scope to documents having the given tags and their corresponding values.
|
||||||
|
The format for the dict is {"tag-1": "value-1", "tag-2": "value-2" ...}
|
||||||
|
"""
|
||||||
|
if not tags:
|
||||||
|
raise Exception("No tag supplied for filtering the documents")
|
||||||
|
|
||||||
|
query = """
|
||||||
|
SELECT id FROM document WHERE id in (
|
||||||
|
SELECT dt.document_id
|
||||||
|
FROM document_tag dt JOIN
|
||||||
|
tag t
|
||||||
|
ON t.id = dt.tag_id
|
||||||
|
GROUP BY dt.document_id
|
||||||
|
"""
|
||||||
|
tag_filters = []
|
||||||
|
for tag, value in tags.items():
|
||||||
|
if value:
|
||||||
|
tag_filters.append(f"SUM(CASE WHEN t.value='{value}' THEN 1 ELSE 0 END) > 0")
|
||||||
|
|
||||||
|
final_query = f"{query} HAVING {' AND '.join(tag_filters)});"
|
||||||
|
query_results = self.session.execute(final_query)
|
||||||
|
|
||||||
|
doc_ids = [row[0] for row in query_results]
|
||||||
|
return doc_ids
|
||||||
|
|
||||||
|
def write_documents(self, documents):
|
||||||
|
for doc in documents:
|
||||||
|
row = Document(name=doc["name"], text=doc["text"])
|
||||||
|
self.session.add(row)
|
||||||
|
self.session.commit()
|
||||||
@ -1,4 +1,3 @@
|
|||||||
from haystack.database.orm import Document, db
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import logging
|
import logging
|
||||||
from farm.data_handler.utils import http_get
|
from farm.data_handler.utils import http_get
|
||||||
@ -9,7 +8,7 @@ import zipfile
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def write_documents_to_db(document_dir, clean_func=None, only_empty_db=False):
|
def write_documents_to_db(datastore, document_dir, clean_func=None, only_empty_db=False):
|
||||||
"""
|
"""
|
||||||
Write all text files(.txt) in the sub-directories of the given path to the connected database.
|
Write all text files(.txt) in the sub-directories of the given path to the connected database.
|
||||||
|
|
||||||
@ -24,23 +23,28 @@ def write_documents_to_db(document_dir, clean_func=None, only_empty_db=False):
|
|||||||
|
|
||||||
# check if db has already docs
|
# check if db has already docs
|
||||||
if only_empty_db:
|
if only_empty_db:
|
||||||
n_docs = db.session.query(Document).count()
|
n_docs = len(datastore.get_all_documents())
|
||||||
if n_docs > 0:
|
if n_docs > 0:
|
||||||
logger.info(f"Skip writing documents since DB already contains {n_docs} docs ... "
|
logger.info(f"Skip writing documents since DB already contains {n_docs} docs ... "
|
||||||
"(Disable `only_empty_db`, if you want to add docs anyway.)")
|
"(Disable `only_empty_db`, if you want to add docs anyway.)")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# read and add docs
|
# read and add docs
|
||||||
|
documents_to_write = []
|
||||||
for path in file_paths:
|
for path in file_paths:
|
||||||
with open(path) as doc:
|
with open(path) as doc:
|
||||||
text = doc.read()
|
text = doc.read()
|
||||||
if clean_func:
|
if clean_func:
|
||||||
text = clean_func(text)
|
text = clean_func(text)
|
||||||
doc = Document(name=path.name, text=text)
|
|
||||||
db.session.add(doc)
|
documents_to_write.append(
|
||||||
db.session.commit()
|
{
|
||||||
n_docs += 1
|
"name": path.name,
|
||||||
logger.info(f"Wrote {n_docs} docs to DB")
|
"text": text,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
datastore.write_documents(documents_to_write)
|
||||||
|
logger.info(f"Wrote {len(documents_to_write)} docs to DB")
|
||||||
|
|
||||||
|
|
||||||
def fetch_archive_from_http(url, output_dir, proxies=None):
|
def fetch_archive_from_http(url, output_dir, proxies=None):
|
||||||
|
|||||||
@ -4,8 +4,6 @@ import logging
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||||
|
|
||||||
from haystack.database import db
|
|
||||||
from haystack.database.orm import Document
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -37,7 +35,7 @@ class TfidfRetriever(BaseRetriever):
|
|||||||
It uses sklearn's TfidfVectorizer to compute a tf-idf matrix.
|
It uses sklearn's TfidfVectorizer to compute a tf-idf matrix.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, datastore):
|
||||||
self.vectorizer = TfidfVectorizer(
|
self.vectorizer = TfidfVectorizer(
|
||||||
lowercase=True,
|
lowercase=True,
|
||||||
stop_words=None,
|
stop_words=None,
|
||||||
@ -45,6 +43,7 @@ class TfidfRetriever(BaseRetriever):
|
|||||||
ngram_range=(1, 1),
|
ngram_range=(1, 1),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.datastore = datastore
|
||||||
self.paragraphs = self._get_all_paragraphs()
|
self.paragraphs = self._get_all_paragraphs()
|
||||||
self.df = None
|
self.df = None
|
||||||
self.fit()
|
self.fit()
|
||||||
@ -53,17 +52,17 @@ class TfidfRetriever(BaseRetriever):
|
|||||||
"""
|
"""
|
||||||
Split the list of documents in paragraphs
|
Split the list of documents in paragraphs
|
||||||
"""
|
"""
|
||||||
documents = db.session.query(Document).all()
|
documents = self.datastore.get_all_documents()
|
||||||
|
|
||||||
paragraphs = []
|
paragraphs = []
|
||||||
p_id = 0
|
p_id = 0
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
_pgs = [d for d in doc.text.splitlines() if d.strip()]
|
_pgs = [d for d in doc["text"].splitlines() if d.strip()]
|
||||||
for p in doc.text.split("\n\n"):
|
for p in doc["text"].split("\n\n"):
|
||||||
if not p.strip(): # skip empty paragraphs
|
if not p.strip(): # skip empty paragraphs
|
||||||
continue
|
continue
|
||||||
paragraphs.append(
|
paragraphs.append(
|
||||||
Paragraph(document_id=doc.id, paragraph_id=p_id, text=(p,))
|
Paragraph(document_id=doc["id"], paragraph_id=p_id, text=(p,))
|
||||||
)
|
)
|
||||||
p_id += 1
|
p_id += 1
|
||||||
logger.info(f"Found {len(paragraphs)} candidate paragraphs from {len(documents)} docs in DB")
|
logger.info(f"Found {len(paragraphs)} candidate paragraphs from {len(documents)} docs in DB")
|
||||||
|
|||||||
@ -3,21 +3,11 @@ from collections import defaultdict
|
|||||||
import logging
|
import logging
|
||||||
import pprint
|
import pprint
|
||||||
|
|
||||||
from haystack.database.orm import Document
|
from haystack.database.sql import Document
|
||||||
from haystack.database.orm import db
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def create_db():
|
|
||||||
"""
|
|
||||||
Create all tables as defined by the ORM in the connected SQL database.
|
|
||||||
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
db.session.create_all()
|
|
||||||
|
|
||||||
|
|
||||||
def print_answers(results, details="all"):
|
def print_answers(results, details="all"):
|
||||||
answers = results["answers"]
|
answers = results["answers"]
|
||||||
pp = pprint.PrettyPrinter(indent=4)
|
pp = pprint.PrettyPrinter(indent=4)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user