Fix file upload API (#808)

This commit is contained in:
Tanay Soni 2021-02-05 12:17:38 +01:00 committed by GitHub
parent 7b18e324f2
commit f95b70df38
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 51 additions and 22 deletions

View File

@ -53,13 +53,20 @@ EMBEDDING_MODEL_FORMAT = os.getenv("EMBEDDING_MODEL_FORMAT", "farm")
# File uploads # File uploads
FILE_UPLOAD_PATH = os.getenv("FILE_UPLOAD_PATH", "file-uploads") FILE_UPLOAD_PATH = os.getenv("FILE_UPLOAD_PATH", "file-uploads")
REMOVE_NUMERIC_TABLES = os.getenv("REMOVE_NUMERIC_TABLES", "True").lower() == "true" REMOVE_NUMERIC_TABLES = os.getenv("REMOVE_NUMERIC_TABLES", "True").lower() == "true"
REMOVE_WHITESPACE = os.getenv("REMOVE_WHITESPACE", "True").lower() == "true"
REMOVE_EMPTY_LINES = os.getenv("REMOVE_EMPTY_LINES", "True").lower() == "true"
REMOVE_HEADER_FOOTER = os.getenv("REMOVE_HEADER_FOOTER", "True").lower() == "true"
VALID_LANGUAGES = os.getenv("VALID_LANGUAGES", None) VALID_LANGUAGES = os.getenv("VALID_LANGUAGES", None)
if VALID_LANGUAGES: if VALID_LANGUAGES:
VALID_LANGUAGES = ast.literal_eval(VALID_LANGUAGES) VALID_LANGUAGES = ast.literal_eval(VALID_LANGUAGES)
# Preprocessing
REMOVE_WHITESPACE = os.getenv("REMOVE_WHITESPACE", "True").lower() == "true"
REMOVE_EMPTY_LINES = os.getenv("REMOVE_EMPTY_LINES", "True").lower() == "true"
REMOVE_HEADER_FOOTER = os.getenv("REMOVE_HEADER_FOOTER", "True").lower() == "true"
SPLIT_BY = os.getenv("SPLIT_BY", "word")
SPLIT_LENGTH = os.getenv("SPLIT_LENGTH", 1_000)
SPLIT_OVERLAP = os.getenv("SPLIT_OVERLAP", None)
SPLIT_RESPECT_SENTENCE_BOUNDARY = os.getenv("SPLIT_RESPECT_SENTENCE_BOUNDARY", True)
# Monitoring # Monitoring
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO") LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")
APM_SERVER = os.getenv("APM_SERVER", None) APM_SERVER = os.getenv("APM_SERVER", None)

View File

@ -12,10 +12,12 @@ from fastapi import UploadFile, File, Form
from rest_api.config import DB_HOST, DB_PORT, DB_USER, DB_PW, DB_INDEX, DB_INDEX_FEEDBACK, ES_CONN_SCHEME, TEXT_FIELD_NAME, \ from rest_api.config import DB_HOST, DB_PORT, DB_USER, DB_PW, DB_INDEX, DB_INDEX_FEEDBACK, ES_CONN_SCHEME, TEXT_FIELD_NAME, \
SEARCH_FIELD_NAME, FILE_UPLOAD_PATH, EMBEDDING_DIM, EMBEDDING_FIELD_NAME, EXCLUDE_META_DATA_FIELDS, VALID_LANGUAGES, \ SEARCH_FIELD_NAME, FILE_UPLOAD_PATH, EMBEDDING_DIM, EMBEDDING_FIELD_NAME, EXCLUDE_META_DATA_FIELDS, VALID_LANGUAGES, \
FAQ_QUESTION_FIELD_NAME, REMOVE_NUMERIC_TABLES, REMOVE_WHITESPACE, REMOVE_EMPTY_LINES, REMOVE_HEADER_FOOTER, \ FAQ_QUESTION_FIELD_NAME, REMOVE_NUMERIC_TABLES, REMOVE_WHITESPACE, REMOVE_EMPTY_LINES, REMOVE_HEADER_FOOTER, \
CREATE_INDEX, UPDATE_EXISTING_DOCUMENTS, VECTOR_SIMILARITY_METRIC CREATE_INDEX, UPDATE_EXISTING_DOCUMENTS, VECTOR_SIMILARITY_METRIC, SPLIT_BY, SPLIT_LENGTH, SPLIT_OVERLAP, \
SPLIT_RESPECT_SENTENCE_BOUNDARY
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
from haystack.file_converter.pdf import PDFToTextConverter from haystack.file_converter.pdf import PDFToTextConverter
from haystack.file_converter.txt import TextConverter from haystack.file_converter.txt import TextConverter
from haystack.preprocessor.preprocessor import PreProcessor
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -30,7 +32,7 @@ document_store = ElasticsearchDocumentStore(
index=DB_INDEX, index=DB_INDEX,
label_index=DB_INDEX_FEEDBACK, label_index=DB_INDEX_FEEDBACK,
scheme=ES_CONN_SCHEME, scheme=ES_CONN_SCHEME,
ca_certs=False, ca_certs=None,
verify_certs=False, verify_certs=False,
text_field=TEXT_FIELD_NAME, text_field=TEXT_FIELD_NAME,
search_fields=SEARCH_FIELD_NAME, search_fields=SEARCH_FIELD_NAME,
@ -54,6 +56,10 @@ def upload_file_to_document_store(
remove_empty_lines: Optional[bool] = Form(REMOVE_EMPTY_LINES), remove_empty_lines: Optional[bool] = Form(REMOVE_EMPTY_LINES),
remove_header_footer: Optional[bool] = Form(REMOVE_HEADER_FOOTER), remove_header_footer: Optional[bool] = Form(REMOVE_HEADER_FOOTER),
valid_languages: Optional[List[str]] = Form(VALID_LANGUAGES), valid_languages: Optional[List[str]] = Form(VALID_LANGUAGES),
split_by: Optional[str] = Form(SPLIT_BY),
split_length: Optional[int] = Form(SPLIT_LENGTH),
split_overlap: Optional[int] = Form(SPLIT_OVERLAP),
split_respect_sentence_boundary: Optional[bool] = Form(SPLIT_RESPECT_SENTENCE_BOUNDARY),
): ):
try: try:
file_path = Path(FILE_UPLOAD_PATH) / f"{uuid.uuid4().hex}_{file.filename}" file_path = Path(FILE_UPLOAD_PATH) / f"{uuid.uuid4().hex}_{file.filename}"
@ -62,27 +68,31 @@ def upload_file_to_document_store(
if file.filename.split(".")[-1].lower() == "pdf": if file.filename.split(".")[-1].lower() == "pdf":
pdf_converter = PDFToTextConverter( pdf_converter = PDFToTextConverter(
remove_numeric_tables=remove_numeric_tables, remove_numeric_tables=remove_numeric_tables, valid_languages=valid_languages
remove_whitespace=remove_whitespace,
remove_empty_lines=remove_empty_lines,
remove_header_footer=remove_header_footer,
valid_languages=valid_languages,
) )
document = pdf_converter.convert(file_path) document = pdf_converter.convert(file_path)
elif file.filename.split(".")[-1].lower() == "txt": elif file.filename.split(".")[-1].lower() == "txt":
txt_converter = TextConverter( txt_converter = TextConverter(
remove_numeric_tables=remove_numeric_tables, remove_numeric_tables=remove_numeric_tables, valid_languages=valid_languages,
remove_whitespace=remove_whitespace,
remove_empty_lines=remove_empty_lines,
remove_header_footer=remove_header_footer,
valid_languages=valid_languages,
) )
document = txt_converter.convert(file_path) document = txt_converter.convert(file_path)
else: else:
raise HTTPException(status_code=415, detail=f"Only .pdf and .txt file formats are supported.") raise HTTPException(status_code=415, detail=f"Only .pdf and .txt file formats are supported.")
document_to_write = {TEXT_FIELD_NAME: document["text"], "name": file.filename} document = {TEXT_FIELD_NAME: document["text"], "name": file.filename}
document_store.write_documents([document_to_write])
preprocessor = PreProcessor(
clean_whitespace=remove_whitespace,
clean_header_footer=remove_header_footer,
clean_empty_lines=remove_empty_lines,
split_by=split_by,
split_length=split_length,
split_overlap=split_overlap,
split_respect_sentence_boundary=split_respect_sentence_boundary,
)
documents = preprocessor.process(document)
document_store.write_documents(documents)
return "File upload was successful." return "File upload was successful."
finally: finally:
file.file.close() file.file.close()

View File

@ -1,19 +1,20 @@
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from pathlib import Path
from haystack import Finder from haystack import Finder
from haystack.retriever.sparse import ElasticsearchRetriever from haystack.retriever.sparse import ElasticsearchRetriever
# TODO: Add integration tests for other APIs # TODO: Add integration tests for other APIs
def get_test_client_and_override_dependencies(reader, document_store_with_docs): def get_test_client_and_override_dependencies(reader, document_store):
from rest_api.application import app from rest_api.application import app
from rest_api.controller import search from rest_api.controller import search, file_upload
search.document_store = document_store_with_docs search.document_store = document_store
search.retriever = ElasticsearchRetriever(document_store=document_store_with_docs) search.retriever = ElasticsearchRetriever(document_store=document_store)
search.FINDERS = {1: Finder(reader=reader, retriever=search.retriever)} search.FINDERS = {1: Finder(reader=reader, retriever=search.retriever)}
file_upload.document_store = document_store
return TestClient(app) return TestClient(app)
@ -96,3 +97,14 @@ def test_query_api_filters(reader, document_store_with_docs):
assert "New York" == response_json['hits']['hits'][0]["_source"]["answer"] assert "New York" == response_json['hits']['hits'][0]["_source"]["answer"]
assert "My name is Paul and I live in New York" == response_json['hits']['hits'][0]["_source"]["context"] assert "My name is Paul and I live in New York" == response_json['hits']['hits'][0]["_source"]["context"]
@pytest.mark.slow
@pytest.mark.elasticsearch
@pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True)
def test_file_upload(document_store):
assert document_store.get_document_count() == 0
client = get_test_client_and_override_dependencies(reader=None, document_store=document_store)
file_to_upload = {'file': Path("samples/pdf/sample_pdf_1.pdf").open('rb')}
response = client.post(url="/file-upload", files=file_to_upload)
assert 200 == response.status_code
assert document_store.get_document_count() > 0