mirror of
https://github.com/allenai/olmocr.git
synced 2025-09-03 22:06:16 +00:00
610 lines
18 KiB
Python
610 lines
18 KiB
Python
import glob
|
|
import os
|
|
import re
|
|
from functools import partial, reduce
|
|
from hashlib import sha256
|
|
from itertools import chain
|
|
from pathlib import Path
|
|
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union
|
|
from urllib.parse import urlparse
|
|
|
|
import platformdirs
|
|
import smart_open
|
|
from fsspec import AbstractFileSystem, get_filesystem_class
|
|
from smart_open.compression import get_supported_extensions
|
|
|
|
from .loggers import LOGGER_PREFIX, get_logger
|
|
|
|
__all__ = [
|
|
"glob_path",
|
|
"sub_prefix",
|
|
"add_suffix",
|
|
"sub_suffix",
|
|
"make_relative",
|
|
"mkdir_p",
|
|
"split_path",
|
|
"join_path",
|
|
"is_glob",
|
|
"split_glob",
|
|
"partition_path",
|
|
]
|
|
|
|
|
|
FS_KWARGS: Dict[str, Dict[str, Any]] = {
|
|
"": {"auto_mkdir": True},
|
|
}
|
|
|
|
|
|
RE_ANY_ESCAPE = re.compile(r"(?<!\\)(\*\?\[\])")
|
|
RE_GLOB_STAR_ESCAPE = re.compile(r"(?<!\\)\*")
|
|
RE_GLOB_ONE_ESCAPE = re.compile(r"(?<!\\)\?")
|
|
RE_GLOB_OPEN_ESCAPE = re.compile(r"(?<!\\)\[")
|
|
RE_GLOB_CLOSE_ESCAPE = re.compile(r"(?<!\\)\]")
|
|
ESCAPE_SYMBOLS_MAP = {"*": "\u2581", "?": "\u2582", "[": "\u2583", "]": "\u2584"}
|
|
REVERSE_ESCAPE_SYMBOLS_MAP = {v: k for k, v in ESCAPE_SYMBOLS_MAP.items()}
|
|
PATCHED_GLOB = False
|
|
|
|
|
|
LOGGER = get_logger(__name__)
|
|
|
|
|
|
def get_fs(path: Union[Path, str]) -> AbstractFileSystem:
|
|
"""
|
|
Get the filesystem class for a given path.
|
|
"""
|
|
path = str(path)
|
|
protocol = urlparse(path).scheme
|
|
fs = get_filesystem_class(protocol)(**FS_KWARGS.get(protocol, {}))
|
|
|
|
global PATCHED_GLOB # pylint: disable=global-statement
|
|
|
|
# patch glob method to support recursive globbing
|
|
if protocol == "" and not PATCHED_GLOB:
|
|
fs.glob = partial(glob.glob, recursive=True)
|
|
|
|
# only patch once
|
|
PATCHED_GLOB = True
|
|
|
|
return fs
|
|
|
|
|
|
def _escape_glob(s: Union[str, Path]) -> str:
|
|
"""
|
|
Escape glob characters in a string.
|
|
"""
|
|
s = str(s)
|
|
s = RE_GLOB_STAR_ESCAPE.sub(ESCAPE_SYMBOLS_MAP["*"], s)
|
|
s = RE_GLOB_ONE_ESCAPE.sub(ESCAPE_SYMBOLS_MAP["?"], s)
|
|
s = RE_GLOB_OPEN_ESCAPE.sub(ESCAPE_SYMBOLS_MAP["["], s)
|
|
s = RE_GLOB_CLOSE_ESCAPE.sub(ESCAPE_SYMBOLS_MAP["]"], s)
|
|
return s
|
|
|
|
|
|
def _unescape_glob(s: Union[str, Path]) -> str:
|
|
"""
|
|
Unescape glob characters in a string.
|
|
"""
|
|
s = str(s)
|
|
for k, v in REVERSE_ESCAPE_SYMBOLS_MAP.items():
|
|
s = s.replace(k, v)
|
|
return s
|
|
|
|
|
|
def _pathify(path: Union[Path, str]) -> Tuple[str, Path]:
|
|
"""
|
|
Return the protocol and path of a given path.
|
|
"""
|
|
path = _escape_glob(str(path))
|
|
parsed = urlparse(path)
|
|
path = Path(f"{parsed.netloc}/{parsed.path}") if parsed.netloc else Path(parsed.path)
|
|
return parsed.scheme, path
|
|
|
|
|
|
def _unpathify(protocol: str, path: Path) -> str:
|
|
"""
|
|
Return a path from its protocol and path components.
|
|
"""
|
|
path_str = _unescape_glob(str(path))
|
|
if protocol:
|
|
path_str = f"{protocol}://{path_str.lstrip('/')}"
|
|
return path_str
|
|
|
|
|
|
def remove_params(path: str) -> str:
|
|
"""
|
|
Remove parameters from a path.
|
|
"""
|
|
parsed = urlparse(path)
|
|
return (f"{parsed.scheme}://" if parsed.scheme else "") + f"{parsed.netloc}{parsed.path}"
|
|
|
|
|
|
def is_local(path: str) -> bool:
|
|
"""
|
|
Check if a path is local.
|
|
"""
|
|
prot, _ = _pathify(path)
|
|
return prot == "" or prot == "file"
|
|
|
|
|
|
def copy_file(src: str, dest: str) -> None:
|
|
"""Copy a file."""
|
|
with smart_open.open(src, "rb") as src_file, smart_open.open(dest, "wb") as dest_file:
|
|
dest_file.write(src_file.read())
|
|
|
|
|
|
def copy_dir(
|
|
src: str, dst: str, src_fs: Optional[AbstractFileSystem] = None, dst_fs: Optional[AbstractFileSystem] = None
|
|
):
|
|
"""Copy a directory."""
|
|
src_fs = src_fs or get_fs(src)
|
|
dst_fs = dst_fs or get_fs(dst)
|
|
logger = get_logger(__name__)
|
|
|
|
for src_path in glob_path(src, yield_dirs=True, fs=src_fs):
|
|
rel_path = sub_prefix(src_path, src)
|
|
dest_path = join_path("", dst, rel_path)
|
|
if is_dir(src_path, fs=src_fs):
|
|
# recursively copy directories
|
|
copy_dir(src=src_path, dst=dest_path, src_fs=src_fs, dst_fs=dst_fs)
|
|
else:
|
|
# file; copy over
|
|
logger.info(f"Copying {src_path} to {dest_path}")
|
|
copy_file(src_path, dest_path)
|
|
|
|
|
|
def delete_file(path: str, ignore_missing: bool = False, fs: Optional[AbstractFileSystem] = None) -> bool:
|
|
"""Delete a file."""
|
|
|
|
fs = fs or get_fs(path)
|
|
try:
|
|
fs.rm(path)
|
|
deleted = True
|
|
except FileNotFoundError as ex:
|
|
if not ignore_missing:
|
|
raise ex
|
|
deleted = False
|
|
|
|
return deleted
|
|
|
|
|
|
def get_size(path: str, fs: Optional[AbstractFileSystem] = None) -> int:
|
|
"""Get the size of a file"""
|
|
|
|
fs = fs or get_fs(path)
|
|
|
|
if not exists(path, fs=fs):
|
|
raise ValueError(f"Path {path} does not exist")
|
|
if is_dir(path, fs=fs):
|
|
raise ValueError(f"Path {path} is a directory")
|
|
|
|
return fs.info(path)["size"]
|
|
|
|
|
|
def delete_dir(path: str, ignore_missing: bool = False, fs: Optional[AbstractFileSystem] = None) -> bool:
|
|
"""Delete a directory."""
|
|
|
|
fs = fs or get_fs(path)
|
|
try:
|
|
fs.rm(path, recursive=True)
|
|
deleted = True
|
|
except FileNotFoundError as ex:
|
|
if not ignore_missing:
|
|
raise ex
|
|
deleted = False
|
|
|
|
return deleted
|
|
|
|
|
|
def partition_path(path: str) -> Tuple[str, Tuple[str, ...], Tuple[str, ...]]:
|
|
"""Partition a path into its protocol, symbols before a glob, and symbols after a glob."""
|
|
# split the path into its protocol and path components
|
|
prot, path_obj = _pathify(path)
|
|
|
|
# we need to first figure out if this path has a glob by checking if any of the escaped symbols for
|
|
# globs are in the path.
|
|
glob_locs = [i for i, p in enumerate(path_obj.parts) if any(c in p for c in REVERSE_ESCAPE_SYMBOLS_MAP)]
|
|
|
|
# make the path components before the glob
|
|
pre_glob_path = path_obj.parts[: glob_locs[0]] if glob_locs else path_obj.parts
|
|
pre_glob_path = tuple(_unescape_glob(p) for p in pre_glob_path)
|
|
|
|
# make the path components after the glob
|
|
post_glob_path = path_obj.parts[glob_locs[0] + 1 :] if glob_locs else ()
|
|
post_glob_path = tuple(_unescape_glob(p) for p in post_glob_path)
|
|
|
|
return prot, pre_glob_path, post_glob_path
|
|
|
|
|
|
def split_path(path: str) -> Tuple[str, Tuple[str, ...]]:
|
|
"""
|
|
Split a path into its protocol and path components.
|
|
"""
|
|
protocol, _path = _pathify(path)
|
|
return protocol, tuple(_unescape_glob(p) for p in _path.parts)
|
|
|
|
|
|
def join_path(protocol: Union[str, None], *parts: Union[str, Iterable[str]]) -> str:
|
|
"""
|
|
Join a path from its protocol and path components.
|
|
"""
|
|
all_prots, all_parts = zip(
|
|
*(_pathify(p) for p in chain.from_iterable([p] if isinstance(p, str) else p for p in parts))
|
|
)
|
|
path = str(Path(*all_parts)).rstrip("/")
|
|
protocol = protocol or str(all_prots[0])
|
|
|
|
if protocol:
|
|
path = f"{protocol}://{path.lstrip('/')}"
|
|
return _unescape_glob(path)
|
|
|
|
|
|
def glob_path(
|
|
path: Union[Path, str],
|
|
hidden_files: bool = False,
|
|
autoglob_dirs: bool = True,
|
|
recursive_dirs: bool = False,
|
|
yield_dirs: bool = True,
|
|
fs: Optional[AbstractFileSystem] = None,
|
|
) -> Iterator[str]:
|
|
"""
|
|
Expand a glob path into a list of paths.
|
|
"""
|
|
protocol, parsed_path = _pathify(path)
|
|
fs = fs or get_fs(path)
|
|
|
|
if autoglob_dirs and fs.isdir(path):
|
|
path = join_path(protocol, _unescape_glob(parsed_path), "*")
|
|
|
|
if "*" not in str(path):
|
|
# nothing to glob
|
|
yield str(path)
|
|
return
|
|
|
|
for gl in fs.glob(path):
|
|
gl = str(gl)
|
|
|
|
if not hidden_files and Path(gl).name.startswith("."):
|
|
continue
|
|
|
|
if fs.isdir(gl):
|
|
if recursive_dirs:
|
|
yield from glob_path(
|
|
gl,
|
|
hidden_files=hidden_files,
|
|
autoglob_dirs=autoglob_dirs,
|
|
recursive_dirs=recursive_dirs,
|
|
yield_dirs=yield_dirs,
|
|
fs=fs,
|
|
)
|
|
if yield_dirs:
|
|
yield join_path(protocol, gl)
|
|
else:
|
|
yield join_path(protocol, gl)
|
|
|
|
|
|
def sub_prefix(a: str, b: str) -> str:
|
|
"""
|
|
Return the relative path of b from a.
|
|
"""
|
|
prot_a, path_a = _pathify(a)
|
|
prot_b, path_b = _pathify(b)
|
|
|
|
if prot_a != prot_b:
|
|
raise ValueError(f"Protocols of {a} and {b} do not match")
|
|
|
|
try:
|
|
diff = str(path_a.relative_to(path_b))
|
|
except ValueError:
|
|
diff = join_path(prot_a, path_a.parts)
|
|
|
|
return _unescape_glob(diff)
|
|
|
|
|
|
def sub_suffix(a: str, b: str) -> str:
|
|
"""
|
|
Remove b from the end of a.
|
|
"""
|
|
prot_a, path_a = _pathify(a)
|
|
prot_b, path_b = _pathify(b)
|
|
|
|
if prot_b:
|
|
raise ValueError(f"{b} is not a relative path")
|
|
|
|
sub_path = re.sub(f"{path_b}$", "", str(path_a))
|
|
sub_prot = f"{prot_a}://" if prot_a else ""
|
|
|
|
# need to trim '/' from the end if (a) '/' is not the only symbol in the path or
|
|
# (b) there is a protocol so absolute paths don't make sense
|
|
if sub_path != "/" or sub_prot:
|
|
sub_path = sub_path.rstrip("/")
|
|
|
|
return _unescape_glob(sub_prot + sub_path)
|
|
|
|
|
|
def add_suffix(a: str, b: str) -> str:
|
|
"""
|
|
Return the the path of a joined with b.
|
|
"""
|
|
prot_a, path_a = _pathify(a)
|
|
prot_b, path_b = _pathify(b)
|
|
|
|
if prot_b:
|
|
raise ValueError(f"{b} is not a relative path")
|
|
|
|
return join_path(prot_a, str(path_a / path_b))
|
|
|
|
|
|
def exists(path: str, fs: Optional[AbstractFileSystem] = None) -> bool:
|
|
"""Check if a path exists."""
|
|
|
|
fs = fs or get_fs(path)
|
|
return fs.exists(path)
|
|
|
|
|
|
def is_dir(path: str, fs: Optional[AbstractFileSystem] = None) -> bool:
|
|
"""Check if a path is a directory."""
|
|
fs = fs or get_fs(path)
|
|
if exists(path, fs=fs):
|
|
return fs.isdir(path)
|
|
return False
|
|
|
|
|
|
def is_file(path: str, fs: Optional[AbstractFileSystem] = None) -> bool:
|
|
"""Check if a path is a file."""
|
|
fs = fs or get_fs(path)
|
|
if exists(path, fs=fs):
|
|
return fs.isfile(path)
|
|
return False
|
|
|
|
|
|
def parent(path: str) -> str:
|
|
"""Get the parent directory of a path; if the parent is the root, return the root."""
|
|
|
|
prot, parts = split_path(path)
|
|
if len(parts) == 1:
|
|
return path
|
|
return join_path(prot, *parts[:-1])
|
|
|
|
|
|
def mkdir_p(path: str, fs: Optional[AbstractFileSystem] = None) -> None:
|
|
"""
|
|
Create a directory if it does not exist.
|
|
"""
|
|
if is_glob(path):
|
|
raise ValueError(f"Cannot create directory with glob pattern: {path}")
|
|
|
|
fs = fs or get_fs(path)
|
|
fs.makedirs(path, exist_ok=True)
|
|
|
|
|
|
def make_relative(paths: List[str]) -> Tuple[str, List[str]]:
|
|
"""Find minimum longest root shared among all paths"""
|
|
if len(paths) == 0:
|
|
raise ValueError("Cannot make relative path of empty list")
|
|
|
|
common_prot, common_parts, _ = partition_path(paths[0])
|
|
|
|
for path in paths:
|
|
current_prot, current_parts, _ = partition_path(path)
|
|
if current_prot != common_prot:
|
|
raise ValueError(f"Protocols of {path} and {paths[0]} do not match")
|
|
|
|
for i in range(min(len(common_parts), len(current_parts))):
|
|
if common_parts[i] != current_parts[i]:
|
|
common_parts = common_parts[:i]
|
|
break
|
|
|
|
if len(common_parts) > 0:
|
|
common_path = (f"{common_prot}://" if common_prot else "") + str(Path(*common_parts))
|
|
relative_paths = [sub_prefix(path, common_path) for path in paths]
|
|
else:
|
|
common_path = f"{common_prot}://" if common_prot else ""
|
|
relative_paths = [_unpathify("", _pathify(path)[1]) for path in paths]
|
|
|
|
return common_path, relative_paths
|
|
|
|
|
|
def is_glob(path: str) -> bool:
|
|
"""
|
|
Check if a path contains a glob wildcard.
|
|
"""
|
|
return bool(re.search(r"(?<!\\)[*?[\]]", path))
|
|
|
|
|
|
def split_glob(path: str) -> Tuple[str, str]:
|
|
"""
|
|
Partition a path on the first wildcard.
|
|
"""
|
|
if not is_glob(path):
|
|
# it's not a glob, so it's all path
|
|
return path, ""
|
|
|
|
if path[0] == "*":
|
|
# starts with a glob, so it's all glob
|
|
return "", path
|
|
|
|
protocol, parts = split_path(path)
|
|
|
|
i = min(i for i, c in enumerate(parts) if is_glob(c))
|
|
|
|
if i == 0:
|
|
# no path, so it's all glob
|
|
return protocol, join_path("", *parts)
|
|
|
|
path = join_path(protocol, *parts[:i])
|
|
rest = join_path("", *parts[i:])
|
|
return path, rest
|
|
|
|
|
|
def get_cache_dir() -> str:
|
|
"""
|
|
Returns the path to the cache directory for the Dolma toolkit.
|
|
If the directory does not exist, it will be created.
|
|
|
|
Returns:
|
|
str: The path to the cache directory.
|
|
"""
|
|
loc = platformdirs.user_cache_dir(LOGGER_PREFIX)
|
|
mkdir_p(loc)
|
|
return loc
|
|
|
|
|
|
def resource_to_filename(resource: Union[str, bytes]) -> str:
|
|
"""
|
|
Convert a ``resource`` into a hashed filename in a repeatable way. Preserves the file extensions.
|
|
"""
|
|
_, (*_, orig_filename) = split_path(remove_params(str(resource)))
|
|
_, extensions = split_basename_and_extension(orig_filename)
|
|
|
|
resource_bytes = str(resource).encode("utf-8")
|
|
resource_hash = sha256(resource_bytes)
|
|
hash_filename = resource_hash.hexdigest() + extensions
|
|
|
|
return hash_filename
|
|
|
|
|
|
def cached_path(path: str, fs: Optional[AbstractFileSystem] = None) -> str:
|
|
"""
|
|
Returns the cached path for a given resource.
|
|
|
|
If the resource is already available locally, the function returns the path as is.
|
|
Otherwise, it downloads the resource from the specified path and saves it in the cache directory.
|
|
|
|
Args:
|
|
path (str): The path to the resource.
|
|
|
|
Returns:
|
|
str: The cached path of the resource.
|
|
"""
|
|
if is_local(path):
|
|
# Implementation goes here
|
|
pass
|
|
return path
|
|
|
|
destination = f"{get_cache_dir()}/{resource_to_filename(path)}"
|
|
|
|
remote_fs = fs or get_fs(path)
|
|
local_fs = get_fs(destination)
|
|
|
|
if exists(destination, fs=local_fs):
|
|
LOGGER.info(f"Using cached file {destination} for {path}")
|
|
return destination
|
|
|
|
if is_dir(path, fs=remote_fs):
|
|
for sub_path in glob_path(path, fs=remote_fs):
|
|
rel_path = sub_prefix(sub_path, path)
|
|
dest_path = join_path("", destination, rel_path)
|
|
mkdir_p(parent(dest_path), fs=local_fs)
|
|
LOGGER.info(f"Downloading {sub_path} to {dest_path}")
|
|
with smart_open.open(sub_path, "rb") as src, smart_open.open(dest_path, "wb") as dest:
|
|
dest.write(src.read())
|
|
else:
|
|
LOGGER.info(f"Downloading {path} to {destination}")
|
|
with smart_open.open(path, "rb") as src, smart_open.open(destination, "wb") as dest:
|
|
dest.write(src.read())
|
|
|
|
return destination
|
|
|
|
|
|
def split_basename_and_extension(path: str) -> Tuple[str, str]:
|
|
"""
|
|
Get the path and extension from a given file path. If a file has multiple
|
|
extensions, they will be joined with a period, e.g. "foo/bar/baz.tar.gz"
|
|
will return ("foo/bar/baz", ".tar.gz"). If the file has no extension, the
|
|
second element of the tuple will be an empty string. Works with both local
|
|
and remote (e.g. s3://) paths.
|
|
|
|
Args:
|
|
path (str): The file path.
|
|
|
|
Returns:
|
|
Tuple[str, str]: A tuple containing the path and extension.
|
|
"""
|
|
prot, (*parts, filename) = split_path(path)
|
|
base, *ext_parts = filename.split(".")
|
|
ext = ("." + ".".join(ext_parts)) if ext_parts else ""
|
|
return join_path(prot, *parts, base), ext
|
|
|
|
|
|
def decompress_path(path: str, dest: Optional[str] = None) -> str:
|
|
"""
|
|
Decompresses a file at the given path and returns the path to the decompressed file.
|
|
|
|
Args:
|
|
path (str): The path to the file to be decompressed.
|
|
dest (str, optional): The destination path for the decompressed file.
|
|
If not provided, a destination path will be computed based on the original
|
|
file name and the cache directory.
|
|
|
|
Returns:
|
|
str: The path to the decompressed file. If the file cannot be decompressed,
|
|
the original path will be returned.
|
|
"""
|
|
for supported_ext in get_supported_extensions():
|
|
# not the supported extension
|
|
if not path.endswith(supported_ext):
|
|
continue
|
|
|
|
if dest is None:
|
|
# compute the name for the decompressed file; to do this, we first hash for
|
|
# resource and then remove the extension.
|
|
base_fn, ext = split_basename_and_extension(resource_to_filename(path))
|
|
|
|
# to get the decompressed file name, we remove the bit of the extension that
|
|
# indicates the compression type.
|
|
decompressed_fn = base_fn + ext.replace(supported_ext, "")
|
|
|
|
# finally, we get cache directory and join the decompressed file name to it
|
|
dest = join_path("", get_cache_dir(), decompressed_fn)
|
|
|
|
# here we do the actual decompression
|
|
with smart_open.open(path, "rb") as fr, smart_open.open(dest, "wb") as fw:
|
|
fw.write(fr.read())
|
|
|
|
# return the path to the decompressed file
|
|
return dest
|
|
|
|
# already decompressed or can't be decompressed
|
|
return path
|
|
|
|
|
|
def split_ext(path: str) -> Tuple[str, Tuple[str, ...], str]:
|
|
"""
|
|
Split a path into its protocol and extensions.
|
|
"""
|
|
prot, parts = split_path(path)
|
|
if not parts:
|
|
return prot, (), ""
|
|
|
|
filename = parts[-1]
|
|
extensions = []
|
|
while True:
|
|
filename, ext = os.path.splitext(filename)
|
|
if not ext:
|
|
break
|
|
extensions.append(ext)
|
|
|
|
return prot, (*parts[:-1], filename), "".join(reversed(extensions))
|
|
|
|
|
|
def get_unified_path(paths: List[str]) -> str:
|
|
"""Get a unified path for a list of paths."""
|
|
|
|
if len(paths) == 1:
|
|
# if there is only one path, we don't need to unify anything
|
|
return paths[0]
|
|
|
|
# get shared root for all paths; we will put the unified path here
|
|
root, relative = make_relative(paths)
|
|
|
|
# get the extension from the first path; assume all paths have the same extension
|
|
_, _, ext = split_ext(relative[0])
|
|
|
|
# hash all the sorted relative paths in order to get a unique name
|
|
# the type: ignore is needed because mypy fails to infer the type of the lambda
|
|
# (the "or" ensures that the lambda returns the same type as the first argument, which is a hash)
|
|
h = reduce(lambda h, p: h.update(p.encode()) or h, sorted(relative), sha256()) # type: ignore
|
|
|
|
# return the unified path
|
|
return join_path(root, h.hexdigest() + ext)
|