feat/tqdm ingest support (#3199)

### Description
Add in tqdm support to show progress bar of status of each job when
being run. Supported for each mode (serial, async, multiprocess). Also
small timing wrapper around jobs to print out how long it took in total.
This commit is contained in:
Roman Isecke 2024-06-13 14:41:54 -04:00 committed by GitHub
parent f5ebb209a4
commit dadc9c6d0b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 48 additions and 2 deletions

View File

@ -4,6 +4,8 @@
### Features
* **tqdm ingest support** add optional flag to ingest flow to print out progress bar of each step in the process.
### Fixes
* **Remove deprecated `overwrite_schema` kwarg from Delta Table connector.**. The `overwrite_schema` kwarg is deprecated in `deltalake>=0.18.0`. `schema_mode=` should be used now instead. `schema_mode="overwrite"` is equivalent to `overwrite_schema=True` and `schema_mode="merge"` is equivalent to `overwrite_schema="False"`. `schema_mode` defaults to `None`. You can also now specify `engine`, which defaults to `"pyarrow"`. You need to specify `enginer="rust"` to use `"schema_mode"`.

View File

@ -17,3 +17,4 @@ backoff
typing-extensions
unstructured-client
wrapt
tqdm

View File

@ -83,5 +83,6 @@ class ProcessorCliConfig(CliConfig):
"files based on file extension.",
),
click.Option(["--verbose"], is_flag=True, default=False),
click.Option(["--tqdm"], is_flag=True, default=False, help="Show progress bar"),
]
return options

View File

@ -24,7 +24,7 @@ download_path = work_dir / "download"
if __name__ == "__main__":
logger.info(f"Writing all content in: {work_dir.resolve()}")
Pipeline.from_configs(
context=ProcessorConfig(work_dir=str(work_dir.resolve())),
context=ProcessorConfig(work_dir=str(work_dir.resolve()), tqdm=True),
indexer_config=S3IndexerConfig(remote_url="s3://utic-dev-tech-fixtures/small-pdf-set/"),
downloader_config=S3DownloaderConfig(download_dir=download_path),
source_connection_config=S3ConnectionConfig(anonymous=True),

View File

@ -13,6 +13,7 @@ DEFAULT_WORK_DIR = str((Path.home() / ".cache" / "unstructured" / "ingest" / "pi
class ProcessorConfig(EnhancedDataClassJsonMixin):
reprocess: bool = False
verbose: bool = False
tqdm: bool = False
work_dir: str = field(default_factory=lambda: DEFAULT_WORK_DIR)
num_processes: int = 2
max_connections: Optional[int] = None

View File

@ -3,9 +3,14 @@ import logging
import multiprocessing as mp
from abc import ABC
from dataclasses import dataclass
from functools import wraps
from pathlib import Path
from time import time
from typing import Any, Optional, TypeVar
from tqdm import tqdm
from tqdm.asyncio import tqdm as tqdm_asyncio
from unstructured.ingest.v2.interfaces import BaseProcess, ProcessorConfig
from unstructured.ingest.v2.logger import logger
@ -13,6 +18,22 @@ BaseProcessT = TypeVar("BaseProcessT", bound=BaseProcess)
iterable_input = list[dict[str, Any]]
def timed(func):
@wraps(func)
def time_it(self, *args, **kwargs):
start = time()
try:
return func(self, *args, **kwargs)
finally:
if func.__name__ == "__call__":
reported_name = f"{self.__class__.__name__} [cls]"
else:
reported_name = func.__name__
logger.info(f"{reported_name} took {time() - start} seconds")
return time_it
@dataclass
class PipelineStep(ABC):
process: BaseProcessT
@ -25,6 +46,10 @@ class PipelineStep(ABC):
def process_serially(self, iterable: iterable_input) -> Any:
logger.info("processing content serially")
if iterable:
if len(iterable) == 1:
return [self.run(**iterable[0])]
if self.context.tqdm:
return [self.run(**it) for it in tqdm(iterable, desc=self.identifier)]
return [self.run(**it) for it in iterable]
return [self.run()]
@ -32,6 +57,10 @@ class PipelineStep(ABC):
if iterable:
if len(iterable) == 1:
return [await self.run_async(**iterable[0])]
if self.context.tqdm:
return await tqdm_asyncio.gather(
*[self.run_async(**i) for i in iterable], desc=self.identifier
)
return await asyncio.gather(*[self.run_async(**i) for i in iterable])
return [await self.run_async()]
@ -44,7 +73,7 @@ class PipelineStep(ABC):
if iterable:
if len(iterable) == 1:
return [self.run(**iterable[0])]
return [self.process_serially(iterable)]
if self.context.num_processes == 1:
return self.process_serially(iterable)
with mp.Pool(
@ -52,6 +81,14 @@ class PipelineStep(ABC):
initializer=self._set_log_level,
initargs=(logging.DEBUG if self.context.verbose else logging.INFO,),
) as pool:
if self.context.tqdm:
return list(
tqdm(
pool.imap_unordered(func=self._wrap_mp, iterable=iterable),
total=len(iterable),
desc=self.identifier,
)
)
return pool.map(self._wrap_mp, iterable)
return [self.run()]
@ -63,6 +100,7 @@ class PipelineStep(ABC):
# Set the log level for each spawned process when using multiprocessing pool
logger.setLevel(log_level)
@timed
def __call__(self, iterable: Optional[iterable_input] = None) -> Any:
iterable = iterable or []
if iterable:

View File

@ -1,6 +1,7 @@
import logging
import multiprocessing as mp
from dataclasses import InitVar, dataclass, field
from time import time
from typing import Any, Optional, Union
from unstructured.ingest.v2.interfaces import ProcessorConfig
@ -82,7 +83,9 @@ class Pipeline:
def run(self):
try:
start_time = time()
self._run()
logger.info(f"Finished ingest process in {time() - start_time}s")
finally:
self.log_statuses()
self.cleanup()