mirror of
https://github.com/microsoft/graphrag.git
synced 2025-12-26 06:28:52 +00:00
Fix vector store logic and refactor audience parameter (#1259)
This commit is contained in:
parent
6aae386b30
commit
e0840a2dc4
2
.gitignore
vendored
2
.gitignore
vendored
@ -1,6 +1,7 @@
|
||||
# Python Artifacts
|
||||
python/*/lib/
|
||||
dist/
|
||||
|
||||
# Test Output
|
||||
.coverage
|
||||
coverage/
|
||||
@ -20,7 +21,6 @@ venv/
|
||||
.conda
|
||||
.tmp
|
||||
|
||||
|
||||
.env
|
||||
build.zip
|
||||
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "refactor use of vector stores and update support for managed identity"
|
||||
}
|
||||
@ -52,7 +52,7 @@ This is the base LLM configuration section. Other steps may override this config
|
||||
- `api_version` **str** - The API version
|
||||
- `organization` **str** - The client organization.
|
||||
- `proxy` **str** - The proxy URL to use.
|
||||
- `cognitive_services_endpoint` **str** - The url endpoint for cognitive services.
|
||||
- `audience` **str** - (Azure OpenAI only) The URI of the target Azure resource/service for which a managed identity token is requested. Used if `api_key` is not defined. Default=`https://cognitiveservices.azure.com/.default`
|
||||
- `deployment_name` **str** - The deployment name to use (Azure).
|
||||
- `model_supports_json` **bool** - Whether the model supports JSON-mode output.
|
||||
- `tokens_per_minute` **int** - Set a leaky-bucket throttle on tokens-per-minute.
|
||||
@ -84,9 +84,17 @@ This is the base LLM configuration section. Other steps may override this config
|
||||
- `parallelization` (see Parallelization top-level config)
|
||||
- `async_mode` (see Async Mode top-level config)
|
||||
- `batch_size` **int** - The maximum batch size to use.
|
||||
- `batch_max_tokens` **int** - The maximum batch #-tokens.
|
||||
- `batch_max_tokens` **int** - The maximum batch # of tokens.
|
||||
- `target` **required|all** - Determines which set of embeddings to emit.
|
||||
- `skip` **list[str]** - Which embeddings to skip.
|
||||
- `vector_store` **dict** - The vector store to use. Configured for lancedb by default.
|
||||
- `type` **str** - `lancedb` or `azure_ai_search`. Default=`lancedb`
|
||||
- `db_uri` **str** (only for lancedb) - The database uri. Default=`storage.base_dir/lancedb`
|
||||
- `url` **str** (only for AI Search) - AI Search endpoint
|
||||
- `api_key` **str** (optional - only for AI Search) - The AI Search api key to use.
|
||||
- `audience` **str** (only for AI Search) - Audience for managed identity token if managed identity authentication is used.
|
||||
- `overwrite` **bool** (only used at index creation time) - Overwrite collection if it exist. Default=`True`
|
||||
- `collection_name` **str** - The name of a vector collection. Default=`entity_description_embeddings`
|
||||
- `strategy` **dict** - Fully override the text-embedding strategy.
|
||||
|
||||
## chunks
|
||||
@ -214,7 +222,7 @@ This is the base LLM configuration section. Other steps may override this config
|
||||
|
||||
## encoding_model
|
||||
|
||||
**str** - The text encoding model to use. Default is `cl100k_base`.
|
||||
**str** - The text encoding model to use. Default=`cl100k_base`.
|
||||
|
||||
## skip_workflows
|
||||
|
||||
|
||||
@ -7,9 +7,9 @@ WARNING: This API is under development and may undergo changes in future release
|
||||
Backwards compatibility is not guaranteed at this time.
|
||||
"""
|
||||
|
||||
from .index_api import build_index
|
||||
from .prompt_tune_api import DocSelectionType, generate_indexing_prompts
|
||||
from .query_api import (
|
||||
from graphrag.api.index import build_index
|
||||
from graphrag.api.prompt_tune import DocSelectionType, generate_indexing_prompts
|
||||
from graphrag.api.query import (
|
||||
global_search,
|
||||
global_search_streaming,
|
||||
local_search,
|
||||
|
||||
@ -8,6 +8,8 @@ WARNING: This API is under development and may undergo changes in future release
|
||||
Backwards compatibility is not guaranteed at this time.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from graphrag.config import CacheType, GraphRagConfig
|
||||
from graphrag.index.cache.noop_pipeline_cache import NoopPipelineCache
|
||||
from graphrag.index.create_pipeline_config import create_pipeline_config
|
||||
@ -15,6 +17,7 @@ from graphrag.index.emit.types import TableEmitterType
|
||||
from graphrag.index.run import run_pipeline_with_config
|
||||
from graphrag.index.typing import PipelineRunResult
|
||||
from graphrag.logging import ProgressReporter
|
||||
from graphrag.vector_stores.factory import VectorStoreType
|
||||
|
||||
|
||||
async def build_index(
|
||||
@ -30,7 +33,7 @@ async def build_index(
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : PipelineConfig
|
||||
config : GraphRagConfig
|
||||
The configuration.
|
||||
run_id : str
|
||||
The run id. Creates a output directory with this name.
|
||||
@ -55,6 +58,14 @@ async def build_index(
|
||||
msg = "Cannot resume and update a run at the same time."
|
||||
raise ValueError(msg)
|
||||
|
||||
# TODO: must update filepath of lancedb (if used) until the new config engine has been implemented
|
||||
# TODO: remove the type ignore annotations below once the new config engine has been refactored
|
||||
vector_store_type = config.embeddings.vector_store["type"] # type: ignore
|
||||
if vector_store_type == VectorStoreType.LanceDB:
|
||||
db_uri = config.embeddings.vector_store["db_uri"] # type: ignore
|
||||
lancedb_dir = Path(config.root_dir).resolve() / db_uri
|
||||
config.embeddings.vector_store["db_uri"] = str(lancedb_dir) # type: ignore
|
||||
|
||||
pipeline_config = create_pipeline_config(config)
|
||||
pipeline_cache = (
|
||||
NoopPipelineCache() if config.cache.type == CacheType.none is None else None
|
||||
@ -26,7 +26,6 @@ from pydantic import validate_call
|
||||
|
||||
from graphrag.config import GraphRagConfig
|
||||
from graphrag.logging import PrintProgressReporter
|
||||
from graphrag.model.entity import Entity
|
||||
from graphrag.query.factories import get_global_search_engine, get_local_search_engine
|
||||
from graphrag.query.indexer_adapters import (
|
||||
read_indexer_covariates,
|
||||
@ -35,10 +34,9 @@ from graphrag.query.indexer_adapters import (
|
||||
read_indexer_reports,
|
||||
read_indexer_text_units,
|
||||
)
|
||||
from graphrag.query.input.loaders.dfs import store_entity_semantic_embeddings
|
||||
from graphrag.query.structured_search.base import SearchResult # noqa: TCH001
|
||||
from graphrag.vector_stores.lancedb import LanceDBVectorStore
|
||||
from graphrag.vector_stores.typing import VectorStoreFactory, VectorStoreType
|
||||
from graphrag.utils.cli import redact
|
||||
from graphrag.vector_stores import VectorStoreFactory, VectorStoreType
|
||||
|
||||
reporter = PrintProgressReporter("")
|
||||
|
||||
@ -184,24 +182,20 @@ async def local_search(
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
vector_store_args = (
|
||||
config.embeddings.vector_store if config.embeddings.vector_store else {}
|
||||
# TODO: must update filepath of lancedb (if used) until the new config engine has been implemented
|
||||
# TODO: remove the type ignore annotations below once the new config engine has been refactored
|
||||
vector_store_type = config.embeddings.vector_store.get("type") # type: ignore
|
||||
vector_store_args = config.embeddings.vector_store
|
||||
if vector_store_type == "lancedb":
|
||||
db_uri = config.embeddings.vector_store["db_uri"] # type: ignore
|
||||
lancedb_dir = Path(config.root_dir).resolve() / db_uri
|
||||
vector_store_args["db_uri"] = str(lancedb_dir) # type: ignore
|
||||
reporter.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore
|
||||
description_embedding_store = _get_embedding_description_store(
|
||||
config_args=vector_store_args, # type: ignore
|
||||
)
|
||||
reporter.info(f"Vector Store Args: {vector_store_args}")
|
||||
|
||||
vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB)
|
||||
|
||||
_entities = read_indexer_entities(nodes, entities, community_level)
|
||||
|
||||
lancedb_dir = Path(config.storage.base_dir) / "lancedb"
|
||||
|
||||
vector_store_args.update({"db_uri": str(lancedb_dir)})
|
||||
description_embedding_store = _get_embedding_description_store(
|
||||
entities=_entities,
|
||||
vector_store_type=vector_store_type,
|
||||
config_args=vector_store_args,
|
||||
)
|
||||
|
||||
_covariates = read_indexer_covariates(covariates) if covariates is not None else []
|
||||
|
||||
search_engine = get_local_search_engine(
|
||||
@ -257,24 +251,20 @@ async def local_search_streaming(
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
vector_store_args = (
|
||||
config.embeddings.vector_store if config.embeddings.vector_store else {}
|
||||
# TODO: must update filepath of lancedb (if used) until the new config engine has been implemented
|
||||
# TODO: remove the type ignore annotations below once the new config engine has been refactored
|
||||
vector_store_type = config.embeddings.vector_store["type"] # type: ignore
|
||||
vector_store_args = config.embeddings.vector_store
|
||||
if vector_store_type == VectorStoreType.LanceDB:
|
||||
db_uri = config.embeddings.vector_store["db_uri"] # type: ignore
|
||||
lancedb_dir = Path(config.root_dir).resolve() / db_uri
|
||||
vector_store_args["db_uri"] = str(lancedb_dir) # type: ignore
|
||||
reporter.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore
|
||||
description_embedding_store = _get_embedding_description_store(
|
||||
config_args=vector_store_args, # type: ignore
|
||||
)
|
||||
reporter.info(f"Vector Store Args: {vector_store_args}")
|
||||
|
||||
vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB)
|
||||
|
||||
_entities = read_indexer_entities(nodes, entities, community_level)
|
||||
|
||||
lancedb_dir = Path(config.storage.base_dir) / "lancedb"
|
||||
|
||||
vector_store_args.update({"db_uri": str(lancedb_dir)})
|
||||
description_embedding_store = _get_embedding_description_store(
|
||||
entities=_entities,
|
||||
vector_store_type=vector_store_type,
|
||||
config_args=vector_store_args,
|
||||
)
|
||||
|
||||
_covariates = read_indexer_covariates(covariates) if covariates is not None else []
|
||||
|
||||
search_engine = get_local_search_engine(
|
||||
@ -303,49 +293,14 @@ async def local_search_streaming(
|
||||
|
||||
|
||||
def _get_embedding_description_store(
|
||||
entities: list[Entity],
|
||||
vector_store_type: str = VectorStoreType.LanceDB,
|
||||
config_args: dict | None = None,
|
||||
config_args: dict,
|
||||
):
|
||||
"""Get the embedding description store."""
|
||||
if not config_args:
|
||||
config_args = {}
|
||||
|
||||
collection_name = config_args.get(
|
||||
"query_collection_name", "entity_description_embeddings"
|
||||
)
|
||||
config_args.update({"collection_name": collection_name})
|
||||
vector_store_type = config_args["type"]
|
||||
description_embedding_store = VectorStoreFactory.get_vector_store(
|
||||
vector_store_type=vector_store_type, kwargs=config_args
|
||||
)
|
||||
|
||||
description_embedding_store.connect(**config_args)
|
||||
|
||||
if config_args.get("overwrite", True):
|
||||
# this step assumes the embeddings were originally stored in a file rather
|
||||
# than a vector database
|
||||
|
||||
# dump embeddings from the entities list to the description_embedding_store
|
||||
store_entity_semantic_embeddings(
|
||||
entities=entities, vectorstore=description_embedding_store
|
||||
)
|
||||
else:
|
||||
# load description embeddings to an in-memory lancedb vectorstore
|
||||
# and connect to a remote db, specify url and port values.
|
||||
description_embedding_store = LanceDBVectorStore(
|
||||
collection_name=collection_name
|
||||
)
|
||||
description_embedding_store.connect(
|
||||
db_uri=config_args.get("db_uri", "./lancedb")
|
||||
)
|
||||
|
||||
# load data from an existing table
|
||||
description_embedding_store.document_collection = (
|
||||
description_embedding_store.db_connection.open_table(
|
||||
description_embedding_store.collection_name
|
||||
)
|
||||
)
|
||||
|
||||
return description_embedding_store
|
||||
|
||||
|
||||
@ -83,10 +83,7 @@ def create_graphrag_config(
|
||||
llm_type = LLMType(llm_type) if llm_type else base.type
|
||||
api_key = reader.str(Fragment.api_key) or base.api_key
|
||||
api_base = reader.str(Fragment.api_base) or base.api_base
|
||||
cognitive_services_endpoint = (
|
||||
reader.str(Fragment.cognitive_services_endpoint)
|
||||
or base.cognitive_services_endpoint
|
||||
)
|
||||
audience = reader.str(Fragment.audience) or base.audience
|
||||
deployment_name = (
|
||||
reader.str(Fragment.deployment_name) or base.deployment_name
|
||||
)
|
||||
@ -119,7 +116,7 @@ def create_graphrag_config(
|
||||
or base.model_supports_json,
|
||||
request_timeout=reader.float(Fragment.request_timeout)
|
||||
or base.request_timeout,
|
||||
cognitive_services_endpoint=cognitive_services_endpoint,
|
||||
audience=audience,
|
||||
deployment_name=deployment_name,
|
||||
tokens_per_minute=reader.int("tokens_per_minute", Fragment.tpm)
|
||||
or base.tokens_per_minute,
|
||||
@ -141,7 +138,7 @@ def create_graphrag_config(
|
||||
api_type = LLMType(api_type) if api_type else defs.LLM_TYPE
|
||||
api_key = reader.str(Fragment.api_key) or base.api_key
|
||||
|
||||
# In a unique events where:
|
||||
# Account for various permutations of config settings such as:
|
||||
# - same api_bases for LLM and embeddings (both Azure)
|
||||
# - different api_bases for LLM and embeddings (both Azure)
|
||||
# - LLM uses Azure OpenAI, while embeddings uses base OpenAI (this one is important)
|
||||
@ -158,10 +155,7 @@ def create_graphrag_config(
|
||||
)
|
||||
api_organization = reader.str("organization") or base.organization
|
||||
api_proxy = reader.str("proxy") or base.proxy
|
||||
cognitive_services_endpoint = (
|
||||
reader.str(Fragment.cognitive_services_endpoint)
|
||||
or base.cognitive_services_endpoint
|
||||
)
|
||||
audience = reader.str(Fragment.audience) or base.audience
|
||||
deployment_name = reader.str(Fragment.deployment_name)
|
||||
|
||||
if api_key is None and not _is_azure(api_type):
|
||||
@ -186,7 +180,7 @@ def create_graphrag_config(
|
||||
model=reader.str(Fragment.model) or defs.EMBEDDING_MODEL,
|
||||
request_timeout=reader.float(Fragment.request_timeout)
|
||||
or defs.LLM_REQUEST_TIMEOUT,
|
||||
cognitive_services_endpoint=cognitive_services_endpoint,
|
||||
audience=audience,
|
||||
deployment_name=deployment_name,
|
||||
tokens_per_minute=reader.int("tokens_per_minute", Fragment.tpm)
|
||||
or defs.LLM_TOKENS_PER_MINUTE,
|
||||
@ -237,9 +231,7 @@ def create_graphrag_config(
|
||||
api_base = reader.str(Fragment.api_base) or fallback_oai_base
|
||||
api_version = reader.str(Fragment.api_version) or fallback_oai_version
|
||||
api_proxy = reader.str(Fragment.api_proxy) or fallback_oai_proxy
|
||||
cognitive_services_endpoint = reader.str(
|
||||
Fragment.cognitive_services_endpoint
|
||||
)
|
||||
audience = reader.str(Fragment.audience)
|
||||
deployment_name = reader.str(Fragment.deployment_name)
|
||||
|
||||
if api_key is None and not _is_azure(llm_type):
|
||||
@ -270,7 +262,7 @@ def create_graphrag_config(
|
||||
model_supports_json=reader.bool(Fragment.model_supports_json),
|
||||
request_timeout=reader.float(Fragment.request_timeout)
|
||||
or defs.LLM_REQUEST_TIMEOUT,
|
||||
cognitive_services_endpoint=cognitive_services_endpoint,
|
||||
audience=audience,
|
||||
deployment_name=deployment_name,
|
||||
tokens_per_minute=reader.int(Fragment.tpm)
|
||||
or defs.LLM_TOKENS_PER_MINUTE,
|
||||
@ -294,13 +286,15 @@ def create_graphrag_config(
|
||||
embeddings_config = values.get("embeddings") or {}
|
||||
with reader.envvar_prefix(Section.embedding), reader.use(embeddings_config):
|
||||
embeddings_target = reader.str("target")
|
||||
# TODO: remove the type ignore annotations below once the new config engine has been refactored
|
||||
embeddings_model = TextEmbeddingConfig(
|
||||
llm=hydrate_embeddings_params(embeddings_config, llm_model),
|
||||
llm=hydrate_embeddings_params(embeddings_config, llm_model), # type: ignore
|
||||
parallelization=hydrate_parallelization_params(
|
||||
embeddings_config, llm_parallelization_model
|
||||
embeddings_config, # type: ignore
|
||||
llm_parallelization_model, # type: ignore
|
||||
),
|
||||
vector_store=embeddings_config.get("vector_store", None),
|
||||
async_mode=hydrate_async_type(embeddings_config, async_mode),
|
||||
async_mode=hydrate_async_type(embeddings_config, async_mode), # type: ignore
|
||||
target=(
|
||||
TextEmbeddingTarget(embeddings_target)
|
||||
if embeddings_target
|
||||
@ -579,8 +573,8 @@ class Fragment(str, Enum):
|
||||
api_organization = "API_ORGANIZATION"
|
||||
api_proxy = "API_PROXY"
|
||||
async_mode = "ASYNC_MODE"
|
||||
audience = "AUDIENCE"
|
||||
base_dir = "BASE_DIR"
|
||||
cognitive_services_endpoint = "COGNITIVE_SERVICES_ENDPOINT"
|
||||
concurrent_requests = "CONCURRENT_REQUESTS"
|
||||
conn_string = "CONNECTION_STRING"
|
||||
container_name = "CONTAINER_NAME"
|
||||
|
||||
@ -3,8 +3,12 @@
|
||||
|
||||
"""Common default configuration values."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from datashaper import AsyncType
|
||||
|
||||
from graphrag.vector_stores import VectorStoreType
|
||||
|
||||
from .enums import (
|
||||
CacheType,
|
||||
InputFileType,
|
||||
@ -74,7 +78,7 @@ NODE2VEC_WINDOW_SIZE = 2
|
||||
NODE2VEC_ITERATIONS = 3
|
||||
NODE2VEC_RANDOM_SEED = 597832
|
||||
REPORTING_TYPE = ReportingType.file
|
||||
REPORTING_BASE_DIR = "output"
|
||||
REPORTING_BASE_DIR = "logs"
|
||||
SNAPSHOTS_GRAPHML = False
|
||||
SNAPSHOTS_RAW_ENTITIES = False
|
||||
SNAPSHOTS_TOP_LEVEL_NODES = False
|
||||
@ -83,6 +87,13 @@ STORAGE_TYPE = StorageType.file
|
||||
SUMMARIZE_DESCRIPTIONS_MAX_LENGTH = 500
|
||||
UMAP_ENABLED = False
|
||||
|
||||
VECTOR_STORE = f"""
|
||||
type: {VectorStoreType.LanceDB.value}
|
||||
db_uri: '{(Path(STORAGE_BASE_DIR) / "lancedb")!s}'
|
||||
collection_name: entity_description_embeddings
|
||||
overwrite: true\
|
||||
"""
|
||||
|
||||
# Local Search
|
||||
LOCAL_SEARCH_TEXT_UNIT_PROP = 0.5
|
||||
LOCAL_SEARCH_COMMUNITY_PROP = 0.1
|
||||
|
||||
@ -20,7 +20,7 @@ class LLMParametersInput(TypedDict):
|
||||
api_version: NotRequired[str | None]
|
||||
organization: NotRequired[str | None]
|
||||
proxy: NotRequired[str | None]
|
||||
cognitive_services_endpoint: NotRequired[str | None]
|
||||
audience: NotRequired[str | None]
|
||||
deployment_name: NotRequired[str | None]
|
||||
model_supports_json: NotRequired[bool | str | None]
|
||||
tokens_per_minute: NotRequired[int | str | None]
|
||||
|
||||
@ -52,8 +52,9 @@ class LLMParameters(BaseModel):
|
||||
proxy: str | None = Field(
|
||||
description="The proxy to use for the LLM service.", default=None
|
||||
)
|
||||
cognitive_services_endpoint: str | None = Field(
|
||||
description="The endpoint to reach cognitives services.", default=None
|
||||
audience: str | None = Field(
|
||||
description="Azure resource URI to use with managed identity for the llm connection.",
|
||||
default=None,
|
||||
)
|
||||
deployment_name: str | None = Field(
|
||||
description="The deployment name to use for the LLM service.", default=None
|
||||
|
||||
@ -27,7 +27,7 @@ class TextEmbeddingConfig(LLMConfig):
|
||||
)
|
||||
skip: list[str] = Field(description="The specific embeddings to skip.", default=[])
|
||||
vector_store: dict | None = Field(
|
||||
description="The vector storage configuration", default=None
|
||||
description="The vector storage configuration", default=defs.VECTOR_STORE
|
||||
)
|
||||
strategy: dict | None = Field(
|
||||
description="The override strategy to use.", default=None
|
||||
|
||||
@ -4,7 +4,6 @@
|
||||
"""Main definition."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
@ -19,6 +18,7 @@ from graphrag.config import (
|
||||
resolve_paths,
|
||||
)
|
||||
from graphrag.logging import ProgressReporter, ReporterType, create_progress_reporter
|
||||
from graphrag.utils.cli import redact
|
||||
|
||||
from .emit.types import TableEmitterType
|
||||
from .graph.extractors.claims.prompts import CLAIM_EXTRACTION_PROMPT
|
||||
@ -34,36 +34,6 @@ warnings.filterwarnings("ignore", message=".*NumbaDeprecationWarning.*")
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _redact(input: dict) -> str:
|
||||
"""Sanitize the config json."""
|
||||
|
||||
# Redact any sensitive configuration
|
||||
def redact_dict(input: dict) -> dict:
|
||||
if not isinstance(input, dict):
|
||||
return input
|
||||
|
||||
result = {}
|
||||
for key, value in input.items():
|
||||
if key in {
|
||||
"api_key",
|
||||
"connection_string",
|
||||
"container_name",
|
||||
"organization",
|
||||
}:
|
||||
if value is not None:
|
||||
result[key] = "==== REDACTED ===="
|
||||
elif isinstance(value, dict):
|
||||
result[key] = redact_dict(value)
|
||||
elif isinstance(value, list):
|
||||
result[key] = [redact_dict(i) for i in value]
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
redacted_dict = redact_dict(input)
|
||||
return json.dumps(redacted_dict, indent=4)
|
||||
|
||||
|
||||
def _logger(reporter: ProgressReporter):
|
||||
def info(msg: str, verbose: bool = False):
|
||||
log.info(msg)
|
||||
@ -140,7 +110,7 @@ def index_cli(
|
||||
info(f"Logging enabled at {log_path}", True)
|
||||
else:
|
||||
info(
|
||||
f"Logging not enabled for config {_redact(config.model_dump())}",
|
||||
f"Logging not enabled for config {redact(config.model_dump())}",
|
||||
True,
|
||||
)
|
||||
|
||||
@ -149,7 +119,7 @@ def index_cli(
|
||||
|
||||
info(f"Starting pipeline run for: {run_id}, {dryrun=}", verbose)
|
||||
info(
|
||||
f"Using default configuration: {_redact(config.model_dump())}",
|
||||
f"Using default configuration: {redact(config.model_dump())}",
|
||||
verbose,
|
||||
)
|
||||
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
"""Content for the init CLI command."""
|
||||
|
||||
"""Content for the init CLI command to generate a default configuration."""
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
|
||||
INIT_YAML = f"""
|
||||
INIT_YAML = f"""\
|
||||
encoding_model: cl100k_base
|
||||
skip_workflows: []
|
||||
llm:
|
||||
@ -12,6 +13,7 @@ llm:
|
||||
type: {defs.LLM_TYPE.value} # or azure_openai_chat
|
||||
model: {defs.LLM_MODEL}
|
||||
model_supports_json: true # recommended if this is available for your model.
|
||||
# audience: "https://cognitiveservices.azure.com/.default"
|
||||
# max_tokens: {defs.LLM_MAX_TOKENS}
|
||||
# request_timeout: {defs.LLM_REQUEST_TIMEOUT}
|
||||
# api_base: https://<instance>.openai.azure.com
|
||||
@ -40,12 +42,21 @@ embeddings:
|
||||
# target: {defs.EMBEDDING_TARGET.value} # or all
|
||||
# batch_size: {defs.EMBEDDING_BATCH_SIZE} # the number of documents to send in a single request
|
||||
# batch_max_tokens: {defs.EMBEDDING_BATCH_MAX_TOKENS} # the maximum number of tokens to send in a single request
|
||||
vector_store:{defs.VECTOR_STORE}
|
||||
# vector_store: # configuration for AI Search
|
||||
# type: azure_ai_search
|
||||
# url: <ai_search_endpoint>
|
||||
# api_key: <api_key> # if not set, will attempt to use managed identity. Expects the `Search Index Data Contributor` RBAC role in this case.
|
||||
# audience: <optional> # if using managed identity, the audience to use for the token
|
||||
# overwrite: true # or false. Only applicable at index creation time
|
||||
# collection_name: <collection_name> # the name of the collection to use
|
||||
llm:
|
||||
api_key: ${{GRAPHRAG_API_KEY}}
|
||||
type: {defs.EMBEDDING_TYPE.value} # or azure_openai_embedding
|
||||
model: {defs.EMBEDDING_MODEL}
|
||||
# api_base: https://<instance>.openai.azure.com
|
||||
# api_version: 2024-02-15-preview
|
||||
# audience: "https://cognitiveservices.azure.com/.default"
|
||||
# organization: <organization_id>
|
||||
# deployment_name: <azure_model_deployment_name>
|
||||
# tokens_per_minute: 150_000 # set a leaky bucket throttle
|
||||
@ -160,6 +171,6 @@ global_search:
|
||||
# concurrency: {defs.GLOBAL_SEARCH_CONCURRENCY}
|
||||
"""
|
||||
|
||||
INIT_DOTENV = """
|
||||
INIT_DOTENV = """\
|
||||
GRAPHRAG_API_KEY=<API_KEY>
|
||||
"""
|
||||
|
||||
@ -199,7 +199,7 @@ def _get_base_config(config: dict[str, Any]) -> dict[str, Any]:
|
||||
"model_supports_json": config.get("model_supports_json"),
|
||||
"concurrent_requests": config.get("concurrent_requests", 4),
|
||||
"encoding_model": config.get("encoding_model", "cl100k_base"),
|
||||
"cognitive_services_endpoint": config.get("cognitive_services_endpoint"),
|
||||
"audience": config.get("audience"),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -231,7 +231,7 @@ def _get_collection_name(vector_store_config: dict, embedding_name: str) -> str:
|
||||
collection_names = vector_store_config.get("collection_names", {})
|
||||
collection_name = collection_names.get(embedding_name, embedding_name)
|
||||
|
||||
msg = f"using {vector_store_config.get('type')} collection_name {collection_name} for embedding {embedding_name}"
|
||||
msg = f"using vector store {vector_store_config.get('type')} with collection_name {collection_name} for embedding {embedding_name}"
|
||||
log.info(msg)
|
||||
return collection_name
|
||||
|
||||
|
||||
@ -32,15 +32,16 @@ def create_openai_client(
|
||||
api_base,
|
||||
configuration.deployment_name,
|
||||
)
|
||||
if configuration.cognitive_services_endpoint is None:
|
||||
cognitive_services_endpoint = "https://cognitiveservices.azure.com/.default"
|
||||
else:
|
||||
cognitive_services_endpoint = configuration.cognitive_services_endpoint
|
||||
audience = (
|
||||
configuration.audience
|
||||
if configuration.audience
|
||||
else "https://cognitiveservices.azure.com/.default"
|
||||
)
|
||||
|
||||
return AsyncAzureOpenAI(
|
||||
api_key=configuration.api_key if configuration.api_key else None,
|
||||
azure_ad_token_provider=get_bearer_token_provider(
|
||||
DefaultAzureCredential(), cognitive_services_endpoint
|
||||
DefaultAzureCredential(), audience
|
||||
)
|
||||
if not configuration.api_key
|
||||
else None,
|
||||
|
||||
@ -26,7 +26,7 @@ class OpenAIConfiguration(Hashable, LLMConfig):
|
||||
|
||||
_api_base: str | None
|
||||
_api_version: str | None
|
||||
_cognitive_services_endpoint: str | None
|
||||
_audience: str | None
|
||||
_deployment_name: str | None
|
||||
_organization: str | None
|
||||
_proxy: str | None
|
||||
@ -103,7 +103,7 @@ class OpenAIConfiguration(Hashable, LLMConfig):
|
||||
self._deployment_name = lookup_str("deployment_name")
|
||||
self._api_base = lookup_str("api_base")
|
||||
self._api_version = lookup_str("api_version")
|
||||
self._cognitive_services_endpoint = lookup_str("cognitive_services_endpoint")
|
||||
self._audience = lookup_str("audience")
|
||||
self._organization = lookup_str("organization")
|
||||
self._proxy = lookup_str("proxy")
|
||||
self._n = lookup_int("n")
|
||||
@ -156,9 +156,9 @@ class OpenAIConfiguration(Hashable, LLMConfig):
|
||||
return _non_blank(self._api_version)
|
||||
|
||||
@property
|
||||
def cognitive_services_endpoint(self) -> str | None:
|
||||
def audience(self) -> str | None:
|
||||
"""API version property definition."""
|
||||
return _non_blank(self._cognitive_services_endpoint)
|
||||
return _non_blank(self._audience)
|
||||
|
||||
@property
|
||||
def organization(self) -> str | None:
|
||||
|
||||
@ -44,17 +44,16 @@ def get_llm(config: GraphRagConfig) -> ChatOpenAI:
|
||||
**config.llm.model_dump(),
|
||||
"api_key": f"REDACTED,len={len(debug_llm_key)}",
|
||||
}
|
||||
if config.llm.cognitive_services_endpoint is None:
|
||||
cognitive_services_endpoint = "https://cognitiveservices.azure.com/.default"
|
||||
else:
|
||||
cognitive_services_endpoint = config.llm.cognitive_services_endpoint
|
||||
audience = (
|
||||
config.llm.audience
|
||||
if config.llm.audience
|
||||
else "https://cognitiveservices.azure.com/.default"
|
||||
)
|
||||
print(f"creating llm client with {llm_debug_info}") # noqa T201
|
||||
return ChatOpenAI(
|
||||
api_key=config.llm.api_key,
|
||||
azure_ad_token_provider=(
|
||||
get_bearer_token_provider(
|
||||
DefaultAzureCredential(), cognitive_services_endpoint
|
||||
)
|
||||
get_bearer_token_provider(DefaultAzureCredential(), audience)
|
||||
if is_azure_client and not config.llm.api_key
|
||||
else None
|
||||
),
|
||||
@ -77,17 +76,15 @@ def get_text_embedder(config: GraphRagConfig) -> OpenAIEmbedding:
|
||||
**config.embeddings.llm.model_dump(),
|
||||
"api_key": f"REDACTED,len={len(debug_embedding_api_key)}",
|
||||
}
|
||||
if config.embeddings.llm.cognitive_services_endpoint is None:
|
||||
cognitive_services_endpoint = "https://cognitiveservices.azure.com/.default"
|
||||
if config.embeddings.llm.audience is None:
|
||||
audience = "https://cognitiveservices.azure.com/.default"
|
||||
else:
|
||||
cognitive_services_endpoint = config.embeddings.llm.cognitive_services_endpoint
|
||||
audience = config.embeddings.llm.audience
|
||||
print(f"creating embedding llm client with {llm_debug_info}") # noqa T201
|
||||
return OpenAIEmbedding(
|
||||
api_key=config.embeddings.llm.api_key,
|
||||
azure_ad_token_provider=(
|
||||
get_bearer_token_provider(
|
||||
DefaultAzureCredential(), cognitive_services_endpoint
|
||||
)
|
||||
get_bearer_token_provider(DefaultAzureCredential(), audience)
|
||||
if is_azure_client and not config.embeddings.llm.api_key
|
||||
else None
|
||||
),
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
"""CLI functions for the GraphRAG module."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@ -21,3 +22,33 @@ def dir_exist(path):
|
||||
msg = f"Directory not found: {path}"
|
||||
raise argparse.ArgumentTypeError(msg)
|
||||
return path
|
||||
|
||||
|
||||
def redact(config: dict) -> str:
|
||||
"""Sanitize secrets in a config object."""
|
||||
|
||||
# Redact any sensitive configuration
|
||||
def redact_dict(config: dict) -> dict:
|
||||
if not isinstance(config, dict):
|
||||
return config
|
||||
|
||||
result = {}
|
||||
for key, value in config.items():
|
||||
if key in {
|
||||
"api_key",
|
||||
"connection_string",
|
||||
"container_name",
|
||||
"organization",
|
||||
}:
|
||||
if value is not None:
|
||||
result[key] = "==== REDACTED ===="
|
||||
elif isinstance(value, dict):
|
||||
result[key] = redact_dict(value)
|
||||
elif isinstance(value, list):
|
||||
result[key] = [redact_dict(i) for i in value]
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
redacted_dict = redact_dict(config)
|
||||
return json.dumps(redacted_dict, indent=4)
|
||||
|
||||
@ -3,15 +3,15 @@
|
||||
|
||||
"""A module containing vector storage implementations."""
|
||||
|
||||
from .azure_ai_search import AzureAISearch
|
||||
from .base import BaseVectorStore, VectorStoreDocument, VectorStoreSearchResult
|
||||
from .lancedb import LanceDBVectorStore
|
||||
from .typing import VectorStoreFactory, VectorStoreType
|
||||
from graphrag.vector_stores.base import (
|
||||
BaseVectorStore,
|
||||
VectorStoreDocument,
|
||||
VectorStoreSearchResult,
|
||||
)
|
||||
from graphrag.vector_stores.factory import VectorStoreFactory, VectorStoreType
|
||||
|
||||
__all__ = [
|
||||
"AzureAISearch",
|
||||
"BaseVectorStore",
|
||||
"LanceDBVectorStore",
|
||||
"VectorStoreDocument",
|
||||
"VectorStoreFactory",
|
||||
"VectorStoreSearchResult",
|
||||
|
||||
@ -35,13 +35,16 @@ from .base import (
|
||||
|
||||
|
||||
class AzureAISearch(BaseVectorStore):
|
||||
"""The Azure AI Search vector storage implementation."""
|
||||
"""Azure AI Search vector storage implementation."""
|
||||
|
||||
index_client: SearchIndexClient
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def connect(self, **kwargs: Any) -> Any:
|
||||
"""Connect to the AzureAI vector store."""
|
||||
url = kwargs.get("url")
|
||||
"""Connect to AI search vector storage."""
|
||||
url = kwargs["url"]
|
||||
api_key = kwargs.get("api_key")
|
||||
audience = kwargs.get("audience")
|
||||
self.vector_size = kwargs.get("vector_size", DEFAULT_VECTOR_SIZE)
|
||||
@ -51,7 +54,7 @@ class AzureAISearch(BaseVectorStore):
|
||||
)
|
||||
|
||||
if url:
|
||||
audience_arg = {"audience": audience} if audience else {}
|
||||
audience_arg = {"audience": audience} if audience and not api_key else {}
|
||||
self.db_connection = SearchClient(
|
||||
endpoint=url,
|
||||
index_name=self.collection_name,
|
||||
@ -68,18 +71,18 @@ class AzureAISearch(BaseVectorStore):
|
||||
**audience_arg,
|
||||
)
|
||||
else:
|
||||
not_supported_error = "AAISearchDBClient is not supported on local host."
|
||||
not_supported_error = "Azure AI Search expects `url`."
|
||||
raise ValueError(not_supported_error)
|
||||
|
||||
def load_documents(
|
||||
self, documents: list[VectorStoreDocument], overwrite: bool = True
|
||||
) -> None:
|
||||
"""Load documents into the Azure AI Search index."""
|
||||
"""Load documents into an Azure AI Search index."""
|
||||
if overwrite:
|
||||
if self.collection_name in self.index_client.list_index_names():
|
||||
self.index_client.delete_index(self.collection_name)
|
||||
|
||||
# Configure the vector search profile
|
||||
# Configure vector search profile
|
||||
vector_search = VectorSearch(
|
||||
algorithms=[
|
||||
HnswAlgorithmConfiguration(
|
||||
@ -96,7 +99,7 @@ class AzureAISearch(BaseVectorStore):
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# Configure the index
|
||||
index = SearchIndex(
|
||||
name=self.collection_name,
|
||||
fields=[
|
||||
@ -120,7 +123,6 @@ class AzureAISearch(BaseVectorStore):
|
||||
],
|
||||
vector_search=vector_search,
|
||||
)
|
||||
|
||||
self.index_client.create_or_update_index(
|
||||
index,
|
||||
)
|
||||
@ -136,7 +138,7 @@ class AzureAISearch(BaseVectorStore):
|
||||
if doc.vector is not None
|
||||
]
|
||||
|
||||
if batch and len(batch) > 0:
|
||||
if len(batch) > 0:
|
||||
self.db_connection.upload_documents(batch)
|
||||
|
||||
def filter_by_id(self, include_ids: list[str] | list[int]) -> Any:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A package containing the supported vector store types."""
|
||||
"""A package containing a factory and supported vector store types."""
|
||||
|
||||
from enum import Enum
|
||||
from typing import ClassVar
|
||||
@ -18,7 +18,7 @@ class VectorStoreType(str, Enum):
|
||||
|
||||
|
||||
class VectorStoreFactory:
|
||||
"""A factory class for creating vector stores."""
|
||||
"""A factory class for vector stores."""
|
||||
|
||||
vector_store_types: ClassVar[dict[str, type]] = {}
|
||||
|
||||
@ -3,14 +3,14 @@
|
||||
|
||||
"""The LanceDB vector storage implementation package."""
|
||||
|
||||
import lancedb as lancedb # noqa: I001 (Ruff was breaking on this file imports, even tho they were sorted and passed local tests)
|
||||
from graphrag.model.types import TextEmbedder
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import lancedb as lancedb
|
||||
import pyarrow as pa
|
||||
|
||||
from graphrag.model.types import TextEmbedder
|
||||
|
||||
from .base import (
|
||||
BaseVectorStore,
|
||||
VectorStoreDocument,
|
||||
@ -19,12 +19,21 @@ from .base import (
|
||||
|
||||
|
||||
class LanceDBVectorStore(BaseVectorStore):
|
||||
"""The LanceDB vector storage implementation."""
|
||||
"""LanceDB vector storage implementation."""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def connect(self, **kwargs: Any) -> Any:
|
||||
"""Connect to the vector storage."""
|
||||
db_uri = kwargs.get("db_uri", "./lancedb")
|
||||
self.db_connection = lancedb.connect(db_uri) # type: ignore
|
||||
self.db_connection = lancedb.connect(kwargs["db_uri"])
|
||||
if (
|
||||
self.collection_name
|
||||
and self.collection_name in self.db_connection.table_names()
|
||||
):
|
||||
self.document_collection = self.db_connection.open_table(
|
||||
self.collection_name
|
||||
)
|
||||
|
||||
def load_documents(
|
||||
self, documents: list[VectorStoreDocument], overwrite: bool = True
|
||||
@ -50,6 +59,9 @@ class LanceDBVectorStore(BaseVectorStore):
|
||||
pa.field("vector", pa.list_(pa.float64())),
|
||||
pa.field("attributes", pa.string()),
|
||||
])
|
||||
# NOTE: If modifying the next section of code, ensure that the schema remains the same.
|
||||
# The pyarrow format of the 'vector' field may change if the order of operations is changed
|
||||
# and will break vector search.
|
||||
if overwrite:
|
||||
if data:
|
||||
self.document_collection = self.db_connection.create_table(
|
||||
@ -87,14 +99,18 @@ class LanceDBVectorStore(BaseVectorStore):
|
||||
"""Perform a vector-based similarity search."""
|
||||
if self.query_filter:
|
||||
docs = (
|
||||
self.document_collection.search(query=query_embedding)
|
||||
self.document_collection.search(
|
||||
query=query_embedding, vector_column_name="vector"
|
||||
)
|
||||
.where(self.query_filter, prefilter=True)
|
||||
.limit(k)
|
||||
.to_list()
|
||||
)
|
||||
else:
|
||||
docs = (
|
||||
self.document_collection.search(query=query_embedding)
|
||||
self.document_collection.search(
|
||||
query=query_embedding, vector_column_name="vector"
|
||||
)
|
||||
.limit(k)
|
||||
.to_list()
|
||||
)
|
||||
|
||||
2
tests/fixtures/azure/settings.yml
vendored
2
tests/fixtures/azure/settings.yml
vendored
@ -7,8 +7,6 @@ embeddings:
|
||||
url: ${AZURE_AI_SEARCH_URL_ENDPOINT}
|
||||
api_key: ${AZURE_AI_SEARCH_API_KEY}
|
||||
collection_name: "azure_ci"
|
||||
query_collection_name: "azure_ci_query"
|
||||
|
||||
entity_name_description:
|
||||
title_column: "name"
|
||||
|
||||
|
||||
3
tests/fixtures/min-csv/settings.yml
vendored
3
tests/fixtures/min-csv/settings.yml
vendored
@ -4,7 +4,8 @@ input:
|
||||
embeddings:
|
||||
vector_store:
|
||||
type: "lancedb"
|
||||
uri_db: "./tests/fixtures/min-csv/lancedb"
|
||||
db_uri: "./tests/fixtures/min-csv/lancedb"
|
||||
collection_name: "lancedb_ci"
|
||||
store_in_table: True
|
||||
|
||||
entity_name_description:
|
||||
|
||||
2
tests/fixtures/text/settings.yml
vendored
2
tests/fixtures/text/settings.yml
vendored
@ -7,9 +7,7 @@ embeddings:
|
||||
url: ${AZURE_AI_SEARCH_URL_ENDPOINT}
|
||||
api_key: ${AZURE_AI_SEARCH_API_KEY}
|
||||
collection_name: "simple_text_ci"
|
||||
query_collection_name: "simple_text_ci_query"
|
||||
store_in_table: True
|
||||
|
||||
entity_name_description:
|
||||
title_column: "name"
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user