From d26491a622f7b819e837165bd480cf5d406705d7 Mon Sep 17 00:00:00 2001 From: Gabriel Nieves-Ponce <39567323+nievespg1@users.noreply.github.com> Date: Tue, 30 Jul 2024 19:59:04 -0400 Subject: [PATCH] Gnievesponce/query client vectore store (#771) * added default title_column and collection_name values for workflows using the vector store option * incorporated vector database support to the query client * Updated docuemnatation to reflect the new query client param. * Fixed ruff formatting * added new poetry lock file --------- Co-authored-by: Gabriel Nieves-Ponce Co-authored-by: Alonso Guevara --- .../minor-20240729193644620548.json | 4 + docsite/posts/query/3-cli.md | 3 +- graphrag/query/__main__.py | 9 ++ graphrag/query/cli.py | 88 ++++++++++++++----- 4 files changed, 82 insertions(+), 22 deletions(-) create mode 100644 .semversioner/next-release/minor-20240729193644620548.json diff --git a/.semversioner/next-release/minor-20240729193644620548.json b/.semversioner/next-release/minor-20240729193644620548.json new file mode 100644 index 00000000..df862dab --- /dev/null +++ b/.semversioner/next-release/minor-20240729193644620548.json @@ -0,0 +1,4 @@ +{ + "type": "minor", + "description": "Fixed a bug that erased the vector database, added a new parameter to specify the config file path, and updated the documentation accordingly." +} diff --git a/docsite/posts/query/3-cli.md b/docsite/posts/query/3-cli.md index c1a88f9e..518d9cd8 100644 --- a/docsite/posts/query/3-cli.md +++ b/docsite/posts/query/3-cli.md @@ -9,11 +9,12 @@ date: 2024-27-03 The GraphRAG query CLI allows for no-code usage of the GraphRAG Query engine. ```bash -python -m graphrag.query --data --community_level --response_type --method <"local"|"global"> +python -m graphrag.query --config --data --community_level --response_type --method <"local"|"global"> ``` ## CLI Arguments +- `--config ` - The configuration yaml file to use when running the query. If this is used, then none of the environment-variables below will apply. - `--data ` - Folder containing the `.parquet` output files from running the Indexer. - `--community_level ` - Community level in the Leiden community hierarchy from which we will load the community reports higher value means we use reports on smaller communities. Default: 2 - `--response_type ` - Free form text describing the response type and format, can be anything, e.g. `Multiple Paragraphs`, `Single Paragraph`, `Single Sentence`, `List of 3-7 Points`, `Single Page`, `Multi-Page Report`. Default: `Multiple Paragraphs`. diff --git a/graphrag/query/__main__.py b/graphrag/query/__main__.py index 9367f9a9..edf678fa 100644 --- a/graphrag/query/__main__.py +++ b/graphrag/query/__main__.py @@ -25,6 +25,13 @@ class SearchType(Enum): if __name__ == "__main__": parser = argparse.ArgumentParser() + parser.add_argument( + "--config", + help="The configuration yaml file to use when running the query", + required=False, + type=str, + ) + parser.add_argument( "--data", help="The path with the output data from the pipeline", @@ -74,6 +81,7 @@ if __name__ == "__main__": match args.method: case SearchType.LOCAL: run_local_search( + args.config, args.data, args.root, args.community_level, @@ -82,6 +90,7 @@ if __name__ == "__main__": ) case SearchType.GLOBAL: run_global_search( + args.config, args.data, args.root, args.community_level, diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index 10ad95d5..59430948 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -14,10 +14,12 @@ from graphrag.config import ( create_graphrag_config, ) from graphrag.index.progress import PrintProgressReporter +from graphrag.model.entity import Entity from graphrag.query.input.loaders.dfs import ( store_entity_semantic_embeddings, ) from graphrag.vector_stores import VectorStoreFactory, VectorStoreType +from graphrag.vector_stores.lancedb import LanceDBVectorStore from .factories import get_global_search_engine, get_local_search_engine from .indexer_adapters import ( @@ -32,28 +34,54 @@ reporter = PrintProgressReporter("") def __get_embedding_description_store( - vector_store_type: str = VectorStoreType.LanceDB, config_args: dict | None = None + entities: list[Entity], + vector_store_type: str = VectorStoreType.LanceDB, + config_args: dict | None = None, ): """Get the embedding description store.""" if not config_args: config_args = {} - config_args.update({ - "collection_name": config_args.get( - "query_collection_name", - config_args.get("collection_name", "description_embedding"), - ), - }) - + collection_name = config_args.get( + "query_collection_name", "entity_description_embeddings" + ) + config_args.update({"collection_name": collection_name}) description_embedding_store = VectorStoreFactory.get_vector_store( vector_store_type=vector_store_type, kwargs=config_args ) description_embedding_store.connect(**config_args) + + if config_args.get("overwrite", False): + # this step assumps the embeddings where originally stored in a file rather + # than a vector database + + # dump embeddings from the entities list to the description_embedding_store + store_entity_semantic_embeddings( + entities=entities, vectorstore=description_embedding_store + ) + else: + # load description embeddings to an in-memory lancedb vectorstore + # to connect to a remote db, specify url and port values. + description_embedding_store = LanceDBVectorStore( + collection_name=collection_name + ) + description_embedding_store.connect( + db_uri=config_args.get("db_uri", "./lancedb") + ) + + # load data from an existing table + description_embedding_store.document_collection = ( + description_embedding_store.db_connection.open_table( + description_embedding_store.collection_name + ) + ) + return description_embedding_store def run_global_search( + config_dir: str | None, data_dir: str | None, root_dir: str | None, community_level: int, @@ -61,7 +89,9 @@ def run_global_search( query: str, ): """Run a global search with the given query.""" - data_dir, root_dir, config = _configure_paths_and_settings(data_dir, root_dir) + data_dir, root_dir, config = _configure_paths_and_settings( + data_dir, root_dir, config_dir + ) data_path = Path(data_dir) final_nodes: pd.DataFrame = pd.read_parquet( @@ -92,6 +122,7 @@ def run_global_search( def run_local_search( + config_dir: str | None, data_dir: str | None, root_dir: str | None, community_level: int, @@ -99,7 +130,9 @@ def run_local_search( query: str, ): """Run a local search with the given query.""" - data_dir, root_dir, config = _configure_paths_and_settings(data_dir, root_dir) + data_dir, root_dir, config = _configure_paths_and_settings( + data_dir, root_dir, config_dir + ) data_path = Path(data_dir) final_nodes = pd.read_parquet(data_path / "create_final_nodes.parquet") @@ -121,16 +154,16 @@ def run_local_search( vector_store_args = ( config.embeddings.vector_store if config.embeddings.vector_store else {} ) + + reporter.info(f"Vector Store Args: {vector_store_args}") vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB) + entities = read_indexer_entities(final_nodes, final_entities, community_level) description_embedding_store = __get_embedding_description_store( + entities=entities, vector_store_type=vector_store_type, config_args=vector_store_args, ) - entities = read_indexer_entities(final_nodes, final_entities, community_level) - store_entity_semantic_embeddings( - entities=entities, vectorstore=description_embedding_store - ) covariates = ( read_indexer_covariates(final_covariates) if final_covariates is not None @@ -156,14 +189,16 @@ def run_local_search( def _configure_paths_and_settings( - data_dir: str | None, root_dir: str | None + data_dir: str | None, + root_dir: str | None, + config_dir: str | None, ) -> tuple[str, str | None, GraphRagConfig]: if data_dir is None and root_dir is None: msg = "Either data_dir or root_dir must be provided." raise ValueError(msg) if data_dir is None: data_dir = _infer_data_dir(cast(str, root_dir)) - config = _create_graphrag_config(root_dir, data_dir) + config = _create_graphrag_config(root_dir, config_dir) return data_dir, root_dir, config @@ -179,17 +214,23 @@ def _infer_data_dir(root: str) -> str: raise ValueError(msg) -def _create_graphrag_config(root: str | None, data_dir: str | None) -> GraphRagConfig: +def _create_graphrag_config( + root: str | None, + config_dir: str | None, +) -> GraphRagConfig: """Create a GraphRag configuration.""" - return _read_config_parameters(cast(str, root or data_dir)) + return _read_config_parameters(root or "./", config_dir) -def _read_config_parameters(root: str): +def _read_config_parameters(root: str, config: str | None): _root = Path(root) - settings_yaml = _root / "settings.yaml" + settings_yaml = ( + Path(config) + if config and Path(config).suffix in [".yaml", ".yml"] + else _root / "settings.yaml" + ) if not settings_yaml.exists(): settings_yaml = _root / "settings.yml" - settings_json = _root / "settings.json" if settings_yaml.exists(): reporter.info(f"Reading settings from {settings_yaml}") @@ -201,6 +242,11 @@ def _read_config_parameters(root: str): data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict")) return create_graphrag_config(data, root) + settings_json = ( + Path(config) + if config and Path(config).suffix == ".json" + else _root / "settings.json" + ) if settings_json.exists(): reporter.info(f"Reading settings from {settings_json}") with settings_json.open("rb") as file: