From eeee84e9d9da2078946577813c1ec66388c2f4fd Mon Sep 17 00:00:00 2001 From: Derek Worthen Date: Tue, 28 Jan 2025 10:46:41 -0800 Subject: [PATCH] Add vector store id reference to embeddings config. (#1662) --- .../patch-20250127224919088925.json | 4 +++ graphrag/config/defaults.py | 2 +- graphrag/config/embeddings.py | 14 +++------- graphrag/config/init_content.py | 3 ++- graphrag/config/models/graph_rag_config.py | 26 ++++++++++++++++++- .../config/models/text_embedding_config.py | 4 +++ tests/fixtures/azure/settings.yml | 2 +- tests/fixtures/min-csv/settings.yml | 2 +- tests/fixtures/text/settings.yml | 2 +- tests/unit/config/utils.py | 2 +- 10 files changed, 43 insertions(+), 18 deletions(-) create mode 100644 .semversioner/next-release/patch-20250127224919088925.json diff --git a/.semversioner/next-release/patch-20250127224919088925.json b/.semversioner/next-release/patch-20250127224919088925.json new file mode 100644 index 00000000..5e0d8904 --- /dev/null +++ b/.semversioner/next-release/patch-20250127224919088925.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Add vector store id reference to embeddings config." +} diff --git a/graphrag/config/defaults.py b/graphrag/config/defaults.py index cb961b8b..f33985db 100644 --- a/graphrag/config/defaults.py +++ b/graphrag/config/defaults.py @@ -106,7 +106,7 @@ VECTOR_STORE_TYPE = VectorStoreType.LanceDB.value VECTOR_STORE_DB_URI = str(Path(OUTPUT_BASE_DIR) / "lancedb") VECTOR_STORE_CONTAINER_NAME = "default" VECTOR_STORE_OVERWRITE = True -VECTOR_STORE_INDEX_NAME = "output" +VECTOR_STORE_DEFAULT_ID = "default_vector_store" # Local Search LOCAL_SEARCH_TEXT_UNIT_PROP = 0.5 diff --git a/graphrag/config/embeddings.py b/graphrag/config/embeddings.py index a3222901..11fa82ef 100644 --- a/graphrag/config/embeddings.py +++ b/graphrag/config/embeddings.py @@ -57,18 +57,10 @@ def get_embedding_settings( embeddings_llm_settings = settings.get_language_model_config( settings.embeddings.model_id ) - num_entries = len(settings.vector_store) - if num_entries == 1: - store = next(iter(settings.vector_store.values())) - vector_store_settings = store.model_dump() - else: - # The vector_store dict should only have more than one entry for multi-index query - vector_store_settings = None + vector_store_settings = settings.get_vector_store_config( + settings.embeddings.vector_store_id + ).model_dump() - if vector_store_settings is None: - return { - "strategy": settings.embeddings.resolved_strategy(embeddings_llm_settings) - } # # If we get to this point, settings.vector_store is defined, and there's a specific setting for this embedding. # settings.vector_store.base contains connection information, or may be undefined diff --git a/graphrag/config/init_content.py b/graphrag/config/init_content.py index f510ef7f..eccd05e4 100644 --- a/graphrag/config/init_content.py +++ b/graphrag/config/init_content.py @@ -40,7 +40,7 @@ models: # deployment_name: vector_store: - {defs.VECTOR_STORE_INDEX_NAME}: + {defs.VECTOR_STORE_DEFAULT_ID}: type: {defs.VECTOR_STORE_TYPE} db_uri: {defs.VECTOR_STORE_DB_URI} container_name: {defs.VECTOR_STORE_CONTAINER_NAME} @@ -48,6 +48,7 @@ vector_store: embeddings: model_id: {defs.DEFAULT_EMBEDDING_MODEL_ID} + vector_store_id: {defs.VECTOR_STORE_DEFAULT_ID} ### Input settings ### diff --git a/graphrag/config/models/graph_rag_config.py b/graphrag/config/models/graph_rag_config.py index 9c50714f..1e5fd84c 100644 --- a/graphrag/config/models/graph_rag_config.py +++ b/graphrag/config/models/graph_rag_config.py @@ -226,7 +226,7 @@ class GraphRagConfig(BaseModel): vector_store: dict[str, VectorStoreConfig] = Field( description="The vector store configuration.", - default={"output": VectorStoreConfig()}, + default={defs.VECTOR_STORE_DEFAULT_ID: VectorStoreConfig()}, ) """The vector store configuration.""" @@ -263,6 +263,30 @@ class GraphRagConfig(BaseModel): return self.models[model_id] + def get_vector_store_config(self, vector_store_id: str) -> VectorStoreConfig: + """Get a vector store configuration by ID. + + Parameters + ---------- + vector_store_id : str + The ID of the vector store to get. Should match an ID in the vector_store list. + + Returns + ------- + VectorStoreConfig + The vector store configuration if found. + + Raises + ------ + ValueError + If the vector store ID is not found in the configuration. + """ + if vector_store_id not in self.vector_store: + err_msg = f"Vector Store ID {vector_store_id} not found in configuration. Please rerun `graphrag init` and set the vector store configuration." + raise ValueError(err_msg) + + return self.vector_store[vector_store_id] + @model_validator(mode="after") def _validate_model(self): """Validate the model configuration.""" diff --git a/graphrag/config/models/text_embedding_config.py b/graphrag/config/models/text_embedding_config.py index c26e13ee..9a8763fd 100644 --- a/graphrag/config/models/text_embedding_config.py +++ b/graphrag/config/models/text_embedding_config.py @@ -34,6 +34,10 @@ class TextEmbeddingConfig(BaseModel): description="The model ID to use for text embeddings.", default=defs.EMBEDDING_MODEL_ID, ) + vector_store_id: str = Field( + description="The vector store ID to use for text embeddings.", + default=defs.VECTOR_STORE_DEFAULT_ID, + ) def resolved_strategy(self, model_config: LanguageModelConfig) -> dict: """Get the resolved text embedding strategy.""" diff --git a/tests/fixtures/azure/settings.yml b/tests/fixtures/azure/settings.yml index 3f054b67..6303c771 100644 --- a/tests/fixtures/azure/settings.yml +++ b/tests/fixtures/azure/settings.yml @@ -3,7 +3,7 @@ claim_extraction: embeddings: vector_store: - output: + default_vector_store: type: "azure_ai_search" url: ${AZURE_AI_SEARCH_URL_ENDPOINT} api_key: ${AZURE_AI_SEARCH_API_KEY} diff --git a/tests/fixtures/min-csv/settings.yml b/tests/fixtures/min-csv/settings.yml index 09642c92..ebd9b5f3 100644 --- a/tests/fixtures/min-csv/settings.yml +++ b/tests/fixtures/min-csv/settings.yml @@ -26,7 +26,7 @@ models: async_mode: threaded vector_store: - output: + default_vector_store: type: "lancedb" db_uri: "./tests/fixtures/min-csv/lancedb" container_name: "lancedb_ci" diff --git a/tests/fixtures/text/settings.yml b/tests/fixtures/text/settings.yml index 09b5f13d..d05d384d 100644 --- a/tests/fixtures/text/settings.yml +++ b/tests/fixtures/text/settings.yml @@ -26,7 +26,7 @@ models: async_mode: threaded vector_store: - output: + default_vector_store: type: "azure_ai_search" url: ${AZURE_AI_SEARCH_URL_ENDPOINT} api_key: ${AZURE_AI_SEARCH_API_KEY} diff --git a/tests/unit/config/utils.py b/tests/unit/config/utils.py index d231b5c2..6535f448 100644 --- a/tests/unit/config/utils.py +++ b/tests/unit/config/utils.py @@ -50,7 +50,7 @@ DEFAULT_MODEL_CONFIG = { DEFAULT_GRAPHRAG_CONFIG_SETTINGS = { "models": DEFAULT_MODEL_CONFIG, "vector_store": { - "output": { + defs.VECTOR_STORE_DEFAULT_ID: { "type": defs.VECTOR_STORE_TYPE, "db_uri": defs.VECTOR_STORE_DB_URI, "container_name": defs.VECTOR_STORE_CONTAINER_NAME,