mirror of
https://github.com/microsoft/graphrag.git
synced 2025-12-13 07:51:34 +00:00
Gnievesponce/query client vectore store (#771)
* added default title_column and collection_name values for workflows using the vector store option * incorporated vector database support to the query client * Updated docuemnatation to reflect the new query client param. * Fixed ruff formatting * added new poetry lock file --------- Co-authored-by: Gabriel Nieves-Ponce <gnievesponce@microsoft.com> Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
This commit is contained in:
parent
fc9f29dccd
commit
d26491a622
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "minor",
|
||||
"description": "Fixed a bug that erased the vector database, added a new parameter to specify the config file path, and updated the documentation accordingly."
|
||||
}
|
||||
@ -9,11 +9,12 @@ date: 2024-27-03
|
||||
The GraphRAG query CLI allows for no-code usage of the GraphRAG Query engine.
|
||||
|
||||
```bash
|
||||
python -m graphrag.query --data <path-to-data> --community_level <comunit-level> --response_type <response-type> --method <"local"|"global"> <query>
|
||||
python -m graphrag.query --config <config_file.yml> --data <path-to-data> --community_level <comunit-level> --response_type <response-type> --method <"local"|"global"> <query>
|
||||
```
|
||||
|
||||
## CLI Arguments
|
||||
|
||||
- `--config <config_file.yml>` - The configuration yaml file to use when running the query. If this is used, then none of the environment-variables below will apply.
|
||||
- `--data <path-to-data>` - Folder containing the `.parquet` output files from running the Indexer.
|
||||
- `--community_level <community-level>` - Community level in the Leiden community hierarchy from which we will load the community reports higher value means we use reports on smaller communities. Default: 2
|
||||
- `--response_type <response-type>` - Free form text describing the response type and format, can be anything, e.g. `Multiple Paragraphs`, `Single Paragraph`, `Single Sentence`, `List of 3-7 Points`, `Single Page`, `Multi-Page Report`. Default: `Multiple Paragraphs`.
|
||||
|
||||
@ -25,6 +25,13 @@ class SearchType(Enum):
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
help="The configuration yaml file to use when running the query",
|
||||
required=False,
|
||||
type=str,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--data",
|
||||
help="The path with the output data from the pipeline",
|
||||
@ -74,6 +81,7 @@ if __name__ == "__main__":
|
||||
match args.method:
|
||||
case SearchType.LOCAL:
|
||||
run_local_search(
|
||||
args.config,
|
||||
args.data,
|
||||
args.root,
|
||||
args.community_level,
|
||||
@ -82,6 +90,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
case SearchType.GLOBAL:
|
||||
run_global_search(
|
||||
args.config,
|
||||
args.data,
|
||||
args.root,
|
||||
args.community_level,
|
||||
|
||||
@ -14,10 +14,12 @@ from graphrag.config import (
|
||||
create_graphrag_config,
|
||||
)
|
||||
from graphrag.index.progress import PrintProgressReporter
|
||||
from graphrag.model.entity import Entity
|
||||
from graphrag.query.input.loaders.dfs import (
|
||||
store_entity_semantic_embeddings,
|
||||
)
|
||||
from graphrag.vector_stores import VectorStoreFactory, VectorStoreType
|
||||
from graphrag.vector_stores.lancedb import LanceDBVectorStore
|
||||
|
||||
from .factories import get_global_search_engine, get_local_search_engine
|
||||
from .indexer_adapters import (
|
||||
@ -32,28 +34,54 @@ reporter = PrintProgressReporter("")
|
||||
|
||||
|
||||
def __get_embedding_description_store(
|
||||
vector_store_type: str = VectorStoreType.LanceDB, config_args: dict | None = None
|
||||
entities: list[Entity],
|
||||
vector_store_type: str = VectorStoreType.LanceDB,
|
||||
config_args: dict | None = None,
|
||||
):
|
||||
"""Get the embedding description store."""
|
||||
if not config_args:
|
||||
config_args = {}
|
||||
|
||||
config_args.update({
|
||||
"collection_name": config_args.get(
|
||||
"query_collection_name",
|
||||
config_args.get("collection_name", "description_embedding"),
|
||||
),
|
||||
})
|
||||
|
||||
collection_name = config_args.get(
|
||||
"query_collection_name", "entity_description_embeddings"
|
||||
)
|
||||
config_args.update({"collection_name": collection_name})
|
||||
description_embedding_store = VectorStoreFactory.get_vector_store(
|
||||
vector_store_type=vector_store_type, kwargs=config_args
|
||||
)
|
||||
|
||||
description_embedding_store.connect(**config_args)
|
||||
|
||||
if config_args.get("overwrite", False):
|
||||
# this step assumps the embeddings where originally stored in a file rather
|
||||
# than a vector database
|
||||
|
||||
# dump embeddings from the entities list to the description_embedding_store
|
||||
store_entity_semantic_embeddings(
|
||||
entities=entities, vectorstore=description_embedding_store
|
||||
)
|
||||
else:
|
||||
# load description embeddings to an in-memory lancedb vectorstore
|
||||
# to connect to a remote db, specify url and port values.
|
||||
description_embedding_store = LanceDBVectorStore(
|
||||
collection_name=collection_name
|
||||
)
|
||||
description_embedding_store.connect(
|
||||
db_uri=config_args.get("db_uri", "./lancedb")
|
||||
)
|
||||
|
||||
# load data from an existing table
|
||||
description_embedding_store.document_collection = (
|
||||
description_embedding_store.db_connection.open_table(
|
||||
description_embedding_store.collection_name
|
||||
)
|
||||
)
|
||||
|
||||
return description_embedding_store
|
||||
|
||||
|
||||
def run_global_search(
|
||||
config_dir: str | None,
|
||||
data_dir: str | None,
|
||||
root_dir: str | None,
|
||||
community_level: int,
|
||||
@ -61,7 +89,9 @@ def run_global_search(
|
||||
query: str,
|
||||
):
|
||||
"""Run a global search with the given query."""
|
||||
data_dir, root_dir, config = _configure_paths_and_settings(data_dir, root_dir)
|
||||
data_dir, root_dir, config = _configure_paths_and_settings(
|
||||
data_dir, root_dir, config_dir
|
||||
)
|
||||
data_path = Path(data_dir)
|
||||
|
||||
final_nodes: pd.DataFrame = pd.read_parquet(
|
||||
@ -92,6 +122,7 @@ def run_global_search(
|
||||
|
||||
|
||||
def run_local_search(
|
||||
config_dir: str | None,
|
||||
data_dir: str | None,
|
||||
root_dir: str | None,
|
||||
community_level: int,
|
||||
@ -99,7 +130,9 @@ def run_local_search(
|
||||
query: str,
|
||||
):
|
||||
"""Run a local search with the given query."""
|
||||
data_dir, root_dir, config = _configure_paths_and_settings(data_dir, root_dir)
|
||||
data_dir, root_dir, config = _configure_paths_and_settings(
|
||||
data_dir, root_dir, config_dir
|
||||
)
|
||||
data_path = Path(data_dir)
|
||||
|
||||
final_nodes = pd.read_parquet(data_path / "create_final_nodes.parquet")
|
||||
@ -121,16 +154,16 @@ def run_local_search(
|
||||
vector_store_args = (
|
||||
config.embeddings.vector_store if config.embeddings.vector_store else {}
|
||||
)
|
||||
|
||||
reporter.info(f"Vector Store Args: {vector_store_args}")
|
||||
vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB)
|
||||
|
||||
entities = read_indexer_entities(final_nodes, final_entities, community_level)
|
||||
description_embedding_store = __get_embedding_description_store(
|
||||
entities=entities,
|
||||
vector_store_type=vector_store_type,
|
||||
config_args=vector_store_args,
|
||||
)
|
||||
entities = read_indexer_entities(final_nodes, final_entities, community_level)
|
||||
store_entity_semantic_embeddings(
|
||||
entities=entities, vectorstore=description_embedding_store
|
||||
)
|
||||
covariates = (
|
||||
read_indexer_covariates(final_covariates)
|
||||
if final_covariates is not None
|
||||
@ -156,14 +189,16 @@ def run_local_search(
|
||||
|
||||
|
||||
def _configure_paths_and_settings(
|
||||
data_dir: str | None, root_dir: str | None
|
||||
data_dir: str | None,
|
||||
root_dir: str | None,
|
||||
config_dir: str | None,
|
||||
) -> tuple[str, str | None, GraphRagConfig]:
|
||||
if data_dir is None and root_dir is None:
|
||||
msg = "Either data_dir or root_dir must be provided."
|
||||
raise ValueError(msg)
|
||||
if data_dir is None:
|
||||
data_dir = _infer_data_dir(cast(str, root_dir))
|
||||
config = _create_graphrag_config(root_dir, data_dir)
|
||||
config = _create_graphrag_config(root_dir, config_dir)
|
||||
return data_dir, root_dir, config
|
||||
|
||||
|
||||
@ -179,17 +214,23 @@ def _infer_data_dir(root: str) -> str:
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def _create_graphrag_config(root: str | None, data_dir: str | None) -> GraphRagConfig:
|
||||
def _create_graphrag_config(
|
||||
root: str | None,
|
||||
config_dir: str | None,
|
||||
) -> GraphRagConfig:
|
||||
"""Create a GraphRag configuration."""
|
||||
return _read_config_parameters(cast(str, root or data_dir))
|
||||
return _read_config_parameters(root or "./", config_dir)
|
||||
|
||||
|
||||
def _read_config_parameters(root: str):
|
||||
def _read_config_parameters(root: str, config: str | None):
|
||||
_root = Path(root)
|
||||
settings_yaml = _root / "settings.yaml"
|
||||
settings_yaml = (
|
||||
Path(config)
|
||||
if config and Path(config).suffix in [".yaml", ".yml"]
|
||||
else _root / "settings.yaml"
|
||||
)
|
||||
if not settings_yaml.exists():
|
||||
settings_yaml = _root / "settings.yml"
|
||||
settings_json = _root / "settings.json"
|
||||
|
||||
if settings_yaml.exists():
|
||||
reporter.info(f"Reading settings from {settings_yaml}")
|
||||
@ -201,6 +242,11 @@ def _read_config_parameters(root: str):
|
||||
data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict"))
|
||||
return create_graphrag_config(data, root)
|
||||
|
||||
settings_json = (
|
||||
Path(config)
|
||||
if config and Path(config).suffix == ".json"
|
||||
else _root / "settings.json"
|
||||
)
|
||||
if settings_json.exists():
|
||||
reporter.info(f"Reading settings from {settings_json}")
|
||||
with settings_json.open("rb") as file:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user