diff --git a/rest_api/application.py b/rest_api/application.py index 1c19ad615..f3ac38e19 100644 --- a/rest_api/application.py +++ b/rest_api/application.py @@ -3,6 +3,7 @@ import logging from pathlib import Path import uvicorn from fastapi import FastAPI, HTTPException +from fastapi.routing import APIRoute from starlette.middleware.cors import CORSMiddleware from rest_api.controller.errors.http_error import http_error_handler @@ -19,7 +20,7 @@ from rest_api.controller.router import router as api_router def get_application() -> FastAPI: - application = FastAPI(title="Haystack-API", debug=True, version="0.1", root_path=ROOT_PATH) + application = FastAPI(title="Haystack-API", debug=True, version="0.10", root_path=ROOT_PATH) # This middleware enables allow all cross-domain requests to the API from a browser. For production # deployments, it could be made more restrictive. @@ -28,14 +29,23 @@ def get_application() -> FastAPI: ) application.add_exception_handler(HTTPException, http_error_handler) - application.include_router(api_router) return application +def use_route_names_as_operation_ids(app: FastAPI) -> None: + """ + Simplify operation IDs so that generated API clients have simpler function + names (see https://fastapi.tiangolo.com/advanced/path-operation-advanced-configuration/#using-the-path-operation-function-name-as-the-operationid). + The operation IDs will be the same as the route names (i.e. the python method names of the endpoints) + Should be called only after all routes have been added. + """ + for route in app.routes: + if isinstance(route, APIRoute): + route.operation_id = route.name app = get_application() - +use_route_names_as_operation_ids(app) logger.info("Open http://127.0.0.1:8000/docs to see Swagger API Documentation.") logger.info( diff --git a/rest_api/controller/document.py b/rest_api/controller/document.py index 6fc857489..4d06e104e 100644 --- a/rest_api/controller/document.py +++ b/rest_api/controller/document.py @@ -16,8 +16,8 @@ logger = logging.getLogger("haystack") router = APIRouter() -@router.post("/documents/get_by_filters", response_model=List[DocumentSerialized]) -def get_documents_by_filter(filters: FilterRequest): +@router.post("/documents/get_by_filters", response_model=List[DocumentSerialized], response_model_exclude_none=True) +def get_documents(filters: FilterRequest): """ Can be used to get documents from a document store. @@ -33,7 +33,7 @@ def get_documents_by_filter(filters: FilterRequest): @router.post("/documents/delete_by_filters", response_model=bool) -def delete_documents_by_filter(filters: FilterRequest): +def delete_documents(filters: FilterRequest): """ Can be used to delete documents from a document store. diff --git a/rest_api/controller/feedback.py b/rest_api/controller/feedback.py index 4d57b8455..e9ad820ad 100644 --- a/rest_api/controller/feedback.py +++ b/rest_api/controller/feedback.py @@ -11,20 +11,20 @@ router = APIRouter() logger = logging.getLogger(__name__) -@router.post("/feedback") +@router.post("/feedback", operation_id="post_feedback") def user_feedback(feedback: LabelSerialized): if feedback.origin is None: feedback.origin = "user-feedback" DOCUMENT_STORE.write_labels([feedback]) -@router.get("/feedback") +@router.get("/feedback", operation_id="get_feedback") def user_feedback(): labels = DOCUMENT_STORE.get_all_labels() return labels -@router.post("/eval-feedback") +@router.post("/eval-feedback", operation_id="get_feedback_metrics") def eval_extractive_qa_feedback(filters: FilterRequest = None): """ Return basic accuracy metrics based on the user feedback. @@ -62,7 +62,7 @@ def eval_extractive_qa_feedback(filters: FilterRequest = None): @router.get("/export-feedback") -def export_extractive_qa_feedback( +def export_feedback( context_size: int = 100_000, full_document_context: bool = True, only_positive_labels: bool = False ): """ diff --git a/rest_api/controller/file_upload.py b/rest_api/controller/file_upload.py index 5dff2dbb0..0d2cd034c 100644 --- a/rest_api/controller/file_upload.py +++ b/rest_api/controller/file_upload.py @@ -63,7 +63,7 @@ class Response(BaseModel): @router.post("/file-upload") -def file_upload( +def upload_file( files: List[UploadFile] = File(...), meta: Optional[str] = Form("null"), # JSON serialized string fileconverter_params: FileConverterParams = Depends(FileConverterParams.as_form), diff --git a/rest_api/controller/search.py b/rest_api/controller/search.py index c82612af4..082479526 100644 --- a/rest_api/controller/search.py +++ b/rest_api/controller/search.py @@ -31,7 +31,7 @@ concurrency_limiter = RequestLimiter(CONCURRENT_REQUEST_PER_WORKER) @router.get("/initialized") -def initialized(): +def check_status(): """ This endpoint can be used during startup to understand if the server is ready to take any requests, or is still loading. @@ -42,7 +42,7 @@ def initialized(): return True -@router.post("/query", response_model=QueryResponse) +@router.post("/query", response_model=QueryResponse, response_model_exclude_none=True) def query(request: QueryRequest): with concurrency_limiter.run(): result = _process_request(PIPELINE, request) diff --git a/rest_api/schema.py b/rest_api/schema.py index 432427154..a2b0c3505 100644 --- a/rest_api/schema.py +++ b/rest_api/schema.py @@ -29,7 +29,7 @@ class AnswerSerialized(Answer): @pydantic_dataclass class DocumentSerialized(Document): content: str - embedding: List[float] + embedding: Optional[List[float]] @pydantic_dataclass class LabelSerialized(Label):