From ce462515d8467b03ddc124ef8c65ec79de74e34e Mon Sep 17 00:00:00 2001 From: Alonso Guevara Date: Mon, 15 Jul 2024 13:01:56 -0600 Subject: [PATCH] Local search llm params (#533) * initialize config with LocalSearchConfig and GlobalSearchConfig * init_content LocalSearchConfig and GlobalSearchConfig * rollback MAP_SYSTEM_PROMPT * Small changes before merging. Notebook rollback * Semver --------- Co-authored-by: glide-the <2533736852@qq.com> --- .../next-release/patch-20240712235357550877.json | 4 ++++ graphrag/config/create_graphrag_config.py | 11 ++++++++--- graphrag/config/defaults.py | 6 ++++++ graphrag/config/models/global_search_config.py | 6 +++--- graphrag/config/models/local_search_config.py | 12 ++++++++++++ graphrag/index/init_content.py | 6 ++++++ graphrag/index/utils/json.py | 2 ++ graphrag/query/factories.py | 4 +++- tests/unit/config/test_default_config.py | 13 +++++++++++++ 9 files changed, 57 insertions(+), 7 deletions(-) create mode 100644 .semversioner/next-release/patch-20240712235357550877.json diff --git a/.semversioner/next-release/patch-20240712235357550877.json b/.semversioner/next-release/patch-20240712235357550877.json new file mode 100644 index 00000000..818d6098 --- /dev/null +++ b/.semversioner/next-release/patch-20240712235357550877.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Add llm params to local and global search" +} diff --git a/graphrag/config/create_graphrag_config.py b/graphrag/config/create_graphrag_config.py index 45edfc84..37a4477e 100644 --- a/graphrag/config/create_graphrag_config.py +++ b/graphrag/config/create_graphrag_config.py @@ -495,6 +495,10 @@ def create_graphrag_config( or defs.LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES, top_k_relationships=reader.int("top_k_relationships") or defs.LOCAL_SEARCH_TOP_K_RELATIONSHIPS, + temperature=reader.float("llm_temperature") + or defs.LOCAL_SEARCH_LLM_TEMPERATURE, + top_p=reader.float("llm_top_p") or defs.LOCAL_SEARCH_LLM_TOP_P, + n=reader.int("llm_n") or defs.LOCAL_SEARCH_LLM_N, max_tokens=reader.int(Fragment.max_tokens) or defs.LOCAL_SEARCH_MAX_TOKENS, llm_max_tokens=reader.int("llm_max_tokens") @@ -506,9 +510,10 @@ def create_graphrag_config( reader.envvar_prefix(Section.global_search), ): global_search_model = GlobalSearchConfig( - temperature=reader.float(Fragment.temperature) or defs.LLM_TEMPERATURE, - top_p=reader.float(Fragment.top_p) or defs.LLM_TOP_P, - n=reader.int(Fragment.n) or defs.LLM_N, + temperature=reader.float("llm_temperature") + or defs.GLOBAL_SEARCH_LLM_TEMPERATURE, + top_p=reader.float("llm_top_p") or defs.GLOBAL_SEARCH_LLM_TOP_P, + n=reader.int("llm_n") or defs.GLOBAL_SEARCH_LLM_N, max_tokens=reader.int(Fragment.max_tokens) or defs.GLOBAL_SEARCH_MAX_TOKENS, data_max_tokens=reader.int("data_max_tokens") diff --git a/graphrag/config/defaults.py b/graphrag/config/defaults.py index a2a23e80..4d648914 100644 --- a/graphrag/config/defaults.py +++ b/graphrag/config/defaults.py @@ -90,9 +90,15 @@ LOCAL_SEARCH_CONVERSATION_HISTORY_MAX_TURNS = 5 LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES = 10 LOCAL_SEARCH_TOP_K_RELATIONSHIPS = 10 LOCAL_SEARCH_MAX_TOKENS = 12_000 +LOCAL_SEARCH_LLM_TEMPERATURE = 0 +LOCAL_SEARCH_LLM_TOP_P = 1 +LOCAL_SEARCH_LLM_N = 1 LOCAL_SEARCH_LLM_MAX_TOKENS = 2000 # Global Search +GLOBAL_SEARCH_LLM_TEMPERATURE = 0 +GLOBAL_SEARCH_LLM_TOP_P = 1 +GLOBAL_SEARCH_LLM_N = 1 GLOBAL_SEARCH_MAX_TOKENS = 12_000 GLOBAL_SEARCH_DATA_MAX_TOKENS = 12_000 GLOBAL_SEARCH_MAP_MAX_TOKENS = 1000 diff --git a/graphrag/config/models/global_search_config.py b/graphrag/config/models/global_search_config.py index c7483f7c..9eb388c3 100644 --- a/graphrag/config/models/global_search_config.py +++ b/graphrag/config/models/global_search_config.py @@ -13,15 +13,15 @@ class GlobalSearchConfig(BaseModel): temperature: float | None = Field( description="The temperature to use for token generation.", - default=defs.LLM_TEMPERATURE, + default=defs.GLOBAL_SEARCH_LLM_TEMPERATURE, ) top_p: float | None = Field( description="The top-p value to use for token generation.", - default=defs.LLM_TOP_P, + default=defs.GLOBAL_SEARCH_LLM_TOP_P, ) n: int | None = Field( description="The number of completions to generate.", - default=defs.LLM_N, + default=defs.GLOBAL_SEARCH_LLM_N, ) max_tokens: int = Field( description="The maximum context size in tokens.", diff --git a/graphrag/config/models/local_search_config.py b/graphrag/config/models/local_search_config.py index 5fa5dd1e..c41344da 100644 --- a/graphrag/config/models/local_search_config.py +++ b/graphrag/config/models/local_search_config.py @@ -31,6 +31,18 @@ class LocalSearchConfig(BaseModel): description="The top k mapped relations.", default=defs.LOCAL_SEARCH_TOP_K_RELATIONSHIPS, ) + temperature: float | None = Field( + description="The temperature to use for token generation.", + default=defs.LOCAL_SEARCH_LLM_TEMPERATURE, + ) + top_p: float | None = Field( + description="The top-p value to use for token generation.", + default=defs.LOCAL_SEARCH_LLM_TOP_P, + ) + n: int | None = Field( + description="The number of completions to generate.", + default=defs.LOCAL_SEARCH_LLM_N, + ) max_tokens: int = Field( description="The maximum tokens.", default=defs.LOCAL_SEARCH_MAX_TOKENS ) diff --git a/graphrag/index/init_content.py b/graphrag/index/init_content.py index 13df1828..d2ad3906 100644 --- a/graphrag/index/init_content.py +++ b/graphrag/index/init_content.py @@ -144,9 +144,15 @@ local_search: # conversation_history_max_turns: {defs.LOCAL_SEARCH_CONVERSATION_HISTORY_MAX_TURNS} # top_k_mapped_entities: {defs.LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES} # top_k_relationships: {defs.LOCAL_SEARCH_TOP_K_RELATIONSHIPS} + # llm_temperature: {defs.LOCAL_SEARCH_LLM_TEMPERATURE} # temperature for sampling + # llm_top_p: {defs.LOCAL_SEARCH_LLM_TOP_P} # top-p sampling + # llm_n: {defs.LOCAL_SEARCH_LLM_N} # Number of completions to generate # max_tokens: {defs.LOCAL_SEARCH_MAX_TOKENS} global_search: + # llm_temperature: {defs.GLOBAL_SEARCH_LLM_TEMPERATURE} # temperature for sampling + # llm_top_p: {defs.GLOBAL_SEARCH_LLM_TOP_P} # top-p sampling + # llm_n: {defs.GLOBAL_SEARCH_LLM_N} # Number of completions to generate # max_tokens: {defs.GLOBAL_SEARCH_MAX_TOKENS} # data_max_tokens: {defs.GLOBAL_SEARCH_DATA_MAX_TOKENS} # map_max_tokens: {defs.GLOBAL_SEARCH_MAP_MAX_TOKENS} diff --git a/graphrag/index/utils/json.py b/graphrag/index/utils/json.py index 11d5b020..ed6c0666 100644 --- a/graphrag/index/utils/json.py +++ b/graphrag/index/utils/json.py @@ -19,6 +19,8 @@ def clean_up_json(json_str: str): # Remove JSON Markdown Frame if json_str.startswith("```json"): json_str = json_str[len("```json") :] + if json_str.startswith("json"): + json_str = json_str[len("json") :] if json_str.endswith("```"): json_str = json_str[: len(json_str) - len("```")] diff --git a/graphrag/query/factories.py b/graphrag/query/factories.py index ef630cfb..533fa153 100644 --- a/graphrag/query/factories.py +++ b/graphrag/query/factories.py @@ -130,7 +130,9 @@ def get_local_search_engine( token_encoder=token_encoder, llm_params={ "max_tokens": ls_config.llm_max_tokens, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 1000=1500) - "temperature": 0.0, + "temperature": ls_config.temperature, + "top_p": ls_config.top_p, + "n": ls_config.n, }, context_builder_params={ "text_unit_prop": ls_config.text_unit_prop, diff --git a/tests/unit/config/test_default_config.py b/tests/unit/config/test_default_config.py index 27a80cbd..6cde9475 100644 --- a/tests/unit/config/test_default_config.py +++ b/tests/unit/config/test_default_config.py @@ -170,11 +170,17 @@ ALL_ENV_VARS = { "GRAPHRAG_UMAP_ENABLED": "true", "GRAPHRAG_LOCAL_SEARCH_TEXT_UNIT_PROP": "0.713", "GRAPHRAG_LOCAL_SEARCH_COMMUNITY_PROP": "0.1234", + "GRAPHRAG_LOCAL_SEARCH_LLM_TEMPERATURE": "0.1", + "GRAPHRAG_LOCAL_SEARCH_LLM_TOP_P": "0.9", + "GRAPHRAG_LOCAL_SEARCH_LLM_N": "2", "GRAPHRAG_LOCAL_SEARCH_LLM_MAX_TOKENS": "12", "GRAPHRAG_LOCAL_SEARCH_TOP_K_RELATIONSHIPS": "15", "GRAPHRAG_LOCAL_SEARCH_TOP_K_ENTITIES": "14", "GRAPHRAG_LOCAL_SEARCH_CONVERSATION_HISTORY_MAX_TURNS": "2", "GRAPHRAG_LOCAL_SEARCH_MAX_TOKENS": "142435", + "GRAPHRAG_GLOBAL_SEARCH_LLM_TEMPERATURE": "0.1", + "GRAPHRAG_GLOBAL_SEARCH_LLM_TOP_P": "0.9", + "GRAPHRAG_GLOBAL_SEARCH_LLM_N": "2", "GRAPHRAG_GLOBAL_SEARCH_MAX_TOKENS": "5123", "GRAPHRAG_GLOBAL_SEARCH_DATA_MAX_TOKENS": "123", "GRAPHRAG_GLOBAL_SEARCH_MAP_MAX_TOKENS": "4123", @@ -605,7 +611,14 @@ class TestDefaultConfig(unittest.TestCase): assert parameters.local_search.top_k_relationships == 15 assert parameters.local_search.conversation_history_max_turns == 2 assert parameters.local_search.top_k_entities == 14 + assert parameters.local_search.temperature == 0.1 + assert parameters.local_search.top_p == 0.9 + assert parameters.local_search.n == 2 assert parameters.local_search.max_tokens == 142435 + + assert parameters.global_search.temperature == 0.1 + assert parameters.global_search.top_p == 0.9 + assert parameters.global_search.n == 2 assert parameters.global_search.max_tokens == 5123 assert parameters.global_search.data_max_tokens == 123 assert parameters.global_search.map_max_tokens == 4123