graphrag/graphrag/api/index.py

100 lines
3.4 KiB
Python
Raw Normal View History

# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""
Indexing API for GraphRAG.
WARNING: This API is under development and may undergo changes in future releases.
Backwards compatibility is not guaranteed at this time.
"""
import logging
Refactor config (#1593) * Refactor config - Add new ModelConfig to represent LLM settings - Combines LLMParameters, ParallelizationParameters, encoding_model, and async_mode - Add top level models config that is a list of available LLM ModelConfigs - Remove LLMConfig inheritance and delete LLMConfig - Replace the inheritance with a model_id reference to the ModelConfig listed in the top level models config - Remove all fallbacks and hydration logic from create_graphrag_config - This removes the automatic env variable overrides - Support env variables within config files using Templating - This requires "$" to be escaped with extra "$" so ".*\\.txt$" becomes ".*\\.txt$$" - Update init content to initialize new config file with the ModelConfig structure * Use dict of ModelConfig instead of list * Add model validations and unit tests * Fix ruff checks * Add semversioner change * Fix unit tests * validate root_dir in pydantic model * Rename ModelConfig to LanguageModelConfig * Rename ModelConfigMissingError to LanguageModelConfigMissingError * Add validationg for unexpected API keys * Allow skipping pydantic validation for testing/mocking purposes. * Add default lm configs to verb tests * smoke test * remove config from flows to fix llm arg mapping * Fix embedding llm arg mapping * Remove timestamp from smoke test outputs * Remove unused "subworkflows" smoke test properties * Add models to smoke test configs * Update smoke test output path * Send logs to logs folder * Fix output path * Fix csv test file pattern * Update placeholder * Format * Instantiate default model configs * Fix unit tests for config defaults * Fix migration notebook * Remove create_pipeline_config * Remove several unused config models * Remove indexing embedding and input configs * Move embeddings function to config * Remove skip_workflows * Remove skip embeddings in favor of explicit naming * fix unit test spelling mistake * self.models[model_id] is already a language model. Remove redundant casting. * update validation errors to instruct users to rerun graphrag init * instantiate LanguageModelConfigs with validation * skip validation in unit tests * update verb tests to use default model settings instead of skipping validation * test using llm settings * cleanup verb tests * remove unsafe default model config * remove the ability to skip pydantic validation * remove None union types when default values are set * move vector_store from embeddings to top level of config and delete resolve_paths * update vector store settings * fix vector store and smoke tests * fix serializing vector_store settings * fix vector_store usage * fix vector_store type * support cli overrides for loading graphrag config * rename storage to output * Add --force flag to init * Remove run_id and resume, fix Drift config assignment * Ruff --------- Co-authored-by: Nathan Evans <github@talkswithnumbers.com> Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
2025-01-21 15:52:06 -08:00
from graphrag.callbacks.reporting import create_pipeline_reporter
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.enums import IndexingMethod
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.run.run_pipeline import run_pipeline
from graphrag.index.run.utils import create_callback_chain
from graphrag.index.typing.pipeline_run_result import PipelineRunResult
from graphrag.index.typing.workflow import WorkflowFunction
from graphrag.index.workflows.factory import PipelineFactory
from graphrag.logger.base import ProgressLogger
from graphrag.logger.null_progress import NullProgressLogger
log = logging.getLogger(__name__)
async def build_index(
config: GraphRagConfig,
method: IndexingMethod | str = IndexingMethod.Standard,
is_update_run: bool = False,
memory_profile: bool = False,
callbacks: list[WorkflowCallbacks] | None = None,
progress_logger: ProgressLogger | None = None,
) -> list[PipelineRunResult]:
"""Run the pipeline with the given configuration.
Parameters
----------
config : GraphRagConfig
The configuration.
method : IndexingMethod default=IndexingMethod.Standard
Styling of indexing to perform (full LLM, NLP + LLM, etc.).
memory_profile : bool
Whether to enable memory profiling.
callbacks : list[WorkflowCallbacks] | None default=None
A list of callbacks to register.
progress_logger : ProgressLogger | None default=None
The progress logger.
Returns
-------
list[PipelineRunResult]
The list of pipeline run results
"""
logger = progress_logger or NullProgressLogger()
# create a pipeline reporter and add to any additional callbacks
callbacks = callbacks or []
callbacks.append(create_pipeline_reporter(config.reporting, None))
workflow_callbacks = create_callback_chain(callbacks, logger)
outputs: list[PipelineRunResult] = []
if memory_profile:
log.warning("New pipeline does not yet support memory profiling.")
# todo: this could propagate out to the cli for better clarity, but will be a breaking api change
method = _get_method(method, is_update_run)
pipeline = PipelineFactory.create_pipeline(config, method)
workflow_callbacks.pipeline_start(pipeline.names())
async for output in run_pipeline(
pipeline,
config,
callbacks=workflow_callbacks,
logger=logger,
is_update_run=is_update_run,
):
outputs.append(output)
if output.errors and len(output.errors) > 0:
logger.error(output.workflow)
else:
logger.success(output.workflow)
logger.info(str(output.result))
workflow_callbacks.pipeline_end(outputs)
return outputs
def register_workflow_function(name: str, workflow: WorkflowFunction):
"""Register a custom workflow function. You can then include the name in the settings.yaml workflows list."""
PipelineFactory.register(name, workflow)
def _get_method(method: IndexingMethod | str, is_update_run: bool) -> str:
m = method.value if isinstance(method, IndexingMethod) else method
return f"{m}-update" if is_update_run else m