mirror of
https://github.com/microsoft/graphrag.git
synced 2025-06-26 23:19:58 +00:00
Improve default llm retry logic to be more optimized (#1701)
This commit is contained in:
parent
b8b949f3bb
commit
f14cda2b6d
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "add dynamic retry logic."
|
||||
}
|
@ -13,6 +13,7 @@ Backwards compatibility is not guaranteed at this time.
|
||||
|
||||
from pydantic import PositiveInt, validate_call
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.llm.load_llm import load_llm
|
||||
@ -95,8 +96,14 @@ async def generate_indexing_prompts(
|
||||
)
|
||||
|
||||
# Create LLM from config
|
||||
# TODO: Expose way to specify Prompt Tuning model ID through config
|
||||
# TODO: Expose a way to specify Prompt Tuning model ID through config
|
||||
default_llm_settings = config.get_language_model_config(PROMPT_TUNING_MODEL_ID)
|
||||
|
||||
# if max_retries is not set, inject a dynamically assigned value based on the number of expected LLM calls
|
||||
# to be made or fallback to a default value in the worst case
|
||||
if default_llm_settings.max_retries == -1:
|
||||
default_llm_settings.max_retries = min(len(doc_list), defs.LLM_MAX_RETRIES)
|
||||
|
||||
llm = load_llm(
|
||||
"prompt_tuning",
|
||||
default_llm_settings,
|
||||
|
@ -24,7 +24,7 @@ DEFAULT_CHAT_MODEL_ID = "default_chat_model"
|
||||
DEFAULT_EMBEDDING_MODEL_ID = "default_embedding_model"
|
||||
ASYNC_MODE = AsyncType.Threaded
|
||||
ENCODING_MODEL = "cl100k_base"
|
||||
AZURE_AUDIENCE = "https://cognitiveservices.azure.com/.default"
|
||||
COGNITIVE_SERVICES_AUDIENCE = "https://cognitiveservices.azure.com/.default"
|
||||
AUTH_TYPE = AuthType.APIKey
|
||||
#
|
||||
# LLM Parameters
|
||||
@ -39,15 +39,12 @@ LLM_N = 1
|
||||
LLM_REQUEST_TIMEOUT = 180.0
|
||||
LLM_TOKENS_PER_MINUTE = 50_000
|
||||
LLM_REQUESTS_PER_MINUTE = 1_000
|
||||
RETRY_STRATEGY = "native"
|
||||
LLM_MAX_RETRIES = 10
|
||||
LLM_MAX_RETRY_WAIT = 10.0
|
||||
LLM_PRESENCE_PENALTY = 0.0
|
||||
LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION = True
|
||||
LLM_CONCURRENT_REQUESTS = 25
|
||||
|
||||
PARALLELIZATION_STAGGER = 0.3
|
||||
PARALLELIZATION_NUM_THREADS = 50
|
||||
|
||||
#
|
||||
# Text embedding
|
||||
#
|
||||
|
@ -14,32 +14,41 @@ INIT_YAML = f"""\
|
||||
|
||||
models:
|
||||
{defs.DEFAULT_CHAT_MODEL_ID}:
|
||||
api_key: ${{GRAPHRAG_API_KEY}} # set this in the generated .env file
|
||||
type: {defs.LLM_TYPE.value} # or azure_openai_chat
|
||||
# api_base: https://<instance>.openai.azure.com
|
||||
# api_version: 2024-05-01-preview
|
||||
auth_type: {defs.AUTH_TYPE.value} # or azure_managed_identity
|
||||
api_key: ${{GRAPHRAG_API_KEY}} # set this in the generated .env file
|
||||
# audience: "https://cognitiveservices.azure.com/.default"
|
||||
# organization: <organization_id>
|
||||
model: {defs.LLM_MODEL}
|
||||
# deployment_name: <azure_model_deployment_name>
|
||||
# encoding_model: {defs.ENCODING_MODEL} # automatically set by tiktoken if left undefined
|
||||
model_supports_json: true # recommended if this is available for your model.
|
||||
parallelization_num_threads: {defs.PARALLELIZATION_NUM_THREADS}
|
||||
parallelization_stagger: {defs.PARALLELIZATION_STAGGER}
|
||||
concurrent_requests: {defs.LLM_CONCURRENT_REQUESTS} # max number of simultaneous LLM requests allowed
|
||||
async_mode: {defs.ASYNC_MODE.value} # or asyncio
|
||||
# audience: "https://cognitiveservices.azure.com/.default"
|
||||
# api_base: https://<instance>.openai.azure.com
|
||||
# api_version: 2024-02-15-preview
|
||||
# organization: <organization_id>
|
||||
# deployment_name: <azure_model_deployment_name>
|
||||
retry_strategy: native
|
||||
max_retries: -1 # set to -1 for dynamic retry logic (most optimal setting based on server response)
|
||||
tokens_per_minute: 0 # set to 0 to disable rate limiting
|
||||
requests_per_minute: 0 # set to 0 to disable rate limiting
|
||||
{defs.DEFAULT_EMBEDDING_MODEL_ID}:
|
||||
api_key: ${{GRAPHRAG_API_KEY}}
|
||||
type: {defs.EMBEDDING_TYPE.value} # or azure_openai_embedding
|
||||
auth_type: {defs.AUTH_TYPE.value} # or azure_managed_identity
|
||||
model: {defs.EMBEDDING_MODEL}
|
||||
parallelization_num_threads: {defs.PARALLELIZATION_NUM_THREADS}
|
||||
parallelization_stagger: {defs.PARALLELIZATION_STAGGER}
|
||||
async_mode: {defs.ASYNC_MODE.value} # or asyncio
|
||||
# api_base: https://<instance>.openai.azure.com
|
||||
# api_version: 2024-02-15-preview
|
||||
# api_version: 2024-05-01-preview
|
||||
auth_type: {defs.AUTH_TYPE.value} # or azure_managed_identity
|
||||
api_key: ${{GRAPHRAG_API_KEY}}
|
||||
# audience: "https://cognitiveservices.azure.com/.default"
|
||||
# organization: <organization_id>
|
||||
model: {defs.EMBEDDING_MODEL}
|
||||
# deployment_name: <azure_model_deployment_name>
|
||||
# encoding_model: {defs.ENCODING_MODEL} # automatically set by tiktoken if left undefined
|
||||
model_supports_json: true # recommended if this is available for your model.
|
||||
concurrent_requests: {defs.LLM_CONCURRENT_REQUESTS} # max number of simultaneous LLM requests allowed
|
||||
async_mode: {defs.ASYNC_MODE.value} # or asyncio
|
||||
retry_strategy: native
|
||||
max_retries: -1 # set to -1 for dynamic retry logic (most optimal setting based on server response)
|
||||
tokens_per_minute: 0 # set to 0 to disable rate limiting
|
||||
requests_per_minute: 0 # set to 0 to disable rate limiting
|
||||
|
||||
vector_store:
|
||||
{defs.VECTOR_STORE_DEFAULT_ID}:
|
||||
|
@ -49,8 +49,7 @@ class CommunityReportsConfig(BaseModel):
|
||||
return self.strategy or {
|
||||
"type": CreateCommunityReportsStrategyType.graph_intelligence,
|
||||
"llm": model_config.model_dump(),
|
||||
"stagger": model_config.parallelization_stagger,
|
||||
"num_threads": model_config.parallelization_num_threads,
|
||||
"num_threads": model_config.concurrent_requests,
|
||||
"graph_prompt": (Path(root_dir) / self.graph_prompt).read_text(
|
||||
encoding="utf-8"
|
||||
)
|
||||
|
@ -46,8 +46,7 @@ class ClaimExtractionConfig(BaseModel):
|
||||
"""Get the resolved claim extraction strategy."""
|
||||
return self.strategy or {
|
||||
"llm": model_config.model_dump(),
|
||||
"stagger": model_config.parallelization_stagger,
|
||||
"num_threads": model_config.parallelization_num_threads,
|
||||
"num_threads": model_config.concurrent_requests,
|
||||
"extraction_prompt": (Path(root_dir) / self.prompt).read_text(
|
||||
encoding="utf-8"
|
||||
)
|
||||
|
@ -47,8 +47,7 @@ class ExtractGraphConfig(BaseModel):
|
||||
return self.strategy or {
|
||||
"type": ExtractEntityStrategyType.graph_intelligence,
|
||||
"llm": model_config.model_dump(),
|
||||
"stagger": model_config.parallelization_stagger,
|
||||
"num_threads": model_config.parallelization_num_threads,
|
||||
"num_threads": model_config.concurrent_requests,
|
||||
"extraction_prompt": (Path(root_dir) / self.prompt).read_text(
|
||||
encoding="utf-8"
|
||||
)
|
||||
|
@ -64,7 +64,7 @@ class ExtractGraphNLPConfig(BaseModel):
|
||||
text_analyzer: TextAnalyzerConfig = Field(
|
||||
description="The text analyzer configuration.", default=TextAnalyzerConfig()
|
||||
)
|
||||
parallelization_num_threads: int = Field(
|
||||
concurrent_requests: int = Field(
|
||||
description="The number of threads to use for the extraction process.",
|
||||
default=defs.PARALLELIZATION_NUM_THREADS,
|
||||
default=defs.LLM_CONCURRENT_REQUESTS,
|
||||
)
|
||||
|
@ -31,7 +31,7 @@ class LanguageModelConfig(BaseModel):
|
||||
API Key is required when using OpenAI API
|
||||
or when using Azure API with API Key authentication.
|
||||
For the time being, this check is extra verbose for clarity.
|
||||
It will also through an exception if an API Key is provided
|
||||
It will also raise an exception if an API Key is provided
|
||||
when one is not expected such as the case of using Azure
|
||||
Managed Identity.
|
||||
|
||||
@ -199,6 +199,10 @@ class LanguageModelConfig(BaseModel):
|
||||
description="The number of requests per minute to use for the LLM service.",
|
||||
default=defs.LLM_REQUESTS_PER_MINUTE,
|
||||
)
|
||||
retry_strategy: str = Field(
|
||||
description="The retry strategy to use for the LLM service.",
|
||||
default=defs.RETRY_STRATEGY,
|
||||
)
|
||||
max_retries: int = Field(
|
||||
description="The maximum number of retries to use for the LLM service.",
|
||||
default=defs.LLM_MAX_RETRIES,
|
||||
@ -207,10 +211,6 @@ class LanguageModelConfig(BaseModel):
|
||||
description="The maximum retry wait to use for the LLM service.",
|
||||
default=defs.LLM_MAX_RETRY_WAIT,
|
||||
)
|
||||
sleep_on_rate_limit_recommendation: bool = Field(
|
||||
description="Whether to sleep on rate limit recommendations.",
|
||||
default=defs.LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION,
|
||||
)
|
||||
concurrent_requests: int = Field(
|
||||
description="Whether to use concurrent requests for the LLM service.",
|
||||
default=defs.LLM_CONCURRENT_REQUESTS,
|
||||
@ -218,14 +218,6 @@ class LanguageModelConfig(BaseModel):
|
||||
responses: list[str | BaseModel] | None = Field(
|
||||
default=None, description="Static responses to use in mock mode."
|
||||
)
|
||||
parallelization_stagger: float = Field(
|
||||
description="The stagger to use for the LLM service.",
|
||||
default=defs.PARALLELIZATION_STAGGER,
|
||||
)
|
||||
parallelization_num_threads: int = Field(
|
||||
description="The number of threads to use for the LLM service.",
|
||||
default=defs.PARALLELIZATION_NUM_THREADS,
|
||||
)
|
||||
async_mode: AsyncType = Field(
|
||||
description="The async mode to use.", default=defs.ASYNC_MODE
|
||||
)
|
||||
|
@ -40,8 +40,7 @@ class SummarizeDescriptionsConfig(BaseModel):
|
||||
return self.strategy or {
|
||||
"type": SummarizeStrategyType.graph_intelligence,
|
||||
"llm": model_config.model_dump(),
|
||||
"stagger": model_config.parallelization_stagger,
|
||||
"num_threads": model_config.parallelization_num_threads,
|
||||
"num_threads": model_config.concurrent_requests,
|
||||
"summarize_prompt": (Path(root_dir) / self.prompt).read_text(
|
||||
encoding="utf-8"
|
||||
)
|
||||
|
@ -48,8 +48,7 @@ class TextEmbeddingConfig(BaseModel):
|
||||
return self.strategy or {
|
||||
"type": TextEmbedStrategyType.openai,
|
||||
"llm": model_config.model_dump(),
|
||||
"stagger": model_config.parallelization_stagger,
|
||||
"num_threads": model_config.parallelization_num_threads,
|
||||
"num_threads": model_config.concurrent_requests,
|
||||
"batch_size": self.batch_size,
|
||||
"batch_max_tokens": self.batch_max_tokens,
|
||||
}
|
||||
|
@ -37,7 +37,7 @@ async def extract_graph_nlp(
|
||||
text_units,
|
||||
text_analyzer=text_analyzer,
|
||||
normalize_edge_weights=extraction_config.normalize_edge_weights,
|
||||
num_threads=extraction_config.parallelization_num_threads,
|
||||
num_threads=extraction_config.concurrent_requests,
|
||||
cache=cache,
|
||||
)
|
||||
|
||||
|
@ -8,8 +8,9 @@ from __future__ import annotations
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from fnllm import ChatLLM, EmbeddingsLLM, JsonStrategy, LLMEvents
|
||||
from fnllm.base.config import JsonStrategy, RetryStrategy
|
||||
from fnllm.caching import Cache as LLMCache
|
||||
from fnllm.events import LLMEvents
|
||||
from fnllm.openai import (
|
||||
AzureOpenAIConfig,
|
||||
OpenAIConfig,
|
||||
@ -30,6 +31,8 @@ from graphrag.index.llm.manager import ChatLLMSingleton, EmbeddingsLLMSingleton
|
||||
from .mock_llm import MockChatLLM
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fnllm.types import ChatLLM, EmbeddingsLLM
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.index.typing import ErrorHandlerFn
|
||||
@ -209,7 +212,7 @@ def _create_openai_config(config: LanguageModelConfig, azure: bool) -> OpenAICon
|
||||
msg = "Azure OpenAI Chat LLM requires an API base"
|
||||
raise ValueError(msg)
|
||||
|
||||
audience = config.audience or defs.AZURE_AUDIENCE
|
||||
audience = config.audience or defs.COGNITIVE_SERVICES_AUDIENCE
|
||||
return AzureOpenAIConfig(
|
||||
api_key=config.api_key,
|
||||
endpoint=config.api_base,
|
||||
@ -220,19 +223,22 @@ def _create_openai_config(config: LanguageModelConfig, azure: bool) -> OpenAICon
|
||||
max_retry_wait=config.max_retry_wait,
|
||||
requests_per_minute=config.requests_per_minute,
|
||||
tokens_per_minute=config.tokens_per_minute,
|
||||
cognitive_services_endpoint=audience,
|
||||
audience=audience,
|
||||
retry_strategy=RetryStrategy(config.retry_strategy),
|
||||
timeout=config.request_timeout,
|
||||
max_concurrency=config.concurrent_requests,
|
||||
model=config.model,
|
||||
encoding=encoding_model,
|
||||
deployment=config.deployment_name,
|
||||
chat_parameters=chat_parameters,
|
||||
sleep_on_rate_limit_recommendation=True,
|
||||
)
|
||||
return PublicOpenAIConfig(
|
||||
api_key=config.api_key,
|
||||
base_url=config.api_base,
|
||||
json_strategy=json_strategy,
|
||||
organization=config.organization,
|
||||
retry_strategy=RetryStrategy(config.retry_strategy),
|
||||
max_retries=config.max_retries,
|
||||
max_retry_wait=config.max_retry_wait,
|
||||
requests_per_minute=config.requests_per_minute,
|
||||
@ -242,6 +248,7 @@ def _create_openai_config(config: LanguageModelConfig, azure: bool) -> OpenAICon
|
||||
model=config.model,
|
||||
encoding=encoding_model,
|
||||
chat_parameters=chat_parameters,
|
||||
sleep_on_rate_limit_recommendation=True,
|
||||
)
|
||||
|
||||
|
||||
|
@ -5,7 +5,7 @@
|
||||
|
||||
from functools import cache
|
||||
|
||||
from fnllm import ChatLLM, EmbeddingsLLM
|
||||
from fnllm.types import ChatLLM, EmbeddingsLLM
|
||||
|
||||
|
||||
@cache
|
||||
|
@ -5,8 +5,9 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
|
||||
from fnllm import ChatLLM, LLMInput, LLMOutput
|
||||
from fnllm.types import ChatLLM
|
||||
from fnllm.types.generics import THistoryEntry, TJsonModel, TModelParameters
|
||||
from fnllm.types.io import LLMInput, LLMOutput
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Unpack
|
||||
|
||||
|
@ -116,10 +116,10 @@ async def _text_embed_in_memory(
|
||||
):
|
||||
strategy_type = strategy["type"]
|
||||
strategy_exec = load_strategy(strategy_type)
|
||||
strategy_args = {**strategy}
|
||||
strategy_config = {**strategy}
|
||||
|
||||
texts: list[str] = input[embed_column].to_numpy().tolist()
|
||||
result = await strategy_exec(texts, callbacks, cache, strategy_args)
|
||||
result = await strategy_exec(texts, callbacks, cache, strategy_config)
|
||||
|
||||
return result.embeddings
|
||||
|
||||
@ -137,7 +137,11 @@ async def _text_embed_with_vector_store(
|
||||
):
|
||||
strategy_type = strategy["type"]
|
||||
strategy_exec = load_strategy(strategy_type)
|
||||
strategy_args = {**strategy}
|
||||
strategy_config = {**strategy}
|
||||
|
||||
# if max_retries is not set, inject a dynamically assigned value based on the total number of expected LLM calls to be made
|
||||
if strategy_config.get("llm") and strategy_config["llm"]["max_retries"] == -1:
|
||||
strategy_config["llm"]["max_retries"] = len(input)
|
||||
|
||||
# Get vector-storage configuration
|
||||
insert_batch_size: int = (
|
||||
@ -176,7 +180,7 @@ async def _text_embed_with_vector_store(
|
||||
texts: list[str] = batch[embed_column].to_numpy().tolist()
|
||||
titles: list[str] = batch[title].to_numpy().tolist()
|
||||
ids: list[str] = batch[id_column].to_numpy().tolist()
|
||||
result = await strategy_exec(texts, callbacks, cache, strategy_args)
|
||||
result = await strategy_exec(texts, callbacks, cache, strategy_config)
|
||||
if result.embeddings:
|
||||
embeddings = [
|
||||
embedding for embedding in result.embeddings if embedding is not None
|
||||
|
@ -8,7 +8,7 @@ import logging
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from fnllm import EmbeddingsLLM
|
||||
from fnllm.types import EmbeddingsLLM
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
|
@ -9,7 +9,7 @@ from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import tiktoken
|
||||
from fnllm import ChatLLM
|
||||
from fnllm.types import ChatLLM
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.index.typing import ErrorHandlerFn
|
||||
|
@ -50,6 +50,10 @@ async def extract_covariates(
|
||||
strategy = strategy or {}
|
||||
strategy_config = {**strategy}
|
||||
|
||||
# if max_retries is not set, inject a dynamically assigned value based on the total number of expected LLM calls to be made
|
||||
if strategy_config.get("llm") and strategy_config["llm"]["max_retries"] == -1:
|
||||
strategy_config["llm"]["max_retries"] = len(input)
|
||||
|
||||
async def run_strategy(row):
|
||||
text = row[column]
|
||||
result = await run_extract_claims(
|
||||
|
@ -94,6 +94,10 @@ async def extract_graph(
|
||||
)
|
||||
strategy_config = {**strategy}
|
||||
|
||||
# if max_retries is not set, inject a dynamically assigned value based on the total number of expected LLM calls to be made
|
||||
if strategy_config.get("llm") and strategy_config["llm"]["max_retries"] == -1:
|
||||
strategy_config["llm"]["max_retries"] = len(text_units)
|
||||
|
||||
num_started = 0
|
||||
|
||||
async def run_strategy(row):
|
||||
|
@ -12,7 +12,7 @@ from typing import Any
|
||||
|
||||
import networkx as nx
|
||||
import tiktoken
|
||||
from fnllm import ChatLLM
|
||||
from fnllm.types import ChatLLM
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.index.typing import ErrorHandlerFn
|
||||
|
@ -4,7 +4,7 @@
|
||||
"""A module containing run_graph_intelligence, run_extract_graph and _create_text_splitter methods to run graph intelligence."""
|
||||
|
||||
import networkx as nx
|
||||
from fnllm import ChatLLM
|
||||
from fnllm.types import ChatLLM
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
|
@ -8,7 +8,7 @@ import traceback
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from fnllm import ChatLLM
|
||||
from fnllm.types import ChatLLM
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphrag.index.typing import ErrorHandlerFn
|
||||
|
@ -6,7 +6,7 @@
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from fnllm import ChatLLM
|
||||
from fnllm.types import ChatLLM
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
|
@ -42,7 +42,12 @@ async def summarize_communities(
|
||||
"""Generate community summaries."""
|
||||
reports: list[CommunityReport | None] = []
|
||||
tick = progress_ticker(callbacks.progress, len(local_contexts))
|
||||
runner = load_strategy(strategy["type"])
|
||||
strategy_exec = load_strategy(strategy["type"])
|
||||
strategy_config = {**strategy}
|
||||
|
||||
# if max_retries is not set, inject a dynamically assigned value based on the total number of expected LLM calls to be made
|
||||
if strategy_config.get("llm") and strategy_config["llm"]["max_retries"] == -1:
|
||||
strategy_config["llm"]["max_retries"] = len(nodes)
|
||||
|
||||
community_hierarchy = restore_community_hierarchy(nodes)
|
||||
levels = get_levels(nodes)
|
||||
@ -62,13 +67,13 @@ async def summarize_communities(
|
||||
|
||||
async def run_generate(record):
|
||||
result = await _generate_report(
|
||||
runner,
|
||||
strategy_exec,
|
||||
community_id=record[schemas.COMMUNITY_ID],
|
||||
community_level=record[schemas.COMMUNITY_LEVEL],
|
||||
community_context=record[schemas.CONTEXT_STRING],
|
||||
callbacks=callbacks,
|
||||
cache=cache,
|
||||
strategy=strategy,
|
||||
strategy=strategy_config,
|
||||
)
|
||||
tick()
|
||||
return result
|
||||
|
@ -6,7 +6,7 @@
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
|
||||
from fnllm import ChatLLM
|
||||
from fnllm.types import ChatLLM
|
||||
|
||||
from graphrag.index.typing import ErrorHandlerFn
|
||||
from graphrag.index.utils.tokens import num_tokens_from_string
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
"""A module containing run_graph_intelligence, run_resolve_entities and _create_text_list_splitter methods to run graph intelligence."""
|
||||
|
||||
from fnllm import ChatLLM
|
||||
from fnllm.types import ChatLLM
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
|
@ -76,6 +76,10 @@ async def summarize_descriptions(
|
||||
)
|
||||
strategy_config = {**strategy}
|
||||
|
||||
# if max_retries is not set, inject a dynamically assigned value based on the maximum number of expected LLM calls to be made
|
||||
if strategy_config.get("llm") and strategy_config["llm"]["max_retries"] == -1:
|
||||
strategy_config["llm"]["max_retries"] = len(entities_df) + len(relationships_df)
|
||||
|
||||
async def get_summarized(
|
||||
nodes: pd.DataFrame, edges: pd.DataFrame, semaphore: asyncio.Semaphore
|
||||
):
|
||||
|
@ -6,6 +6,7 @@
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.llm.load_llm import load_llm, load_llm_embeddings
|
||||
@ -17,6 +18,9 @@ def validate_config_names(logger: ProgressLogger, parameters: GraphRagConfig) ->
|
||||
# Validate Chat LLM configs
|
||||
# TODO: Replace default_chat_model with a way to select the model
|
||||
default_llm_settings = parameters.get_language_model_config("default_chat_model")
|
||||
# if max_retries is not set, set it to the default value
|
||||
if default_llm_settings.max_retries == -1:
|
||||
default_llm_settings.max_retries = defs.LLM_MAX_RETRIES
|
||||
llm = load_llm(
|
||||
name="test-llm",
|
||||
config=default_llm_settings,
|
||||
|
@ -39,7 +39,7 @@ async def run_workflow(
|
||||
config.community_reports.model_id
|
||||
)
|
||||
async_mode = community_reports_llm_settings.async_mode
|
||||
num_threads = community_reports_llm_settings.parallelization_num_threads
|
||||
num_threads = community_reports_llm_settings.concurrent_requests
|
||||
summarization_strategy = config.community_reports.resolved_strategy(
|
||||
config.root_dir, community_reports_llm_settings
|
||||
)
|
||||
|
@ -31,7 +31,7 @@ async def run_workflow(
|
||||
config.community_reports.model_id
|
||||
)
|
||||
async_mode = community_reports_llm_settings.async_mode
|
||||
num_threads = community_reports_llm_settings.parallelization_num_threads
|
||||
num_threads = community_reports_llm_settings.concurrent_requests
|
||||
summarization_strategy = config.community_reports.resolved_strategy(
|
||||
config.root_dir, community_reports_llm_settings
|
||||
)
|
||||
|
@ -32,7 +32,7 @@ async def run_workflow(
|
||||
)
|
||||
|
||||
async_mode = extract_claims_llm_settings.async_mode
|
||||
num_threads = extract_claims_llm_settings.parallelization_num_threads
|
||||
num_threads = extract_claims_llm_settings.concurrent_requests
|
||||
|
||||
output = await extract_covariates(
|
||||
text_units,
|
||||
|
@ -32,9 +32,6 @@ async def run_workflow(
|
||||
extraction_strategy = config.extract_graph.resolved_strategy(
|
||||
config.root_dir, extract_graph_llm_settings
|
||||
)
|
||||
extraction_num_threads = extract_graph_llm_settings.parallelization_num_threads
|
||||
extraction_async_mode = extract_graph_llm_settings.async_mode
|
||||
entity_types = config.extract_graph.entity_types
|
||||
|
||||
summarization_llm_settings = config.get_language_model_config(
|
||||
config.summarize_descriptions.model_id
|
||||
@ -42,18 +39,17 @@ async def run_workflow(
|
||||
summarization_strategy = config.summarize_descriptions.resolved_strategy(
|
||||
config.root_dir, summarization_llm_settings
|
||||
)
|
||||
summarization_num_threads = summarization_llm_settings.parallelization_num_threads
|
||||
|
||||
entities, relationships = await extract_graph(
|
||||
text_units=text_units,
|
||||
callbacks=callbacks,
|
||||
cache=context.cache,
|
||||
extraction_strategy=extraction_strategy,
|
||||
extraction_num_threads=extraction_num_threads,
|
||||
extraction_async_mode=extraction_async_mode,
|
||||
entity_types=entity_types,
|
||||
extraction_num_threads=extract_graph_llm_settings.concurrent_requests,
|
||||
extraction_async_mode=extract_graph_llm_settings.async_mode,
|
||||
entity_types=config.extract_graph.entity_types,
|
||||
summarization_strategy=summarization_strategy,
|
||||
summarization_num_threads=summarization_num_threads,
|
||||
summarization_num_threads=summarization_llm_settings.concurrent_requests,
|
||||
embed_config=config.embed_graph,
|
||||
layout_enabled=config.umap.enabled,
|
||||
)
|
||||
|
@ -3,7 +3,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from fnllm import ChatLLM
|
||||
from fnllm.types import ChatLLM
|
||||
|
||||
from graphrag.prompt_tune.prompt.community_report_rating import (
|
||||
GENERATE_REPORT_RATING_PROMPT,
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
"""Generate a community reporter role for community summarization."""
|
||||
|
||||
from fnllm import ChatLLM
|
||||
from fnllm.types import ChatLLM
|
||||
|
||||
from graphrag.prompt_tune.prompt.community_reporter_role import (
|
||||
GENERATE_COMMUNITY_REPORTER_ROLE_PROMPT,
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
"""Domain generation for GraphRAG prompts."""
|
||||
|
||||
from fnllm import ChatLLM
|
||||
from fnllm.types import ChatLLM
|
||||
|
||||
from graphrag.prompt_tune.prompt.domain import GENERATE_DOMAIN_PROMPT
|
||||
|
||||
|
@ -6,7 +6,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from fnllm import ChatLLM
|
||||
from fnllm.types import ChatLLM
|
||||
|
||||
from graphrag.prompt_tune.prompt.entity_relationship import (
|
||||
ENTITY_RELATIONSHIPS_GENERATION_JSON_PROMPT,
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
"""Entity type generation module for fine-tuning."""
|
||||
|
||||
from fnllm import ChatLLM
|
||||
from fnllm.types import ChatLLM
|
||||
from pydantic import BaseModel
|
||||
|
||||
from graphrag.prompt_tune.defaults import DEFAULT_TASK
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
"""Language detection for GraphRAG prompts."""
|
||||
|
||||
from fnllm import ChatLLM
|
||||
from fnllm.types import ChatLLM
|
||||
|
||||
from graphrag.prompt_tune.prompt.language import DETECT_LANGUAGE_PROMPT
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
"""Persona generating module for fine-tuning GraphRAG prompts."""
|
||||
|
||||
from fnllm import ChatLLM
|
||||
from fnllm.types import ChatLLM
|
||||
|
||||
from graphrag.prompt_tune.defaults import DEFAULT_TASK
|
||||
from graphrag.prompt_tune.prompt.persona import GENERATE_PERSONA_PROMPT
|
||||
|
@ -5,7 +5,7 @@
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from fnllm import ChatLLM
|
||||
from fnllm.types import ChatLLM
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
|
@ -5,6 +5,7 @@
|
||||
|
||||
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.config.enums import AuthType, LLMType
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
|
||||
@ -14,67 +15,68 @@ from graphrag.query.llm.oai.typing import OpenaiApiType
|
||||
|
||||
def get_llm(config: GraphRagConfig) -> ChatOpenAI:
|
||||
"""Get the LLM client."""
|
||||
default_llm_settings = config.get_language_model_config("default_chat_model")
|
||||
is_azure_client = default_llm_settings.type == LLMType.AzureOpenAIChat
|
||||
debug_llm_key = default_llm_settings.api_key or ""
|
||||
llm_config = config.get_language_model_config("default_chat_model")
|
||||
is_azure_client = llm_config.type == LLMType.AzureOpenAIChat
|
||||
debug_llm_key = llm_config.api_key or ""
|
||||
llm_debug_info = {
|
||||
**default_llm_settings.model_dump(),
|
||||
**llm_config.model_dump(),
|
||||
"api_key": f"REDACTED,len={len(debug_llm_key)}",
|
||||
}
|
||||
audience = (
|
||||
default_llm_settings.audience
|
||||
if default_llm_settings.audience
|
||||
llm_config.audience
|
||||
if llm_config.audience
|
||||
else "https://cognitiveservices.azure.com/.default"
|
||||
)
|
||||
print(f"creating llm client with {llm_debug_info}") # noqa T201
|
||||
return ChatOpenAI(
|
||||
api_key=default_llm_settings.api_key,
|
||||
api_key=llm_config.api_key,
|
||||
azure_ad_token_provider=(
|
||||
get_bearer_token_provider(DefaultAzureCredential(), audience)
|
||||
if is_azure_client
|
||||
and default_llm_settings.auth_type == AuthType.AzureManagedIdentity
|
||||
if is_azure_client and llm_config.auth_type == AuthType.AzureManagedIdentity
|
||||
else None
|
||||
),
|
||||
api_base=default_llm_settings.api_base,
|
||||
organization=default_llm_settings.organization,
|
||||
model=default_llm_settings.model,
|
||||
api_base=llm_config.api_base,
|
||||
organization=llm_config.organization,
|
||||
model=llm_config.model,
|
||||
api_type=OpenaiApiType.AzureOpenAI if is_azure_client else OpenaiApiType.OpenAI,
|
||||
deployment_name=default_llm_settings.deployment_name,
|
||||
api_version=default_llm_settings.api_version,
|
||||
max_retries=default_llm_settings.max_retries,
|
||||
request_timeout=default_llm_settings.request_timeout,
|
||||
deployment_name=llm_config.deployment_name,
|
||||
api_version=llm_config.api_version,
|
||||
max_retries=llm_config.max_retries
|
||||
if llm_config.max_retries != -1
|
||||
else defs.LLM_MAX_RETRIES,
|
||||
request_timeout=llm_config.request_timeout,
|
||||
)
|
||||
|
||||
|
||||
def get_text_embedder(config: GraphRagConfig) -> OpenAIEmbedding:
|
||||
"""Get the LLM client for embeddings."""
|
||||
embeddings_llm_settings = config.get_language_model_config(
|
||||
config.embed_text.model_id
|
||||
)
|
||||
is_azure_client = embeddings_llm_settings.type == LLMType.AzureOpenAIEmbedding
|
||||
debug_embedding_api_key = embeddings_llm_settings.api_key or ""
|
||||
embeddings_llm_config = config.get_language_model_config(config.embed_text.model_id)
|
||||
is_azure_client = embeddings_llm_config.type == LLMType.AzureOpenAIEmbedding
|
||||
debug_embedding_api_key = embeddings_llm_config.api_key or ""
|
||||
llm_debug_info = {
|
||||
**embeddings_llm_settings.model_dump(),
|
||||
**embeddings_llm_config.model_dump(),
|
||||
"api_key": f"REDACTED,len={len(debug_embedding_api_key)}",
|
||||
}
|
||||
if embeddings_llm_settings.audience is None:
|
||||
if embeddings_llm_config.audience is None:
|
||||
audience = "https://cognitiveservices.azure.com/.default"
|
||||
else:
|
||||
audience = embeddings_llm_settings.audience
|
||||
audience = embeddings_llm_config.audience
|
||||
print(f"creating embedding llm client with {llm_debug_info}") # noqa T201
|
||||
return OpenAIEmbedding(
|
||||
api_key=embeddings_llm_settings.api_key,
|
||||
api_key=embeddings_llm_config.api_key,
|
||||
azure_ad_token_provider=(
|
||||
get_bearer_token_provider(DefaultAzureCredential(), audience)
|
||||
if is_azure_client
|
||||
and embeddings_llm_settings.auth_type == AuthType.AzureManagedIdentity
|
||||
and embeddings_llm_config.auth_type == AuthType.AzureManagedIdentity
|
||||
else None
|
||||
),
|
||||
api_base=embeddings_llm_settings.api_base,
|
||||
organization=embeddings_llm_settings.organization,
|
||||
api_base=embeddings_llm_config.api_base,
|
||||
organization=embeddings_llm_config.organization,
|
||||
api_type=OpenaiApiType.AzureOpenAI if is_azure_client else OpenaiApiType.OpenAI,
|
||||
model=embeddings_llm_settings.model,
|
||||
deployment_name=embeddings_llm_settings.deployment_name,
|
||||
api_version=embeddings_llm_settings.api_version,
|
||||
max_retries=embeddings_llm_settings.max_retries,
|
||||
model=embeddings_llm_config.model,
|
||||
deployment_name=embeddings_llm_config.deployment_name,
|
||||
api_version=embeddings_llm_config.api_version,
|
||||
max_retries=embeddings_llm_config.max_retries
|
||||
if embeddings_llm_config.max_retries != -1
|
||||
else defs.LLM_MAX_RETRIES,
|
||||
)
|
||||
|
864
poetry.lock
generated
864
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -57,11 +57,15 @@ lancedb = "^0.17.0"
|
||||
aiofiles = "^24.1.0"
|
||||
|
||||
# LLM
|
||||
fnllm = {extras = ["azure", "openai"], version = "^0.1.2"}
|
||||
httpx = "^0.28.1"
|
||||
json-repair = "^0.30.3"
|
||||
openai = "^1.57.0"
|
||||
nltk = "3.9.1"
|
||||
tenacity = "^9.0.0"
|
||||
tiktoken = "^0.8.0"
|
||||
|
||||
# Data-Sci
|
||||
# Data-Science
|
||||
numpy = "^1.25.2"
|
||||
graspologic = "^3.4.1"
|
||||
networkx = "^3.4.2"
|
||||
@ -78,22 +82,18 @@ rich = "^13.9.4"
|
||||
devtools = "^0.12.2"
|
||||
typing-extensions = "^4.12.2"
|
||||
|
||||
#Azure
|
||||
# Azure
|
||||
azure-cosmos = "^4.9.0"
|
||||
azure-identity = "^1.19.0"
|
||||
azure-storage-blob = "^12.24.0"
|
||||
|
||||
future = "^1.0.0" # Needed until graspologic fixes their dependency
|
||||
typer = "^0.15.1"
|
||||
fnllm = "^0.0.10"
|
||||
|
||||
tenacity = "^9.0.0"
|
||||
json-repair = "^0.30.3"
|
||||
tqdm = "^4.67.1"
|
||||
httpx = "^0.28.1"
|
||||
|
||||
textblob = "^0.18.0.post0"
|
||||
spacy = "^3.8.4"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
coverage = "^7.6.9"
|
||||
ipykernel = "^6.29.5"
|
||||
|
2
tests/fixtures/min-csv/config.json
vendored
2
tests/fixtures/min-csv/config.json
vendored
@ -15,7 +15,7 @@
|
||||
1,
|
||||
2500
|
||||
],
|
||||
"max_runtime": 300,
|
||||
"max_runtime": 500,
|
||||
"expected_artifacts": 2
|
||||
},
|
||||
"create_communities": {
|
||||
|
6
tests/fixtures/min-csv/settings.yml
vendored
6
tests/fixtures/min-csv/settings.yml
vendored
@ -10,8 +10,7 @@ models:
|
||||
tokens_per_minute: ${GRAPHRAG_LLM_TPM}
|
||||
requests_per_minute: ${GRAPHRAG_LLM_RPM}
|
||||
model_supports_json: true
|
||||
parallelization_num_threads: 50
|
||||
parallelization_stagger: 0.3
|
||||
concurrent_requests: 50
|
||||
async_mode: threaded
|
||||
default_embedding_model:
|
||||
azure_auth_type: api_key
|
||||
@ -23,8 +22,7 @@ models:
|
||||
model: ${GRAPHRAG_EMBEDDING_MODEL}
|
||||
tokens_per_minute: ${GRAPHRAG_EMBEDDING_TPM}
|
||||
requests_per_minute: ${GRAPHRAG_EMBEDDING_RPM}
|
||||
parallelization_num_threads: 50
|
||||
parallelization_stagger: 0.3
|
||||
concurrent_requests: 50
|
||||
async_mode: threaded
|
||||
|
||||
vector_store:
|
||||
|
2
tests/fixtures/text/config.json
vendored
2
tests/fixtures/text/config.json
vendored
@ -15,7 +15,7 @@
|
||||
1,
|
||||
2500
|
||||
],
|
||||
"max_runtime": 300,
|
||||
"max_runtime": 500,
|
||||
"expected_artifacts": 2
|
||||
},
|
||||
"extract_covariates": {
|
||||
|
6
tests/fixtures/text/settings.yml
vendored
6
tests/fixtures/text/settings.yml
vendored
@ -10,8 +10,7 @@ models:
|
||||
tokens_per_minute: ${GRAPHRAG_LLM_TPM}
|
||||
requests_per_minute: ${GRAPHRAG_LLM_RPM}
|
||||
model_supports_json: true
|
||||
parallelization_num_threads: 50
|
||||
parallelization_stagger: 0.3
|
||||
concurrent_requests: 50
|
||||
async_mode: threaded
|
||||
default_embedding_model:
|
||||
azure_auth_type: api_key
|
||||
@ -23,8 +22,7 @@ models:
|
||||
model: ${GRAPHRAG_EMBEDDING_MODEL}
|
||||
tokens_per_minute: ${GRAPHRAG_EMBEDDING_TPM}
|
||||
requests_per_minute: ${GRAPHRAG_EMBEDDING_RPM}
|
||||
parallelization_num_threads: 50
|
||||
parallelization_stagger: 0.3
|
||||
concurrent_requests: 50
|
||||
async_mode: threaded
|
||||
|
||||
vector_store:
|
||||
|
@ -265,13 +265,7 @@ def assert_language_model_configs(
|
||||
assert actual.requests_per_minute == expected.requests_per_minute
|
||||
assert actual.max_retries == expected.max_retries
|
||||
assert actual.max_retry_wait == expected.max_retry_wait
|
||||
assert (
|
||||
actual.sleep_on_rate_limit_recommendation
|
||||
== expected.sleep_on_rate_limit_recommendation
|
||||
)
|
||||
assert actual.concurrent_requests == expected.concurrent_requests
|
||||
assert actual.parallelization_stagger == expected.parallelization_stagger
|
||||
assert actual.parallelization_num_threads == expected.parallelization_num_threads
|
||||
assert actual.async_mode == expected.async_mode
|
||||
if actual.responses is not None:
|
||||
assert expected.responses is not None
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
from fnllm import ChatLLM
|
||||
from fnllm.types import ChatLLM
|
||||
from pydantic import BaseModel
|
||||
|
||||
from graphrag.index.llm.mock_llm import MockChatLLM
|
||||
|
Loading…
x
Reference in New Issue
Block a user