get params with defaults (#3004)

Extract repeated logic into `get_call_args_with_defaults` function
This commit is contained in:
John 2024-05-13 08:56:55 -05:00 committed by GitHub
parent e4c895923d
commit 45d7bcb399
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 32 additions and 42 deletions

View File

@ -1,4 +1,4 @@
## 0.13.8-dev4
## 0.13.8-dev5
### Enhancements

View File

@ -1 +1 @@
__version__ = "0.13.8-dev4" # pragma: no cover
__version__ = "0.13.8-dev5" # pragma: no cover

View File

@ -16,7 +16,7 @@ from typing_extensions import ParamSpec
from unstructured.chunking.basic import chunk_elements
from unstructured.chunking.title import chunk_by_title
from unstructured.documents.elements import Element
from unstructured.utils import lazyproperty
from unstructured.utils import get_call_args_applying_defaults, lazyproperty
_P = ParamSpec("_P")
@ -70,20 +70,11 @@ def add_chunking_strategy(func: Callable[_P, list[Element]]) -> Callable[_P, lis
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> list[Element]:
"""The decorated function is replaced with this one."""
def get_call_args_applying_defaults() -> dict[str, Any]:
"""Map both explicit and default arguments of decorated func call by param name."""
sig = inspect.signature(func)
call_args: dict[str, Any] = dict(**dict(zip(sig.parameters, args)), **kwargs)
for param in sig.parameters.values():
if param.name not in call_args and param.default is not param.empty:
call_args[param.name] = param.default
return call_args
# -- call the partitioning function to get the elements --
elements = func(*args, **kwargs)
# -- look for a chunking-strategy argument --
call_args = get_call_args_applying_defaults()
call_args = get_call_args_applying_defaults(func, *args, **kwargs)
chunking_strategy = call_args.pop("chunking_strategy", None)
# -- no chunking-strategy means no chunking --

View File

@ -6,7 +6,6 @@ import dataclasses as dc
import enum
import functools
import hashlib
import inspect
import os
import pathlib
import re
@ -23,7 +22,7 @@ from unstructured.documents.coordinates import (
RelativeCoordinateSystem,
)
from unstructured.partition.utils.constants import UNSTRUCTURED_INCLUDE_DEBUG_METADATA
from unstructured.utils import lazyproperty
from unstructured.utils import get_call_args_applying_defaults, lazyproperty
Point: TypeAlias = "tuple[float, float]"
Points: TypeAlias = "tuple[Point, ...]"
@ -568,20 +567,16 @@ def process_metadata() -> Callable[[Callable[_P, list[Element]]], Callable[_P, l
@functools.wraps(func)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> list[Element]:
elements = func(*args, **kwargs)
sig = inspect.signature(func)
params: dict[str, Any] = dict(**dict(zip(sig.parameters, args)), **kwargs)
for param in sig.parameters.values():
if param.name not in params and param.default is not param.empty:
params[param.name] = param.default
call_args = get_call_args_applying_defaults(func, *args, **kwargs)
regex_metadata: dict["str", "str"] = params.get("regex_metadata", {})
regex_metadata: dict["str", "str"] = call_args.get("regex_metadata", {})
# -- don't write an empty `{}` to metadata.regex_metadata when no regex-metadata was
# -- requested, otherwise it will serialize (because it's not None) when it has no
# -- meaning or is even misleading. Also it complicates tests that don't use regex-meta.
if regex_metadata:
elements = _add_regex_metadata(elements, regex_metadata)
unique_element_ids: bool = params.get("unique_element_ids", False)
unique_element_ids: bool = call_args.get("unique_element_ids", False)
if unique_element_ids is False:
elements = assign_and_map_hash_ids(elements)

View File

@ -2,12 +2,11 @@ from __future__ import annotations
import enum
import functools
import inspect
import json
import os
import re
import zipfile
from typing import IO, Any, Callable, Dict, List, Optional
from typing import IO, Callable, List, Optional
from typing_extensions import ParamSpec
@ -20,6 +19,7 @@ from unstructured.partition.common import (
remove_element_metadata,
set_element_hierarchy,
)
from unstructured.utils import get_call_args_applying_defaults
try:
import magic
@ -580,18 +580,14 @@ def add_metadata(func: Callable[_P, List[Element]]) -> Callable[_P, List[Element
@functools.wraps(func)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> List[Element]:
elements = func(*args, **kwargs)
sig = inspect.signature(func)
params: Dict[str, Any] = dict(**dict(zip(sig.parameters, args)), **kwargs)
for param in sig.parameters.values():
if param.name not in params and param.default is not param.empty:
params[param.name] = param.default
include_metadata = params.get("include_metadata", True)
call_args = get_call_args_applying_defaults(func, *args, **kwargs)
include_metadata = call_args.get("include_metadata", True)
if include_metadata:
if params.get("metadata_filename"):
params["filename"] = params.get("metadata_filename")
if call_args.get("metadata_filename"):
call_args["filename"] = call_args.get("metadata_filename")
metadata_kwargs = {
kwarg: params.get(kwarg) for kwarg in ("filename", "url", "text_as_html")
kwarg: call_args.get(kwarg) for kwarg in ("filename", "url", "text_as_html")
}
# NOTE (yao): do not use cast here as cast(None) still is None
if not str(kwargs.get("model_name", "")).startswith("chipper"):
@ -620,16 +616,9 @@ def add_filetype(
@functools.wraps(func)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> List[Element]:
elements = func(*args, **kwargs)
sig = inspect.signature(func)
params: Dict[str, Any] = dict(**dict(zip(sig.parameters, args)), **kwargs)
for param in sig.parameters.values():
if param.name not in params and param.default is not param.empty:
params[param.name] = param.default
params = get_call_args_applying_defaults(func, *args, **kwargs)
include_metadata = params.get("include_metadata", True)
if include_metadata:
if params.get("metadata_filename"):
params["filename"] = params.get("metadata_filename")
for element in elements:
# NOTE(robinson) - Attached files have already run through this logic
# in their own partitioning function

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import functools
import html
import importlib
import inspect
import json
import os
import platform
@ -33,7 +34,7 @@ from typing_extensions import ParamSpec, TypeAlias
from unstructured.__version__ import __version__
if TYPE_CHECKING:
from unstructured.documents.elements import Text
from unstructured.documents.elements import Element, Text
# Box format: [x_bottom_left, y_bottom_left, x_top_right, y_top_right]
Box: TypeAlias = Tuple[float, float, float, float]
@ -46,6 +47,20 @@ _T = TypeVar("_T")
_P = ParamSpec("_P")
def get_call_args_applying_defaults(
func: Callable[_P, List[Element]],
*args: _P.args,
**kwargs: _P.kwargs,
) -> dict[str, Any]:
"""Map both explicit and default arguments of decorated func call by param name."""
sig = inspect.signature(func)
call_args: dict[str, Any] = dict(**dict(zip(sig.parameters, args)), **kwargs)
for arg in sig.parameters.values():
if arg.name not in call_args and arg.default is not arg.empty:
call_args[arg.name] = arg.default
return call_args
def htmlify_matrix_of_cell_texts(matrix: Sequence[Sequence[str]]) -> str:
"""Form an HTML table from "rows" and "columns" of `matrix`.