Gnievesponce/query client vectore store (#771)

* added default title_column and collection_name values for workflows using the vector store option

* incorporated vector database support to the query client

* Updated docuemnatation to reflect the new query client param.

* Fixed ruff formatting

* added new poetry lock file

---------

Co-authored-by: Gabriel Nieves-Ponce <gnievesponce@microsoft.com>
Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
This commit is contained in:
Gabriel Nieves-Ponce 2024-07-30 19:59:04 -04:00 committed by GitHub
parent fc9f29dccd
commit d26491a622
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 82 additions and 22 deletions

View File

@ -0,0 +1,4 @@
{
"type": "minor",
"description": "Fixed a bug that erased the vector database, added a new parameter to specify the config file path, and updated the documentation accordingly."
}

View File

@ -9,11 +9,12 @@ date: 2024-27-03
The GraphRAG query CLI allows for no-code usage of the GraphRAG Query engine. The GraphRAG query CLI allows for no-code usage of the GraphRAG Query engine.
```bash ```bash
python -m graphrag.query --data <path-to-data> --community_level <comunit-level> --response_type <response-type> --method <"local"|"global"> <query> python -m graphrag.query --config <config_file.yml> --data <path-to-data> --community_level <comunit-level> --response_type <response-type> --method <"local"|"global"> <query>
``` ```
## CLI Arguments ## CLI Arguments
- `--config <config_file.yml>` - The configuration yaml file to use when running the query. If this is used, then none of the environment-variables below will apply.
- `--data <path-to-data>` - Folder containing the `.parquet` output files from running the Indexer. - `--data <path-to-data>` - Folder containing the `.parquet` output files from running the Indexer.
- `--community_level <community-level>` - Community level in the Leiden community hierarchy from which we will load the community reports higher value means we use reports on smaller communities. Default: 2 - `--community_level <community-level>` - Community level in the Leiden community hierarchy from which we will load the community reports higher value means we use reports on smaller communities. Default: 2
- `--response_type <response-type>` - Free form text describing the response type and format, can be anything, e.g. `Multiple Paragraphs`, `Single Paragraph`, `Single Sentence`, `List of 3-7 Points`, `Single Page`, `Multi-Page Report`. Default: `Multiple Paragraphs`. - `--response_type <response-type>` - Free form text describing the response type and format, can be anything, e.g. `Multiple Paragraphs`, `Single Paragraph`, `Single Sentence`, `List of 3-7 Points`, `Single Page`, `Multi-Page Report`. Default: `Multiple Paragraphs`.

View File

@ -25,6 +25,13 @@ class SearchType(Enum):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
help="The configuration yaml file to use when running the query",
required=False,
type=str,
)
parser.add_argument( parser.add_argument(
"--data", "--data",
help="The path with the output data from the pipeline", help="The path with the output data from the pipeline",
@ -74,6 +81,7 @@ if __name__ == "__main__":
match args.method: match args.method:
case SearchType.LOCAL: case SearchType.LOCAL:
run_local_search( run_local_search(
args.config,
args.data, args.data,
args.root, args.root,
args.community_level, args.community_level,
@ -82,6 +90,7 @@ if __name__ == "__main__":
) )
case SearchType.GLOBAL: case SearchType.GLOBAL:
run_global_search( run_global_search(
args.config,
args.data, args.data,
args.root, args.root,
args.community_level, args.community_level,

View File

@ -14,10 +14,12 @@ from graphrag.config import (
create_graphrag_config, create_graphrag_config,
) )
from graphrag.index.progress import PrintProgressReporter from graphrag.index.progress import PrintProgressReporter
from graphrag.model.entity import Entity
from graphrag.query.input.loaders.dfs import ( from graphrag.query.input.loaders.dfs import (
store_entity_semantic_embeddings, store_entity_semantic_embeddings,
) )
from graphrag.vector_stores import VectorStoreFactory, VectorStoreType from graphrag.vector_stores import VectorStoreFactory, VectorStoreType
from graphrag.vector_stores.lancedb import LanceDBVectorStore
from .factories import get_global_search_engine, get_local_search_engine from .factories import get_global_search_engine, get_local_search_engine
from .indexer_adapters import ( from .indexer_adapters import (
@ -32,28 +34,54 @@ reporter = PrintProgressReporter("")
def __get_embedding_description_store( def __get_embedding_description_store(
vector_store_type: str = VectorStoreType.LanceDB, config_args: dict | None = None entities: list[Entity],
vector_store_type: str = VectorStoreType.LanceDB,
config_args: dict | None = None,
): ):
"""Get the embedding description store.""" """Get the embedding description store."""
if not config_args: if not config_args:
config_args = {} config_args = {}
config_args.update({ collection_name = config_args.get(
"collection_name": config_args.get( "query_collection_name", "entity_description_embeddings"
"query_collection_name", )
config_args.get("collection_name", "description_embedding"), config_args.update({"collection_name": collection_name})
),
})
description_embedding_store = VectorStoreFactory.get_vector_store( description_embedding_store = VectorStoreFactory.get_vector_store(
vector_store_type=vector_store_type, kwargs=config_args vector_store_type=vector_store_type, kwargs=config_args
) )
description_embedding_store.connect(**config_args) description_embedding_store.connect(**config_args)
if config_args.get("overwrite", False):
# this step assumps the embeddings where originally stored in a file rather
# than a vector database
# dump embeddings from the entities list to the description_embedding_store
store_entity_semantic_embeddings(
entities=entities, vectorstore=description_embedding_store
)
else:
# load description embeddings to an in-memory lancedb vectorstore
# to connect to a remote db, specify url and port values.
description_embedding_store = LanceDBVectorStore(
collection_name=collection_name
)
description_embedding_store.connect(
db_uri=config_args.get("db_uri", "./lancedb")
)
# load data from an existing table
description_embedding_store.document_collection = (
description_embedding_store.db_connection.open_table(
description_embedding_store.collection_name
)
)
return description_embedding_store return description_embedding_store
def run_global_search( def run_global_search(
config_dir: str | None,
data_dir: str | None, data_dir: str | None,
root_dir: str | None, root_dir: str | None,
community_level: int, community_level: int,
@ -61,7 +89,9 @@ def run_global_search(
query: str, query: str,
): ):
"""Run a global search with the given query.""" """Run a global search with the given query."""
data_dir, root_dir, config = _configure_paths_and_settings(data_dir, root_dir) data_dir, root_dir, config = _configure_paths_and_settings(
data_dir, root_dir, config_dir
)
data_path = Path(data_dir) data_path = Path(data_dir)
final_nodes: pd.DataFrame = pd.read_parquet( final_nodes: pd.DataFrame = pd.read_parquet(
@ -92,6 +122,7 @@ def run_global_search(
def run_local_search( def run_local_search(
config_dir: str | None,
data_dir: str | None, data_dir: str | None,
root_dir: str | None, root_dir: str | None,
community_level: int, community_level: int,
@ -99,7 +130,9 @@ def run_local_search(
query: str, query: str,
): ):
"""Run a local search with the given query.""" """Run a local search with the given query."""
data_dir, root_dir, config = _configure_paths_and_settings(data_dir, root_dir) data_dir, root_dir, config = _configure_paths_and_settings(
data_dir, root_dir, config_dir
)
data_path = Path(data_dir) data_path = Path(data_dir)
final_nodes = pd.read_parquet(data_path / "create_final_nodes.parquet") final_nodes = pd.read_parquet(data_path / "create_final_nodes.parquet")
@ -121,16 +154,16 @@ def run_local_search(
vector_store_args = ( vector_store_args = (
config.embeddings.vector_store if config.embeddings.vector_store else {} config.embeddings.vector_store if config.embeddings.vector_store else {}
) )
reporter.info(f"Vector Store Args: {vector_store_args}")
vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB) vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB)
entities = read_indexer_entities(final_nodes, final_entities, community_level)
description_embedding_store = __get_embedding_description_store( description_embedding_store = __get_embedding_description_store(
entities=entities,
vector_store_type=vector_store_type, vector_store_type=vector_store_type,
config_args=vector_store_args, config_args=vector_store_args,
) )
entities = read_indexer_entities(final_nodes, final_entities, community_level)
store_entity_semantic_embeddings(
entities=entities, vectorstore=description_embedding_store
)
covariates = ( covariates = (
read_indexer_covariates(final_covariates) read_indexer_covariates(final_covariates)
if final_covariates is not None if final_covariates is not None
@ -156,14 +189,16 @@ def run_local_search(
def _configure_paths_and_settings( def _configure_paths_and_settings(
data_dir: str | None, root_dir: str | None data_dir: str | None,
root_dir: str | None,
config_dir: str | None,
) -> tuple[str, str | None, GraphRagConfig]: ) -> tuple[str, str | None, GraphRagConfig]:
if data_dir is None and root_dir is None: if data_dir is None and root_dir is None:
msg = "Either data_dir or root_dir must be provided." msg = "Either data_dir or root_dir must be provided."
raise ValueError(msg) raise ValueError(msg)
if data_dir is None: if data_dir is None:
data_dir = _infer_data_dir(cast(str, root_dir)) data_dir = _infer_data_dir(cast(str, root_dir))
config = _create_graphrag_config(root_dir, data_dir) config = _create_graphrag_config(root_dir, config_dir)
return data_dir, root_dir, config return data_dir, root_dir, config
@ -179,17 +214,23 @@ def _infer_data_dir(root: str) -> str:
raise ValueError(msg) raise ValueError(msg)
def _create_graphrag_config(root: str | None, data_dir: str | None) -> GraphRagConfig: def _create_graphrag_config(
root: str | None,
config_dir: str | None,
) -> GraphRagConfig:
"""Create a GraphRag configuration.""" """Create a GraphRag configuration."""
return _read_config_parameters(cast(str, root or data_dir)) return _read_config_parameters(root or "./", config_dir)
def _read_config_parameters(root: str): def _read_config_parameters(root: str, config: str | None):
_root = Path(root) _root = Path(root)
settings_yaml = _root / "settings.yaml" settings_yaml = (
Path(config)
if config and Path(config).suffix in [".yaml", ".yml"]
else _root / "settings.yaml"
)
if not settings_yaml.exists(): if not settings_yaml.exists():
settings_yaml = _root / "settings.yml" settings_yaml = _root / "settings.yml"
settings_json = _root / "settings.json"
if settings_yaml.exists(): if settings_yaml.exists():
reporter.info(f"Reading settings from {settings_yaml}") reporter.info(f"Reading settings from {settings_yaml}")
@ -201,6 +242,11 @@ def _read_config_parameters(root: str):
data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict")) data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict"))
return create_graphrag_config(data, root) return create_graphrag_config(data, root)
settings_json = (
Path(config)
if config and Path(config).suffix == ".json"
else _root / "settings.json"
)
if settings_json.exists(): if settings_json.exists():
reporter.info(f"Reading settings from {settings_json}") reporter.info(f"Reading settings from {settings_json}")
with settings_json.open("rb") as file: with settings_json.open("rb") as file: