Rename and restructure modules (database, indexing, schemas) (#379)

* rename database to documentstore

* move document, label, multilabel to haystack/schema.py

* rename documentstore -> document_store

* split indexing modules -> file_converter + preprocessor

* fix order of imports

* Update tutorial notebooks

* fix torch version in tutorial 4
This commit is contained in:
Malte Pietsch 2020-09-16 18:33:23 +02:00 committed by GitHub
parent bde33ddaaa
commit 9727829cc6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
54 changed files with 258 additions and 232 deletions

2
.gitignore vendored
View File

@ -138,7 +138,7 @@ dmypy.json
.idea
# haystack files
haystack/database/qa.db
haystack/document_store/qa.db
data
mlruns
src

View File

@ -17,7 +17,7 @@ COPY README.rst models* /home/user/models/
# optional : copy sqlite db if needed for testing
#COPY qa.db /home/user/
# optional: copy data directory containing docs for indexing
# optional: copy data directory containing docs for ingestion
#COPY data /home/user/data
EXPOSE 8000

View File

@ -1,6 +1,7 @@
import logging
import pandas as pd
from haystack.schema import Document, Label, MultiLabel
from haystack.finder import Finder
pd.options.display.max_colwidth = 80

View File

@ -0,0 +1,125 @@
import logging
from abc import abstractmethod, ABC
from typing import Any, Optional, Dict, List, Union
from haystack import Document, Label, MultiLabel
logger = logging.getLogger(__name__)
class BaseDocumentStore(ABC):
"""
Base class for implementing Document Stores.
"""
index: Optional[str]
label_index: Optional[str]
@abstractmethod
def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None):
"""
Indexes documents for later queries.
:param documents: a list of Python dictionaries or a list of Haystack Document objects.
For documents as dictionaries, the format is {"text": "<the-actual-text>"}.
Optionally: Include meta data via {"text": "<the-actual-text>",
"meta":{"name": "<some-document-name>, "author": "somebody", ...}}
It can be used for filtering and is accessible in the responses of the Finder.
:param index: Optional name of index where the documents shall be written to.
If None, the DocumentStore's default index (self.index) will be used.
:return: None
"""
pass
@abstractmethod
def get_all_documents(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None) -> List[Document]:
pass
@abstractmethod
def get_all_labels(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None) -> List[Label]:
pass
def get_all_labels_aggregated(self,
index: Optional[str] = None,
filters: Optional[Dict[str, List[str]]] = None) -> List[MultiLabel]:
aggregated_labels = []
all_labels = self.get_all_labels(index=index, filters=filters)
# Collect all answers to a question in a dict
question_ans_dict = {} # type: ignore
for l in all_labels:
# only aggregate labels with correct answers, as only those can be currently used in evaluation
if not l.is_correct_answer:
continue
if l.question in question_ans_dict:
question_ans_dict[l.question].append(l)
else:
question_ans_dict[l.question] = [l]
# Aggregate labels
for q, ls in question_ans_dict.items():
ls = list(set(ls)) # get rid of exact duplicates
# check if there are both text answer and "no answer" present
t_present = False
no_present = False
no_idx = []
for idx, l in enumerate(ls):
if len(l.answer) == 0:
no_present = True
no_idx.append(idx)
else:
t_present = True
# if both text and no answer are present, remove no answer labels
if t_present and no_present:
logger.warning(
f"Both text label and 'no answer possible' label is present for question: {ls[0].question}")
for remove_idx in no_idx[::-1]:
ls.pop(remove_idx)
# construct Aggregated_label
for i, l in enumerate(ls):
if i == 0:
agg_label = MultiLabel(question=l.question,
multiple_answers=[l.answer],
is_correct_answer=l.is_correct_answer,
is_correct_document=l.is_correct_document,
origin=l.origin,
multiple_document_ids=[l.document_id],
multiple_offset_start_in_docs=[l.offset_start_in_doc],
no_answer=l.no_answer,
model_id=l.model_id,
)
else:
agg_label.multiple_answers.append(l.answer)
agg_label.multiple_document_ids.append(l.document_id)
agg_label.multiple_offset_start_in_docs.append(l.offset_start_in_doc)
aggregated_labels.append(agg_label)
return aggregated_labels
@abstractmethod
def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]:
pass
@abstractmethod
def get_document_count(self, index: Optional[str] = None) -> int:
pass
@abstractmethod
def query_by_embedding(self,
query_emb: List[float],
filters: Optional[Optional[Dict[str, List[str]]]] = None,
top_k: int = 10,
index: Optional[str] = None) -> List[Document]:
pass
@abstractmethod
def get_label_count(self, index: Optional[str] = None) -> int:
pass
@abstractmethod
def add_eval_data(self, filename: str, doc_index: str = "document", label_index: str = "label"):
pass
def delete_all_documents(self, index: str):
pass

View File

@ -7,8 +7,9 @@ from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk, scan
import numpy as np
from haystack.database.base import BaseDocumentStore, Document, Label
from haystack.indexing.utils import eval_data_from_file
from haystack.document_store.base import BaseDocumentStore
from haystack import Document, Label
from haystack.preprocessor.utils import eval_data_from_file
from haystack.retriever.base import BaseRetriever
logger = logging.getLogger(__name__)
@ -70,7 +71,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
:param refresh_type: Type of ES refresh used to control when changes made by a request (e.g. bulk) are made visible to search.
Values:
- 'wait_for' => continue only after changes are visible (slow, but safe)
- 'false' => continue directly (fast, but sometimes unintuitive behaviour when docs are not immediately available after indexing)
- 'false' => continue directly (fast, but sometimes unintuitive behaviour when docs are not immediately available after ingestion)
More info at https://www.elastic.co/guide/en/elasticsearch/reference/6.8/docs-refresh.html
"""
self.client = Elasticsearch(hosts=[{"host": host, "port": port}], http_auth=(username, password),
@ -470,7 +471,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
if embeddings[0].shape[0] != self.embedding_dim:
raise RuntimeError(f"Embedding dim. of model ({embeddings[0].shape[0]})"
f" doesn't match embedding dim. in documentstore ({self.embedding_dim})."
f" doesn't match embedding dim. in DocumentStore ({self.embedding_dim})."
"Specify the arg `embedding_dim` when initializing ElasticsearchDocumentStore()")
doc_updates = []
for doc, emb in zip(docs, embeddings):

View File

@ -6,8 +6,8 @@ import faiss
import numpy as np
from faiss.swigfaiss import IndexHNSWFlat
from haystack.database.base import Document
from haystack.database.sql import SQLDocumentStore
from haystack import Document
from haystack.document_store.sql import SQLDocumentStore
from haystack.retriever.base import BaseRetriever
logger = logging.getLogger(__name__)
@ -35,7 +35,7 @@ class FAISSDocumentStore(SQLDocumentStore):
"""
:param sql_url: SQL connection URL for database. It defaults to local file based SQLite DB. For large scale
deployment, Postgres is recommended.
:param index_buffer_size: When working with large dataset, the indexing process(FAISS + SQL) can be buffered in
:param index_buffer_size: When working with large datasets, the ingestion process(FAISS + SQL) can be buffered in
smaller chunks to reduce memory footprint.
:param vector_size: the embedding vector size.
:param faiss_index: load an existing FAISS Index.

View File

@ -2,8 +2,9 @@ from typing import Any, Dict, List, Optional, Union
from uuid import uuid4
from collections import defaultdict
from haystack.database.base import BaseDocumentStore, Document, Label
from haystack.indexing.utils import eval_data_from_file
from haystack.document_store.base import BaseDocumentStore
from haystack import Document, Label
from haystack.preprocessor.utils import eval_data_from_file
from haystack.retriever.base import BaseRetriever
import logging
@ -114,7 +115,7 @@ class InMemoryDocumentStore(BaseDocumentStore):
if embeddings[0].shape[0] != self.embedding_dim:
raise RuntimeError(f"Embedding dim. of model ({embeddings[0].shape[0]})"
f" doesn't match embedding dim. in documentstore ({self.embedding_dim})."
f" doesn't match embedding dim. in DocumentStore ({self.embedding_dim})."
"Specify the arg `embedding_dim` when initializing InMemoryDocumentStore()")
for doc, emb in zip(docs, embeddings):

View File

@ -5,8 +5,9 @@ from sqlalchemy import create_engine, Column, Integer, String, DateTime, func, F
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship, sessionmaker
from haystack.database.base import BaseDocumentStore, Document, Label
from haystack.indexing.utils import eval_data_from_file
from haystack.document_store.base import BaseDocumentStore
from haystack import Document, Label
from haystack.preprocessor.utils import eval_data_from_file
Base = declarative_base() # type: Any

View File

@ -1,6 +1,6 @@
from typing import List, Tuple, Dict, Any
from haystack.database.base import MultiLabel
from haystack import MultiLabel
def calculate_reader_metrics(metric_counts: Dict[str, float], correct_retrievals: int):

View File

@ -9,7 +9,7 @@ import langdetect
class BaseConverter:
"""
Base class for implementing file converts to transform input documents to text format for indexing in database.
Base class for implementing file converts to transform input documents to text format for ingestion in DocumentStore.
"""
def __init__(

View File

@ -1,4 +1,4 @@
from haystack.indexing.file_converters.base import BaseConverter
from haystack.file_converter.base import BaseConverter
import logging
from pathlib import Path
from typing import List, Dict, Optional, Any, Tuple

View File

@ -4,7 +4,7 @@ import subprocess
from pathlib import Path
from typing import List, Optional, Dict, Tuple, Any
from haystack.indexing.file_converters.base import BaseConverter
from haystack.file_converter.base import BaseConverter
logger = logging.getLogger(__name__)

View File

@ -7,7 +7,7 @@ from typing import List, Optional, Tuple, Dict, Any
import requests
from tika import parser as tikaparser
from haystack.indexing.file_converters.base import BaseConverter
from haystack.file_converter.base import BaseConverter
logger = logging.getLogger(__name__)

View File

@ -3,7 +3,7 @@ import re
from pathlib import Path
from typing import List, Optional, Tuple, Any, Dict
from haystack.indexing.file_converters.base import BaseConverter
from haystack.file_converter.base import BaseConverter
logger = logging.getLogger(__name__)

View File

@ -9,7 +9,7 @@ from scipy.special import expit
from haystack.reader.base import BaseReader
from haystack.retriever.base import BaseRetriever
from haystack.database.base import MultiLabel, Document
from haystack import MultiLabel, Document
from haystack.eval import calculate_average_precision, eval_counts_reader_batch, calculate_reader_metrics, \
eval_counts_reader

View File

@ -9,9 +9,9 @@ import json
from farm.data_handler.utils import http_get
from haystack.indexing.file_converters.pdf import PDFToTextConverter
from haystack.indexing.file_converters.tika import TikaConverter
from haystack.database.base import Document, Label
from haystack.file_converter.pdf import PDFToTextConverter
from haystack.file_converter.tika import TikaConverter
from haystack import Document, Label
logger = logging.getLogger(__name__)
@ -78,7 +78,7 @@ def convert_files_to_dicts(dir_path: str, clean_func: Optional[Callable] = None,
Convert all files(.txt, .pdf) in the sub-directories of the given path to Python dicts that can be written to a
Document Store.
:param dir_path: path for the documents to be written to the database
:param dir_path: path for the documents to be written to the DocumentStore
:param clean_func: a custom cleaning function that gets applied to each doc (input: str, output:str)
:param split_paragraphs: split text in paragraphs.
@ -127,7 +127,7 @@ def tika_convert_files_to_dicts(
Convert all files(.txt, .pdf) in the sub-directories of the given path to Python dicts that can be written to a
Document Store.
:param dir_path: path for the documents to be written to the database
:param dir_path: path for the documents to be written to the DocumentStore
:param clean_func: a custom cleaning function that gets applied to each doc (input: str, output:str)
:param split_paragraphs: split text in paragraphs.

View File

@ -3,7 +3,7 @@ from scipy.special import expit
from abc import ABC, abstractmethod
from typing import List, Optional, Sequence
from haystack.database.base import Document
from haystack import Document
class BaseReader(ABC):

View File

@ -19,8 +19,10 @@ from farm.utils import set_all_seeds, initialize_device_settings
from scipy.special import expit
import shutil
from haystack.database.base import Document, BaseDocumentStore
from haystack import Document
from haystack.document_store.base import BaseDocumentStore
from haystack.reader.base import BaseReader
logger = logging.getLogger(__name__)

View File

@ -2,7 +2,7 @@ from typing import List, Optional
from transformers import pipeline
from haystack.database.base import Document
from haystack import Document
from haystack.reader.base import BaseReader

View File

@ -2,8 +2,8 @@ from abc import ABC, abstractmethod
from typing import List
import logging
from haystack.database.base import Document
from haystack.database.base import BaseDocumentStore
from haystack import Document
from haystack.document_store.base import BaseDocumentStore
logger = logging.getLogger(__name__)

View File

@ -6,7 +6,9 @@ from pathlib import Path
from farm.infer import Inferencer
from haystack.database.base import Document, BaseDocumentStore
from haystack.document_store.base import BaseDocumentStore
from haystack import Document
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
from haystack.retriever.base import BaseRetriever
from haystack.retriever.sparse import logger

View File

@ -48,7 +48,7 @@ DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST = [
class ModelOutput:
"""
Base class for all model outputs as dataclass. Has a ``__getitem__`` that allows indexing by integer or slice (like
Base class for all model outputs as dataclass. Has a ``__getitem__`` that allows preprocessor by integer or slice (like
a tuple) or strings (like a dictionnary) that will ignore the ``None`` attributes.
"""

View File

@ -5,8 +5,9 @@ from typing import List
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from haystack.database.base import Document, BaseDocumentStore
from haystack.database.elasticsearch import ElasticsearchDocumentStore
from haystack.document_store.base import BaseDocumentStore
from haystack import Document
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
from haystack.retriever.base import BaseRetriever
from collections import namedtuple

View File

@ -1,14 +1,9 @@
import logging
from abc import abstractmethod, ABC
from typing import Any, Optional, Dict, List, Union
from uuid import uuid4
import numpy as np
logger = logging.getLogger(__name__)
class Document:
def __init__(self, text: str,
id: str = None,
@ -183,123 +178,4 @@ class MultiLabel:
return cls(**dict)
def to_dict(self):
return self.__dict__
class BaseDocumentStore(ABC):
"""
Base class for implementing Document Stores.
"""
index: Optional[str]
label_index: Optional[str]
@abstractmethod
def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None):
"""
Indexes documents for later queries.
:param documents: a list of Python dictionaries or a list of Haystack Document objects.
For documents as dictionaries, the format is {"text": "<the-actual-text>"}.
Optionally: Include meta data via {"text": "<the-actual-text>",
"meta":{"name": "<some-document-name>, "author": "somebody", ...}}
It can be used for filtering and is accessible in the responses of the Finder.
:param index: Optional name of index where the documents shall be written to.
If None, the DocumentStore's default index (self.index) will be used.
:return: None
"""
pass
@abstractmethod
def get_all_documents(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None) -> List[Document]:
pass
@abstractmethod
def get_all_labels(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None) -> List[Label]:
pass
def get_all_labels_aggregated(self,
index: Optional[str] = None,
filters: Optional[Dict[str, List[str]]] = None) -> List[MultiLabel]:
aggregated_labels = []
all_labels = self.get_all_labels(index=index, filters=filters)
# Collect all answers to a question in a dict
question_ans_dict = {} # type: ignore
for l in all_labels:
# only aggregate labels with correct answers, as only those can be currently used in evaluation
if not l.is_correct_answer:
continue
if l.question in question_ans_dict:
question_ans_dict[l.question].append(l)
else:
question_ans_dict[l.question] = [l]
# Aggregate labels
for q, ls in question_ans_dict.items():
ls = list(set(ls)) # get rid of exact duplicates
# check if there are both text answer and "no answer" present
t_present = False
no_present = False
no_idx = []
for idx, l in enumerate(ls):
if len(l.answer) == 0:
no_present = True
no_idx.append(idx)
else:
t_present = True
# if both text and no answer are present, remove no answer labels
if t_present and no_present:
logger.warning(
f"Both text label and 'no answer possible' label is present for question: {ls[0].question}")
for remove_idx in no_idx[::-1]:
ls.pop(remove_idx)
# construct Aggregated_label
for i, l in enumerate(ls):
if i == 0:
agg_label = MultiLabel(question=l.question,
multiple_answers=[l.answer],
is_correct_answer=l.is_correct_answer,
is_correct_document=l.is_correct_document,
origin=l.origin,
multiple_document_ids=[l.document_id],
multiple_offset_start_in_docs=[l.offset_start_in_doc],
no_answer=l.no_answer,
model_id=l.model_id,
)
else:
agg_label.multiple_answers.append(l.answer)
agg_label.multiple_document_ids.append(l.document_id)
agg_label.multiple_offset_start_in_docs.append(l.offset_start_in_doc)
aggregated_labels.append(agg_label)
return aggregated_labels
@abstractmethod
def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]:
pass
@abstractmethod
def get_document_count(self, index: Optional[str] = None) -> int:
pass
@abstractmethod
def query_by_embedding(self,
query_emb: List[float],
filters: Optional[Optional[Dict[str, List[str]]]] = None,
top_k: int = 10,
index: Optional[str] = None) -> List[Document]:
pass
@abstractmethod
def get_label_count(self, index: Optional[str] = None) -> int:
pass
@abstractmethod
def add_eval_data(self, filename: str, doc_index: str = "document", label_index: str = "label"):
pass
def delete_all_documents(self, index: str):
pass
return self.__dict__

View File

@ -4,7 +4,7 @@ import logging
import pprint
import pandas as pd
from typing import Dict, Any, List
from haystack.database.sql import DocumentORM
from haystack.document_store.sql import DocumentORM
logger = logging.getLogger(__name__)

View File

@ -3,7 +3,7 @@ from typing import Optional
from fastapi import APIRouter
from pydantic import BaseModel, Field
from haystack.database.elasticsearch import ElasticsearchDocumentStore
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
from rest_api.config import (
DB_HOST,
DB_PORT,

View File

@ -12,9 +12,9 @@ from fastapi import UploadFile, File, Form
from rest_api.config import DB_HOST, DB_PORT, DB_USER, DB_PW, DB_INDEX, ES_CONN_SCHEME, TEXT_FIELD_NAME, \
SEARCH_FIELD_NAME, FILE_UPLOAD_PATH, EMBEDDING_DIM, EMBEDDING_FIELD_NAME, EXCLUDE_META_DATA_FIELDS, VALID_LANGUAGES, \
FAQ_QUESTION_FIELD_NAME, REMOVE_NUMERIC_TABLES, REMOVE_WHITESPACE, REMOVE_EMPTY_LINES, REMOVE_HEADER_FOOTER
from haystack.database.elasticsearch import ElasticsearchDocumentStore
from haystack.indexing.file_converters.pdf import PDFToTextConverter
from haystack.indexing.file_converters.txt import TextConverter
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
from haystack.file_converter.pdf import PDFToTextConverter
from haystack.file_converter.txt import TextConverter
logger = logging.getLogger(__name__)

View File

@ -16,7 +16,7 @@ from rest_api.config import DB_HOST, DB_PORT, DB_USER, DB_PW, DB_INDEX, ES_CONN_
DEFAULT_TOP_K_READER, DEFAULT_TOP_K_RETRIEVER, CONCURRENT_REQUEST_PER_WORKER, FAQ_QUESTION_FIELD_NAME, \
EMBEDDING_MODEL_FORMAT, READER_TYPE, READER_TOKENIZER, GPU_NUMBER, NAME_FIELD_NAME
from rest_api.controller.utils import RequestLimiter
from haystack.database.elasticsearch import ElasticsearchDocumentStore
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
from haystack.reader.farm import FARMReader
from haystack.reader.transformers import TransformersReader
from haystack.retriever.base import BaseRetriever

View File

@ -8,11 +8,11 @@ import pytest
import requests
from elasticsearch import Elasticsearch
from haystack.database.base import Document
from haystack.database.elasticsearch import ElasticsearchDocumentStore
from haystack.database.faiss import FAISSDocumentStore
from haystack.database.memory import InMemoryDocumentStore
from haystack.database.sql import SQLDocumentStore
from haystack import Document
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
from haystack.document_store.faiss import FAISSDocumentStore
from haystack.document_store.memory import InMemoryDocumentStore
from haystack.document_store.sql import SQLDocumentStore
from haystack.reader.farm import FARMReader
from haystack.reader.transformers import TransformersReader

View File

@ -2,9 +2,9 @@ import numpy as np
import pytest
from elasticsearch import Elasticsearch
from haystack.database.base import Document, Label
from haystack.database.elasticsearch import ElasticsearchDocumentStore
from haystack.database.faiss import FAISSDocumentStore
from haystack import Document, Label
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
from haystack.document_store.faiss import FAISSDocumentStore
def test_get_all_documents_without_filters(document_store_with_docs):

View File

@ -1,4 +1,4 @@
from haystack.database.base import Document
from haystack import Document
def test_document_data_access():

View File

@ -1,6 +1,6 @@
from pathlib import Path
from haystack.indexing.file_converters.docx import DocxToTextConverter
from haystack.file_converter.docx import DocxToTextConverter
def test_extract_pages():

View File

@ -2,8 +2,8 @@ import pytest
import time
from haystack.retriever.dense import DensePassageRetriever
from haystack.database.base import Document
from haystack.database.elasticsearch import ElasticsearchDocumentStore
from haystack import Document
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory"], indirect=True)

View File

@ -1,4 +1,4 @@
from haystack.database.base import Document
from haystack import Document
import pytest

View File

@ -1,5 +1,5 @@
import pytest
from haystack.database.base import BaseDocumentStore
from haystack.document_store.base import BaseDocumentStore
from haystack.retriever.sparse import ElasticsearchRetriever
from haystack.finder import Finder

View File

@ -3,7 +3,7 @@ from haystack import Finder
def test_faq_retriever_in_memory_store():
from haystack.database.memory import InMemoryDocumentStore
from haystack.document_store.memory import InMemoryDocumentStore
from haystack.retriever.dense import EmbeddingRetriever
document_store = InMemoryDocumentStore(embedding_field="embedding")

View File

@ -1,8 +1,8 @@
def test_module_imports():
from haystack import Finder
from haystack.database.sql import SQLDocumentStore
from haystack.indexing.cleaning import clean_wiki_text
from haystack.indexing.utils import convert_files_to_dicts, fetch_archive_from_http
from haystack.document_store.sql import SQLDocumentStore
from haystack.preprocessor.cleaning import clean_wiki_text
from haystack.preprocessor.utils import convert_files_to_dicts, fetch_archive_from_http
from haystack.reader.farm import FARMReader
from haystack.reader.transformers import TransformersReader
from haystack.retriever.sparse import TfidfRetriever

View File

@ -2,8 +2,8 @@ from pathlib import Path
import pytest
from haystack.indexing.file_converters.pdf import PDFToTextConverter
from haystack.indexing.file_converters.tika import TikaConverter
from haystack.file_converter.pdf import PDFToTextConverter
from haystack.file_converter.tika import TikaConverter
@pytest.mark.parametrize("Converter", [PDFToTextConverter, TikaConverter])

View File

@ -1,6 +1,6 @@
import math
from haystack.database.base import Document
from haystack import Document
from haystack.reader.base import BaseReader
from haystack.reader.farm import FARMReader
from haystack.reader.transformers import TransformersReader

View File

@ -7,7 +7,7 @@ def test_tfidf_retriever():
{"name": "testing the finder 3", "text": "alien says arghh"}
]
from haystack.database.memory import InMemoryDocumentStore
from haystack.document_store.memory import InMemoryDocumentStore
document_store = InMemoryDocumentStore()
document_store.write_documents(test_docs)

View File

@ -41,8 +41,8 @@
"outputs": [],
"source": [
"from haystack import Finder\n",
"from haystack.indexing.cleaning import clean_wiki_text\n",
"from haystack.indexing.utils import convert_files_to_dicts, fetch_archive_from_http\n",
"from haystack.preprocessor.cleaning import clean_wiki_text\n",
"from haystack.preprocessor.utils import convert_files_to_dicts, fetch_archive_from_http\n",
"from haystack.reader.farm import FARMReader\n",
"from haystack.reader.transformers import TransformersReader\n",
"from haystack.utils import print_answers"
@ -125,7 +125,7 @@
"source": [
"# Connect to Elasticsearch\n",
"\n",
"from haystack.database.elasticsearch import ElasticsearchDocumentStore\n",
"from haystack.document_store.elasticsearch import ElasticsearchDocumentStore\n",
"document_store = ElasticsearchDocumentStore(host=\"localhost\", username=\"\", password=\"\", index=\"document\")"
]
},
@ -137,9 +137,13 @@
}
},
"source": [
"## Cleaning & indexing documents\n",
"## Preprocessing of documents\n",
"\n",
"Haystack provides a customizable cleaning and indexing pipeline for ingesting documents in Document Stores.\n",
"Haystack provides a customizable pipeline for:\n",
" - converting files into texts\n",
" - cleaning texts\n",
" - splitting texts\n",
" - writing them to a Document Store\n",
"\n",
"In this tutorial, we download Wikipedia articles on Game of Thrones, apply a basic cleaning function, and index them in Elasticsearch."
]

View File

@ -14,9 +14,9 @@ import subprocess
import time
from haystack import Finder
from haystack.database.elasticsearch import ElasticsearchDocumentStore
from haystack.indexing.cleaning import clean_wiki_text
from haystack.indexing.utils import convert_files_to_dicts, fetch_archive_from_http
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
from haystack.preprocessor.cleaning import clean_wiki_text
from haystack.preprocessor.utils import convert_files_to_dicts, fetch_archive_from_http
from haystack.reader.farm import FARMReader
from haystack.reader.transformers import TransformersReader
from haystack.utils import print_answers
@ -56,11 +56,15 @@ if LAUNCH_ELASTICSEARCH:
# Connect to Elasticsearch
document_store = ElasticsearchDocumentStore(host="localhost", username="", password="", index="document")
# ## Cleaning & indexing documents
# ## Preprocessing of documents
#
# Haystack provides a customizable cleaning and indexing pipeline for ingesting documents in Document Stores.
#
# In this tutorial, we download Wikipedia articles on Game of Thrones, apply a basic cleaning function, and index
# Haystack provides a customizable pipeline for:
# - converting files into texts
# - cleaning texts
# - splitting texts
# - writing them to a Document Store
# In this tutorial, we download Wikipedia articles on Game of Thrones, apply a basic cleaning function, and add
# them in Elasticsearch.

View File

@ -38,8 +38,8 @@
"outputs": [],
"source": [
"from haystack import Finder\n",
"from haystack.indexing.cleaning import clean_wiki_text\n",
"from haystack.indexing.utils import convert_files_to_dicts, fetch_archive_from_http\n",
"from haystack.preprocessor.cleaning import clean_wiki_text\n",
"from haystack.preprocessor.utils import convert_files_to_dicts, fetch_archive_from_http\n",
"from haystack.reader.farm import FARMReader\n",
"from haystack.reader.transformers import TransformersReader\n",
"from haystack.utils import print_answers"
@ -59,7 +59,7 @@
"outputs": [],
"source": [
"# In-Memory Document Store\n",
"from haystack.database.memory import InMemoryDocumentStore\n",
"from haystack.document_store.memory import InMemoryDocumentStore\n",
"document_store = InMemoryDocumentStore()"
]
},
@ -70,7 +70,7 @@
"outputs": [],
"source": [
"# SQLite Document Store\n",
"# from haystack.database.sql import SQLDocumentStore\n",
"# from haystack.document_store.sql import SQLDocumentStore\n",
"# document_store = SQLDocumentStore(url=\"sqlite:///qa.db\")"
]
},
@ -82,9 +82,13 @@
}
},
"source": [
"## Cleaning & indexing documents\n",
"## Preprocessing of documents\n",
"\n",
"Haystack provides a customizable cleaning and indexing pipeline for ingesting documents in Document Stores.\n",
"Haystack provides a customizable pipeline for:\n",
" - converting files into texts\n",
" - cleaning texts\n",
" - splitting texts\n",
" - writing them to a Document Store\n",
"\n",
"In this tutorial, we download Wikipedia articles on Game of Thrones, apply a basic cleaning function, and index them in Elasticsearch."
]

View File

@ -8,10 +8,10 @@
from haystack import Finder
from haystack.database.memory import InMemoryDocumentStore
from haystack.database.sql import SQLDocumentStore
from haystack.indexing.cleaning import clean_wiki_text
from haystack.indexing.utils import convert_files_to_dicts, fetch_archive_from_http
from haystack.document_store.memory import InMemoryDocumentStore
from haystack.document_store.sql import SQLDocumentStore
from haystack.preprocessor.cleaning import clean_wiki_text
from haystack.preprocessor.utils import convert_files_to_dicts, fetch_archive_from_http
from haystack.reader.farm import FARMReader
from haystack.reader.transformers import TransformersReader
from haystack.retriever.sparse import TfidfRetriever
@ -25,10 +25,14 @@ document_store = InMemoryDocumentStore()
# document_store = SQLDocumentStore(url="sqlite:///qa.db")
# ## Cleaning & indexing documents
#
# Haystack provides a customizable cleaning and indexing pipeline for ingesting documents in Document Stores.
# ## Preprocessing of documents
#
# Haystack provides a customizable pipeline for:
# - converting files into texts
# - cleaning texts
# - splitting texts
# - writing them to a Document Store
# In this tutorial, we download Wikipedia articles on Game of Thrones, apply a basic cleaning function, and index
# them in Elasticsearch.
# Let's first get some documents that we want to query

View File

@ -46,7 +46,7 @@
"outputs": [],
"source": [
"from haystack import Finder\n",
"from haystack.database.elasticsearch import ElasticsearchDocumentStore\n",
"from haystack.document_store.elasticsearch import ElasticsearchDocumentStore\n",
"\n",
"from haystack.retriever.dense import EmbeddingRetriever\n",
"from haystack.utils import print_answers\n",
@ -125,7 +125,7 @@
}
],
"source": [
"from haystack.database.elasticsearch import ElasticsearchDocumentStore\n",
"from haystack.document_store.elasticsearch import ElasticsearchDocumentStore\n",
"document_store = ElasticsearchDocumentStore(host=\"localhost\", username=\"\", password=\"\",\n",
" index=\"document\",\n",
" embedding_field=\"question_emb\",\n",
@ -188,7 +188,7 @@
"# Get embeddings for our questions from the FAQs\n",
"questions = list(df[\"question\"].values)\n",
"df[\"question_emb\"] = retriever.embed_queries(texts=questions)\n",
"df[\"question_emb\"] = df[\"question_emb\"].apply(list) # convert from numpy to list for ES indexing\n",
"df[\"question_emb\"] = df[\"question_emb\"].apply(list) # convert from numpy to list for ES ingestion\n",
"df = df.rename(columns={\"answer\": \"text\"})\n",
"\n",
"# Convert Dataframe to list of dicts and index them in our DocumentStore\n",

View File

@ -1,5 +1,5 @@
from haystack import Finder
from haystack.database.elasticsearch import ElasticsearchDocumentStore
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
from haystack.retriever.dense import EmbeddingRetriever
from haystack.utils import print_answers
@ -67,7 +67,7 @@ print(df.head())
# Get embeddings for our questions from the FAQs
questions = list(df["question"].values)
df["question_emb"] = retriever.embed_queries(texts=questions)
df["question_emb"] = df["question_emb"].apply(list) # convert from numpy to list for ES indexing
df["question_emb"] = df["question_emb"].apply(list) # convert from numpy to list for ES ingestion
df = df.rename(columns={"answer": "text"})
# Convert Dataframe to list of dicts and index them in our DocumentStore

View File

@ -115,7 +115,7 @@
},
"outputs": [],
"source": [
"from haystack.indexing.utils import fetch_archive_from_http\n",
"from haystack.preprocessor.utils import fetch_archive_from_http\n",
"\n",
"# Download evaluation data, which is a subset of Natural Questions development set containing 50 documents\n",
"doc_dir = \"../data/nq\"\n",
@ -148,7 +148,7 @@
"outputs": [],
"source": [
"# Connect to Elasticsearch\n",
"from haystack.database.elasticsearch import ElasticsearchDocumentStore\n",
"from haystack.document_store.elasticsearch import ElasticsearchDocumentStore\n",
"\n",
"# Connect to Elasticsearch\n",
"document_store = ElasticsearchDocumentStore(host=\"localhost\", username=\"\", password=\"\", index=\"document\",\n",
@ -174,7 +174,7 @@
},
"outputs": [],
"source": [
"# Add evaluation data to Elasticsearch database\n",
"# Add evaluation data to Elasticsearch Document Store\n",
"# We first delete the custom tutorial indices to not have duplicate elements\n",
"document_store.delete_all_documents(index=doc_index)\n",
"document_store.delete_all_documents(index=label_index)\n",

View File

@ -1,5 +1,5 @@
from haystack.database.elasticsearch import ElasticsearchDocumentStore
from haystack.indexing.utils import fetch_archive_from_http
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
from haystack.preprocessor.utils import fetch_archive_from_http
from haystack.retriever.sparse import ElasticsearchRetriever
from haystack.retriever.dense import DensePassageRetriever
from haystack.reader.farm import FARMReader
@ -52,7 +52,7 @@ document_store = ElasticsearchDocumentStore(host="localhost", username="", passw
embedding_dim=768, excluded_meta_data=["emb"])
# Add evaluation data to Elasticsearch database
# Add evaluation data to Elasticsearch document store
# We first delete the custom tutorial indices to not have duplicate elements
document_store.delete_all_documents(index=doc_index)
document_store.delete_all_documents(index=label_index)

View File

@ -295,8 +295,8 @@
"outputs": [],
"source": [
"from haystack import Finder\n",
"from haystack.indexing.cleaning import clean_wiki_text\n",
"from haystack.indexing.utils import convert_files_to_dicts, fetch_archive_from_http\n",
"from haystack.preprocessor.cleaning import clean_wiki_text\n",
"from haystack.preprocessor.utils import convert_files_to_dicts, fetch_archive_from_http\n",
"from haystack.reader.farm import FARMReader\n",
"from haystack.reader.transformers import TransformersReader\n",
"from haystack.utils import print_answers"
@ -343,7 +343,7 @@
}
],
"source": [
"from haystack.database.faiss import FAISSDocumentStore\n",
"from haystack.document_store.faiss import FAISSDocumentStore\n",
"\n",
"document_store = FAISSDocumentStore()"
]

View File

@ -1,19 +1,19 @@
from haystack import Finder
from haystack.database.faiss import FAISSDocumentStore
from haystack.indexing.cleaning import clean_wiki_text
from haystack.indexing.utils import convert_files_to_dicts, fetch_archive_from_http
from haystack.document_store.faiss import FAISSDocumentStore
from haystack.preprocessor.cleaning import clean_wiki_text
from haystack.preprocessor.utils import convert_files_to_dicts, fetch_archive_from_http
from haystack.reader.farm import FARMReader
from haystack.utils import print_answers
from haystack.retriever.dense import DensePassageRetriever
# FAISS is a library for efficient similarity search on a cluster of dense vectors.
# The FAISSDocumentStore uses a SQL(SQLite in-memory be default) database under-the-hood
# The FAISSDocumentStore uses a SQL(SQLite in-memory be default) document store under-the-hood
# to store the document text and other meta data. The vector embeddings of the text are
# indexed on a FAISS Index that later is queried for searching answers.
document_store = FAISSDocumentStore()
# ## Cleaning & indexing documents
# ## Preprocessing of documents
# Let's first get some documents that we want to query
doc_dir = "data/article_txt_got"
s3_url = "https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-qa/datasets/documents/wiki_gameofthrones_txt.zip"