mirror of
https://github.com/microsoft/graphrag.git
synced 2025-12-27 15:10:00 +00:00
Add N parameter support (#390)
* Add N parameter support * Fix unit tests * Add new env vars to param testing
This commit is contained in:
parent
a22003c302
commit
b912081f1b
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Add N parameter support"
|
||||
}
|
||||
@ -59,11 +59,14 @@ These settings control the text generation model used by the pipeline. Any setti
|
||||
| `GRAPHRAG_LLM_THREAD_COUNT` | | The number of threads to use for LLM parallelization. | `int` | 50 |
|
||||
| `GRAPHRAG_LLM_THREAD_STAGGER` | | The time to wait (in seconds) between starting each thread. | `float` | 0.3 |
|
||||
| `GRAPHRAG_LLM_CONCURRENT_REQUESTS` | | The number of concurrent requests to allow for the embedding client. | `int` | 25 |
|
||||
| `GRAPHRAG_LLM_TPM` | | The number of tokens per minute to allow for the LLM client. 0 = Bypass | `int` | 0 |
|
||||
| `GRAPHRAG_LLM_RPM` | | The number of requests per minute to allow for the LLM client. 0 = Bypass | `int` | 0 |
|
||||
| `GRAPHRAG_LLM_TOKENS_PER_MINUTE` | | The number of tokens per minute to allow for the LLM client. 0 = Bypass | `int` | 0 |
|
||||
| `GRAPHRAG_LLM_REQUESTS_PER_MINUTE` | | The number of requests per minute to allow for the LLM client. 0 = Bypass | `int` | 0 |
|
||||
| `GRAPHRAG_LLM_MAX_RETRIES` | | The maximum number of retries to attempt when a request fails. | `int` | 10 |
|
||||
| `GRAPHRAG_LLM_MAX_RETRY_WAIT` | | The maximum number of seconds to wait between retries. | `int` | 10 |
|
||||
| `GRAPHRAG_LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION` | | Whether to sleep on rate limit recommendation. (Azure Only) | `bool` | `True` |
|
||||
| `GRAPHRAG_LLM_TEMPERATURE` | | The temperature to use generation. | `float` | 0 |
|
||||
| `GRAPHRAG_LLM_TOP_P` | | The top_p to use for sampling. | `float` | 1 |
|
||||
| `GRAPHRAG_LLM_N` | | The number of responses to generate. | `int` | 1 |
|
||||
|
||||
## Text Embedding Settings
|
||||
|
||||
@ -86,8 +89,8 @@ These settings control the text embedding model used by the pipeline. Any settin
|
||||
| `GRAPHRAG_EMBEDDING_THREAD_COUNT` | | The number of threads to use for parallelization for embeddings. | `int` | |
|
||||
| `GRAPHRAG_EMBEDDING_THREAD_STAGGER` | | The time to wait (in seconds) between starting each thread for embeddings. | `float` | 50 |
|
||||
| `GRAPHRAG_EMBEDDING_CONCURRENT_REQUESTS` | | The number of concurrent requests to allow for the embedding client. | `int` | 25 |
|
||||
| `GRAPHRAG_EMBEDDING_TPM` | | The number of tokens per minute to allow for the embedding client. 0 = Bypass | `int` | 0 |
|
||||
| `GRAPHRAG_EMBEDDING_RPM` | | The number of requests per minute to allow for the embedding client. 0 = Bypass | `int` | 0 |
|
||||
| `GRAPHRAG_EMBEDDING_TOKENS_PER_MINUTE` | | The number of tokens per minute to allow for the embedding client. 0 = Bypass | `int` | 0 |
|
||||
| `GRAPHRAG_EMBEDDING_REQUESTS_PER_MINUTE` | | The number of requests per minute to allow for the embedding client. 0 = Bypass | `int` | 0 |
|
||||
| `GRAPHRAG_EMBEDDING_MAX_RETRIES` | | The maximum number of retries to attempt when a request fails. | `int` | 10 |
|
||||
| `GRAPHRAG_EMBEDDING_MAX_RETRY_WAIT` | | The maximum number of seconds to wait between retries. | `int` | 10 |
|
||||
| `GRAPHRAG_EMBEDDING_TARGET` | | The target fields to embed. Either `required` or `all`. | `str` | `required` |
|
||||
|
||||
@ -67,6 +67,9 @@ This is the base LLM configuration section. Other steps may override this config
|
||||
- `max_retry_wait` **float** - The maximum backoff time.
|
||||
- `sleep_on_rate_limit_recommendation` **bool** - Whether to adhere to sleep recommendations (Azure).
|
||||
- `concurrent_requests` **int** The number of open requests to allow at once.
|
||||
- `temperature` **float** - The temperature to use.
|
||||
- `top_p` **float** - The top-p value to use.
|
||||
- `n` **int** - The number of completions to generate.
|
||||
|
||||
## parallelization
|
||||
|
||||
|
||||
@ -114,6 +114,7 @@ def create_graphrag_config(
|
||||
max_tokens=reader.int(Fragment.max_tokens) or base.max_tokens,
|
||||
temperature=reader.float(Fragment.temperature) or base.temperature,
|
||||
top_p=reader.float(Fragment.top_p) or base.top_p,
|
||||
n=reader.int(Fragment.n) or base.n,
|
||||
model_supports_json=reader.bool(Fragment.model_supports_json)
|
||||
or base.model_supports_json,
|
||||
request_timeout=reader.float(Fragment.request_timeout)
|
||||
@ -251,6 +252,7 @@ def create_graphrag_config(
|
||||
temperature=reader.float(Fragment.temperature)
|
||||
or defs.LLM_TEMPERATURE,
|
||||
top_p=reader.float(Fragment.top_p) or defs.LLM_TOP_P,
|
||||
n=reader.int(Fragment.n) or defs.LLM_N,
|
||||
model_supports_json=reader.bool(Fragment.model_supports_json),
|
||||
request_timeout=reader.float(Fragment.request_timeout)
|
||||
or defs.LLM_REQUEST_TIMEOUT,
|
||||
@ -492,6 +494,7 @@ def create_graphrag_config(
|
||||
global_search_model = GlobalSearchConfig(
|
||||
temperature=reader.float(Fragment.temperature) or defs.LLM_TEMPERATURE,
|
||||
top_p=reader.float(Fragment.top_p) or defs.LLM_TOP_P,
|
||||
n=reader.int(Fragment.n) or defs.LLM_N,
|
||||
max_tokens=reader.int(Fragment.max_tokens)
|
||||
or defs.GLOBAL_SEARCH_MAX_TOKENS,
|
||||
data_max_tokens=reader.int("data_max_tokens")
|
||||
@ -559,6 +562,7 @@ class Fragment(str, Enum):
|
||||
max_tokens = "MAX_TOKENS"
|
||||
temperature = "TEMPERATURE"
|
||||
top_p = "TOP_P"
|
||||
n = "N"
|
||||
model = "MODEL"
|
||||
model_supports_json = "MODEL_SUPPORTS_JSON"
|
||||
prompt_file = "PROMPT_FILE"
|
||||
|
||||
@ -25,6 +25,7 @@ LLM_MODEL = "gpt-4-turbo-preview"
|
||||
LLM_MAX_TOKENS = 4000
|
||||
LLM_TEMPERATURE = 0
|
||||
LLM_TOP_P = 1
|
||||
LLM_N = 1
|
||||
LLM_REQUEST_TIMEOUT = 180.0
|
||||
LLM_TOKENS_PER_MINUTE = 0
|
||||
LLM_REQUESTS_PER_MINUTE = 0
|
||||
|
||||
@ -19,6 +19,10 @@ class GlobalSearchConfig(BaseModel):
|
||||
description="The top-p value to use for token generation.",
|
||||
default=defs.LLM_TOP_P,
|
||||
)
|
||||
n: int | None = Field(
|
||||
description="The number of completions to generate.",
|
||||
default=defs.LLM_N,
|
||||
)
|
||||
max_tokens: int = Field(
|
||||
description="The maximum context size in tokens.",
|
||||
default=defs.GLOBAL_SEARCH_MAX_TOKENS,
|
||||
|
||||
@ -33,6 +33,10 @@ class LLMParameters(BaseModel):
|
||||
description="The top-p value to use for token generation.",
|
||||
default=defs.LLM_TOP_P,
|
||||
)
|
||||
n: int | None = Field(
|
||||
description="The number of completions to generate.",
|
||||
default=defs.LLM_N,
|
||||
)
|
||||
request_timeout: float = Field(
|
||||
description="The request timeout to use.", default=defs.LLM_REQUEST_TIMEOUT
|
||||
)
|
||||
|
||||
@ -24,6 +24,9 @@ llm:
|
||||
# max_retry_wait: {defs.LLM_MAX_RETRY_WAIT}
|
||||
# sleep_on_rate_limit_recommendation: true # whether to sleep when azure suggests wait-times
|
||||
# concurrent_requests: {defs.LLM_CONCURRENT_REQUESTS} # the number of parallel inflight requests that may be made
|
||||
# temperature: {defs.LLM_TEMPERATURE} # temperature for sampling
|
||||
# top_p: {defs.LLM_TOP_P} # top-p sampling
|
||||
# n: {defs.LLM_N} # Number of completions to generate
|
||||
|
||||
parallelization:
|
||||
stagger: {defs.PARALLELIZATION_STAGGER}
|
||||
|
||||
@ -111,6 +111,7 @@ def _load_openai_completion_llm(
|
||||
"presence_penalty": config.get("presence_penalty", 0),
|
||||
"top_p": config.get("top_p", 1),
|
||||
"max_tokens": config.get("max_tokens", 4000),
|
||||
"n": config.get("n"),
|
||||
}),
|
||||
on_error,
|
||||
cache,
|
||||
@ -135,6 +136,7 @@ def _load_openai_chat_llm(
|
||||
"presence_penalty": config.get("presence_penalty", 0),
|
||||
"top_p": config.get("top_p", 1),
|
||||
"max_tokens": config.get("max_tokens"),
|
||||
"n": config.get("n"),
|
||||
}),
|
||||
on_error,
|
||||
cache,
|
||||
|
||||
@ -50,11 +50,13 @@ def get_llm(config: GraphRagConfig) -> ChatOpenAI:
|
||||
print(f"creating llm client with {llm_debug_info}") # noqa T201
|
||||
return ChatOpenAI(
|
||||
api_key=config.llm.api_key,
|
||||
azure_ad_token_provider=get_bearer_token_provider(
|
||||
DefaultAzureCredential(), cognitive_services_endpoint
|
||||
)
|
||||
if is_azure_client and not config.llm.api_key
|
||||
else None,
|
||||
azure_ad_token_provider=(
|
||||
get_bearer_token_provider(
|
||||
DefaultAzureCredential(), cognitive_services_endpoint
|
||||
)
|
||||
if is_azure_client and not config.llm.api_key
|
||||
else None
|
||||
),
|
||||
api_base=config.llm.api_base,
|
||||
model=config.llm.model,
|
||||
api_type=OpenaiApiType.AzureOpenAI if is_azure_client else OpenaiApiType.OpenAI,
|
||||
@ -79,11 +81,13 @@ def get_text_embedder(config: GraphRagConfig) -> OpenAIEmbedding:
|
||||
print(f"creating embedding llm client with {llm_debug_info}") # noqa T201
|
||||
return OpenAIEmbedding(
|
||||
api_key=config.embeddings.llm.api_key,
|
||||
azure_ad_token_provider=get_bearer_token_provider(
|
||||
DefaultAzureCredential(), cognitive_services_endpoint
|
||||
)
|
||||
if is_azure_client and not config.embeddings.llm.api_key
|
||||
else None,
|
||||
azure_ad_token_provider=(
|
||||
get_bearer_token_provider(
|
||||
DefaultAzureCredential(), cognitive_services_endpoint
|
||||
)
|
||||
if is_azure_client and not config.embeddings.llm.api_key
|
||||
else None
|
||||
),
|
||||
api_base=config.embeddings.llm.api_base,
|
||||
api_type=OpenaiApiType.AzureOpenAI if is_azure_client else OpenaiApiType.OpenAI,
|
||||
model=config.embeddings.llm.model,
|
||||
@ -167,11 +171,13 @@ def get_global_search_engine(
|
||||
"max_tokens": gs_config.map_max_tokens,
|
||||
"temperature": gs_config.temperature,
|
||||
"top_p": gs_config.top_p,
|
||||
"n": gs_config.n,
|
||||
},
|
||||
reduce_llm_params={
|
||||
"max_tokens": gs_config.reduce_max_tokens,
|
||||
"temperature": gs_config.temperature,
|
||||
"top_p": gs_config.top_p,
|
||||
"n": gs_config.n,
|
||||
},
|
||||
allow_general_knowledge=False,
|
||||
json_mode=False,
|
||||
|
||||
@ -101,13 +101,13 @@ ALL_ENV_VARS = {
|
||||
"GRAPHRAG_EMBEDDING_MAX_RETRIES": "3",
|
||||
"GRAPHRAG_EMBEDDING_MAX_RETRY_WAIT": "0.1123",
|
||||
"GRAPHRAG_EMBEDDING_MODEL": "text-embedding-2",
|
||||
"GRAPHRAG_EMBEDDING_RPM": "500",
|
||||
"GRAPHRAG_EMBEDDING_REQUESTS_PER_MINUTE": "500",
|
||||
"GRAPHRAG_EMBEDDING_SKIP": "a1,b1,c1",
|
||||
"GRAPHRAG_EMBEDDING_SLEEP_ON_RATE_LIMIT_RECOMMENDATION": "False",
|
||||
"GRAPHRAG_EMBEDDING_TARGET": "all",
|
||||
"GRAPHRAG_EMBEDDING_THREAD_COUNT": "2345",
|
||||
"GRAPHRAG_EMBEDDING_THREAD_STAGGER": "0.456",
|
||||
"GRAPHRAG_EMBEDDING_TPM": "7000",
|
||||
"GRAPHRAG_EMBEDDING_TOKENS_PER_MINUTE": "7000",
|
||||
"GRAPHRAG_EMBEDDING_TYPE": "azure_openai_embedding",
|
||||
"GRAPHRAG_ENCODING_MODEL": "test123",
|
||||
"GRAPHRAG_INPUT_STORAGE_ACCOUNT_BLOB_URL": "input_account_blob_url",
|
||||
@ -134,12 +134,13 @@ ALL_ENV_VARS = {
|
||||
"GRAPHRAG_LLM_MAX_TOKENS": "15000",
|
||||
"GRAPHRAG_LLM_MODEL_SUPPORTS_JSON": "true",
|
||||
"GRAPHRAG_LLM_MODEL": "test-llm",
|
||||
"GRAPHRAG_LLM_N": "1",
|
||||
"GRAPHRAG_LLM_REQUEST_TIMEOUT": "12.7",
|
||||
"GRAPHRAG_LLM_RPM": "900",
|
||||
"GRAPHRAG_LLM_REQUESTS_PER_MINUTE": "900",
|
||||
"GRAPHRAG_LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION": "False",
|
||||
"GRAPHRAG_LLM_THREAD_COUNT": "987",
|
||||
"GRAPHRAG_LLM_THREAD_STAGGER": "0.123",
|
||||
"GRAPHRAG_LLM_TPM": "8000",
|
||||
"GRAPHRAG_LLM_TOKENS_PER_MINUTE": "8000",
|
||||
"GRAPHRAG_LLM_TYPE": "azure_openai_chat",
|
||||
"GRAPHRAG_MAX_CLUSTER_SIZE": "123",
|
||||
"GRAPHRAG_NODE2VEC_ENABLED": "true",
|
||||
@ -164,6 +165,8 @@ ALL_ENV_VARS = {
|
||||
"GRAPHRAG_STORAGE_TYPE": "blob",
|
||||
"GRAPHRAG_SUMMARIZE_DESCRIPTIONS_MAX_LENGTH": "12345",
|
||||
"GRAPHRAG_SUMMARIZE_DESCRIPTIONS_PROMPT_FILE": "tests/unit/config/prompt-d.txt",
|
||||
"GRAPHRAG_LLM_TEMPERATURE": "0.0",
|
||||
"GRAPHRAG_LLM_TOP_P": "1.0",
|
||||
"GRAPHRAG_UMAP_ENABLED": "true",
|
||||
"GRAPHRAG_LOCAL_SEARCH_TEXT_UNIT_PROP": "0.713",
|
||||
"GRAPHRAG_LOCAL_SEARCH_COMMUNITY_PROP": "0.1234",
|
||||
@ -562,11 +565,14 @@ class TestDefaultConfig(unittest.TestCase):
|
||||
assert parameters.llm.max_tokens == 15000
|
||||
assert parameters.llm.model == "test-llm"
|
||||
assert parameters.llm.model_supports_json
|
||||
assert parameters.llm.n == 1
|
||||
assert parameters.llm.organization == "test_org"
|
||||
assert parameters.llm.proxy == "http://some/proxy"
|
||||
assert parameters.llm.request_timeout == 12.7
|
||||
assert parameters.llm.requests_per_minute == 900
|
||||
assert parameters.llm.sleep_on_rate_limit_recommendation is False
|
||||
assert parameters.llm.temperature == 0.0
|
||||
assert parameters.llm.top_p == 1.0
|
||||
assert parameters.llm.tokens_per_minute == 8000
|
||||
assert parameters.llm.type == "azure_openai_chat"
|
||||
assert parameters.parallelization.num_threads == 987
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user