mirror of
https://github.com/microsoft/graphrag.git
synced 2025-07-03 07:04:19 +00:00
Fix storage class instantiation (#1582)
This commit is contained in:
parent
a35cb12741
commit
cbb8f8788e
@ -0,0 +1,4 @@
|
|||||||
|
{
|
||||||
|
"type": "patch",
|
||||||
|
"description": "fix instantiation of storage classes."
|
||||||
|
}
|
7
graphrag/cache/factory.py
vendored
7
graphrag/cache/factory.py
vendored
@ -8,7 +8,7 @@ from __future__ import annotations
|
|||||||
from typing import TYPE_CHECKING, ClassVar
|
from typing import TYPE_CHECKING, ClassVar
|
||||||
|
|
||||||
from graphrag.config.enums import CacheType
|
from graphrag.config.enums import CacheType
|
||||||
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
|
from graphrag.storage.blob_pipeline_storage import create_blob_storage
|
||||||
from graphrag.storage.cosmosdb_pipeline_storage import create_cosmosdb_storage
|
from graphrag.storage.cosmosdb_pipeline_storage import create_cosmosdb_storage
|
||||||
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
|
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
|
||||||
|
|
||||||
@ -24,6 +24,9 @@ class CacheFactory:
|
|||||||
"""A factory class for cache implementations.
|
"""A factory class for cache implementations.
|
||||||
|
|
||||||
Includes a method for users to register a custom cache implementation.
|
Includes a method for users to register a custom cache implementation.
|
||||||
|
|
||||||
|
Configuration arguments are passed to each cache implementation as kwargs (where possible)
|
||||||
|
for individual enforcement of required/optional arguments.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cache_types: ClassVar[dict[str, type]] = {}
|
cache_types: ClassVar[dict[str, type]] = {}
|
||||||
@ -50,7 +53,7 @@ class CacheFactory:
|
|||||||
FilePipelineStorage(root_dir=root_dir).child(kwargs["base_dir"])
|
FilePipelineStorage(root_dir=root_dir).child(kwargs["base_dir"])
|
||||||
)
|
)
|
||||||
case CacheType.blob:
|
case CacheType.blob:
|
||||||
return JsonPipelineCache(BlobPipelineStorage(**kwargs))
|
return JsonPipelineCache(create_blob_storage(**kwargs))
|
||||||
case CacheType.cosmosdb:
|
case CacheType.cosmosdb:
|
||||||
return JsonPipelineCache(create_cosmosdb_storage(**kwargs))
|
return JsonPipelineCache(create_cosmosdb_storage(**kwargs))
|
||||||
case _:
|
case _:
|
||||||
|
@ -290,13 +290,12 @@ class BlobPipelineStorage(PipelineStorage):
|
|||||||
return f"abfs://{path}"
|
return f"abfs://{path}"
|
||||||
|
|
||||||
|
|
||||||
def create_blob_storage(
|
def create_blob_storage(**kwargs: Any) -> PipelineStorage:
|
||||||
connection_string: str | None,
|
|
||||||
storage_account_blob_url: str | None,
|
|
||||||
container_name: str,
|
|
||||||
base_dir: str | None,
|
|
||||||
) -> PipelineStorage:
|
|
||||||
"""Create a blob based storage."""
|
"""Create a blob based storage."""
|
||||||
|
connection_string = kwargs.get("connection_string")
|
||||||
|
storage_account_blob_url = kwargs.get("storage_account_blob_url")
|
||||||
|
base_dir = kwargs.get("base_dir")
|
||||||
|
container_name = kwargs["container_name"]
|
||||||
log.info("Creating blob storage at %s", container_name)
|
log.info("Creating blob storage at %s", container_name)
|
||||||
if container_name is None:
|
if container_name is None:
|
||||||
msg = "No container name provided for blob storage."
|
msg = "No container name provided for blob storage."
|
||||||
|
@ -21,6 +21,9 @@ class StorageFactory:
|
|||||||
"""A factory class for storage implementations.
|
"""A factory class for storage implementations.
|
||||||
|
|
||||||
Includes a method for users to register a custom storage implementation.
|
Includes a method for users to register a custom storage implementation.
|
||||||
|
|
||||||
|
Configuration arguments are passed to each storage implementation as kwargs
|
||||||
|
for individual enforcement of required/optional arguments.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
storage_types: ClassVar[dict[str, type]] = {}
|
storage_types: ClassVar[dict[str, type]] = {}
|
||||||
|
75
tests/integration/storage/test_factory.py
Normal file
75
tests/integration/storage/test_factory.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
# Copyright (c) 2024 Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License
|
||||||
|
"""StorageFactory Tests.
|
||||||
|
|
||||||
|
These tests will test the StorageFactory class and the creation of each storage type that is natively supported.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from graphrag.config.enums import StorageType
|
||||||
|
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
|
||||||
|
from graphrag.storage.cosmosdb_pipeline_storage import CosmosDBPipelineStorage
|
||||||
|
from graphrag.storage.factory import StorageFactory
|
||||||
|
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
|
||||||
|
from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage
|
||||||
|
|
||||||
|
# cspell:disable-next-line well-known-key
|
||||||
|
WELL_KNOWN_BLOB_STORAGE_KEY = "DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://127.0.0.1:10000/devstoreaccount1;"
|
||||||
|
# cspell:disable-next-line well-known-key
|
||||||
|
WELL_KNOWN_COSMOS_CONNECTION_STRING = "AccountEndpoint=https://127.0.0.1:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw=="
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_blob_storage():
|
||||||
|
kwargs = {
|
||||||
|
"type": "blob",
|
||||||
|
"connection_string": WELL_KNOWN_BLOB_STORAGE_KEY,
|
||||||
|
"base_dir": "testbasedir",
|
||||||
|
"container_name": "testcontainer",
|
||||||
|
}
|
||||||
|
storage = StorageFactory.create_storage(StorageType.blob, kwargs)
|
||||||
|
assert isinstance(storage, BlobPipelineStorage)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not sys.platform.startswith("win"),
|
||||||
|
reason="cosmosdb emulator is only available on windows runners at this time",
|
||||||
|
)
|
||||||
|
def test_create_cosmosdb_storage():
|
||||||
|
kwargs = {
|
||||||
|
"type": "cosmosdb",
|
||||||
|
"connection_string": WELL_KNOWN_COSMOS_CONNECTION_STRING,
|
||||||
|
"base_dir": "testdatabase",
|
||||||
|
"container_name": "testcontainer",
|
||||||
|
}
|
||||||
|
storage = StorageFactory.create_storage(StorageType.cosmosdb, kwargs)
|
||||||
|
assert isinstance(storage, CosmosDBPipelineStorage)
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_file_storage():
|
||||||
|
kwargs = {"type": "file", "base_dir": "/tmp/teststorage"}
|
||||||
|
storage = StorageFactory.create_storage(StorageType.file, kwargs)
|
||||||
|
assert isinstance(storage, FilePipelineStorage)
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_memory_storage():
|
||||||
|
kwargs = {"type": "memory"}
|
||||||
|
storage = StorageFactory.create_storage(StorageType.memory, kwargs)
|
||||||
|
assert isinstance(storage, MemoryPipelineStorage)
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_and_create_custom_storage():
|
||||||
|
class CustomStorage:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
StorageFactory.register("custom", CustomStorage)
|
||||||
|
storage = StorageFactory.create_storage("custom", {})
|
||||||
|
assert isinstance(storage, CustomStorage)
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_unknown_storage():
|
||||||
|
with pytest.raises(ValueError, match="Unknown storage type: unknown"):
|
||||||
|
StorageFactory.create_storage("unknown", {})
|
Loading…
x
Reference in New Issue
Block a user