From a80ac220e470803cef927d36ebba7679df9b3666 Mon Sep 17 00:00:00 2001 From: Malte Pietsch Date: Tue, 14 Jan 2020 18:36:33 +0100 Subject: [PATCH] Move API from flask to fastAPI (#3) * moving api from flask to fastAPI * update path parameter. add logs * add error msg for wrong finder id --- README.rst | 5 +- haystack/api/inference.py | 105 ++++++++++++++++++++++---------------- requirements.txt | 5 +- 3 files changed, 65 insertions(+), 50 deletions(-) diff --git a/README.rst b/README.rst index 717b7674b..72fdc1317 100644 --- a/README.rst +++ b/README.rst @@ -82,6 +82,5 @@ Elasticsearch Backend REST API -------- -A Flask based HTTP REST API is included to use the QA Framework with UI or integrating with other systems. To serve the API, run :code:`FLASK_APP=farm_hackstack.api.inference flask run`. - - +A simple REST API based on `FastAPI `_ is included to answer questions at inference time. To serve the API, run :code:`uvicorn haystack.api.inference:app`. +You will find the Swagger API documentation at http://127.0.0.1:8000/docs \ No newline at end of file diff --git a/haystack/api/inference.py b/haystack/api/inference.py index e343e7815..31577bca2 100644 --- a/haystack/api/inference.py +++ b/haystack/api/inference.py @@ -1,23 +1,30 @@ -import json from pathlib import Path +from fastapi import FastAPI, HTTPException -import numpy as np -from flask import request, make_response -from flask_cors import CORS -from flask_restplus import Api, Resource +import logging from haystack import Finder from haystack.database import app from haystack.reader.farm import FARMReader from haystack.retriever.tfidf import TfidfRetriever -CORS(app) -api = Api( - app, debug=True, validate=True, version="1.0", title="FARM Question Answering API" -) +from pydantic import BaseModel +from typing import List, Dict +import uvicorn -MODELS_DIRS = ["saved_models"] +logger = logging.getLogger(__name__) +#TODO Enable CORS + +MODELS_DIRS = ["saved_models", "models", "model"] +USE_GPU = False +BATCH_SIZE = 16 + +app = FastAPI(title="Haystack API", version="0.1") + +############################################# +# Load all models in memory +############################################# model_paths = [] for model_dir in MODELS_DIRS: path = Path(model_dir) @@ -25,49 +32,59 @@ for model_dir in MODELS_DIRS: models = [f for f in path.iterdir() if f.is_dir()] model_paths.extend(models) +if len(model_paths) == 0: + logger.error(f"Could not find any model to load. Checked folders: {MODELS_DIRS}") + retriever = TfidfRetriever() FINDERS = {} -for idx, model_dir in enumerate(model_paths): - reader = FARMReader(model_dir=str(model_dir), batch_size=16) - FINDERS[idx + 1] = Finder(reader, retriever) +for idx, model_dir in enumerate(model_paths, start=1): + reader = FARMReader(model_dir=str(model_dir), batch_size=BATCH_SIZE, use_gpu=USE_GPU) + FINDERS[idx] = Finder(reader, retriever) + logger.info(f"Initialized Finder (ID={idx}) with model '{model_dir}'") + +logger.info("Open http://127.0.0.1:8000/docs to see Swagger API Documentation.") +logger.info(""" Or just try it out directly: curl --request POST --url 'http://127.0.0.1:8000/finders/1/ask' --data '{"question": "Who is the father of Arya Starck?"}'""") + +############################################# +# Basic data schema for request & response +############################################# +class Request(BaseModel): + question: str + filters: Dict[str, str] = None + top_k_reader: int = 5 + top_k_retriever: int = 10 -class NumpyEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, np.ndarray): - return obj.tolist() - if isinstance(obj, np.float32): - return str(obj) - return json.JSONEncoder.default(self, obj) +class Answer(BaseModel): + answer: str + score: float = None + probability: float = None + context: str + offset_start: int + offset_end: int + document_id: str = None -@api.representation("application/json") -def resp_json(data, code, headers=None): - resp = make_response(json.dumps(data, cls=NumpyEncoder), code) - resp.headers.extend(headers or {}) - return resp +class Response(BaseModel): + question: str + answers: List[Answer] +############################################# +# Endpoints +############################################# +@app.post("/finders/{finder_id}/ask", response_model=Response, response_model_exclude_unset=True) +def ask(finder_id: int, request: Request): + finder = FINDERS.get(finder_id, None) + if not finder: + raise HTTPException(status_code=404, detail=f"Couldn't get Finder with ID {finder_id}. Available IDs: {list(FINDERS.keys())}") -@api.route("/finders//ask") -class InferenceEndpoint(Resource): - def post(self, finder_id): - finder = FINDERS.get(finder_id, None) - if not finder: - return "Model not found", 404 + results = finder.get_answers( + question=request.question, top_k_retriever=request.top_k_retriever, + top_k_reader=request.top_k_reader, filters=request.filters + ) - request_body = request.get_json() - questions = request_body.get("questions", None) - if not questions: - return "The request is missing 'questions' field", 400 - - filters = request_body.get("filters", None) - - results = finder.get_answers( - question=request_body["questions"][0], top_k_reader=3, filters=filters - ) - - return results + return results if __name__ == "__main__": - app.run(host="0.0.0.0") + uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 2cdb0954a..a455a19d7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,8 @@ # FARM (incl. transformers 2.3.0 with pipelines) #farm -e git+https://github.com/deepset-ai/FARM.git@1d30237b037050ef0ac5516f427443cdd18a4d43 -e git://github.com/deepset-ai/FARM.git@1d30237b037050ef0ac5516f427443cdd18a4d43#egg=farm -flask -flask_cors -flask_restplus +fastapi +uvicorn flask_sqlalchemy pandas psycopg2-binary