mirror of
https://github.com/microsoft/graphrag.git
synced 2025-12-28 23:49:20 +00:00
Fix/feat: Implementation of Minute-Based Rate Limiting in CommunityReportsExtractor Using asyncio and async_mode (#373)
* RateLimiter: The original TpmRpmLLMLimiter strategy did not account for minute-based rate limiting when scheduled. The RateLimiter was introduced to ensure that the CommunityReportsExtractor could be scheduled to adhere to rate configurations on a per-minute basis. RateLimiter scheduled: using asyncio and async_mode Additionally, some key loading issues for rpm = "REQUESTS_PER_MINUTE" and tpm = "TOKENS_PER_MINUTE" were fixed. Configuration loading was also enhanced to include temperature = "TEMPERATURE" and top_p = "TOP_P" settings. * RateLimiter scheduled: using asyncio and async_mode * Additionally, some key loading issues for rpm = "REQUESTS_PER_MINUTE" and tpm = "TOKENS_PER_MINUTE" were fixed. Configuration loading was also enhanced to include temperature = "TEMPERATURE" and top_p = "TOP_P" settings. * Format * Semversioner * Format and cleanup --------- Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
This commit is contained in:
parent
daca75ff79
commit
6865d60133
1
.gitignore
vendored
1
.gitignore
vendored
@ -7,6 +7,7 @@ docsite/*/docsTemp/
|
||||
docsite/*/build/
|
||||
.swc/
|
||||
dist/
|
||||
.idea
|
||||
# https://yarnpkg.com/advanced/qa#which-files-should-be-gitignored
|
||||
docsite/.yarn/*
|
||||
!docsite/.yarn/patches
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Add Minute-based Rate Limiting and fix rpm, tpm settings"
|
||||
}
|
||||
@ -112,6 +112,8 @@ def create_graphrag_config(
|
||||
proxy=reader.str("proxy") or base.proxy,
|
||||
model=reader.str("model") or base.model,
|
||||
max_tokens=reader.int(Fragment.max_tokens) or base.max_tokens,
|
||||
temperature=reader.float(Fragment.temperature) or base.temperature,
|
||||
top_p=reader.float(Fragment.top_p) or base.top_p,
|
||||
model_supports_json=reader.bool(Fragment.model_supports_json)
|
||||
or base.model_supports_json,
|
||||
request_timeout=reader.float(Fragment.request_timeout)
|
||||
@ -246,6 +248,9 @@ def create_graphrag_config(
|
||||
type=llm_type,
|
||||
model=reader.str(Fragment.model) or defs.LLM_MODEL,
|
||||
max_tokens=reader.int(Fragment.max_tokens) or defs.LLM_MAX_TOKENS,
|
||||
temperature=reader.float(Fragment.temperature)
|
||||
or defs.LLM_TEMPERATURE,
|
||||
top_p=reader.float(Fragment.top_p) or defs.LLM_TOP_P,
|
||||
model_supports_json=reader.bool(Fragment.model_supports_json),
|
||||
request_timeout=reader.float(Fragment.request_timeout)
|
||||
or defs.LLM_REQUEST_TIMEOUT,
|
||||
@ -485,6 +490,8 @@ def create_graphrag_config(
|
||||
reader.envvar_prefix(Section.global_search),
|
||||
):
|
||||
global_search_model = GlobalSearchConfig(
|
||||
temperature=reader.float(Fragment.temperature) or defs.LLM_TEMPERATURE,
|
||||
top_p=reader.float(Fragment.top_p) or defs.LLM_TOP_P,
|
||||
max_tokens=reader.int(Fragment.max_tokens)
|
||||
or defs.GLOBAL_SEARCH_MAX_TOKENS,
|
||||
data_max_tokens=reader.int("data_max_tokens")
|
||||
@ -550,16 +557,18 @@ class Fragment(str, Enum):
|
||||
max_retries = "MAX_RETRIES"
|
||||
max_retry_wait = "MAX_RETRY_WAIT"
|
||||
max_tokens = "MAX_TOKENS"
|
||||
temperature = "TEMPERATURE"
|
||||
top_p = "TOP_P"
|
||||
model = "MODEL"
|
||||
model_supports_json = "MODEL_SUPPORTS_JSON"
|
||||
prompt_file = "PROMPT_FILE"
|
||||
request_timeout = "REQUEST_TIMEOUT"
|
||||
rpm = "RPM"
|
||||
rpm = "REQUESTS_PER_MINUTE"
|
||||
sleep_recommendation = "SLEEP_ON_RATE_LIMIT_RECOMMENDATION"
|
||||
storage_account_blob_url = "STORAGE_ACCOUNT_BLOB_URL"
|
||||
thread_count = "THREAD_COUNT"
|
||||
thread_stagger = "THREAD_STAGGER"
|
||||
tpm = "TPM"
|
||||
tpm = "TOKENS_PER_MINUTE"
|
||||
type = "TYPE"
|
||||
|
||||
|
||||
|
||||
@ -23,6 +23,8 @@ ENCODING_MODEL = "cl100k_base"
|
||||
LLM_TYPE = LLMType.OpenAIChat
|
||||
LLM_MODEL = "gpt-4-turbo-preview"
|
||||
LLM_MAX_TOKENS = 4000
|
||||
LLM_TEMPERATURE = 0
|
||||
LLM_TOP_P = 1
|
||||
LLM_REQUEST_TIMEOUT = 180.0
|
||||
LLM_TOKENS_PER_MINUTE = 0
|
||||
LLM_REQUESTS_PER_MINUTE = 0
|
||||
|
||||
@ -11,6 +11,14 @@ import graphrag.config.defaults as defs
|
||||
class GlobalSearchConfig(BaseModel):
|
||||
"""The default configuration section for Cache."""
|
||||
|
||||
temperature: float | None = Field(
|
||||
description="The temperature to use for token generation.",
|
||||
default=defs.LLM_TEMPERATURE,
|
||||
)
|
||||
top_p: float | None = Field(
|
||||
description="The top-p value to use for token generation.",
|
||||
default=defs.LLM_TOP_P,
|
||||
)
|
||||
max_tokens: int = Field(
|
||||
description="The maximum context size in tokens.",
|
||||
default=defs.GLOBAL_SEARCH_MAX_TOKENS,
|
||||
|
||||
@ -25,6 +25,14 @@ class LLMParameters(BaseModel):
|
||||
description="The maximum number of tokens to generate.",
|
||||
default=defs.LLM_MAX_TOKENS,
|
||||
)
|
||||
temperature: float | None = Field(
|
||||
description="The temperature to use for token generation.",
|
||||
default=defs.LLM_TEMPERATURE,
|
||||
)
|
||||
top_p: float | None = Field(
|
||||
description="The top-p value to use for token generation.",
|
||||
default=defs.LLM_TOP_P,
|
||||
)
|
||||
request_timeout: float = Field(
|
||||
description="The request timeout to use.", default=defs.LLM_REQUEST_TIMEOUT
|
||||
)
|
||||
|
||||
@ -110,7 +110,7 @@ def _load_openai_completion_llm(
|
||||
"frequency_penalty": config.get("frequency_penalty", 0),
|
||||
"presence_penalty": config.get("presence_penalty", 0),
|
||||
"top_p": config.get("top_p", 1),
|
||||
"max_tokens": config.get("max_tokens"),
|
||||
"max_tokens": config.get("max_tokens", 4000),
|
||||
}),
|
||||
on_error,
|
||||
cache,
|
||||
|
||||
40
graphrag/index/utils/rate_limiter.py
Normal file
40
graphrag/index/utils/rate_limiter.py
Normal file
@ -0,0 +1,40 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Rate limiter utility."""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""
|
||||
The original TpmRpmLLMLimiter strategy did not account for minute-based rate limiting when scheduled.
|
||||
|
||||
The RateLimiter was introduced to ensure that the CommunityReportsExtractor could be scheduled to adhere to rate configurations on a per-minute basis.
|
||||
"""
|
||||
|
||||
# TODO: RateLimiter scheduled: using asyncio for async_mode
|
||||
|
||||
def __init__(self, rate: int, per: int):
|
||||
self.rate = rate
|
||||
self.per = per
|
||||
self.allowance = rate
|
||||
self.last_check = time.monotonic()
|
||||
|
||||
async def acquire(self):
|
||||
"""Acquire a token from the rate limiter."""
|
||||
current = time.monotonic()
|
||||
elapsed = current - self.last_check
|
||||
self.last_check = current
|
||||
self.allowance += elapsed * (self.rate / self.per)
|
||||
|
||||
if self.allowance > self.rate:
|
||||
self.allowance = self.rate
|
||||
|
||||
if self.allowance < 1.0:
|
||||
sleep_time = (1.0 - self.allowance) * (self.per / self.rate)
|
||||
await asyncio.sleep(sleep_time)
|
||||
self.allowance = 0.0
|
||||
else:
|
||||
self.allowance -= 1.0
|
||||
@ -15,6 +15,7 @@ from graphrag.index.graph.extractors.community_reports import (
|
||||
CommunityReportsExtractor,
|
||||
)
|
||||
from graphrag.index.llm import load_llm
|
||||
from graphrag.index.utils.rate_limiter import RateLimiter
|
||||
from graphrag.index.verbs.graph.report.strategies.typing import (
|
||||
CommunityReport,
|
||||
StrategyConfig,
|
||||
@ -53,6 +54,8 @@ async def _run_extractor(
|
||||
args: StrategyConfig,
|
||||
reporter: VerbCallbacks,
|
||||
) -> CommunityReport | None:
|
||||
# RateLimiter
|
||||
rate_limiter = RateLimiter(rate=1, per=60)
|
||||
extractor = CommunityReportsExtractor(
|
||||
llm,
|
||||
extraction_prompt=args.get("extraction_prompt", None),
|
||||
@ -63,6 +66,7 @@ async def _run_extractor(
|
||||
)
|
||||
|
||||
try:
|
||||
await rate_limiter.acquire()
|
||||
results = await extractor({"input_text": input})
|
||||
report = results.structured_output
|
||||
if report is None or len(report.keys()) == 0:
|
||||
|
||||
@ -165,11 +165,13 @@ def get_global_search_engine(
|
||||
max_data_tokens=gs_config.data_max_tokens,
|
||||
map_llm_params={
|
||||
"max_tokens": gs_config.map_max_tokens,
|
||||
"temperature": 0.0,
|
||||
"temperature": gs_config.temperature,
|
||||
"top_p": gs_config.top_p,
|
||||
},
|
||||
reduce_llm_params={
|
||||
"max_tokens": gs_config.reduce_max_tokens,
|
||||
"temperature": 0.0,
|
||||
"temperature": gs_config.temperature,
|
||||
"top_p": gs_config.top_p,
|
||||
},
|
||||
allow_general_knowledge=False,
|
||||
json_mode=False,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user