Guillim 73a4f9825a
Add env var CONCURRENT_REQUEST_PER_WORKER (#1235)
* we create an env var `CONCURRENT_REQUEST_PER_WORKER` following your naming convention, (I came a few commit backwards to find the original name)

* default to 4
2021-06-29 07:44:25 +02:00

81 lines
2.3 KiB
Python

import json
import logging
import time
from pathlib import Path
from typing import Dict, List, Optional, Union, Any
from fastapi import APIRouter
from pydantic import BaseModel
from haystack import Pipeline
from rest_api.config import PIPELINE_YAML_PATH, LOG_LEVEL, QUERY_PIPELINE_NAME, CONCURRENT_REQUEST_PER_WORKER
from rest_api.controller.utils import RequestLimiter
logging.getLogger("haystack").setLevel(LOG_LEVEL)
logger = logging.getLogger("haystack")
router = APIRouter()
class Request(BaseModel):
query: str
filters: Optional[Dict[str, Optional[Union[str, List[str]]]]] = None
top_k_retriever: Optional[int] = None
top_k_reader: Optional[int] = None
class Answer(BaseModel):
answer: Optional[str]
question: Optional[str]
score: Optional[float] = None
probability: Optional[float] = None
context: Optional[str]
offset_start: int
offset_end: int
offset_start_in_doc: Optional[int]
offset_end_in_doc: Optional[int]
document_id: Optional[str] = None
meta: Optional[Dict[str, Any]]
class Response(BaseModel):
query: str
answers: List[Answer]
PIPELINE = Pipeline.load_from_yaml(Path(PIPELINE_YAML_PATH), pipeline_name=QUERY_PIPELINE_NAME)
logger.info(f"Loaded pipeline nodes: {PIPELINE.graph.nodes.keys()}")
concurrency_limiter = RequestLimiter(CONCURRENT_REQUEST_PER_WORKER)
@router.post("/query", response_model=Response)
def query(request: Request):
with concurrency_limiter.run():
result = _process_request(PIPELINE, request)
return result
def _process_request(pipeline, request) -> Response:
start_time = time.time()
filters = {}
if request.filters:
# put filter values into a list and remove filters with null value
for key, values in request.filters.items():
if values is None:
continue
if not isinstance(values, list):
values = [values]
filters[key] = values
result = pipeline.run(query=request.query,
filters=filters,
top_k_retriever=request.top_k_retriever,
top_k_reader=request.top_k_reader)
end_time = time.time()
logger.info(json.dumps({"request": request.dict(), "response": result, "time": f"{(end_time - start_time):.2f}"}))
return result