mirror of
https://github.com/microsoft/graphrag.git
synced 2025-12-01 01:20:22 +00:00
Implement query api (#839)
* initial API redesign * typo fix * update docstring * update docsring * remove artifacts caused by the merge from main * minor typo updates * add semversioner check * switch API to async function calls --------- Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
This commit is contained in:
parent
7fd23fa79c
commit
4bcbfd10eb
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "minor",
|
||||
"description": "Implement query engine API."
|
||||
}
|
||||
@ -23,7 +23,10 @@ class SearchType(Enum):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="python -m graphrag.query",
|
||||
description="The graphrag query engine",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
@ -49,7 +52,7 @@ if __name__ == "__main__":
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
help="The method to run, one of: local or global",
|
||||
help="The method to run",
|
||||
required=True,
|
||||
type=SearchType,
|
||||
choices=list(SearchType),
|
||||
@ -57,14 +60,14 @@ if __name__ == "__main__":
|
||||
|
||||
parser.add_argument(
|
||||
"--community_level",
|
||||
help="Community level in the Leiden community hierarchy from which we will load the community reports higher value means we use reports on smaller communities",
|
||||
help="Community level in the Leiden community hierarchy from which we will load the community reports higher value means we use reports on smaller communities. Default: 2",
|
||||
type=int,
|
||||
default=2,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--response_type",
|
||||
help="Free form text describing the response type and format, can be anything, e.g. Multiple Paragraphs, Single Paragraph, Single Sentence, List of 3-7 Points, Single Page, Multi-Page Report",
|
||||
help="Free form text describing the response type and format, can be anything, e.g. Multiple Paragraphs, Single Paragraph, Single Sentence, List of 3-7 Points, Single Page, Multi-Page Report. Default: Multiple Paragraphs",
|
||||
type=str,
|
||||
default="Multiple Paragraphs",
|
||||
)
|
||||
|
||||
192
graphrag/query/api.py
Normal file
192
graphrag/query/api.py
Normal file
@ -0,0 +1,192 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""
|
||||
Query Engine API.
|
||||
|
||||
This API provides access to the query engine of graphrag, allowing external applications
|
||||
to hook into graphrag and run queries over a knowledge graph generated by graphrag.
|
||||
|
||||
WARNING: This API is under development and may undergo changes in future releases.
|
||||
Backwards compatibility is not guaranteed at this time.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
from pydantic import validate_call
|
||||
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.progress.types import PrintProgressReporter
|
||||
from graphrag.model.entity import Entity
|
||||
from graphrag.vector_stores.lancedb import LanceDBVectorStore
|
||||
from graphrag.vector_stores.typing import VectorStoreFactory, VectorStoreType
|
||||
|
||||
from .factories import get_global_search_engine, get_local_search_engine
|
||||
from .indexer_adapters import (
|
||||
read_indexer_covariates,
|
||||
read_indexer_entities,
|
||||
read_indexer_relationships,
|
||||
read_indexer_reports,
|
||||
read_indexer_text_units,
|
||||
)
|
||||
from .input.loaders.dfs import store_entity_semantic_embeddings
|
||||
|
||||
reporter = PrintProgressReporter("")
|
||||
|
||||
|
||||
def __get_embedding_description_store(
|
||||
entities: list[Entity],
|
||||
vector_store_type: str = VectorStoreType.LanceDB,
|
||||
config_args: dict | None = None,
|
||||
):
|
||||
"""Get the embedding description store."""
|
||||
if not config_args:
|
||||
config_args = {}
|
||||
|
||||
collection_name = config_args.get(
|
||||
"query_collection_name", "entity_description_embeddings"
|
||||
)
|
||||
config_args.update({"collection_name": collection_name})
|
||||
description_embedding_store = VectorStoreFactory.get_vector_store(
|
||||
vector_store_type=vector_store_type, kwargs=config_args
|
||||
)
|
||||
|
||||
description_embedding_store.connect(**config_args)
|
||||
|
||||
if config_args.get("overwrite", True):
|
||||
# this step assumes the embeddings were originally stored in a file rather
|
||||
# than a vector database
|
||||
|
||||
# dump embeddings from the entities list to the description_embedding_store
|
||||
store_entity_semantic_embeddings(
|
||||
entities=entities, vectorstore=description_embedding_store
|
||||
)
|
||||
else:
|
||||
# load description embeddings to an in-memory lancedb vectorstore
|
||||
# and connect to a remote db, specify url and port values.
|
||||
description_embedding_store = LanceDBVectorStore(
|
||||
collection_name=collection_name
|
||||
)
|
||||
description_embedding_store.connect(
|
||||
db_uri=config_args.get("db_uri", "./lancedb")
|
||||
)
|
||||
|
||||
# load data from an existing table
|
||||
description_embedding_store.document_collection = (
|
||||
description_embedding_store.db_connection.open_table(
|
||||
description_embedding_store.collection_name
|
||||
)
|
||||
)
|
||||
|
||||
return description_embedding_store
|
||||
|
||||
|
||||
@validate_call(config={"arbitrary_types_allowed": True})
|
||||
async def global_search(
|
||||
config: GraphRagConfig,
|
||||
nodes: pd.DataFrame,
|
||||
entities: pd.DataFrame,
|
||||
community_reports: pd.DataFrame,
|
||||
community_level: int,
|
||||
response_type: str,
|
||||
query: str,
|
||||
) -> str | dict[str, Any] | list[dict[str, Any]]:
|
||||
"""Perform a global search.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
|
||||
- nodes (pd.DataFrame): A DataFrame containing the final nodes (from create_final_nodes.parquet)
|
||||
- entities (pd.DataFrame): A DataFrame containing the final entities (from create_final_entities.parquet)
|
||||
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from create_final_community_reports.parquet)
|
||||
- community_level (int): The community level to search at.
|
||||
- response_type (str): The type of response to return.
|
||||
- query (str): The user query to search for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
|
||||
Raises
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
reports = read_indexer_reports(community_reports, nodes, community_level)
|
||||
_entities = read_indexer_entities(nodes, entities, community_level)
|
||||
search_engine = get_global_search_engine(
|
||||
config,
|
||||
reports=reports,
|
||||
entities=_entities,
|
||||
response_type=response_type,
|
||||
)
|
||||
result = await search_engine.asearch(query=query)
|
||||
reporter.success(f"Global Search Response: {result.response}")
|
||||
return result.response
|
||||
|
||||
|
||||
@validate_call(config={"arbitrary_types_allowed": True})
|
||||
async def local_search(
|
||||
config: GraphRagConfig,
|
||||
nodes: pd.DataFrame,
|
||||
entities: pd.DataFrame,
|
||||
community_reports: pd.DataFrame,
|
||||
text_units: pd.DataFrame,
|
||||
relationships: pd.DataFrame,
|
||||
covariates: pd.DataFrame | None,
|
||||
community_level: int,
|
||||
response_type: str,
|
||||
query: str,
|
||||
) -> str | dict[str, Any] | list[dict[str, Any]]:
|
||||
"""Perform a local search.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
|
||||
- nodes (pd.DataFrame): A DataFrame containing the final nodes (from create_final_nodes.parquet)
|
||||
- entities (pd.DataFrame): A DataFrame containing the final entities (from create_final_entities.parquet)
|
||||
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from create_final_community_reports.parquet)
|
||||
- text_units (pd.DataFrame): A DataFrame containing the final text units (from create_final_text_units.parquet)
|
||||
- relationships (pd.DataFrame): A DataFrame containing the final relationships (from create_final_relationships.parquet)
|
||||
- covariates (pd.DataFrame): A DataFrame containing the final covariates (from create_final_covariates.parquet)
|
||||
- community_level (int): The community level to search at.
|
||||
- response_type (str): The response type to return.
|
||||
- query (str): The user query to search for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
|
||||
Raises
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
vector_store_args = (
|
||||
config.embeddings.vector_store if config.embeddings.vector_store else {}
|
||||
)
|
||||
|
||||
reporter.info(f"Vector Store Args: {vector_store_args}")
|
||||
vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB)
|
||||
|
||||
_entities = read_indexer_entities(nodes, entities, community_level)
|
||||
description_embedding_store = __get_embedding_description_store(
|
||||
entities=_entities,
|
||||
vector_store_type=vector_store_type,
|
||||
config_args=vector_store_args,
|
||||
)
|
||||
_covariates = read_indexer_covariates(covariates) if covariates is not None else []
|
||||
|
||||
search_engine = get_local_search_engine(
|
||||
config=config,
|
||||
reports=read_indexer_reports(community_reports, nodes, community_level),
|
||||
text_units=read_indexer_text_units(text_units),
|
||||
entities=_entities,
|
||||
relationships=read_indexer_relationships(relationships),
|
||||
covariates={"claims": _covariates},
|
||||
description_embedding_store=description_embedding_store,
|
||||
response_type=response_type,
|
||||
)
|
||||
|
||||
result = await search_engine.asearch(query=query)
|
||||
reporter.success(f"Local Search Response: {result.response}")
|
||||
return result.response
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
"""Command line interface for the query module."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
@ -14,72 +15,12 @@ from graphrag.config import (
|
||||
create_graphrag_config,
|
||||
)
|
||||
from graphrag.index.progress import PrintProgressReporter
|
||||
from graphrag.model.entity import Entity
|
||||
from graphrag.query.input.loaders.dfs import (
|
||||
store_entity_semantic_embeddings,
|
||||
)
|
||||
from graphrag.vector_stores import VectorStoreFactory, VectorStoreType
|
||||
from graphrag.vector_stores.lancedb import LanceDBVectorStore
|
||||
|
||||
from .factories import get_global_search_engine, get_local_search_engine
|
||||
from .indexer_adapters import (
|
||||
read_indexer_covariates,
|
||||
read_indexer_entities,
|
||||
read_indexer_relationships,
|
||||
read_indexer_reports,
|
||||
read_indexer_text_units,
|
||||
)
|
||||
from . import api
|
||||
|
||||
reporter = PrintProgressReporter("")
|
||||
|
||||
|
||||
def __get_embedding_description_store(
|
||||
entities: list[Entity],
|
||||
vector_store_type: str = VectorStoreType.LanceDB,
|
||||
config_args: dict | None = None,
|
||||
):
|
||||
"""Get the embedding description store."""
|
||||
if not config_args:
|
||||
config_args = {}
|
||||
|
||||
collection_name = config_args.get(
|
||||
"query_collection_name", "entity_description_embeddings"
|
||||
)
|
||||
config_args.update({"collection_name": collection_name})
|
||||
description_embedding_store = VectorStoreFactory.get_vector_store(
|
||||
vector_store_type=vector_store_type, kwargs=config_args
|
||||
)
|
||||
|
||||
description_embedding_store.connect(**config_args)
|
||||
|
||||
if config_args.get("overwrite", True):
|
||||
# this step assumps the embeddings where originally stored in a file rather
|
||||
# than a vector database
|
||||
|
||||
# dump embeddings from the entities list to the description_embedding_store
|
||||
store_entity_semantic_embeddings(
|
||||
entities=entities, vectorstore=description_embedding_store
|
||||
)
|
||||
else:
|
||||
# load description embeddings to an in-memory lancedb vectorstore
|
||||
# to connect to a remote db, specify url and port values.
|
||||
description_embedding_store = LanceDBVectorStore(
|
||||
collection_name=collection_name
|
||||
)
|
||||
description_embedding_store.connect(
|
||||
db_uri=config_args.get("db_uri", "./lancedb")
|
||||
)
|
||||
|
||||
# load data from an existing table
|
||||
description_embedding_store.document_collection = (
|
||||
description_embedding_store.db_connection.open_table(
|
||||
description_embedding_store.collection_name
|
||||
)
|
||||
)
|
||||
|
||||
return description_embedding_store
|
||||
|
||||
|
||||
def run_global_search(
|
||||
config_dir: str | None,
|
||||
data_dir: str | None,
|
||||
@ -88,7 +29,10 @@ def run_global_search(
|
||||
response_type: str,
|
||||
query: str,
|
||||
):
|
||||
"""Run a global search with the given query."""
|
||||
"""Perform a global search with a given query.
|
||||
|
||||
Loads index files required for global search and calls the Query API.
|
||||
"""
|
||||
data_dir, root_dir, config = _configure_paths_and_settings(
|
||||
data_dir, root_dir, config_dir
|
||||
)
|
||||
@ -104,21 +48,17 @@ def run_global_search(
|
||||
data_path / "create_final_community_reports.parquet"
|
||||
)
|
||||
|
||||
reports = read_indexer_reports(
|
||||
final_community_reports, final_nodes, community_level
|
||||
)
|
||||
entities = read_indexer_entities(final_nodes, final_entities, community_level)
|
||||
search_engine = get_global_search_engine(
|
||||
config,
|
||||
reports=reports,
|
||||
entities=entities,
|
||||
return asyncio.run(
|
||||
api.global_search(
|
||||
config=config,
|
||||
nodes=final_nodes,
|
||||
entities=final_entities,
|
||||
community_reports=final_community_reports,
|
||||
community_level=community_level,
|
||||
response_type=response_type,
|
||||
query=query,
|
||||
)
|
||||
)
|
||||
|
||||
result = search_engine.search(query=query)
|
||||
|
||||
reporter.success(f"Global Search Response: {result.response}")
|
||||
return result.response
|
||||
|
||||
|
||||
def run_local_search(
|
||||
@ -129,7 +69,10 @@ def run_local_search(
|
||||
response_type: str,
|
||||
query: str,
|
||||
):
|
||||
"""Run a local search with the given query."""
|
||||
"""Perform a local search with a given query.
|
||||
|
||||
Loads index files required for local search and calls the Query API.
|
||||
"""
|
||||
data_dir, root_dir, config = _configure_paths_and_settings(
|
||||
data_dir, root_dir, config_dir
|
||||
)
|
||||
@ -151,41 +94,21 @@ def run_local_search(
|
||||
else None
|
||||
)
|
||||
|
||||
vector_store_args = (
|
||||
config.embeddings.vector_store if config.embeddings.vector_store else {}
|
||||
)
|
||||
|
||||
reporter.info(f"Vector Store Args: {vector_store_args}")
|
||||
vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB)
|
||||
|
||||
entities = read_indexer_entities(final_nodes, final_entities, community_level)
|
||||
description_embedding_store = __get_embedding_description_store(
|
||||
entities=entities,
|
||||
vector_store_type=vector_store_type,
|
||||
config_args=vector_store_args,
|
||||
)
|
||||
covariates = (
|
||||
read_indexer_covariates(final_covariates)
|
||||
if final_covariates is not None
|
||||
else []
|
||||
)
|
||||
|
||||
search_engine = get_local_search_engine(
|
||||
config,
|
||||
reports=read_indexer_reports(
|
||||
final_community_reports, final_nodes, community_level
|
||||
),
|
||||
text_units=read_indexer_text_units(final_text_units),
|
||||
entities=entities,
|
||||
relationships=read_indexer_relationships(final_relationships),
|
||||
covariates={"claims": covariates},
|
||||
description_embedding_store=description_embedding_store,
|
||||
# call the Query API
|
||||
return asyncio.run(
|
||||
api.local_search(
|
||||
config=config,
|
||||
nodes=final_nodes,
|
||||
entities=final_entities,
|
||||
community_reports=final_community_reports,
|
||||
text_units=final_text_units,
|
||||
relationships=final_relationships,
|
||||
covariates=final_covariates,
|
||||
community_level=community_level,
|
||||
response_type=response_type,
|
||||
query=query,
|
||||
)
|
||||
)
|
||||
|
||||
result = search_engine.search(query=query)
|
||||
reporter.success(f"Local Search Response: {result.response}")
|
||||
return result.response
|
||||
|
||||
|
||||
def _configure_paths_and_settings(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user