Add N parameter support (#390)

* Add N parameter support

* Fix unit tests

* Add new env vars to param testing
This commit is contained in:
Alonso Guevara 2024-07-08 14:04:49 -06:00 committed by GitHub
parent a22003c302
commit b912081f1b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 58 additions and 18 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Add N parameter support"
}

View File

@ -59,11 +59,14 @@ These settings control the text generation model used by the pipeline. Any setti
| `GRAPHRAG_LLM_THREAD_COUNT` | | The number of threads to use for LLM parallelization. | `int` | 50 |
| `GRAPHRAG_LLM_THREAD_STAGGER` | | The time to wait (in seconds) between starting each thread. | `float` | 0.3 |
| `GRAPHRAG_LLM_CONCURRENT_REQUESTS` | | The number of concurrent requests to allow for the embedding client. | `int` | 25 |
| `GRAPHRAG_LLM_TPM` | | The number of tokens per minute to allow for the LLM client. 0 = Bypass | `int` | 0 |
| `GRAPHRAG_LLM_RPM` | | The number of requests per minute to allow for the LLM client. 0 = Bypass | `int` | 0 |
| `GRAPHRAG_LLM_TOKENS_PER_MINUTE` | | The number of tokens per minute to allow for the LLM client. 0 = Bypass | `int` | 0 |
| `GRAPHRAG_LLM_REQUESTS_PER_MINUTE` | | The number of requests per minute to allow for the LLM client. 0 = Bypass | `int` | 0 |
| `GRAPHRAG_LLM_MAX_RETRIES` | | The maximum number of retries to attempt when a request fails. | `int` | 10 |
| `GRAPHRAG_LLM_MAX_RETRY_WAIT` | | The maximum number of seconds to wait between retries. | `int` | 10 |
| `GRAPHRAG_LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION` | | Whether to sleep on rate limit recommendation. (Azure Only) | `bool` | `True` |
| `GRAPHRAG_LLM_TEMPERATURE` | | The temperature to use generation. | `float` | 0 |
| `GRAPHRAG_LLM_TOP_P` | | The top_p to use for sampling. | `float` | 1 |
| `GRAPHRAG_LLM_N` | | The number of responses to generate. | `int` | 1 |
## Text Embedding Settings
@ -86,8 +89,8 @@ These settings control the text embedding model used by the pipeline. Any settin
| `GRAPHRAG_EMBEDDING_THREAD_COUNT` | | The number of threads to use for parallelization for embeddings. | `int` | |
| `GRAPHRAG_EMBEDDING_THREAD_STAGGER` | | The time to wait (in seconds) between starting each thread for embeddings. | `float` | 50 |
| `GRAPHRAG_EMBEDDING_CONCURRENT_REQUESTS` | | The number of concurrent requests to allow for the embedding client. | `int` | 25 |
| `GRAPHRAG_EMBEDDING_TPM` | | The number of tokens per minute to allow for the embedding client. 0 = Bypass | `int` | 0 |
| `GRAPHRAG_EMBEDDING_RPM` | | The number of requests per minute to allow for the embedding client. 0 = Bypass | `int` | 0 |
| `GRAPHRAG_EMBEDDING_TOKENS_PER_MINUTE` | | The number of tokens per minute to allow for the embedding client. 0 = Bypass | `int` | 0 |
| `GRAPHRAG_EMBEDDING_REQUESTS_PER_MINUTE` | | The number of requests per minute to allow for the embedding client. 0 = Bypass | `int` | 0 |
| `GRAPHRAG_EMBEDDING_MAX_RETRIES` | | The maximum number of retries to attempt when a request fails. | `int` | 10 |
| `GRAPHRAG_EMBEDDING_MAX_RETRY_WAIT` | | The maximum number of seconds to wait between retries. | `int` | 10 |
| `GRAPHRAG_EMBEDDING_TARGET` | | The target fields to embed. Either `required` or `all`. | `str` | `required` |

View File

@ -67,6 +67,9 @@ This is the base LLM configuration section. Other steps may override this config
- `max_retry_wait` **float** - The maximum backoff time.
- `sleep_on_rate_limit_recommendation` **bool** - Whether to adhere to sleep recommendations (Azure).
- `concurrent_requests` **int** The number of open requests to allow at once.
- `temperature` **float** - The temperature to use.
- `top_p` **float** - The top-p value to use.
- `n` **int** - The number of completions to generate.
## parallelization

View File

@ -114,6 +114,7 @@ def create_graphrag_config(
max_tokens=reader.int(Fragment.max_tokens) or base.max_tokens,
temperature=reader.float(Fragment.temperature) or base.temperature,
top_p=reader.float(Fragment.top_p) or base.top_p,
n=reader.int(Fragment.n) or base.n,
model_supports_json=reader.bool(Fragment.model_supports_json)
or base.model_supports_json,
request_timeout=reader.float(Fragment.request_timeout)
@ -251,6 +252,7 @@ def create_graphrag_config(
temperature=reader.float(Fragment.temperature)
or defs.LLM_TEMPERATURE,
top_p=reader.float(Fragment.top_p) or defs.LLM_TOP_P,
n=reader.int(Fragment.n) or defs.LLM_N,
model_supports_json=reader.bool(Fragment.model_supports_json),
request_timeout=reader.float(Fragment.request_timeout)
or defs.LLM_REQUEST_TIMEOUT,
@ -492,6 +494,7 @@ def create_graphrag_config(
global_search_model = GlobalSearchConfig(
temperature=reader.float(Fragment.temperature) or defs.LLM_TEMPERATURE,
top_p=reader.float(Fragment.top_p) or defs.LLM_TOP_P,
n=reader.int(Fragment.n) or defs.LLM_N,
max_tokens=reader.int(Fragment.max_tokens)
or defs.GLOBAL_SEARCH_MAX_TOKENS,
data_max_tokens=reader.int("data_max_tokens")
@ -559,6 +562,7 @@ class Fragment(str, Enum):
max_tokens = "MAX_TOKENS"
temperature = "TEMPERATURE"
top_p = "TOP_P"
n = "N"
model = "MODEL"
model_supports_json = "MODEL_SUPPORTS_JSON"
prompt_file = "PROMPT_FILE"

View File

@ -25,6 +25,7 @@ LLM_MODEL = "gpt-4-turbo-preview"
LLM_MAX_TOKENS = 4000
LLM_TEMPERATURE = 0
LLM_TOP_P = 1
LLM_N = 1
LLM_REQUEST_TIMEOUT = 180.0
LLM_TOKENS_PER_MINUTE = 0
LLM_REQUESTS_PER_MINUTE = 0

View File

@ -19,6 +19,10 @@ class GlobalSearchConfig(BaseModel):
description="The top-p value to use for token generation.",
default=defs.LLM_TOP_P,
)
n: int | None = Field(
description="The number of completions to generate.",
default=defs.LLM_N,
)
max_tokens: int = Field(
description="The maximum context size in tokens.",
default=defs.GLOBAL_SEARCH_MAX_TOKENS,

View File

@ -33,6 +33,10 @@ class LLMParameters(BaseModel):
description="The top-p value to use for token generation.",
default=defs.LLM_TOP_P,
)
n: int | None = Field(
description="The number of completions to generate.",
default=defs.LLM_N,
)
request_timeout: float = Field(
description="The request timeout to use.", default=defs.LLM_REQUEST_TIMEOUT
)

View File

@ -24,6 +24,9 @@ llm:
# max_retry_wait: {defs.LLM_MAX_RETRY_WAIT}
# sleep_on_rate_limit_recommendation: true # whether to sleep when azure suggests wait-times
# concurrent_requests: {defs.LLM_CONCURRENT_REQUESTS} # the number of parallel inflight requests that may be made
# temperature: {defs.LLM_TEMPERATURE} # temperature for sampling
# top_p: {defs.LLM_TOP_P} # top-p sampling
# n: {defs.LLM_N} # Number of completions to generate
parallelization:
stagger: {defs.PARALLELIZATION_STAGGER}

View File

@ -111,6 +111,7 @@ def _load_openai_completion_llm(
"presence_penalty": config.get("presence_penalty", 0),
"top_p": config.get("top_p", 1),
"max_tokens": config.get("max_tokens", 4000),
"n": config.get("n"),
}),
on_error,
cache,
@ -135,6 +136,7 @@ def _load_openai_chat_llm(
"presence_penalty": config.get("presence_penalty", 0),
"top_p": config.get("top_p", 1),
"max_tokens": config.get("max_tokens"),
"n": config.get("n"),
}),
on_error,
cache,

View File

@ -50,11 +50,13 @@ def get_llm(config: GraphRagConfig) -> ChatOpenAI:
print(f"creating llm client with {llm_debug_info}") # noqa T201
return ChatOpenAI(
api_key=config.llm.api_key,
azure_ad_token_provider=get_bearer_token_provider(
DefaultAzureCredential(), cognitive_services_endpoint
)
if is_azure_client and not config.llm.api_key
else None,
azure_ad_token_provider=(
get_bearer_token_provider(
DefaultAzureCredential(), cognitive_services_endpoint
)
if is_azure_client and not config.llm.api_key
else None
),
api_base=config.llm.api_base,
model=config.llm.model,
api_type=OpenaiApiType.AzureOpenAI if is_azure_client else OpenaiApiType.OpenAI,
@ -79,11 +81,13 @@ def get_text_embedder(config: GraphRagConfig) -> OpenAIEmbedding:
print(f"creating embedding llm client with {llm_debug_info}") # noqa T201
return OpenAIEmbedding(
api_key=config.embeddings.llm.api_key,
azure_ad_token_provider=get_bearer_token_provider(
DefaultAzureCredential(), cognitive_services_endpoint
)
if is_azure_client and not config.embeddings.llm.api_key
else None,
azure_ad_token_provider=(
get_bearer_token_provider(
DefaultAzureCredential(), cognitive_services_endpoint
)
if is_azure_client and not config.embeddings.llm.api_key
else None
),
api_base=config.embeddings.llm.api_base,
api_type=OpenaiApiType.AzureOpenAI if is_azure_client else OpenaiApiType.OpenAI,
model=config.embeddings.llm.model,
@ -167,11 +171,13 @@ def get_global_search_engine(
"max_tokens": gs_config.map_max_tokens,
"temperature": gs_config.temperature,
"top_p": gs_config.top_p,
"n": gs_config.n,
},
reduce_llm_params={
"max_tokens": gs_config.reduce_max_tokens,
"temperature": gs_config.temperature,
"top_p": gs_config.top_p,
"n": gs_config.n,
},
allow_general_knowledge=False,
json_mode=False,

View File

@ -101,13 +101,13 @@ ALL_ENV_VARS = {
"GRAPHRAG_EMBEDDING_MAX_RETRIES": "3",
"GRAPHRAG_EMBEDDING_MAX_RETRY_WAIT": "0.1123",
"GRAPHRAG_EMBEDDING_MODEL": "text-embedding-2",
"GRAPHRAG_EMBEDDING_RPM": "500",
"GRAPHRAG_EMBEDDING_REQUESTS_PER_MINUTE": "500",
"GRAPHRAG_EMBEDDING_SKIP": "a1,b1,c1",
"GRAPHRAG_EMBEDDING_SLEEP_ON_RATE_LIMIT_RECOMMENDATION": "False",
"GRAPHRAG_EMBEDDING_TARGET": "all",
"GRAPHRAG_EMBEDDING_THREAD_COUNT": "2345",
"GRAPHRAG_EMBEDDING_THREAD_STAGGER": "0.456",
"GRAPHRAG_EMBEDDING_TPM": "7000",
"GRAPHRAG_EMBEDDING_TOKENS_PER_MINUTE": "7000",
"GRAPHRAG_EMBEDDING_TYPE": "azure_openai_embedding",
"GRAPHRAG_ENCODING_MODEL": "test123",
"GRAPHRAG_INPUT_STORAGE_ACCOUNT_BLOB_URL": "input_account_blob_url",
@ -134,12 +134,13 @@ ALL_ENV_VARS = {
"GRAPHRAG_LLM_MAX_TOKENS": "15000",
"GRAPHRAG_LLM_MODEL_SUPPORTS_JSON": "true",
"GRAPHRAG_LLM_MODEL": "test-llm",
"GRAPHRAG_LLM_N": "1",
"GRAPHRAG_LLM_REQUEST_TIMEOUT": "12.7",
"GRAPHRAG_LLM_RPM": "900",
"GRAPHRAG_LLM_REQUESTS_PER_MINUTE": "900",
"GRAPHRAG_LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION": "False",
"GRAPHRAG_LLM_THREAD_COUNT": "987",
"GRAPHRAG_LLM_THREAD_STAGGER": "0.123",
"GRAPHRAG_LLM_TPM": "8000",
"GRAPHRAG_LLM_TOKENS_PER_MINUTE": "8000",
"GRAPHRAG_LLM_TYPE": "azure_openai_chat",
"GRAPHRAG_MAX_CLUSTER_SIZE": "123",
"GRAPHRAG_NODE2VEC_ENABLED": "true",
@ -164,6 +165,8 @@ ALL_ENV_VARS = {
"GRAPHRAG_STORAGE_TYPE": "blob",
"GRAPHRAG_SUMMARIZE_DESCRIPTIONS_MAX_LENGTH": "12345",
"GRAPHRAG_SUMMARIZE_DESCRIPTIONS_PROMPT_FILE": "tests/unit/config/prompt-d.txt",
"GRAPHRAG_LLM_TEMPERATURE": "0.0",
"GRAPHRAG_LLM_TOP_P": "1.0",
"GRAPHRAG_UMAP_ENABLED": "true",
"GRAPHRAG_LOCAL_SEARCH_TEXT_UNIT_PROP": "0.713",
"GRAPHRAG_LOCAL_SEARCH_COMMUNITY_PROP": "0.1234",
@ -562,11 +565,14 @@ class TestDefaultConfig(unittest.TestCase):
assert parameters.llm.max_tokens == 15000
assert parameters.llm.model == "test-llm"
assert parameters.llm.model_supports_json
assert parameters.llm.n == 1
assert parameters.llm.organization == "test_org"
assert parameters.llm.proxy == "http://some/proxy"
assert parameters.llm.request_timeout == 12.7
assert parameters.llm.requests_per_minute == 900
assert parameters.llm.sleep_on_rate_limit_recommendation is False
assert parameters.llm.temperature == 0.0
assert parameters.llm.top_p == 1.0
assert parameters.llm.tokens_per_minute == 8000
assert parameters.llm.type == "azure_openai_chat"
assert parameters.parallelization.num_threads == 987