Improve open api spec (#1700)

* improve open api spec

* move to automatic generation of better operationIDs
This commit is contained in:
Malte Pietsch 2021-11-11 09:40:58 +01:00 committed by GitHub
parent 14515a861b
commit b28dd823ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 24 additions and 14 deletions

View File

@ -3,6 +3,7 @@ import logging
from pathlib import Path
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.routing import APIRoute
from starlette.middleware.cors import CORSMiddleware
from rest_api.controller.errors.http_error import http_error_handler
@ -19,7 +20,7 @@ from rest_api.controller.router import router as api_router
def get_application() -> FastAPI:
application = FastAPI(title="Haystack-API", debug=True, version="0.1", root_path=ROOT_PATH)
application = FastAPI(title="Haystack-API", debug=True, version="0.10", root_path=ROOT_PATH)
# This middleware enables allow all cross-domain requests to the API from a browser. For production
# deployments, it could be made more restrictive.
@ -28,14 +29,23 @@ def get_application() -> FastAPI:
)
application.add_exception_handler(HTTPException, http_error_handler)
application.include_router(api_router)
return application
def use_route_names_as_operation_ids(app: FastAPI) -> None:
"""
Simplify operation IDs so that generated API clients have simpler function
names (see https://fastapi.tiangolo.com/advanced/path-operation-advanced-configuration/#using-the-path-operation-function-name-as-the-operationid).
The operation IDs will be the same as the route names (i.e. the python method names of the endpoints)
Should be called only after all routes have been added.
"""
for route in app.routes:
if isinstance(route, APIRoute):
route.operation_id = route.name
app = get_application()
use_route_names_as_operation_ids(app)
logger.info("Open http://127.0.0.1:8000/docs to see Swagger API Documentation.")
logger.info(

View File

@ -16,8 +16,8 @@ logger = logging.getLogger("haystack")
router = APIRouter()
@router.post("/documents/get_by_filters", response_model=List[DocumentSerialized])
def get_documents_by_filter(filters: FilterRequest):
@router.post("/documents/get_by_filters", response_model=List[DocumentSerialized], response_model_exclude_none=True)
def get_documents(filters: FilterRequest):
"""
Can be used to get documents from a document store.
@ -33,7 +33,7 @@ def get_documents_by_filter(filters: FilterRequest):
@router.post("/documents/delete_by_filters", response_model=bool)
def delete_documents_by_filter(filters: FilterRequest):
def delete_documents(filters: FilterRequest):
"""
Can be used to delete documents from a document store.

View File

@ -11,20 +11,20 @@ router = APIRouter()
logger = logging.getLogger(__name__)
@router.post("/feedback")
@router.post("/feedback", operation_id="post_feedback")
def user_feedback(feedback: LabelSerialized):
if feedback.origin is None:
feedback.origin = "user-feedback"
DOCUMENT_STORE.write_labels([feedback])
@router.get("/feedback")
@router.get("/feedback", operation_id="get_feedback")
def user_feedback():
labels = DOCUMENT_STORE.get_all_labels()
return labels
@router.post("/eval-feedback")
@router.post("/eval-feedback", operation_id="get_feedback_metrics")
def eval_extractive_qa_feedback(filters: FilterRequest = None):
"""
Return basic accuracy metrics based on the user feedback.
@ -62,7 +62,7 @@ def eval_extractive_qa_feedback(filters: FilterRequest = None):
@router.get("/export-feedback")
def export_extractive_qa_feedback(
def export_feedback(
context_size: int = 100_000, full_document_context: bool = True, only_positive_labels: bool = False
):
"""

View File

@ -63,7 +63,7 @@ class Response(BaseModel):
@router.post("/file-upload")
def file_upload(
def upload_file(
files: List[UploadFile] = File(...),
meta: Optional[str] = Form("null"), # JSON serialized string
fileconverter_params: FileConverterParams = Depends(FileConverterParams.as_form),

View File

@ -31,7 +31,7 @@ concurrency_limiter = RequestLimiter(CONCURRENT_REQUEST_PER_WORKER)
@router.get("/initialized")
def initialized():
def check_status():
"""
This endpoint can be used during startup to understand if the
server is ready to take any requests, or is still loading.
@ -42,7 +42,7 @@ def initialized():
return True
@router.post("/query", response_model=QueryResponse)
@router.post("/query", response_model=QueryResponse, response_model_exclude_none=True)
def query(request: QueryRequest):
with concurrency_limiter.run():
result = _process_request(PIPELINE, request)

View File

@ -29,7 +29,7 @@ class AnswerSerialized(Answer):
@pydantic_dataclass
class DocumentSerialized(Document):
content: str
embedding: List[float]
embedding: Optional[List[float]]
@pydantic_dataclass
class LabelSerialized(Label):