79 lines
2.5 KiB
Python
Raw Normal View History

# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Loader module."""
import json
import logging
import os
from knowledge_loader.data_sources.blob_source import (
BlobDatasource,
load_blob_file,
load_blob_prompt_config,
)
from knowledge_loader.data_sources.default import (
LISTING_FILE,
blob_account_name,
local_data_root,
)
from knowledge_loader.data_sources.local_source import (
LocalDatasource,
load_local_prompt_config,
)
from knowledge_loader.data_sources.typing import DatasetConfig, Datasource
logging.basicConfig(level=logging.INFO)
logging.getLogger("azure").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
def _get_base_path(
dataset: str | None, root: str | None, extra_path: str | None = None
) -> str:
"""Construct and return the base path for the given dataset and extra path."""
return os.path.join( # noqa: PTH118
os.path.dirname(os.path.realpath(__file__)), # noqa: PTH120
root if root else "",
dataset if dataset else "",
*(extra_path.split("/") if extra_path else []),
)
def create_datasource(dataset_folder: str) -> Datasource:
"""Return a datasource that reads from a local or blob storage parquet file."""
if blob_account_name is not None and blob_account_name != "":
return BlobDatasource(dataset_folder)
base_path = _get_base_path(dataset_folder, local_data_root)
return LocalDatasource(base_path)
def load_dataset_listing() -> list[DatasetConfig]:
"""Load dataset listing file."""
datasets = []
if blob_account_name is not None and blob_account_name != "":
try:
blob = load_blob_file(None, LISTING_FILE)
datasets_str = blob.getvalue().decode("utf-8")
if datasets_str:
datasets = json.loads(datasets_str)
except Exception as e: # noqa: BLE001
print(f"Error loading dataset config: {e}") # noqa T201
return []
else:
base_path = _get_base_path(None, local_data_root, LISTING_FILE)
with open(base_path, "r") as file: # noqa: UP015, PTH123
datasets = json.load(file)
return [DatasetConfig(**d) for d in datasets]
def load_prompts(dataset: str) -> dict[str, str]:
"""Return the prompts configuration for a specific dataset."""
if blob_account_name is not None and blob_account_name != "":
return load_blob_prompt_config(dataset)
base_path = _get_base_path(dataset, local_data_root, "prompts")
return load_local_prompt_config(base_path)