Fix vector store logic and refactor audience parameter (#1259)

This commit is contained in:
KennyZhang1 2024-10-21 16:56:56 -04:00 committed by GitHub
parent 6aae386b30
commit e0840a2dc4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 203 additions and 194 deletions

2
.gitignore vendored
View File

@ -1,6 +1,7 @@
# Python Artifacts
python/*/lib/
dist/
# Test Output
.coverage
coverage/
@ -20,7 +21,6 @@ venv/
.conda
.tmp
.env
build.zip

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "refactor use of vector stores and update support for managed identity"
}

View File

@ -52,7 +52,7 @@ This is the base LLM configuration section. Other steps may override this config
- `api_version` **str** - The API version
- `organization` **str** - The client organization.
- `proxy` **str** - The proxy URL to use.
- `cognitive_services_endpoint` **str** - The url endpoint for cognitive services.
- `audience` **str** - (Azure OpenAI only) The URI of the target Azure resource/service for which a managed identity token is requested. Used if `api_key` is not defined. Default=`https://cognitiveservices.azure.com/.default`
- `deployment_name` **str** - The deployment name to use (Azure).
- `model_supports_json` **bool** - Whether the model supports JSON-mode output.
- `tokens_per_minute` **int** - Set a leaky-bucket throttle on tokens-per-minute.
@ -84,9 +84,17 @@ This is the base LLM configuration section. Other steps may override this config
- `parallelization` (see Parallelization top-level config)
- `async_mode` (see Async Mode top-level config)
- `batch_size` **int** - The maximum batch size to use.
- `batch_max_tokens` **int** - The maximum batch #-tokens.
- `batch_max_tokens` **int** - The maximum batch # of tokens.
- `target` **required|all** - Determines which set of embeddings to emit.
- `skip` **list[str]** - Which embeddings to skip.
- `vector_store` **dict** - The vector store to use. Configured for lancedb by default.
- `type` **str** - `lancedb` or `azure_ai_search`. Default=`lancedb`
- `db_uri` **str** (only for lancedb) - The database uri. Default=`storage.base_dir/lancedb`
- `url` **str** (only for AI Search) - AI Search endpoint
- `api_key` **str** (optional - only for AI Search) - The AI Search api key to use.
- `audience` **str** (only for AI Search) - Audience for managed identity token if managed identity authentication is used.
- `overwrite` **bool** (only used at index creation time) - Overwrite collection if it exist. Default=`True`
- `collection_name` **str** - The name of a vector collection. Default=`entity_description_embeddings`
- `strategy` **dict** - Fully override the text-embedding strategy.
## chunks
@ -214,7 +222,7 @@ This is the base LLM configuration section. Other steps may override this config
## encoding_model
**str** - The text encoding model to use. Default is `cl100k_base`.
**str** - The text encoding model to use. Default=`cl100k_base`.
## skip_workflows

View File

@ -7,9 +7,9 @@ WARNING: This API is under development and may undergo changes in future release
Backwards compatibility is not guaranteed at this time.
"""
from .index_api import build_index
from .prompt_tune_api import DocSelectionType, generate_indexing_prompts
from .query_api import (
from graphrag.api.index import build_index
from graphrag.api.prompt_tune import DocSelectionType, generate_indexing_prompts
from graphrag.api.query import (
global_search,
global_search_streaming,
local_search,

View File

@ -8,6 +8,8 @@ WARNING: This API is under development and may undergo changes in future release
Backwards compatibility is not guaranteed at this time.
"""
from pathlib import Path
from graphrag.config import CacheType, GraphRagConfig
from graphrag.index.cache.noop_pipeline_cache import NoopPipelineCache
from graphrag.index.create_pipeline_config import create_pipeline_config
@ -15,6 +17,7 @@ from graphrag.index.emit.types import TableEmitterType
from graphrag.index.run import run_pipeline_with_config
from graphrag.index.typing import PipelineRunResult
from graphrag.logging import ProgressReporter
from graphrag.vector_stores.factory import VectorStoreType
async def build_index(
@ -30,7 +33,7 @@ async def build_index(
Parameters
----------
config : PipelineConfig
config : GraphRagConfig
The configuration.
run_id : str
The run id. Creates a output directory with this name.
@ -55,6 +58,14 @@ async def build_index(
msg = "Cannot resume and update a run at the same time."
raise ValueError(msg)
# TODO: must update filepath of lancedb (if used) until the new config engine has been implemented
# TODO: remove the type ignore annotations below once the new config engine has been refactored
vector_store_type = config.embeddings.vector_store["type"] # type: ignore
if vector_store_type == VectorStoreType.LanceDB:
db_uri = config.embeddings.vector_store["db_uri"] # type: ignore
lancedb_dir = Path(config.root_dir).resolve() / db_uri
config.embeddings.vector_store["db_uri"] = str(lancedb_dir) # type: ignore
pipeline_config = create_pipeline_config(config)
pipeline_cache = (
NoopPipelineCache() if config.cache.type == CacheType.none is None else None

View File

@ -26,7 +26,6 @@ from pydantic import validate_call
from graphrag.config import GraphRagConfig
from graphrag.logging import PrintProgressReporter
from graphrag.model.entity import Entity
from graphrag.query.factories import get_global_search_engine, get_local_search_engine
from graphrag.query.indexer_adapters import (
read_indexer_covariates,
@ -35,10 +34,9 @@ from graphrag.query.indexer_adapters import (
read_indexer_reports,
read_indexer_text_units,
)
from graphrag.query.input.loaders.dfs import store_entity_semantic_embeddings
from graphrag.query.structured_search.base import SearchResult # noqa: TCH001
from graphrag.vector_stores.lancedb import LanceDBVectorStore
from graphrag.vector_stores.typing import VectorStoreFactory, VectorStoreType
from graphrag.utils.cli import redact
from graphrag.vector_stores import VectorStoreFactory, VectorStoreType
reporter = PrintProgressReporter("")
@ -184,24 +182,20 @@ async def local_search(
------
TODO: Document any exceptions to expect.
"""
vector_store_args = (
config.embeddings.vector_store if config.embeddings.vector_store else {}
# TODO: must update filepath of lancedb (if used) until the new config engine has been implemented
# TODO: remove the type ignore annotations below once the new config engine has been refactored
vector_store_type = config.embeddings.vector_store.get("type") # type: ignore
vector_store_args = config.embeddings.vector_store
if vector_store_type == "lancedb":
db_uri = config.embeddings.vector_store["db_uri"] # type: ignore
lancedb_dir = Path(config.root_dir).resolve() / db_uri
vector_store_args["db_uri"] = str(lancedb_dir) # type: ignore
reporter.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore
description_embedding_store = _get_embedding_description_store(
config_args=vector_store_args, # type: ignore
)
reporter.info(f"Vector Store Args: {vector_store_args}")
vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB)
_entities = read_indexer_entities(nodes, entities, community_level)
lancedb_dir = Path(config.storage.base_dir) / "lancedb"
vector_store_args.update({"db_uri": str(lancedb_dir)})
description_embedding_store = _get_embedding_description_store(
entities=_entities,
vector_store_type=vector_store_type,
config_args=vector_store_args,
)
_covariates = read_indexer_covariates(covariates) if covariates is not None else []
search_engine = get_local_search_engine(
@ -257,24 +251,20 @@ async def local_search_streaming(
------
TODO: Document any exceptions to expect.
"""
vector_store_args = (
config.embeddings.vector_store if config.embeddings.vector_store else {}
# TODO: must update filepath of lancedb (if used) until the new config engine has been implemented
# TODO: remove the type ignore annotations below once the new config engine has been refactored
vector_store_type = config.embeddings.vector_store["type"] # type: ignore
vector_store_args = config.embeddings.vector_store
if vector_store_type == VectorStoreType.LanceDB:
db_uri = config.embeddings.vector_store["db_uri"] # type: ignore
lancedb_dir = Path(config.root_dir).resolve() / db_uri
vector_store_args["db_uri"] = str(lancedb_dir) # type: ignore
reporter.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore
description_embedding_store = _get_embedding_description_store(
config_args=vector_store_args, # type: ignore
)
reporter.info(f"Vector Store Args: {vector_store_args}")
vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB)
_entities = read_indexer_entities(nodes, entities, community_level)
lancedb_dir = Path(config.storage.base_dir) / "lancedb"
vector_store_args.update({"db_uri": str(lancedb_dir)})
description_embedding_store = _get_embedding_description_store(
entities=_entities,
vector_store_type=vector_store_type,
config_args=vector_store_args,
)
_covariates = read_indexer_covariates(covariates) if covariates is not None else []
search_engine = get_local_search_engine(
@ -303,49 +293,14 @@ async def local_search_streaming(
def _get_embedding_description_store(
entities: list[Entity],
vector_store_type: str = VectorStoreType.LanceDB,
config_args: dict | None = None,
config_args: dict,
):
"""Get the embedding description store."""
if not config_args:
config_args = {}
collection_name = config_args.get(
"query_collection_name", "entity_description_embeddings"
)
config_args.update({"collection_name": collection_name})
vector_store_type = config_args["type"]
description_embedding_store = VectorStoreFactory.get_vector_store(
vector_store_type=vector_store_type, kwargs=config_args
)
description_embedding_store.connect(**config_args)
if config_args.get("overwrite", True):
# this step assumes the embeddings were originally stored in a file rather
# than a vector database
# dump embeddings from the entities list to the description_embedding_store
store_entity_semantic_embeddings(
entities=entities, vectorstore=description_embedding_store
)
else:
# load description embeddings to an in-memory lancedb vectorstore
# and connect to a remote db, specify url and port values.
description_embedding_store = LanceDBVectorStore(
collection_name=collection_name
)
description_embedding_store.connect(
db_uri=config_args.get("db_uri", "./lancedb")
)
# load data from an existing table
description_embedding_store.document_collection = (
description_embedding_store.db_connection.open_table(
description_embedding_store.collection_name
)
)
return description_embedding_store

View File

@ -83,10 +83,7 @@ def create_graphrag_config(
llm_type = LLMType(llm_type) if llm_type else base.type
api_key = reader.str(Fragment.api_key) or base.api_key
api_base = reader.str(Fragment.api_base) or base.api_base
cognitive_services_endpoint = (
reader.str(Fragment.cognitive_services_endpoint)
or base.cognitive_services_endpoint
)
audience = reader.str(Fragment.audience) or base.audience
deployment_name = (
reader.str(Fragment.deployment_name) or base.deployment_name
)
@ -119,7 +116,7 @@ def create_graphrag_config(
or base.model_supports_json,
request_timeout=reader.float(Fragment.request_timeout)
or base.request_timeout,
cognitive_services_endpoint=cognitive_services_endpoint,
audience=audience,
deployment_name=deployment_name,
tokens_per_minute=reader.int("tokens_per_minute", Fragment.tpm)
or base.tokens_per_minute,
@ -141,7 +138,7 @@ def create_graphrag_config(
api_type = LLMType(api_type) if api_type else defs.LLM_TYPE
api_key = reader.str(Fragment.api_key) or base.api_key
# In a unique events where:
# Account for various permutations of config settings such as:
# - same api_bases for LLM and embeddings (both Azure)
# - different api_bases for LLM and embeddings (both Azure)
# - LLM uses Azure OpenAI, while embeddings uses base OpenAI (this one is important)
@ -158,10 +155,7 @@ def create_graphrag_config(
)
api_organization = reader.str("organization") or base.organization
api_proxy = reader.str("proxy") or base.proxy
cognitive_services_endpoint = (
reader.str(Fragment.cognitive_services_endpoint)
or base.cognitive_services_endpoint
)
audience = reader.str(Fragment.audience) or base.audience
deployment_name = reader.str(Fragment.deployment_name)
if api_key is None and not _is_azure(api_type):
@ -186,7 +180,7 @@ def create_graphrag_config(
model=reader.str(Fragment.model) or defs.EMBEDDING_MODEL,
request_timeout=reader.float(Fragment.request_timeout)
or defs.LLM_REQUEST_TIMEOUT,
cognitive_services_endpoint=cognitive_services_endpoint,
audience=audience,
deployment_name=deployment_name,
tokens_per_minute=reader.int("tokens_per_minute", Fragment.tpm)
or defs.LLM_TOKENS_PER_MINUTE,
@ -237,9 +231,7 @@ def create_graphrag_config(
api_base = reader.str(Fragment.api_base) or fallback_oai_base
api_version = reader.str(Fragment.api_version) or fallback_oai_version
api_proxy = reader.str(Fragment.api_proxy) or fallback_oai_proxy
cognitive_services_endpoint = reader.str(
Fragment.cognitive_services_endpoint
)
audience = reader.str(Fragment.audience)
deployment_name = reader.str(Fragment.deployment_name)
if api_key is None and not _is_azure(llm_type):
@ -270,7 +262,7 @@ def create_graphrag_config(
model_supports_json=reader.bool(Fragment.model_supports_json),
request_timeout=reader.float(Fragment.request_timeout)
or defs.LLM_REQUEST_TIMEOUT,
cognitive_services_endpoint=cognitive_services_endpoint,
audience=audience,
deployment_name=deployment_name,
tokens_per_minute=reader.int(Fragment.tpm)
or defs.LLM_TOKENS_PER_MINUTE,
@ -294,13 +286,15 @@ def create_graphrag_config(
embeddings_config = values.get("embeddings") or {}
with reader.envvar_prefix(Section.embedding), reader.use(embeddings_config):
embeddings_target = reader.str("target")
# TODO: remove the type ignore annotations below once the new config engine has been refactored
embeddings_model = TextEmbeddingConfig(
llm=hydrate_embeddings_params(embeddings_config, llm_model),
llm=hydrate_embeddings_params(embeddings_config, llm_model), # type: ignore
parallelization=hydrate_parallelization_params(
embeddings_config, llm_parallelization_model
embeddings_config, # type: ignore
llm_parallelization_model, # type: ignore
),
vector_store=embeddings_config.get("vector_store", None),
async_mode=hydrate_async_type(embeddings_config, async_mode),
async_mode=hydrate_async_type(embeddings_config, async_mode), # type: ignore
target=(
TextEmbeddingTarget(embeddings_target)
if embeddings_target
@ -579,8 +573,8 @@ class Fragment(str, Enum):
api_organization = "API_ORGANIZATION"
api_proxy = "API_PROXY"
async_mode = "ASYNC_MODE"
audience = "AUDIENCE"
base_dir = "BASE_DIR"
cognitive_services_endpoint = "COGNITIVE_SERVICES_ENDPOINT"
concurrent_requests = "CONCURRENT_REQUESTS"
conn_string = "CONNECTION_STRING"
container_name = "CONTAINER_NAME"

View File

@ -3,8 +3,12 @@
"""Common default configuration values."""
from pathlib import Path
from datashaper import AsyncType
from graphrag.vector_stores import VectorStoreType
from .enums import (
CacheType,
InputFileType,
@ -74,7 +78,7 @@ NODE2VEC_WINDOW_SIZE = 2
NODE2VEC_ITERATIONS = 3
NODE2VEC_RANDOM_SEED = 597832
REPORTING_TYPE = ReportingType.file
REPORTING_BASE_DIR = "output"
REPORTING_BASE_DIR = "logs"
SNAPSHOTS_GRAPHML = False
SNAPSHOTS_RAW_ENTITIES = False
SNAPSHOTS_TOP_LEVEL_NODES = False
@ -83,6 +87,13 @@ STORAGE_TYPE = StorageType.file
SUMMARIZE_DESCRIPTIONS_MAX_LENGTH = 500
UMAP_ENABLED = False
VECTOR_STORE = f"""
type: {VectorStoreType.LanceDB.value}
db_uri: '{(Path(STORAGE_BASE_DIR) / "lancedb")!s}'
collection_name: entity_description_embeddings
overwrite: true\
"""
# Local Search
LOCAL_SEARCH_TEXT_UNIT_PROP = 0.5
LOCAL_SEARCH_COMMUNITY_PROP = 0.1

View File

@ -20,7 +20,7 @@ class LLMParametersInput(TypedDict):
api_version: NotRequired[str | None]
organization: NotRequired[str | None]
proxy: NotRequired[str | None]
cognitive_services_endpoint: NotRequired[str | None]
audience: NotRequired[str | None]
deployment_name: NotRequired[str | None]
model_supports_json: NotRequired[bool | str | None]
tokens_per_minute: NotRequired[int | str | None]

View File

@ -52,8 +52,9 @@ class LLMParameters(BaseModel):
proxy: str | None = Field(
description="The proxy to use for the LLM service.", default=None
)
cognitive_services_endpoint: str | None = Field(
description="The endpoint to reach cognitives services.", default=None
audience: str | None = Field(
description="Azure resource URI to use with managed identity for the llm connection.",
default=None,
)
deployment_name: str | None = Field(
description="The deployment name to use for the LLM service.", default=None

View File

@ -27,7 +27,7 @@ class TextEmbeddingConfig(LLMConfig):
)
skip: list[str] = Field(description="The specific embeddings to skip.", default=[])
vector_store: dict | None = Field(
description="The vector storage configuration", default=None
description="The vector storage configuration", default=defs.VECTOR_STORE
)
strategy: dict | None = Field(
description="The override strategy to use.", default=None

View File

@ -4,7 +4,6 @@
"""Main definition."""
import asyncio
import json
import logging
import sys
import time
@ -19,6 +18,7 @@ from graphrag.config import (
resolve_paths,
)
from graphrag.logging import ProgressReporter, ReporterType, create_progress_reporter
from graphrag.utils.cli import redact
from .emit.types import TableEmitterType
from .graph.extractors.claims.prompts import CLAIM_EXTRACTION_PROMPT
@ -34,36 +34,6 @@ warnings.filterwarnings("ignore", message=".*NumbaDeprecationWarning.*")
log = logging.getLogger(__name__)
def _redact(input: dict) -> str:
"""Sanitize the config json."""
# Redact any sensitive configuration
def redact_dict(input: dict) -> dict:
if not isinstance(input, dict):
return input
result = {}
for key, value in input.items():
if key in {
"api_key",
"connection_string",
"container_name",
"organization",
}:
if value is not None:
result[key] = "==== REDACTED ===="
elif isinstance(value, dict):
result[key] = redact_dict(value)
elif isinstance(value, list):
result[key] = [redact_dict(i) for i in value]
else:
result[key] = value
return result
redacted_dict = redact_dict(input)
return json.dumps(redacted_dict, indent=4)
def _logger(reporter: ProgressReporter):
def info(msg: str, verbose: bool = False):
log.info(msg)
@ -140,7 +110,7 @@ def index_cli(
info(f"Logging enabled at {log_path}", True)
else:
info(
f"Logging not enabled for config {_redact(config.model_dump())}",
f"Logging not enabled for config {redact(config.model_dump())}",
True,
)
@ -149,7 +119,7 @@ def index_cli(
info(f"Starting pipeline run for: {run_id}, {dryrun=}", verbose)
info(
f"Using default configuration: {_redact(config.model_dump())}",
f"Using default configuration: {redact(config.model_dump())}",
verbose,
)

View File

@ -1,10 +1,11 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Content for the init CLI command."""
"""Content for the init CLI command to generate a default configuration."""
import graphrag.config.defaults as defs
INIT_YAML = f"""
INIT_YAML = f"""\
encoding_model: cl100k_base
skip_workflows: []
llm:
@ -12,6 +13,7 @@ llm:
type: {defs.LLM_TYPE.value} # or azure_openai_chat
model: {defs.LLM_MODEL}
model_supports_json: true # recommended if this is available for your model.
# audience: "https://cognitiveservices.azure.com/.default"
# max_tokens: {defs.LLM_MAX_TOKENS}
# request_timeout: {defs.LLM_REQUEST_TIMEOUT}
# api_base: https://<instance>.openai.azure.com
@ -40,12 +42,21 @@ embeddings:
# target: {defs.EMBEDDING_TARGET.value} # or all
# batch_size: {defs.EMBEDDING_BATCH_SIZE} # the number of documents to send in a single request
# batch_max_tokens: {defs.EMBEDDING_BATCH_MAX_TOKENS} # the maximum number of tokens to send in a single request
vector_store:{defs.VECTOR_STORE}
# vector_store: # configuration for AI Search
# type: azure_ai_search
# url: <ai_search_endpoint>
# api_key: <api_key> # if not set, will attempt to use managed identity. Expects the `Search Index Data Contributor` RBAC role in this case.
# audience: <optional> # if using managed identity, the audience to use for the token
# overwrite: true # or false. Only applicable at index creation time
# collection_name: <collection_name> # the name of the collection to use
llm:
api_key: ${{GRAPHRAG_API_KEY}}
type: {defs.EMBEDDING_TYPE.value} # or azure_openai_embedding
model: {defs.EMBEDDING_MODEL}
# api_base: https://<instance>.openai.azure.com
# api_version: 2024-02-15-preview
# audience: "https://cognitiveservices.azure.com/.default"
# organization: <organization_id>
# deployment_name: <azure_model_deployment_name>
# tokens_per_minute: 150_000 # set a leaky bucket throttle
@ -160,6 +171,6 @@ global_search:
# concurrency: {defs.GLOBAL_SEARCH_CONCURRENCY}
"""
INIT_DOTENV = """
INIT_DOTENV = """\
GRAPHRAG_API_KEY=<API_KEY>
"""

View File

@ -199,7 +199,7 @@ def _get_base_config(config: dict[str, Any]) -> dict[str, Any]:
"model_supports_json": config.get("model_supports_json"),
"concurrent_requests": config.get("concurrent_requests", 4),
"encoding_model": config.get("encoding_model", "cl100k_base"),
"cognitive_services_endpoint": config.get("cognitive_services_endpoint"),
"audience": config.get("audience"),
}

View File

@ -231,7 +231,7 @@ def _get_collection_name(vector_store_config: dict, embedding_name: str) -> str:
collection_names = vector_store_config.get("collection_names", {})
collection_name = collection_names.get(embedding_name, embedding_name)
msg = f"using {vector_store_config.get('type')} collection_name {collection_name} for embedding {embedding_name}"
msg = f"using vector store {vector_store_config.get('type')} with collection_name {collection_name} for embedding {embedding_name}"
log.info(msg)
return collection_name

View File

@ -32,15 +32,16 @@ def create_openai_client(
api_base,
configuration.deployment_name,
)
if configuration.cognitive_services_endpoint is None:
cognitive_services_endpoint = "https://cognitiveservices.azure.com/.default"
else:
cognitive_services_endpoint = configuration.cognitive_services_endpoint
audience = (
configuration.audience
if configuration.audience
else "https://cognitiveservices.azure.com/.default"
)
return AsyncAzureOpenAI(
api_key=configuration.api_key if configuration.api_key else None,
azure_ad_token_provider=get_bearer_token_provider(
DefaultAzureCredential(), cognitive_services_endpoint
DefaultAzureCredential(), audience
)
if not configuration.api_key
else None,

View File

@ -26,7 +26,7 @@ class OpenAIConfiguration(Hashable, LLMConfig):
_api_base: str | None
_api_version: str | None
_cognitive_services_endpoint: str | None
_audience: str | None
_deployment_name: str | None
_organization: str | None
_proxy: str | None
@ -103,7 +103,7 @@ class OpenAIConfiguration(Hashable, LLMConfig):
self._deployment_name = lookup_str("deployment_name")
self._api_base = lookup_str("api_base")
self._api_version = lookup_str("api_version")
self._cognitive_services_endpoint = lookup_str("cognitive_services_endpoint")
self._audience = lookup_str("audience")
self._organization = lookup_str("organization")
self._proxy = lookup_str("proxy")
self._n = lookup_int("n")
@ -156,9 +156,9 @@ class OpenAIConfiguration(Hashable, LLMConfig):
return _non_blank(self._api_version)
@property
def cognitive_services_endpoint(self) -> str | None:
def audience(self) -> str | None:
"""API version property definition."""
return _non_blank(self._cognitive_services_endpoint)
return _non_blank(self._audience)
@property
def organization(self) -> str | None:

View File

@ -44,17 +44,16 @@ def get_llm(config: GraphRagConfig) -> ChatOpenAI:
**config.llm.model_dump(),
"api_key": f"REDACTED,len={len(debug_llm_key)}",
}
if config.llm.cognitive_services_endpoint is None:
cognitive_services_endpoint = "https://cognitiveservices.azure.com/.default"
else:
cognitive_services_endpoint = config.llm.cognitive_services_endpoint
audience = (
config.llm.audience
if config.llm.audience
else "https://cognitiveservices.azure.com/.default"
)
print(f"creating llm client with {llm_debug_info}") # noqa T201
return ChatOpenAI(
api_key=config.llm.api_key,
azure_ad_token_provider=(
get_bearer_token_provider(
DefaultAzureCredential(), cognitive_services_endpoint
)
get_bearer_token_provider(DefaultAzureCredential(), audience)
if is_azure_client and not config.llm.api_key
else None
),
@ -77,17 +76,15 @@ def get_text_embedder(config: GraphRagConfig) -> OpenAIEmbedding:
**config.embeddings.llm.model_dump(),
"api_key": f"REDACTED,len={len(debug_embedding_api_key)}",
}
if config.embeddings.llm.cognitive_services_endpoint is None:
cognitive_services_endpoint = "https://cognitiveservices.azure.com/.default"
if config.embeddings.llm.audience is None:
audience = "https://cognitiveservices.azure.com/.default"
else:
cognitive_services_endpoint = config.embeddings.llm.cognitive_services_endpoint
audience = config.embeddings.llm.audience
print(f"creating embedding llm client with {llm_debug_info}") # noqa T201
return OpenAIEmbedding(
api_key=config.embeddings.llm.api_key,
azure_ad_token_provider=(
get_bearer_token_provider(
DefaultAzureCredential(), cognitive_services_endpoint
)
get_bearer_token_provider(DefaultAzureCredential(), audience)
if is_azure_client and not config.embeddings.llm.api_key
else None
),

View File

@ -4,6 +4,7 @@
"""CLI functions for the GraphRAG module."""
import argparse
import json
from pathlib import Path
@ -21,3 +22,33 @@ def dir_exist(path):
msg = f"Directory not found: {path}"
raise argparse.ArgumentTypeError(msg)
return path
def redact(config: dict) -> str:
"""Sanitize secrets in a config object."""
# Redact any sensitive configuration
def redact_dict(config: dict) -> dict:
if not isinstance(config, dict):
return config
result = {}
for key, value in config.items():
if key in {
"api_key",
"connection_string",
"container_name",
"organization",
}:
if value is not None:
result[key] = "==== REDACTED ===="
elif isinstance(value, dict):
result[key] = redact_dict(value)
elif isinstance(value, list):
result[key] = [redact_dict(i) for i in value]
else:
result[key] = value
return result
redacted_dict = redact_dict(config)
return json.dumps(redacted_dict, indent=4)

View File

@ -3,15 +3,15 @@
"""A module containing vector storage implementations."""
from .azure_ai_search import AzureAISearch
from .base import BaseVectorStore, VectorStoreDocument, VectorStoreSearchResult
from .lancedb import LanceDBVectorStore
from .typing import VectorStoreFactory, VectorStoreType
from graphrag.vector_stores.base import (
BaseVectorStore,
VectorStoreDocument,
VectorStoreSearchResult,
)
from graphrag.vector_stores.factory import VectorStoreFactory, VectorStoreType
__all__ = [
"AzureAISearch",
"BaseVectorStore",
"LanceDBVectorStore",
"VectorStoreDocument",
"VectorStoreFactory",
"VectorStoreSearchResult",

View File

@ -35,13 +35,16 @@ from .base import (
class AzureAISearch(BaseVectorStore):
"""The Azure AI Search vector storage implementation."""
"""Azure AI Search vector storage implementation."""
index_client: SearchIndexClient
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
def connect(self, **kwargs: Any) -> Any:
"""Connect to the AzureAI vector store."""
url = kwargs.get("url")
"""Connect to AI search vector storage."""
url = kwargs["url"]
api_key = kwargs.get("api_key")
audience = kwargs.get("audience")
self.vector_size = kwargs.get("vector_size", DEFAULT_VECTOR_SIZE)
@ -51,7 +54,7 @@ class AzureAISearch(BaseVectorStore):
)
if url:
audience_arg = {"audience": audience} if audience else {}
audience_arg = {"audience": audience} if audience and not api_key else {}
self.db_connection = SearchClient(
endpoint=url,
index_name=self.collection_name,
@ -68,18 +71,18 @@ class AzureAISearch(BaseVectorStore):
**audience_arg,
)
else:
not_supported_error = "AAISearchDBClient is not supported on local host."
not_supported_error = "Azure AI Search expects `url`."
raise ValueError(not_supported_error)
def load_documents(
self, documents: list[VectorStoreDocument], overwrite: bool = True
) -> None:
"""Load documents into the Azure AI Search index."""
"""Load documents into an Azure AI Search index."""
if overwrite:
if self.collection_name in self.index_client.list_index_names():
self.index_client.delete_index(self.collection_name)
# Configure the vector search profile
# Configure vector search profile
vector_search = VectorSearch(
algorithms=[
HnswAlgorithmConfiguration(
@ -96,7 +99,7 @@ class AzureAISearch(BaseVectorStore):
)
],
)
# Configure the index
index = SearchIndex(
name=self.collection_name,
fields=[
@ -120,7 +123,6 @@ class AzureAISearch(BaseVectorStore):
],
vector_search=vector_search,
)
self.index_client.create_or_update_index(
index,
)
@ -136,7 +138,7 @@ class AzureAISearch(BaseVectorStore):
if doc.vector is not None
]
if batch and len(batch) > 0:
if len(batch) > 0:
self.db_connection.upload_documents(batch)
def filter_by_id(self, include_ids: list[str] | list[int]) -> Any:

View File

@ -1,7 +1,7 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A package containing the supported vector store types."""
"""A package containing a factory and supported vector store types."""
from enum import Enum
from typing import ClassVar
@ -18,7 +18,7 @@ class VectorStoreType(str, Enum):
class VectorStoreFactory:
"""A factory class for creating vector stores."""
"""A factory class for vector stores."""
vector_store_types: ClassVar[dict[str, type]] = {}

View File

@ -3,14 +3,14 @@
"""The LanceDB vector storage implementation package."""
import lancedb as lancedb # noqa: I001 (Ruff was breaking on this file imports, even tho they were sorted and passed local tests)
from graphrag.model.types import TextEmbedder
import json
from typing import Any
import lancedb as lancedb
import pyarrow as pa
from graphrag.model.types import TextEmbedder
from .base import (
BaseVectorStore,
VectorStoreDocument,
@ -19,12 +19,21 @@ from .base import (
class LanceDBVectorStore(BaseVectorStore):
"""The LanceDB vector storage implementation."""
"""LanceDB vector storage implementation."""
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
def connect(self, **kwargs: Any) -> Any:
"""Connect to the vector storage."""
db_uri = kwargs.get("db_uri", "./lancedb")
self.db_connection = lancedb.connect(db_uri) # type: ignore
self.db_connection = lancedb.connect(kwargs["db_uri"])
if (
self.collection_name
and self.collection_name in self.db_connection.table_names()
):
self.document_collection = self.db_connection.open_table(
self.collection_name
)
def load_documents(
self, documents: list[VectorStoreDocument], overwrite: bool = True
@ -50,6 +59,9 @@ class LanceDBVectorStore(BaseVectorStore):
pa.field("vector", pa.list_(pa.float64())),
pa.field("attributes", pa.string()),
])
# NOTE: If modifying the next section of code, ensure that the schema remains the same.
# The pyarrow format of the 'vector' field may change if the order of operations is changed
# and will break vector search.
if overwrite:
if data:
self.document_collection = self.db_connection.create_table(
@ -87,14 +99,18 @@ class LanceDBVectorStore(BaseVectorStore):
"""Perform a vector-based similarity search."""
if self.query_filter:
docs = (
self.document_collection.search(query=query_embedding)
self.document_collection.search(
query=query_embedding, vector_column_name="vector"
)
.where(self.query_filter, prefilter=True)
.limit(k)
.to_list()
)
else:
docs = (
self.document_collection.search(query=query_embedding)
self.document_collection.search(
query=query_embedding, vector_column_name="vector"
)
.limit(k)
.to_list()
)

View File

@ -7,8 +7,6 @@ embeddings:
url: ${AZURE_AI_SEARCH_URL_ENDPOINT}
api_key: ${AZURE_AI_SEARCH_API_KEY}
collection_name: "azure_ci"
query_collection_name: "azure_ci_query"
entity_name_description:
title_column: "name"

View File

@ -4,7 +4,8 @@ input:
embeddings:
vector_store:
type: "lancedb"
uri_db: "./tests/fixtures/min-csv/lancedb"
db_uri: "./tests/fixtures/min-csv/lancedb"
collection_name: "lancedb_ci"
store_in_table: True
entity_name_description:

View File

@ -7,9 +7,7 @@ embeddings:
url: ${AZURE_AI_SEARCH_URL_ENDPOINT}
api_key: ${AZURE_AI_SEARCH_API_KEY}
collection_name: "simple_text_ci"
query_collection_name: "simple_text_ci_query"
store_in_table: True
entity_name_description:
title_column: "name"