diff --git a/.semversioner/next-release/patch-20240705235656897489.json b/.semversioner/next-release/patch-20240705235656897489.json new file mode 100644 index 00000000..bb76d708 --- /dev/null +++ b/.semversioner/next-release/patch-20240705235656897489.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Add N parameter support" +} diff --git a/docsite/posts/config/env_vars.md b/docsite/posts/config/env_vars.md index a7256ee9..1523f05a 100644 --- a/docsite/posts/config/env_vars.md +++ b/docsite/posts/config/env_vars.md @@ -59,11 +59,14 @@ These settings control the text generation model used by the pipeline. Any setti | `GRAPHRAG_LLM_THREAD_COUNT` | | The number of threads to use for LLM parallelization. | `int` | 50 | | `GRAPHRAG_LLM_THREAD_STAGGER` | | The time to wait (in seconds) between starting each thread. | `float` | 0.3 | | `GRAPHRAG_LLM_CONCURRENT_REQUESTS` | | The number of concurrent requests to allow for the embedding client. | `int` | 25 | -| `GRAPHRAG_LLM_TPM` | | The number of tokens per minute to allow for the LLM client. 0 = Bypass | `int` | 0 | -| `GRAPHRAG_LLM_RPM` | | The number of requests per minute to allow for the LLM client. 0 = Bypass | `int` | 0 | +| `GRAPHRAG_LLM_TOKENS_PER_MINUTE` | | The number of tokens per minute to allow for the LLM client. 0 = Bypass | `int` | 0 | +| `GRAPHRAG_LLM_REQUESTS_PER_MINUTE` | | The number of requests per minute to allow for the LLM client. 0 = Bypass | `int` | 0 | | `GRAPHRAG_LLM_MAX_RETRIES` | | The maximum number of retries to attempt when a request fails. | `int` | 10 | | `GRAPHRAG_LLM_MAX_RETRY_WAIT` | | The maximum number of seconds to wait between retries. | `int` | 10 | | `GRAPHRAG_LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION` | | Whether to sleep on rate limit recommendation. (Azure Only) | `bool` | `True` | +| `GRAPHRAG_LLM_TEMPERATURE` | | The temperature to use generation. | `float` | 0 | +| `GRAPHRAG_LLM_TOP_P` | | The top_p to use for sampling. | `float` | 1 | +| `GRAPHRAG_LLM_N` | | The number of responses to generate. | `int` | 1 | ## Text Embedding Settings @@ -86,8 +89,8 @@ These settings control the text embedding model used by the pipeline. Any settin | `GRAPHRAG_EMBEDDING_THREAD_COUNT` | | The number of threads to use for parallelization for embeddings. | `int` | | | `GRAPHRAG_EMBEDDING_THREAD_STAGGER` | | The time to wait (in seconds) between starting each thread for embeddings. | `float` | 50 | | `GRAPHRAG_EMBEDDING_CONCURRENT_REQUESTS` | | The number of concurrent requests to allow for the embedding client. | `int` | 25 | -| `GRAPHRAG_EMBEDDING_TPM` | | The number of tokens per minute to allow for the embedding client. 0 = Bypass | `int` | 0 | -| `GRAPHRAG_EMBEDDING_RPM` | | The number of requests per minute to allow for the embedding client. 0 = Bypass | `int` | 0 | +| `GRAPHRAG_EMBEDDING_TOKENS_PER_MINUTE` | | The number of tokens per minute to allow for the embedding client. 0 = Bypass | `int` | 0 | +| `GRAPHRAG_EMBEDDING_REQUESTS_PER_MINUTE` | | The number of requests per minute to allow for the embedding client. 0 = Bypass | `int` | 0 | | `GRAPHRAG_EMBEDDING_MAX_RETRIES` | | The maximum number of retries to attempt when a request fails. | `int` | 10 | | `GRAPHRAG_EMBEDDING_MAX_RETRY_WAIT` | | The maximum number of seconds to wait between retries. | `int` | 10 | | `GRAPHRAG_EMBEDDING_TARGET` | | The target fields to embed. Either `required` or `all`. | `str` | `required` | diff --git a/docsite/posts/config/json_yaml.md b/docsite/posts/config/json_yaml.md index a9bb4464..b1350cde 100644 --- a/docsite/posts/config/json_yaml.md +++ b/docsite/posts/config/json_yaml.md @@ -67,6 +67,9 @@ This is the base LLM configuration section. Other steps may override this config - `max_retry_wait` **float** - The maximum backoff time. - `sleep_on_rate_limit_recommendation` **bool** - Whether to adhere to sleep recommendations (Azure). - `concurrent_requests` **int** The number of open requests to allow at once. +- `temperature` **float** - The temperature to use. +- `top_p` **float** - The top-p value to use. +- `n` **int** - The number of completions to generate. ## parallelization diff --git a/graphrag/config/create_graphrag_config.py b/graphrag/config/create_graphrag_config.py index af56f749..e712411c 100644 --- a/graphrag/config/create_graphrag_config.py +++ b/graphrag/config/create_graphrag_config.py @@ -114,6 +114,7 @@ def create_graphrag_config( max_tokens=reader.int(Fragment.max_tokens) or base.max_tokens, temperature=reader.float(Fragment.temperature) or base.temperature, top_p=reader.float(Fragment.top_p) or base.top_p, + n=reader.int(Fragment.n) or base.n, model_supports_json=reader.bool(Fragment.model_supports_json) or base.model_supports_json, request_timeout=reader.float(Fragment.request_timeout) @@ -251,6 +252,7 @@ def create_graphrag_config( temperature=reader.float(Fragment.temperature) or defs.LLM_TEMPERATURE, top_p=reader.float(Fragment.top_p) or defs.LLM_TOP_P, + n=reader.int(Fragment.n) or defs.LLM_N, model_supports_json=reader.bool(Fragment.model_supports_json), request_timeout=reader.float(Fragment.request_timeout) or defs.LLM_REQUEST_TIMEOUT, @@ -492,6 +494,7 @@ def create_graphrag_config( global_search_model = GlobalSearchConfig( temperature=reader.float(Fragment.temperature) or defs.LLM_TEMPERATURE, top_p=reader.float(Fragment.top_p) or defs.LLM_TOP_P, + n=reader.int(Fragment.n) or defs.LLM_N, max_tokens=reader.int(Fragment.max_tokens) or defs.GLOBAL_SEARCH_MAX_TOKENS, data_max_tokens=reader.int("data_max_tokens") @@ -559,6 +562,7 @@ class Fragment(str, Enum): max_tokens = "MAX_TOKENS" temperature = "TEMPERATURE" top_p = "TOP_P" + n = "N" model = "MODEL" model_supports_json = "MODEL_SUPPORTS_JSON" prompt_file = "PROMPT_FILE" diff --git a/graphrag/config/defaults.py b/graphrag/config/defaults.py index 1ce15fda..62dfacdb 100644 --- a/graphrag/config/defaults.py +++ b/graphrag/config/defaults.py @@ -25,6 +25,7 @@ LLM_MODEL = "gpt-4-turbo-preview" LLM_MAX_TOKENS = 4000 LLM_TEMPERATURE = 0 LLM_TOP_P = 1 +LLM_N = 1 LLM_REQUEST_TIMEOUT = 180.0 LLM_TOKENS_PER_MINUTE = 0 LLM_REQUESTS_PER_MINUTE = 0 diff --git a/graphrag/config/models/global_search_config.py b/graphrag/config/models/global_search_config.py index 19c52a03..c7483f7c 100644 --- a/graphrag/config/models/global_search_config.py +++ b/graphrag/config/models/global_search_config.py @@ -19,6 +19,10 @@ class GlobalSearchConfig(BaseModel): description="The top-p value to use for token generation.", default=defs.LLM_TOP_P, ) + n: int | None = Field( + description="The number of completions to generate.", + default=defs.LLM_N, + ) max_tokens: int = Field( description="The maximum context size in tokens.", default=defs.GLOBAL_SEARCH_MAX_TOKENS, diff --git a/graphrag/config/models/llm_parameters.py b/graphrag/config/models/llm_parameters.py index b558890f..df81138a 100644 --- a/graphrag/config/models/llm_parameters.py +++ b/graphrag/config/models/llm_parameters.py @@ -33,6 +33,10 @@ class LLMParameters(BaseModel): description="The top-p value to use for token generation.", default=defs.LLM_TOP_P, ) + n: int | None = Field( + description="The number of completions to generate.", + default=defs.LLM_N, + ) request_timeout: float = Field( description="The request timeout to use.", default=defs.LLM_REQUEST_TIMEOUT ) diff --git a/graphrag/index/init_content.py b/graphrag/index/init_content.py index c1342119..9e65b852 100644 --- a/graphrag/index/init_content.py +++ b/graphrag/index/init_content.py @@ -24,6 +24,9 @@ llm: # max_retry_wait: {defs.LLM_MAX_RETRY_WAIT} # sleep_on_rate_limit_recommendation: true # whether to sleep when azure suggests wait-times # concurrent_requests: {defs.LLM_CONCURRENT_REQUESTS} # the number of parallel inflight requests that may be made + # temperature: {defs.LLM_TEMPERATURE} # temperature for sampling + # top_p: {defs.LLM_TOP_P} # top-p sampling + # n: {defs.LLM_N} # Number of completions to generate parallelization: stagger: {defs.PARALLELIZATION_STAGGER} diff --git a/graphrag/index/llm/load_llm.py b/graphrag/index/llm/load_llm.py index 9af26efb..264229c8 100644 --- a/graphrag/index/llm/load_llm.py +++ b/graphrag/index/llm/load_llm.py @@ -111,6 +111,7 @@ def _load_openai_completion_llm( "presence_penalty": config.get("presence_penalty", 0), "top_p": config.get("top_p", 1), "max_tokens": config.get("max_tokens", 4000), + "n": config.get("n"), }), on_error, cache, @@ -135,6 +136,7 @@ def _load_openai_chat_llm( "presence_penalty": config.get("presence_penalty", 0), "top_p": config.get("top_p", 1), "max_tokens": config.get("max_tokens"), + "n": config.get("n"), }), on_error, cache, diff --git a/graphrag/query/factories.py b/graphrag/query/factories.py index 75211502..ef630cfb 100644 --- a/graphrag/query/factories.py +++ b/graphrag/query/factories.py @@ -50,11 +50,13 @@ def get_llm(config: GraphRagConfig) -> ChatOpenAI: print(f"creating llm client with {llm_debug_info}") # noqa T201 return ChatOpenAI( api_key=config.llm.api_key, - azure_ad_token_provider=get_bearer_token_provider( - DefaultAzureCredential(), cognitive_services_endpoint - ) - if is_azure_client and not config.llm.api_key - else None, + azure_ad_token_provider=( + get_bearer_token_provider( + DefaultAzureCredential(), cognitive_services_endpoint + ) + if is_azure_client and not config.llm.api_key + else None + ), api_base=config.llm.api_base, model=config.llm.model, api_type=OpenaiApiType.AzureOpenAI if is_azure_client else OpenaiApiType.OpenAI, @@ -79,11 +81,13 @@ def get_text_embedder(config: GraphRagConfig) -> OpenAIEmbedding: print(f"creating embedding llm client with {llm_debug_info}") # noqa T201 return OpenAIEmbedding( api_key=config.embeddings.llm.api_key, - azure_ad_token_provider=get_bearer_token_provider( - DefaultAzureCredential(), cognitive_services_endpoint - ) - if is_azure_client and not config.embeddings.llm.api_key - else None, + azure_ad_token_provider=( + get_bearer_token_provider( + DefaultAzureCredential(), cognitive_services_endpoint + ) + if is_azure_client and not config.embeddings.llm.api_key + else None + ), api_base=config.embeddings.llm.api_base, api_type=OpenaiApiType.AzureOpenAI if is_azure_client else OpenaiApiType.OpenAI, model=config.embeddings.llm.model, @@ -167,11 +171,13 @@ def get_global_search_engine( "max_tokens": gs_config.map_max_tokens, "temperature": gs_config.temperature, "top_p": gs_config.top_p, + "n": gs_config.n, }, reduce_llm_params={ "max_tokens": gs_config.reduce_max_tokens, "temperature": gs_config.temperature, "top_p": gs_config.top_p, + "n": gs_config.n, }, allow_general_knowledge=False, json_mode=False, diff --git a/tests/unit/config/test_default_config.py b/tests/unit/config/test_default_config.py index e98d39ee..2e009ad1 100644 --- a/tests/unit/config/test_default_config.py +++ b/tests/unit/config/test_default_config.py @@ -101,13 +101,13 @@ ALL_ENV_VARS = { "GRAPHRAG_EMBEDDING_MAX_RETRIES": "3", "GRAPHRAG_EMBEDDING_MAX_RETRY_WAIT": "0.1123", "GRAPHRAG_EMBEDDING_MODEL": "text-embedding-2", - "GRAPHRAG_EMBEDDING_RPM": "500", + "GRAPHRAG_EMBEDDING_REQUESTS_PER_MINUTE": "500", "GRAPHRAG_EMBEDDING_SKIP": "a1,b1,c1", "GRAPHRAG_EMBEDDING_SLEEP_ON_RATE_LIMIT_RECOMMENDATION": "False", "GRAPHRAG_EMBEDDING_TARGET": "all", "GRAPHRAG_EMBEDDING_THREAD_COUNT": "2345", "GRAPHRAG_EMBEDDING_THREAD_STAGGER": "0.456", - "GRAPHRAG_EMBEDDING_TPM": "7000", + "GRAPHRAG_EMBEDDING_TOKENS_PER_MINUTE": "7000", "GRAPHRAG_EMBEDDING_TYPE": "azure_openai_embedding", "GRAPHRAG_ENCODING_MODEL": "test123", "GRAPHRAG_INPUT_STORAGE_ACCOUNT_BLOB_URL": "input_account_blob_url", @@ -134,12 +134,13 @@ ALL_ENV_VARS = { "GRAPHRAG_LLM_MAX_TOKENS": "15000", "GRAPHRAG_LLM_MODEL_SUPPORTS_JSON": "true", "GRAPHRAG_LLM_MODEL": "test-llm", + "GRAPHRAG_LLM_N": "1", "GRAPHRAG_LLM_REQUEST_TIMEOUT": "12.7", - "GRAPHRAG_LLM_RPM": "900", + "GRAPHRAG_LLM_REQUESTS_PER_MINUTE": "900", "GRAPHRAG_LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION": "False", "GRAPHRAG_LLM_THREAD_COUNT": "987", "GRAPHRAG_LLM_THREAD_STAGGER": "0.123", - "GRAPHRAG_LLM_TPM": "8000", + "GRAPHRAG_LLM_TOKENS_PER_MINUTE": "8000", "GRAPHRAG_LLM_TYPE": "azure_openai_chat", "GRAPHRAG_MAX_CLUSTER_SIZE": "123", "GRAPHRAG_NODE2VEC_ENABLED": "true", @@ -164,6 +165,8 @@ ALL_ENV_VARS = { "GRAPHRAG_STORAGE_TYPE": "blob", "GRAPHRAG_SUMMARIZE_DESCRIPTIONS_MAX_LENGTH": "12345", "GRAPHRAG_SUMMARIZE_DESCRIPTIONS_PROMPT_FILE": "tests/unit/config/prompt-d.txt", + "GRAPHRAG_LLM_TEMPERATURE": "0.0", + "GRAPHRAG_LLM_TOP_P": "1.0", "GRAPHRAG_UMAP_ENABLED": "true", "GRAPHRAG_LOCAL_SEARCH_TEXT_UNIT_PROP": "0.713", "GRAPHRAG_LOCAL_SEARCH_COMMUNITY_PROP": "0.1234", @@ -562,11 +565,14 @@ class TestDefaultConfig(unittest.TestCase): assert parameters.llm.max_tokens == 15000 assert parameters.llm.model == "test-llm" assert parameters.llm.model_supports_json + assert parameters.llm.n == 1 assert parameters.llm.organization == "test_org" assert parameters.llm.proxy == "http://some/proxy" assert parameters.llm.request_timeout == 12.7 assert parameters.llm.requests_per_minute == 900 assert parameters.llm.sleep_on_rate_limit_recommendation is False + assert parameters.llm.temperature == 0.0 + assert parameters.llm.top_p == 1.0 assert parameters.llm.tokens_per_minute == 8000 assert parameters.llm.type == "azure_openai_chat" assert parameters.parallelization.num_threads == 987