Move API from flask to fastAPI (#3)

* moving api from flask to fastAPI

* update path parameter. add logs

* add error msg for wrong finder id
This commit is contained in:
Malte Pietsch 2020-01-14 18:36:33 +01:00 committed by GitHub
parent 0b3fd7572b
commit a80ac220e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 65 additions and 50 deletions

View File

@ -82,6 +82,5 @@ Elasticsearch Backend
REST API
--------
A Flask based HTTP REST API is included to use the QA Framework with UI or integrating with other systems. To serve the API, run :code:`FLASK_APP=farm_hackstack.api.inference flask run`.
A simple REST API based on `FastAPI <https://fastapi.tiangolo.com/>`_ is included to answer questions at inference time. To serve the API, run :code:`uvicorn haystack.api.inference:app`.
You will find the Swagger API documentation at http://127.0.0.1:8000/docs

View File

@ -1,23 +1,30 @@
import json
from pathlib import Path
from fastapi import FastAPI, HTTPException
import numpy as np
from flask import request, make_response
from flask_cors import CORS
from flask_restplus import Api, Resource
import logging
from haystack import Finder
from haystack.database import app
from haystack.reader.farm import FARMReader
from haystack.retriever.tfidf import TfidfRetriever
CORS(app)
api = Api(
app, debug=True, validate=True, version="1.0", title="FARM Question Answering API"
)
from pydantic import BaseModel
from typing import List, Dict
import uvicorn
MODELS_DIRS = ["saved_models"]
logger = logging.getLogger(__name__)
#TODO Enable CORS
MODELS_DIRS = ["saved_models", "models", "model"]
USE_GPU = False
BATCH_SIZE = 16
app = FastAPI(title="Haystack API", version="0.1")
#############################################
# Load all models in memory
#############################################
model_paths = []
for model_dir in MODELS_DIRS:
path = Path(model_dir)
@ -25,49 +32,59 @@ for model_dir in MODELS_DIRS:
models = [f for f in path.iterdir() if f.is_dir()]
model_paths.extend(models)
if len(model_paths) == 0:
logger.error(f"Could not find any model to load. Checked folders: {MODELS_DIRS}")
retriever = TfidfRetriever()
FINDERS = {}
for idx, model_dir in enumerate(model_paths):
reader = FARMReader(model_dir=str(model_dir), batch_size=16)
FINDERS[idx + 1] = Finder(reader, retriever)
for idx, model_dir in enumerate(model_paths, start=1):
reader = FARMReader(model_dir=str(model_dir), batch_size=BATCH_SIZE, use_gpu=USE_GPU)
FINDERS[idx] = Finder(reader, retriever)
logger.info(f"Initialized Finder (ID={idx}) with model '{model_dir}'")
logger.info("Open http://127.0.0.1:8000/docs to see Swagger API Documentation.")
logger.info(""" Or just try it out directly: curl --request POST --url 'http://127.0.0.1:8000/finders/1/ask' --data '{"question": "Who is the father of Arya Starck?"}'""")
#############################################
# Basic data schema for request & response
#############################################
class Request(BaseModel):
question: str
filters: Dict[str, str] = None
top_k_reader: int = 5
top_k_retriever: int = 10
class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, np.float32):
return str(obj)
return json.JSONEncoder.default(self, obj)
class Answer(BaseModel):
answer: str
score: float = None
probability: float = None
context: str
offset_start: int
offset_end: int
document_id: str = None
@api.representation("application/json")
def resp_json(data, code, headers=None):
resp = make_response(json.dumps(data, cls=NumpyEncoder), code)
resp.headers.extend(headers or {})
return resp
class Response(BaseModel):
question: str
answers: List[Answer]
#############################################
# Endpoints
#############################################
@app.post("/finders/{finder_id}/ask", response_model=Response, response_model_exclude_unset=True)
def ask(finder_id: int, request: Request):
finder = FINDERS.get(finder_id, None)
if not finder:
raise HTTPException(status_code=404, detail=f"Couldn't get Finder with ID {finder_id}. Available IDs: {list(FINDERS.keys())}")
@api.route("/finders/<int:finder_id>/ask")
class InferenceEndpoint(Resource):
def post(self, finder_id):
finder = FINDERS.get(finder_id, None)
if not finder:
return "Model not found", 404
results = finder.get_answers(
question=request.question, top_k_retriever=request.top_k_retriever,
top_k_reader=request.top_k_reader, filters=request.filters
)
request_body = request.get_json()
questions = request_body.get("questions", None)
if not questions:
return "The request is missing 'questions' field", 400
filters = request_body.get("filters", None)
results = finder.get_answers(
question=request_body["questions"][0], top_k_reader=3, filters=filters
)
return results
return results
if __name__ == "__main__":
app.run(host="0.0.0.0")
uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@ -1,9 +1,8 @@
# FARM (incl. transformers 2.3.0 with pipelines)
#farm -e git+https://github.com/deepset-ai/FARM.git@1d30237b037050ef0ac5516f427443cdd18a4d43
-e git://github.com/deepset-ai/FARM.git@1d30237b037050ef0ac5516f427443cdd18a4d43#egg=farm
flask
flask_cors
flask_restplus
fastapi
uvicorn
flask_sqlalchemy
pandas
psycopg2-binary