mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-27 15:08:43 +00:00
Move API from flask to fastAPI (#3)
* moving api from flask to fastAPI * update path parameter. add logs * add error msg for wrong finder id
This commit is contained in:
parent
0b3fd7572b
commit
a80ac220e4
@ -82,6 +82,5 @@ Elasticsearch Backend
|
||||
|
||||
REST API
|
||||
--------
|
||||
A Flask based HTTP REST API is included to use the QA Framework with UI or integrating with other systems. To serve the API, run :code:`FLASK_APP=farm_hackstack.api.inference flask run`.
|
||||
|
||||
|
||||
A simple REST API based on `FastAPI <https://fastapi.tiangolo.com/>`_ is included to answer questions at inference time. To serve the API, run :code:`uvicorn haystack.api.inference:app`.
|
||||
You will find the Swagger API documentation at http://127.0.0.1:8000/docs
|
||||
@ -1,23 +1,30 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from fastapi import FastAPI, HTTPException
|
||||
|
||||
import numpy as np
|
||||
from flask import request, make_response
|
||||
from flask_cors import CORS
|
||||
from flask_restplus import Api, Resource
|
||||
import logging
|
||||
|
||||
from haystack import Finder
|
||||
from haystack.database import app
|
||||
from haystack.reader.farm import FARMReader
|
||||
from haystack.retriever.tfidf import TfidfRetriever
|
||||
|
||||
CORS(app)
|
||||
api = Api(
|
||||
app, debug=True, validate=True, version="1.0", title="FARM Question Answering API"
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Dict
|
||||
import uvicorn
|
||||
|
||||
MODELS_DIRS = ["saved_models"]
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
#TODO Enable CORS
|
||||
|
||||
MODELS_DIRS = ["saved_models", "models", "model"]
|
||||
USE_GPU = False
|
||||
BATCH_SIZE = 16
|
||||
|
||||
app = FastAPI(title="Haystack API", version="0.1")
|
||||
|
||||
#############################################
|
||||
# Load all models in memory
|
||||
#############################################
|
||||
model_paths = []
|
||||
for model_dir in MODELS_DIRS:
|
||||
path = Path(model_dir)
|
||||
@ -25,49 +32,59 @@ for model_dir in MODELS_DIRS:
|
||||
models = [f for f in path.iterdir() if f.is_dir()]
|
||||
model_paths.extend(models)
|
||||
|
||||
if len(model_paths) == 0:
|
||||
logger.error(f"Could not find any model to load. Checked folders: {MODELS_DIRS}")
|
||||
|
||||
retriever = TfidfRetriever()
|
||||
FINDERS = {}
|
||||
for idx, model_dir in enumerate(model_paths):
|
||||
reader = FARMReader(model_dir=str(model_dir), batch_size=16)
|
||||
FINDERS[idx + 1] = Finder(reader, retriever)
|
||||
for idx, model_dir in enumerate(model_paths, start=1):
|
||||
reader = FARMReader(model_dir=str(model_dir), batch_size=BATCH_SIZE, use_gpu=USE_GPU)
|
||||
FINDERS[idx] = Finder(reader, retriever)
|
||||
logger.info(f"Initialized Finder (ID={idx}) with model '{model_dir}'")
|
||||
|
||||
logger.info("Open http://127.0.0.1:8000/docs to see Swagger API Documentation.")
|
||||
logger.info(""" Or just try it out directly: curl --request POST --url 'http://127.0.0.1:8000/finders/1/ask' --data '{"question": "Who is the father of Arya Starck?"}'""")
|
||||
|
||||
#############################################
|
||||
# Basic data schema for request & response
|
||||
#############################################
|
||||
class Request(BaseModel):
|
||||
question: str
|
||||
filters: Dict[str, str] = None
|
||||
top_k_reader: int = 5
|
||||
top_k_retriever: int = 10
|
||||
|
||||
|
||||
class NumpyEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
if isinstance(obj, np.float32):
|
||||
return str(obj)
|
||||
return json.JSONEncoder.default(self, obj)
|
||||
class Answer(BaseModel):
|
||||
answer: str
|
||||
score: float = None
|
||||
probability: float = None
|
||||
context: str
|
||||
offset_start: int
|
||||
offset_end: int
|
||||
document_id: str = None
|
||||
|
||||
|
||||
@api.representation("application/json")
|
||||
def resp_json(data, code, headers=None):
|
||||
resp = make_response(json.dumps(data, cls=NumpyEncoder), code)
|
||||
resp.headers.extend(headers or {})
|
||||
return resp
|
||||
class Response(BaseModel):
|
||||
question: str
|
||||
answers: List[Answer]
|
||||
|
||||
#############################################
|
||||
# Endpoints
|
||||
#############################################
|
||||
@app.post("/finders/{finder_id}/ask", response_model=Response, response_model_exclude_unset=True)
|
||||
def ask(finder_id: int, request: Request):
|
||||
finder = FINDERS.get(finder_id, None)
|
||||
if not finder:
|
||||
raise HTTPException(status_code=404, detail=f"Couldn't get Finder with ID {finder_id}. Available IDs: {list(FINDERS.keys())}")
|
||||
|
||||
@api.route("/finders/<int:finder_id>/ask")
|
||||
class InferenceEndpoint(Resource):
|
||||
def post(self, finder_id):
|
||||
finder = FINDERS.get(finder_id, None)
|
||||
if not finder:
|
||||
return "Model not found", 404
|
||||
results = finder.get_answers(
|
||||
question=request.question, top_k_retriever=request.top_k_retriever,
|
||||
top_k_reader=request.top_k_reader, filters=request.filters
|
||||
)
|
||||
|
||||
request_body = request.get_json()
|
||||
questions = request_body.get("questions", None)
|
||||
if not questions:
|
||||
return "The request is missing 'questions' field", 400
|
||||
|
||||
filters = request_body.get("filters", None)
|
||||
|
||||
results = finder.get_answers(
|
||||
question=request_body["questions"][0], top_k_reader=3, filters=filters
|
||||
)
|
||||
|
||||
return results
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(host="0.0.0.0")
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
@ -1,9 +1,8 @@
|
||||
# FARM (incl. transformers 2.3.0 with pipelines)
|
||||
#farm -e git+https://github.com/deepset-ai/FARM.git@1d30237b037050ef0ac5516f427443cdd18a4d43
|
||||
-e git://github.com/deepset-ai/FARM.git@1d30237b037050ef0ac5516f427443cdd18a4d43#egg=farm
|
||||
flask
|
||||
flask_cors
|
||||
flask_restplus
|
||||
fastapi
|
||||
uvicorn
|
||||
flask_sqlalchemy
|
||||
pandas
|
||||
psycopg2-binary
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user