rfctr: prepare for adding metadata.orig_elements field (#2647)

**Summary**
Some typing modernization in `elements.py` which will get changes to add
the `orig_elements` metadata field.

Also some additions to `unit_util.py` to enable simplified mocking that
will be required in the next PR.
This commit is contained in:
Steve Canny 2024-03-14 14:31:58 -07:00 committed by GitHub
parent d9e557459c
commit 94535e353c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 222 additions and 62 deletions

View File

@ -25,11 +25,7 @@ repos:
hooks:
- id: ruff
args:
[
"--fix",
"--select=C4,COM,E,F,I,PLR0402,PT,SIM,UP015,UP018,UP032,UP034",
"--ignore=PT011,PT012,SIM117,COM812",
]
["--fix"]
- repo: https://github.com/pycqa/flake8
rev: 4.0.1

View File

@ -1,4 +1,4 @@
## 0.12.7-dev2
## 0.12.7-dev3
### Enhancements

View File

@ -393,7 +393,8 @@ check-flake8-print:
.PHONY: check-ruff
check-ruff:
ruff . --select C4,COM,E,F,I,PLR0402,PT,SIM,UP015,UP018,UP032,UP034 --ignore COM812,PT011,PT012,SIM117
# -- ruff options are determined by pyproject.toml --
ruff .
.PHONY: check-autoflake
check-autoflake:

View File

@ -31,6 +31,7 @@ select = [
]
ignore = [
"COM812", # -- over aggressively insists on trailing commas where not desireable --
"PT005", # -- flags mock fixtures with names intentionally matching private method name --
"PT011", # -- pytest.raises({exc}) too broad, use match param or more specific exception --
"PT012", # -- pytest.raises() block should contain a single simple statement --
"SIM117", # -- merge `with` statements for context managers that have same scope --

View File

@ -1,13 +1,42 @@
"""Utilities that ease unit-testing."""
from __future__ import annotations
import datetime as dt
import difflib
import pathlib
from typing import List, Optional
from typing import Any, List, Optional
from unittest.mock import (
ANY,
MagicMock,
Mock,
PropertyMock,
call,
create_autospec,
mock_open,
patch,
)
from pytest import FixtureRequest, LogCaptureFixture # noqa: PT013
from unstructured.documents.elements import Element
from unstructured.staging.base import elements_from_json, elements_to_json
__all__ = (
"ANY",
"FixtureRequest",
"LogCaptureFixture",
"MagicMock",
"Mock",
"call",
"class_mock",
"function_mock",
"initializer_mock",
"instance_mock",
"method_mock",
"property_mock",
)
def assert_round_trips_through_JSON(elements: List[Element]) -> None:
"""Raises AssertionError if `elements -> JSON -> List[Element] -> JSON` are not equal.
@ -54,3 +83,136 @@ def example_doc_path(file_name: str) -> str:
def parse_optional_datetime(datetime_str: Optional[str]) -> Optional[dt.datetime]:
"""Parse `datetime_str` to a datetime.datetime instance or None if `datetime_str` is None."""
return dt.datetime.fromisoformat(datetime_str) if datetime_str else None
# ------------------------------------------------------------------------------------------------
# MOCKING FIXTURES
# ------------------------------------------------------------------------------------------------
# These allow full-featured and type-safe mocks to be created simply by adding a unit-test
# fixture.
# ------------------------------------------------------------------------------------------------
def class_mock(
request: FixtureRequest, q_class_name: str, autospec: bool = True, **kwargs: Any
) -> Mock:
"""Return mock patching class with qualified name `q_class_name`.
The mock is autospec'ed based on the patched class unless the optional argument `autospec` is
set to False. Any other keyword arguments are passed through to Mock(). Patch is reversed after
calling test returns.
"""
_patch = patch(q_class_name, autospec=autospec, **kwargs)
request.addfinalizer(_patch.stop)
return _patch.start()
def cls_attr_mock(
request: FixtureRequest,
cls: type,
attr_name: str,
name: str | None = None,
**kwargs: Any,
):
"""Return a mock for attribute `attr_name` on `cls`.
Patch is reversed after pytest uses it.
"""
name = request.fixturename if name is None else name
_patch = patch.object(cls, attr_name, name=name, **kwargs)
request.addfinalizer(_patch.stop)
return _patch.start()
def function_mock(
request: FixtureRequest, q_function_name: str, autospec: bool = True, **kwargs: Any
):
"""Return mock patching function with qualified name `q_function_name`.
Patch is reversed after calling test returns.
"""
_patch = patch(q_function_name, autospec=autospec, **kwargs)
request.addfinalizer(_patch.stop)
return _patch.start()
def initializer_mock(request: FixtureRequest, cls: type, autospec: bool = True, **kwargs: Any):
"""Return mock for __init__() method on `cls`.
The patch is reversed after pytest uses it.
"""
_patch = patch.object(cls, "__init__", autospec=autospec, return_value=None, **kwargs)
request.addfinalizer(_patch.stop)
return _patch.start()
def instance_mock(
request: FixtureRequest,
cls: type,
name: str | None = None,
spec_set: bool = True,
**kwargs: Any,
):
"""Return a mock for an instance of `cls` that draws its spec from the class.
The mock does not allow new attributes to be set on the instance. If `name` is missing or
|None|, the name of the returned |Mock| instance is set to *request.fixturename*. Additional
keyword arguments are passed through to the Mock() call that creates the mock.
"""
name = name if name is not None else request.fixturename
return create_autospec(cls, _name=name, spec_set=spec_set, instance=True, **kwargs)
def loose_mock(request: FixtureRequest, name: str | None = None, **kwargs: Any):
"""Return a "loose" mock, meaning it has no spec to constrain calls on it.
Additional keyword arguments are passed through to Mock(). If called without a name, it is
assigned the name of the fixture.
"""
if name is None:
name = request.fixturename
return Mock(name=name, **kwargs)
def method_mock(
request: FixtureRequest,
cls: type,
method_name: str,
autospec: bool = True,
**kwargs: Any,
):
"""Return mock for method `method_name` on `cls`.
The patch is reversed after pytest uses it.
"""
_patch = patch.object(cls, method_name, autospec=autospec, **kwargs)
request.addfinalizer(_patch.stop)
return _patch.start()
def open_mock(request: FixtureRequest, module_name: str, **kwargs: Any):
"""Return a mock for the builtin `open()` method in `module_name`."""
target = "%s.open" % module_name
_patch = patch(target, mock_open(), create=True, **kwargs)
request.addfinalizer(_patch.stop)
return _patch.start()
def property_mock(request: FixtureRequest, cls: type, prop_name: str, **kwargs: Any) -> Mock:
"""A mock for property `prop_name` on class `cls`.
Patch is reversed at the end of the test run.
"""
_patch = patch.object(cls, prop_name, new_callable=PropertyMock, **kwargs)
request.addfinalizer(_patch.stop)
return _patch.start()
def var_mock(request: FixtureRequest, q_var_name: str, **kwargs: Any):
"""Return a mock patching the variable with qualified name `q_var_name`.
Patch is reversed after calling test returns.
"""
_patch = patch(q_var_name, **kwargs)
request.addfinalizer(_patch.stop)
return _patch.start()

View File

@ -1 +1 @@
__version__ = "0.12.7-dev2" # pragma: no cover
__version__ = "0.12.7-dev3" # pragma: no cover

View File

@ -12,7 +12,7 @@ import pathlib
import re
import uuid
from types import MappingProxyType
from typing import Any, Callable, Dict, FrozenSet, List, Optional, Sequence, Tuple, Union, cast
from typing import Any, Callable, FrozenSet, Optional, Sequence, cast
from typing_extensions import ParamSpec, TypeAlias, TypedDict
@ -24,8 +24,8 @@ from unstructured.documents.coordinates import (
from unstructured.partition.utils.constants import UNSTRUCTURED_INCLUDE_DEBUG_METADATA
from unstructured.utils import lazyproperty
Point: TypeAlias = Tuple[float, float]
Points: TypeAlias = Tuple[Point, ...]
Point: TypeAlias = "tuple[float, float]"
Points: TypeAlias = "tuple[Point, ...]"
class NoID(abc.ABC):
@ -42,17 +42,17 @@ class DataSourceMetadata:
url: Optional[str] = None
version: Optional[str] = None
record_locator: Optional[Dict[str, Any]] = None # Values must be JSON-serializable
record_locator: Optional[dict[str, Any]] = None # Values must be JSON-serializable
date_created: Optional[str] = None
date_modified: Optional[str] = None
date_processed: Optional[str] = None
permissions_data: Optional[List[Dict[str, Any]]] = None
permissions_data: Optional[list[dict[str, Any]]] = None
def to_dict(self):
return {key: value for key, value in self.__dict__.items() if value is not None}
@classmethod
def from_dict(cls, input_dict: Dict[str, Any]):
def from_dict(cls, input_dict: dict[str, Any]):
# Only use existing fields when constructing
supported_fields = [f.name for f in dc.fields(cls)]
args = {k: v for k, v in input_dict.items() if k in supported_fields}
@ -95,10 +95,10 @@ class CoordinatesMetadata:
}
@classmethod
def from_dict(cls, input_dict: Dict[str, Any]):
def from_dict(cls, input_dict: dict[str, Any]):
# `input_dict` may contain a tuple of tuples or a list of lists
def convert_to_points(sequence_of_sequences: Sequence[Sequence[float]]) -> Points:
points: List[Point] = []
points: list[Point] = []
for seq in sequence_of_sequences:
if isinstance(seq, list):
points.append(cast(Point, tuple(seq)))
@ -172,8 +172,8 @@ class ElementMetadata:
detection_class_prob: Optional[float]
# -- DEBUG field, the detection mechanism that emitted this element --
detection_origin: Optional[str]
emphasized_text_contents: Optional[List[str]]
emphasized_text_tags: Optional[List[str]]
emphasized_text_contents: Optional[list[str]]
emphasized_text_tags: Optional[list[str]]
file_directory: Optional[str]
filename: Optional[str]
filetype: Optional[str]
@ -184,24 +184,24 @@ class ElementMetadata:
header_footer_type: Optional[str]
# -- used in chunks only, when chunk must be split mid-text to fit window --
is_continuation: Optional[bool]
languages: Optional[List[str]]
languages: Optional[list[str]]
last_modified: Optional[str]
link_texts: Optional[List[str]]
link_urls: Optional[List[str]]
links: Optional[List[Link]]
link_texts: Optional[list[str]]
link_urls: Optional[list[str]]
links: Optional[list[Link]]
# -- the worksheet name in XLXS documents --
page_name: Optional[str]
# -- page numbers currently supported for DOCX, HTML, PDF, and PPTX documents --
page_number: Optional[int]
parent_id: Optional[str | uuid.UUID | NoID | UUID]
# -- "fields" e.g. status, dept.no, etc. extracted from text via regex --
regex_metadata: Optional[Dict[str, List[RegexMetadata]]]
regex_metadata: Optional[dict[str, list[RegexMetadata]]]
# -- EPUB document section --
section: Optional[str]
# -- e-mail specific metadata fields --
sent_from: Optional[List[str]]
sent_to: Optional[List[str]]
sent_from: Optional[list[str]]
sent_to: Optional[list[str]]
subject: Optional[str]
signature: Optional[str]
@ -221,26 +221,26 @@ class ElementMetadata:
coordinates: Optional[CoordinatesMetadata] = None,
data_source: Optional[DataSourceMetadata] = None,
detection_class_prob: Optional[float] = None,
emphasized_text_contents: Optional[List[str]] = None,
emphasized_text_tags: Optional[List[str]] = None,
emphasized_text_contents: Optional[list[str]] = None,
emphasized_text_tags: Optional[list[str]] = None,
file_directory: Optional[str] = None,
filename: Optional[str | pathlib.Path] = None,
filetype: Optional[str] = None,
header_footer_type: Optional[str] = None,
image_path: Optional[str] = None,
is_continuation: Optional[bool] = None,
languages: Optional[List[str]] = None,
languages: Optional[list[str]] = None,
last_modified: Optional[str] = None,
link_texts: Optional[List[str]] = None,
link_urls: Optional[List[str]] = None,
links: Optional[List[Link]] = None,
link_texts: Optional[list[str]] = None,
link_urls: Optional[list[str]] = None,
links: Optional[list[Link]] = None,
page_name: Optional[str] = None,
page_number: Optional[int] = None,
parent_id: Optional[str | uuid.UUID | NoID | UUID] = None,
regex_metadata: Optional[Dict[str, List[RegexMetadata]]] = None,
regex_metadata: Optional[dict[str, list[RegexMetadata]]] = None,
section: Optional[str] = None,
sent_from: Optional[List[str]] = None,
sent_to: Optional[List[str]] = None,
sent_from: Optional[list[str]] = None,
sent_to: Optional[list[str]] = None,
signature: Optional[str] = None,
subject: Optional[str] = None,
text_as_html: Optional[str] = None,
@ -311,7 +311,7 @@ class ElementMetadata:
super().__setattr__(__name, __value)
@classmethod
def from_dict(cls, meta_dict: Dict[str, Any]) -> ElementMetadata:
def from_dict(cls, meta_dict: dict[str, Any]) -> ElementMetadata:
"""Construct from a metadata-dict.
This would generally be a dict formed using the `.to_dict()` method and stored as JSON
@ -362,7 +362,7 @@ class ElementMetadata:
}
)
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""Convert this metadata to dict form, suitable for JSON serialization.
The returned dict is "sparse" in that no key-value pair appears for a field with value
@ -375,7 +375,7 @@ class ElementMetadata:
meta_dict.pop(field_name, None)
# -- don't serialize empty lists --
meta_dict: Dict[str, Any] = {
meta_dict: dict[str, Any] = {
field_name: value
for field_name, value in meta_dict.items()
if value != [] and value != {}
@ -445,7 +445,7 @@ class ConsolidationStrategy(enum.Enum):
"""Combine regex-metadata of elements, adjust start and stop offsets for concatenated text."""
@classmethod
def field_consolidation_strategies(cls) -> Dict[str, ConsolidationStrategy]:
def field_consolidation_strategies(cls) -> dict[str, ConsolidationStrategy]:
"""Mapping from ElementMetadata field-name to its consolidation strategy.
Note that only _TextSection objects ("pre-chunks" containing only `Text` elements that are
@ -492,7 +492,7 @@ class ConsolidationStrategy(enum.Enum):
_P = ParamSpec("_P")
def process_metadata() -> Callable[[Callable[_P, List[Element]]], Callable[_P, List[Element]]]:
def process_metadata() -> Callable[[Callable[_P, list[Element]]], Callable[_P, list[Element]]]:
"""Post-process element-metadata for this document.
This decorator adds a post-processing step to a document partitioner. It adds documentation for
@ -501,7 +501,7 @@ def process_metadata() -> Callable[[Callable[_P, List[Element]]], Callable[_P, L
`unique_element_ids` argument is provided and True.
"""
def decorator(func: Callable[_P, List[Element]]) -> Callable[_P, List[Element]]:
def decorator(func: Callable[_P, list[Element]]) -> Callable[_P, list[Element]]:
if func.__doc__:
if (
"metadata_filename" in func.__code__.co_varnames
@ -522,15 +522,15 @@ def process_metadata() -> Callable[[Callable[_P, List[Element]]], Callable[_P, L
)
@functools.wraps(func)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> List[Element]:
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> list[Element]:
elements = func(*args, **kwargs)
sig = inspect.signature(func)
params: Dict[str, Any] = dict(**dict(zip(sig.parameters, args)), **kwargs)
params: dict[str, Any] = dict(**dict(zip(sig.parameters, args)), **kwargs)
for param in sig.parameters.values():
if param.name not in params and param.default is not param.empty:
params[param.name] = param.default
regex_metadata: Dict["str", "str"] = params.get("regex_metadata", {})
regex_metadata: dict["str", "str"] = params.get("regex_metadata", {})
# -- don't write an empty `{}` to metadata.regex_metadata when no regex-metadata was
# -- requested, otherwise it will serialize (because it's not None) when it has no
# -- meaning or is even misleading. Also it complicates tests that don't use regex-meta.
@ -549,18 +549,18 @@ def process_metadata() -> Callable[[Callable[_P, List[Element]]], Callable[_P, L
def _add_regex_metadata(
elements: List[Element],
regex_metadata: Dict[str, str] = {},
) -> List[Element]:
elements: list[Element],
regex_metadata: dict[str, str] = {},
) -> list[Element]:
"""Adds metadata based on a user provided regular expression.
The additional metadata will be added to the regex_metadata attrbuted in the element metadata.
"""
for element in elements:
if isinstance(element, Text):
_regex_metadata: Dict["str", List[RegexMetadata]] = {}
_regex_metadata: dict["str", list[RegexMetadata]] = {}
for field_name, pattern in regex_metadata.items():
results: List[RegexMetadata] = []
results: list[RegexMetadata] = []
for result in re.finditer(pattern, element.text):
start, end = result.span()
results.append(
@ -637,13 +637,13 @@ class Element(abc.ABC):
def __init__(
self,
element_id: Union[str, uuid.UUID, NoID, UUID] = NoID(),
coordinates: Optional[Tuple[Tuple[float, float], ...]] = None,
element_id: str | uuid.UUID | NoID | UUID = NoID(),
coordinates: Optional[tuple[tuple[float, float], ...]] = None,
coordinate_system: Optional[CoordinateSystem] = None,
metadata: Optional[ElementMetadata] = None,
detection_origin: Optional[str] = None,
):
self.id: Union[str, uuid.UUID, NoID, UUID] = element_id
self.id: str | uuid.UUID | NoID | UUID = element_id
self.metadata = ElementMetadata() if metadata is None else metadata
if coordinates is not None or coordinate_system is not None:
self.metadata.coordinates = CoordinatesMetadata(
@ -657,7 +657,7 @@ class Element(abc.ABC):
def id_to_uuid(self):
self.id = str(uuid.uuid4())
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
return {
"type": None,
"element_id": self.id,
@ -703,8 +703,8 @@ class CheckBox(Element):
def __init__(
self,
element_id: Union[str, uuid.UUID, NoID, UUID] = NoID(),
coordinates: Optional[Tuple[Tuple[float, float], ...]] = None,
element_id: str | uuid.UUID | NoID | UUID = NoID(),
coordinates: Optional[tuple[tuple[float, float], ...]] = None,
coordinate_system: Optional[CoordinateSystem] = None,
checked: bool = False,
metadata: Optional[ElementMetadata] = None,
@ -730,7 +730,7 @@ class CheckBox(Element):
)
)
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""Serialize to JSON-compatible (str keys) dict."""
out = super().to_dict()
out["type"] = "CheckBox"
@ -747,16 +747,16 @@ class Text(Element):
def __init__(
self,
text: str,
element_id: Union[str, uuid.UUID, NoID, UUID] = NoID(),
coordinates: Optional[Tuple[Tuple[float, float], ...]] = None,
element_id: str | uuid.UUID | NoID | UUID = NoID(),
coordinates: Optional[tuple[tuple[float, float], ...]] = None,
coordinate_system: Optional[CoordinateSystem] = None,
metadata: Optional[ElementMetadata] = None,
detection_origin: Optional[str] = None,
embeddings: Optional[List[float]] = None,
embeddings: Optional[list[float]] = None,
):
metadata = metadata if metadata else ElementMetadata()
self.text: str = text
self.embeddings: Optional[List[float]] = embeddings
self.embeddings: Optional[list[float]] = embeddings
if isinstance(element_id, NoID):
# NOTE(robinson) - Cut the SHA256 hex in half to get the first 128 bits
@ -788,7 +788,7 @@ class Text(Element):
def __str__(self):
return self.text
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""Serialize to JSON-compatible (str keys) dict."""
out = super().to_dict()
out["element_id"] = self.id
@ -899,7 +899,7 @@ class Footer(Text):
category = "Footer"
TYPE_TO_TEXT_ELEMENT_MAP: Dict[str, type[Text]] = {
TYPE_TO_TEXT_ELEMENT_MAP: dict[str, type[Text]] = {
ElementType.TITLE: Title,
ElementType.SECTION_HEADER: Title,
ElementType.HEADLINE: Title,