mirror of
https://github.com/microsoft/graphrag.git
synced 2025-12-06 03:48:01 +00:00
85 lines
3.1 KiB
Python
85 lines
3.1 KiB
Python
# Copyright (c) 2024 Microsoft Corporation.
|
|
# Licensed under the MIT License
|
|
|
|
"""A module containing create_input method definition."""
|
|
|
|
import logging
|
|
from collections.abc import Awaitable, Callable
|
|
from pathlib import Path
|
|
from typing import cast
|
|
|
|
import pandas as pd
|
|
|
|
from graphrag.config.enums import InputType
|
|
from graphrag.config.models.input_config import InputConfig
|
|
from graphrag.index.config.input import PipelineInputConfig
|
|
from graphrag.index.input.csv import input_type as csv
|
|
from graphrag.index.input.csv import load as load_csv
|
|
from graphrag.index.input.text import input_type as text
|
|
from graphrag.index.input.text import load as load_text
|
|
from graphrag.logger.base import ProgressLogger
|
|
from graphrag.logger.null_progress import NullProgressLogger
|
|
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
|
|
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
|
|
|
|
log = logging.getLogger(__name__)
|
|
loaders: dict[str, Callable[..., Awaitable[pd.DataFrame]]] = {
|
|
text: load_text,
|
|
csv: load_csv,
|
|
}
|
|
|
|
|
|
async def create_input(
|
|
config: PipelineInputConfig | InputConfig,
|
|
progress_reporter: ProgressLogger | None = None,
|
|
root_dir: str | None = None,
|
|
) -> pd.DataFrame:
|
|
"""Instantiate input data for a pipeline."""
|
|
root_dir = root_dir or ""
|
|
log.info("loading input from root_dir=%s", config.base_dir)
|
|
progress_reporter = progress_reporter or NullProgressLogger()
|
|
|
|
if config is None:
|
|
msg = "No input specified!"
|
|
raise ValueError(msg)
|
|
|
|
match config.type:
|
|
case InputType.blob:
|
|
log.info("using blob storage input")
|
|
if config.container_name is None:
|
|
msg = "Container name required for blob storage"
|
|
raise ValueError(msg)
|
|
if (
|
|
config.connection_string is None
|
|
and config.storage_account_blob_url is None
|
|
):
|
|
msg = "Connection string or storage account blob url required for blob storage"
|
|
raise ValueError(msg)
|
|
storage = BlobPipelineStorage(
|
|
connection_string=config.connection_string,
|
|
storage_account_blob_url=config.storage_account_blob_url,
|
|
container_name=config.container_name,
|
|
path_prefix=config.base_dir,
|
|
)
|
|
case InputType.file:
|
|
log.info("using file storage for input")
|
|
storage = FilePipelineStorage(
|
|
root_dir=str(Path(root_dir) / (config.base_dir or ""))
|
|
)
|
|
case _:
|
|
log.info("using file storage for input")
|
|
storage = FilePipelineStorage(
|
|
root_dir=str(Path(root_dir) / (config.base_dir or ""))
|
|
)
|
|
|
|
if config.file_type in loaders:
|
|
progress = progress_reporter.child(
|
|
f"Loading Input ({config.file_type})", transient=False
|
|
)
|
|
loader = loaders[config.file_type]
|
|
results = await loader(config, progress, storage)
|
|
return cast("pd.DataFrame", results)
|
|
|
|
msg = f"Unknown input type {config.file_type}"
|
|
raise ValueError(msg)
|