Sara Zan 13510aa753
Refactoring of the haystack package (#1624)
* Files moved, imports all broken

* Fix most imports and docstrings into

* Fix the paths to the modules in the API docs

* Add latest docstring and tutorial changes

* Add a few pipelines that were lost in the inports

* Fix a bunch of mypy warnings

* Add latest docstring and tutorial changes

* Create a file_classifier module

* Add docs for file_classifier

* Fixed most circular imports, now the REST API can start

* Add latest docstring and tutorial changes

* Tackling more mypy issues

* Reintroduce  from FARM and fix last mypy issues hopefully

* Re-enable old-style imports

* Fix some more import from the top-level  package in an attempt to sort out circular imports

* Fix some imports in tests to new-style to prevent failed class equalities from breaking tests

* Change document_store into document_stores

* Update imports in tutorials

* Add latest docstring and tutorial changes

* Probably fixes summarizer tests

* Improve the old-style import allowing module imports (should work)

* Try to fix the docs

* Remove dedicated KnowledgeGraph page from autodocs

* Remove dedicated GraphRetriever page from autodocs

* Fix generate_docstrings.sh with an updated list of yaml files to look for

* Fix some more modules in the docs

* Fix the document stores docs too

* Fix a small issue on Tutorial14

* Add latest docstring and tutorial changes

* Add deprecation warning to old-style imports

* Remove stray folder and import Dict into dense.py

* Change import path for MLFlowLogger

* Add old loggers path to the import path aliases

* Fix debug output of convert_ipynb.py

* Fix circular import on BaseRetriever

* Missed one merge block

* re-run tutorial 5

* Fix imports in tutorial 5

* Re-enable squad_to_dpr CLI from the root package and move get_batches_from_generator into document_stores.base

* Add latest docstring and tutorial changes

* Fix typo in utils __init__

* Fix a few more imports

* Fix benchmarks too

* New-style imports in test_knowledge_graph

* Rollback setup.py

* Rollback squad_to_dpr too

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2021-10-25 15:50:23 +02:00

171 lines
7.4 KiB
Python

import os
from haystack.document_stores.sql import SQLDocumentStore
from haystack.document_stores.memory import InMemoryDocumentStore
from haystack.document_stores.elasticsearch import Elasticsearch, ElasticsearchDocumentStore, OpenSearchDocumentStore
from haystack.document_stores.faiss import FAISSDocumentStore
from haystack.document_stores.milvus import MilvusDocumentStore
from haystack.nodes.retriever.sparse import ElasticsearchRetriever, TfidfRetriever
from haystack.nodes.retriever.dense import DensePassageRetriever, EmbeddingRetriever
from haystack.nodes.reader.farm import FARMReader
from haystack.nodes.reader.transformers import TransformersReader
from haystack.utils import launch_milvus, launch_es, launch_opensearch
from haystack.modeling.data_handler.processor import http_get
import logging
import subprocess
import time
import json
from typing import Union
from pathlib import Path
logger = logging.getLogger(__name__)
reader_models = ["deepset/roberta-base-squad2", "deepset/minilm-uncased-squad2", "deepset/bert-base-cased-squad2", "deepset/bert-large-uncased-whole-word-masking-squad2", "deepset/xlm-roberta-large-squad2"]
reader_types = ["farm"]
doc_index = "eval_document"
label_index = "label"
def get_document_store(document_store_type, similarity='dot_product', index="document"):
""" TODO This method is taken from test/conftest.py but maybe should be within Haystack.
Perhaps a class method of DocStore that just takes string for type of DocStore"""
if document_store_type == "sql":
if os.path.exists("haystack_test.db"):
os.remove("haystack_test.db")
document_store = SQLDocumentStore(url="sqlite:///haystack_test.db")
assert document_store.get_document_count() == 0
elif document_store_type == "memory":
document_store = InMemoryDocumentStore()
elif document_store_type == "elasticsearch":
launch_es()
# make sure we start from a fresh index
client = Elasticsearch()
client.indices.delete(index='haystack_test*', ignore=[404])
document_store = ElasticsearchDocumentStore(index="eval_document", similarity=similarity, timeout=3000)
elif document_store_type in ("milvus_flat", "milvus_hnsw"):
launch_milvus()
if document_store_type == "milvus_flat":
index_type = "FLAT"
index_param = None
search_param = None
elif document_store_type == "milvus_hnsw":
index_type = "HNSW"
index_param = {"M": 64, "efConstruction": 80}
search_param = {"ef": 20}
document_store = MilvusDocumentStore(
similarity=similarity,
index_type=index_type,
index_param=index_param,
search_param=search_param,
index=index
)
assert document_store.get_document_count(index="eval_document") == 0
elif document_store_type in ("faiss_flat", "faiss_hnsw"):
if document_store_type == "faiss_flat":
index_type = "Flat"
elif document_store_type == "faiss_hnsw":
index_type = "HNSW"
status = subprocess.run(
['docker rm -f haystack-postgres'],
shell=True)
time.sleep(1)
status = subprocess.run(
['docker run --name haystack-postgres -p 5432:5432 -e POSTGRES_PASSWORD=password -d postgres'],
shell=True)
time.sleep(6)
status = subprocess.run(
['docker exec haystack-postgres psql -U postgres -c "CREATE DATABASE haystack;"'], shell=True)
time.sleep(1)
document_store = FAISSDocumentStore(
sql_url="postgresql://postgres:password@localhost:5432/haystack",
faiss_index_factory_str=index_type,
similarity=similarity,
index=index
)
assert document_store.get_document_count() == 0
elif document_store_type in ("opensearch_flat", "opensearch_hnsw"):
launch_opensearch()
if document_store_type == "opensearch_flat":
index_type = "flat"
elif document_store_type == "opensearch_hnsw":
index_type = "hnsw"
document_store = OpenSearchDocumentStore(index_type=index_type, port=9201, timeout=3000)
else:
raise Exception(f"No document store fixture for '{document_store_type}'")
return document_store
def get_retriever(retriever_name, doc_store, devices):
if retriever_name == "elastic":
return ElasticsearchRetriever(doc_store)
if retriever_name == "tfidf":
return TfidfRetriever(doc_store)
if retriever_name == "dpr":
return DensePassageRetriever(document_store=doc_store,
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
use_gpu=True,
use_fast_tokenizers=False,
devices=devices)
if retriever_name == "sentence_transformers":
return EmbeddingRetriever(document_store=doc_store,
embedding_model="nq-distilbert-base-v1",
use_gpu=True,
model_format="sentence_transformers")
def get_reader(reader_name, reader_type, max_seq_len=384):
reader_class = None
if reader_type == "farm":
reader_class = FARMReader
elif reader_type == "transformers":
reader_class = TransformersReader
return reader_class(reader_name, top_k_per_candidate=4, max_seq_len=max_seq_len)
def index_to_doc_store(doc_store, docs, retriever, labels=None):
doc_store.write_documents(docs, doc_index)
if labels:
doc_store.write_labels(labels, index=label_index)
# these lines are not run if the docs.embedding field is already populated with precomputed embeddings
# See the prepare_data() fn in the retriever benchmark script
if callable(getattr(retriever, "embed_passages", None)) and docs[0].embedding is None:
doc_store.update_embeddings(retriever, index=doc_index)
def load_config(config_filename, ci):
conf = json.load(open(config_filename))
if ci:
params = conf["params"]["ci"]
else:
params = conf["params"]["full"]
filenames = conf["filenames"]
max_docs = max(params["n_docs_options"])
n_docs_keys = sorted([int(x) for x in list(filenames["embeddings_filenames"])])
for k in n_docs_keys:
if max_docs <= k:
filenames["embeddings_filenames"] = [filenames["embeddings_filenames"][str(k)]]
filenames["filename_negative"] = filenames["filenames_negative"][str(k)]
break
return params, filenames
def download_from_url(url: str, filepath:Union[str, Path]):
"""
Download from a url to a local file. Skip already existing files.
:param url: Url
:param filepath: local path where the url content shall be stored
:return: local path of the downloaded file
"""
logger.info(f"Downloading {url}")
# Create local folder
folder, filename = os.path.split(filepath)
if not os.path.exists(folder):
os.makedirs(folder)
# Download file if not present locally
if os.path.exists(filepath):
logger.info(f"Skipping {url} (exists locally)")
else:
logger.info(f"Downloading {url} to {filepath} ")
with open(filepath, "wb") as file:
http_get(url=url, temp_file=file)
return filepath