Implement query api (#839)

* initial API redesign

* typo fix

* update docstring

* update docsring

* remove artifacts caused by the merge from main

* minor typo updates

* add semversioner check

* switch API to async function calls

---------

Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
This commit is contained in:
Josh Bradley 2024-08-12 15:40:10 -04:00 committed by GitHub
parent 7fd23fa79c
commit 4bcbfd10eb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 237 additions and 115 deletions

View File

@ -0,0 +1,4 @@
{
"type": "minor",
"description": "Implement query engine API."
}

View File

@ -23,7 +23,10 @@ class SearchType(Enum):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser(
prog="python -m graphrag.query",
description="The graphrag query engine",
)
parser.add_argument( parser.add_argument(
"--config", "--config",
@ -49,7 +52,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--method", "--method",
help="The method to run, one of: local or global", help="The method to run",
required=True, required=True,
type=SearchType, type=SearchType,
choices=list(SearchType), choices=list(SearchType),
@ -57,14 +60,14 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--community_level", "--community_level",
help="Community level in the Leiden community hierarchy from which we will load the community reports higher value means we use reports on smaller communities", help="Community level in the Leiden community hierarchy from which we will load the community reports higher value means we use reports on smaller communities. Default: 2",
type=int, type=int,
default=2, default=2,
) )
parser.add_argument( parser.add_argument(
"--response_type", "--response_type",
help="Free form text describing the response type and format, can be anything, e.g. Multiple Paragraphs, Single Paragraph, Single Sentence, List of 3-7 Points, Single Page, Multi-Page Report", help="Free form text describing the response type and format, can be anything, e.g. Multiple Paragraphs, Single Paragraph, Single Sentence, List of 3-7 Points, Single Page, Multi-Page Report. Default: Multiple Paragraphs",
type=str, type=str,
default="Multiple Paragraphs", default="Multiple Paragraphs",
) )

192
graphrag/query/api.py Normal file
View File

@ -0,0 +1,192 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""
Query Engine API.
This API provides access to the query engine of graphrag, allowing external applications
to hook into graphrag and run queries over a knowledge graph generated by graphrag.
WARNING: This API is under development and may undergo changes in future releases.
Backwards compatibility is not guaranteed at this time.
"""
from typing import Any
import pandas as pd
from pydantic import validate_call
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.progress.types import PrintProgressReporter
from graphrag.model.entity import Entity
from graphrag.vector_stores.lancedb import LanceDBVectorStore
from graphrag.vector_stores.typing import VectorStoreFactory, VectorStoreType
from .factories import get_global_search_engine, get_local_search_engine
from .indexer_adapters import (
read_indexer_covariates,
read_indexer_entities,
read_indexer_relationships,
read_indexer_reports,
read_indexer_text_units,
)
from .input.loaders.dfs import store_entity_semantic_embeddings
reporter = PrintProgressReporter("")
def __get_embedding_description_store(
entities: list[Entity],
vector_store_type: str = VectorStoreType.LanceDB,
config_args: dict | None = None,
):
"""Get the embedding description store."""
if not config_args:
config_args = {}
collection_name = config_args.get(
"query_collection_name", "entity_description_embeddings"
)
config_args.update({"collection_name": collection_name})
description_embedding_store = VectorStoreFactory.get_vector_store(
vector_store_type=vector_store_type, kwargs=config_args
)
description_embedding_store.connect(**config_args)
if config_args.get("overwrite", True):
# this step assumes the embeddings were originally stored in a file rather
# than a vector database
# dump embeddings from the entities list to the description_embedding_store
store_entity_semantic_embeddings(
entities=entities, vectorstore=description_embedding_store
)
else:
# load description embeddings to an in-memory lancedb vectorstore
# and connect to a remote db, specify url and port values.
description_embedding_store = LanceDBVectorStore(
collection_name=collection_name
)
description_embedding_store.connect(
db_uri=config_args.get("db_uri", "./lancedb")
)
# load data from an existing table
description_embedding_store.document_collection = (
description_embedding_store.db_connection.open_table(
description_embedding_store.collection_name
)
)
return description_embedding_store
@validate_call(config={"arbitrary_types_allowed": True})
async def global_search(
config: GraphRagConfig,
nodes: pd.DataFrame,
entities: pd.DataFrame,
community_reports: pd.DataFrame,
community_level: int,
response_type: str,
query: str,
) -> str | dict[str, Any] | list[dict[str, Any]]:
"""Perform a global search.
Parameters
----------
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
- nodes (pd.DataFrame): A DataFrame containing the final nodes (from create_final_nodes.parquet)
- entities (pd.DataFrame): A DataFrame containing the final entities (from create_final_entities.parquet)
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from create_final_community_reports.parquet)
- community_level (int): The community level to search at.
- response_type (str): The type of response to return.
- query (str): The user query to search for.
Returns
-------
TODO: Document the search response type and format.
Raises
------
TODO: Document any exceptions to expect.
"""
reports = read_indexer_reports(community_reports, nodes, community_level)
_entities = read_indexer_entities(nodes, entities, community_level)
search_engine = get_global_search_engine(
config,
reports=reports,
entities=_entities,
response_type=response_type,
)
result = await search_engine.asearch(query=query)
reporter.success(f"Global Search Response: {result.response}")
return result.response
@validate_call(config={"arbitrary_types_allowed": True})
async def local_search(
config: GraphRagConfig,
nodes: pd.DataFrame,
entities: pd.DataFrame,
community_reports: pd.DataFrame,
text_units: pd.DataFrame,
relationships: pd.DataFrame,
covariates: pd.DataFrame | None,
community_level: int,
response_type: str,
query: str,
) -> str | dict[str, Any] | list[dict[str, Any]]:
"""Perform a local search.
Parameters
----------
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
- nodes (pd.DataFrame): A DataFrame containing the final nodes (from create_final_nodes.parquet)
- entities (pd.DataFrame): A DataFrame containing the final entities (from create_final_entities.parquet)
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from create_final_community_reports.parquet)
- text_units (pd.DataFrame): A DataFrame containing the final text units (from create_final_text_units.parquet)
- relationships (pd.DataFrame): A DataFrame containing the final relationships (from create_final_relationships.parquet)
- covariates (pd.DataFrame): A DataFrame containing the final covariates (from create_final_covariates.parquet)
- community_level (int): The community level to search at.
- response_type (str): The response type to return.
- query (str): The user query to search for.
Returns
-------
TODO: Document the search response type and format.
Raises
------
TODO: Document any exceptions to expect.
"""
vector_store_args = (
config.embeddings.vector_store if config.embeddings.vector_store else {}
)
reporter.info(f"Vector Store Args: {vector_store_args}")
vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB)
_entities = read_indexer_entities(nodes, entities, community_level)
description_embedding_store = __get_embedding_description_store(
entities=_entities,
vector_store_type=vector_store_type,
config_args=vector_store_args,
)
_covariates = read_indexer_covariates(covariates) if covariates is not None else []
search_engine = get_local_search_engine(
config=config,
reports=read_indexer_reports(community_reports, nodes, community_level),
text_units=read_indexer_text_units(text_units),
entities=_entities,
relationships=read_indexer_relationships(relationships),
covariates={"claims": _covariates},
description_embedding_store=description_embedding_store,
response_type=response_type,
)
result = await search_engine.asearch(query=query)
reporter.success(f"Local Search Response: {result.response}")
return result.response

View File

@ -3,6 +3,7 @@
"""Command line interface for the query module.""" """Command line interface for the query module."""
import asyncio
import os import os
from pathlib import Path from pathlib import Path
from typing import cast from typing import cast
@ -14,72 +15,12 @@ from graphrag.config import (
create_graphrag_config, create_graphrag_config,
) )
from graphrag.index.progress import PrintProgressReporter from graphrag.index.progress import PrintProgressReporter
from graphrag.model.entity import Entity
from graphrag.query.input.loaders.dfs import (
store_entity_semantic_embeddings,
)
from graphrag.vector_stores import VectorStoreFactory, VectorStoreType
from graphrag.vector_stores.lancedb import LanceDBVectorStore
from .factories import get_global_search_engine, get_local_search_engine from . import api
from .indexer_adapters import (
read_indexer_covariates,
read_indexer_entities,
read_indexer_relationships,
read_indexer_reports,
read_indexer_text_units,
)
reporter = PrintProgressReporter("") reporter = PrintProgressReporter("")
def __get_embedding_description_store(
entities: list[Entity],
vector_store_type: str = VectorStoreType.LanceDB,
config_args: dict | None = None,
):
"""Get the embedding description store."""
if not config_args:
config_args = {}
collection_name = config_args.get(
"query_collection_name", "entity_description_embeddings"
)
config_args.update({"collection_name": collection_name})
description_embedding_store = VectorStoreFactory.get_vector_store(
vector_store_type=vector_store_type, kwargs=config_args
)
description_embedding_store.connect(**config_args)
if config_args.get("overwrite", True):
# this step assumps the embeddings where originally stored in a file rather
# than a vector database
# dump embeddings from the entities list to the description_embedding_store
store_entity_semantic_embeddings(
entities=entities, vectorstore=description_embedding_store
)
else:
# load description embeddings to an in-memory lancedb vectorstore
# to connect to a remote db, specify url and port values.
description_embedding_store = LanceDBVectorStore(
collection_name=collection_name
)
description_embedding_store.connect(
db_uri=config_args.get("db_uri", "./lancedb")
)
# load data from an existing table
description_embedding_store.document_collection = (
description_embedding_store.db_connection.open_table(
description_embedding_store.collection_name
)
)
return description_embedding_store
def run_global_search( def run_global_search(
config_dir: str | None, config_dir: str | None,
data_dir: str | None, data_dir: str | None,
@ -88,7 +29,10 @@ def run_global_search(
response_type: str, response_type: str,
query: str, query: str,
): ):
"""Run a global search with the given query.""" """Perform a global search with a given query.
Loads index files required for global search and calls the Query API.
"""
data_dir, root_dir, config = _configure_paths_and_settings( data_dir, root_dir, config = _configure_paths_and_settings(
data_dir, root_dir, config_dir data_dir, root_dir, config_dir
) )
@ -104,21 +48,17 @@ def run_global_search(
data_path / "create_final_community_reports.parquet" data_path / "create_final_community_reports.parquet"
) )
reports = read_indexer_reports( return asyncio.run(
final_community_reports, final_nodes, community_level api.global_search(
) config=config,
entities = read_indexer_entities(final_nodes, final_entities, community_level) nodes=final_nodes,
search_engine = get_global_search_engine( entities=final_entities,
config, community_reports=final_community_reports,
reports=reports, community_level=community_level,
entities=entities,
response_type=response_type, response_type=response_type,
query=query,
)
) )
result = search_engine.search(query=query)
reporter.success(f"Global Search Response: {result.response}")
return result.response
def run_local_search( def run_local_search(
@ -129,7 +69,10 @@ def run_local_search(
response_type: str, response_type: str,
query: str, query: str,
): ):
"""Run a local search with the given query.""" """Perform a local search with a given query.
Loads index files required for local search and calls the Query API.
"""
data_dir, root_dir, config = _configure_paths_and_settings( data_dir, root_dir, config = _configure_paths_and_settings(
data_dir, root_dir, config_dir data_dir, root_dir, config_dir
) )
@ -151,41 +94,21 @@ def run_local_search(
else None else None
) )
vector_store_args = ( # call the Query API
config.embeddings.vector_store if config.embeddings.vector_store else {} return asyncio.run(
) api.local_search(
config=config,
reporter.info(f"Vector Store Args: {vector_store_args}") nodes=final_nodes,
vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB) entities=final_entities,
community_reports=final_community_reports,
entities = read_indexer_entities(final_nodes, final_entities, community_level) text_units=final_text_units,
description_embedding_store = __get_embedding_description_store( relationships=final_relationships,
entities=entities, covariates=final_covariates,
vector_store_type=vector_store_type, community_level=community_level,
config_args=vector_store_args,
)
covariates = (
read_indexer_covariates(final_covariates)
if final_covariates is not None
else []
)
search_engine = get_local_search_engine(
config,
reports=read_indexer_reports(
final_community_reports, final_nodes, community_level
),
text_units=read_indexer_text_units(final_text_units),
entities=entities,
relationships=read_indexer_relationships(final_relationships),
covariates={"claims": covariates},
description_embedding_store=description_embedding_store,
response_type=response_type, response_type=response_type,
query=query,
)
) )
result = search_engine.search(query=query)
reporter.success(f"Local Search Response: {result.response}")
return result.response
def _configure_paths_and_settings( def _configure_paths_and_settings(