mirror of
https://github.com/microsoft/graphrag.git
synced 2025-11-05 04:09:53 +00:00
* Rename pipeline_storage file * Add runtime storage option to context * Fix import * Switch to memory storage for runtime * Infra for workflow runtime storage * Migrate base_text_units to runtime storage * Fix comment * Semver * Remove whitespace * Remove subflow smoke tests and ignore transient artifacts * Remove entity graph from transient list (not yet implemented) * Increase smoke runtime allotment for create_base_entity_graph * Revert format fix * Remove noqa
170 lines
5.5 KiB
Python
170 lines
5.5 KiB
Python
# Copyright (c) 2024 Microsoft Corporation.
|
|
# Licensed under the MIT License
|
|
|
|
"""A module containing 'FileStorage' and 'FilePipelineStorage' models."""
|
|
|
|
import logging
|
|
import os
|
|
import re
|
|
import shutil
|
|
from collections.abc import Iterator
|
|
from pathlib import Path
|
|
from typing import Any, cast
|
|
|
|
import aiofiles
|
|
from aiofiles.os import remove
|
|
from aiofiles.ospath import exists
|
|
from datashaper import Progress
|
|
|
|
from graphrag.logging import ProgressReporter
|
|
|
|
from .pipeline_storage import PipelineStorage
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class FilePipelineStorage(PipelineStorage):
|
|
"""File storage class definition."""
|
|
|
|
_root_dir: str
|
|
_encoding: str
|
|
|
|
def __init__(self, root_dir: str | None = None, encoding: str | None = None):
|
|
"""Init method definition."""
|
|
self._root_dir = root_dir or ""
|
|
self._encoding = encoding or "utf-8"
|
|
Path(self._root_dir).mkdir(parents=True, exist_ok=True)
|
|
|
|
def find(
|
|
self,
|
|
file_pattern: re.Pattern[str],
|
|
base_dir: str | None = None,
|
|
progress: ProgressReporter | None = None,
|
|
file_filter: dict[str, Any] | None = None,
|
|
max_count=-1,
|
|
) -> Iterator[tuple[str, dict[str, Any]]]:
|
|
"""Find files in the storage using a file pattern, as well as a custom filter function."""
|
|
|
|
def item_filter(item: dict[str, Any]) -> bool:
|
|
if file_filter is None:
|
|
return True
|
|
|
|
return all(re.match(value, item[key]) for key, value in file_filter.items())
|
|
|
|
search_path = Path(self._root_dir) / (base_dir or "")
|
|
log.info("search %s for files matching %s", search_path, file_pattern.pattern)
|
|
all_files = list(search_path.rglob("**/*"))
|
|
num_loaded = 0
|
|
num_total = len(all_files)
|
|
num_filtered = 0
|
|
for file in all_files:
|
|
match = file_pattern.match(f"{file}")
|
|
if match:
|
|
group = match.groupdict()
|
|
if item_filter(group):
|
|
filename = f"{file}".replace(self._root_dir, "")
|
|
if filename.startswith(os.sep):
|
|
filename = filename[1:]
|
|
yield (filename, group)
|
|
num_loaded += 1
|
|
if max_count > 0 and num_loaded >= max_count:
|
|
break
|
|
else:
|
|
num_filtered += 1
|
|
else:
|
|
num_filtered += 1
|
|
if progress is not None:
|
|
progress(_create_progress_status(num_loaded, num_filtered, num_total))
|
|
|
|
async def get(
|
|
self, key: str, as_bytes: bool | None = False, encoding: str | None = None
|
|
) -> Any:
|
|
"""Get method definition."""
|
|
file_path = join_path(self._root_dir, key)
|
|
|
|
if await self.has(key):
|
|
return await self._read_file(file_path, as_bytes, encoding)
|
|
if await exists(key):
|
|
# Lookup for key, as it is pressumably a new file loaded from inputs
|
|
# and not yet written to storage
|
|
return await self._read_file(key, as_bytes, encoding)
|
|
|
|
return None
|
|
|
|
async def _read_file(
|
|
self,
|
|
path: str | Path,
|
|
as_bytes: bool | None = False,
|
|
encoding: str | None = None,
|
|
) -> Any:
|
|
"""Read the contents of a file."""
|
|
read_type = "rb" if as_bytes else "r"
|
|
encoding = None if as_bytes else (encoding or self._encoding)
|
|
|
|
async with aiofiles.open(
|
|
path,
|
|
cast(Any, read_type),
|
|
encoding=encoding,
|
|
) as f:
|
|
return await f.read()
|
|
|
|
async def set(self, key: str, value: Any, encoding: str | None = None) -> None:
|
|
"""Set method definition."""
|
|
is_bytes = isinstance(value, bytes)
|
|
write_type = "wb" if is_bytes else "w"
|
|
encoding = None if is_bytes else encoding or self._encoding
|
|
async with aiofiles.open(
|
|
join_path(self._root_dir, key),
|
|
cast(Any, write_type),
|
|
encoding=encoding,
|
|
) as f:
|
|
await f.write(value)
|
|
|
|
async def has(self, key: str) -> bool:
|
|
"""Has method definition."""
|
|
return await exists(join_path(self._root_dir, key))
|
|
|
|
async def delete(self, key: str) -> None:
|
|
"""Delete method definition."""
|
|
if await self.has(key):
|
|
await remove(join_path(self._root_dir, key))
|
|
|
|
async def clear(self) -> None:
|
|
"""Clear method definition."""
|
|
for file in Path(self._root_dir).glob("*"):
|
|
if file.is_dir():
|
|
shutil.rmtree(file)
|
|
else:
|
|
file.unlink()
|
|
|
|
def child(self, name: str | None) -> "PipelineStorage":
|
|
"""Create a child storage instance."""
|
|
if name is None:
|
|
return self
|
|
return FilePipelineStorage(str(Path(self._root_dir) / Path(name)))
|
|
|
|
def keys(self) -> list[str]:
|
|
"""Return the keys in the storage."""
|
|
return os.listdir(self._root_dir)
|
|
|
|
|
|
def join_path(file_path: str, file_name: str) -> Path:
|
|
"""Join a path and a file. Independent of the OS."""
|
|
return Path(file_path) / Path(file_name).parent / Path(file_name).name
|
|
|
|
|
|
def create_file_storage(out_dir: str | None) -> PipelineStorage:
|
|
"""Create a file based storage."""
|
|
log.info("Creating file storage at %s", out_dir)
|
|
return FilePipelineStorage(out_dir)
|
|
|
|
|
|
def _create_progress_status(
|
|
num_loaded: int, num_filtered: int, num_total: int
|
|
) -> Progress:
|
|
return Progress(
|
|
total_items=num_total,
|
|
completed_items=num_loaded + num_filtered,
|
|
description=f"{num_loaded} files loaded ({num_filtered} filtered)",
|
|
)
|