Gnievesponce/query client vectore store (#771)

* added default title_column and collection_name values for workflows using the vector store option

* incorporated vector database support to the query client

* Updated docuemnatation to reflect the new query client param.

* Fixed ruff formatting

* added new poetry lock file

---------

Co-authored-by: Gabriel Nieves-Ponce <gnievesponce@microsoft.com>
Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
This commit is contained in:
Gabriel Nieves-Ponce 2024-07-30 19:59:04 -04:00 committed by GitHub
parent fc9f29dccd
commit d26491a622
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 82 additions and 22 deletions

View File

@ -0,0 +1,4 @@
{
"type": "minor",
"description": "Fixed a bug that erased the vector database, added a new parameter to specify the config file path, and updated the documentation accordingly."
}

View File

@ -9,11 +9,12 @@ date: 2024-27-03
The GraphRAG query CLI allows for no-code usage of the GraphRAG Query engine.
```bash
python -m graphrag.query --data <path-to-data> --community_level <comunit-level> --response_type <response-type> --method <"local"|"global"> <query>
python -m graphrag.query --config <config_file.yml> --data <path-to-data> --community_level <comunit-level> --response_type <response-type> --method <"local"|"global"> <query>
```
## CLI Arguments
- `--config <config_file.yml>` - The configuration yaml file to use when running the query. If this is used, then none of the environment-variables below will apply.
- `--data <path-to-data>` - Folder containing the `.parquet` output files from running the Indexer.
- `--community_level <community-level>` - Community level in the Leiden community hierarchy from which we will load the community reports higher value means we use reports on smaller communities. Default: 2
- `--response_type <response-type>` - Free form text describing the response type and format, can be anything, e.g. `Multiple Paragraphs`, `Single Paragraph`, `Single Sentence`, `List of 3-7 Points`, `Single Page`, `Multi-Page Report`. Default: `Multiple Paragraphs`.

View File

@ -25,6 +25,13 @@ class SearchType(Enum):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
help="The configuration yaml file to use when running the query",
required=False,
type=str,
)
parser.add_argument(
"--data",
help="The path with the output data from the pipeline",
@ -74,6 +81,7 @@ if __name__ == "__main__":
match args.method:
case SearchType.LOCAL:
run_local_search(
args.config,
args.data,
args.root,
args.community_level,
@ -82,6 +90,7 @@ if __name__ == "__main__":
)
case SearchType.GLOBAL:
run_global_search(
args.config,
args.data,
args.root,
args.community_level,

View File

@ -14,10 +14,12 @@ from graphrag.config import (
create_graphrag_config,
)
from graphrag.index.progress import PrintProgressReporter
from graphrag.model.entity import Entity
from graphrag.query.input.loaders.dfs import (
store_entity_semantic_embeddings,
)
from graphrag.vector_stores import VectorStoreFactory, VectorStoreType
from graphrag.vector_stores.lancedb import LanceDBVectorStore
from .factories import get_global_search_engine, get_local_search_engine
from .indexer_adapters import (
@ -32,28 +34,54 @@ reporter = PrintProgressReporter("")
def __get_embedding_description_store(
vector_store_type: str = VectorStoreType.LanceDB, config_args: dict | None = None
entities: list[Entity],
vector_store_type: str = VectorStoreType.LanceDB,
config_args: dict | None = None,
):
"""Get the embedding description store."""
if not config_args:
config_args = {}
config_args.update({
"collection_name": config_args.get(
"query_collection_name",
config_args.get("collection_name", "description_embedding"),
),
})
collection_name = config_args.get(
"query_collection_name", "entity_description_embeddings"
)
config_args.update({"collection_name": collection_name})
description_embedding_store = VectorStoreFactory.get_vector_store(
vector_store_type=vector_store_type, kwargs=config_args
)
description_embedding_store.connect(**config_args)
if config_args.get("overwrite", False):
# this step assumps the embeddings where originally stored in a file rather
# than a vector database
# dump embeddings from the entities list to the description_embedding_store
store_entity_semantic_embeddings(
entities=entities, vectorstore=description_embedding_store
)
else:
# load description embeddings to an in-memory lancedb vectorstore
# to connect to a remote db, specify url and port values.
description_embedding_store = LanceDBVectorStore(
collection_name=collection_name
)
description_embedding_store.connect(
db_uri=config_args.get("db_uri", "./lancedb")
)
# load data from an existing table
description_embedding_store.document_collection = (
description_embedding_store.db_connection.open_table(
description_embedding_store.collection_name
)
)
return description_embedding_store
def run_global_search(
config_dir: str | None,
data_dir: str | None,
root_dir: str | None,
community_level: int,
@ -61,7 +89,9 @@ def run_global_search(
query: str,
):
"""Run a global search with the given query."""
data_dir, root_dir, config = _configure_paths_and_settings(data_dir, root_dir)
data_dir, root_dir, config = _configure_paths_and_settings(
data_dir, root_dir, config_dir
)
data_path = Path(data_dir)
final_nodes: pd.DataFrame = pd.read_parquet(
@ -92,6 +122,7 @@ def run_global_search(
def run_local_search(
config_dir: str | None,
data_dir: str | None,
root_dir: str | None,
community_level: int,
@ -99,7 +130,9 @@ def run_local_search(
query: str,
):
"""Run a local search with the given query."""
data_dir, root_dir, config = _configure_paths_and_settings(data_dir, root_dir)
data_dir, root_dir, config = _configure_paths_and_settings(
data_dir, root_dir, config_dir
)
data_path = Path(data_dir)
final_nodes = pd.read_parquet(data_path / "create_final_nodes.parquet")
@ -121,16 +154,16 @@ def run_local_search(
vector_store_args = (
config.embeddings.vector_store if config.embeddings.vector_store else {}
)
reporter.info(f"Vector Store Args: {vector_store_args}")
vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB)
entities = read_indexer_entities(final_nodes, final_entities, community_level)
description_embedding_store = __get_embedding_description_store(
entities=entities,
vector_store_type=vector_store_type,
config_args=vector_store_args,
)
entities = read_indexer_entities(final_nodes, final_entities, community_level)
store_entity_semantic_embeddings(
entities=entities, vectorstore=description_embedding_store
)
covariates = (
read_indexer_covariates(final_covariates)
if final_covariates is not None
@ -156,14 +189,16 @@ def run_local_search(
def _configure_paths_and_settings(
data_dir: str | None, root_dir: str | None
data_dir: str | None,
root_dir: str | None,
config_dir: str | None,
) -> tuple[str, str | None, GraphRagConfig]:
if data_dir is None and root_dir is None:
msg = "Either data_dir or root_dir must be provided."
raise ValueError(msg)
if data_dir is None:
data_dir = _infer_data_dir(cast(str, root_dir))
config = _create_graphrag_config(root_dir, data_dir)
config = _create_graphrag_config(root_dir, config_dir)
return data_dir, root_dir, config
@ -179,17 +214,23 @@ def _infer_data_dir(root: str) -> str:
raise ValueError(msg)
def _create_graphrag_config(root: str | None, data_dir: str | None) -> GraphRagConfig:
def _create_graphrag_config(
root: str | None,
config_dir: str | None,
) -> GraphRagConfig:
"""Create a GraphRag configuration."""
return _read_config_parameters(cast(str, root or data_dir))
return _read_config_parameters(root or "./", config_dir)
def _read_config_parameters(root: str):
def _read_config_parameters(root: str, config: str | None):
_root = Path(root)
settings_yaml = _root / "settings.yaml"
settings_yaml = (
Path(config)
if config and Path(config).suffix in [".yaml", ".yml"]
else _root / "settings.yaml"
)
if not settings_yaml.exists():
settings_yaml = _root / "settings.yml"
settings_json = _root / "settings.json"
if settings_yaml.exists():
reporter.info(f"Reading settings from {settings_yaml}")
@ -201,6 +242,11 @@ def _read_config_parameters(root: str):
data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict"))
return create_graphrag_config(data, root)
settings_json = (
Path(config)
if config and Path(config).suffix == ".json"
else _root / "settings.json"
)
if settings_json.exists():
reporter.info(f"Reading settings from {settings_json}")
with settings_json.open("rb") as file: