Improve default llm retry logic to be more optimized (#1701)

This commit is contained in:
Josh Bradley 2025-02-13 16:56:37 -05:00 committed by GitHub
parent b8b949f3bb
commit f14cda2b6d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
50 changed files with 606 additions and 567 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "add dynamic retry logic."
}

View File

@ -13,6 +13,7 @@ Backwards compatibility is not guaranteed at this time.
from pydantic import PositiveInt, validate_call
import graphrag.config.defaults as defs
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.llm.load_llm import load_llm
@ -95,8 +96,14 @@ async def generate_indexing_prompts(
)
# Create LLM from config
# TODO: Expose way to specify Prompt Tuning model ID through config
# TODO: Expose a way to specify Prompt Tuning model ID through config
default_llm_settings = config.get_language_model_config(PROMPT_TUNING_MODEL_ID)
# if max_retries is not set, inject a dynamically assigned value based on the number of expected LLM calls
# to be made or fallback to a default value in the worst case
if default_llm_settings.max_retries == -1:
default_llm_settings.max_retries = min(len(doc_list), defs.LLM_MAX_RETRIES)
llm = load_llm(
"prompt_tuning",
default_llm_settings,

View File

@ -24,7 +24,7 @@ DEFAULT_CHAT_MODEL_ID = "default_chat_model"
DEFAULT_EMBEDDING_MODEL_ID = "default_embedding_model"
ASYNC_MODE = AsyncType.Threaded
ENCODING_MODEL = "cl100k_base"
AZURE_AUDIENCE = "https://cognitiveservices.azure.com/.default"
COGNITIVE_SERVICES_AUDIENCE = "https://cognitiveservices.azure.com/.default"
AUTH_TYPE = AuthType.APIKey
#
# LLM Parameters
@ -39,15 +39,12 @@ LLM_N = 1
LLM_REQUEST_TIMEOUT = 180.0
LLM_TOKENS_PER_MINUTE = 50_000
LLM_REQUESTS_PER_MINUTE = 1_000
RETRY_STRATEGY = "native"
LLM_MAX_RETRIES = 10
LLM_MAX_RETRY_WAIT = 10.0
LLM_PRESENCE_PENALTY = 0.0
LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION = True
LLM_CONCURRENT_REQUESTS = 25
PARALLELIZATION_STAGGER = 0.3
PARALLELIZATION_NUM_THREADS = 50
#
# Text embedding
#

View File

@ -14,32 +14,41 @@ INIT_YAML = f"""\
models:
{defs.DEFAULT_CHAT_MODEL_ID}:
api_key: ${{GRAPHRAG_API_KEY}} # set this in the generated .env file
type: {defs.LLM_TYPE.value} # or azure_openai_chat
# api_base: https://<instance>.openai.azure.com
# api_version: 2024-05-01-preview
auth_type: {defs.AUTH_TYPE.value} # or azure_managed_identity
api_key: ${{GRAPHRAG_API_KEY}} # set this in the generated .env file
# audience: "https://cognitiveservices.azure.com/.default"
# organization: <organization_id>
model: {defs.LLM_MODEL}
# deployment_name: <azure_model_deployment_name>
# encoding_model: {defs.ENCODING_MODEL} # automatically set by tiktoken if left undefined
model_supports_json: true # recommended if this is available for your model.
parallelization_num_threads: {defs.PARALLELIZATION_NUM_THREADS}
parallelization_stagger: {defs.PARALLELIZATION_STAGGER}
concurrent_requests: {defs.LLM_CONCURRENT_REQUESTS} # max number of simultaneous LLM requests allowed
async_mode: {defs.ASYNC_MODE.value} # or asyncio
# audience: "https://cognitiveservices.azure.com/.default"
# api_base: https://<instance>.openai.azure.com
# api_version: 2024-02-15-preview
# organization: <organization_id>
# deployment_name: <azure_model_deployment_name>
retry_strategy: native
max_retries: -1 # set to -1 for dynamic retry logic (most optimal setting based on server response)
tokens_per_minute: 0 # set to 0 to disable rate limiting
requests_per_minute: 0 # set to 0 to disable rate limiting
{defs.DEFAULT_EMBEDDING_MODEL_ID}:
api_key: ${{GRAPHRAG_API_KEY}}
type: {defs.EMBEDDING_TYPE.value} # or azure_openai_embedding
auth_type: {defs.AUTH_TYPE.value} # or azure_managed_identity
model: {defs.EMBEDDING_MODEL}
parallelization_num_threads: {defs.PARALLELIZATION_NUM_THREADS}
parallelization_stagger: {defs.PARALLELIZATION_STAGGER}
async_mode: {defs.ASYNC_MODE.value} # or asyncio
# api_base: https://<instance>.openai.azure.com
# api_version: 2024-02-15-preview
# api_version: 2024-05-01-preview
auth_type: {defs.AUTH_TYPE.value} # or azure_managed_identity
api_key: ${{GRAPHRAG_API_KEY}}
# audience: "https://cognitiveservices.azure.com/.default"
# organization: <organization_id>
model: {defs.EMBEDDING_MODEL}
# deployment_name: <azure_model_deployment_name>
# encoding_model: {defs.ENCODING_MODEL} # automatically set by tiktoken if left undefined
model_supports_json: true # recommended if this is available for your model.
concurrent_requests: {defs.LLM_CONCURRENT_REQUESTS} # max number of simultaneous LLM requests allowed
async_mode: {defs.ASYNC_MODE.value} # or asyncio
retry_strategy: native
max_retries: -1 # set to -1 for dynamic retry logic (most optimal setting based on server response)
tokens_per_minute: 0 # set to 0 to disable rate limiting
requests_per_minute: 0 # set to 0 to disable rate limiting
vector_store:
{defs.VECTOR_STORE_DEFAULT_ID}:

View File

@ -49,8 +49,7 @@ class CommunityReportsConfig(BaseModel):
return self.strategy or {
"type": CreateCommunityReportsStrategyType.graph_intelligence,
"llm": model_config.model_dump(),
"stagger": model_config.parallelization_stagger,
"num_threads": model_config.parallelization_num_threads,
"num_threads": model_config.concurrent_requests,
"graph_prompt": (Path(root_dir) / self.graph_prompt).read_text(
encoding="utf-8"
)

View File

@ -46,8 +46,7 @@ class ClaimExtractionConfig(BaseModel):
"""Get the resolved claim extraction strategy."""
return self.strategy or {
"llm": model_config.model_dump(),
"stagger": model_config.parallelization_stagger,
"num_threads": model_config.parallelization_num_threads,
"num_threads": model_config.concurrent_requests,
"extraction_prompt": (Path(root_dir) / self.prompt).read_text(
encoding="utf-8"
)

View File

@ -47,8 +47,7 @@ class ExtractGraphConfig(BaseModel):
return self.strategy or {
"type": ExtractEntityStrategyType.graph_intelligence,
"llm": model_config.model_dump(),
"stagger": model_config.parallelization_stagger,
"num_threads": model_config.parallelization_num_threads,
"num_threads": model_config.concurrent_requests,
"extraction_prompt": (Path(root_dir) / self.prompt).read_text(
encoding="utf-8"
)

View File

@ -64,7 +64,7 @@ class ExtractGraphNLPConfig(BaseModel):
text_analyzer: TextAnalyzerConfig = Field(
description="The text analyzer configuration.", default=TextAnalyzerConfig()
)
parallelization_num_threads: int = Field(
concurrent_requests: int = Field(
description="The number of threads to use for the extraction process.",
default=defs.PARALLELIZATION_NUM_THREADS,
default=defs.LLM_CONCURRENT_REQUESTS,
)

View File

@ -31,7 +31,7 @@ class LanguageModelConfig(BaseModel):
API Key is required when using OpenAI API
or when using Azure API with API Key authentication.
For the time being, this check is extra verbose for clarity.
It will also through an exception if an API Key is provided
It will also raise an exception if an API Key is provided
when one is not expected such as the case of using Azure
Managed Identity.
@ -199,6 +199,10 @@ class LanguageModelConfig(BaseModel):
description="The number of requests per minute to use for the LLM service.",
default=defs.LLM_REQUESTS_PER_MINUTE,
)
retry_strategy: str = Field(
description="The retry strategy to use for the LLM service.",
default=defs.RETRY_STRATEGY,
)
max_retries: int = Field(
description="The maximum number of retries to use for the LLM service.",
default=defs.LLM_MAX_RETRIES,
@ -207,10 +211,6 @@ class LanguageModelConfig(BaseModel):
description="The maximum retry wait to use for the LLM service.",
default=defs.LLM_MAX_RETRY_WAIT,
)
sleep_on_rate_limit_recommendation: bool = Field(
description="Whether to sleep on rate limit recommendations.",
default=defs.LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION,
)
concurrent_requests: int = Field(
description="Whether to use concurrent requests for the LLM service.",
default=defs.LLM_CONCURRENT_REQUESTS,
@ -218,14 +218,6 @@ class LanguageModelConfig(BaseModel):
responses: list[str | BaseModel] | None = Field(
default=None, description="Static responses to use in mock mode."
)
parallelization_stagger: float = Field(
description="The stagger to use for the LLM service.",
default=defs.PARALLELIZATION_STAGGER,
)
parallelization_num_threads: int = Field(
description="The number of threads to use for the LLM service.",
default=defs.PARALLELIZATION_NUM_THREADS,
)
async_mode: AsyncType = Field(
description="The async mode to use.", default=defs.ASYNC_MODE
)

View File

@ -40,8 +40,7 @@ class SummarizeDescriptionsConfig(BaseModel):
return self.strategy or {
"type": SummarizeStrategyType.graph_intelligence,
"llm": model_config.model_dump(),
"stagger": model_config.parallelization_stagger,
"num_threads": model_config.parallelization_num_threads,
"num_threads": model_config.concurrent_requests,
"summarize_prompt": (Path(root_dir) / self.prompt).read_text(
encoding="utf-8"
)

View File

@ -48,8 +48,7 @@ class TextEmbeddingConfig(BaseModel):
return self.strategy or {
"type": TextEmbedStrategyType.openai,
"llm": model_config.model_dump(),
"stagger": model_config.parallelization_stagger,
"num_threads": model_config.parallelization_num_threads,
"num_threads": model_config.concurrent_requests,
"batch_size": self.batch_size,
"batch_max_tokens": self.batch_max_tokens,
}

View File

@ -37,7 +37,7 @@ async def extract_graph_nlp(
text_units,
text_analyzer=text_analyzer,
normalize_edge_weights=extraction_config.normalize_edge_weights,
num_threads=extraction_config.parallelization_num_threads,
num_threads=extraction_config.concurrent_requests,
cache=cache,
)

View File

@ -8,8 +8,9 @@ from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any
from fnllm import ChatLLM, EmbeddingsLLM, JsonStrategy, LLMEvents
from fnllm.base.config import JsonStrategy, RetryStrategy
from fnllm.caching import Cache as LLMCache
from fnllm.events import LLMEvents
from fnllm.openai import (
AzureOpenAIConfig,
OpenAIConfig,
@ -30,6 +31,8 @@ from graphrag.index.llm.manager import ChatLLMSingleton, EmbeddingsLLMSingleton
from .mock_llm import MockChatLLM
if TYPE_CHECKING:
from fnllm.types import ChatLLM, EmbeddingsLLM
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.index.typing import ErrorHandlerFn
@ -209,7 +212,7 @@ def _create_openai_config(config: LanguageModelConfig, azure: bool) -> OpenAICon
msg = "Azure OpenAI Chat LLM requires an API base"
raise ValueError(msg)
audience = config.audience or defs.AZURE_AUDIENCE
audience = config.audience or defs.COGNITIVE_SERVICES_AUDIENCE
return AzureOpenAIConfig(
api_key=config.api_key,
endpoint=config.api_base,
@ -220,19 +223,22 @@ def _create_openai_config(config: LanguageModelConfig, azure: bool) -> OpenAICon
max_retry_wait=config.max_retry_wait,
requests_per_minute=config.requests_per_minute,
tokens_per_minute=config.tokens_per_minute,
cognitive_services_endpoint=audience,
audience=audience,
retry_strategy=RetryStrategy(config.retry_strategy),
timeout=config.request_timeout,
max_concurrency=config.concurrent_requests,
model=config.model,
encoding=encoding_model,
deployment=config.deployment_name,
chat_parameters=chat_parameters,
sleep_on_rate_limit_recommendation=True,
)
return PublicOpenAIConfig(
api_key=config.api_key,
base_url=config.api_base,
json_strategy=json_strategy,
organization=config.organization,
retry_strategy=RetryStrategy(config.retry_strategy),
max_retries=config.max_retries,
max_retry_wait=config.max_retry_wait,
requests_per_minute=config.requests_per_minute,
@ -242,6 +248,7 @@ def _create_openai_config(config: LanguageModelConfig, azure: bool) -> OpenAICon
model=config.model,
encoding=encoding_model,
chat_parameters=chat_parameters,
sleep_on_rate_limit_recommendation=True,
)

View File

@ -5,7 +5,7 @@
from functools import cache
from fnllm import ChatLLM, EmbeddingsLLM
from fnllm.types import ChatLLM, EmbeddingsLLM
@cache

View File

@ -5,8 +5,9 @@
from dataclasses import dataclass
from typing import Any, cast
from fnllm import ChatLLM, LLMInput, LLMOutput
from fnllm.types import ChatLLM
from fnllm.types.generics import THistoryEntry, TJsonModel, TModelParameters
from fnllm.types.io import LLMInput, LLMOutput
from pydantic import BaseModel
from typing_extensions import Unpack

View File

@ -116,10 +116,10 @@ async def _text_embed_in_memory(
):
strategy_type = strategy["type"]
strategy_exec = load_strategy(strategy_type)
strategy_args = {**strategy}
strategy_config = {**strategy}
texts: list[str] = input[embed_column].to_numpy().tolist()
result = await strategy_exec(texts, callbacks, cache, strategy_args)
result = await strategy_exec(texts, callbacks, cache, strategy_config)
return result.embeddings
@ -137,7 +137,11 @@ async def _text_embed_with_vector_store(
):
strategy_type = strategy["type"]
strategy_exec = load_strategy(strategy_type)
strategy_args = {**strategy}
strategy_config = {**strategy}
# if max_retries is not set, inject a dynamically assigned value based on the total number of expected LLM calls to be made
if strategy_config.get("llm") and strategy_config["llm"]["max_retries"] == -1:
strategy_config["llm"]["max_retries"] = len(input)
# Get vector-storage configuration
insert_batch_size: int = (
@ -176,7 +180,7 @@ async def _text_embed_with_vector_store(
texts: list[str] = batch[embed_column].to_numpy().tolist()
titles: list[str] = batch[title].to_numpy().tolist()
ids: list[str] = batch[id_column].to_numpy().tolist()
result = await strategy_exec(texts, callbacks, cache, strategy_args)
result = await strategy_exec(texts, callbacks, cache, strategy_config)
if result.embeddings:
embeddings = [
embedding for embedding in result.embeddings if embedding is not None

View File

@ -8,7 +8,7 @@ import logging
from typing import Any
import numpy as np
from fnllm import EmbeddingsLLM
from fnllm.types import EmbeddingsLLM
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks

View File

@ -9,7 +9,7 @@ from dataclasses import dataclass
from typing import Any
import tiktoken
from fnllm import ChatLLM
from fnllm.types import ChatLLM
import graphrag.config.defaults as defs
from graphrag.index.typing import ErrorHandlerFn

View File

@ -50,6 +50,10 @@ async def extract_covariates(
strategy = strategy or {}
strategy_config = {**strategy}
# if max_retries is not set, inject a dynamically assigned value based on the total number of expected LLM calls to be made
if strategy_config.get("llm") and strategy_config["llm"]["max_retries"] == -1:
strategy_config["llm"]["max_retries"] = len(input)
async def run_strategy(row):
text = row[column]
result = await run_extract_claims(

View File

@ -94,6 +94,10 @@ async def extract_graph(
)
strategy_config = {**strategy}
# if max_retries is not set, inject a dynamically assigned value based on the total number of expected LLM calls to be made
if strategy_config.get("llm") and strategy_config["llm"]["max_retries"] == -1:
strategy_config["llm"]["max_retries"] = len(text_units)
num_started = 0
async def run_strategy(row):

View File

@ -12,7 +12,7 @@ from typing import Any
import networkx as nx
import tiktoken
from fnllm import ChatLLM
from fnllm.types import ChatLLM
import graphrag.config.defaults as defs
from graphrag.index.typing import ErrorHandlerFn

View File

@ -4,7 +4,7 @@
"""A module containing run_graph_intelligence, run_extract_graph and _create_text_splitter methods to run graph intelligence."""
import networkx as nx
from fnllm import ChatLLM
from fnllm.types import ChatLLM
import graphrag.config.defaults as defs
from graphrag.cache.pipeline_cache import PipelineCache

View File

@ -8,7 +8,7 @@ import traceback
from dataclasses import dataclass
from typing import Any
from fnllm import ChatLLM
from fnllm.types import ChatLLM
from pydantic import BaseModel, Field
from graphrag.index.typing import ErrorHandlerFn

View File

@ -6,7 +6,7 @@
import logging
import traceback
from fnllm import ChatLLM
from fnllm.types import ChatLLM
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks

View File

@ -42,7 +42,12 @@ async def summarize_communities(
"""Generate community summaries."""
reports: list[CommunityReport | None] = []
tick = progress_ticker(callbacks.progress, len(local_contexts))
runner = load_strategy(strategy["type"])
strategy_exec = load_strategy(strategy["type"])
strategy_config = {**strategy}
# if max_retries is not set, inject a dynamically assigned value based on the total number of expected LLM calls to be made
if strategy_config.get("llm") and strategy_config["llm"]["max_retries"] == -1:
strategy_config["llm"]["max_retries"] = len(nodes)
community_hierarchy = restore_community_hierarchy(nodes)
levels = get_levels(nodes)
@ -62,13 +67,13 @@ async def summarize_communities(
async def run_generate(record):
result = await _generate_report(
runner,
strategy_exec,
community_id=record[schemas.COMMUNITY_ID],
community_level=record[schemas.COMMUNITY_LEVEL],
community_context=record[schemas.CONTEXT_STRING],
callbacks=callbacks,
cache=cache,
strategy=strategy,
strategy=strategy_config,
)
tick()
return result

View File

@ -6,7 +6,7 @@
import json
from dataclasses import dataclass
from fnllm import ChatLLM
from fnllm.types import ChatLLM
from graphrag.index.typing import ErrorHandlerFn
from graphrag.index.utils.tokens import num_tokens_from_string

View File

@ -3,7 +3,7 @@
"""A module containing run_graph_intelligence, run_resolve_entities and _create_text_list_splitter methods to run graph intelligence."""
from fnllm import ChatLLM
from fnllm.types import ChatLLM
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks

View File

@ -76,6 +76,10 @@ async def summarize_descriptions(
)
strategy_config = {**strategy}
# if max_retries is not set, inject a dynamically assigned value based on the maximum number of expected LLM calls to be made
if strategy_config.get("llm") and strategy_config["llm"]["max_retries"] == -1:
strategy_config["llm"]["max_retries"] = len(entities_df) + len(relationships_df)
async def get_summarized(
nodes: pd.DataFrame, edges: pd.DataFrame, semaphore: asyncio.Semaphore
):

View File

@ -6,6 +6,7 @@
import asyncio
import sys
import graphrag.config.defaults as defs
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.llm.load_llm import load_llm, load_llm_embeddings
@ -17,6 +18,9 @@ def validate_config_names(logger: ProgressLogger, parameters: GraphRagConfig) ->
# Validate Chat LLM configs
# TODO: Replace default_chat_model with a way to select the model
default_llm_settings = parameters.get_language_model_config("default_chat_model")
# if max_retries is not set, set it to the default value
if default_llm_settings.max_retries == -1:
default_llm_settings.max_retries = defs.LLM_MAX_RETRIES
llm = load_llm(
name="test-llm",
config=default_llm_settings,

View File

@ -39,7 +39,7 @@ async def run_workflow(
config.community_reports.model_id
)
async_mode = community_reports_llm_settings.async_mode
num_threads = community_reports_llm_settings.parallelization_num_threads
num_threads = community_reports_llm_settings.concurrent_requests
summarization_strategy = config.community_reports.resolved_strategy(
config.root_dir, community_reports_llm_settings
)

View File

@ -31,7 +31,7 @@ async def run_workflow(
config.community_reports.model_id
)
async_mode = community_reports_llm_settings.async_mode
num_threads = community_reports_llm_settings.parallelization_num_threads
num_threads = community_reports_llm_settings.concurrent_requests
summarization_strategy = config.community_reports.resolved_strategy(
config.root_dir, community_reports_llm_settings
)

View File

@ -32,7 +32,7 @@ async def run_workflow(
)
async_mode = extract_claims_llm_settings.async_mode
num_threads = extract_claims_llm_settings.parallelization_num_threads
num_threads = extract_claims_llm_settings.concurrent_requests
output = await extract_covariates(
text_units,

View File

@ -32,9 +32,6 @@ async def run_workflow(
extraction_strategy = config.extract_graph.resolved_strategy(
config.root_dir, extract_graph_llm_settings
)
extraction_num_threads = extract_graph_llm_settings.parallelization_num_threads
extraction_async_mode = extract_graph_llm_settings.async_mode
entity_types = config.extract_graph.entity_types
summarization_llm_settings = config.get_language_model_config(
config.summarize_descriptions.model_id
@ -42,18 +39,17 @@ async def run_workflow(
summarization_strategy = config.summarize_descriptions.resolved_strategy(
config.root_dir, summarization_llm_settings
)
summarization_num_threads = summarization_llm_settings.parallelization_num_threads
entities, relationships = await extract_graph(
text_units=text_units,
callbacks=callbacks,
cache=context.cache,
extraction_strategy=extraction_strategy,
extraction_num_threads=extraction_num_threads,
extraction_async_mode=extraction_async_mode,
entity_types=entity_types,
extraction_num_threads=extract_graph_llm_settings.concurrent_requests,
extraction_async_mode=extract_graph_llm_settings.async_mode,
entity_types=config.extract_graph.entity_types,
summarization_strategy=summarization_strategy,
summarization_num_threads=summarization_num_threads,
summarization_num_threads=summarization_llm_settings.concurrent_requests,
embed_config=config.embed_graph,
layout_enabled=config.umap.enabled,
)

View File

@ -3,7 +3,7 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from fnllm import ChatLLM
from fnllm.types import ChatLLM
from graphrag.prompt_tune.prompt.community_report_rating import (
GENERATE_REPORT_RATING_PROMPT,

View File

@ -3,7 +3,7 @@
"""Generate a community reporter role for community summarization."""
from fnllm import ChatLLM
from fnllm.types import ChatLLM
from graphrag.prompt_tune.prompt.community_reporter_role import (
GENERATE_COMMUNITY_REPORTER_ROLE_PROMPT,

View File

@ -3,7 +3,7 @@
"""Domain generation for GraphRAG prompts."""
from fnllm import ChatLLM
from fnllm.types import ChatLLM
from graphrag.prompt_tune.prompt.domain import GENERATE_DOMAIN_PROMPT

View File

@ -6,7 +6,7 @@
import asyncio
import json
from fnllm import ChatLLM
from fnllm.types import ChatLLM
from graphrag.prompt_tune.prompt.entity_relationship import (
ENTITY_RELATIONSHIPS_GENERATION_JSON_PROMPT,

View File

@ -3,7 +3,7 @@
"""Entity type generation module for fine-tuning."""
from fnllm import ChatLLM
from fnllm.types import ChatLLM
from pydantic import BaseModel
from graphrag.prompt_tune.defaults import DEFAULT_TASK

View File

@ -3,7 +3,7 @@
"""Language detection for GraphRAG prompts."""
from fnllm import ChatLLM
from fnllm.types import ChatLLM
from graphrag.prompt_tune.prompt.language import DETECT_LANGUAGE_PROMPT

View File

@ -3,7 +3,7 @@
"""Persona generating module for fine-tuning GraphRAG prompts."""
from fnllm import ChatLLM
from fnllm.types import ChatLLM
from graphrag.prompt_tune.defaults import DEFAULT_TASK
from graphrag.prompt_tune.prompt.persona import GENERATE_PERSONA_PROMPT

View File

@ -5,7 +5,7 @@
import numpy as np
import pandas as pd
from fnllm import ChatLLM
from fnllm.types import ChatLLM
import graphrag.config.defaults as defs
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks

View File

@ -5,6 +5,7 @@
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
import graphrag.config.defaults as defs
from graphrag.config.enums import AuthType, LLMType
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
@ -14,67 +15,68 @@ from graphrag.query.llm.oai.typing import OpenaiApiType
def get_llm(config: GraphRagConfig) -> ChatOpenAI:
"""Get the LLM client."""
default_llm_settings = config.get_language_model_config("default_chat_model")
is_azure_client = default_llm_settings.type == LLMType.AzureOpenAIChat
debug_llm_key = default_llm_settings.api_key or ""
llm_config = config.get_language_model_config("default_chat_model")
is_azure_client = llm_config.type == LLMType.AzureOpenAIChat
debug_llm_key = llm_config.api_key or ""
llm_debug_info = {
**default_llm_settings.model_dump(),
**llm_config.model_dump(),
"api_key": f"REDACTED,len={len(debug_llm_key)}",
}
audience = (
default_llm_settings.audience
if default_llm_settings.audience
llm_config.audience
if llm_config.audience
else "https://cognitiveservices.azure.com/.default"
)
print(f"creating llm client with {llm_debug_info}") # noqa T201
return ChatOpenAI(
api_key=default_llm_settings.api_key,
api_key=llm_config.api_key,
azure_ad_token_provider=(
get_bearer_token_provider(DefaultAzureCredential(), audience)
if is_azure_client
and default_llm_settings.auth_type == AuthType.AzureManagedIdentity
if is_azure_client and llm_config.auth_type == AuthType.AzureManagedIdentity
else None
),
api_base=default_llm_settings.api_base,
organization=default_llm_settings.organization,
model=default_llm_settings.model,
api_base=llm_config.api_base,
organization=llm_config.organization,
model=llm_config.model,
api_type=OpenaiApiType.AzureOpenAI if is_azure_client else OpenaiApiType.OpenAI,
deployment_name=default_llm_settings.deployment_name,
api_version=default_llm_settings.api_version,
max_retries=default_llm_settings.max_retries,
request_timeout=default_llm_settings.request_timeout,
deployment_name=llm_config.deployment_name,
api_version=llm_config.api_version,
max_retries=llm_config.max_retries
if llm_config.max_retries != -1
else defs.LLM_MAX_RETRIES,
request_timeout=llm_config.request_timeout,
)
def get_text_embedder(config: GraphRagConfig) -> OpenAIEmbedding:
"""Get the LLM client for embeddings."""
embeddings_llm_settings = config.get_language_model_config(
config.embed_text.model_id
)
is_azure_client = embeddings_llm_settings.type == LLMType.AzureOpenAIEmbedding
debug_embedding_api_key = embeddings_llm_settings.api_key or ""
embeddings_llm_config = config.get_language_model_config(config.embed_text.model_id)
is_azure_client = embeddings_llm_config.type == LLMType.AzureOpenAIEmbedding
debug_embedding_api_key = embeddings_llm_config.api_key or ""
llm_debug_info = {
**embeddings_llm_settings.model_dump(),
**embeddings_llm_config.model_dump(),
"api_key": f"REDACTED,len={len(debug_embedding_api_key)}",
}
if embeddings_llm_settings.audience is None:
if embeddings_llm_config.audience is None:
audience = "https://cognitiveservices.azure.com/.default"
else:
audience = embeddings_llm_settings.audience
audience = embeddings_llm_config.audience
print(f"creating embedding llm client with {llm_debug_info}") # noqa T201
return OpenAIEmbedding(
api_key=embeddings_llm_settings.api_key,
api_key=embeddings_llm_config.api_key,
azure_ad_token_provider=(
get_bearer_token_provider(DefaultAzureCredential(), audience)
if is_azure_client
and embeddings_llm_settings.auth_type == AuthType.AzureManagedIdentity
and embeddings_llm_config.auth_type == AuthType.AzureManagedIdentity
else None
),
api_base=embeddings_llm_settings.api_base,
organization=embeddings_llm_settings.organization,
api_base=embeddings_llm_config.api_base,
organization=embeddings_llm_config.organization,
api_type=OpenaiApiType.AzureOpenAI if is_azure_client else OpenaiApiType.OpenAI,
model=embeddings_llm_settings.model,
deployment_name=embeddings_llm_settings.deployment_name,
api_version=embeddings_llm_settings.api_version,
max_retries=embeddings_llm_settings.max_retries,
model=embeddings_llm_config.model,
deployment_name=embeddings_llm_config.deployment_name,
api_version=embeddings_llm_config.api_version,
max_retries=embeddings_llm_config.max_retries
if embeddings_llm_config.max_retries != -1
else defs.LLM_MAX_RETRIES,
)

864
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -57,11 +57,15 @@ lancedb = "^0.17.0"
aiofiles = "^24.1.0"
# LLM
fnllm = {extras = ["azure", "openai"], version = "^0.1.2"}
httpx = "^0.28.1"
json-repair = "^0.30.3"
openai = "^1.57.0"
nltk = "3.9.1"
tenacity = "^9.0.0"
tiktoken = "^0.8.0"
# Data-Sci
# Data-Science
numpy = "^1.25.2"
graspologic = "^3.4.1"
networkx = "^3.4.2"
@ -78,22 +82,18 @@ rich = "^13.9.4"
devtools = "^0.12.2"
typing-extensions = "^4.12.2"
#Azure
# Azure
azure-cosmos = "^4.9.0"
azure-identity = "^1.19.0"
azure-storage-blob = "^12.24.0"
future = "^1.0.0" # Needed until graspologic fixes their dependency
typer = "^0.15.1"
fnllm = "^0.0.10"
tenacity = "^9.0.0"
json-repair = "^0.30.3"
tqdm = "^4.67.1"
httpx = "^0.28.1"
textblob = "^0.18.0.post0"
spacy = "^3.8.4"
[tool.poetry.group.dev.dependencies]
coverage = "^7.6.9"
ipykernel = "^6.29.5"

View File

@ -15,7 +15,7 @@
1,
2500
],
"max_runtime": 300,
"max_runtime": 500,
"expected_artifacts": 2
},
"create_communities": {

View File

@ -10,8 +10,7 @@ models:
tokens_per_minute: ${GRAPHRAG_LLM_TPM}
requests_per_minute: ${GRAPHRAG_LLM_RPM}
model_supports_json: true
parallelization_num_threads: 50
parallelization_stagger: 0.3
concurrent_requests: 50
async_mode: threaded
default_embedding_model:
azure_auth_type: api_key
@ -23,8 +22,7 @@ models:
model: ${GRAPHRAG_EMBEDDING_MODEL}
tokens_per_minute: ${GRAPHRAG_EMBEDDING_TPM}
requests_per_minute: ${GRAPHRAG_EMBEDDING_RPM}
parallelization_num_threads: 50
parallelization_stagger: 0.3
concurrent_requests: 50
async_mode: threaded
vector_store:

View File

@ -15,7 +15,7 @@
1,
2500
],
"max_runtime": 300,
"max_runtime": 500,
"expected_artifacts": 2
},
"extract_covariates": {

View File

@ -10,8 +10,7 @@ models:
tokens_per_minute: ${GRAPHRAG_LLM_TPM}
requests_per_minute: ${GRAPHRAG_LLM_RPM}
model_supports_json: true
parallelization_num_threads: 50
parallelization_stagger: 0.3
concurrent_requests: 50
async_mode: threaded
default_embedding_model:
azure_auth_type: api_key
@ -23,8 +22,7 @@ models:
model: ${GRAPHRAG_EMBEDDING_MODEL}
tokens_per_minute: ${GRAPHRAG_EMBEDDING_TPM}
requests_per_minute: ${GRAPHRAG_EMBEDDING_RPM}
parallelization_num_threads: 50
parallelization_stagger: 0.3
concurrent_requests: 50
async_mode: threaded
vector_store:

View File

@ -265,13 +265,7 @@ def assert_language_model_configs(
assert actual.requests_per_minute == expected.requests_per_minute
assert actual.max_retries == expected.max_retries
assert actual.max_retry_wait == expected.max_retry_wait
assert (
actual.sleep_on_rate_limit_recommendation
== expected.sleep_on_rate_limit_recommendation
)
assert actual.concurrent_requests == expected.concurrent_requests
assert actual.parallelization_stagger == expected.parallelization_stagger
assert actual.parallelization_num_threads == expected.parallelization_num_threads
assert actual.async_mode == expected.async_mode
if actual.responses is not None:
assert expected.responses is not None

View File

@ -1,6 +1,6 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from fnllm import ChatLLM
from fnllm.types import ChatLLM
from pydantic import BaseModel
from graphrag.index.llm.mock_llm import MockChatLLM