Fix/encoding model config (#1527)

* fix: include encoding_model option when initializing LLMParameters

* chore: add semver patch description

* Fix encoding model parsing

* Fix unit tests

---------

Co-authored-by: Nico Reinartz <nico.reinartz@rwth-aachen.de>
This commit is contained in:
Alonso Guevara 2024-12-16 21:03:56 -06:00 committed by GitHub
parent 329b83cf7f
commit f7cd155dbc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 38 additions and 11 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Respect encoding_model option"
}

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Fix encoding model config parsing"
}

View File

@ -85,6 +85,7 @@ def create_graphrag_config(
deployment_name = (
reader.str(Fragment.deployment_name) or base.deployment_name
)
encoding_model = reader.str(Fragment.encoding_model) or base.encoding_model
if api_key is None and not _is_azure(llm_type):
raise ApiKeyMissingError
@ -106,6 +107,7 @@ def create_graphrag_config(
organization=reader.str("organization") or base.organization,
proxy=reader.str("proxy") or base.proxy,
model=reader.str("model") or base.model,
encoding_model=encoding_model,
max_tokens=reader.int(Fragment.max_tokens) or base.max_tokens,
temperature=reader.float(Fragment.temperature) or base.temperature,
top_p=reader.float(Fragment.top_p) or base.top_p,
@ -155,6 +157,7 @@ def create_graphrag_config(
api_proxy = reader.str("proxy") or base.proxy
audience = reader.str(Fragment.audience) or base.audience
deployment_name = reader.str(Fragment.deployment_name)
encoding_model = reader.str(Fragment.encoding_model) or base.encoding_model
if api_key is None and not _is_azure(api_type):
raise ApiKeyMissingError(embedding=True)
@ -176,6 +179,7 @@ def create_graphrag_config(
organization=api_organization,
proxy=api_proxy,
model=reader.str(Fragment.model) or defs.EMBEDDING_MODEL,
encoding_model=encoding_model,
request_timeout=reader.float(Fragment.request_timeout)
or defs.LLM_REQUEST_TIMEOUT,
audience=audience,
@ -217,6 +221,9 @@ def create_graphrag_config(
fallback_oai_base = reader.str(Fragment.api_base) or fallback_oai_base
fallback_oai_version = reader.str(Fragment.api_version) or fallback_oai_version
fallback_oai_proxy = reader.str(Fragment.api_proxy)
global_encoding_model = (
reader.str(Fragment.encoding_model) or defs.ENCODING_MODEL
)
with reader.envvar_prefix(Section.llm):
with reader.use(values.get("llm")):
@ -231,6 +238,9 @@ def create_graphrag_config(
api_proxy = reader.str(Fragment.api_proxy) or fallback_oai_proxy
audience = reader.str(Fragment.audience)
deployment_name = reader.str(Fragment.deployment_name)
encoding_model = (
reader.str(Fragment.encoding_model) or global_encoding_model
)
if api_key is None and not _is_azure(llm_type):
raise ApiKeyMissingError
@ -252,6 +262,7 @@ def create_graphrag_config(
proxy=api_proxy,
type=llm_type,
model=reader.str(Fragment.model) or defs.LLM_MODEL,
encoding_model=encoding_model,
max_tokens=reader.int(Fragment.max_tokens) or defs.LLM_MAX_TOKENS,
temperature=reader.float(Fragment.temperature)
or defs.LLM_TEMPERATURE,
@ -396,12 +407,15 @@ def create_graphrag_config(
group_by_columns = reader.list("group_by_columns", "BY_COLUMNS")
if group_by_columns is None:
group_by_columns = defs.CHUNK_GROUP_BY_COLUMNS
encoding_model = (
reader.str(Fragment.encoding_model) or global_encoding_model
)
chunks_model = ChunkingConfig(
size=reader.int("size") or defs.CHUNK_SIZE,
overlap=reader.int("overlap") or defs.CHUNK_OVERLAP,
group_by_columns=group_by_columns,
encoding_model=reader.str(Fragment.encoding_model),
encoding_model=encoding_model,
)
with (
reader.envvar_prefix(Section.snapshot),
@ -428,6 +442,9 @@ def create_graphrag_config(
if max_gleanings is not None
else defs.ENTITY_EXTRACTION_MAX_GLEANINGS
)
encoding_model = (
reader.str(Fragment.encoding_model) or global_encoding_model
)
entity_extraction_model = EntityExtractionConfig(
llm=hydrate_llm_params(entity_extraction_config, llm_model),
@ -440,7 +457,7 @@ def create_graphrag_config(
max_gleanings=max_gleanings,
prompt=reader.str("prompt", Fragment.prompt_file),
strategy=entity_extraction_config.get("strategy"),
encoding_model=reader.str(Fragment.encoding_model),
encoding_model=encoding_model,
)
claim_extraction_config = values.get("claim_extraction") or {}
@ -452,6 +469,9 @@ def create_graphrag_config(
max_gleanings = (
max_gleanings if max_gleanings is not None else defs.CLAIM_MAX_GLEANINGS
)
encoding_model = (
reader.str(Fragment.encoding_model) or global_encoding_model
)
claim_extraction_model = ClaimExtractionConfig(
enabled=reader.bool(Fragment.enabled) or defs.CLAIM_EXTRACTION_ENABLED,
llm=hydrate_llm_params(claim_extraction_config, llm_model),
@ -462,7 +482,7 @@ def create_graphrag_config(
description=reader.str("description") or defs.CLAIM_DESCRIPTION,
prompt=reader.str("prompt", Fragment.prompt_file),
max_gleanings=max_gleanings,
encoding_model=reader.str(Fragment.encoding_model),
encoding_model=encoding_model,
)
community_report_config = values.get("community_reports") or {}
@ -603,7 +623,6 @@ def create_graphrag_config(
or defs.DRIFT_LOCAL_SEARCH_LLM_MAX_TOKENS,
)
encoding_model = reader.str(Fragment.encoding_model) or defs.ENCODING_MODEL
skip_workflows = reader.list("skip_workflows") or []
return GraphRagConfig(
@ -626,7 +645,7 @@ def create_graphrag_config(
summarize_descriptions=summarize_descriptions_model,
umap=umap_model,
cluster_graph=cluster_graph_model,
encoding_model=encoding_model,
encoding_model=global_encoding_model,
skip_workflows=skip_workflows,
local_search=local_search_model,
global_search=global_search_model,

View File

@ -27,7 +27,7 @@ class ChunkingConfig(BaseModel):
default=None, description="The encoding model to use."
)
def resolved_strategy(self, encoding_model: str) -> dict:
def resolved_strategy(self, encoding_model: str | None) -> dict:
"""Get the resolved chunking strategy."""
from graphrag.index.operations.chunk_text import ChunkStrategyType
@ -36,5 +36,5 @@ class ChunkingConfig(BaseModel):
"chunk_size": self.size,
"chunk_overlap": self.overlap,
"group_by_columns": self.group_by_columns,
"encoding_name": self.encoding_model or encoding_model,
"encoding_name": encoding_model or self.encoding_model,
}

View File

@ -35,7 +35,7 @@ class ClaimExtractionConfig(LLMConfig):
default=None, description="The encoding model to use."
)
def resolved_strategy(self, root_dir: str, encoding_model: str) -> dict:
def resolved_strategy(self, root_dir: str, encoding_model: str | None) -> dict:
"""Get the resolved claim extraction strategy."""
from graphrag.index.operations.extract_covariates import (
ExtractClaimsStrategyType,
@ -52,5 +52,5 @@ class ClaimExtractionConfig(LLMConfig):
else None,
"claim_description": self.description,
"max_gleanings": self.max_gleanings,
"encoding_name": self.encoding_model or encoding_model,
"encoding_name": encoding_model or self.encoding_model,
}

View File

@ -32,7 +32,7 @@ class EntityExtractionConfig(LLMConfig):
default=None, description="The encoding model to use."
)
def resolved_strategy(self, root_dir: str, encoding_model: str) -> dict:
def resolved_strategy(self, root_dir: str, encoding_model: str | None) -> dict:
"""Get the resolved entity extraction strategy."""
from graphrag.index.operations.extract_entities import (
ExtractEntityStrategyType,
@ -49,6 +49,6 @@ class EntityExtractionConfig(LLMConfig):
else None,
"max_gleanings": self.max_gleanings,
# It's prechunked in create_base_text_units
"encoding_name": self.encoding_model or encoding_model,
"encoding_name": encoding_model or self.encoding_model,
"prechunked": True,
}