mirror of
https://github.com/microsoft/graphrag.git
synced 2025-12-10 14:31:15 +00:00
Fix/encoding model config (#1527)
* fix: include encoding_model option when initializing LLMParameters * chore: add semver patch description * Fix encoding model parsing * Fix unit tests --------- Co-authored-by: Nico Reinartz <nico.reinartz@rwth-aachen.de>
This commit is contained in:
parent
329b83cf7f
commit
f7cd155dbc
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Respect encoding_model option"
|
||||
}
|
||||
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Fix encoding model config parsing"
|
||||
}
|
||||
@ -85,6 +85,7 @@ def create_graphrag_config(
|
||||
deployment_name = (
|
||||
reader.str(Fragment.deployment_name) or base.deployment_name
|
||||
)
|
||||
encoding_model = reader.str(Fragment.encoding_model) or base.encoding_model
|
||||
|
||||
if api_key is None and not _is_azure(llm_type):
|
||||
raise ApiKeyMissingError
|
||||
@ -106,6 +107,7 @@ def create_graphrag_config(
|
||||
organization=reader.str("organization") or base.organization,
|
||||
proxy=reader.str("proxy") or base.proxy,
|
||||
model=reader.str("model") or base.model,
|
||||
encoding_model=encoding_model,
|
||||
max_tokens=reader.int(Fragment.max_tokens) or base.max_tokens,
|
||||
temperature=reader.float(Fragment.temperature) or base.temperature,
|
||||
top_p=reader.float(Fragment.top_p) or base.top_p,
|
||||
@ -155,6 +157,7 @@ def create_graphrag_config(
|
||||
api_proxy = reader.str("proxy") or base.proxy
|
||||
audience = reader.str(Fragment.audience) or base.audience
|
||||
deployment_name = reader.str(Fragment.deployment_name)
|
||||
encoding_model = reader.str(Fragment.encoding_model) or base.encoding_model
|
||||
|
||||
if api_key is None and not _is_azure(api_type):
|
||||
raise ApiKeyMissingError(embedding=True)
|
||||
@ -176,6 +179,7 @@ def create_graphrag_config(
|
||||
organization=api_organization,
|
||||
proxy=api_proxy,
|
||||
model=reader.str(Fragment.model) or defs.EMBEDDING_MODEL,
|
||||
encoding_model=encoding_model,
|
||||
request_timeout=reader.float(Fragment.request_timeout)
|
||||
or defs.LLM_REQUEST_TIMEOUT,
|
||||
audience=audience,
|
||||
@ -217,6 +221,9 @@ def create_graphrag_config(
|
||||
fallback_oai_base = reader.str(Fragment.api_base) or fallback_oai_base
|
||||
fallback_oai_version = reader.str(Fragment.api_version) or fallback_oai_version
|
||||
fallback_oai_proxy = reader.str(Fragment.api_proxy)
|
||||
global_encoding_model = (
|
||||
reader.str(Fragment.encoding_model) or defs.ENCODING_MODEL
|
||||
)
|
||||
|
||||
with reader.envvar_prefix(Section.llm):
|
||||
with reader.use(values.get("llm")):
|
||||
@ -231,6 +238,9 @@ def create_graphrag_config(
|
||||
api_proxy = reader.str(Fragment.api_proxy) or fallback_oai_proxy
|
||||
audience = reader.str(Fragment.audience)
|
||||
deployment_name = reader.str(Fragment.deployment_name)
|
||||
encoding_model = (
|
||||
reader.str(Fragment.encoding_model) or global_encoding_model
|
||||
)
|
||||
|
||||
if api_key is None and not _is_azure(llm_type):
|
||||
raise ApiKeyMissingError
|
||||
@ -252,6 +262,7 @@ def create_graphrag_config(
|
||||
proxy=api_proxy,
|
||||
type=llm_type,
|
||||
model=reader.str(Fragment.model) or defs.LLM_MODEL,
|
||||
encoding_model=encoding_model,
|
||||
max_tokens=reader.int(Fragment.max_tokens) or defs.LLM_MAX_TOKENS,
|
||||
temperature=reader.float(Fragment.temperature)
|
||||
or defs.LLM_TEMPERATURE,
|
||||
@ -396,12 +407,15 @@ def create_graphrag_config(
|
||||
group_by_columns = reader.list("group_by_columns", "BY_COLUMNS")
|
||||
if group_by_columns is None:
|
||||
group_by_columns = defs.CHUNK_GROUP_BY_COLUMNS
|
||||
encoding_model = (
|
||||
reader.str(Fragment.encoding_model) or global_encoding_model
|
||||
)
|
||||
|
||||
chunks_model = ChunkingConfig(
|
||||
size=reader.int("size") or defs.CHUNK_SIZE,
|
||||
overlap=reader.int("overlap") or defs.CHUNK_OVERLAP,
|
||||
group_by_columns=group_by_columns,
|
||||
encoding_model=reader.str(Fragment.encoding_model),
|
||||
encoding_model=encoding_model,
|
||||
)
|
||||
with (
|
||||
reader.envvar_prefix(Section.snapshot),
|
||||
@ -428,6 +442,9 @@ def create_graphrag_config(
|
||||
if max_gleanings is not None
|
||||
else defs.ENTITY_EXTRACTION_MAX_GLEANINGS
|
||||
)
|
||||
encoding_model = (
|
||||
reader.str(Fragment.encoding_model) or global_encoding_model
|
||||
)
|
||||
|
||||
entity_extraction_model = EntityExtractionConfig(
|
||||
llm=hydrate_llm_params(entity_extraction_config, llm_model),
|
||||
@ -440,7 +457,7 @@ def create_graphrag_config(
|
||||
max_gleanings=max_gleanings,
|
||||
prompt=reader.str("prompt", Fragment.prompt_file),
|
||||
strategy=entity_extraction_config.get("strategy"),
|
||||
encoding_model=reader.str(Fragment.encoding_model),
|
||||
encoding_model=encoding_model,
|
||||
)
|
||||
|
||||
claim_extraction_config = values.get("claim_extraction") or {}
|
||||
@ -452,6 +469,9 @@ def create_graphrag_config(
|
||||
max_gleanings = (
|
||||
max_gleanings if max_gleanings is not None else defs.CLAIM_MAX_GLEANINGS
|
||||
)
|
||||
encoding_model = (
|
||||
reader.str(Fragment.encoding_model) or global_encoding_model
|
||||
)
|
||||
claim_extraction_model = ClaimExtractionConfig(
|
||||
enabled=reader.bool(Fragment.enabled) or defs.CLAIM_EXTRACTION_ENABLED,
|
||||
llm=hydrate_llm_params(claim_extraction_config, llm_model),
|
||||
@ -462,7 +482,7 @@ def create_graphrag_config(
|
||||
description=reader.str("description") or defs.CLAIM_DESCRIPTION,
|
||||
prompt=reader.str("prompt", Fragment.prompt_file),
|
||||
max_gleanings=max_gleanings,
|
||||
encoding_model=reader.str(Fragment.encoding_model),
|
||||
encoding_model=encoding_model,
|
||||
)
|
||||
|
||||
community_report_config = values.get("community_reports") or {}
|
||||
@ -603,7 +623,6 @@ def create_graphrag_config(
|
||||
or defs.DRIFT_LOCAL_SEARCH_LLM_MAX_TOKENS,
|
||||
)
|
||||
|
||||
encoding_model = reader.str(Fragment.encoding_model) or defs.ENCODING_MODEL
|
||||
skip_workflows = reader.list("skip_workflows") or []
|
||||
|
||||
return GraphRagConfig(
|
||||
@ -626,7 +645,7 @@ def create_graphrag_config(
|
||||
summarize_descriptions=summarize_descriptions_model,
|
||||
umap=umap_model,
|
||||
cluster_graph=cluster_graph_model,
|
||||
encoding_model=encoding_model,
|
||||
encoding_model=global_encoding_model,
|
||||
skip_workflows=skip_workflows,
|
||||
local_search=local_search_model,
|
||||
global_search=global_search_model,
|
||||
|
||||
@ -27,7 +27,7 @@ class ChunkingConfig(BaseModel):
|
||||
default=None, description="The encoding model to use."
|
||||
)
|
||||
|
||||
def resolved_strategy(self, encoding_model: str) -> dict:
|
||||
def resolved_strategy(self, encoding_model: str | None) -> dict:
|
||||
"""Get the resolved chunking strategy."""
|
||||
from graphrag.index.operations.chunk_text import ChunkStrategyType
|
||||
|
||||
@ -36,5 +36,5 @@ class ChunkingConfig(BaseModel):
|
||||
"chunk_size": self.size,
|
||||
"chunk_overlap": self.overlap,
|
||||
"group_by_columns": self.group_by_columns,
|
||||
"encoding_name": self.encoding_model or encoding_model,
|
||||
"encoding_name": encoding_model or self.encoding_model,
|
||||
}
|
||||
|
||||
@ -35,7 +35,7 @@ class ClaimExtractionConfig(LLMConfig):
|
||||
default=None, description="The encoding model to use."
|
||||
)
|
||||
|
||||
def resolved_strategy(self, root_dir: str, encoding_model: str) -> dict:
|
||||
def resolved_strategy(self, root_dir: str, encoding_model: str | None) -> dict:
|
||||
"""Get the resolved claim extraction strategy."""
|
||||
from graphrag.index.operations.extract_covariates import (
|
||||
ExtractClaimsStrategyType,
|
||||
@ -52,5 +52,5 @@ class ClaimExtractionConfig(LLMConfig):
|
||||
else None,
|
||||
"claim_description": self.description,
|
||||
"max_gleanings": self.max_gleanings,
|
||||
"encoding_name": self.encoding_model or encoding_model,
|
||||
"encoding_name": encoding_model or self.encoding_model,
|
||||
}
|
||||
|
||||
@ -32,7 +32,7 @@ class EntityExtractionConfig(LLMConfig):
|
||||
default=None, description="The encoding model to use."
|
||||
)
|
||||
|
||||
def resolved_strategy(self, root_dir: str, encoding_model: str) -> dict:
|
||||
def resolved_strategy(self, root_dir: str, encoding_model: str | None) -> dict:
|
||||
"""Get the resolved entity extraction strategy."""
|
||||
from graphrag.index.operations.extract_entities import (
|
||||
ExtractEntityStrategyType,
|
||||
@ -49,6 +49,6 @@ class EntityExtractionConfig(LLMConfig):
|
||||
else None,
|
||||
"max_gleanings": self.max_gleanings,
|
||||
# It's prechunked in create_base_text_units
|
||||
"encoding_name": self.encoding_model or encoding_model,
|
||||
"encoding_name": encoding_model or self.encoding_model,
|
||||
"prechunked": True,
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user