mirror of
https://github.com/microsoft/graphrag.git
synced 2025-12-27 15:10:00 +00:00
New workflow to generate embeddings in a single workflow (#1296)
* New workflow to generate embeddings in a single workflow * New workflow to generate embeddings in a single workflow * version change * clean tests without any embeddings references * clean tests without any embeddings references * remove code * feedback implemented * changes in logic * feedback implemented * store in table bug fixed * smoke test for generate_text_embeddings workflow * smoke test fix * add generate_text_embeddings to the list of transient workflows * smoke tests * fix * ruff formatting updates * fix * smoke test fixed * smoke test fixed * fix lancedb import * smoke test fix * ignore sorting * smoke test fixed * smoke test fixed * check smoke test * smoke test fixed * change config for vector store * format fix * vector store changes * revert debug profile back to empty filepath * merge conflict solved * merge conflict solved * format fixed * format fixed * fix return dataframe * snapshot fix * format fix * embeddings param implemented * validation fixes * fix map * fix map * fix properties * config updates * smoke test fixed * settings change * Update collection config and rework back-compat * Repalce . with - for embedding store --------- Co-authored-by: Alonso Guevara <alonsog@microsoft.com> Co-authored-by: Josh Bradley <joshbradley@microsoft.com> Co-authored-by: Nathan Evans <github@talkswithnumbers.com>
This commit is contained in:
parent
8302920ac8
commit
17658c5df8
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "minor",
|
||||
"description": "embeddings moved to a different workflow"
|
||||
}
|
||||
@ -85,8 +85,8 @@ This is the base LLM configuration section. Other steps may override this config
|
||||
- `async_mode` (see Async Mode top-level config)
|
||||
- `batch_size` **int** - The maximum batch size to use.
|
||||
- `batch_max_tokens` **int** - The maximum batch # of tokens.
|
||||
- `target` **required|all** - Determines which set of embeddings to emit.
|
||||
- `skip` **list[str]** - Which embeddings to skip.
|
||||
- `target` **required|all|none** - Determines which set of embeddings to emit.
|
||||
- `skip` **list[str]** - Which embeddings to skip. Only useful if target=all to customize the list.
|
||||
- `vector_store` **dict** - The vector store to use. Configured for lancedb by default.
|
||||
- `type` **str** - `lancedb` or `azure_ai_search`. Default=`lancedb`
|
||||
- `db_uri` **str** (only for lancedb) - The database uri. Default=`storage.base_dir/lancedb`
|
||||
@ -94,7 +94,7 @@ This is the base LLM configuration section. Other steps may override this config
|
||||
- `api_key` **str** (optional - only for AI Search) - The AI Search api key to use.
|
||||
- `audience` **str** (only for AI Search) - Audience for managed identity token if managed identity authentication is used.
|
||||
- `overwrite` **bool** (only used at index creation time) - Overwrite collection if it exist. Default=`True`
|
||||
- `collection_name` **str** - The name of a vector collection. Default=`entity_description_embeddings`
|
||||
- `container_name` **str** - The name of a vector container. This stores all indexes (tables) for a given dataset ingest. Default=`default`
|
||||
- `strategy` **dict** - Fully override the text-embedding strategy.
|
||||
|
||||
## chunks
|
||||
|
||||
@ -108,7 +108,7 @@
|
||||
"# load description embeddings to an in-memory lancedb vectorstore\n",
|
||||
"# to connect to a remote db, specify url and port values.\n",
|
||||
"description_embedding_store = LanceDBVectorStore(\n",
|
||||
" collection_name=\"entity_description_embeddings\",\n",
|
||||
" collection_name=\"entity.description\",\n",
|
||||
")\n",
|
||||
"description_embedding_store.connect(db_uri=LANCEDB_URI)\n",
|
||||
"entity_description_embeddings = store_entity_semantic_embeddings(\n",
|
||||
|
||||
@ -299,7 +299,7 @@
|
||||
"entities = read_indexer_entities(entity_df, entity_embedding_df, COMMUNITY_LEVEL)\n",
|
||||
"\n",
|
||||
"description_embedding_store = LanceDBVectorStore(\n",
|
||||
" collection_name=\"entity_description_embeddings\",\n",
|
||||
" collection_name=\"entity.description\",\n",
|
||||
")\n",
|
||||
"description_embedding_store.connect(db_uri=LANCEDB_URI)\n",
|
||||
"entity_description_embeddings = store_entity_semantic_embeddings(\n",
|
||||
|
||||
@ -59,13 +59,7 @@ async def build_index(
|
||||
msg = "Cannot resume and update a run at the same time."
|
||||
raise ValueError(msg)
|
||||
|
||||
# TODO: must update filepath of lancedb (if used) until the new config engine has been implemented
|
||||
# TODO: remove the type ignore annotations below once the new config engine has been refactored
|
||||
vector_store_type = config.embeddings.vector_store["type"] # type: ignore
|
||||
if vector_store_type == VectorStoreType.LanceDB:
|
||||
db_uri = config.embeddings.vector_store["db_uri"] # type: ignore
|
||||
lancedb_dir = Path(config.root_dir).resolve() / db_uri
|
||||
config.embeddings.vector_store["db_uri"] = str(lancedb_dir) # type: ignore
|
||||
config = _patch_vector_config(config)
|
||||
|
||||
pipeline_config = create_pipeline_config(config)
|
||||
pipeline_cache = (
|
||||
@ -90,3 +84,22 @@ async def build_index(
|
||||
progress_reporter.success(output.workflow)
|
||||
progress_reporter.info(str(output.result))
|
||||
return outputs
|
||||
|
||||
|
||||
def _patch_vector_config(config: GraphRagConfig):
|
||||
"""Back-compat patch to ensure a default vector store configuration."""
|
||||
if not config.embeddings.vector_store:
|
||||
config.embeddings.vector_store = {
|
||||
"type": "lancedb",
|
||||
"db_uri": "output/lancedb",
|
||||
"container_name": "default",
|
||||
"overwrite": True,
|
||||
}
|
||||
# TODO: must update filepath of lancedb (if used) until the new config engine has been implemented
|
||||
# TODO: remove the type ignore annotations below once the new config engine has been refactored
|
||||
vector_store_type = config.embeddings.vector_store["type"] # type: ignore
|
||||
if vector_store_type == VectorStoreType.LanceDB:
|
||||
db_uri = config.embeddings.vector_store["db_uri"] # type: ignore
|
||||
lancedb_dir = Path(config.root_dir).resolve() / db_uri
|
||||
config.embeddings.vector_store["db_uri"] = str(lancedb_dir) # type: ignore
|
||||
return config
|
||||
|
||||
@ -182,56 +182,22 @@ async def local_search(
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
#################################### BEGIN PATCH ####################################
|
||||
# TODO: remove the following patch that checks for a vector_store prior to v1 release
|
||||
# TODO: this is a backwards compatibility patch that injects the default vector_store settings into the config if it is not present
|
||||
# Only applicable in situations involving a local vector_store (lancedb). The general idea:
|
||||
# if vector_store not in config:
|
||||
# 1. assume user is running local if vector_store is not in config
|
||||
# 2. insert default vector_store in config
|
||||
# 3 .create lancedb vector_store instance
|
||||
# 4. upload vector embeddings from the input dataframes to the vector_store
|
||||
backwards_compatible = False
|
||||
if not config.embeddings.vector_store:
|
||||
backwards_compatible = True
|
||||
from graphrag.query.input.loaders.dfs import store_entity_semantic_embeddings
|
||||
from graphrag.vector_stores.lancedb import LanceDBVectorStore
|
||||
|
||||
config.embeddings.vector_store = {
|
||||
"type": "lancedb",
|
||||
"db_uri": f"{Path(config.storage.base_dir)}/lancedb",
|
||||
"collection_name": "entity_description_embeddings",
|
||||
"overwrite": True,
|
||||
}
|
||||
_entities = read_indexer_entities(nodes, entities, community_level)
|
||||
description_embedding_store = LanceDBVectorStore(
|
||||
db_uri=config.embeddings.vector_store["db_uri"],
|
||||
collection_name=config.embeddings.vector_store["collection_name"],
|
||||
overwrite=config.embeddings.vector_store["overwrite"],
|
||||
)
|
||||
description_embedding_store.connect(
|
||||
db_uri=config.embeddings.vector_store["db_uri"]
|
||||
)
|
||||
# dump embeddings from the entities list to the description_embedding_store
|
||||
store_entity_semantic_embeddings(
|
||||
entities=_entities, vectorstore=description_embedding_store
|
||||
)
|
||||
#################################### END PATCH ####################################
|
||||
config = _patch_vector_store(config, nodes, entities, community_level)
|
||||
|
||||
# TODO: update filepath of lancedb (if used) until the new config engine has been implemented
|
||||
# TODO: remove the type ignore annotations below once the new config engine has been refactored
|
||||
vector_store_type = config.embeddings.vector_store.get("type") # type: ignore
|
||||
vector_store_args = config.embeddings.vector_store
|
||||
if vector_store_type == VectorStoreType.LanceDB and not backwards_compatible:
|
||||
if vector_store_type == VectorStoreType.LanceDB:
|
||||
db_uri = config.embeddings.vector_store["db_uri"] # type: ignore
|
||||
lancedb_dir = Path(config.root_dir).resolve() / db_uri
|
||||
vector_store_args["db_uri"] = str(lancedb_dir) # type: ignore
|
||||
|
||||
reporter.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore
|
||||
if not backwards_compatible: # can remove this check and always set the description_embedding_store before v1 release
|
||||
description_embedding_store = _get_embedding_description_store(
|
||||
config_args=vector_store_args, # type: ignore
|
||||
)
|
||||
|
||||
description_embedding_store = _get_embedding_description_store(
|
||||
config_args=vector_store_args, # type: ignore
|
||||
)
|
||||
|
||||
_entities = read_indexer_entities(nodes, entities, community_level)
|
||||
_covariates = read_indexer_covariates(covariates) if covariates is not None else []
|
||||
@ -289,56 +255,22 @@ async def local_search_streaming(
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
#################################### BEGIN PATCH ####################################
|
||||
# TODO: remove the following patch that checks for a vector_store prior to v1 release
|
||||
# TODO: this is a backwards compatibility patch that injects the default vector_store settings into the config if it is not present
|
||||
# Only applicable in situations involving a local vector_store (lancedb). The general idea:
|
||||
# if vector_store not in config:
|
||||
# 1. assume user is running local if vector_store is not in config
|
||||
# 2. insert default vector_store in config
|
||||
# 3 .create lancedb vector_store instance
|
||||
# 4. upload vector embeddings from the input dataframes to the vector_store
|
||||
backwards_compatible = False
|
||||
if not config.embeddings.vector_store:
|
||||
backwards_compatible = True
|
||||
from graphrag.query.input.loaders.dfs import store_entity_semantic_embeddings
|
||||
from graphrag.vector_stores.lancedb import LanceDBVectorStore
|
||||
|
||||
config.embeddings.vector_store = {
|
||||
"type": "lancedb",
|
||||
"db_uri": f"{Path(config.storage.base_dir)}/lancedb",
|
||||
"collection_name": "entity_description_embeddings",
|
||||
"overwrite": True,
|
||||
}
|
||||
_entities = read_indexer_entities(nodes, entities, community_level)
|
||||
description_embedding_store = LanceDBVectorStore(
|
||||
db_uri=config.embeddings.vector_store["db_uri"],
|
||||
collection_name=config.embeddings.vector_store["collection_name"],
|
||||
overwrite=config.embeddings.vector_store["overwrite"],
|
||||
)
|
||||
description_embedding_store.connect(
|
||||
db_uri=config.embeddings.vector_store["db_uri"]
|
||||
)
|
||||
# dump embeddings from the entities list to the description_embedding_store
|
||||
store_entity_semantic_embeddings(
|
||||
entities=_entities, vectorstore=description_embedding_store
|
||||
)
|
||||
#################################### END PATCH ####################################
|
||||
config = _patch_vector_store(config, nodes, entities, community_level)
|
||||
|
||||
# TODO: must update filepath of lancedb (if used) until the new config engine has been implemented
|
||||
# TODO: remove the type ignore annotations below once the new config engine has been refactored
|
||||
vector_store_type = config.embeddings.vector_store.get("type") # type: ignore
|
||||
vector_store_args = config.embeddings.vector_store
|
||||
if vector_store_type == VectorStoreType.LanceDB and not backwards_compatible:
|
||||
if vector_store_type == VectorStoreType.LanceDB:
|
||||
db_uri = config.embeddings.vector_store["db_uri"] # type: ignore
|
||||
lancedb_dir = Path(config.root_dir).resolve() / db_uri
|
||||
vector_store_args["db_uri"] = str(lancedb_dir) # type: ignore
|
||||
|
||||
reporter.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore
|
||||
if not backwards_compatible: # can remove this check and always set the description_embedding_store before v1 release
|
||||
description_embedding_store = _get_embedding_description_store(
|
||||
config_args=vector_store_args, # type: ignore
|
||||
)
|
||||
|
||||
description_embedding_store = _get_embedding_description_store(
|
||||
conf_args=vector_store_args, # type: ignore
|
||||
)
|
||||
|
||||
_entities = read_indexer_entities(nodes, entities, community_level)
|
||||
_covariates = read_indexer_covariates(covariates) if covariates is not None else []
|
||||
@ -368,13 +300,55 @@ async def local_search_streaming(
|
||||
yield stream_chunk
|
||||
|
||||
|
||||
def _patch_vector_store(
|
||||
config: GraphRagConfig,
|
||||
nodes: pd.DataFrame,
|
||||
entities: pd.DataFrame,
|
||||
community_level: int,
|
||||
) -> GraphRagConfig:
|
||||
# TODO: remove the following patch that checks for a vector_store prior to v1 release
|
||||
# TODO: this is a backwards compatibility patch that injects the default vector_store settings into the config if it is not present
|
||||
# Only applicable in situations involving a local vector_store (lancedb). The general idea:
|
||||
# if vector_store not in config:
|
||||
# 1. assume user is running local if vector_store is not in config
|
||||
# 2. insert default vector_store in config
|
||||
# 3 .create lancedb vector_store instance
|
||||
# 4. upload vector embeddings from the input dataframes to the vector_store
|
||||
if not config.embeddings.vector_store:
|
||||
from graphrag.query.input.loaders.dfs import store_entity_semantic_embeddings
|
||||
from graphrag.vector_stores.lancedb import LanceDBVectorStore
|
||||
|
||||
config.embeddings.vector_store = {
|
||||
"type": "lancedb",
|
||||
"db_uri": f"{Path(config.storage.base_dir)}/lancedb",
|
||||
"container_name": "default",
|
||||
"overwrite": True,
|
||||
}
|
||||
description_embedding_store = LanceDBVectorStore(
|
||||
db_uri=config.embeddings.vector_store["db_uri"],
|
||||
collection_name="default-entity-description",
|
||||
overwrite=config.embeddings.vector_store["overwrite"],
|
||||
)
|
||||
description_embedding_store.connect(
|
||||
db_uri=config.embeddings.vector_store["db_uri"]
|
||||
)
|
||||
# dump embeddings from the entities list to the description_embedding_store
|
||||
_entities = read_indexer_entities(nodes, entities, community_level)
|
||||
store_entity_semantic_embeddings(
|
||||
entities=_entities, vectorstore=description_embedding_store
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
def _get_embedding_description_store(
|
||||
config_args: dict,
|
||||
):
|
||||
"""Get the embedding description store."""
|
||||
vector_store_type = config_args["type"]
|
||||
collection_name = f"{config_args['container_name']}-entity-description"
|
||||
description_embedding_store = VectorStoreFactory.get_vector_store(
|
||||
vector_store_type=vector_store_type, kwargs=config_args
|
||||
vector_store_type=vector_store_type,
|
||||
kwargs={**config_args, "collection_name": collection_name},
|
||||
)
|
||||
description_embedding_store.connect(**config_args)
|
||||
return description_embedding_store
|
||||
|
||||
@ -115,6 +115,7 @@ def run_local_search(
|
||||
config.storage.base_dir = str(data_dir) if data_dir else config.storage.base_dir
|
||||
resolve_paths(config)
|
||||
|
||||
# TODO remove optional create_final_entities_description_embeddings.parquet to delete backwards compatibility
|
||||
dataframe_dict = _resolve_parquet_files(
|
||||
root_dir=root_dir,
|
||||
config=config,
|
||||
@ -125,7 +126,9 @@ def run_local_search(
|
||||
"create_final_relationships.parquet",
|
||||
"create_final_entities.parquet",
|
||||
],
|
||||
optional_list=["create_final_covariates.parquet"],
|
||||
optional_list=[
|
||||
"create_final_covariates.parquet",
|
||||
],
|
||||
)
|
||||
final_nodes: pd.DataFrame = dataframe_dict["create_final_nodes"]
|
||||
final_community_reports: pd.DataFrame = dataframe_dict[
|
||||
|
||||
@ -414,6 +414,7 @@ def create_graphrag_config(
|
||||
raw_entities=reader.bool("raw_entities") or defs.SNAPSHOTS_RAW_ENTITIES,
|
||||
top_level_nodes=reader.bool("top_level_nodes")
|
||||
or defs.SNAPSHOTS_TOP_LEVEL_NODES,
|
||||
embeddings=reader.bool("embeddings") or defs.SNAPSHOTS_EMBEDDINGS,
|
||||
)
|
||||
with reader.envvar_prefix(Section.umap), reader.use(values.get("umap")):
|
||||
umap_model = UmapConfig(
|
||||
|
||||
@ -82,6 +82,7 @@ REPORTING_BASE_DIR = "logs"
|
||||
SNAPSHOTS_GRAPHML = False
|
||||
SNAPSHOTS_RAW_ENTITIES = False
|
||||
SNAPSHOTS_TOP_LEVEL_NODES = False
|
||||
SNAPSHOTS_EMBEDDINGS = False
|
||||
STORAGE_BASE_DIR = "output"
|
||||
STORAGE_TYPE = StorageType.file
|
||||
SUMMARIZE_DESCRIPTIONS_MAX_LENGTH = 500
|
||||
@ -91,7 +92,7 @@ UPDATE_STORAGE_BASE_DIR = "update_output"
|
||||
VECTOR_STORE = f"""
|
||||
type: {VectorStoreType.LanceDB.value}
|
||||
db_uri: '{(Path(STORAGE_BASE_DIR) / "lancedb")!s}'
|
||||
collection_name: entity_description_embeddings
|
||||
collection_name: default
|
||||
overwrite: true\
|
||||
"""
|
||||
|
||||
|
||||
@ -86,6 +86,7 @@ class TextEmbeddingTarget(str, Enum):
|
||||
|
||||
all = "all"
|
||||
required = "required"
|
||||
none = "none"
|
||||
|
||||
def __repr__(self):
|
||||
"""Get a string representation."""
|
||||
|
||||
@ -23,3 +23,7 @@ class SnapshotsConfig(BaseModel):
|
||||
description="A flag indicating whether to take snapshots of top-level nodes.",
|
||||
default=defs.SNAPSHOTS_TOP_LEVEL_NODES,
|
||||
)
|
||||
embeddings: bool = Field(
|
||||
description="A flag indicating whether to take snapshots of embeddings.",
|
||||
default=defs.SNAPSHOTS_EMBEDDINGS,
|
||||
)
|
||||
|
||||
@ -11,6 +11,18 @@ from .cache import (
|
||||
PipelineMemoryCacheConfig,
|
||||
PipelineNoneCacheConfig,
|
||||
)
|
||||
from .embeddings import (
|
||||
all_embeddings,
|
||||
community_full_content_embedding,
|
||||
community_summary_embedding,
|
||||
community_title_embedding,
|
||||
document_raw_content_embedding,
|
||||
entity_description_embedding,
|
||||
entity_name_embedding,
|
||||
relationship_description_embedding,
|
||||
required_embeddings,
|
||||
text_unit_text_embedding,
|
||||
)
|
||||
from .input import (
|
||||
PipelineCSVInputConfig,
|
||||
PipelineInputConfig,
|
||||
@ -66,4 +78,14 @@ __all__ = [
|
||||
"PipelineWorkflowConfig",
|
||||
"PipelineWorkflowReference",
|
||||
"PipelineWorkflowStep",
|
||||
"all_embeddings",
|
||||
"community_full_content_embedding",
|
||||
"community_summary_embedding",
|
||||
"community_title_embedding",
|
||||
"document_raw_content_embedding",
|
||||
"entity_description_embedding",
|
||||
"entity_name_embedding",
|
||||
"relationship_description_embedding",
|
||||
"required_embeddings",
|
||||
"text_unit_text_embedding",
|
||||
]
|
||||
|
||||
25
graphrag/index/config/embeddings.py
Normal file
25
graphrag/index/config/embeddings.py
Normal file
@ -0,0 +1,25 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing embeddings values."""
|
||||
|
||||
entity_name_embedding = "entity.name"
|
||||
entity_description_embedding = "entity.description"
|
||||
relationship_description_embedding = "relationship.description"
|
||||
document_raw_content_embedding = "document.raw_content"
|
||||
community_title_embedding = "community.title"
|
||||
community_summary_embedding = "community.summary"
|
||||
community_full_content_embedding = "community.full_content"
|
||||
text_unit_text_embedding = "text_unit.text"
|
||||
|
||||
all_embeddings: set[str] = {
|
||||
entity_name_embedding,
|
||||
entity_description_embedding,
|
||||
relationship_description_embedding,
|
||||
document_raw_content_embedding,
|
||||
community_title_embedding,
|
||||
community_summary_embedding,
|
||||
community_full_content_embedding,
|
||||
text_unit_text_embedding,
|
||||
}
|
||||
required_embeddings: set[str] = {entity_description_embedding}
|
||||
@ -22,6 +22,10 @@ from graphrag.index.config.cache import (
|
||||
PipelineMemoryCacheConfig,
|
||||
PipelineNoneCacheConfig,
|
||||
)
|
||||
from graphrag.index.config.embeddings import (
|
||||
all_embeddings,
|
||||
required_embeddings,
|
||||
)
|
||||
from graphrag.index.config.input import (
|
||||
PipelineCSVInputConfig,
|
||||
PipelineInputConfigTypes,
|
||||
@ -56,33 +60,11 @@ from graphrag.index.workflows.default_workflows import (
|
||||
create_final_nodes,
|
||||
create_final_relationships,
|
||||
create_final_text_units,
|
||||
generate_text_embeddings,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
entity_name_embedding = "entity.name"
|
||||
entity_description_embedding = "entity.description"
|
||||
relationship_description_embedding = "relationship.description"
|
||||
document_raw_content_embedding = "document.raw_content"
|
||||
community_title_embedding = "community.title"
|
||||
community_summary_embedding = "community.summary"
|
||||
community_full_content_embedding = "community.full_content"
|
||||
text_unit_text_embedding = "text_unit.text"
|
||||
|
||||
all_embeddings: set[str] = {
|
||||
entity_name_embedding,
|
||||
entity_description_embedding,
|
||||
relationship_description_embedding,
|
||||
document_raw_content_embedding,
|
||||
community_title_embedding,
|
||||
community_summary_embedding,
|
||||
community_full_content_embedding,
|
||||
text_unit_text_embedding,
|
||||
}
|
||||
required_embeddings: set[str] = {entity_description_embedding}
|
||||
|
||||
|
||||
builtin_document_attributes: set[str] = {
|
||||
"id",
|
||||
"source",
|
||||
@ -121,11 +103,12 @@ def create_pipeline_config(settings: GraphRagConfig, verbose=False) -> PipelineC
|
||||
),
|
||||
cache=_get_cache_config(settings),
|
||||
workflows=[
|
||||
*_document_workflows(settings, embedded_fields),
|
||||
*_text_unit_workflows(settings, covariates_enabled, embedded_fields),
|
||||
*_graph_workflows(settings, embedded_fields),
|
||||
*_community_workflows(settings, covariates_enabled, embedded_fields),
|
||||
*_document_workflows(settings),
|
||||
*_text_unit_workflows(settings, covariates_enabled),
|
||||
*_graph_workflows(settings),
|
||||
*_community_workflows(settings, covariates_enabled),
|
||||
*(_covariate_workflows(settings) if covariates_enabled else []),
|
||||
*(_embeddings_workflows(settings, embedded_fields)),
|
||||
],
|
||||
)
|
||||
|
||||
@ -138,9 +121,11 @@ def create_pipeline_config(settings: GraphRagConfig, verbose=False) -> PipelineC
|
||||
def _get_embedded_fields(settings: GraphRagConfig) -> set[str]:
|
||||
match settings.embeddings.target:
|
||||
case TextEmbeddingTarget.all:
|
||||
return all_embeddings - {*settings.embeddings.skip}
|
||||
return all_embeddings.difference(settings.embeddings.skip)
|
||||
case TextEmbeddingTarget.required:
|
||||
return required_embeddings
|
||||
case TextEmbeddingTarget.none:
|
||||
return set()
|
||||
case _:
|
||||
msg = f"Unknown embeddings target: {settings.embeddings.target}"
|
||||
raise ValueError(msg)
|
||||
@ -163,11 +148,8 @@ def _log_llm_settings(settings: GraphRagConfig) -> None:
|
||||
|
||||
|
||||
def _document_workflows(
|
||||
settings: GraphRagConfig, embedded_fields: set[str]
|
||||
settings: GraphRagConfig,
|
||||
) -> list[PipelineWorkflowReference]:
|
||||
skip_document_raw_content_embedding = (
|
||||
document_raw_content_embedding not in embedded_fields
|
||||
)
|
||||
return [
|
||||
PipelineWorkflowReference(
|
||||
name=create_final_documents,
|
||||
@ -176,15 +158,6 @@ def _document_workflows(
|
||||
{*(settings.input.document_attribute_columns)}
|
||||
- builtin_document_attributes
|
||||
),
|
||||
"document_raw_content_embed": _get_embedding_settings(
|
||||
settings.embeddings,
|
||||
"document_raw_content",
|
||||
{
|
||||
"title_column": "raw_content",
|
||||
"collection_name": "final_documents_raw_content_embedding",
|
||||
},
|
||||
),
|
||||
"skip_raw_content_embedding": skip_document_raw_content_embedding,
|
||||
},
|
||||
),
|
||||
]
|
||||
@ -193,9 +166,7 @@ def _document_workflows(
|
||||
def _text_unit_workflows(
|
||||
settings: GraphRagConfig,
|
||||
covariates_enabled: bool,
|
||||
embedded_fields: set[str],
|
||||
) -> list[PipelineWorkflowReference]:
|
||||
skip_text_unit_embedding = text_unit_text_embedding not in embedded_fields
|
||||
return [
|
||||
PipelineWorkflowReference(
|
||||
name=create_base_text_units,
|
||||
@ -211,13 +182,7 @@ def _text_unit_workflows(
|
||||
PipelineWorkflowReference(
|
||||
name=create_final_text_units,
|
||||
config={
|
||||
"text_unit_text_embed": _get_embedding_settings(
|
||||
settings.embeddings,
|
||||
"text_unit_text",
|
||||
{"title_column": "text", "collection_name": "text_units_embedding"},
|
||||
),
|
||||
"covariates_enabled": covariates_enabled,
|
||||
"skip_text_unit_embedding": skip_text_unit_embedding,
|
||||
},
|
||||
),
|
||||
]
|
||||
@ -225,7 +190,6 @@ def _text_unit_workflows(
|
||||
|
||||
def _get_embedding_settings(
|
||||
settings: TextEmbeddingConfig,
|
||||
embedding_name: str,
|
||||
vector_store_params: dict | None = None,
|
||||
) -> dict:
|
||||
vector_store_settings = settings.vector_store
|
||||
@ -243,20 +207,10 @@ def _get_embedding_settings(
|
||||
# This ensures the vector store config is part of the strategy and not the global config
|
||||
return {
|
||||
"strategy": strategy,
|
||||
"embedding_name": embedding_name,
|
||||
}
|
||||
|
||||
|
||||
def _graph_workflows(
|
||||
settings: GraphRagConfig, embedded_fields: set[str]
|
||||
) -> list[PipelineWorkflowReference]:
|
||||
skip_entity_name_embedding = entity_name_embedding not in embedded_fields
|
||||
skip_entity_description_embedding = (
|
||||
entity_description_embedding not in embedded_fields
|
||||
)
|
||||
skip_relationship_description_embedding = (
|
||||
relationship_description_embedding not in embedded_fields
|
||||
)
|
||||
def _graph_workflows(settings: GraphRagConfig) -> list[PipelineWorkflowReference]:
|
||||
return [
|
||||
PipelineWorkflowReference(
|
||||
name=create_base_entity_graph,
|
||||
@ -286,40 +240,11 @@ def _graph_workflows(
|
||||
),
|
||||
PipelineWorkflowReference(
|
||||
name=create_final_entities,
|
||||
config={
|
||||
"entity_name_embed": _get_embedding_settings(
|
||||
settings.embeddings,
|
||||
"entity_name",
|
||||
{
|
||||
"title_column": "name",
|
||||
"collection_name": "entity_name_embeddings",
|
||||
},
|
||||
),
|
||||
"entity_name_description_embed": _get_embedding_settings(
|
||||
settings.embeddings,
|
||||
"entity_name_description",
|
||||
{
|
||||
"title_column": "description",
|
||||
"collection_name": "entity_description_embeddings",
|
||||
},
|
||||
),
|
||||
"skip_name_embedding": skip_entity_name_embedding,
|
||||
"skip_description_embedding": skip_entity_description_embedding,
|
||||
},
|
||||
config={},
|
||||
),
|
||||
PipelineWorkflowReference(
|
||||
name=create_final_relationships,
|
||||
config={
|
||||
"relationship_description_embed": _get_embedding_settings(
|
||||
settings.embeddings,
|
||||
"relationship_description",
|
||||
{
|
||||
"title_column": "description",
|
||||
"collection_name": "relationships_description_embeddings",
|
||||
},
|
||||
),
|
||||
"skip_description_embedding": skip_relationship_description_embedding,
|
||||
},
|
||||
config={},
|
||||
),
|
||||
PipelineWorkflowReference(
|
||||
name=create_final_nodes,
|
||||
@ -332,24 +257,14 @@ def _graph_workflows(
|
||||
|
||||
|
||||
def _community_workflows(
|
||||
settings: GraphRagConfig, covariates_enabled: bool, embedded_fields: set[str]
|
||||
settings: GraphRagConfig, covariates_enabled: bool
|
||||
) -> list[PipelineWorkflowReference]:
|
||||
skip_community_title_embedding = community_title_embedding not in embedded_fields
|
||||
skip_community_summary_embedding = (
|
||||
community_summary_embedding not in embedded_fields
|
||||
)
|
||||
skip_community_full_content_embedding = (
|
||||
community_full_content_embedding not in embedded_fields
|
||||
)
|
||||
return [
|
||||
PipelineWorkflowReference(name=create_final_communities),
|
||||
PipelineWorkflowReference(
|
||||
name=create_final_community_reports,
|
||||
config={
|
||||
"covariates_enabled": covariates_enabled,
|
||||
"skip_title_embedding": skip_community_title_embedding,
|
||||
"skip_summary_embedding": skip_community_summary_embedding,
|
||||
"skip_full_content_embedding": skip_community_full_content_embedding,
|
||||
"create_community_reports": {
|
||||
**settings.community_reports.parallelization.model_dump(),
|
||||
"async_mode": settings.community_reports.async_mode,
|
||||
@ -357,27 +272,6 @@ def _community_workflows(
|
||||
settings.root_dir
|
||||
),
|
||||
},
|
||||
"community_report_full_content_embed": _get_embedding_settings(
|
||||
settings.embeddings,
|
||||
"community_report_full_content",
|
||||
{
|
||||
"title_column": "full_content",
|
||||
"collection_name": "final_community_reports_full_content_embedding",
|
||||
},
|
||||
),
|
||||
"community_report_summary_embed": _get_embedding_settings(
|
||||
settings.embeddings,
|
||||
"community_report_summary",
|
||||
{
|
||||
"title_column": "summary",
|
||||
"collection_name": "final_community_reports_summary_embedding",
|
||||
},
|
||||
),
|
||||
"community_report_title_embed": _get_embedding_settings(
|
||||
settings.embeddings,
|
||||
"community_report_title",
|
||||
{"title_column": "title"},
|
||||
),
|
||||
},
|
||||
),
|
||||
]
|
||||
@ -401,6 +295,21 @@ def _covariate_workflows(
|
||||
]
|
||||
|
||||
|
||||
def _embeddings_workflows(
|
||||
settings: GraphRagConfig, embedded_fields: set[str]
|
||||
) -> list[PipelineWorkflowReference]:
|
||||
return [
|
||||
PipelineWorkflowReference(
|
||||
name=generate_text_embeddings,
|
||||
config={
|
||||
"snapshot_embeddings": settings.snapshots.embeddings,
|
||||
"text_embed": _get_embedding_settings(settings.embeddings),
|
||||
"embedded_fields": embedded_fields,
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _get_pipeline_input_config(
|
||||
settings: GraphRagConfig,
|
||||
) -> PipelineInputConfigTypes:
|
||||
|
||||
@ -31,7 +31,6 @@ from graphrag.index.graph.extractors.community_reports.schemas import (
|
||||
NODE_ID,
|
||||
NODE_NAME,
|
||||
)
|
||||
from graphrag.index.operations.embed_text import embed_text
|
||||
from graphrag.index.operations.summarize_communities import (
|
||||
prepare_community_reports,
|
||||
restore_community_hierarchy,
|
||||
@ -49,9 +48,6 @@ async def create_final_community_reports(
|
||||
summarization_strategy: dict,
|
||||
async_mode: AsyncType = AsyncType.AsyncIO,
|
||||
num_threads: int = 4,
|
||||
full_content_text_embed: dict | None = None,
|
||||
summary_text_embed: dict | None = None,
|
||||
title_text_embed: dict | None = None,
|
||||
) -> pd.DataFrame:
|
||||
"""All the steps to transform community reports."""
|
||||
nodes = _prep_nodes(nodes_input)
|
||||
@ -86,39 +82,6 @@ async def create_final_community_reports(
|
||||
lambda _x: str(uuid4())
|
||||
)
|
||||
|
||||
# Embed full content if not skipped
|
||||
if full_content_text_embed:
|
||||
community_reports["full_content_embedding"] = await embed_text(
|
||||
community_reports,
|
||||
callbacks,
|
||||
cache,
|
||||
column="full_content",
|
||||
strategy=full_content_text_embed["strategy"],
|
||||
embedding_name="community_report_full_content",
|
||||
)
|
||||
|
||||
# Embed summary if not skipped
|
||||
if summary_text_embed:
|
||||
community_reports["summary_embedding"] = await embed_text(
|
||||
community_reports,
|
||||
callbacks,
|
||||
cache,
|
||||
column="summary",
|
||||
strategy=summary_text_embed["strategy"],
|
||||
embedding_name="community_report_summary",
|
||||
)
|
||||
|
||||
# Embed title if not skipped
|
||||
if title_text_embed:
|
||||
community_reports["title_embedding"] = await embed_text(
|
||||
community_reports,
|
||||
callbacks,
|
||||
cache,
|
||||
column="title",
|
||||
strategy=title_text_embed["strategy"],
|
||||
embedding_name="community_report_title",
|
||||
)
|
||||
|
||||
# Merge by community and it with communities to add size and period
|
||||
return community_reports.merge(
|
||||
communities_input.loc[:, ["id", "size", "period"]],
|
||||
|
||||
@ -4,21 +4,12 @@
|
||||
"""All the steps to transform final documents."""
|
||||
|
||||
import pandas as pd
|
||||
from datashaper import (
|
||||
VerbCallbacks,
|
||||
)
|
||||
|
||||
from graphrag.index.cache import PipelineCache
|
||||
from graphrag.index.operations.embed_text import embed_text
|
||||
|
||||
|
||||
async def create_final_documents(
|
||||
def create_final_documents(
|
||||
documents: pd.DataFrame,
|
||||
text_units: pd.DataFrame,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
document_attribute_columns: list[str] | None = None,
|
||||
raw_content_text_embed: dict | None = None,
|
||||
) -> pd.DataFrame:
|
||||
"""All the steps to transform final documents."""
|
||||
exploded = (
|
||||
@ -72,13 +63,4 @@ async def create_final_documents(
|
||||
# Drop the original attribute columns after collapsing them
|
||||
rejoined.drop(columns=document_attribute_columns, inplace=True)
|
||||
|
||||
if raw_content_text_embed:
|
||||
rejoined["raw_content_embedding"] = await embed_text(
|
||||
rejoined,
|
||||
callbacks,
|
||||
cache,
|
||||
column="raw_content",
|
||||
strategy=raw_content_text_embed["strategy"],
|
||||
)
|
||||
|
||||
return rejoined
|
||||
|
||||
@ -8,18 +8,13 @@ from datashaper import (
|
||||
VerbCallbacks,
|
||||
)
|
||||
|
||||
from graphrag.index.cache import PipelineCache
|
||||
from graphrag.index.operations.embed_text import embed_text
|
||||
from graphrag.index.operations.split_text import split_text
|
||||
from graphrag.index.operations.unpack_graph import unpack_graph
|
||||
|
||||
|
||||
async def create_final_entities(
|
||||
def create_final_entities(
|
||||
entity_graph: pd.DataFrame,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
name_text_embed: dict | None = None,
|
||||
description_text_embed: dict | None = None,
|
||||
) -> pd.DataFrame:
|
||||
"""All the steps to transform final entities."""
|
||||
# Process nodes
|
||||
@ -44,37 +39,6 @@ async def create_final_entities(
|
||||
nodes = nodes.loc[nodes["name"].notna()]
|
||||
|
||||
# Split 'source_id' column into 'text_unit_ids'
|
||||
nodes = split_text(
|
||||
return split_text(
|
||||
nodes, column="source_id", separator=",", to="text_unit_ids"
|
||||
).drop(columns=["source_id"])
|
||||
|
||||
# Embed name if not skipped
|
||||
if name_text_embed:
|
||||
nodes["name_embedding"] = await embed_text(
|
||||
nodes,
|
||||
callbacks,
|
||||
cache,
|
||||
column="name",
|
||||
strategy=name_text_embed["strategy"],
|
||||
embedding_name="entity_name",
|
||||
)
|
||||
|
||||
# Embed description if not skipped
|
||||
if description_text_embed:
|
||||
# Concatenate 'name' and 'description' and embed
|
||||
nodes["name_description"] = nodes["name"] + ":" + nodes["description"]
|
||||
nodes["description_embedding"] = await embed_text(
|
||||
nodes,
|
||||
callbacks,
|
||||
cache,
|
||||
column="name_description",
|
||||
strategy=description_text_embed["strategy"],
|
||||
embedding_name="entity_name_description",
|
||||
)
|
||||
|
||||
# Drop rows with NaN 'description_embedding' if not using vector store
|
||||
if not description_text_embed.get("strategy", {}).get("vector_store"):
|
||||
nodes = nodes.loc[nodes["description_embedding"].notna()]
|
||||
nodes.drop(columns="name_description", inplace=True)
|
||||
|
||||
return nodes
|
||||
|
||||
@ -10,20 +10,16 @@ from datashaper import (
|
||||
VerbCallbacks,
|
||||
)
|
||||
|
||||
from graphrag.index.cache import PipelineCache
|
||||
from graphrag.index.operations.compute_edge_combined_degree import (
|
||||
compute_edge_combined_degree,
|
||||
)
|
||||
from graphrag.index.operations.embed_text import embed_text
|
||||
from graphrag.index.operations.unpack_graph import unpack_graph
|
||||
|
||||
|
||||
async def create_final_relationships(
|
||||
def create_final_relationships(
|
||||
entity_graph: pd.DataFrame,
|
||||
nodes: pd.DataFrame,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
description_text_embed: dict | None = None,
|
||||
) -> pd.DataFrame:
|
||||
"""All the steps to transform final relationships."""
|
||||
graph_edges = unpack_graph(entity_graph, callbacks, "clustered_graph", "edges")
|
||||
@ -34,16 +30,6 @@ async def create_final_relationships(
|
||||
pd.DataFrame, graph_edges[graph_edges["level"] == 0].reset_index(drop=True)
|
||||
)
|
||||
|
||||
if description_text_embed:
|
||||
filtered["description_embedding"] = await embed_text(
|
||||
filtered,
|
||||
callbacks,
|
||||
cache,
|
||||
column="description",
|
||||
strategy=description_text_embed["strategy"],
|
||||
embedding_name="relationship_description",
|
||||
)
|
||||
|
||||
pruned_edges = filtered.drop(columns=["level"])
|
||||
|
||||
filtered_nodes = nodes[nodes["level"] == 0].reset_index(drop=True)
|
||||
|
||||
@ -6,22 +6,13 @@
|
||||
from typing import cast
|
||||
|
||||
import pandas as pd
|
||||
from datashaper import (
|
||||
VerbCallbacks,
|
||||
)
|
||||
|
||||
from graphrag.index.cache import PipelineCache
|
||||
from graphrag.index.operations.embed_text import embed_text
|
||||
|
||||
|
||||
async def create_final_text_units(
|
||||
def create_final_text_units(
|
||||
text_units: pd.DataFrame,
|
||||
final_entities: pd.DataFrame,
|
||||
final_relationships: pd.DataFrame,
|
||||
final_covariates: pd.DataFrame | None,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
text_text_embed: dict | None = None,
|
||||
) -> pd.DataFrame:
|
||||
"""All the steps to transform the text units."""
|
||||
selected = text_units.loc[:, ["id", "chunk", "document_ids", "n_tokens"]].rename(
|
||||
@ -41,30 +32,12 @@ async def create_final_text_units(
|
||||
|
||||
aggregated = final_joined.groupby("id", sort=False).agg("first").reset_index()
|
||||
|
||||
is_using_vector_store = False
|
||||
if text_text_embed:
|
||||
aggregated["text_embedding"] = await embed_text(
|
||||
aggregated,
|
||||
callbacks,
|
||||
cache,
|
||||
column="text",
|
||||
strategy=text_text_embed["strategy"],
|
||||
)
|
||||
is_using_vector_store = (
|
||||
text_text_embed.get("strategy", {}).get("vector_store", None) is not None
|
||||
)
|
||||
|
||||
return cast(
|
||||
pd.DataFrame,
|
||||
aggregated[
|
||||
[
|
||||
"id",
|
||||
"text",
|
||||
*(
|
||||
[]
|
||||
if (not text_text_embed or is_using_vector_store)
|
||||
else ["text_embedding"]
|
||||
),
|
||||
"n_tokens",
|
||||
"document_ids",
|
||||
"entity_ids",
|
||||
|
||||
146
graphrag/index/flows/generate_text_embeddings.py
Normal file
146
graphrag/index/flows/generate_text_embeddings.py
Normal file
@ -0,0 +1,146 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""All the steps to transform the text units."""
|
||||
|
||||
import logging
|
||||
|
||||
import pandas as pd
|
||||
from datashaper import (
|
||||
VerbCallbacks,
|
||||
)
|
||||
|
||||
from graphrag.index.cache import PipelineCache
|
||||
from graphrag.index.config.embeddings import (
|
||||
community_full_content_embedding,
|
||||
community_summary_embedding,
|
||||
community_title_embedding,
|
||||
document_raw_content_embedding,
|
||||
entity_description_embedding,
|
||||
entity_name_embedding,
|
||||
relationship_description_embedding,
|
||||
text_unit_text_embedding,
|
||||
)
|
||||
from graphrag.index.operations.embed_text import embed_text
|
||||
from graphrag.index.operations.snapshot import snapshot
|
||||
from graphrag.index.storage import PipelineStorage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def generate_text_embeddings(
|
||||
final_documents: pd.DataFrame | None,
|
||||
final_relationships: pd.DataFrame | None,
|
||||
final_text_units: pd.DataFrame | None,
|
||||
final_entities: pd.DataFrame | None,
|
||||
final_community_reports: pd.DataFrame | None,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
storage: PipelineStorage,
|
||||
text_embed_config: dict,
|
||||
embedded_fields: set[str],
|
||||
embeddings_snapshot_enabled: bool = False,
|
||||
) -> None:
|
||||
"""All the steps to generate all embeddings."""
|
||||
embedding_param_map = {
|
||||
document_raw_content_embedding: {
|
||||
"data": final_documents.loc[:, ["id", "raw_content"]]
|
||||
if final_documents is not None
|
||||
else None,
|
||||
"column_to_embed": "raw_content",
|
||||
},
|
||||
relationship_description_embedding: {
|
||||
"data": final_relationships.loc[:, ["id", "description"]]
|
||||
if final_relationships is not None
|
||||
else None,
|
||||
"column_to_embed": "description",
|
||||
},
|
||||
text_unit_text_embedding: {
|
||||
"data": final_text_units.loc[:, ["id", "text"]]
|
||||
if final_text_units is not None
|
||||
else None,
|
||||
"column_to_embed": "text",
|
||||
},
|
||||
entity_name_embedding: {
|
||||
"data": final_entities.loc[:, ["id", "name", "description"]]
|
||||
if final_entities is not None
|
||||
else None,
|
||||
"column_to_embed": "name",
|
||||
},
|
||||
entity_description_embedding: {
|
||||
"data": final_entities.loc[:, ["id", "name", "description"]].assign(
|
||||
name_description=lambda df: df["name"] + ":" + df["description"]
|
||||
)
|
||||
if final_entities is not None
|
||||
else None,
|
||||
"column_to_embed": "name_description",
|
||||
},
|
||||
community_title_embedding: {
|
||||
"data": final_community_reports.loc[
|
||||
:, ["id", "full_content", "summary", "title"]
|
||||
]
|
||||
if final_community_reports is not None
|
||||
else None,
|
||||
"column_to_embed": "title",
|
||||
},
|
||||
community_summary_embedding: {
|
||||
"data": final_community_reports.loc[
|
||||
:, ["id", "full_content", "summary", "title"]
|
||||
]
|
||||
if final_community_reports is not None
|
||||
else None,
|
||||
"column_to_embed": "summary",
|
||||
},
|
||||
community_full_content_embedding: {
|
||||
"data": final_community_reports.loc[
|
||||
:, ["id", "full_content", "summary", "title"]
|
||||
]
|
||||
if final_community_reports is not None
|
||||
else None,
|
||||
"column_to_embed": "full_content",
|
||||
},
|
||||
}
|
||||
|
||||
log.info("Creating embeddings")
|
||||
for field in embedded_fields:
|
||||
await _run_and_snapshot_embeddings(
|
||||
name=field,
|
||||
callbacks=callbacks,
|
||||
cache=cache,
|
||||
storage=storage,
|
||||
text_embed_config=text_embed_config,
|
||||
embeddings_snapshot_enabled=embeddings_snapshot_enabled,
|
||||
**embedding_param_map[field],
|
||||
)
|
||||
|
||||
|
||||
async def _run_and_snapshot_embeddings(
|
||||
name: str,
|
||||
data: pd.DataFrame,
|
||||
column_to_embed: str,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
storage: PipelineStorage,
|
||||
text_embed_config: dict,
|
||||
embeddings_snapshot_enabled: bool,
|
||||
) -> None:
|
||||
"""All the steps to generate single embedding."""
|
||||
if text_embed_config:
|
||||
data["embedding"] = await embed_text(
|
||||
data,
|
||||
callbacks,
|
||||
cache,
|
||||
embed_column=column_to_embed,
|
||||
embedding_name=name,
|
||||
strategy=text_embed_config["strategy"],
|
||||
)
|
||||
|
||||
data = data.loc[:, ["id", "embedding"]]
|
||||
|
||||
if embeddings_snapshot_enabled is True:
|
||||
await snapshot(
|
||||
data,
|
||||
name=f"embeddings.{name}",
|
||||
storage=storage,
|
||||
formats=["parquet"],
|
||||
)
|
||||
@ -49,7 +49,7 @@ embeddings:
|
||||
# api_key: <api_key> # if not set, will attempt to use managed identity. Expects the `Search Index Data Contributor` RBAC role in this case.
|
||||
# audience: <optional> # if using managed identity, the audience to use for the token
|
||||
# overwrite: true # or false. Only applicable at index creation time
|
||||
# collection_name: <collection_name> # the name of the collection to use
|
||||
# collection_name: <collection_name> # the name of the collection to use. Default: 'default'
|
||||
llm:
|
||||
api_key: ${{GRAPHRAG_API_KEY}}
|
||||
type: {defs.EMBEDDING_TYPE.value} # or azure_openai_embedding
|
||||
|
||||
@ -42,9 +42,11 @@ async def embed_text(
|
||||
input: pd.DataFrame,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
column: str,
|
||||
embed_column: str,
|
||||
strategy: dict,
|
||||
embedding_name: str = "default",
|
||||
embedding_name: str,
|
||||
id_column: str = "id",
|
||||
title_column: str | None = None,
|
||||
):
|
||||
"""
|
||||
Embed a piece of text into a vector space. The operation outputs a new column containing a mapping between doc_id and vector.
|
||||
@ -91,18 +93,19 @@ async def embed_text(
|
||||
input,
|
||||
callbacks,
|
||||
cache,
|
||||
column,
|
||||
embed_column,
|
||||
strategy,
|
||||
vector_store,
|
||||
vector_store_workflow_config,
|
||||
vector_store_config.get("store_in_table", False),
|
||||
id_column=id_column,
|
||||
title_column=title_column,
|
||||
)
|
||||
|
||||
return await _text_embed_in_memory(
|
||||
input,
|
||||
callbacks,
|
||||
cache,
|
||||
column,
|
||||
embed_column,
|
||||
strategy,
|
||||
)
|
||||
|
||||
@ -111,14 +114,14 @@ async def _text_embed_in_memory(
|
||||
input: pd.DataFrame,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
column: str,
|
||||
embed_column: str,
|
||||
strategy: dict,
|
||||
):
|
||||
strategy_type = strategy["type"]
|
||||
strategy_exec = load_strategy(strategy_type)
|
||||
strategy_args = {**strategy}
|
||||
|
||||
texts: list[str] = input[column].to_numpy().tolist()
|
||||
texts: list[str] = input[embed_column].to_numpy().tolist()
|
||||
result = await strategy_exec(texts, callbacks, cache, strategy_args)
|
||||
|
||||
return result.embeddings
|
||||
@ -128,11 +131,12 @@ async def _text_embed_with_vector_store(
|
||||
input: pd.DataFrame,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
column: str,
|
||||
embed_column: str,
|
||||
strategy: dict[str, Any],
|
||||
vector_store: BaseVectorStore,
|
||||
vector_store_config: dict,
|
||||
store_in_table: bool = False,
|
||||
id_column: str = "id",
|
||||
title_column: str | None = None,
|
||||
):
|
||||
strategy_type = strategy["type"]
|
||||
strategy_exec = load_strategy(strategy_type)
|
||||
@ -142,24 +146,24 @@ async def _text_embed_with_vector_store(
|
||||
insert_batch_size: int = (
|
||||
vector_store_config.get("batch_size") or DEFAULT_EMBEDDING_BATCH_SIZE
|
||||
)
|
||||
title_column: str = vector_store_config.get("title_column", "title")
|
||||
id_column: str = vector_store_config.get("id_column", "id")
|
||||
|
||||
overwrite: bool = vector_store_config.get("overwrite", True)
|
||||
|
||||
if column not in input.columns:
|
||||
msg = (
|
||||
f"Column {column} not found in input dataframe with columns {input.columns}"
|
||||
)
|
||||
if embed_column not in input.columns:
|
||||
msg = f"Column {embed_column} not found in input dataframe with columns {input.columns}"
|
||||
raise ValueError(msg)
|
||||
if title_column not in input.columns:
|
||||
msg = f"Column {title_column} not found in input dataframe with columns {input.columns}"
|
||||
title = title_column or embed_column
|
||||
if title not in input.columns:
|
||||
msg = (
|
||||
f"Column {title} not found in input dataframe with columns {input.columns}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
if id_column not in input.columns:
|
||||
msg = f"Column {id_column} not found in input dataframe with columns {input.columns}"
|
||||
raise ValueError(msg)
|
||||
|
||||
total_rows = 0
|
||||
for row in input[column]:
|
||||
for row in input[embed_column]:
|
||||
if isinstance(row, list):
|
||||
total_rows += len(row)
|
||||
else:
|
||||
@ -172,8 +176,8 @@ async def _text_embed_with_vector_store(
|
||||
|
||||
while insert_batch_size * i < input.shape[0]:
|
||||
batch = input.iloc[insert_batch_size * i : insert_batch_size * (i + 1)]
|
||||
texts: list[str] = batch[column].to_numpy().tolist()
|
||||
titles: list[str] = batch[title_column].to_numpy().tolist()
|
||||
texts: list[str] = batch[embed_column].to_numpy().tolist()
|
||||
titles: list[str] = batch[title].to_numpy().tolist()
|
||||
ids: list[str] = batch[id_column].to_numpy().tolist()
|
||||
result = await strategy_exec(
|
||||
texts,
|
||||
@ -181,7 +185,7 @@ async def _text_embed_with_vector_store(
|
||||
cache,
|
||||
strategy_args,
|
||||
)
|
||||
if store_in_table and result.embeddings:
|
||||
if result.embeddings:
|
||||
embeddings = [
|
||||
embedding for embedding in result.embeddings if embedding is not None
|
||||
]
|
||||
@ -204,10 +208,7 @@ async def _text_embed_with_vector_store(
|
||||
starting_index += len(documents)
|
||||
i += 1
|
||||
|
||||
if store_in_table:
|
||||
return all_results
|
||||
|
||||
return None
|
||||
return all_results
|
||||
|
||||
|
||||
def _create_vector_store(
|
||||
@ -226,12 +227,10 @@ def _create_vector_store(
|
||||
|
||||
|
||||
def _get_collection_name(vector_store_config: dict, embedding_name: str) -> str:
|
||||
collection_name = vector_store_config.get("collection_name")
|
||||
if not collection_name:
|
||||
collection_names = vector_store_config.get("collection_names", {})
|
||||
collection_name = collection_names.get(embedding_name, embedding_name)
|
||||
container_name = vector_store_config.get("container_name")
|
||||
collection_name = f"{container_name}.{embedding_name}".replace(".", "-")
|
||||
|
||||
msg = f"using vector store {vector_store_config.get('type')} with collection_name {collection_name} for embedding {embedding_name}"
|
||||
msg = f"using vector store {vector_store_config.get('type')} with container_name {container_name} for embedding {embedding_name}: {collection_name}"
|
||||
log.info(msg)
|
||||
return collection_name
|
||||
|
||||
|
||||
@ -17,8 +17,8 @@ async def snapshot(
|
||||
"""Take a entire snapshot of the tabular data."""
|
||||
for fmt in formats:
|
||||
if fmt == "parquet":
|
||||
await storage.set(name + ".parquet", input.to_parquet())
|
||||
await storage.set(f"{name}.parquet", input.to_parquet())
|
||||
elif fmt == "json":
|
||||
await storage.set(
|
||||
name + ".json", input.to_json(orient="records", lines=True)
|
||||
f"{name}.json", input.to_json(orient="records", lines=True)
|
||||
)
|
||||
|
||||
@ -181,7 +181,8 @@ async def _run_entity_description_embedding(
|
||||
entities_df,
|
||||
callbacks,
|
||||
cache,
|
||||
column="name_description",
|
||||
embed_column="name_description",
|
||||
embedding_name="entity.description",
|
||||
strategy=embed_config.get("strategy", {}),
|
||||
)
|
||||
return entities_df.drop(columns=["name_description"])
|
||||
|
||||
@ -67,7 +67,13 @@ from .v1.create_final_text_units import (
|
||||
from .v1.create_final_text_units import (
|
||||
workflow_name as create_final_text_units,
|
||||
)
|
||||
from .v1.generate_text_embeddings import (
|
||||
build_steps as build_generate_text_embeddings_steps,
|
||||
)
|
||||
|
||||
from .v1.generate_text_embeddings import (
|
||||
workflow_name as generate_text_embeddings,
|
||||
)
|
||||
|
||||
default_workflows: WorkflowDefinitions = {
|
||||
create_base_entity_graph: build_create_base_entity_graph_steps,
|
||||
@ -80,4 +86,5 @@ default_workflows: WorkflowDefinitions = {
|
||||
create_final_covariates: build_create_final_covariates_steps,
|
||||
create_final_entities: build_create_final_entities_steps,
|
||||
create_final_communities: build_create_final_communities_steps,
|
||||
generate_text_embeddings: build_generate_text_embeddings_steps,
|
||||
}
|
||||
|
||||
@ -23,19 +23,6 @@ def build_steps(
|
||||
async_mode = create_community_reports_config.get("async_mode")
|
||||
num_threads = create_community_reports_config.get("num_threads")
|
||||
|
||||
base_text_embed = config.get("text_embed", {})
|
||||
community_report_full_content_embed_config = config.get(
|
||||
"community_report_full_content_embed", base_text_embed
|
||||
)
|
||||
community_report_summary_embed_config = config.get(
|
||||
"community_report_summary_embed", base_text_embed
|
||||
)
|
||||
community_report_title_embed_config = config.get(
|
||||
"community_report_title_embed", base_text_embed
|
||||
)
|
||||
skip_title_embedding = config.get("skip_title_embedding", False)
|
||||
skip_summary_embedding = config.get("skip_summary_embedding", False)
|
||||
skip_full_content_embedding = config.get("skip_full_content_embedding", False)
|
||||
input = {
|
||||
"source": "workflow:create_final_nodes",
|
||||
"relationships": "workflow:create_final_relationships",
|
||||
@ -48,21 +35,6 @@ def build_steps(
|
||||
{
|
||||
"verb": "create_final_community_reports",
|
||||
"args": {
|
||||
"full_content_text_embed": (
|
||||
community_report_full_content_embed_config
|
||||
if not skip_full_content_embedding
|
||||
else None
|
||||
),
|
||||
"summary_text_embed": (
|
||||
community_report_summary_embed_config
|
||||
if not skip_summary_embedding
|
||||
else None
|
||||
),
|
||||
"title_text_embed": (
|
||||
community_report_title_embed_config
|
||||
if not skip_title_embedding
|
||||
else None
|
||||
),
|
||||
"summarization_strategy": summarization_strategy,
|
||||
"async_mode": async_mode,
|
||||
"num_threads": num_threads,
|
||||
|
||||
@ -19,23 +19,11 @@ def build_steps(
|
||||
## Dependencies
|
||||
* `workflow:create_final_text_units`
|
||||
"""
|
||||
base_text_embed = config.get("text_embed", {})
|
||||
document_raw_content_embed_config = config.get(
|
||||
"document_raw_content_embed", base_text_embed
|
||||
)
|
||||
skip_raw_content_embedding = config.get("skip_raw_content_embedding", False)
|
||||
document_attribute_columns = config.get("document_attribute_columns", [])
|
||||
return [
|
||||
{
|
||||
"verb": "create_final_documents",
|
||||
"args": {
|
||||
"document_attribute_columns": document_attribute_columns,
|
||||
"raw_content_text_embed": (
|
||||
document_raw_content_embed_config
|
||||
if not skip_raw_content_embedding
|
||||
else None
|
||||
),
|
||||
},
|
||||
"args": {"document_attribute_columns": document_attribute_columns},
|
||||
"input": {
|
||||
"source": DEFAULT_INPUT_NAME,
|
||||
"text_units": "workflow:create_final_text_units",
|
||||
|
||||
@ -3,13 +3,16 @@
|
||||
|
||||
"""A module containing build_steps method definition."""
|
||||
|
||||
import logging
|
||||
|
||||
from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep
|
||||
|
||||
workflow_name = "create_final_entities"
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_steps(
|
||||
config: PipelineWorkflowConfig,
|
||||
config: PipelineWorkflowConfig, # noqa: ARG001
|
||||
) -> list[PipelineWorkflowStep]:
|
||||
"""
|
||||
Create the final entities table.
|
||||
@ -17,26 +20,10 @@ def build_steps(
|
||||
## Dependencies
|
||||
* `workflow:create_base_entity_graph`
|
||||
"""
|
||||
base_text_embed = config.get("text_embed", {})
|
||||
entity_name_embed_config = config.get("entity_name_embed", base_text_embed)
|
||||
entity_name_description_embed_config = config.get(
|
||||
"entity_name_description_embed", base_text_embed
|
||||
)
|
||||
|
||||
skip_name_embedding = config.get("skip_name_embedding", False)
|
||||
skip_description_embedding = config.get("skip_description_embedding", False)
|
||||
|
||||
return [
|
||||
{
|
||||
"verb": "create_final_entities",
|
||||
"args": {
|
||||
"name_text_embed": entity_name_embed_config
|
||||
if not skip_name_embedding
|
||||
else None,
|
||||
"description_text_embed": entity_name_description_embed_config
|
||||
if not skip_description_embedding
|
||||
else None,
|
||||
},
|
||||
"args": {},
|
||||
"input": {"source": "workflow:create_base_entity_graph"},
|
||||
},
|
||||
]
|
||||
|
||||
@ -3,13 +3,17 @@
|
||||
|
||||
"""A module containing build_steps method definition."""
|
||||
|
||||
import logging
|
||||
|
||||
from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep
|
||||
|
||||
workflow_name = "create_final_relationships"
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_steps(
|
||||
config: PipelineWorkflowConfig,
|
||||
config: PipelineWorkflowConfig, # noqa: ARG001
|
||||
) -> list[PipelineWorkflowStep]:
|
||||
"""
|
||||
Create the final relationships table.
|
||||
@ -18,19 +22,10 @@ def build_steps(
|
||||
* `workflow:create_base_entity_graph`
|
||||
* `workflow:create_final_nodes`
|
||||
"""
|
||||
base_text_embed = config.get("text_embed", {})
|
||||
relationship_description_embed_config = config.get(
|
||||
"relationship_description_embed", base_text_embed
|
||||
)
|
||||
skip_description_embedding = config.get("skip_description_embedding", False)
|
||||
return [
|
||||
{
|
||||
"verb": "create_final_relationships",
|
||||
"args": {
|
||||
"description_text_embed": relationship_description_embed_config
|
||||
if not skip_description_embedding
|
||||
else None,
|
||||
},
|
||||
"args": {},
|
||||
"input": {
|
||||
"source": "workflow:create_base_entity_graph",
|
||||
"nodes": "workflow:create_final_nodes",
|
||||
|
||||
@ -19,10 +19,6 @@ def build_steps(
|
||||
* `workflow:create_final_entities`
|
||||
* `workflow:create_final_communities`
|
||||
"""
|
||||
base_text_embed = config.get("text_embed", {})
|
||||
text_unit_text_embed_config = config.get("text_unit_text_embed", base_text_embed)
|
||||
|
||||
skip_text_unit_embedding = config.get("skip_text_unit_embedding", False)
|
||||
covariates_enabled = config.get("covariates_enabled", False)
|
||||
|
||||
input = {
|
||||
@ -37,11 +33,7 @@ def build_steps(
|
||||
return [
|
||||
{
|
||||
"verb": "create_final_text_units",
|
||||
"args": {
|
||||
"text_text_embed": text_unit_text_embed_config
|
||||
if not skip_text_unit_embedding
|
||||
else None,
|
||||
},
|
||||
"args": {},
|
||||
"input": input,
|
||||
},
|
||||
]
|
||||
|
||||
49
graphrag/index/workflows/v1/generate_text_embeddings.py
Normal file
49
graphrag/index/workflows/v1/generate_text_embeddings.py
Normal file
@ -0,0 +1,49 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing build_steps method definition."""
|
||||
|
||||
import logging
|
||||
|
||||
from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
workflow_name = "generate_text_embeddings"
|
||||
|
||||
input = {
|
||||
"source": "workflow:create_final_documents",
|
||||
"relationships": "workflow:create_final_relationships",
|
||||
"text_units": "workflow:create_final_text_units",
|
||||
"entities": "workflow:create_final_entities",
|
||||
"community_reports": "workflow:create_final_community_reports",
|
||||
}
|
||||
|
||||
|
||||
def build_steps(
|
||||
config: PipelineWorkflowConfig,
|
||||
) -> list[PipelineWorkflowStep]:
|
||||
"""
|
||||
Create the final embeddings files.
|
||||
|
||||
## Dependencies
|
||||
* `workflow:create_final_documents`
|
||||
* `workflow:create_final_relationships`
|
||||
* `workflow:create_final_text_units`
|
||||
* `workflow:create_final_entities`
|
||||
* `workflow:create_final_community_reports`
|
||||
"""
|
||||
text_embed = config.get("text_embed", {})
|
||||
embedded_fields = config.get("embedded_fields", {})
|
||||
embeddings_snapshot_enabled = config.get("snapshot_embeddings", False)
|
||||
return [
|
||||
{
|
||||
"verb": "generate_text_embeddings",
|
||||
"args": {
|
||||
"text_embed": text_embed,
|
||||
"embedded_fields": embedded_fields,
|
||||
"embeddings_snapshot_enabled": embeddings_snapshot_enabled,
|
||||
},
|
||||
"input": input,
|
||||
},
|
||||
]
|
||||
@ -15,6 +15,7 @@ from .create_final_relationships import (
|
||||
create_final_relationships,
|
||||
)
|
||||
from .create_final_text_units import create_final_text_units
|
||||
from .generate_text_embeddings import generate_text_embeddings
|
||||
|
||||
__all__ = [
|
||||
"create_base_entity_graph",
|
||||
@ -27,4 +28,5 @@ __all__ = [
|
||||
"create_final_nodes",
|
||||
"create_final_relationships",
|
||||
"create_final_text_units",
|
||||
"generate_text_embeddings",
|
||||
]
|
||||
|
||||
@ -30,9 +30,6 @@ async def create_final_community_reports(
|
||||
summarization_strategy: dict,
|
||||
async_mode: AsyncType = AsyncType.AsyncIO,
|
||||
num_threads: int = 4,
|
||||
full_content_text_embed: dict | None = None,
|
||||
summary_text_embed: dict | None = None,
|
||||
title_text_embed: dict | None = None,
|
||||
**_kwargs: dict,
|
||||
) -> VerbResult:
|
||||
"""All the steps to transform community reports."""
|
||||
@ -57,9 +54,6 @@ async def create_final_community_reports(
|
||||
summarization_strategy,
|
||||
async_mode=async_mode,
|
||||
num_threads=num_threads,
|
||||
full_content_text_embed=full_content_text_embed,
|
||||
summary_text_embed=summary_text_embed,
|
||||
title_text_embed=title_text_embed,
|
||||
)
|
||||
|
||||
return create_verb_result(
|
||||
|
||||
@ -8,13 +8,11 @@ from typing import cast
|
||||
import pandas as pd
|
||||
from datashaper import (
|
||||
Table,
|
||||
VerbCallbacks,
|
||||
VerbInput,
|
||||
verb,
|
||||
)
|
||||
from datashaper.table_store.types import VerbResult, create_verb_result
|
||||
|
||||
from graphrag.index.cache import PipelineCache
|
||||
from graphrag.index.flows.create_final_documents import (
|
||||
create_final_documents as create_final_documents_flow,
|
||||
)
|
||||
@ -25,25 +23,15 @@ from graphrag.index.utils.ds_util import get_required_input_table
|
||||
name="create_final_documents",
|
||||
treats_input_tables_as_immutable=True,
|
||||
)
|
||||
async def create_final_documents(
|
||||
def create_final_documents(
|
||||
input: VerbInput,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
document_attribute_columns: list[str] | None = None,
|
||||
raw_content_text_embed: dict | None = None,
|
||||
**_kwargs: dict,
|
||||
) -> VerbResult:
|
||||
"""All the steps to transform final documents."""
|
||||
source = cast(pd.DataFrame, input.get_input())
|
||||
text_units = cast(pd.DataFrame, get_required_input_table(input, "text_units").table)
|
||||
|
||||
output = await create_final_documents_flow(
|
||||
source,
|
||||
text_units,
|
||||
callbacks,
|
||||
cache,
|
||||
document_attribute_columns=document_attribute_columns,
|
||||
raw_content_text_embed=raw_content_text_embed,
|
||||
)
|
||||
output = create_final_documents_flow(source, text_units, document_attribute_columns)
|
||||
|
||||
return create_verb_result(cast(Table, output))
|
||||
|
||||
@ -14,7 +14,6 @@ from datashaper import (
|
||||
)
|
||||
from datashaper.table_store.types import VerbResult, create_verb_result
|
||||
|
||||
from graphrag.index.cache import PipelineCache
|
||||
from graphrag.index.flows.create_final_entities import (
|
||||
create_final_entities as create_final_entities_flow,
|
||||
)
|
||||
@ -24,23 +23,17 @@ from graphrag.index.flows.create_final_entities import (
|
||||
name="create_final_entities",
|
||||
treats_input_tables_as_immutable=True,
|
||||
)
|
||||
async def create_final_entities(
|
||||
def create_final_entities(
|
||||
input: VerbInput,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
name_text_embed: dict | None = None,
|
||||
description_text_embed: dict | None = None,
|
||||
**_kwargs: dict,
|
||||
) -> VerbResult:
|
||||
"""All the steps to transform final entities."""
|
||||
source = cast(pd.DataFrame, input.get_input())
|
||||
|
||||
output = await create_final_entities_flow(
|
||||
output = create_final_entities_flow(
|
||||
source,
|
||||
callbacks,
|
||||
cache,
|
||||
name_text_embed=name_text_embed,
|
||||
description_text_embed=description_text_embed,
|
||||
)
|
||||
|
||||
return create_verb_result(cast(Table, output))
|
||||
|
||||
@ -14,7 +14,6 @@ from datashaper import (
|
||||
)
|
||||
from datashaper.table_store.types import VerbResult, create_verb_result
|
||||
|
||||
from graphrag.index.cache import PipelineCache
|
||||
from graphrag.index.flows.create_final_relationships import (
|
||||
create_final_relationships as create_final_relationships_flow,
|
||||
)
|
||||
@ -25,23 +24,19 @@ from graphrag.index.utils.ds_util import get_required_input_table
|
||||
name="create_final_relationships",
|
||||
treats_input_tables_as_immutable=True,
|
||||
)
|
||||
async def create_final_relationships(
|
||||
def create_final_relationships(
|
||||
input: VerbInput,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
description_text_embed: dict | None = None,
|
||||
**_kwargs: dict,
|
||||
) -> VerbResult:
|
||||
"""All the steps to transform final relationships."""
|
||||
source = cast(pd.DataFrame, input.get_input())
|
||||
nodes = cast(pd.DataFrame, get_required_input_table(input, "nodes").table)
|
||||
|
||||
output = await create_final_relationships_flow(
|
||||
source,
|
||||
nodes,
|
||||
callbacks,
|
||||
cache,
|
||||
description_text_embed=description_text_embed,
|
||||
output = create_final_relationships_flow(
|
||||
entity_graph=source,
|
||||
nodes=nodes,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
return create_verb_result(cast(Table, output))
|
||||
|
||||
@ -8,14 +8,12 @@ from typing import cast
|
||||
import pandas as pd
|
||||
from datashaper import (
|
||||
Table,
|
||||
VerbCallbacks,
|
||||
VerbInput,
|
||||
VerbResult,
|
||||
create_verb_result,
|
||||
verb,
|
||||
)
|
||||
|
||||
from graphrag.index.cache import PipelineCache
|
||||
from graphrag.index.flows.create_final_text_units import (
|
||||
create_final_text_units as create_final_text_units_flow,
|
||||
)
|
||||
@ -26,10 +24,7 @@ from graphrag.index.utils.ds_util import get_named_input_table, get_required_inp
|
||||
@verb(name="create_final_text_units", treats_input_tables_as_immutable=True)
|
||||
async def create_final_text_units(
|
||||
input: VerbInput,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
runtime_storage: PipelineStorage,
|
||||
text_text_embed: dict | None = None,
|
||||
**_kwargs: dict,
|
||||
) -> VerbResult:
|
||||
"""All the steps to transform the text units."""
|
||||
@ -45,14 +40,11 @@ async def create_final_text_units(
|
||||
if final_covariates:
|
||||
final_covariates = cast(pd.DataFrame, final_covariates.table)
|
||||
|
||||
output = await create_final_text_units_flow(
|
||||
output = create_final_text_units_flow(
|
||||
text_units,
|
||||
final_entities,
|
||||
final_relationships,
|
||||
final_covariates,
|
||||
callbacks,
|
||||
cache,
|
||||
text_text_embed=text_text_embed,
|
||||
)
|
||||
|
||||
return create_verb_result(cast(Table, output))
|
||||
|
||||
@ -0,0 +1,70 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""All the steps to transform the text units."""
|
||||
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
import pandas as pd
|
||||
from datashaper import (
|
||||
Table,
|
||||
VerbCallbacks,
|
||||
VerbInput,
|
||||
VerbResult,
|
||||
create_verb_result,
|
||||
verb,
|
||||
)
|
||||
|
||||
from graphrag.index.cache import PipelineCache
|
||||
from graphrag.index.flows.generate_text_embeddings import (
|
||||
generate_text_embeddings as generate_text_embeddings_flow,
|
||||
)
|
||||
from graphrag.index.storage import PipelineStorage
|
||||
from graphrag.index.utils.ds_util import get_required_input_table
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@verb(name="generate_text_embeddings", treats_input_tables_as_immutable=True)
|
||||
async def generate_text_embeddings(
|
||||
input: VerbInput,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
storage: PipelineStorage,
|
||||
text_embed: dict,
|
||||
embedded_fields: set[str],
|
||||
embeddings_snapshot_enabled: bool = False,
|
||||
**_kwargs: dict,
|
||||
) -> VerbResult:
|
||||
"""All the steps to generate embeddings."""
|
||||
source = cast(pd.DataFrame, input.get_input())
|
||||
final_relationships = cast(
|
||||
pd.DataFrame, get_required_input_table(input, "relationships").table
|
||||
)
|
||||
final_text_units = cast(
|
||||
pd.DataFrame, get_required_input_table(input, "text_units").table
|
||||
)
|
||||
final_entities = cast(
|
||||
pd.DataFrame, get_required_input_table(input, "entities").table
|
||||
)
|
||||
|
||||
final_community_reports = cast(
|
||||
pd.DataFrame, get_required_input_table(input, "community_reports").table
|
||||
)
|
||||
|
||||
await generate_text_embeddings_flow(
|
||||
final_documents=source,
|
||||
final_relationships=final_relationships,
|
||||
final_text_units=final_text_units,
|
||||
final_entities=final_entities,
|
||||
final_community_reports=final_community_reports,
|
||||
callbacks=callbacks,
|
||||
cache=cache,
|
||||
storage=storage,
|
||||
text_embed_config=text_embed,
|
||||
embedded_fields=embedded_fields,
|
||||
embeddings_snapshot_enabled=embeddings_snapshot_enabled,
|
||||
)
|
||||
|
||||
return create_verb_result(cast(Table, pd.DataFrame()))
|
||||
@ -3,10 +3,9 @@
|
||||
|
||||
"""The LanceDB vector storage implementation package."""
|
||||
|
||||
import json
|
||||
import json # noqa: I001
|
||||
from typing import Any
|
||||
|
||||
import lancedb as lancedb
|
||||
import pyarrow as pa
|
||||
|
||||
from graphrag.model.types import TextEmbedder
|
||||
@ -16,6 +15,7 @@ from .base import (
|
||||
VectorStoreDocument,
|
||||
VectorStoreSearchResult,
|
||||
)
|
||||
import lancedb
|
||||
|
||||
|
||||
class LanceDBVectorStore(BaseVectorStore):
|
||||
|
||||
4
tests/fixtures/azure/settings.yml
vendored
4
tests/fixtures/azure/settings.yml
vendored
@ -6,9 +6,7 @@ embeddings:
|
||||
type: "azure_ai_search"
|
||||
url: ${AZURE_AI_SEARCH_URL_ENDPOINT}
|
||||
api_key: ${AZURE_AI_SEARCH_API_KEY}
|
||||
collection_name: "azure_ci"
|
||||
entity_name_description:
|
||||
title_column: "name"
|
||||
container_name: "azure_ci"
|
||||
|
||||
input:
|
||||
type: blob
|
||||
|
||||
36
tests/fixtures/min-csv/config.json
vendored
36
tests/fixtures/min-csv/config.json
vendored
@ -8,7 +8,8 @@
|
||||
2500
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 150
|
||||
"max_runtime": 150,
|
||||
"expected_artifacts": 0
|
||||
},
|
||||
"create_base_entity_graph": {
|
||||
"row_range": [
|
||||
@ -16,7 +17,8 @@
|
||||
2500
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 300
|
||||
"max_runtime": 300,
|
||||
"expected_artifacts": 1
|
||||
},
|
||||
"create_final_entities": {
|
||||
"row_range": [
|
||||
@ -29,7 +31,8 @@
|
||||
"graph_embedding"
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 300
|
||||
"max_runtime": 300,
|
||||
"expected_artifacts": 1
|
||||
},
|
||||
"create_final_relationships": {
|
||||
"row_range": [
|
||||
@ -37,7 +40,8 @@
|
||||
6000
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 150
|
||||
"max_runtime": 150,
|
||||
"expected_artifacts": 1
|
||||
},
|
||||
"create_final_nodes": {
|
||||
"row_range": [
|
||||
@ -52,7 +56,8 @@
|
||||
"level"
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 150
|
||||
"max_runtime": 150,
|
||||
"expected_artifacts": 1
|
||||
},
|
||||
"create_final_communities": {
|
||||
"row_range": [
|
||||
@ -60,7 +65,8 @@
|
||||
2500
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 150
|
||||
"max_runtime": 150,
|
||||
"expected_artifacts": 1
|
||||
},
|
||||
"create_final_community_reports": {
|
||||
"row_range": [
|
||||
@ -78,7 +84,8 @@
|
||||
"findings"
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 300
|
||||
"max_runtime": 300,
|
||||
"expected_artifacts": 1
|
||||
},
|
||||
"create_final_text_units": {
|
||||
"row_range": [
|
||||
@ -90,7 +97,8 @@
|
||||
"entity_ids"
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 150
|
||||
"max_runtime": 150,
|
||||
"expected_artifacts": 1
|
||||
},
|
||||
"create_final_documents": {
|
||||
"row_range": [
|
||||
@ -98,7 +106,17 @@
|
||||
2500
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 150
|
||||
"max_runtime": 150,
|
||||
"expected_artifacts": 1
|
||||
},
|
||||
"generate_text_embeddings": {
|
||||
"row_range": [
|
||||
1,
|
||||
2500
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 150,
|
||||
"expected_artifacts": 1
|
||||
}
|
||||
},
|
||||
"query_config": [
|
||||
|
||||
18
tests/fixtures/min-csv/settings.yml
vendored
18
tests/fixtures/min-csv/settings.yml
vendored
@ -5,19 +5,8 @@ embeddings:
|
||||
vector_store:
|
||||
type: "lancedb"
|
||||
db_uri: "./tests/fixtures/min-csv/lancedb"
|
||||
collection_name: "lancedb_ci"
|
||||
container_name: "lancedb_ci"
|
||||
overwrite: True
|
||||
store_in_table: True
|
||||
entity_name_description:
|
||||
title_column: "name"
|
||||
# id_column: "id"
|
||||
# entity_name: ...
|
||||
# relationship_description: ...
|
||||
# community_report_full_content: ...
|
||||
# community_report_summary: ...
|
||||
# community_report_title: ...
|
||||
# document_raw_content: ...
|
||||
# text_unit_text: ...
|
||||
|
||||
storage:
|
||||
type: file # or blob
|
||||
@ -29,4 +18,7 @@ reporting:
|
||||
type: file # or console, blob
|
||||
base_dir: "output/${timestamp}/reports"
|
||||
# connection_string: <azure_blob_storage_connection_string>
|
||||
# container_name: <azure_blob_storage_container_name>
|
||||
# container_name: <azure_blob_storage_container_name>
|
||||
|
||||
snapshots:
|
||||
embeddings: True
|
||||
39
tests/fixtures/text/config.json
vendored
39
tests/fixtures/text/config.json
vendored
@ -8,7 +8,8 @@
|
||||
2500
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 150
|
||||
"max_runtime": 150,
|
||||
"expected_artifacts": 0
|
||||
},
|
||||
"create_final_covariates": {
|
||||
"row_range": [
|
||||
@ -25,7 +26,8 @@
|
||||
"source_text"
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 300
|
||||
"max_runtime": 300,
|
||||
"expected_artifacts": 1
|
||||
},
|
||||
"create_base_entity_graph": {
|
||||
"row_range": [
|
||||
@ -33,7 +35,8 @@
|
||||
2500
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 300
|
||||
"max_runtime": 300,
|
||||
"expected_artifacts": 1
|
||||
},
|
||||
"create_final_entities": {
|
||||
"row_range": [
|
||||
@ -46,7 +49,8 @@
|
||||
"graph_embedding"
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 300
|
||||
"max_runtime": 300,
|
||||
"expected_artifacts": 1
|
||||
},
|
||||
"create_final_relationships": {
|
||||
"row_range": [
|
||||
@ -54,7 +58,8 @@
|
||||
6000
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 150
|
||||
"max_runtime": 150,
|
||||
"expected_artifacts": 1
|
||||
},
|
||||
"create_final_nodes": {
|
||||
"row_range": [
|
||||
@ -69,7 +74,8 @@
|
||||
"level"
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 150
|
||||
"max_runtime": 150,
|
||||
"expected_artifacts": 1
|
||||
},
|
||||
"create_final_communities": {
|
||||
"row_range": [
|
||||
@ -77,7 +83,8 @@
|
||||
2500
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 150
|
||||
"max_runtime": 150,
|
||||
"expected_artifacts": 1
|
||||
},
|
||||
"create_final_community_reports": {
|
||||
"row_range": [
|
||||
@ -95,7 +102,8 @@
|
||||
"findings"
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 300
|
||||
"max_runtime": 300,
|
||||
"expected_artifacts": 1
|
||||
},
|
||||
"create_final_text_units": {
|
||||
"row_range": [
|
||||
@ -107,7 +115,8 @@
|
||||
"entity_ids"
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 150
|
||||
"max_runtime": 150,
|
||||
"expected_artifacts": 1
|
||||
},
|
||||
"create_final_documents": {
|
||||
"row_range": [
|
||||
@ -115,7 +124,17 @@
|
||||
2500
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 150
|
||||
"max_runtime": 150,
|
||||
"expected_artifacts": 1
|
||||
},
|
||||
"generate_text_embeddings": {
|
||||
"row_range": [
|
||||
1,
|
||||
2500
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 150,
|
||||
"expected_artifacts": 1
|
||||
}
|
||||
},
|
||||
"query_config": [
|
||||
|
||||
10
tests/fixtures/text/settings.yml
vendored
10
tests/fixtures/text/settings.yml
vendored
@ -6,10 +6,7 @@ embeddings:
|
||||
type: "azure_ai_search"
|
||||
url: ${AZURE_AI_SEARCH_URL_ENDPOINT}
|
||||
api_key: ${AZURE_AI_SEARCH_API_KEY}
|
||||
collection_name: "simple_text_ci"
|
||||
store_in_table: True
|
||||
entity_name_description:
|
||||
title_column: "name"
|
||||
container_name: "simple_text_ci"
|
||||
|
||||
community_reports:
|
||||
prompt: "prompts/community_report.txt"
|
||||
@ -27,4 +24,7 @@ reporting:
|
||||
type: file # or console, blob
|
||||
base_dir: "output/${timestamp}/reports"
|
||||
# connection_string: <azure_blob_storage_connection_string>
|
||||
# container_name: <azure_blob_storage_container_name>
|
||||
# container_name: <azure_blob_storage_container_name>
|
||||
|
||||
snapshots:
|
||||
embeddings: True
|
||||
@ -172,6 +172,7 @@ class TestIndexer:
|
||||
stats = json.loads((artifacts / "stats.json").read_bytes().decode("utf-8"))
|
||||
|
||||
# Check all workflows run
|
||||
expected_artifacts = 0
|
||||
expected_workflows = set(workflow_config.keys())
|
||||
workflows = set(stats["workflows"].keys())
|
||||
assert (
|
||||
@ -180,6 +181,10 @@ class TestIndexer:
|
||||
|
||||
# [OPTIONAL] Check runtime
|
||||
for workflow in expected_workflows:
|
||||
# Check expected artifacts
|
||||
expected_artifacts = expected_artifacts + workflow_config[workflow].get(
|
||||
"expected_artifacts", 1
|
||||
)
|
||||
# Check max runtime
|
||||
max_runtime = workflow_config[workflow].get("max_runtime", None)
|
||||
if max_runtime:
|
||||
@ -189,38 +194,40 @@ class TestIndexer:
|
||||
|
||||
# Check artifacts
|
||||
artifact_files = os.listdir(artifacts)
|
||||
# check that the number of workflows matches the number of artifacts, but:
|
||||
# (1) do not count workflows with only transient output
|
||||
# (2) account for the stats.json file
|
||||
transient_workflows = [
|
||||
"workflow:create_base_text_units",
|
||||
]
|
||||
|
||||
# check that the number of workflows matches the number of artifacts
|
||||
assert (
|
||||
len(artifact_files)
|
||||
== (len(expected_workflows) - len(transient_workflows) + 1)
|
||||
len(artifact_files) == (expected_artifacts + 1)
|
||||
), f"Expected {len(expected_workflows) + 1} artifacts, found: {len(artifact_files)}"
|
||||
|
||||
for artifact in artifact_files:
|
||||
if artifact.endswith(".parquet"):
|
||||
output_df = pd.read_parquet(artifacts / artifact)
|
||||
artifact_name = artifact.split(".")[0]
|
||||
workflow = workflow_config[artifact_name]
|
||||
|
||||
# Check number of rows between range
|
||||
assert (
|
||||
workflow["row_range"][0]
|
||||
<= len(output_df)
|
||||
<= workflow["row_range"][1]
|
||||
), f"Expected between {workflow['row_range'][0]} and {workflow['row_range'][1]}, found: {len(output_df)} for file: {artifact}"
|
||||
try:
|
||||
workflow = workflow_config[artifact_name]
|
||||
|
||||
# Get non-nan rows
|
||||
nan_df = output_df.loc[
|
||||
:, ~output_df.columns.isin(workflow.get("nan_allowed_columns", []))
|
||||
]
|
||||
nan_df = nan_df[nan_df.isna().any(axis=1)]
|
||||
assert (
|
||||
len(nan_df) == 0
|
||||
), f"Found {len(nan_df)} rows with NaN values for file: {artifact} on columns: {nan_df.columns[nan_df.isna().any()].tolist()}"
|
||||
# Check number of rows between range
|
||||
assert (
|
||||
workflow["row_range"][0]
|
||||
<= len(output_df)
|
||||
<= workflow["row_range"][1]
|
||||
), f"Expected between {workflow['row_range'][0]} and {workflow['row_range'][1]}, found: {len(output_df)} for file: {artifact}"
|
||||
|
||||
# Get non-nan rows
|
||||
nan_df = output_df.loc[
|
||||
:,
|
||||
~output_df.columns.isin(
|
||||
workflow.get("nan_allowed_columns", [])
|
||||
),
|
||||
]
|
||||
nan_df = nan_df[nan_df.isna().any(axis=1)]
|
||||
assert (
|
||||
len(nan_df) == 0
|
||||
), f"Found {len(nan_df)} rows with NaN values for file: {artifact} on columns: {nan_df.columns[nan_df.isna().any()].tolist()}"
|
||||
except KeyError:
|
||||
log.warning("No workflow config found %s", artifact_name)
|
||||
|
||||
def __run_query(self, root: Path, query_config: dict[str, str]):
|
||||
command = [
|
||||
|
||||
@ -74,44 +74,6 @@ async def test_create_final_community_reports():
|
||||
assert actual["rank_explanation"][:1][0] == "<rating_explanation>"
|
||||
|
||||
|
||||
async def test_create_final_community_reports_with_embeddings():
|
||||
input_tables = load_input_tables([
|
||||
"workflow:create_final_nodes",
|
||||
"workflow:create_final_covariates",
|
||||
"workflow:create_final_relationships",
|
||||
"workflow:create_final_communities",
|
||||
])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
config["create_community_reports"]["strategy"]["llm"] = MOCK_LLM_CONFIG
|
||||
|
||||
config["skip_full_content_embedding"] = False
|
||||
config["community_report_full_content_embed"]["strategy"]["type"] = "mock"
|
||||
config["skip_summary_embedding"] = False
|
||||
config["community_report_summary_embed"]["strategy"]["type"] = "mock"
|
||||
config["skip_title_embedding"] = False
|
||||
config["community_report_title_embed"]["strategy"]["type"] = "mock"
|
||||
|
||||
steps = build_steps(config)
|
||||
|
||||
actual = await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
)
|
||||
|
||||
assert len(actual.columns) == len(expected.columns) + 3
|
||||
assert "full_content_embedding" in actual.columns
|
||||
assert len(actual["full_content_embedding"][:1][0]) == 3
|
||||
assert "summary_embedding" in actual.columns
|
||||
assert len(actual["summary_embedding"][:1][0]) == 3
|
||||
assert "title_embedding" in actual.columns
|
||||
assert len(actual["title_embedding"][:1][0]) == 3
|
||||
|
||||
|
||||
async def test_create_final_community_reports_missing_llm_throws():
|
||||
input_tables = load_input_tables([
|
||||
"workflow:create_final_nodes",
|
||||
|
||||
@ -23,8 +23,6 @@ async def test_create_final_documents():
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
config["skip_raw_content_embedding"] = True
|
||||
|
||||
steps = build_steps(config)
|
||||
|
||||
actual = await get_workflow_output(
|
||||
@ -37,34 +35,6 @@ async def test_create_final_documents():
|
||||
compare_outputs(actual, expected)
|
||||
|
||||
|
||||
async def test_create_final_documents_with_embeddings():
|
||||
input_tables = load_input_tables([
|
||||
"workflow:create_final_text_units",
|
||||
])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
config["skip_raw_content_embedding"] = False
|
||||
# default config has a detailed standard embed config
|
||||
# just override the strategy to mock so the rest of the required parameters are in place
|
||||
config["document_raw_content_embed"]["strategy"]["type"] = "mock"
|
||||
|
||||
steps = build_steps(config)
|
||||
|
||||
actual = await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
)
|
||||
|
||||
assert "raw_content_embedding" in actual.columns
|
||||
assert len(actual.columns) == len(expected.columns) + 1
|
||||
# the mock impl returns an array of 3 floats for each embedding
|
||||
assert len(actual["raw_content_embedding"][:1][0]) == 3
|
||||
|
||||
|
||||
async def test_create_final_documents_with_attribute_columns():
|
||||
input_tables = load_input_tables(["workflow:create_final_text_units"])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
@ -23,9 +23,6 @@ async def test_create_final_entities():
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
config["skip_name_embedding"] = True
|
||||
config["skip_description_embedding"] = True
|
||||
|
||||
steps = build_steps(config)
|
||||
|
||||
actual = await get_workflow_output(
|
||||
@ -50,83 +47,3 @@ async def test_create_final_entities():
|
||||
],
|
||||
)
|
||||
assert len(actual.columns) == len(expected.columns) - 1
|
||||
|
||||
|
||||
async def test_create_final_entities_with_name_embeddings():
|
||||
input_tables = load_input_tables([
|
||||
"workflow:create_base_entity_graph",
|
||||
])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
config["skip_name_embedding"] = False
|
||||
config["skip_description_embedding"] = True
|
||||
config["entity_name_embed"]["strategy"]["type"] = "mock"
|
||||
|
||||
steps = build_steps(config)
|
||||
|
||||
actual = await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
)
|
||||
|
||||
assert "name_embedding" in actual.columns
|
||||
assert len(actual.columns) == len(expected.columns)
|
||||
# the mock impl returns an array of 3 floats for each embedding
|
||||
assert len(actual["name_embedding"][:1][0]) == 3
|
||||
|
||||
|
||||
async def test_create_final_entities_with_description_embeddings():
|
||||
input_tables = load_input_tables([
|
||||
"workflow:create_base_entity_graph",
|
||||
])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
config["skip_name_embedding"] = True
|
||||
config["skip_description_embedding"] = False
|
||||
config["entity_name_description_embed"]["strategy"]["type"] = "mock"
|
||||
|
||||
steps = build_steps(config)
|
||||
|
||||
actual = await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
)
|
||||
|
||||
assert "description_embedding" in actual.columns
|
||||
assert len(actual.columns) == len(expected.columns)
|
||||
assert len(actual["description_embedding"][:1][0]) == 3
|
||||
|
||||
|
||||
async def test_create_final_entities_with_name_and_description_embeddings():
|
||||
input_tables = load_input_tables([
|
||||
"workflow:create_base_entity_graph",
|
||||
])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
config["skip_name_embedding"] = False
|
||||
config["skip_description_embedding"] = False
|
||||
config["entity_name_description_embed"]["strategy"]["type"] = "mock"
|
||||
config["entity_name_embed"]["strategy"]["type"] = "mock"
|
||||
|
||||
steps = build_steps(config)
|
||||
|
||||
actual = await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
)
|
||||
|
||||
assert "description_embedding" in actual.columns
|
||||
assert len(actual.columns) == len(expected.columns) + 1
|
||||
assert len(actual["description_embedding"][:1][0]) == 3
|
||||
|
||||
@ -24,8 +24,6 @@ async def test_create_final_relationships():
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
config["skip_description_embedding"] = True
|
||||
|
||||
steps = build_steps(config)
|
||||
|
||||
actual = await get_workflow_output(
|
||||
@ -36,32 +34,3 @@ async def test_create_final_relationships():
|
||||
)
|
||||
|
||||
compare_outputs(actual, expected)
|
||||
|
||||
|
||||
async def test_create_final_relationships_with_embeddings():
|
||||
input_tables = load_input_tables([
|
||||
"workflow:create_base_entity_graph",
|
||||
"workflow:create_final_nodes",
|
||||
])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
config["skip_description_embedding"] = False
|
||||
# default config has a detailed standard embed config
|
||||
# just override the strategy to mock so the rest of the required parameters are in place
|
||||
config["relationship_description_embed"]["strategy"]["type"] = "mock"
|
||||
|
||||
steps = build_steps(config)
|
||||
|
||||
actual = await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
)
|
||||
|
||||
assert "description_embedding" in actual.columns
|
||||
assert len(actual.columns) == len(expected.columns) + 1
|
||||
# the mock impl returns an array of 3 floats for each embedding
|
||||
assert len(actual["description_embedding"][:1][0]) == 3
|
||||
|
||||
@ -33,7 +33,6 @@ async def test_create_final_text_units():
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
config["covariates_enabled"] = True
|
||||
config["skip_text_unit_embedding"] = True
|
||||
|
||||
steps = build_steps(config)
|
||||
|
||||
@ -65,7 +64,6 @@ async def test_create_final_text_units_no_covariates():
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
config["covariates_enabled"] = False
|
||||
config["skip_text_unit_embedding"] = True
|
||||
|
||||
steps = build_steps(config)
|
||||
|
||||
@ -83,41 +81,3 @@ async def test_create_final_text_units_no_covariates():
|
||||
expected,
|
||||
["id", "text", "n_tokens", "document_ids", "entity_ids", "relationship_ids"],
|
||||
)
|
||||
|
||||
|
||||
async def test_create_final_text_units_with_embeddings():
|
||||
input_tables = load_input_tables([
|
||||
"workflow:create_base_text_units",
|
||||
"workflow:create_final_entities",
|
||||
"workflow:create_final_relationships",
|
||||
"workflow:create_final_covariates",
|
||||
])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
context = create_run_context(None, None, None)
|
||||
await context.runtime_storage.set(
|
||||
"base_text_units", input_tables["workflow:create_base_text_units"]
|
||||
)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
config["covariates_enabled"] = True
|
||||
config["skip_text_unit_embedding"] = False
|
||||
# default config has a detailed standard embed config
|
||||
# just override the strategy to mock so the rest of the required parameters are in place
|
||||
config["text_unit_text_embed"]["strategy"]["type"] = "mock"
|
||||
|
||||
steps = build_steps(config)
|
||||
|
||||
actual = await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert "text_embedding" in actual.columns
|
||||
assert len(actual.columns) == len(expected.columns) + 1
|
||||
# the mock impl returns an array of 3 floats for each embedding
|
||||
assert len(actual["text_embedding"][:1][0]) == 3
|
||||
|
||||
82
tests/verbs/test_generate_text_embeddings.py
Normal file
82
tests/verbs/test_generate_text_embeddings.py
Normal file
@ -0,0 +1,82 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from io import BytesIO
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.index.config.embeddings import (
|
||||
all_embeddings,
|
||||
)
|
||||
from graphrag.index.run.utils import create_run_context
|
||||
from graphrag.index.workflows.v1.generate_text_embeddings import (
|
||||
build_steps,
|
||||
workflow_name,
|
||||
)
|
||||
|
||||
from .util import (
|
||||
get_config_for_workflow,
|
||||
get_workflow_output,
|
||||
load_input_tables,
|
||||
)
|
||||
|
||||
|
||||
async def test_generate_text_embeddings():
|
||||
input_tables = load_input_tables(
|
||||
inputs=[
|
||||
"workflow:create_final_documents",
|
||||
"workflow:create_final_relationships",
|
||||
"workflow:create_final_text_units",
|
||||
"workflow:create_final_entities",
|
||||
"workflow:create_final_community_reports",
|
||||
]
|
||||
)
|
||||
context = create_run_context(None, None, None)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
config["text_embed"]["strategy"]["type"] = "mock"
|
||||
config["snapshot_embeddings"] = True
|
||||
|
||||
config["embedded_fields"] = all_embeddings
|
||||
|
||||
steps = build_steps(config)
|
||||
|
||||
await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
context,
|
||||
)
|
||||
|
||||
parquet_files = context.storage.keys()
|
||||
|
||||
for field in all_embeddings:
|
||||
assert f"embeddings.{field}.parquet" in parquet_files
|
||||
|
||||
# entity description should always be here, let's assert its format
|
||||
entity_description_embeddings_buffer = BytesIO(
|
||||
await context.storage.get(
|
||||
"embeddings.entity.description.parquet", as_bytes=True
|
||||
)
|
||||
)
|
||||
entity_description_embeddings = pd.read_parquet(
|
||||
entity_description_embeddings_buffer
|
||||
)
|
||||
assert len(entity_description_embeddings.columns) == 2
|
||||
assert "id" in entity_description_embeddings.columns
|
||||
assert "embedding" in entity_description_embeddings.columns
|
||||
|
||||
# every other embedding is optional but we've turned them all on, so check a random one
|
||||
document_raw_content_embeddings_buffer = BytesIO(
|
||||
await context.storage.get(
|
||||
"embeddings.document.raw_content.parquet", as_bytes=True
|
||||
)
|
||||
)
|
||||
document_raw_content_embeddings = pd.read_parquet(
|
||||
document_raw_content_embeddings_buffer
|
||||
)
|
||||
assert len(document_raw_content_embeddings.columns) == 2
|
||||
assert "id" in document_raw_content_embeddings.columns
|
||||
assert "embedding" in document_raw_content_embeddings.columns
|
||||
Loading…
x
Reference in New Issue
Block a user