mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-05 11:38:20 +00:00
Update tutorials (#12)
Co-authored-by: Malte Pietsch <malte.pietsch@deepset.ai>
This commit is contained in:
parent
1718ea55b8
commit
c52266e520
@ -4,9 +4,9 @@ from fastapi import FastAPI, HTTPException
|
||||
import logging
|
||||
|
||||
from haystack import Finder
|
||||
from haystack.database import app
|
||||
from haystack.reader.farm import FARMReader
|
||||
from haystack.retriever.tfidf import TfidfRetriever
|
||||
from haystack.database.sql import SQLDocumentStore
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Dict
|
||||
@ -19,25 +19,19 @@ logger = logging.getLogger(__name__)
|
||||
MODELS_DIRS = ["saved_models", "models", "model"]
|
||||
USE_GPU = False
|
||||
BATCH_SIZE = 16
|
||||
DATABASE_URL = "sqlite:///qa.db"
|
||||
MODEL_PATHS = ['deepset/bert-base-cased-squad2']
|
||||
|
||||
app = FastAPI(title="Haystack API", version="0.1")
|
||||
|
||||
#############################################
|
||||
# Load all models in memory
|
||||
#############################################
|
||||
model_paths = []
|
||||
for model_dir in MODELS_DIRS:
|
||||
path = Path(model_dir)
|
||||
if path.is_dir():
|
||||
models = [f for f in path.iterdir() if f.is_dir()]
|
||||
model_paths.extend(models)
|
||||
if len(MODEL_PATHS) == 0:
|
||||
logger.error(f"No model to load. Please specify one via MODEL_PATHS (e.g. ['deepset/bert-base-cased-squad2']")
|
||||
|
||||
if len(model_paths) == 0:
|
||||
logger.error(f"Could not find any model to load. Checked folders: {MODELS_DIRS}")
|
||||
datastore = SQLDocumentStore(url=DATABASE_URL)
|
||||
retriever = TfidfRetriever(datastore=datastore)
|
||||
|
||||
retriever = TfidfRetriever()
|
||||
FINDERS = {}
|
||||
for idx, model_dir in enumerate(model_paths, start=1):
|
||||
for idx, model_dir in enumerate(MODEL_PATHS, start=1):
|
||||
reader = FARMReader(model_name_or_path=str(model_dir), batch_size=BATCH_SIZE, use_gpu=USE_GPU)
|
||||
FINDERS[idx] = Finder(reader, retriever)
|
||||
logger.info(f"Initialized Finder (ID={idx}) with model '{model_dir}'")
|
||||
|
||||
@ -40,11 +40,13 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from haystack.reader.farm import FARMReader\n",
|
||||
"from haystack.retriever.tfidf import TfidfRetriever\n",
|
||||
"from haystack import Finder\n",
|
||||
"from haystack.indexing.io import write_documents_to_db, fetch_archive_from_http\n",
|
||||
"from haystack.database.sql import SQLDocumentStore\n",
|
||||
"from haystack.indexing.cleaning import clean_wiki_text\n",
|
||||
"from haystack.indexing.io import write_documents_to_db, fetch_archive_from_http\n",
|
||||
"from haystack.reader.farm import FARMReader\n",
|
||||
"from haystack.reader.transformers import TransformersReader\n",
|
||||
"from haystack.retriever.tfidf import TfidfRetriever\n",
|
||||
"from haystack.utils import print_answers"
|
||||
]
|
||||
},
|
||||
@ -75,20 +77,21 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Init a database (default: sqllite)\n",
|
||||
"from haystack.database import db\n",
|
||||
"db.create_all()\n",
|
||||
"\n",
|
||||
"# Let's first get some documents that we want to query\n",
|
||||
"# Here: 517 Wikipedia articles for Game of Thrones\n",
|
||||
"doc_dir = \"data/article_txt_got\"\n",
|
||||
"s3_url = \"https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-qa/datasets/documents/wiki_gameofthrones_txt.zip\"\n",
|
||||
"fetch_archive_from_http(url=s3_url, output_dir=doc_dir)\n",
|
||||
"\n",
|
||||
"# Now, let's write the docs to our DB. \n",
|
||||
"# You can supply a cleaning function that is applied to each doc (e.g. to remove footers)\n",
|
||||
"# The documents can be stored in different types of \"DocumentStores\".\n",
|
||||
"# For dev we suggest a light-weight SQL DB\n",
|
||||
"# For production we suggest elasticsearch\n",
|
||||
"datastore = SQLDocumentStore(url=\"sqlite:///qa.db\")\n",
|
||||
"\n",
|
||||
"# Now, let's write the docs to our DB.\n",
|
||||
"# You can optionally supply a cleaning function that is applied to each doc (e.g. to remove footers)\n",
|
||||
"# It must take a str as input, and return a str.\n",
|
||||
"write_documents_to_db(document_dir=doc_dir, clean_func=clean_wiki_text, only_empty_db=True)"
|
||||
"write_documents_to_db(datastore=datastore, document_dir=doc_dir, clean_func=clean_wiki_text, only_empty_db=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -119,7 +122,7 @@
|
||||
"source": [
|
||||
"# A retriever identifies the k most promising chunks of text that might contain the answer for our question\n",
|
||||
"# Retrievers use some simple but fast algorithm, here: TF-IDF\n",
|
||||
"retriever = TfidfRetriever()"
|
||||
"retriever = TfidfRetriever(datastore=datastore)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -143,13 +146,13 @@
|
||||
],
|
||||
"source": [
|
||||
"# A reader scans the text chunks in detail and extracts the k best answers\n",
|
||||
"# Reader use more powerful but slower deep learning models, here: a BERT QA model trained via FARM on Squad 2.0\n",
|
||||
"from haystack.indexing.io import fetch_archive_from_http\n",
|
||||
"fetch_archive_from_http(url=\"https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-models/0.3.0/bert-english-qa-large.tar.gz\", output_dir=\"model\")\n",
|
||||
"reader = FARMReader(model_dir=\"model/bert-english-qa-large\", use_gpu=False)\n",
|
||||
"# Reader use more powerful but slower deep learning models\n",
|
||||
"# You can select a local model or any of the QA models published on huggingface's model hub (https://huggingface.co/models)\n",
|
||||
"# here: a medium sized BERT QA model trained via FARM on Squad 2.0\n",
|
||||
"reader = FARMReader(model_name_or_path=\"deepset/bert-base-cased-squad2\", use_gpu=False)\n",
|
||||
"\n",
|
||||
"# OR: use alternatively a reader from huggingface's Transformers package\n",
|
||||
"# reader = TransformersReader(use_gpu=-1)"
|
||||
"# OR: use alternatively a reader from huggingface's transformers package (https://github.com/huggingface/transformers)\n",
|
||||
"# reader = TransformersReader(model=\"distilbert-base-uncased-distilled-squad\", tokenizer=\"distilbert-base-uncased\", use_gpu=-1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -162,7 +165,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# The Finder sticks together retriever and retriever in a pipeline to answer our actual questions \n",
|
||||
"# The Finder sticks together retriever and retriever in a pipeline to answer our actual questions\n",
|
||||
"finder = Finder(reader, retriever)"
|
||||
]
|
||||
},
|
||||
|
||||
@ -1,16 +1,14 @@
|
||||
from haystack import Finder
|
||||
from haystack.database.sql import SQLDocumentStore
|
||||
from haystack.indexing.cleaning import clean_wiki_text
|
||||
from haystack.indexing.io import write_documents_to_db, fetch_archive_from_http
|
||||
from haystack.reader.farm import FARMReader
|
||||
from haystack.reader.transformers import TransformersReader
|
||||
from haystack.retriever.tfidf import TfidfRetriever
|
||||
from haystack import Finder
|
||||
from haystack.indexing.io import write_documents_to_db, fetch_archive_from_http
|
||||
from haystack.indexing.cleaning import clean_wiki_text
|
||||
from haystack.utils import print_answers
|
||||
|
||||
|
||||
## Indexing & cleaning documents
|
||||
# Init a database (default: sqllite)
|
||||
from haystack.database import db
|
||||
db.create_all()
|
||||
|
||||
# Let's first get some documents that we want to query
|
||||
# Here: 517 Wikipedia articles for Game of Thrones
|
||||
@ -18,25 +16,31 @@ doc_dir = "data/article_txt_got"
|
||||
s3_url = "https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-qa/datasets/documents/wiki_gameofthrones_txt.zip"
|
||||
fetch_archive_from_http(url=s3_url, output_dir=doc_dir)
|
||||
|
||||
# Now, let's write the docs to our DB.
|
||||
# You can supply a cleaning function that is applied to each doc (e.g. to remove footers)
|
||||
# It must take a str as input, and return a str.
|
||||
write_documents_to_db(document_dir=doc_dir, clean_func=clean_wiki_text, only_empty_db=True)
|
||||
|
||||
# The documents can be stored in different types of "DocumentStores".
|
||||
# For dev we suggest a light-weight SQL DB
|
||||
# For production we suggest elasticsearch
|
||||
datastore = SQLDocumentStore(url="sqlite:///qa.db")
|
||||
|
||||
# Now, let's write the docs to our DB.
|
||||
# You can optionally supply a cleaning function that is applied to each doc (e.g. to remove footers)
|
||||
# It must take a str as input, and return a str.
|
||||
write_documents_to_db(datastore=datastore, document_dir=doc_dir, clean_func=clean_wiki_text, only_empty_db=True)
|
||||
|
||||
## Initalize Reader, Retriever & Finder
|
||||
|
||||
# A retriever identifies the k most promising chunks of text that might contain the answer for our question
|
||||
# Retrievers use some simple but fast algorithm, here: TF-IDF
|
||||
retriever = TfidfRetriever()
|
||||
retriever = TfidfRetriever(datastore=datastore)
|
||||
|
||||
# A reader scans the text chunks in detail and extracts the k best answers
|
||||
# Reader use more powerful but slower deep learning models, here: a BERT QA model trained via FARM on Squad 2.0
|
||||
fetch_archive_from_http(url="https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-models/0.3.0/bert-english-qa-large.tar.gz", output_dir="model")
|
||||
reader = FARMReader(model_name_or_path="model/bert-english-qa-large", use_gpu=False)
|
||||
# Reader use more powerful but slower deep learning models
|
||||
# You can select a local model or any of the QA models published on huggingface's model hub (https://huggingface.co/models)
|
||||
# here: a medium sized BERT QA model trained via FARM on Squad 2.0
|
||||
reader = FARMReader(model_name_or_path="deepset/bert-base-cased-squad2", use_gpu=False)
|
||||
|
||||
# OR: use alternatively a reader from huggingface's Transformers package
|
||||
# reader = TransformersReader(use_gpu=-1)
|
||||
# OR: use alternatively a reader from huggingface's transformers package (https://github.com/huggingface/transformers)
|
||||
# reader = TransformersReader(model="distilbert-base-uncased-distilled-squad", tokenizer="distilbert-base-uncased", use_gpu=-1)
|
||||
|
||||
# The Finder sticks together retriever and retriever in a pipeline to answer our actual questions
|
||||
finder = Finder(reader, retriever)
|
||||
@ -50,6 +54,3 @@ prediction = finder.get_answers(question="Who is the father of Arya Stark?", top
|
||||
#prediction = finder.get_answers(question="Who is the sister of Sansa?", top_k_reader=5)
|
||||
|
||||
print_answers(prediction, details="minimal")
|
||||
|
||||
|
||||
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
from haystack.reader.farm import FARMReader
|
||||
from haystack.reader.transformers import TransformersReader
|
||||
from haystack.retriever.tfidf import TfidfRetriever
|
||||
|
||||
from haystack import Finder
|
||||
from haystack.indexing.io import write_documents_to_db, fetch_archive_from_http
|
||||
from haystack.database.sql import SQLDocumentStore
|
||||
from haystack.indexing.cleaning import clean_wiki_text
|
||||
from haystack.indexing.io import write_documents_to_db, fetch_archive_from_http
|
||||
from haystack.reader.farm import FARMReader
|
||||
from haystack.retriever.tfidf import TfidfRetriever
|
||||
from haystack.utils import print_answers
|
||||
|
||||
#### TRAINING #############
|
||||
@ -11,26 +12,31 @@ from haystack.utils import print_answers
|
||||
reader = FARMReader(model_name_or_path="distilbert-base-uncased-distilled-squad", use_gpu=False)
|
||||
|
||||
# and fine-tune it on your own custom dataset (should be in SQuAD like format)
|
||||
reader.train(data_dir="../data/squad_small", train_filename="train.json", use_gpu=False, n_epochs=1)
|
||||
train_data = "PATH/TO_YOUR/TRAIN_DATA"
|
||||
reader.train(data_dir=train_data, train_filename="train.json", use_gpu=False, n_epochs=1)
|
||||
|
||||
|
||||
#### Use it (same as in Tutorial 1) #############
|
||||
|
||||
# Okay, we have a fine-tuned model now. Let's test it on some docs:
|
||||
## Let's get some docs for testing (see Tutorial 1 for more explanations)
|
||||
from haystack.database import db
|
||||
db.create_all()
|
||||
## Indexing & cleaning documents
|
||||
|
||||
# Download docs
|
||||
# Let's get the data (Game of thrones articles from wikipedia)
|
||||
doc_dir = "data/article_txt_got"
|
||||
s3_url = "https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-qa/datasets/documents/wiki_gameofthrones_txt.zip"
|
||||
fetch_archive_from_http(url=s3_url, output_dir=doc_dir)
|
||||
|
||||
# Write docs to our DB.
|
||||
write_documents_to_db(document_dir=doc_dir, clean_func=clean_wiki_text, only_empty_db=True)
|
||||
|
||||
# Initialize Finder Pipeline
|
||||
retriever = TfidfRetriever()
|
||||
# Init Document store & write docs to it
|
||||
datastore = SQLDocumentStore(url="sqlite:///qa.db")
|
||||
write_documents_to_db(datastore=datastore, document_dir=doc_dir, clean_func=clean_wiki_text, only_empty_db=True)
|
||||
|
||||
## Initalize Reader, Retriever & Finder
|
||||
|
||||
# A retriever identifies the k most promising chunks of text that might contain the answer for our question
|
||||
# Retrievers use some simple but fast algorithm, here: TF-IDF
|
||||
retriever = TfidfRetriever(datastore=datastore)
|
||||
|
||||
# The Finder sticks together retriever and retriever in a pipeline to answer our actual questions
|
||||
finder = Finder(reader, retriever)
|
||||
|
||||
## Voilá! Ask a question!
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user