diff --git a/README.rst b/README.rst
index 717b7674b..72fdc1317 100644
--- a/README.rst
+++ b/README.rst
@@ -82,6 +82,5 @@ Elasticsearch Backend
REST API
--------
-A Flask based HTTP REST API is included to use the QA Framework with UI or integrating with other systems. To serve the API, run :code:`FLASK_APP=farm_hackstack.api.inference flask run`.
-
-
+A simple REST API based on `FastAPI `_ is included to answer questions at inference time. To serve the API, run :code:`uvicorn haystack.api.inference:app`.
+You will find the Swagger API documentation at http://127.0.0.1:8000/docs
\ No newline at end of file
diff --git a/haystack/api/inference.py b/haystack/api/inference.py
index e343e7815..31577bca2 100644
--- a/haystack/api/inference.py
+++ b/haystack/api/inference.py
@@ -1,23 +1,30 @@
-import json
from pathlib import Path
+from fastapi import FastAPI, HTTPException
-import numpy as np
-from flask import request, make_response
-from flask_cors import CORS
-from flask_restplus import Api, Resource
+import logging
from haystack import Finder
from haystack.database import app
from haystack.reader.farm import FARMReader
from haystack.retriever.tfidf import TfidfRetriever
-CORS(app)
-api = Api(
- app, debug=True, validate=True, version="1.0", title="FARM Question Answering API"
-)
+from pydantic import BaseModel
+from typing import List, Dict
+import uvicorn
-MODELS_DIRS = ["saved_models"]
+logger = logging.getLogger(__name__)
+#TODO Enable CORS
+
+MODELS_DIRS = ["saved_models", "models", "model"]
+USE_GPU = False
+BATCH_SIZE = 16
+
+app = FastAPI(title="Haystack API", version="0.1")
+
+#############################################
+# Load all models in memory
+#############################################
model_paths = []
for model_dir in MODELS_DIRS:
path = Path(model_dir)
@@ -25,49 +32,59 @@ for model_dir in MODELS_DIRS:
models = [f for f in path.iterdir() if f.is_dir()]
model_paths.extend(models)
+if len(model_paths) == 0:
+ logger.error(f"Could not find any model to load. Checked folders: {MODELS_DIRS}")
+
retriever = TfidfRetriever()
FINDERS = {}
-for idx, model_dir in enumerate(model_paths):
- reader = FARMReader(model_dir=str(model_dir), batch_size=16)
- FINDERS[idx + 1] = Finder(reader, retriever)
+for idx, model_dir in enumerate(model_paths, start=1):
+ reader = FARMReader(model_dir=str(model_dir), batch_size=BATCH_SIZE, use_gpu=USE_GPU)
+ FINDERS[idx] = Finder(reader, retriever)
+ logger.info(f"Initialized Finder (ID={idx}) with model '{model_dir}'")
+
+logger.info("Open http://127.0.0.1:8000/docs to see Swagger API Documentation.")
+logger.info(""" Or just try it out directly: curl --request POST --url 'http://127.0.0.1:8000/finders/1/ask' --data '{"question": "Who is the father of Arya Starck?"}'""")
+
+#############################################
+# Basic data schema for request & response
+#############################################
+class Request(BaseModel):
+ question: str
+ filters: Dict[str, str] = None
+ top_k_reader: int = 5
+ top_k_retriever: int = 10
-class NumpyEncoder(json.JSONEncoder):
- def default(self, obj):
- if isinstance(obj, np.ndarray):
- return obj.tolist()
- if isinstance(obj, np.float32):
- return str(obj)
- return json.JSONEncoder.default(self, obj)
+class Answer(BaseModel):
+ answer: str
+ score: float = None
+ probability: float = None
+ context: str
+ offset_start: int
+ offset_end: int
+ document_id: str = None
-@api.representation("application/json")
-def resp_json(data, code, headers=None):
- resp = make_response(json.dumps(data, cls=NumpyEncoder), code)
- resp.headers.extend(headers or {})
- return resp
+class Response(BaseModel):
+ question: str
+ answers: List[Answer]
+#############################################
+# Endpoints
+#############################################
+@app.post("/finders/{finder_id}/ask", response_model=Response, response_model_exclude_unset=True)
+def ask(finder_id: int, request: Request):
+ finder = FINDERS.get(finder_id, None)
+ if not finder:
+ raise HTTPException(status_code=404, detail=f"Couldn't get Finder with ID {finder_id}. Available IDs: {list(FINDERS.keys())}")
-@api.route("/finders//ask")
-class InferenceEndpoint(Resource):
- def post(self, finder_id):
- finder = FINDERS.get(finder_id, None)
- if not finder:
- return "Model not found", 404
+ results = finder.get_answers(
+ question=request.question, top_k_retriever=request.top_k_retriever,
+ top_k_reader=request.top_k_reader, filters=request.filters
+ )
- request_body = request.get_json()
- questions = request_body.get("questions", None)
- if not questions:
- return "The request is missing 'questions' field", 400
-
- filters = request_body.get("filters", None)
-
- results = finder.get_answers(
- question=request_body["questions"][0], top_k_reader=3, filters=filters
- )
-
- return results
+ return results
if __name__ == "__main__":
- app.run(host="0.0.0.0")
+ uvicorn.run(app, host="0.0.0.0", port=8000)
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 2cdb0954a..a455a19d7 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,9 +1,8 @@
# FARM (incl. transformers 2.3.0 with pipelines)
#farm -e git+https://github.com/deepset-ai/FARM.git@1d30237b037050ef0ac5516f427443cdd18a4d43
-e git://github.com/deepset-ai/FARM.git@1d30237b037050ef0ac5516f427443cdd18a4d43#egg=farm
-flask
-flask_cors
-flask_restplus
+fastapi
+uvicorn
flask_sqlalchemy
pandas
psycopg2-binary