diff --git a/packages/markitdown/src/markitdown/__main__.py b/packages/markitdown/src/markitdown/__main__.py index 6a24391..8266f5c 100644 --- a/packages/markitdown/src/markitdown/__main__.py +++ b/packages/markitdown/src/markitdown/__main__.py @@ -3,10 +3,11 @@ # SPDX-License-Identifier: MIT import argparse import sys +import codecs from textwrap import dedent from importlib.metadata import entry_points from .__about__ import __version__ -from ._markitdown import MarkItDown, DocumentConverterResult +from ._markitdown import MarkItDown, StreamInfo, DocumentConverterResult def main(): @@ -58,6 +59,24 @@ def main(): help="Output file name. If not provided, output is written to stdout.", ) + parser.add_argument( + "-x", + "--extension", + help="Provide a hint about the file extension (e.g., when reading from stdin).", + ) + + parser.add_argument( + "-m", + "--mime-type", + help="Provide a hint about the file's MIME type.", + ) + + parser.add_argument( + "-c", + "--charset", + help="Provide a hint about the file's charset (e.g, UTF-8).", + ) + parser.add_argument( "-d", "--use-docintel", @@ -88,6 +107,48 @@ def main(): parser.add_argument("filename", nargs="?") args = parser.parse_args() + # Parse the extension hint + extension_hint = args.extension + if extension_hint is not None: + extension_hint = extension_hint.strip().lower() + if len(extension_hint) > 0: + if not extension_hint.startswith("."): + extension_hint = "." + extension_hint + else: + extension_hint = None + + # Parse the mime type + mime_type_hint = args.mime_type + if mime_type_hint is not None: + mime_type_hint = mime_type_hint.strip() + if len(mime_type_hint) > 0: + if mime_type_hint.count("/") != 1: + _exit_with_error(f"Invalid MIME type: {mime_type_hint}") + else: + mime_type_hint = None + + # Parse the charset + charset_hint = args.charset + if charset_hint is not None: + charset_hint = charset_hint.strip() + if len(charset_hint) > 0: + try: + charset_hint = codecs.lookup(charset_hint).name + except LookupError: + _exit_with_error(f"Invalid charset: {charset_hint}") + else: + charset_hint = None + + stream_info: str | None = None + if ( + extension_hint is not None + or mime_type_hint is not None + or charset_hint is not None + ): + stream_info = StreamInfo( + extension=extension_hint, mimetype=mime_type_hint, charset=charset_hint + ) + if args.list_plugins: # List installed plugins, then exit print("Installed MarkItDown 3rd-party Plugins:\n") @@ -107,11 +168,12 @@ def main(): if args.use_docintel: if args.endpoint is None: - raise ValueError( + _exit_with_error( "Document Intelligence Endpoint is required when using Document Intelligence." ) elif args.filename is None: - raise ValueError("Filename is required when using Document Intelligence.") + _exit_with_error("Filename is required when using Document Intelligence.") + markitdown = MarkItDown( enable_plugins=args.use_plugins, docintel_endpoint=args.endpoint ) @@ -119,9 +181,9 @@ def main(): markitdown = MarkItDown(enable_plugins=args.use_plugins) if args.filename is None: - result = markitdown.convert_stream(sys.stdin.buffer) + result = markitdown.convert_stream(sys.stdin.buffer, stream_info=stream_info) else: - result = markitdown.convert(args.filename) + result = markitdown.convert(args.filename, stream_info=stream_info) _handle_output(args, result) @@ -135,5 +197,10 @@ def _handle_output(args, result: DocumentConverterResult): print(result.text_content) +def _exit_with_error(message: str): + print(message) + sys.exit(1) + + if __name__ == "__main__": main() diff --git a/packages/markitdown/src/markitdown/_markitdown.py b/packages/markitdown/src/markitdown/_markitdown.py index 825643c..b116927 100644 --- a/packages/markitdown/src/markitdown/_markitdown.py +++ b/packages/markitdown/src/markitdown/_markitdown.py @@ -244,7 +244,7 @@ class MarkItDown: or source.startswith("https://") or source.startswith("file://") ): - return self.convert_url(source, **kwargs) + return self.convert_url(source, stream_info=stream_info, *kwargs) else: return self.convert_local(source, stream_info=stream_info, **kwargs) # Path object @@ -252,14 +252,14 @@ class MarkItDown: return self.convert_local(source, stream_info=stream_info, **kwargs) # Request response elif isinstance(source, requests.Response): - return self.convert_response(source, **kwargs) + return self.convert_response(source, stream_info=stream_info, **kwargs) # Binary stream elif ( hasattr(source, "read") and callable(source.read) and not isinstance(source, io.TextIOBase) ): - return self.convert_stream(source, **kwargs) + return self.convert_stream(source, stream_info=stream_info, **kwargs) else: raise TypeError( f"Invalid source type: {type(source)}. Expected str, requests.Response, BinaryIO."