651 lines
23 KiB
Python
Raw Permalink Normal View History

import base64
import contextlib
import json
import os
import pathlib
from typing import Any
from unittest.mock import Mock
import pytest
import requests
from unstructured_client.general import General
from unstructured_client.models import shared
from unstructured_client.models.operations import PartitionRequest
from unstructured_client.models.shared import PartitionParameters
from unstructured_client.utils import retries
from unstructured.documents.elements import ElementType, NarrativeText
from unstructured.partition.api import (
DEFAULT_RETRIES_MAX_ELAPSED_TIME_SEC,
DEFAULT_RETRIES_MAX_INTERVAL_SEC,
get_retries_config,
partition_multiple_via_api,
partition_via_api,
)
from ..unit_utils import ANY, FixtureRequest, example_doc_path, method_mock
DIRECTORY = pathlib.Path(__file__).parent.resolve()
# NOTE(yao): point to paid API for now
API_URL = "https://api.unstructuredapp.io/general/v0/general"
is_in_ci = os.getenv("CI", "").lower() not in {"", "false", "f", "0"}
skip_not_on_main = os.getenv("GITHUB_REF_NAME", "").lower() != "main"
def test_partition_via_api_with_filename_correctly_calls_sdk(
request: FixtureRequest, expected_call_: list[Any]
):
partition_mock_ = method_mock(
request, General, "partition", return_value=FakeResponse(status_code=200)
)
elements = partition_via_api(filename=example_doc_path("eml/fake-email.eml"))
partition_mock_.assert_called_once_with(*expected_call_)
assert isinstance(partition_mock_.call_args_list[0].args[0], General)
assert len(elements) == 1
assert elements[0] == NarrativeText("This is a test email to use for unit tests.")
assert elements[0].metadata.filetype == "message/rfc822"
def test_partition_via_api_with_file_correctly_calls_sdk(
request: FixtureRequest, expected_call_: list[Any]
):
partition_mock_ = method_mock(
request, General, "partition", return_value=FakeResponse(status_code=200)
)
with open(example_doc_path("eml/fake-email.eml"), "rb") as f:
elements = partition_via_api(
file=f, metadata_filename=example_doc_path("eml/fake-email.eml")
)
# Update the fixture content to match the format passed to partition_via_api
modified_expected_call = expected_call_[:]
modified_expected_call[1].partition_parameters.files.content = f
partition_mock_.assert_called_once_with(*modified_expected_call)
assert isinstance(partition_mock_.call_args_list[0].args[0], General)
assert len(elements) == 1
assert elements[0] == NarrativeText("This is a test email to use for unit tests.")
assert elements[0].metadata.filetype == "message/rfc822"
def test_partition_via_api_warns_with_file_and_filename_and_calls_sdk(
request: FixtureRequest, expected_call_: list[Any], caplog: pytest.LogCaptureFixture
):
partition_mock_ = method_mock(
request, General, "partition", return_value=FakeResponse(status_code=200)
)
with open(example_doc_path("eml/fake-email.eml"), "rb") as f:
partition_via_api(file=f, file_filename=example_doc_path("eml/fake-email.eml"))
# Update the fixture content to match the format passed to partition_via_api
modified_expected_call = expected_call_[:]
modified_expected_call[1].partition_parameters.files.content = f
partition_mock_.assert_called_once_with(*modified_expected_call)
assert "WARNING" in caplog.text
assert "The file_filename kwarg will be deprecated" in caplog.text
def test_partition_via_api_from_file_raises_with_metadata_and_file_and_filename():
filename = example_doc_path("eml/fake-email.eml")
with open(filename, "rb") as f, pytest.raises(ValueError):
partition_via_api(file=f, file_filename=filename, metadata_filename=filename)
def test_partition_via_api_from_file_raises_without_filename():
with open(example_doc_path("eml/fake-email.eml"), "rb") as f, pytest.raises(ValueError):
partition_via_api(file=f)
def test_partition_via_api_raises_with_bad_response(request: FixtureRequest):
partition_mock_ = method_mock(
request, General, "partition", return_value=FakeResponse(status_code=500)
)
with pytest.raises(ValueError):
partition_via_api(filename=example_doc_path("eml/fake-email.eml"))
partition_mock_.assert_called_once()
@pytest.mark.skipif(not is_in_ci, reason="Skipping test run outside of CI")
@pytest.mark.skipif(skip_not_on_main, reason="Skipping test run outside of main branch")
def test_partition_via_api_with_no_strategy():
test_file = example_doc_path("pdf/loremipsum-flat.pdf")
elements_no_strategy = partition_via_api(
filename=test_file,
strategy="auto",
api_key=get_api_key(),
# The url has changed since the 06/24 API release while the sdk defaults to the old url
api_url=API_URL,
fix: set `skip_infer_tables` explicitly in `test_partition_via_api_with_no_strategy` (#3057) ### Summary A `partition_via_api` test that only runs on `main` was [failing](https://github.com/Unstructured-IO/unstructured/actions/runs/9159429513/job/25181600959) with the following output, likely due to the change in the default behavior for `skip_infer_table_types`. This PR explicitly sets the `skip_infer_table_types` param to avoid the failure.. ```python =========================== short test summary info ============================ FAILED test_unstructured/partition/test_api.py::test_partition_via_api_with_no_strategy - AssertionError: assert 'Zejiang Shen® (<), Ruochen Zhang?, Melissa Dell®, Benjamin Charles Germain Lee?, Jacob Carlson®, and Weining Li®' != 'Zejiang Shen® (<), Ruochen Zhang?, Melissa Dell®, Benjamin Charles Germain Lee?, Jacob Carlson®, and Weining Li®' + where 'Zejiang Shen® (<), Ruochen Zhang?, Melissa Dell®, Benjamin Charles Germain Lee?, Jacob Carlson®, and Weining Li®' = <unstructured.documents.elements.Text object at 0x7fb9069fc610>.text + and 'Zejiang Shen® (<), Ruochen Zhang?, Melissa Dell®, Benjamin Charles Germain Lee?, Jacob Carlson®, and Weining Li®' = <unstructured.documents.elements.Text object at 0x7fb90648ad90>.text = 1 failed, 2299 passed, 9 skipped, 2 deselected, 2 xfailed, 9 xpassed, 14 warnings in 1241.64s (0:20:41) = make: *** [Makefile:302: test] Error 1 ``` ### Testing After temporarily removing the "skip if not on `main`" `pytest` mark, the [unit tests pass](https://github.com/Unstructured-IO/unstructured/actions/runs/9163268381/job/25192040902?pr=3057O) on the feature branch.
2024-05-20 19:05:13 -04:00
skip_infer_table_types=["pdf"],
)
elements_hi_res = partition_via_api(
filename=test_file,
strategy="hi_res",
api_key=get_api_key(),
# The url has changed since the 06/24 API release while the sdk defaults to the old url
api_url=API_URL,
skip_infer_table_types=["pdf"],
)
elements_fast_res = partition_via_api(
filename=test_file,
strategy="fast",
api_key=get_api_key(),
# The url has changed since the 06/24 API release while the sdk defaults to the old url
api_url=API_URL,
skip_infer_table_types=["pdf"],
)
# confirm that hi_res strategy was not passed as default to partition by comparing outputs
# elements_hi_res[3].text =
# 'LayoutParser: A Unified Toolkit for Deep Learning Based Document Image Analysis'
# while elements_no_strategy[3].text = ']' (as of this writing)
assert len(elements_no_strategy) == len(elements_hi_res)
assert len(elements_hi_res) != len(elements_fast_res)
# NOTE(crag): slightly out scope assertion, but avoid extra API call
assert elements_hi_res[0].metadata.coordinates is None
@pytest.mark.skipif(not is_in_ci, reason="Skipping test run outside of CI")
@pytest.mark.skipif(skip_not_on_main, reason="Skipping test run outside of main branch")
def test_partition_via_api_with_image_hi_res_strategy_includes_coordinates():
# coordinates not included by default to limit payload size
elements = partition_via_api(
filename=example_doc_path("pdf/fake-memo.pdf"),
strategy="hi_res",
coordinates="true",
api_key=get_api_key(),
api_url=API_URL,
)
assert elements[0].metadata.coordinates is not None
@pytest.mark.skipif(not is_in_ci, reason="Skipping test run outside of CI")
@pytest.mark.skipif(skip_not_on_main, reason="Skipping test run outside of main branch")
def test_partition_via_api_image_block_extraction():
elements = partition_via_api(
filename=example_doc_path("pdf/embedded-images-tables.pdf"),
strategy="hi_res",
extract_image_block_types=["image", "table"],
api_key=get_api_key(),
# The url has changed since the 06/24 API release while the sdk defaults to the old url
api_url=API_URL,
)
image_elements = [el for el in elements if el.category == ElementType.IMAGE]
for el in image_elements:
assert el.metadata.image_base64 is not None
assert el.metadata.image_mime_type is not None
image_data = base64.b64decode(el.metadata.image_base64)
assert isinstance(image_data, bytes)
@pytest.mark.skipif(not is_in_ci, reason="Skipping test run outside of CI")
@pytest.mark.skipif(skip_not_on_main, reason="Skipping test run outside of main branch")
def test_partition_via_api_retries_config():
elements = partition_via_api(
filename=example_doc_path("pdf/embedded-images-tables.pdf"),
strategy="fast",
api_key=get_api_key(),
# The url has changed since the 06/24 API release while the sdk defaults to the old url
api_url=API_URL,
retries_initial_interval=5,
retries_max_interval=15,
retries_max_elapsed_time=100,
retries_connection_errors=True,
retries_exponent=1.5,
)
assert len(elements) > 0
# Note(austin) - This test is way too noisy against the hosted api
# def test_partition_via_api_invalid_request_data_kwargs():
# filename = os.path.join(DIRECTORY, "..", "..", "example-docs", "layout-parser-paper-fast.pdf")
# with pytest.raises(SDKError):
# partition_via_api(filename=filename, strategy="not_a_strategy")
def test_retries_config_with_parameters_set():
sdk = Mock()
retries_config = get_retries_config(
retries_connection_errors=True,
retries_exponent=1.75,
retries_initial_interval=20,
retries_max_elapsed_time=1000,
retries_max_interval=100,
sdk=sdk,
)
assert retries_config.retry_connection_errors
assert retries_config.backoff.exponent == 1.75
assert retries_config.backoff.initial_interval == 20
assert retries_config.backoff.max_elapsed_time == 1000
assert retries_config.backoff.max_interval == 100
def test_retries_config_none_parameters_return_empty_config():
sdk = Mock()
retries_config = get_retries_config(
retries_connection_errors=None,
retries_exponent=None,
retries_initial_interval=None,
retries_max_elapsed_time=None,
retries_max_interval=None,
sdk=sdk,
)
assert retries_config is None
def test_retry_config_with_empty_sdk_retry_config_returns_default():
sdk = Mock()
sdk.sdk_configuration.retry_config = None
retries_config = get_retries_config(
retries_connection_errors=True,
retries_exponent=1.88,
retries_initial_interval=3000,
retries_max_elapsed_time=None,
retries_max_interval=None,
sdk=sdk,
)
assert retries_config.retry_connection_errors
assert retries_config.backoff.exponent == 1.88
assert retries_config.backoff.initial_interval == 3000
assert retries_config.backoff.max_elapsed_time == DEFAULT_RETRIES_MAX_ELAPSED_TIME_SEC
assert retries_config.backoff.max_interval == DEFAULT_RETRIES_MAX_INTERVAL_SEC
def test_retries_config_with_no_parameters_set():
retry_config = retries.RetryConfig(
"backoff", retries.BackoffStrategy(3000, 720000, 1.88, 1800000), True
)
sdk = Mock()
sdk.sdk_configuration.retry_config = retry_config
retries_config = get_retries_config(
retries_connection_errors=True,
retries_exponent=None,
retries_initial_interval=None,
retries_max_elapsed_time=None,
retries_max_interval=None,
sdk=sdk,
)
assert retries_config.retry_connection_errors
assert retries_config.backoff.exponent == 1.88
assert retries_config.backoff.initial_interval == 3000
assert retries_config.backoff.max_elapsed_time == 1800000
assert retries_config.backoff.max_interval == 720000
def test_retries_config_cascade():
# notice max_interval is set to 0 which is incorrect - so the DEFAULT_RETRIES_MAX_INTERVAL_SEC
# should be used
retry_config = retries.RetryConfig(
"backoff", retries.BackoffStrategy(3000, 0, 1.88, None), True
)
sdk = Mock()
sdk.sdk_configuration.retry_config = retry_config
retries_config = get_retries_config(
retries_connection_errors=False,
retries_exponent=1.75,
retries_initial_interval=20,
retries_max_elapsed_time=None,
retries_max_interval=None,
sdk=sdk,
)
assert not retries_config.retry_connection_errors
assert retries_config.backoff.exponent == 1.75
assert retries_config.backoff.initial_interval == 20
assert retries_config.backoff.max_elapsed_time == DEFAULT_RETRIES_MAX_ELAPSED_TIME_SEC
assert retries_config.backoff.max_interval == DEFAULT_RETRIES_MAX_INTERVAL_SEC
def test_partition_multiple_via_api_with_single_filename(request: FixtureRequest):
partition_mock_ = method_mock(
request, requests, "post", return_value=FakeResponse(status_code=200)
)
filename = example_doc_path("eml/fake-email.eml")
elements = partition_multiple_via_api(filenames=[filename])
partition_mock_.assert_called_once_with(
"https://api.unstructured.io/general/v0/general",
headers={"ACCEPT": "application/json", "UNSTRUCTURED-API-KEY": ANY},
data={},
files=[("files", (example_doc_path("eml/fake-email.eml"), ANY, None))],
)
assert elements[0][0] == NarrativeText("This is a test email to use for unit tests.")
assert elements[0][0].metadata.filetype == "message/rfc822"
def test_partition_multiple_via_api_from_filenames(request: FixtureRequest):
partition_mock_ = method_mock(
request, requests, "post", return_value=FakeMultipleResponse(status_code=200)
)
filenames = [example_doc_path("eml/fake-email.eml"), example_doc_path("fake.docx")]
elements = partition_multiple_via_api(filenames=filenames)
partition_mock_.assert_called_once_with(
"https://api.unstructured.io/general/v0/general",
headers={"ACCEPT": "application/json", "UNSTRUCTURED-API-KEY": ANY},
data={},
files=[
("files", (example_doc_path("eml/fake-email.eml"), ANY, None)),
("files", (example_doc_path("fake.docx"), ANY, None)),
],
)
assert len(elements) == 2
assert elements[0][0] == NarrativeText("This is a test email to use for unit tests.")
assert elements[0][0].metadata.filetype == "message/rfc822"
def test_partition_multiple_via_api_from_files(request: FixtureRequest):
partition_mock_ = method_mock(
request, requests, "post", return_value=FakeMultipleResponse(status_code=200)
)
filenames = [example_doc_path("eml/fake-email.eml"), example_doc_path("fake.docx")]
with contextlib.ExitStack() as stack:
files = [stack.enter_context(open(filename, "rb")) for filename in filenames]
elements = partition_multiple_via_api(
files=files,
metadata_filenames=filenames,
)
partition_mock_.assert_called_once_with(
"https://api.unstructured.io/general/v0/general",
headers={"ACCEPT": "application/json", "UNSTRUCTURED-API-KEY": ANY},
data={},
files=[
("files", (example_doc_path("eml/fake-email.eml"), ANY, None)),
("files", (example_doc_path("fake.docx"), ANY, None)),
],
)
assert len(elements) == 2
assert elements[0][0] == NarrativeText("This is a test email to use for unit tests.")
assert elements[0][0].metadata.filetype == "message/rfc822"
def test_partition_multiple_via_api_warns_with_file_filename(
caplog: pytest.LogCaptureFixture, request: FixtureRequest
):
partition_mock_ = method_mock(
request, requests, "post", return_value=FakeMultipleResponse(status_code=200)
)
filenames = [example_doc_path("eml/fake-email.eml"), example_doc_path("fake.docx")]
with contextlib.ExitStack() as stack:
files = [stack.enter_context(open(filename, "rb")) for filename in filenames]
partition_multiple_via_api(
files=files,
file_filenames=filenames,
)
partition_mock_.assert_called_once_with(
"https://api.unstructured.io/general/v0/general",
headers={"ACCEPT": "application/json", "UNSTRUCTURED-API-KEY": ANY},
data={},
files=[
("files", (example_doc_path("eml/fake-email.eml"), ANY, None)),
("files", (example_doc_path("fake.docx"), ANY, None)),
],
)
assert "WARNING" in caplog.text
assert "The file_filenames kwarg will be deprecated" in caplog.text
def test_partition_multiple_via_api_raises_with_file_and_metadata_filename():
filenames = [example_doc_path("eml/fake-email.eml"), example_doc_path("fake.docx")]
with contextlib.ExitStack() as stack:
files = [stack.enter_context(open(filename, "rb")) for filename in filenames]
with pytest.raises(ValueError):
partition_multiple_via_api(
files=files,
metadata_filenames=filenames,
file_filenames=filenames,
)
def test_partition_multiple_via_api_raises_with_bad_response(request: FixtureRequest):
partition_mock_ = method_mock(
request, requests, "post", return_value=FakeMultipleResponse(status_code=500)
)
filenames = [example_doc_path("eml/fake-email.eml"), example_doc_path("fake.docx")]
with pytest.raises(ValueError):
partition_multiple_via_api(filenames=filenames)
partition_mock_.assert_called_once_with(
"https://api.unstructured.io/general/v0/general",
headers={"ACCEPT": "application/json", "UNSTRUCTURED-API-KEY": ANY},
data={},
files=[
("files", (example_doc_path("eml/fake-email.eml"), ANY, None)),
("files", (example_doc_path("fake.docx"), ANY, None)),
],
)
def test_partition_multiple_via_api_raises_with_content_types_size_mismatch():
filenames = [example_doc_path("eml/fake-email.eml"), example_doc_path("fake.docx")]
with pytest.raises(ValueError):
partition_multiple_via_api(
filenames=filenames,
content_types=["text/plain"],
)
def test_partition_multiple_via_api_from_files_raises_with_size_mismatch():
filenames = [example_doc_path("eml/fake-email.eml"), example_doc_path("fake.docx")]
with contextlib.ExitStack() as stack:
files = [stack.enter_context(open(filename, "rb")) for filename in filenames]
with pytest.raises(ValueError):
partition_multiple_via_api(
files=files,
metadata_filenames=filenames,
content_types=["text/plain"],
)
def test_partition_multiple_via_api_from_files_raises_without_filenames():
filenames = [example_doc_path("eml/fake-email.eml"), example_doc_path("fake.docx")]
with contextlib.ExitStack() as stack:
files = [stack.enter_context(open(filename, "rb")) for filename in filenames]
with pytest.raises(ValueError):
partition_multiple_via_api(
files=files,
)
def get_api_key():
api_key = os.getenv("UNS_API_KEY")
if api_key is None:
raise ValueError("UNS_API_KEY environment variable not set")
return api_key
@pytest.mark.skipif(not is_in_ci, reason="Skipping test run outside of CI")
@pytest.mark.skipif(skip_not_on_main, reason="Skipping test run outside of main branch")
def test_partition_multiple_via_api_valid_request_data_kwargs():
filenames = [
example_doc_path("fake-text.txt"),
example_doc_path("fake-email.txt"),
]
list_of_lists_of_elements = partition_multiple_via_api(
filenames=filenames,
strategy="fast",
api_key=get_api_key(),
api_url=API_URL,
)
# assert there is a list of elements for each file
assert len(list_of_lists_of_elements) == 2
assert isinstance(list_of_lists_of_elements[0], list)
assert isinstance(list_of_lists_of_elements[1], list)
@pytest.mark.skipif(not is_in_ci, reason="Skipping test run outside of CI")
def test_partition_multiple_via_api_invalid_request_data_kwargs():
filenames = [
example_doc_path("pdf/layout-parser-paper-fast.pdf"),
example_doc_path("img/layout-parser-paper-fast.jpg"),
]
with pytest.raises(ValueError):
partition_multiple_via_api(
filenames=filenames,
strategy="not_a_strategy",
api_key=get_api_key(),
# The url has changed since the 06/24 API release while the sdk defaults to the old url
api_url=API_URL,
)
MOCK_TEXT = """[
{
"element_id": "f49fbd614ddf5b72e06f59e554e6ae2b",
"text": "This is a test email to use for unit tests.",
"type": "NarrativeText",
"metadata": {
"sent_from": [
"Matthew Robinson <mrobinson@unstructured.io>"
],
"sent_to": [
"Matthew Robinson <mrobinson@unstructured.io>"
],
"subject": "Test Email",
"filename": "fake-email.eml",
"filetype": "message/rfc822"
}
}
]"""
class FakeResponse:
def __init__(self, status_code: int):
self.status_code = status_code
# The string representation of partitioned elements is nested in an additional
# layer in the new unstructured-client:
# `elements_from_json(text=response.raw_response.text)`
self.raw_response = FakeRawResponse()
self.headers = {"Content-Type": "application/json"}
def json(self):
return json.loads(self.text)
@property
def text(self):
return MOCK_TEXT
class FakeRawResponse:
def __init__(self):
self.text = MOCK_TEXT
class FakeMultipleResponse:
def __init__(self, status_code: int):
self.status_code = status_code
def json(self):
return json.loads(self.text)
@property
def text(self):
return """[
[
{
"element_id": "f49fbd614ddf5b72e06f59e554e6ae2b",
"text": "This is a test email to use for unit tests.",
"type": "NarrativeText",
"metadata": {
"sent_from": [
"Matthew Robinson <mrobinson@unstructured.io>"
],
"sent_to": [
"Matthew Robinson <mrobinson@unstructured.io>"
],
"subject": "Test Email",
"filename": "fake-email.eml",
"filetype": "message/rfc822"
}
}
],
[
{
"element_id": "f49fbd614ddf5b72e06f59e554e6ae2b",
"text": "This is a test email to use for unit tests.",
"type": "NarrativeText",
"metadata": {
"sent_from": [
"Matthew Robinson <mrobinson@unstructured.io>"
],
"sent_to": [
"Matthew Robinson <mrobinson@unstructured.io>"
],
"subject": "Test Email",
"filename": "fake-email.eml",
"filetype": "message/rfc822"
}
}
]
]"""
@pytest.fixture()
def expected_call_():
with open(example_doc_path("eml/fake-email.eml"), "rb") as f:
file_bytes = f.read()
return [
ANY,
PartitionRequest(
partition_parameters=PartitionParameters(
files=shared.Files(
content=file_bytes,
file_name=example_doc_path("eml/fake-email.eml"),
),
chunking_strategy=None,
combine_under_n_chars=None,
coordinates=False,
encoding=None,
extract_image_block_types=None,
gz_uncompressed_content_type=None,
hi_res_model_name=None,
include_orig_elements=None,
include_page_breaks=False,
languages=None,
max_characters=None,
multipage_sections=True,
new_after_n_chars=None,
ocr_languages=None,
output_format=shared.OutputFormat.APPLICATION_JSON,
overlap=0,
overlap_all=False,
pdf_infer_table_structure=True,
similarity_threshold=None,
skip_infer_table_types=None,
split_pdf_concurrency_level=5,
split_pdf_page=True,
starting_page_number=None,
strategy=shared.Strategy.HI_RES,
unique_element_ids=False,
xml_keep_tags=False,
)
),
None, # retries kwarg
]