mirror of
https://github.com/microsoft/graphrag.git
synced 2025-06-26 23:19:58 +00:00
Fix storage class instantiation (#1582)
This commit is contained in:
parent
a35cb12741
commit
cbb8f8788e
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "fix instantiation of storage classes."
|
||||
}
|
7
graphrag/cache/factory.py
vendored
7
graphrag/cache/factory.py
vendored
@ -8,7 +8,7 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
from graphrag.config.enums import CacheType
|
||||
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
|
||||
from graphrag.storage.blob_pipeline_storage import create_blob_storage
|
||||
from graphrag.storage.cosmosdb_pipeline_storage import create_cosmosdb_storage
|
||||
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
|
||||
|
||||
@ -24,6 +24,9 @@ class CacheFactory:
|
||||
"""A factory class for cache implementations.
|
||||
|
||||
Includes a method for users to register a custom cache implementation.
|
||||
|
||||
Configuration arguments are passed to each cache implementation as kwargs (where possible)
|
||||
for individual enforcement of required/optional arguments.
|
||||
"""
|
||||
|
||||
cache_types: ClassVar[dict[str, type]] = {}
|
||||
@ -50,7 +53,7 @@ class CacheFactory:
|
||||
FilePipelineStorage(root_dir=root_dir).child(kwargs["base_dir"])
|
||||
)
|
||||
case CacheType.blob:
|
||||
return JsonPipelineCache(BlobPipelineStorage(**kwargs))
|
||||
return JsonPipelineCache(create_blob_storage(**kwargs))
|
||||
case CacheType.cosmosdb:
|
||||
return JsonPipelineCache(create_cosmosdb_storage(**kwargs))
|
||||
case _:
|
||||
|
@ -290,13 +290,12 @@ class BlobPipelineStorage(PipelineStorage):
|
||||
return f"abfs://{path}"
|
||||
|
||||
|
||||
def create_blob_storage(
|
||||
connection_string: str | None,
|
||||
storage_account_blob_url: str | None,
|
||||
container_name: str,
|
||||
base_dir: str | None,
|
||||
) -> PipelineStorage:
|
||||
def create_blob_storage(**kwargs: Any) -> PipelineStorage:
|
||||
"""Create a blob based storage."""
|
||||
connection_string = kwargs.get("connection_string")
|
||||
storage_account_blob_url = kwargs.get("storage_account_blob_url")
|
||||
base_dir = kwargs.get("base_dir")
|
||||
container_name = kwargs["container_name"]
|
||||
log.info("Creating blob storage at %s", container_name)
|
||||
if container_name is None:
|
||||
msg = "No container name provided for blob storage."
|
||||
|
@ -21,6 +21,9 @@ class StorageFactory:
|
||||
"""A factory class for storage implementations.
|
||||
|
||||
Includes a method for users to register a custom storage implementation.
|
||||
|
||||
Configuration arguments are passed to each storage implementation as kwargs
|
||||
for individual enforcement of required/optional arguments.
|
||||
"""
|
||||
|
||||
storage_types: ClassVar[dict[str, type]] = {}
|
||||
|
75
tests/integration/storage/test_factory.py
Normal file
75
tests/integration/storage/test_factory.py
Normal file
@ -0,0 +1,75 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
"""StorageFactory Tests.
|
||||
|
||||
These tests will test the StorageFactory class and the creation of each storage type that is natively supported.
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from graphrag.config.enums import StorageType
|
||||
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
|
||||
from graphrag.storage.cosmosdb_pipeline_storage import CosmosDBPipelineStorage
|
||||
from graphrag.storage.factory import StorageFactory
|
||||
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
|
||||
from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage
|
||||
|
||||
# cspell:disable-next-line well-known-key
|
||||
WELL_KNOWN_BLOB_STORAGE_KEY = "DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://127.0.0.1:10000/devstoreaccount1;"
|
||||
# cspell:disable-next-line well-known-key
|
||||
WELL_KNOWN_COSMOS_CONNECTION_STRING = "AccountEndpoint=https://127.0.0.1:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw=="
|
||||
|
||||
|
||||
def test_create_blob_storage():
|
||||
kwargs = {
|
||||
"type": "blob",
|
||||
"connection_string": WELL_KNOWN_BLOB_STORAGE_KEY,
|
||||
"base_dir": "testbasedir",
|
||||
"container_name": "testcontainer",
|
||||
}
|
||||
storage = StorageFactory.create_storage(StorageType.blob, kwargs)
|
||||
assert isinstance(storage, BlobPipelineStorage)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not sys.platform.startswith("win"),
|
||||
reason="cosmosdb emulator is only available on windows runners at this time",
|
||||
)
|
||||
def test_create_cosmosdb_storage():
|
||||
kwargs = {
|
||||
"type": "cosmosdb",
|
||||
"connection_string": WELL_KNOWN_COSMOS_CONNECTION_STRING,
|
||||
"base_dir": "testdatabase",
|
||||
"container_name": "testcontainer",
|
||||
}
|
||||
storage = StorageFactory.create_storage(StorageType.cosmosdb, kwargs)
|
||||
assert isinstance(storage, CosmosDBPipelineStorage)
|
||||
|
||||
|
||||
def test_create_file_storage():
|
||||
kwargs = {"type": "file", "base_dir": "/tmp/teststorage"}
|
||||
storage = StorageFactory.create_storage(StorageType.file, kwargs)
|
||||
assert isinstance(storage, FilePipelineStorage)
|
||||
|
||||
|
||||
def test_create_memory_storage():
|
||||
kwargs = {"type": "memory"}
|
||||
storage = StorageFactory.create_storage(StorageType.memory, kwargs)
|
||||
assert isinstance(storage, MemoryPipelineStorage)
|
||||
|
||||
|
||||
def test_register_and_create_custom_storage():
|
||||
class CustomStorage:
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
StorageFactory.register("custom", CustomStorage)
|
||||
storage = StorageFactory.create_storage("custom", {})
|
||||
assert isinstance(storage, CustomStorage)
|
||||
|
||||
|
||||
def test_create_unknown_storage():
|
||||
with pytest.raises(ValueError, match="Unknown storage type: unknown"):
|
||||
StorageFactory.create_storage("unknown", {})
|
Loading…
x
Reference in New Issue
Block a user