mirror of
https://github.com/microsoft/graphrag.git
synced 2025-11-09 22:34:32 +00:00
* initialize config with LocalSearchConfig and GlobalSearchConfig * init_content LocalSearchConfig and GlobalSearchConfig * rollback MAP_SYSTEM_PROMPT * Small changes before merging. Notebook rollback * Semver --------- Co-authored-by: glide-the <2533736852@qq.com>
643 lines
28 KiB
Python
643 lines
28 KiB
Python
# Copyright (c) 2024 Microsoft Corporation.
|
|
# Licensed under the MIT License
|
|
|
|
"""Parameterization settings for the default configuration, loaded from environment variables."""
|
|
|
|
import os
|
|
from enum import Enum
|
|
from pathlib import Path
|
|
from typing import cast
|
|
|
|
from datashaper import AsyncType
|
|
from environs import Env
|
|
from pydantic import TypeAdapter
|
|
|
|
import graphrag.config.defaults as defs
|
|
|
|
from .enums import (
|
|
CacheType,
|
|
InputFileType,
|
|
InputType,
|
|
LLMType,
|
|
ReportingType,
|
|
StorageType,
|
|
TextEmbeddingTarget,
|
|
)
|
|
from .environment_reader import EnvironmentReader
|
|
from .errors import (
|
|
ApiKeyMissingError,
|
|
AzureApiBaseMissingError,
|
|
AzureDeploymentNameMissingError,
|
|
)
|
|
from .input_models import (
|
|
GraphRagConfigInput,
|
|
LLMConfigInput,
|
|
)
|
|
from .models import (
|
|
CacheConfig,
|
|
ChunkingConfig,
|
|
ClaimExtractionConfig,
|
|
ClusterGraphConfig,
|
|
CommunityReportsConfig,
|
|
EmbedGraphConfig,
|
|
EntityExtractionConfig,
|
|
GlobalSearchConfig,
|
|
GraphRagConfig,
|
|
InputConfig,
|
|
LLMParameters,
|
|
LocalSearchConfig,
|
|
ParallelizationParameters,
|
|
ReportingConfig,
|
|
SnapshotsConfig,
|
|
StorageConfig,
|
|
SummarizeDescriptionsConfig,
|
|
TextEmbeddingConfig,
|
|
UmapConfig,
|
|
)
|
|
from .read_dotenv import read_dotenv
|
|
|
|
InputModelValidator = TypeAdapter(GraphRagConfigInput)
|
|
|
|
|
|
def create_graphrag_config(
|
|
values: GraphRagConfigInput | None = None, root_dir: str | None = None
|
|
) -> GraphRagConfig:
|
|
"""Load Configuration Parameters from a dictionary."""
|
|
values = values or {}
|
|
root_dir = root_dir or str(Path.cwd())
|
|
env = _make_env(root_dir)
|
|
_token_replace(cast(dict, values))
|
|
InputModelValidator.validate_python(values, strict=True)
|
|
|
|
reader = EnvironmentReader(env)
|
|
|
|
def hydrate_async_type(input: LLMConfigInput, base: AsyncType) -> AsyncType:
|
|
value = input.get(Fragment.async_mode)
|
|
return AsyncType(value) if value else base
|
|
|
|
def hydrate_llm_params(
|
|
config: LLMConfigInput, base: LLMParameters
|
|
) -> LLMParameters:
|
|
with reader.use(config.get("llm")):
|
|
llm_type = reader.str(Fragment.type)
|
|
llm_type = LLMType(llm_type) if llm_type else base.type
|
|
api_key = reader.str(Fragment.api_key) or base.api_key
|
|
api_base = reader.str(Fragment.api_base) or base.api_base
|
|
cognitive_services_endpoint = (
|
|
reader.str(Fragment.cognitive_services_endpoint)
|
|
or base.cognitive_services_endpoint
|
|
)
|
|
deployment_name = (
|
|
reader.str(Fragment.deployment_name) or base.deployment_name
|
|
)
|
|
|
|
if api_key is None and not _is_azure(llm_type):
|
|
raise ApiKeyMissingError
|
|
if _is_azure(llm_type):
|
|
if api_base is None:
|
|
raise AzureApiBaseMissingError
|
|
if deployment_name is None:
|
|
raise AzureDeploymentNameMissingError
|
|
|
|
sleep_on_rate_limit = reader.bool(Fragment.sleep_recommendation)
|
|
if sleep_on_rate_limit is None:
|
|
sleep_on_rate_limit = base.sleep_on_rate_limit_recommendation
|
|
|
|
return LLMParameters(
|
|
api_key=api_key,
|
|
type=llm_type,
|
|
api_base=api_base,
|
|
api_version=reader.str(Fragment.api_version) or base.api_version,
|
|
organization=reader.str("organization") or base.organization,
|
|
proxy=reader.str("proxy") or base.proxy,
|
|
model=reader.str("model") or base.model,
|
|
max_tokens=reader.int(Fragment.max_tokens) or base.max_tokens,
|
|
temperature=reader.float(Fragment.temperature) or base.temperature,
|
|
top_p=reader.float(Fragment.top_p) or base.top_p,
|
|
n=reader.int(Fragment.n) or base.n,
|
|
model_supports_json=reader.bool(Fragment.model_supports_json)
|
|
or base.model_supports_json,
|
|
request_timeout=reader.float(Fragment.request_timeout)
|
|
or base.request_timeout,
|
|
cognitive_services_endpoint=cognitive_services_endpoint,
|
|
deployment_name=deployment_name,
|
|
tokens_per_minute=reader.int("tokens_per_minute", Fragment.tpm)
|
|
or base.tokens_per_minute,
|
|
requests_per_minute=reader.int("requests_per_minute", Fragment.rpm)
|
|
or base.requests_per_minute,
|
|
max_retries=reader.int(Fragment.max_retries) or base.max_retries,
|
|
max_retry_wait=reader.float(Fragment.max_retry_wait)
|
|
or base.max_retry_wait,
|
|
sleep_on_rate_limit_recommendation=sleep_on_rate_limit,
|
|
concurrent_requests=reader.int(Fragment.concurrent_requests)
|
|
or base.concurrent_requests,
|
|
)
|
|
|
|
def hydrate_embeddings_params(
|
|
config: LLMConfigInput, base: LLMParameters
|
|
) -> LLMParameters:
|
|
with reader.use(config.get("llm")):
|
|
api_type = reader.str(Fragment.type) or defs.EMBEDDING_TYPE
|
|
api_type = LLMType(api_type) if api_type else defs.LLM_TYPE
|
|
api_key = reader.str(Fragment.api_key) or base.api_key
|
|
|
|
# In a unique events where:
|
|
# - same api_bases for LLM and embeddings (both Azure)
|
|
# - different api_bases for LLM and embeddings (both Azure)
|
|
# - LLM uses Azure OpenAI, while embeddings uses base OpenAI (this one is important)
|
|
# - LLM uses Azure OpenAI, while embeddings uses third-party OpenAI-like API
|
|
api_base = (
|
|
reader.str(Fragment.api_base) or base.api_base
|
|
if _is_azure(api_type)
|
|
else reader.str(Fragment.api_base)
|
|
)
|
|
api_version = (
|
|
reader.str(Fragment.api_version) or base.api_version
|
|
if _is_azure(api_type)
|
|
else reader.str(Fragment.api_version)
|
|
)
|
|
api_organization = reader.str("organization") or base.organization
|
|
api_proxy = reader.str("proxy") or base.proxy
|
|
cognitive_services_endpoint = (
|
|
reader.str(Fragment.cognitive_services_endpoint)
|
|
or base.cognitive_services_endpoint
|
|
)
|
|
deployment_name = reader.str(Fragment.deployment_name)
|
|
|
|
if api_key is None and not _is_azure(api_type):
|
|
raise ApiKeyMissingError(embedding=True)
|
|
if _is_azure(api_type):
|
|
if api_base is None:
|
|
raise AzureApiBaseMissingError(embedding=True)
|
|
if deployment_name is None:
|
|
raise AzureDeploymentNameMissingError(embedding=True)
|
|
|
|
sleep_on_rate_limit = reader.bool(Fragment.sleep_recommendation)
|
|
if sleep_on_rate_limit is None:
|
|
sleep_on_rate_limit = base.sleep_on_rate_limit_recommendation
|
|
|
|
return LLMParameters(
|
|
api_key=api_key,
|
|
type=api_type,
|
|
api_base=api_base,
|
|
api_version=api_version,
|
|
organization=api_organization,
|
|
proxy=api_proxy,
|
|
model=reader.str(Fragment.model) or defs.EMBEDDING_MODEL,
|
|
request_timeout=reader.float(Fragment.request_timeout)
|
|
or defs.LLM_REQUEST_TIMEOUT,
|
|
cognitive_services_endpoint=cognitive_services_endpoint,
|
|
deployment_name=deployment_name,
|
|
tokens_per_minute=reader.int("tokens_per_minute", Fragment.tpm)
|
|
or defs.LLM_TOKENS_PER_MINUTE,
|
|
requests_per_minute=reader.int("requests_per_minute", Fragment.rpm)
|
|
or defs.LLM_REQUESTS_PER_MINUTE,
|
|
max_retries=reader.int(Fragment.max_retries) or defs.LLM_MAX_RETRIES,
|
|
max_retry_wait=reader.float(Fragment.max_retry_wait)
|
|
or defs.LLM_MAX_RETRY_WAIT,
|
|
sleep_on_rate_limit_recommendation=sleep_on_rate_limit,
|
|
concurrent_requests=reader.int(Fragment.concurrent_requests)
|
|
or defs.LLM_CONCURRENT_REQUESTS,
|
|
)
|
|
|
|
def hydrate_parallelization_params(
|
|
config: LLMConfigInput, base: ParallelizationParameters
|
|
) -> ParallelizationParameters:
|
|
with reader.use(config.get("parallelization")):
|
|
return ParallelizationParameters(
|
|
num_threads=reader.int("num_threads", Fragment.thread_count)
|
|
or base.num_threads,
|
|
stagger=reader.float("stagger", Fragment.thread_stagger)
|
|
or base.stagger,
|
|
)
|
|
|
|
fallback_oai_key = env("OPENAI_API_KEY", env("AZURE_OPENAI_API_KEY", None))
|
|
fallback_oai_org = env("OPENAI_ORG_ID", None)
|
|
fallback_oai_base = env("OPENAI_BASE_URL", None)
|
|
fallback_oai_version = env("OPENAI_API_VERSION", None)
|
|
|
|
with reader.envvar_prefix(Section.graphrag), reader.use(values):
|
|
async_mode = reader.str(Fragment.async_mode)
|
|
async_mode = AsyncType(async_mode) if async_mode else defs.ASYNC_MODE
|
|
|
|
fallback_oai_key = reader.str(Fragment.api_key) or fallback_oai_key
|
|
fallback_oai_org = reader.str(Fragment.api_organization) or fallback_oai_org
|
|
fallback_oai_base = reader.str(Fragment.api_base) or fallback_oai_base
|
|
fallback_oai_version = reader.str(Fragment.api_version) or fallback_oai_version
|
|
fallback_oai_proxy = reader.str(Fragment.api_proxy)
|
|
|
|
with reader.envvar_prefix(Section.llm):
|
|
with reader.use(values.get("llm")):
|
|
llm_type = reader.str(Fragment.type)
|
|
llm_type = LLMType(llm_type) if llm_type else defs.LLM_TYPE
|
|
api_key = reader.str(Fragment.api_key) or fallback_oai_key
|
|
api_organization = (
|
|
reader.str(Fragment.api_organization) or fallback_oai_org
|
|
)
|
|
api_base = reader.str(Fragment.api_base) or fallback_oai_base
|
|
api_version = reader.str(Fragment.api_version) or fallback_oai_version
|
|
api_proxy = reader.str(Fragment.api_proxy) or fallback_oai_proxy
|
|
cognitive_services_endpoint = reader.str(
|
|
Fragment.cognitive_services_endpoint
|
|
)
|
|
deployment_name = reader.str(Fragment.deployment_name)
|
|
|
|
if api_key is None and not _is_azure(llm_type):
|
|
raise ApiKeyMissingError
|
|
if _is_azure(llm_type):
|
|
if api_base is None:
|
|
raise AzureApiBaseMissingError
|
|
if deployment_name is None:
|
|
raise AzureDeploymentNameMissingError
|
|
|
|
sleep_on_rate_limit = reader.bool(Fragment.sleep_recommendation)
|
|
if sleep_on_rate_limit is None:
|
|
sleep_on_rate_limit = defs.LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION
|
|
|
|
llm_model = LLMParameters(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
api_version=api_version,
|
|
organization=api_organization,
|
|
proxy=api_proxy,
|
|
type=llm_type,
|
|
model=reader.str(Fragment.model) or defs.LLM_MODEL,
|
|
max_tokens=reader.int(Fragment.max_tokens) or defs.LLM_MAX_TOKENS,
|
|
temperature=reader.float(Fragment.temperature)
|
|
or defs.LLM_TEMPERATURE,
|
|
top_p=reader.float(Fragment.top_p) or defs.LLM_TOP_P,
|
|
n=reader.int(Fragment.n) or defs.LLM_N,
|
|
model_supports_json=reader.bool(Fragment.model_supports_json),
|
|
request_timeout=reader.float(Fragment.request_timeout)
|
|
or defs.LLM_REQUEST_TIMEOUT,
|
|
cognitive_services_endpoint=cognitive_services_endpoint,
|
|
deployment_name=deployment_name,
|
|
tokens_per_minute=reader.int(Fragment.tpm)
|
|
or defs.LLM_TOKENS_PER_MINUTE,
|
|
requests_per_minute=reader.int(Fragment.rpm)
|
|
or defs.LLM_REQUESTS_PER_MINUTE,
|
|
max_retries=reader.int(Fragment.max_retries)
|
|
or defs.LLM_MAX_RETRIES,
|
|
max_retry_wait=reader.float(Fragment.max_retry_wait)
|
|
or defs.LLM_MAX_RETRY_WAIT,
|
|
sleep_on_rate_limit_recommendation=sleep_on_rate_limit,
|
|
concurrent_requests=reader.int(Fragment.concurrent_requests)
|
|
or defs.LLM_CONCURRENT_REQUESTS,
|
|
)
|
|
with reader.use(values.get("parallelization")):
|
|
llm_parallelization_model = ParallelizationParameters(
|
|
stagger=reader.float("stagger", Fragment.thread_stagger)
|
|
or defs.PARALLELIZATION_STAGGER,
|
|
num_threads=reader.int("num_threads", Fragment.thread_count)
|
|
or defs.PARALLELIZATION_NUM_THREADS,
|
|
)
|
|
embeddings_config = values.get("embeddings") or {}
|
|
with reader.envvar_prefix(Section.embedding), reader.use(embeddings_config):
|
|
embeddings_target = reader.str("target")
|
|
embeddings_model = TextEmbeddingConfig(
|
|
llm=hydrate_embeddings_params(embeddings_config, llm_model),
|
|
parallelization=hydrate_parallelization_params(
|
|
embeddings_config, llm_parallelization_model
|
|
),
|
|
vector_store=embeddings_config.get("vector_store", None),
|
|
async_mode=hydrate_async_type(embeddings_config, async_mode),
|
|
target=(
|
|
TextEmbeddingTarget(embeddings_target)
|
|
if embeddings_target
|
|
else defs.EMBEDDING_TARGET
|
|
),
|
|
batch_size=reader.int("batch_size") or defs.EMBEDDING_BATCH_SIZE,
|
|
batch_max_tokens=reader.int("batch_max_tokens")
|
|
or defs.EMBEDDING_BATCH_MAX_TOKENS,
|
|
skip=reader.list("skip") or [],
|
|
)
|
|
with (
|
|
reader.envvar_prefix(Section.node2vec),
|
|
reader.use(values.get("embed_graph")),
|
|
):
|
|
embed_graph_model = EmbedGraphConfig(
|
|
enabled=reader.bool(Fragment.enabled) or defs.NODE2VEC_ENABLED,
|
|
num_walks=reader.int("num_walks") or defs.NODE2VEC_NUM_WALKS,
|
|
walk_length=reader.int("walk_length") or defs.NODE2VEC_WALK_LENGTH,
|
|
window_size=reader.int("window_size") or defs.NODE2VEC_WINDOW_SIZE,
|
|
iterations=reader.int("iterations") or defs.NODE2VEC_ITERATIONS,
|
|
random_seed=reader.int("random_seed") or defs.NODE2VEC_RANDOM_SEED,
|
|
)
|
|
with reader.envvar_prefix(Section.input), reader.use(values.get("input")):
|
|
input_type = reader.str("type")
|
|
file_type = reader.str(Fragment.file_type)
|
|
input_model = InputConfig(
|
|
file_type=(
|
|
InputFileType(file_type) if file_type else defs.INPUT_FILE_TYPE
|
|
),
|
|
type=(InputType(input_type) if input_type else defs.INPUT_TYPE),
|
|
encoding=reader.str("file_encoding", Fragment.encoding)
|
|
or defs.INPUT_FILE_ENCODING,
|
|
base_dir=reader.str(Fragment.base_dir) or defs.INPUT_BASE_DIR,
|
|
file_pattern=reader.str("file_pattern")
|
|
or (
|
|
defs.INPUT_TEXT_PATTERN
|
|
if file_type == InputFileType.text
|
|
else defs.INPUT_CSV_PATTERN
|
|
),
|
|
source_column=reader.str("source_column"),
|
|
timestamp_column=reader.str("timestamp_column"),
|
|
timestamp_format=reader.str("timestamp_format"),
|
|
text_column=reader.str("text_column") or defs.INPUT_TEXT_COLUMN,
|
|
title_column=reader.str("title_column"),
|
|
document_attribute_columns=reader.list("document_attribute_columns")
|
|
or [],
|
|
connection_string=reader.str(Fragment.conn_string),
|
|
storage_account_blob_url=reader.str(Fragment.storage_account_blob_url),
|
|
container_name=reader.str(Fragment.container_name),
|
|
)
|
|
with reader.envvar_prefix(Section.cache), reader.use(values.get("cache")):
|
|
c_type = reader.str(Fragment.type)
|
|
cache_model = CacheConfig(
|
|
type=CacheType(c_type) if c_type else defs.CACHE_TYPE,
|
|
connection_string=reader.str(Fragment.conn_string),
|
|
storage_account_blob_url=reader.str(Fragment.storage_account_blob_url),
|
|
container_name=reader.str(Fragment.container_name),
|
|
base_dir=reader.str(Fragment.base_dir) or defs.CACHE_BASE_DIR,
|
|
)
|
|
with (
|
|
reader.envvar_prefix(Section.reporting),
|
|
reader.use(values.get("reporting")),
|
|
):
|
|
r_type = reader.str(Fragment.type)
|
|
reporting_model = ReportingConfig(
|
|
type=ReportingType(r_type) if r_type else defs.REPORTING_TYPE,
|
|
connection_string=reader.str(Fragment.conn_string),
|
|
storage_account_blob_url=reader.str(Fragment.storage_account_blob_url),
|
|
container_name=reader.str(Fragment.container_name),
|
|
base_dir=reader.str(Fragment.base_dir) or defs.REPORTING_BASE_DIR,
|
|
)
|
|
with reader.envvar_prefix(Section.storage), reader.use(values.get("storage")):
|
|
s_type = reader.str(Fragment.type)
|
|
storage_model = StorageConfig(
|
|
type=StorageType(s_type) if s_type else defs.STORAGE_TYPE,
|
|
connection_string=reader.str(Fragment.conn_string),
|
|
storage_account_blob_url=reader.str(Fragment.storage_account_blob_url),
|
|
container_name=reader.str(Fragment.container_name),
|
|
base_dir=reader.str(Fragment.base_dir) or defs.STORAGE_BASE_DIR,
|
|
)
|
|
with reader.envvar_prefix(Section.chunk), reader.use(values.get("chunks")):
|
|
chunks_model = ChunkingConfig(
|
|
size=reader.int("size") or defs.CHUNK_SIZE,
|
|
overlap=reader.int("overlap") or defs.CHUNK_OVERLAP,
|
|
group_by_columns=reader.list("group_by_columns", "BY_COLUMNS")
|
|
or defs.CHUNK_GROUP_BY_COLUMNS,
|
|
)
|
|
with (
|
|
reader.envvar_prefix(Section.snapshot),
|
|
reader.use(values.get("snapshots")),
|
|
):
|
|
snapshots_model = SnapshotsConfig(
|
|
graphml=reader.bool("graphml") or defs.SNAPSHOTS_GRAPHML,
|
|
raw_entities=reader.bool("raw_entities") or defs.SNAPSHOTS_RAW_ENTITIES,
|
|
top_level_nodes=reader.bool("top_level_nodes")
|
|
or defs.SNAPSHOTS_TOP_LEVEL_NODES,
|
|
)
|
|
with reader.envvar_prefix(Section.umap), reader.use(values.get("umap")):
|
|
umap_model = UmapConfig(
|
|
enabled=reader.bool(Fragment.enabled) or defs.UMAP_ENABLED,
|
|
)
|
|
|
|
entity_extraction_config = values.get("entity_extraction") or {}
|
|
with (
|
|
reader.envvar_prefix(Section.entity_extraction),
|
|
reader.use(entity_extraction_config),
|
|
):
|
|
entity_extraction_model = EntityExtractionConfig(
|
|
llm=hydrate_llm_params(entity_extraction_config, llm_model),
|
|
parallelization=hydrate_parallelization_params(
|
|
entity_extraction_config, llm_parallelization_model
|
|
),
|
|
async_mode=hydrate_async_type(entity_extraction_config, async_mode),
|
|
entity_types=reader.list("entity_types")
|
|
or defs.ENTITY_EXTRACTION_ENTITY_TYPES,
|
|
max_gleanings=reader.int(Fragment.max_gleanings)
|
|
or defs.ENTITY_EXTRACTION_MAX_GLEANINGS,
|
|
prompt=reader.str("prompt", Fragment.prompt_file),
|
|
)
|
|
|
|
claim_extraction_config = values.get("claim_extraction") or {}
|
|
with (
|
|
reader.envvar_prefix(Section.claim_extraction),
|
|
reader.use(claim_extraction_config),
|
|
):
|
|
claim_extraction_model = ClaimExtractionConfig(
|
|
enabled=reader.bool(Fragment.enabled) or defs.CLAIM_EXTRACTION_ENABLED,
|
|
llm=hydrate_llm_params(claim_extraction_config, llm_model),
|
|
parallelization=hydrate_parallelization_params(
|
|
claim_extraction_config, llm_parallelization_model
|
|
),
|
|
async_mode=hydrate_async_type(claim_extraction_config, async_mode),
|
|
description=reader.str("description") or defs.CLAIM_DESCRIPTION,
|
|
prompt=reader.str("prompt", Fragment.prompt_file),
|
|
max_gleanings=reader.int(Fragment.max_gleanings)
|
|
or defs.CLAIM_MAX_GLEANINGS,
|
|
)
|
|
|
|
community_report_config = values.get("community_reports") or {}
|
|
with (
|
|
reader.envvar_prefix(Section.community_reports),
|
|
reader.use(community_report_config),
|
|
):
|
|
community_reports_model = CommunityReportsConfig(
|
|
llm=hydrate_llm_params(community_report_config, llm_model),
|
|
parallelization=hydrate_parallelization_params(
|
|
community_report_config, llm_parallelization_model
|
|
),
|
|
async_mode=hydrate_async_type(community_report_config, async_mode),
|
|
prompt=reader.str("prompt", Fragment.prompt_file),
|
|
max_length=reader.int(Fragment.max_length)
|
|
or defs.COMMUNITY_REPORT_MAX_LENGTH,
|
|
max_input_length=reader.int("max_input_length")
|
|
or defs.COMMUNITY_REPORT_MAX_INPUT_LENGTH,
|
|
)
|
|
|
|
summarize_description_config = values.get("summarize_descriptions") or {}
|
|
with (
|
|
reader.envvar_prefix(Section.summarize_descriptions),
|
|
reader.use(values.get("summarize_descriptions")),
|
|
):
|
|
summarize_descriptions_model = SummarizeDescriptionsConfig(
|
|
llm=hydrate_llm_params(summarize_description_config, llm_model),
|
|
parallelization=hydrate_parallelization_params(
|
|
summarize_description_config, llm_parallelization_model
|
|
),
|
|
async_mode=hydrate_async_type(summarize_description_config, async_mode),
|
|
prompt=reader.str("prompt", Fragment.prompt_file),
|
|
max_length=reader.int(Fragment.max_length)
|
|
or defs.SUMMARIZE_DESCRIPTIONS_MAX_LENGTH,
|
|
)
|
|
|
|
with reader.use(values.get("cluster_graph")):
|
|
cluster_graph_model = ClusterGraphConfig(
|
|
max_cluster_size=reader.int("max_cluster_size") or defs.MAX_CLUSTER_SIZE
|
|
)
|
|
|
|
with (
|
|
reader.use(values.get("local_search")),
|
|
reader.envvar_prefix(Section.local_search),
|
|
):
|
|
local_search_model = LocalSearchConfig(
|
|
text_unit_prop=reader.float("text_unit_prop")
|
|
or defs.LOCAL_SEARCH_TEXT_UNIT_PROP,
|
|
community_prop=reader.float("community_prop")
|
|
or defs.LOCAL_SEARCH_COMMUNITY_PROP,
|
|
conversation_history_max_turns=reader.int(
|
|
"conversation_history_max_turns"
|
|
)
|
|
or defs.LOCAL_SEARCH_CONVERSATION_HISTORY_MAX_TURNS,
|
|
top_k_entities=reader.int("top_k_entities")
|
|
or defs.LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES,
|
|
top_k_relationships=reader.int("top_k_relationships")
|
|
or defs.LOCAL_SEARCH_TOP_K_RELATIONSHIPS,
|
|
temperature=reader.float("llm_temperature")
|
|
or defs.LOCAL_SEARCH_LLM_TEMPERATURE,
|
|
top_p=reader.float("llm_top_p") or defs.LOCAL_SEARCH_LLM_TOP_P,
|
|
n=reader.int("llm_n") or defs.LOCAL_SEARCH_LLM_N,
|
|
max_tokens=reader.int(Fragment.max_tokens)
|
|
or defs.LOCAL_SEARCH_MAX_TOKENS,
|
|
llm_max_tokens=reader.int("llm_max_tokens")
|
|
or defs.LOCAL_SEARCH_LLM_MAX_TOKENS,
|
|
)
|
|
|
|
with (
|
|
reader.use(values.get("global_search")),
|
|
reader.envvar_prefix(Section.global_search),
|
|
):
|
|
global_search_model = GlobalSearchConfig(
|
|
temperature=reader.float("llm_temperature")
|
|
or defs.GLOBAL_SEARCH_LLM_TEMPERATURE,
|
|
top_p=reader.float("llm_top_p") or defs.GLOBAL_SEARCH_LLM_TOP_P,
|
|
n=reader.int("llm_n") or defs.GLOBAL_SEARCH_LLM_N,
|
|
max_tokens=reader.int(Fragment.max_tokens)
|
|
or defs.GLOBAL_SEARCH_MAX_TOKENS,
|
|
data_max_tokens=reader.int("data_max_tokens")
|
|
or defs.GLOBAL_SEARCH_DATA_MAX_TOKENS,
|
|
map_max_tokens=reader.int("map_max_tokens")
|
|
or defs.GLOBAL_SEARCH_MAP_MAX_TOKENS,
|
|
reduce_max_tokens=reader.int("reduce_max_tokens")
|
|
or defs.GLOBAL_SEARCH_REDUCE_MAX_TOKENS,
|
|
concurrency=reader.int("concurrency") or defs.GLOBAL_SEARCH_CONCURRENCY,
|
|
)
|
|
|
|
encoding_model = reader.str(Fragment.encoding_model) or defs.ENCODING_MODEL
|
|
skip_workflows = reader.list("skip_workflows") or []
|
|
|
|
return GraphRagConfig(
|
|
root_dir=root_dir,
|
|
llm=llm_model,
|
|
parallelization=llm_parallelization_model,
|
|
async_mode=async_mode,
|
|
embeddings=embeddings_model,
|
|
embed_graph=embed_graph_model,
|
|
reporting=reporting_model,
|
|
storage=storage_model,
|
|
cache=cache_model,
|
|
input=input_model,
|
|
chunks=chunks_model,
|
|
snapshots=snapshots_model,
|
|
entity_extraction=entity_extraction_model,
|
|
claim_extraction=claim_extraction_model,
|
|
community_reports=community_reports_model,
|
|
summarize_descriptions=summarize_descriptions_model,
|
|
umap=umap_model,
|
|
cluster_graph=cluster_graph_model,
|
|
encoding_model=encoding_model,
|
|
skip_workflows=skip_workflows,
|
|
local_search=local_search_model,
|
|
global_search=global_search_model,
|
|
)
|
|
|
|
|
|
class Fragment(str, Enum):
|
|
"""Configuration Fragments."""
|
|
|
|
api_base = "API_BASE"
|
|
api_key = "API_KEY"
|
|
api_version = "API_VERSION"
|
|
api_organization = "API_ORGANIZATION"
|
|
api_proxy = "API_PROXY"
|
|
async_mode = "ASYNC_MODE"
|
|
base_dir = "BASE_DIR"
|
|
cognitive_services_endpoint = "COGNITIVE_SERVICES_ENDPOINT"
|
|
concurrent_requests = "CONCURRENT_REQUESTS"
|
|
conn_string = "CONNECTION_STRING"
|
|
container_name = "CONTAINER_NAME"
|
|
deployment_name = "DEPLOYMENT_NAME"
|
|
description = "DESCRIPTION"
|
|
enabled = "ENABLED"
|
|
encoding = "ENCODING"
|
|
encoding_model = "ENCODING_MODEL"
|
|
file_type = "FILE_TYPE"
|
|
max_gleanings = "MAX_GLEANINGS"
|
|
max_length = "MAX_LENGTH"
|
|
max_retries = "MAX_RETRIES"
|
|
max_retry_wait = "MAX_RETRY_WAIT"
|
|
max_tokens = "MAX_TOKENS"
|
|
temperature = "TEMPERATURE"
|
|
top_p = "TOP_P"
|
|
n = "N"
|
|
model = "MODEL"
|
|
model_supports_json = "MODEL_SUPPORTS_JSON"
|
|
prompt_file = "PROMPT_FILE"
|
|
request_timeout = "REQUEST_TIMEOUT"
|
|
rpm = "REQUESTS_PER_MINUTE"
|
|
sleep_recommendation = "SLEEP_ON_RATE_LIMIT_RECOMMENDATION"
|
|
storage_account_blob_url = "STORAGE_ACCOUNT_BLOB_URL"
|
|
thread_count = "THREAD_COUNT"
|
|
thread_stagger = "THREAD_STAGGER"
|
|
tpm = "TOKENS_PER_MINUTE"
|
|
type = "TYPE"
|
|
|
|
|
|
class Section(str, Enum):
|
|
"""Configuration Sections."""
|
|
|
|
base = "BASE"
|
|
cache = "CACHE"
|
|
chunk = "CHUNK"
|
|
claim_extraction = "CLAIM_EXTRACTION"
|
|
community_reports = "COMMUNITY_REPORTS"
|
|
embedding = "EMBEDDING"
|
|
entity_extraction = "ENTITY_EXTRACTION"
|
|
graphrag = "GRAPHRAG"
|
|
input = "INPUT"
|
|
llm = "LLM"
|
|
node2vec = "NODE2VEC"
|
|
reporting = "REPORTING"
|
|
snapshot = "SNAPSHOT"
|
|
storage = "STORAGE"
|
|
summarize_descriptions = "SUMMARIZE_DESCRIPTIONS"
|
|
umap = "UMAP"
|
|
local_search = "LOCAL_SEARCH"
|
|
global_search = "GLOBAL_SEARCH"
|
|
|
|
|
|
def _is_azure(llm_type: LLMType | None) -> bool:
|
|
return (
|
|
llm_type == LLMType.AzureOpenAI
|
|
or llm_type == LLMType.AzureOpenAIChat
|
|
or llm_type == LLMType.AzureOpenAIEmbedding
|
|
)
|
|
|
|
|
|
def _make_env(root_dir: str) -> Env:
|
|
read_dotenv(root_dir)
|
|
env = Env(expand_vars=True)
|
|
env.read_env()
|
|
return env
|
|
|
|
|
|
def _token_replace(data: dict):
|
|
"""Replace env-var tokens in a dictionary object."""
|
|
for key, value in data.items():
|
|
if isinstance(value, dict):
|
|
_token_replace(value)
|
|
elif isinstance(value, str):
|
|
data[key] = os.path.expandvars(value)
|