mirror of
https://github.com/microsoft/graphrag.git
synced 2025-12-16 17:48:48 +00:00
Fix/encoding model config (#1527)
* fix: include encoding_model option when initializing LLMParameters * chore: add semver patch description * Fix encoding model parsing * Fix unit tests --------- Co-authored-by: Nico Reinartz <nico.reinartz@rwth-aachen.de>
This commit is contained in:
parent
329b83cf7f
commit
f7cd155dbc
@ -0,0 +1,4 @@
|
|||||||
|
{
|
||||||
|
"type": "patch",
|
||||||
|
"description": "Respect encoding_model option"
|
||||||
|
}
|
||||||
@ -0,0 +1,4 @@
|
|||||||
|
{
|
||||||
|
"type": "patch",
|
||||||
|
"description": "Fix encoding model config parsing"
|
||||||
|
}
|
||||||
@ -85,6 +85,7 @@ def create_graphrag_config(
|
|||||||
deployment_name = (
|
deployment_name = (
|
||||||
reader.str(Fragment.deployment_name) or base.deployment_name
|
reader.str(Fragment.deployment_name) or base.deployment_name
|
||||||
)
|
)
|
||||||
|
encoding_model = reader.str(Fragment.encoding_model) or base.encoding_model
|
||||||
|
|
||||||
if api_key is None and not _is_azure(llm_type):
|
if api_key is None and not _is_azure(llm_type):
|
||||||
raise ApiKeyMissingError
|
raise ApiKeyMissingError
|
||||||
@ -106,6 +107,7 @@ def create_graphrag_config(
|
|||||||
organization=reader.str("organization") or base.organization,
|
organization=reader.str("organization") or base.organization,
|
||||||
proxy=reader.str("proxy") or base.proxy,
|
proxy=reader.str("proxy") or base.proxy,
|
||||||
model=reader.str("model") or base.model,
|
model=reader.str("model") or base.model,
|
||||||
|
encoding_model=encoding_model,
|
||||||
max_tokens=reader.int(Fragment.max_tokens) or base.max_tokens,
|
max_tokens=reader.int(Fragment.max_tokens) or base.max_tokens,
|
||||||
temperature=reader.float(Fragment.temperature) or base.temperature,
|
temperature=reader.float(Fragment.temperature) or base.temperature,
|
||||||
top_p=reader.float(Fragment.top_p) or base.top_p,
|
top_p=reader.float(Fragment.top_p) or base.top_p,
|
||||||
@ -155,6 +157,7 @@ def create_graphrag_config(
|
|||||||
api_proxy = reader.str("proxy") or base.proxy
|
api_proxy = reader.str("proxy") or base.proxy
|
||||||
audience = reader.str(Fragment.audience) or base.audience
|
audience = reader.str(Fragment.audience) or base.audience
|
||||||
deployment_name = reader.str(Fragment.deployment_name)
|
deployment_name = reader.str(Fragment.deployment_name)
|
||||||
|
encoding_model = reader.str(Fragment.encoding_model) or base.encoding_model
|
||||||
|
|
||||||
if api_key is None and not _is_azure(api_type):
|
if api_key is None and not _is_azure(api_type):
|
||||||
raise ApiKeyMissingError(embedding=True)
|
raise ApiKeyMissingError(embedding=True)
|
||||||
@ -176,6 +179,7 @@ def create_graphrag_config(
|
|||||||
organization=api_organization,
|
organization=api_organization,
|
||||||
proxy=api_proxy,
|
proxy=api_proxy,
|
||||||
model=reader.str(Fragment.model) or defs.EMBEDDING_MODEL,
|
model=reader.str(Fragment.model) or defs.EMBEDDING_MODEL,
|
||||||
|
encoding_model=encoding_model,
|
||||||
request_timeout=reader.float(Fragment.request_timeout)
|
request_timeout=reader.float(Fragment.request_timeout)
|
||||||
or defs.LLM_REQUEST_TIMEOUT,
|
or defs.LLM_REQUEST_TIMEOUT,
|
||||||
audience=audience,
|
audience=audience,
|
||||||
@ -217,6 +221,9 @@ def create_graphrag_config(
|
|||||||
fallback_oai_base = reader.str(Fragment.api_base) or fallback_oai_base
|
fallback_oai_base = reader.str(Fragment.api_base) or fallback_oai_base
|
||||||
fallback_oai_version = reader.str(Fragment.api_version) or fallback_oai_version
|
fallback_oai_version = reader.str(Fragment.api_version) or fallback_oai_version
|
||||||
fallback_oai_proxy = reader.str(Fragment.api_proxy)
|
fallback_oai_proxy = reader.str(Fragment.api_proxy)
|
||||||
|
global_encoding_model = (
|
||||||
|
reader.str(Fragment.encoding_model) or defs.ENCODING_MODEL
|
||||||
|
)
|
||||||
|
|
||||||
with reader.envvar_prefix(Section.llm):
|
with reader.envvar_prefix(Section.llm):
|
||||||
with reader.use(values.get("llm")):
|
with reader.use(values.get("llm")):
|
||||||
@ -231,6 +238,9 @@ def create_graphrag_config(
|
|||||||
api_proxy = reader.str(Fragment.api_proxy) or fallback_oai_proxy
|
api_proxy = reader.str(Fragment.api_proxy) or fallback_oai_proxy
|
||||||
audience = reader.str(Fragment.audience)
|
audience = reader.str(Fragment.audience)
|
||||||
deployment_name = reader.str(Fragment.deployment_name)
|
deployment_name = reader.str(Fragment.deployment_name)
|
||||||
|
encoding_model = (
|
||||||
|
reader.str(Fragment.encoding_model) or global_encoding_model
|
||||||
|
)
|
||||||
|
|
||||||
if api_key is None and not _is_azure(llm_type):
|
if api_key is None and not _is_azure(llm_type):
|
||||||
raise ApiKeyMissingError
|
raise ApiKeyMissingError
|
||||||
@ -252,6 +262,7 @@ def create_graphrag_config(
|
|||||||
proxy=api_proxy,
|
proxy=api_proxy,
|
||||||
type=llm_type,
|
type=llm_type,
|
||||||
model=reader.str(Fragment.model) or defs.LLM_MODEL,
|
model=reader.str(Fragment.model) or defs.LLM_MODEL,
|
||||||
|
encoding_model=encoding_model,
|
||||||
max_tokens=reader.int(Fragment.max_tokens) or defs.LLM_MAX_TOKENS,
|
max_tokens=reader.int(Fragment.max_tokens) or defs.LLM_MAX_TOKENS,
|
||||||
temperature=reader.float(Fragment.temperature)
|
temperature=reader.float(Fragment.temperature)
|
||||||
or defs.LLM_TEMPERATURE,
|
or defs.LLM_TEMPERATURE,
|
||||||
@ -396,12 +407,15 @@ def create_graphrag_config(
|
|||||||
group_by_columns = reader.list("group_by_columns", "BY_COLUMNS")
|
group_by_columns = reader.list("group_by_columns", "BY_COLUMNS")
|
||||||
if group_by_columns is None:
|
if group_by_columns is None:
|
||||||
group_by_columns = defs.CHUNK_GROUP_BY_COLUMNS
|
group_by_columns = defs.CHUNK_GROUP_BY_COLUMNS
|
||||||
|
encoding_model = (
|
||||||
|
reader.str(Fragment.encoding_model) or global_encoding_model
|
||||||
|
)
|
||||||
|
|
||||||
chunks_model = ChunkingConfig(
|
chunks_model = ChunkingConfig(
|
||||||
size=reader.int("size") or defs.CHUNK_SIZE,
|
size=reader.int("size") or defs.CHUNK_SIZE,
|
||||||
overlap=reader.int("overlap") or defs.CHUNK_OVERLAP,
|
overlap=reader.int("overlap") or defs.CHUNK_OVERLAP,
|
||||||
group_by_columns=group_by_columns,
|
group_by_columns=group_by_columns,
|
||||||
encoding_model=reader.str(Fragment.encoding_model),
|
encoding_model=encoding_model,
|
||||||
)
|
)
|
||||||
with (
|
with (
|
||||||
reader.envvar_prefix(Section.snapshot),
|
reader.envvar_prefix(Section.snapshot),
|
||||||
@ -428,6 +442,9 @@ def create_graphrag_config(
|
|||||||
if max_gleanings is not None
|
if max_gleanings is not None
|
||||||
else defs.ENTITY_EXTRACTION_MAX_GLEANINGS
|
else defs.ENTITY_EXTRACTION_MAX_GLEANINGS
|
||||||
)
|
)
|
||||||
|
encoding_model = (
|
||||||
|
reader.str(Fragment.encoding_model) or global_encoding_model
|
||||||
|
)
|
||||||
|
|
||||||
entity_extraction_model = EntityExtractionConfig(
|
entity_extraction_model = EntityExtractionConfig(
|
||||||
llm=hydrate_llm_params(entity_extraction_config, llm_model),
|
llm=hydrate_llm_params(entity_extraction_config, llm_model),
|
||||||
@ -440,7 +457,7 @@ def create_graphrag_config(
|
|||||||
max_gleanings=max_gleanings,
|
max_gleanings=max_gleanings,
|
||||||
prompt=reader.str("prompt", Fragment.prompt_file),
|
prompt=reader.str("prompt", Fragment.prompt_file),
|
||||||
strategy=entity_extraction_config.get("strategy"),
|
strategy=entity_extraction_config.get("strategy"),
|
||||||
encoding_model=reader.str(Fragment.encoding_model),
|
encoding_model=encoding_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
claim_extraction_config = values.get("claim_extraction") or {}
|
claim_extraction_config = values.get("claim_extraction") or {}
|
||||||
@ -452,6 +469,9 @@ def create_graphrag_config(
|
|||||||
max_gleanings = (
|
max_gleanings = (
|
||||||
max_gleanings if max_gleanings is not None else defs.CLAIM_MAX_GLEANINGS
|
max_gleanings if max_gleanings is not None else defs.CLAIM_MAX_GLEANINGS
|
||||||
)
|
)
|
||||||
|
encoding_model = (
|
||||||
|
reader.str(Fragment.encoding_model) or global_encoding_model
|
||||||
|
)
|
||||||
claim_extraction_model = ClaimExtractionConfig(
|
claim_extraction_model = ClaimExtractionConfig(
|
||||||
enabled=reader.bool(Fragment.enabled) or defs.CLAIM_EXTRACTION_ENABLED,
|
enabled=reader.bool(Fragment.enabled) or defs.CLAIM_EXTRACTION_ENABLED,
|
||||||
llm=hydrate_llm_params(claim_extraction_config, llm_model),
|
llm=hydrate_llm_params(claim_extraction_config, llm_model),
|
||||||
@ -462,7 +482,7 @@ def create_graphrag_config(
|
|||||||
description=reader.str("description") or defs.CLAIM_DESCRIPTION,
|
description=reader.str("description") or defs.CLAIM_DESCRIPTION,
|
||||||
prompt=reader.str("prompt", Fragment.prompt_file),
|
prompt=reader.str("prompt", Fragment.prompt_file),
|
||||||
max_gleanings=max_gleanings,
|
max_gleanings=max_gleanings,
|
||||||
encoding_model=reader.str(Fragment.encoding_model),
|
encoding_model=encoding_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
community_report_config = values.get("community_reports") or {}
|
community_report_config = values.get("community_reports") or {}
|
||||||
@ -603,7 +623,6 @@ def create_graphrag_config(
|
|||||||
or defs.DRIFT_LOCAL_SEARCH_LLM_MAX_TOKENS,
|
or defs.DRIFT_LOCAL_SEARCH_LLM_MAX_TOKENS,
|
||||||
)
|
)
|
||||||
|
|
||||||
encoding_model = reader.str(Fragment.encoding_model) or defs.ENCODING_MODEL
|
|
||||||
skip_workflows = reader.list("skip_workflows") or []
|
skip_workflows = reader.list("skip_workflows") or []
|
||||||
|
|
||||||
return GraphRagConfig(
|
return GraphRagConfig(
|
||||||
@ -626,7 +645,7 @@ def create_graphrag_config(
|
|||||||
summarize_descriptions=summarize_descriptions_model,
|
summarize_descriptions=summarize_descriptions_model,
|
||||||
umap=umap_model,
|
umap=umap_model,
|
||||||
cluster_graph=cluster_graph_model,
|
cluster_graph=cluster_graph_model,
|
||||||
encoding_model=encoding_model,
|
encoding_model=global_encoding_model,
|
||||||
skip_workflows=skip_workflows,
|
skip_workflows=skip_workflows,
|
||||||
local_search=local_search_model,
|
local_search=local_search_model,
|
||||||
global_search=global_search_model,
|
global_search=global_search_model,
|
||||||
|
|||||||
@ -27,7 +27,7 @@ class ChunkingConfig(BaseModel):
|
|||||||
default=None, description="The encoding model to use."
|
default=None, description="The encoding model to use."
|
||||||
)
|
)
|
||||||
|
|
||||||
def resolved_strategy(self, encoding_model: str) -> dict:
|
def resolved_strategy(self, encoding_model: str | None) -> dict:
|
||||||
"""Get the resolved chunking strategy."""
|
"""Get the resolved chunking strategy."""
|
||||||
from graphrag.index.operations.chunk_text import ChunkStrategyType
|
from graphrag.index.operations.chunk_text import ChunkStrategyType
|
||||||
|
|
||||||
@ -36,5 +36,5 @@ class ChunkingConfig(BaseModel):
|
|||||||
"chunk_size": self.size,
|
"chunk_size": self.size,
|
||||||
"chunk_overlap": self.overlap,
|
"chunk_overlap": self.overlap,
|
||||||
"group_by_columns": self.group_by_columns,
|
"group_by_columns": self.group_by_columns,
|
||||||
"encoding_name": self.encoding_model or encoding_model,
|
"encoding_name": encoding_model or self.encoding_model,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -35,7 +35,7 @@ class ClaimExtractionConfig(LLMConfig):
|
|||||||
default=None, description="The encoding model to use."
|
default=None, description="The encoding model to use."
|
||||||
)
|
)
|
||||||
|
|
||||||
def resolved_strategy(self, root_dir: str, encoding_model: str) -> dict:
|
def resolved_strategy(self, root_dir: str, encoding_model: str | None) -> dict:
|
||||||
"""Get the resolved claim extraction strategy."""
|
"""Get the resolved claim extraction strategy."""
|
||||||
from graphrag.index.operations.extract_covariates import (
|
from graphrag.index.operations.extract_covariates import (
|
||||||
ExtractClaimsStrategyType,
|
ExtractClaimsStrategyType,
|
||||||
@ -52,5 +52,5 @@ class ClaimExtractionConfig(LLMConfig):
|
|||||||
else None,
|
else None,
|
||||||
"claim_description": self.description,
|
"claim_description": self.description,
|
||||||
"max_gleanings": self.max_gleanings,
|
"max_gleanings": self.max_gleanings,
|
||||||
"encoding_name": self.encoding_model or encoding_model,
|
"encoding_name": encoding_model or self.encoding_model,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -32,7 +32,7 @@ class EntityExtractionConfig(LLMConfig):
|
|||||||
default=None, description="The encoding model to use."
|
default=None, description="The encoding model to use."
|
||||||
)
|
)
|
||||||
|
|
||||||
def resolved_strategy(self, root_dir: str, encoding_model: str) -> dict:
|
def resolved_strategy(self, root_dir: str, encoding_model: str | None) -> dict:
|
||||||
"""Get the resolved entity extraction strategy."""
|
"""Get the resolved entity extraction strategy."""
|
||||||
from graphrag.index.operations.extract_entities import (
|
from graphrag.index.operations.extract_entities import (
|
||||||
ExtractEntityStrategyType,
|
ExtractEntityStrategyType,
|
||||||
@ -49,6 +49,6 @@ class EntityExtractionConfig(LLMConfig):
|
|||||||
else None,
|
else None,
|
||||||
"max_gleanings": self.max_gleanings,
|
"max_gleanings": self.max_gleanings,
|
||||||
# It's prechunked in create_base_text_units
|
# It's prechunked in create_base_text_units
|
||||||
"encoding_name": self.encoding_model or encoding_model,
|
"encoding_name": encoding_model or self.encoding_model,
|
||||||
"prechunked": True,
|
"prechunked": True,
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user