Local search llm params (#533)

* initialize config with  LocalSearchConfig and GlobalSearchConfig

* init_content LocalSearchConfig and GlobalSearchConfig

* rollback MAP_SYSTEM_PROMPT

* Small changes before merging. Notebook rollback

* Semver

---------

Co-authored-by: glide-the <2533736852@qq.com>
This commit is contained in:
Alonso Guevara 2024-07-15 13:01:56 -06:00 committed by GitHub
parent 5b283f2d29
commit ce462515d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 57 additions and 7 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Add llm params to local and global search"
}

View File

@ -495,6 +495,10 @@ def create_graphrag_config(
or defs.LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES, or defs.LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES,
top_k_relationships=reader.int("top_k_relationships") top_k_relationships=reader.int("top_k_relationships")
or defs.LOCAL_SEARCH_TOP_K_RELATIONSHIPS, or defs.LOCAL_SEARCH_TOP_K_RELATIONSHIPS,
temperature=reader.float("llm_temperature")
or defs.LOCAL_SEARCH_LLM_TEMPERATURE,
top_p=reader.float("llm_top_p") or defs.LOCAL_SEARCH_LLM_TOP_P,
n=reader.int("llm_n") or defs.LOCAL_SEARCH_LLM_N,
max_tokens=reader.int(Fragment.max_tokens) max_tokens=reader.int(Fragment.max_tokens)
or defs.LOCAL_SEARCH_MAX_TOKENS, or defs.LOCAL_SEARCH_MAX_TOKENS,
llm_max_tokens=reader.int("llm_max_tokens") llm_max_tokens=reader.int("llm_max_tokens")
@ -506,9 +510,10 @@ def create_graphrag_config(
reader.envvar_prefix(Section.global_search), reader.envvar_prefix(Section.global_search),
): ):
global_search_model = GlobalSearchConfig( global_search_model = GlobalSearchConfig(
temperature=reader.float(Fragment.temperature) or defs.LLM_TEMPERATURE, temperature=reader.float("llm_temperature")
top_p=reader.float(Fragment.top_p) or defs.LLM_TOP_P, or defs.GLOBAL_SEARCH_LLM_TEMPERATURE,
n=reader.int(Fragment.n) or defs.LLM_N, top_p=reader.float("llm_top_p") or defs.GLOBAL_SEARCH_LLM_TOP_P,
n=reader.int("llm_n") or defs.GLOBAL_SEARCH_LLM_N,
max_tokens=reader.int(Fragment.max_tokens) max_tokens=reader.int(Fragment.max_tokens)
or defs.GLOBAL_SEARCH_MAX_TOKENS, or defs.GLOBAL_SEARCH_MAX_TOKENS,
data_max_tokens=reader.int("data_max_tokens") data_max_tokens=reader.int("data_max_tokens")

View File

@ -90,9 +90,15 @@ LOCAL_SEARCH_CONVERSATION_HISTORY_MAX_TURNS = 5
LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES = 10 LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES = 10
LOCAL_SEARCH_TOP_K_RELATIONSHIPS = 10 LOCAL_SEARCH_TOP_K_RELATIONSHIPS = 10
LOCAL_SEARCH_MAX_TOKENS = 12_000 LOCAL_SEARCH_MAX_TOKENS = 12_000
LOCAL_SEARCH_LLM_TEMPERATURE = 0
LOCAL_SEARCH_LLM_TOP_P = 1
LOCAL_SEARCH_LLM_N = 1
LOCAL_SEARCH_LLM_MAX_TOKENS = 2000 LOCAL_SEARCH_LLM_MAX_TOKENS = 2000
# Global Search # Global Search
GLOBAL_SEARCH_LLM_TEMPERATURE = 0
GLOBAL_SEARCH_LLM_TOP_P = 1
GLOBAL_SEARCH_LLM_N = 1
GLOBAL_SEARCH_MAX_TOKENS = 12_000 GLOBAL_SEARCH_MAX_TOKENS = 12_000
GLOBAL_SEARCH_DATA_MAX_TOKENS = 12_000 GLOBAL_SEARCH_DATA_MAX_TOKENS = 12_000
GLOBAL_SEARCH_MAP_MAX_TOKENS = 1000 GLOBAL_SEARCH_MAP_MAX_TOKENS = 1000

View File

@ -13,15 +13,15 @@ class GlobalSearchConfig(BaseModel):
temperature: float | None = Field( temperature: float | None = Field(
description="The temperature to use for token generation.", description="The temperature to use for token generation.",
default=defs.LLM_TEMPERATURE, default=defs.GLOBAL_SEARCH_LLM_TEMPERATURE,
) )
top_p: float | None = Field( top_p: float | None = Field(
description="The top-p value to use for token generation.", description="The top-p value to use for token generation.",
default=defs.LLM_TOP_P, default=defs.GLOBAL_SEARCH_LLM_TOP_P,
) )
n: int | None = Field( n: int | None = Field(
description="The number of completions to generate.", description="The number of completions to generate.",
default=defs.LLM_N, default=defs.GLOBAL_SEARCH_LLM_N,
) )
max_tokens: int = Field( max_tokens: int = Field(
description="The maximum context size in tokens.", description="The maximum context size in tokens.",

View File

@ -31,6 +31,18 @@ class LocalSearchConfig(BaseModel):
description="The top k mapped relations.", description="The top k mapped relations.",
default=defs.LOCAL_SEARCH_TOP_K_RELATIONSHIPS, default=defs.LOCAL_SEARCH_TOP_K_RELATIONSHIPS,
) )
temperature: float | None = Field(
description="The temperature to use for token generation.",
default=defs.LOCAL_SEARCH_LLM_TEMPERATURE,
)
top_p: float | None = Field(
description="The top-p value to use for token generation.",
default=defs.LOCAL_SEARCH_LLM_TOP_P,
)
n: int | None = Field(
description="The number of completions to generate.",
default=defs.LOCAL_SEARCH_LLM_N,
)
max_tokens: int = Field( max_tokens: int = Field(
description="The maximum tokens.", default=defs.LOCAL_SEARCH_MAX_TOKENS description="The maximum tokens.", default=defs.LOCAL_SEARCH_MAX_TOKENS
) )

View File

@ -144,9 +144,15 @@ local_search:
# conversation_history_max_turns: {defs.LOCAL_SEARCH_CONVERSATION_HISTORY_MAX_TURNS} # conversation_history_max_turns: {defs.LOCAL_SEARCH_CONVERSATION_HISTORY_MAX_TURNS}
# top_k_mapped_entities: {defs.LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES} # top_k_mapped_entities: {defs.LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES}
# top_k_relationships: {defs.LOCAL_SEARCH_TOP_K_RELATIONSHIPS} # top_k_relationships: {defs.LOCAL_SEARCH_TOP_K_RELATIONSHIPS}
# llm_temperature: {defs.LOCAL_SEARCH_LLM_TEMPERATURE} # temperature for sampling
# llm_top_p: {defs.LOCAL_SEARCH_LLM_TOP_P} # top-p sampling
# llm_n: {defs.LOCAL_SEARCH_LLM_N} # Number of completions to generate
# max_tokens: {defs.LOCAL_SEARCH_MAX_TOKENS} # max_tokens: {defs.LOCAL_SEARCH_MAX_TOKENS}
global_search: global_search:
# llm_temperature: {defs.GLOBAL_SEARCH_LLM_TEMPERATURE} # temperature for sampling
# llm_top_p: {defs.GLOBAL_SEARCH_LLM_TOP_P} # top-p sampling
# llm_n: {defs.GLOBAL_SEARCH_LLM_N} # Number of completions to generate
# max_tokens: {defs.GLOBAL_SEARCH_MAX_TOKENS} # max_tokens: {defs.GLOBAL_SEARCH_MAX_TOKENS}
# data_max_tokens: {defs.GLOBAL_SEARCH_DATA_MAX_TOKENS} # data_max_tokens: {defs.GLOBAL_SEARCH_DATA_MAX_TOKENS}
# map_max_tokens: {defs.GLOBAL_SEARCH_MAP_MAX_TOKENS} # map_max_tokens: {defs.GLOBAL_SEARCH_MAP_MAX_TOKENS}

View File

@ -19,6 +19,8 @@ def clean_up_json(json_str: str):
# Remove JSON Markdown Frame # Remove JSON Markdown Frame
if json_str.startswith("```json"): if json_str.startswith("```json"):
json_str = json_str[len("```json") :] json_str = json_str[len("```json") :]
if json_str.startswith("json"):
json_str = json_str[len("json") :]
if json_str.endswith("```"): if json_str.endswith("```"):
json_str = json_str[: len(json_str) - len("```")] json_str = json_str[: len(json_str) - len("```")]

View File

@ -130,7 +130,9 @@ def get_local_search_engine(
token_encoder=token_encoder, token_encoder=token_encoder,
llm_params={ llm_params={
"max_tokens": ls_config.llm_max_tokens, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 1000=1500) "max_tokens": ls_config.llm_max_tokens, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 1000=1500)
"temperature": 0.0, "temperature": ls_config.temperature,
"top_p": ls_config.top_p,
"n": ls_config.n,
}, },
context_builder_params={ context_builder_params={
"text_unit_prop": ls_config.text_unit_prop, "text_unit_prop": ls_config.text_unit_prop,

View File

@ -170,11 +170,17 @@ ALL_ENV_VARS = {
"GRAPHRAG_UMAP_ENABLED": "true", "GRAPHRAG_UMAP_ENABLED": "true",
"GRAPHRAG_LOCAL_SEARCH_TEXT_UNIT_PROP": "0.713", "GRAPHRAG_LOCAL_SEARCH_TEXT_UNIT_PROP": "0.713",
"GRAPHRAG_LOCAL_SEARCH_COMMUNITY_PROP": "0.1234", "GRAPHRAG_LOCAL_SEARCH_COMMUNITY_PROP": "0.1234",
"GRAPHRAG_LOCAL_SEARCH_LLM_TEMPERATURE": "0.1",
"GRAPHRAG_LOCAL_SEARCH_LLM_TOP_P": "0.9",
"GRAPHRAG_LOCAL_SEARCH_LLM_N": "2",
"GRAPHRAG_LOCAL_SEARCH_LLM_MAX_TOKENS": "12", "GRAPHRAG_LOCAL_SEARCH_LLM_MAX_TOKENS": "12",
"GRAPHRAG_LOCAL_SEARCH_TOP_K_RELATIONSHIPS": "15", "GRAPHRAG_LOCAL_SEARCH_TOP_K_RELATIONSHIPS": "15",
"GRAPHRAG_LOCAL_SEARCH_TOP_K_ENTITIES": "14", "GRAPHRAG_LOCAL_SEARCH_TOP_K_ENTITIES": "14",
"GRAPHRAG_LOCAL_SEARCH_CONVERSATION_HISTORY_MAX_TURNS": "2", "GRAPHRAG_LOCAL_SEARCH_CONVERSATION_HISTORY_MAX_TURNS": "2",
"GRAPHRAG_LOCAL_SEARCH_MAX_TOKENS": "142435", "GRAPHRAG_LOCAL_SEARCH_MAX_TOKENS": "142435",
"GRAPHRAG_GLOBAL_SEARCH_LLM_TEMPERATURE": "0.1",
"GRAPHRAG_GLOBAL_SEARCH_LLM_TOP_P": "0.9",
"GRAPHRAG_GLOBAL_SEARCH_LLM_N": "2",
"GRAPHRAG_GLOBAL_SEARCH_MAX_TOKENS": "5123", "GRAPHRAG_GLOBAL_SEARCH_MAX_TOKENS": "5123",
"GRAPHRAG_GLOBAL_SEARCH_DATA_MAX_TOKENS": "123", "GRAPHRAG_GLOBAL_SEARCH_DATA_MAX_TOKENS": "123",
"GRAPHRAG_GLOBAL_SEARCH_MAP_MAX_TOKENS": "4123", "GRAPHRAG_GLOBAL_SEARCH_MAP_MAX_TOKENS": "4123",
@ -605,7 +611,14 @@ class TestDefaultConfig(unittest.TestCase):
assert parameters.local_search.top_k_relationships == 15 assert parameters.local_search.top_k_relationships == 15
assert parameters.local_search.conversation_history_max_turns == 2 assert parameters.local_search.conversation_history_max_turns == 2
assert parameters.local_search.top_k_entities == 14 assert parameters.local_search.top_k_entities == 14
assert parameters.local_search.temperature == 0.1
assert parameters.local_search.top_p == 0.9
assert parameters.local_search.n == 2
assert parameters.local_search.max_tokens == 142435 assert parameters.local_search.max_tokens == 142435
assert parameters.global_search.temperature == 0.1
assert parameters.global_search.top_p == 0.9
assert parameters.global_search.n == 2
assert parameters.global_search.max_tokens == 5123 assert parameters.global_search.max_tokens == 5123
assert parameters.global_search.data_max_tokens == 123 assert parameters.global_search.data_max_tokens == 123
assert parameters.global_search.map_max_tokens == 4123 assert parameters.global_search.map_max_tokens == 4123