New workflow to generate embeddings in a single workflow (#1296)

* New workflow to generate embeddings in a single workflow

* New workflow to generate embeddings in a single workflow

* version change

* clean tests without any embeddings references

* clean tests without any embeddings references

* remove code

* feedback implemented

* changes in logic

* feedback implemented

* store in table bug fixed

* smoke test for generate_text_embeddings workflow

* smoke test fix

* add generate_text_embeddings to the list of transient workflows

* smoke tests

* fix

* ruff formatting updates

* fix

* smoke test fixed

* smoke test fixed

* fix lancedb import

* smoke test fix

* ignore sorting

* smoke test fixed

* smoke test fixed

* check smoke test

* smoke test fixed

* change config for vector store

* format fix

* vector store changes

* revert debug profile back to empty filepath

* merge conflict solved

* merge conflict solved

* format fixed

* format fixed

* fix return dataframe

* snapshot fix

* format fix

* embeddings param implemented

* validation fixes

* fix map

* fix map

* fix properties

* config updates

* smoke test fixed

* settings change

* Update collection config and rework back-compat

* Repalce . with - for embedding store

---------

Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
Co-authored-by: Josh Bradley <joshbradley@microsoft.com>
Co-authored-by: Nathan Evans <github@talkswithnumbers.com>
This commit is contained in:
gaudyb 2024-11-01 16:01:35 -06:00 committed by GitHub
parent 8302920ac8
commit 17658c5df8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
51 changed files with 693 additions and 804 deletions

View File

@ -0,0 +1,4 @@
{
"type": "minor",
"description": "embeddings moved to a different workflow"
}

View File

@ -85,8 +85,8 @@ This is the base LLM configuration section. Other steps may override this config
- `async_mode` (see Async Mode top-level config)
- `batch_size` **int** - The maximum batch size to use.
- `batch_max_tokens` **int** - The maximum batch # of tokens.
- `target` **required|all** - Determines which set of embeddings to emit.
- `skip` **list[str]** - Which embeddings to skip.
- `target` **required|all|none** - Determines which set of embeddings to emit.
- `skip` **list[str]** - Which embeddings to skip. Only useful if target=all to customize the list.
- `vector_store` **dict** - The vector store to use. Configured for lancedb by default.
- `type` **str** - `lancedb` or `azure_ai_search`. Default=`lancedb`
- `db_uri` **str** (only for lancedb) - The database uri. Default=`storage.base_dir/lancedb`
@ -94,7 +94,7 @@ This is the base LLM configuration section. Other steps may override this config
- `api_key` **str** (optional - only for AI Search) - The AI Search api key to use.
- `audience` **str** (only for AI Search) - Audience for managed identity token if managed identity authentication is used.
- `overwrite` **bool** (only used at index creation time) - Overwrite collection if it exist. Default=`True`
- `collection_name` **str** - The name of a vector collection. Default=`entity_description_embeddings`
- `container_name` **str** - The name of a vector container. This stores all indexes (tables) for a given dataset ingest. Default=`default`
- `strategy` **dict** - Fully override the text-embedding strategy.
## chunks

View File

@ -108,7 +108,7 @@
"# load description embeddings to an in-memory lancedb vectorstore\n",
"# to connect to a remote db, specify url and port values.\n",
"description_embedding_store = LanceDBVectorStore(\n",
" collection_name=\"entity_description_embeddings\",\n",
" collection_name=\"entity.description\",\n",
")\n",
"description_embedding_store.connect(db_uri=LANCEDB_URI)\n",
"entity_description_embeddings = store_entity_semantic_embeddings(\n",

View File

@ -299,7 +299,7 @@
"entities = read_indexer_entities(entity_df, entity_embedding_df, COMMUNITY_LEVEL)\n",
"\n",
"description_embedding_store = LanceDBVectorStore(\n",
" collection_name=\"entity_description_embeddings\",\n",
" collection_name=\"entity.description\",\n",
")\n",
"description_embedding_store.connect(db_uri=LANCEDB_URI)\n",
"entity_description_embeddings = store_entity_semantic_embeddings(\n",

View File

@ -59,13 +59,7 @@ async def build_index(
msg = "Cannot resume and update a run at the same time."
raise ValueError(msg)
# TODO: must update filepath of lancedb (if used) until the new config engine has been implemented
# TODO: remove the type ignore annotations below once the new config engine has been refactored
vector_store_type = config.embeddings.vector_store["type"] # type: ignore
if vector_store_type == VectorStoreType.LanceDB:
db_uri = config.embeddings.vector_store["db_uri"] # type: ignore
lancedb_dir = Path(config.root_dir).resolve() / db_uri
config.embeddings.vector_store["db_uri"] = str(lancedb_dir) # type: ignore
config = _patch_vector_config(config)
pipeline_config = create_pipeline_config(config)
pipeline_cache = (
@ -90,3 +84,22 @@ async def build_index(
progress_reporter.success(output.workflow)
progress_reporter.info(str(output.result))
return outputs
def _patch_vector_config(config: GraphRagConfig):
"""Back-compat patch to ensure a default vector store configuration."""
if not config.embeddings.vector_store:
config.embeddings.vector_store = {
"type": "lancedb",
"db_uri": "output/lancedb",
"container_name": "default",
"overwrite": True,
}
# TODO: must update filepath of lancedb (if used) until the new config engine has been implemented
# TODO: remove the type ignore annotations below once the new config engine has been refactored
vector_store_type = config.embeddings.vector_store["type"] # type: ignore
if vector_store_type == VectorStoreType.LanceDB:
db_uri = config.embeddings.vector_store["db_uri"] # type: ignore
lancedb_dir = Path(config.root_dir).resolve() / db_uri
config.embeddings.vector_store["db_uri"] = str(lancedb_dir) # type: ignore
return config

View File

@ -182,56 +182,22 @@ async def local_search(
------
TODO: Document any exceptions to expect.
"""
#################################### BEGIN PATCH ####################################
# TODO: remove the following patch that checks for a vector_store prior to v1 release
# TODO: this is a backwards compatibility patch that injects the default vector_store settings into the config if it is not present
# Only applicable in situations involving a local vector_store (lancedb). The general idea:
# if vector_store not in config:
# 1. assume user is running local if vector_store is not in config
# 2. insert default vector_store in config
# 3 .create lancedb vector_store instance
# 4. upload vector embeddings from the input dataframes to the vector_store
backwards_compatible = False
if not config.embeddings.vector_store:
backwards_compatible = True
from graphrag.query.input.loaders.dfs import store_entity_semantic_embeddings
from graphrag.vector_stores.lancedb import LanceDBVectorStore
config.embeddings.vector_store = {
"type": "lancedb",
"db_uri": f"{Path(config.storage.base_dir)}/lancedb",
"collection_name": "entity_description_embeddings",
"overwrite": True,
}
_entities = read_indexer_entities(nodes, entities, community_level)
description_embedding_store = LanceDBVectorStore(
db_uri=config.embeddings.vector_store["db_uri"],
collection_name=config.embeddings.vector_store["collection_name"],
overwrite=config.embeddings.vector_store["overwrite"],
)
description_embedding_store.connect(
db_uri=config.embeddings.vector_store["db_uri"]
)
# dump embeddings from the entities list to the description_embedding_store
store_entity_semantic_embeddings(
entities=_entities, vectorstore=description_embedding_store
)
#################################### END PATCH ####################################
config = _patch_vector_store(config, nodes, entities, community_level)
# TODO: update filepath of lancedb (if used) until the new config engine has been implemented
# TODO: remove the type ignore annotations below once the new config engine has been refactored
vector_store_type = config.embeddings.vector_store.get("type") # type: ignore
vector_store_args = config.embeddings.vector_store
if vector_store_type == VectorStoreType.LanceDB and not backwards_compatible:
if vector_store_type == VectorStoreType.LanceDB:
db_uri = config.embeddings.vector_store["db_uri"] # type: ignore
lancedb_dir = Path(config.root_dir).resolve() / db_uri
vector_store_args["db_uri"] = str(lancedb_dir) # type: ignore
reporter.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore
if not backwards_compatible: # can remove this check and always set the description_embedding_store before v1 release
description_embedding_store = _get_embedding_description_store(
config_args=vector_store_args, # type: ignore
)
description_embedding_store = _get_embedding_description_store(
config_args=vector_store_args, # type: ignore
)
_entities = read_indexer_entities(nodes, entities, community_level)
_covariates = read_indexer_covariates(covariates) if covariates is not None else []
@ -289,56 +255,22 @@ async def local_search_streaming(
------
TODO: Document any exceptions to expect.
"""
#################################### BEGIN PATCH ####################################
# TODO: remove the following patch that checks for a vector_store prior to v1 release
# TODO: this is a backwards compatibility patch that injects the default vector_store settings into the config if it is not present
# Only applicable in situations involving a local vector_store (lancedb). The general idea:
# if vector_store not in config:
# 1. assume user is running local if vector_store is not in config
# 2. insert default vector_store in config
# 3 .create lancedb vector_store instance
# 4. upload vector embeddings from the input dataframes to the vector_store
backwards_compatible = False
if not config.embeddings.vector_store:
backwards_compatible = True
from graphrag.query.input.loaders.dfs import store_entity_semantic_embeddings
from graphrag.vector_stores.lancedb import LanceDBVectorStore
config.embeddings.vector_store = {
"type": "lancedb",
"db_uri": f"{Path(config.storage.base_dir)}/lancedb",
"collection_name": "entity_description_embeddings",
"overwrite": True,
}
_entities = read_indexer_entities(nodes, entities, community_level)
description_embedding_store = LanceDBVectorStore(
db_uri=config.embeddings.vector_store["db_uri"],
collection_name=config.embeddings.vector_store["collection_name"],
overwrite=config.embeddings.vector_store["overwrite"],
)
description_embedding_store.connect(
db_uri=config.embeddings.vector_store["db_uri"]
)
# dump embeddings from the entities list to the description_embedding_store
store_entity_semantic_embeddings(
entities=_entities, vectorstore=description_embedding_store
)
#################################### END PATCH ####################################
config = _patch_vector_store(config, nodes, entities, community_level)
# TODO: must update filepath of lancedb (if used) until the new config engine has been implemented
# TODO: remove the type ignore annotations below once the new config engine has been refactored
vector_store_type = config.embeddings.vector_store.get("type") # type: ignore
vector_store_args = config.embeddings.vector_store
if vector_store_type == VectorStoreType.LanceDB and not backwards_compatible:
if vector_store_type == VectorStoreType.LanceDB:
db_uri = config.embeddings.vector_store["db_uri"] # type: ignore
lancedb_dir = Path(config.root_dir).resolve() / db_uri
vector_store_args["db_uri"] = str(lancedb_dir) # type: ignore
reporter.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore
if not backwards_compatible: # can remove this check and always set the description_embedding_store before v1 release
description_embedding_store = _get_embedding_description_store(
config_args=vector_store_args, # type: ignore
)
description_embedding_store = _get_embedding_description_store(
conf_args=vector_store_args, # type: ignore
)
_entities = read_indexer_entities(nodes, entities, community_level)
_covariates = read_indexer_covariates(covariates) if covariates is not None else []
@ -368,13 +300,55 @@ async def local_search_streaming(
yield stream_chunk
def _patch_vector_store(
config: GraphRagConfig,
nodes: pd.DataFrame,
entities: pd.DataFrame,
community_level: int,
) -> GraphRagConfig:
# TODO: remove the following patch that checks for a vector_store prior to v1 release
# TODO: this is a backwards compatibility patch that injects the default vector_store settings into the config if it is not present
# Only applicable in situations involving a local vector_store (lancedb). The general idea:
# if vector_store not in config:
# 1. assume user is running local if vector_store is not in config
# 2. insert default vector_store in config
# 3 .create lancedb vector_store instance
# 4. upload vector embeddings from the input dataframes to the vector_store
if not config.embeddings.vector_store:
from graphrag.query.input.loaders.dfs import store_entity_semantic_embeddings
from graphrag.vector_stores.lancedb import LanceDBVectorStore
config.embeddings.vector_store = {
"type": "lancedb",
"db_uri": f"{Path(config.storage.base_dir)}/lancedb",
"container_name": "default",
"overwrite": True,
}
description_embedding_store = LanceDBVectorStore(
db_uri=config.embeddings.vector_store["db_uri"],
collection_name="default-entity-description",
overwrite=config.embeddings.vector_store["overwrite"],
)
description_embedding_store.connect(
db_uri=config.embeddings.vector_store["db_uri"]
)
# dump embeddings from the entities list to the description_embedding_store
_entities = read_indexer_entities(nodes, entities, community_level)
store_entity_semantic_embeddings(
entities=_entities, vectorstore=description_embedding_store
)
return config
def _get_embedding_description_store(
config_args: dict,
):
"""Get the embedding description store."""
vector_store_type = config_args["type"]
collection_name = f"{config_args['container_name']}-entity-description"
description_embedding_store = VectorStoreFactory.get_vector_store(
vector_store_type=vector_store_type, kwargs=config_args
vector_store_type=vector_store_type,
kwargs={**config_args, "collection_name": collection_name},
)
description_embedding_store.connect(**config_args)
return description_embedding_store

View File

@ -115,6 +115,7 @@ def run_local_search(
config.storage.base_dir = str(data_dir) if data_dir else config.storage.base_dir
resolve_paths(config)
# TODO remove optional create_final_entities_description_embeddings.parquet to delete backwards compatibility
dataframe_dict = _resolve_parquet_files(
root_dir=root_dir,
config=config,
@ -125,7 +126,9 @@ def run_local_search(
"create_final_relationships.parquet",
"create_final_entities.parquet",
],
optional_list=["create_final_covariates.parquet"],
optional_list=[
"create_final_covariates.parquet",
],
)
final_nodes: pd.DataFrame = dataframe_dict["create_final_nodes"]
final_community_reports: pd.DataFrame = dataframe_dict[

View File

@ -414,6 +414,7 @@ def create_graphrag_config(
raw_entities=reader.bool("raw_entities") or defs.SNAPSHOTS_RAW_ENTITIES,
top_level_nodes=reader.bool("top_level_nodes")
or defs.SNAPSHOTS_TOP_LEVEL_NODES,
embeddings=reader.bool("embeddings") or defs.SNAPSHOTS_EMBEDDINGS,
)
with reader.envvar_prefix(Section.umap), reader.use(values.get("umap")):
umap_model = UmapConfig(

View File

@ -82,6 +82,7 @@ REPORTING_BASE_DIR = "logs"
SNAPSHOTS_GRAPHML = False
SNAPSHOTS_RAW_ENTITIES = False
SNAPSHOTS_TOP_LEVEL_NODES = False
SNAPSHOTS_EMBEDDINGS = False
STORAGE_BASE_DIR = "output"
STORAGE_TYPE = StorageType.file
SUMMARIZE_DESCRIPTIONS_MAX_LENGTH = 500
@ -91,7 +92,7 @@ UPDATE_STORAGE_BASE_DIR = "update_output"
VECTOR_STORE = f"""
type: {VectorStoreType.LanceDB.value}
db_uri: '{(Path(STORAGE_BASE_DIR) / "lancedb")!s}'
collection_name: entity_description_embeddings
collection_name: default
overwrite: true\
"""

View File

@ -86,6 +86,7 @@ class TextEmbeddingTarget(str, Enum):
all = "all"
required = "required"
none = "none"
def __repr__(self):
"""Get a string representation."""

View File

@ -23,3 +23,7 @@ class SnapshotsConfig(BaseModel):
description="A flag indicating whether to take snapshots of top-level nodes.",
default=defs.SNAPSHOTS_TOP_LEVEL_NODES,
)
embeddings: bool = Field(
description="A flag indicating whether to take snapshots of embeddings.",
default=defs.SNAPSHOTS_EMBEDDINGS,
)

View File

@ -11,6 +11,18 @@ from .cache import (
PipelineMemoryCacheConfig,
PipelineNoneCacheConfig,
)
from .embeddings import (
all_embeddings,
community_full_content_embedding,
community_summary_embedding,
community_title_embedding,
document_raw_content_embedding,
entity_description_embedding,
entity_name_embedding,
relationship_description_embedding,
required_embeddings,
text_unit_text_embedding,
)
from .input import (
PipelineCSVInputConfig,
PipelineInputConfig,
@ -66,4 +78,14 @@ __all__ = [
"PipelineWorkflowConfig",
"PipelineWorkflowReference",
"PipelineWorkflowStep",
"all_embeddings",
"community_full_content_embedding",
"community_summary_embedding",
"community_title_embedding",
"document_raw_content_embedding",
"entity_description_embedding",
"entity_name_embedding",
"relationship_description_embedding",
"required_embeddings",
"text_unit_text_embedding",
]

View File

@ -0,0 +1,25 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing embeddings values."""
entity_name_embedding = "entity.name"
entity_description_embedding = "entity.description"
relationship_description_embedding = "relationship.description"
document_raw_content_embedding = "document.raw_content"
community_title_embedding = "community.title"
community_summary_embedding = "community.summary"
community_full_content_embedding = "community.full_content"
text_unit_text_embedding = "text_unit.text"
all_embeddings: set[str] = {
entity_name_embedding,
entity_description_embedding,
relationship_description_embedding,
document_raw_content_embedding,
community_title_embedding,
community_summary_embedding,
community_full_content_embedding,
text_unit_text_embedding,
}
required_embeddings: set[str] = {entity_description_embedding}

View File

@ -22,6 +22,10 @@ from graphrag.index.config.cache import (
PipelineMemoryCacheConfig,
PipelineNoneCacheConfig,
)
from graphrag.index.config.embeddings import (
all_embeddings,
required_embeddings,
)
from graphrag.index.config.input import (
PipelineCSVInputConfig,
PipelineInputConfigTypes,
@ -56,33 +60,11 @@ from graphrag.index.workflows.default_workflows import (
create_final_nodes,
create_final_relationships,
create_final_text_units,
generate_text_embeddings,
)
log = logging.getLogger(__name__)
entity_name_embedding = "entity.name"
entity_description_embedding = "entity.description"
relationship_description_embedding = "relationship.description"
document_raw_content_embedding = "document.raw_content"
community_title_embedding = "community.title"
community_summary_embedding = "community.summary"
community_full_content_embedding = "community.full_content"
text_unit_text_embedding = "text_unit.text"
all_embeddings: set[str] = {
entity_name_embedding,
entity_description_embedding,
relationship_description_embedding,
document_raw_content_embedding,
community_title_embedding,
community_summary_embedding,
community_full_content_embedding,
text_unit_text_embedding,
}
required_embeddings: set[str] = {entity_description_embedding}
builtin_document_attributes: set[str] = {
"id",
"source",
@ -121,11 +103,12 @@ def create_pipeline_config(settings: GraphRagConfig, verbose=False) -> PipelineC
),
cache=_get_cache_config(settings),
workflows=[
*_document_workflows(settings, embedded_fields),
*_text_unit_workflows(settings, covariates_enabled, embedded_fields),
*_graph_workflows(settings, embedded_fields),
*_community_workflows(settings, covariates_enabled, embedded_fields),
*_document_workflows(settings),
*_text_unit_workflows(settings, covariates_enabled),
*_graph_workflows(settings),
*_community_workflows(settings, covariates_enabled),
*(_covariate_workflows(settings) if covariates_enabled else []),
*(_embeddings_workflows(settings, embedded_fields)),
],
)
@ -138,9 +121,11 @@ def create_pipeline_config(settings: GraphRagConfig, verbose=False) -> PipelineC
def _get_embedded_fields(settings: GraphRagConfig) -> set[str]:
match settings.embeddings.target:
case TextEmbeddingTarget.all:
return all_embeddings - {*settings.embeddings.skip}
return all_embeddings.difference(settings.embeddings.skip)
case TextEmbeddingTarget.required:
return required_embeddings
case TextEmbeddingTarget.none:
return set()
case _:
msg = f"Unknown embeddings target: {settings.embeddings.target}"
raise ValueError(msg)
@ -163,11 +148,8 @@ def _log_llm_settings(settings: GraphRagConfig) -> None:
def _document_workflows(
settings: GraphRagConfig, embedded_fields: set[str]
settings: GraphRagConfig,
) -> list[PipelineWorkflowReference]:
skip_document_raw_content_embedding = (
document_raw_content_embedding not in embedded_fields
)
return [
PipelineWorkflowReference(
name=create_final_documents,
@ -176,15 +158,6 @@ def _document_workflows(
{*(settings.input.document_attribute_columns)}
- builtin_document_attributes
),
"document_raw_content_embed": _get_embedding_settings(
settings.embeddings,
"document_raw_content",
{
"title_column": "raw_content",
"collection_name": "final_documents_raw_content_embedding",
},
),
"skip_raw_content_embedding": skip_document_raw_content_embedding,
},
),
]
@ -193,9 +166,7 @@ def _document_workflows(
def _text_unit_workflows(
settings: GraphRagConfig,
covariates_enabled: bool,
embedded_fields: set[str],
) -> list[PipelineWorkflowReference]:
skip_text_unit_embedding = text_unit_text_embedding not in embedded_fields
return [
PipelineWorkflowReference(
name=create_base_text_units,
@ -211,13 +182,7 @@ def _text_unit_workflows(
PipelineWorkflowReference(
name=create_final_text_units,
config={
"text_unit_text_embed": _get_embedding_settings(
settings.embeddings,
"text_unit_text",
{"title_column": "text", "collection_name": "text_units_embedding"},
),
"covariates_enabled": covariates_enabled,
"skip_text_unit_embedding": skip_text_unit_embedding,
},
),
]
@ -225,7 +190,6 @@ def _text_unit_workflows(
def _get_embedding_settings(
settings: TextEmbeddingConfig,
embedding_name: str,
vector_store_params: dict | None = None,
) -> dict:
vector_store_settings = settings.vector_store
@ -243,20 +207,10 @@ def _get_embedding_settings(
# This ensures the vector store config is part of the strategy and not the global config
return {
"strategy": strategy,
"embedding_name": embedding_name,
}
def _graph_workflows(
settings: GraphRagConfig, embedded_fields: set[str]
) -> list[PipelineWorkflowReference]:
skip_entity_name_embedding = entity_name_embedding not in embedded_fields
skip_entity_description_embedding = (
entity_description_embedding not in embedded_fields
)
skip_relationship_description_embedding = (
relationship_description_embedding not in embedded_fields
)
def _graph_workflows(settings: GraphRagConfig) -> list[PipelineWorkflowReference]:
return [
PipelineWorkflowReference(
name=create_base_entity_graph,
@ -286,40 +240,11 @@ def _graph_workflows(
),
PipelineWorkflowReference(
name=create_final_entities,
config={
"entity_name_embed": _get_embedding_settings(
settings.embeddings,
"entity_name",
{
"title_column": "name",
"collection_name": "entity_name_embeddings",
},
),
"entity_name_description_embed": _get_embedding_settings(
settings.embeddings,
"entity_name_description",
{
"title_column": "description",
"collection_name": "entity_description_embeddings",
},
),
"skip_name_embedding": skip_entity_name_embedding,
"skip_description_embedding": skip_entity_description_embedding,
},
config={},
),
PipelineWorkflowReference(
name=create_final_relationships,
config={
"relationship_description_embed": _get_embedding_settings(
settings.embeddings,
"relationship_description",
{
"title_column": "description",
"collection_name": "relationships_description_embeddings",
},
),
"skip_description_embedding": skip_relationship_description_embedding,
},
config={},
),
PipelineWorkflowReference(
name=create_final_nodes,
@ -332,24 +257,14 @@ def _graph_workflows(
def _community_workflows(
settings: GraphRagConfig, covariates_enabled: bool, embedded_fields: set[str]
settings: GraphRagConfig, covariates_enabled: bool
) -> list[PipelineWorkflowReference]:
skip_community_title_embedding = community_title_embedding not in embedded_fields
skip_community_summary_embedding = (
community_summary_embedding not in embedded_fields
)
skip_community_full_content_embedding = (
community_full_content_embedding not in embedded_fields
)
return [
PipelineWorkflowReference(name=create_final_communities),
PipelineWorkflowReference(
name=create_final_community_reports,
config={
"covariates_enabled": covariates_enabled,
"skip_title_embedding": skip_community_title_embedding,
"skip_summary_embedding": skip_community_summary_embedding,
"skip_full_content_embedding": skip_community_full_content_embedding,
"create_community_reports": {
**settings.community_reports.parallelization.model_dump(),
"async_mode": settings.community_reports.async_mode,
@ -357,27 +272,6 @@ def _community_workflows(
settings.root_dir
),
},
"community_report_full_content_embed": _get_embedding_settings(
settings.embeddings,
"community_report_full_content",
{
"title_column": "full_content",
"collection_name": "final_community_reports_full_content_embedding",
},
),
"community_report_summary_embed": _get_embedding_settings(
settings.embeddings,
"community_report_summary",
{
"title_column": "summary",
"collection_name": "final_community_reports_summary_embedding",
},
),
"community_report_title_embed": _get_embedding_settings(
settings.embeddings,
"community_report_title",
{"title_column": "title"},
),
},
),
]
@ -401,6 +295,21 @@ def _covariate_workflows(
]
def _embeddings_workflows(
settings: GraphRagConfig, embedded_fields: set[str]
) -> list[PipelineWorkflowReference]:
return [
PipelineWorkflowReference(
name=generate_text_embeddings,
config={
"snapshot_embeddings": settings.snapshots.embeddings,
"text_embed": _get_embedding_settings(settings.embeddings),
"embedded_fields": embedded_fields,
},
),
]
def _get_pipeline_input_config(
settings: GraphRagConfig,
) -> PipelineInputConfigTypes:

View File

@ -31,7 +31,6 @@ from graphrag.index.graph.extractors.community_reports.schemas import (
NODE_ID,
NODE_NAME,
)
from graphrag.index.operations.embed_text import embed_text
from graphrag.index.operations.summarize_communities import (
prepare_community_reports,
restore_community_hierarchy,
@ -49,9 +48,6 @@ async def create_final_community_reports(
summarization_strategy: dict,
async_mode: AsyncType = AsyncType.AsyncIO,
num_threads: int = 4,
full_content_text_embed: dict | None = None,
summary_text_embed: dict | None = None,
title_text_embed: dict | None = None,
) -> pd.DataFrame:
"""All the steps to transform community reports."""
nodes = _prep_nodes(nodes_input)
@ -86,39 +82,6 @@ async def create_final_community_reports(
lambda _x: str(uuid4())
)
# Embed full content if not skipped
if full_content_text_embed:
community_reports["full_content_embedding"] = await embed_text(
community_reports,
callbacks,
cache,
column="full_content",
strategy=full_content_text_embed["strategy"],
embedding_name="community_report_full_content",
)
# Embed summary if not skipped
if summary_text_embed:
community_reports["summary_embedding"] = await embed_text(
community_reports,
callbacks,
cache,
column="summary",
strategy=summary_text_embed["strategy"],
embedding_name="community_report_summary",
)
# Embed title if not skipped
if title_text_embed:
community_reports["title_embedding"] = await embed_text(
community_reports,
callbacks,
cache,
column="title",
strategy=title_text_embed["strategy"],
embedding_name="community_report_title",
)
# Merge by community and it with communities to add size and period
return community_reports.merge(
communities_input.loc[:, ["id", "size", "period"]],

View File

@ -4,21 +4,12 @@
"""All the steps to transform final documents."""
import pandas as pd
from datashaper import (
VerbCallbacks,
)
from graphrag.index.cache import PipelineCache
from graphrag.index.operations.embed_text import embed_text
async def create_final_documents(
def create_final_documents(
documents: pd.DataFrame,
text_units: pd.DataFrame,
callbacks: VerbCallbacks,
cache: PipelineCache,
document_attribute_columns: list[str] | None = None,
raw_content_text_embed: dict | None = None,
) -> pd.DataFrame:
"""All the steps to transform final documents."""
exploded = (
@ -72,13 +63,4 @@ async def create_final_documents(
# Drop the original attribute columns after collapsing them
rejoined.drop(columns=document_attribute_columns, inplace=True)
if raw_content_text_embed:
rejoined["raw_content_embedding"] = await embed_text(
rejoined,
callbacks,
cache,
column="raw_content",
strategy=raw_content_text_embed["strategy"],
)
return rejoined

View File

@ -8,18 +8,13 @@ from datashaper import (
VerbCallbacks,
)
from graphrag.index.cache import PipelineCache
from graphrag.index.operations.embed_text import embed_text
from graphrag.index.operations.split_text import split_text
from graphrag.index.operations.unpack_graph import unpack_graph
async def create_final_entities(
def create_final_entities(
entity_graph: pd.DataFrame,
callbacks: VerbCallbacks,
cache: PipelineCache,
name_text_embed: dict | None = None,
description_text_embed: dict | None = None,
) -> pd.DataFrame:
"""All the steps to transform final entities."""
# Process nodes
@ -44,37 +39,6 @@ async def create_final_entities(
nodes = nodes.loc[nodes["name"].notna()]
# Split 'source_id' column into 'text_unit_ids'
nodes = split_text(
return split_text(
nodes, column="source_id", separator=",", to="text_unit_ids"
).drop(columns=["source_id"])
# Embed name if not skipped
if name_text_embed:
nodes["name_embedding"] = await embed_text(
nodes,
callbacks,
cache,
column="name",
strategy=name_text_embed["strategy"],
embedding_name="entity_name",
)
# Embed description if not skipped
if description_text_embed:
# Concatenate 'name' and 'description' and embed
nodes["name_description"] = nodes["name"] + ":" + nodes["description"]
nodes["description_embedding"] = await embed_text(
nodes,
callbacks,
cache,
column="name_description",
strategy=description_text_embed["strategy"],
embedding_name="entity_name_description",
)
# Drop rows with NaN 'description_embedding' if not using vector store
if not description_text_embed.get("strategy", {}).get("vector_store"):
nodes = nodes.loc[nodes["description_embedding"].notna()]
nodes.drop(columns="name_description", inplace=True)
return nodes

View File

@ -10,20 +10,16 @@ from datashaper import (
VerbCallbacks,
)
from graphrag.index.cache import PipelineCache
from graphrag.index.operations.compute_edge_combined_degree import (
compute_edge_combined_degree,
)
from graphrag.index.operations.embed_text import embed_text
from graphrag.index.operations.unpack_graph import unpack_graph
async def create_final_relationships(
def create_final_relationships(
entity_graph: pd.DataFrame,
nodes: pd.DataFrame,
callbacks: VerbCallbacks,
cache: PipelineCache,
description_text_embed: dict | None = None,
) -> pd.DataFrame:
"""All the steps to transform final relationships."""
graph_edges = unpack_graph(entity_graph, callbacks, "clustered_graph", "edges")
@ -34,16 +30,6 @@ async def create_final_relationships(
pd.DataFrame, graph_edges[graph_edges["level"] == 0].reset_index(drop=True)
)
if description_text_embed:
filtered["description_embedding"] = await embed_text(
filtered,
callbacks,
cache,
column="description",
strategy=description_text_embed["strategy"],
embedding_name="relationship_description",
)
pruned_edges = filtered.drop(columns=["level"])
filtered_nodes = nodes[nodes["level"] == 0].reset_index(drop=True)

View File

@ -6,22 +6,13 @@
from typing import cast
import pandas as pd
from datashaper import (
VerbCallbacks,
)
from graphrag.index.cache import PipelineCache
from graphrag.index.operations.embed_text import embed_text
async def create_final_text_units(
def create_final_text_units(
text_units: pd.DataFrame,
final_entities: pd.DataFrame,
final_relationships: pd.DataFrame,
final_covariates: pd.DataFrame | None,
callbacks: VerbCallbacks,
cache: PipelineCache,
text_text_embed: dict | None = None,
) -> pd.DataFrame:
"""All the steps to transform the text units."""
selected = text_units.loc[:, ["id", "chunk", "document_ids", "n_tokens"]].rename(
@ -41,30 +32,12 @@ async def create_final_text_units(
aggregated = final_joined.groupby("id", sort=False).agg("first").reset_index()
is_using_vector_store = False
if text_text_embed:
aggregated["text_embedding"] = await embed_text(
aggregated,
callbacks,
cache,
column="text",
strategy=text_text_embed["strategy"],
)
is_using_vector_store = (
text_text_embed.get("strategy", {}).get("vector_store", None) is not None
)
return cast(
pd.DataFrame,
aggregated[
[
"id",
"text",
*(
[]
if (not text_text_embed or is_using_vector_store)
else ["text_embedding"]
),
"n_tokens",
"document_ids",
"entity_ids",

View File

@ -0,0 +1,146 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""All the steps to transform the text units."""
import logging
import pandas as pd
from datashaper import (
VerbCallbacks,
)
from graphrag.index.cache import PipelineCache
from graphrag.index.config.embeddings import (
community_full_content_embedding,
community_summary_embedding,
community_title_embedding,
document_raw_content_embedding,
entity_description_embedding,
entity_name_embedding,
relationship_description_embedding,
text_unit_text_embedding,
)
from graphrag.index.operations.embed_text import embed_text
from graphrag.index.operations.snapshot import snapshot
from graphrag.index.storage import PipelineStorage
log = logging.getLogger(__name__)
async def generate_text_embeddings(
final_documents: pd.DataFrame | None,
final_relationships: pd.DataFrame | None,
final_text_units: pd.DataFrame | None,
final_entities: pd.DataFrame | None,
final_community_reports: pd.DataFrame | None,
callbacks: VerbCallbacks,
cache: PipelineCache,
storage: PipelineStorage,
text_embed_config: dict,
embedded_fields: set[str],
embeddings_snapshot_enabled: bool = False,
) -> None:
"""All the steps to generate all embeddings."""
embedding_param_map = {
document_raw_content_embedding: {
"data": final_documents.loc[:, ["id", "raw_content"]]
if final_documents is not None
else None,
"column_to_embed": "raw_content",
},
relationship_description_embedding: {
"data": final_relationships.loc[:, ["id", "description"]]
if final_relationships is not None
else None,
"column_to_embed": "description",
},
text_unit_text_embedding: {
"data": final_text_units.loc[:, ["id", "text"]]
if final_text_units is not None
else None,
"column_to_embed": "text",
},
entity_name_embedding: {
"data": final_entities.loc[:, ["id", "name", "description"]]
if final_entities is not None
else None,
"column_to_embed": "name",
},
entity_description_embedding: {
"data": final_entities.loc[:, ["id", "name", "description"]].assign(
name_description=lambda df: df["name"] + ":" + df["description"]
)
if final_entities is not None
else None,
"column_to_embed": "name_description",
},
community_title_embedding: {
"data": final_community_reports.loc[
:, ["id", "full_content", "summary", "title"]
]
if final_community_reports is not None
else None,
"column_to_embed": "title",
},
community_summary_embedding: {
"data": final_community_reports.loc[
:, ["id", "full_content", "summary", "title"]
]
if final_community_reports is not None
else None,
"column_to_embed": "summary",
},
community_full_content_embedding: {
"data": final_community_reports.loc[
:, ["id", "full_content", "summary", "title"]
]
if final_community_reports is not None
else None,
"column_to_embed": "full_content",
},
}
log.info("Creating embeddings")
for field in embedded_fields:
await _run_and_snapshot_embeddings(
name=field,
callbacks=callbacks,
cache=cache,
storage=storage,
text_embed_config=text_embed_config,
embeddings_snapshot_enabled=embeddings_snapshot_enabled,
**embedding_param_map[field],
)
async def _run_and_snapshot_embeddings(
name: str,
data: pd.DataFrame,
column_to_embed: str,
callbacks: VerbCallbacks,
cache: PipelineCache,
storage: PipelineStorage,
text_embed_config: dict,
embeddings_snapshot_enabled: bool,
) -> None:
"""All the steps to generate single embedding."""
if text_embed_config:
data["embedding"] = await embed_text(
data,
callbacks,
cache,
embed_column=column_to_embed,
embedding_name=name,
strategy=text_embed_config["strategy"],
)
data = data.loc[:, ["id", "embedding"]]
if embeddings_snapshot_enabled is True:
await snapshot(
data,
name=f"embeddings.{name}",
storage=storage,
formats=["parquet"],
)

View File

@ -49,7 +49,7 @@ embeddings:
# api_key: <api_key> # if not set, will attempt to use managed identity. Expects the `Search Index Data Contributor` RBAC role in this case.
# audience: <optional> # if using managed identity, the audience to use for the token
# overwrite: true # or false. Only applicable at index creation time
# collection_name: <collection_name> # the name of the collection to use
# collection_name: <collection_name> # the name of the collection to use. Default: 'default'
llm:
api_key: ${{GRAPHRAG_API_KEY}}
type: {defs.EMBEDDING_TYPE.value} # or azure_openai_embedding

View File

@ -42,9 +42,11 @@ async def embed_text(
input: pd.DataFrame,
callbacks: VerbCallbacks,
cache: PipelineCache,
column: str,
embed_column: str,
strategy: dict,
embedding_name: str = "default",
embedding_name: str,
id_column: str = "id",
title_column: str | None = None,
):
"""
Embed a piece of text into a vector space. The operation outputs a new column containing a mapping between doc_id and vector.
@ -91,18 +93,19 @@ async def embed_text(
input,
callbacks,
cache,
column,
embed_column,
strategy,
vector_store,
vector_store_workflow_config,
vector_store_config.get("store_in_table", False),
id_column=id_column,
title_column=title_column,
)
return await _text_embed_in_memory(
input,
callbacks,
cache,
column,
embed_column,
strategy,
)
@ -111,14 +114,14 @@ async def _text_embed_in_memory(
input: pd.DataFrame,
callbacks: VerbCallbacks,
cache: PipelineCache,
column: str,
embed_column: str,
strategy: dict,
):
strategy_type = strategy["type"]
strategy_exec = load_strategy(strategy_type)
strategy_args = {**strategy}
texts: list[str] = input[column].to_numpy().tolist()
texts: list[str] = input[embed_column].to_numpy().tolist()
result = await strategy_exec(texts, callbacks, cache, strategy_args)
return result.embeddings
@ -128,11 +131,12 @@ async def _text_embed_with_vector_store(
input: pd.DataFrame,
callbacks: VerbCallbacks,
cache: PipelineCache,
column: str,
embed_column: str,
strategy: dict[str, Any],
vector_store: BaseVectorStore,
vector_store_config: dict,
store_in_table: bool = False,
id_column: str = "id",
title_column: str | None = None,
):
strategy_type = strategy["type"]
strategy_exec = load_strategy(strategy_type)
@ -142,24 +146,24 @@ async def _text_embed_with_vector_store(
insert_batch_size: int = (
vector_store_config.get("batch_size") or DEFAULT_EMBEDDING_BATCH_SIZE
)
title_column: str = vector_store_config.get("title_column", "title")
id_column: str = vector_store_config.get("id_column", "id")
overwrite: bool = vector_store_config.get("overwrite", True)
if column not in input.columns:
msg = (
f"Column {column} not found in input dataframe with columns {input.columns}"
)
if embed_column not in input.columns:
msg = f"Column {embed_column} not found in input dataframe with columns {input.columns}"
raise ValueError(msg)
if title_column not in input.columns:
msg = f"Column {title_column} not found in input dataframe with columns {input.columns}"
title = title_column or embed_column
if title not in input.columns:
msg = (
f"Column {title} not found in input dataframe with columns {input.columns}"
)
raise ValueError(msg)
if id_column not in input.columns:
msg = f"Column {id_column} not found in input dataframe with columns {input.columns}"
raise ValueError(msg)
total_rows = 0
for row in input[column]:
for row in input[embed_column]:
if isinstance(row, list):
total_rows += len(row)
else:
@ -172,8 +176,8 @@ async def _text_embed_with_vector_store(
while insert_batch_size * i < input.shape[0]:
batch = input.iloc[insert_batch_size * i : insert_batch_size * (i + 1)]
texts: list[str] = batch[column].to_numpy().tolist()
titles: list[str] = batch[title_column].to_numpy().tolist()
texts: list[str] = batch[embed_column].to_numpy().tolist()
titles: list[str] = batch[title].to_numpy().tolist()
ids: list[str] = batch[id_column].to_numpy().tolist()
result = await strategy_exec(
texts,
@ -181,7 +185,7 @@ async def _text_embed_with_vector_store(
cache,
strategy_args,
)
if store_in_table and result.embeddings:
if result.embeddings:
embeddings = [
embedding for embedding in result.embeddings if embedding is not None
]
@ -204,10 +208,7 @@ async def _text_embed_with_vector_store(
starting_index += len(documents)
i += 1
if store_in_table:
return all_results
return None
return all_results
def _create_vector_store(
@ -226,12 +227,10 @@ def _create_vector_store(
def _get_collection_name(vector_store_config: dict, embedding_name: str) -> str:
collection_name = vector_store_config.get("collection_name")
if not collection_name:
collection_names = vector_store_config.get("collection_names", {})
collection_name = collection_names.get(embedding_name, embedding_name)
container_name = vector_store_config.get("container_name")
collection_name = f"{container_name}.{embedding_name}".replace(".", "-")
msg = f"using vector store {vector_store_config.get('type')} with collection_name {collection_name} for embedding {embedding_name}"
msg = f"using vector store {vector_store_config.get('type')} with container_name {container_name} for embedding {embedding_name}: {collection_name}"
log.info(msg)
return collection_name

View File

@ -17,8 +17,8 @@ async def snapshot(
"""Take a entire snapshot of the tabular data."""
for fmt in formats:
if fmt == "parquet":
await storage.set(name + ".parquet", input.to_parquet())
await storage.set(f"{name}.parquet", input.to_parquet())
elif fmt == "json":
await storage.set(
name + ".json", input.to_json(orient="records", lines=True)
f"{name}.json", input.to_json(orient="records", lines=True)
)

View File

@ -181,7 +181,8 @@ async def _run_entity_description_embedding(
entities_df,
callbacks,
cache,
column="name_description",
embed_column="name_description",
embedding_name="entity.description",
strategy=embed_config.get("strategy", {}),
)
return entities_df.drop(columns=["name_description"])

View File

@ -67,7 +67,13 @@ from .v1.create_final_text_units import (
from .v1.create_final_text_units import (
workflow_name as create_final_text_units,
)
from .v1.generate_text_embeddings import (
build_steps as build_generate_text_embeddings_steps,
)
from .v1.generate_text_embeddings import (
workflow_name as generate_text_embeddings,
)
default_workflows: WorkflowDefinitions = {
create_base_entity_graph: build_create_base_entity_graph_steps,
@ -80,4 +86,5 @@ default_workflows: WorkflowDefinitions = {
create_final_covariates: build_create_final_covariates_steps,
create_final_entities: build_create_final_entities_steps,
create_final_communities: build_create_final_communities_steps,
generate_text_embeddings: build_generate_text_embeddings_steps,
}

View File

@ -23,19 +23,6 @@ def build_steps(
async_mode = create_community_reports_config.get("async_mode")
num_threads = create_community_reports_config.get("num_threads")
base_text_embed = config.get("text_embed", {})
community_report_full_content_embed_config = config.get(
"community_report_full_content_embed", base_text_embed
)
community_report_summary_embed_config = config.get(
"community_report_summary_embed", base_text_embed
)
community_report_title_embed_config = config.get(
"community_report_title_embed", base_text_embed
)
skip_title_embedding = config.get("skip_title_embedding", False)
skip_summary_embedding = config.get("skip_summary_embedding", False)
skip_full_content_embedding = config.get("skip_full_content_embedding", False)
input = {
"source": "workflow:create_final_nodes",
"relationships": "workflow:create_final_relationships",
@ -48,21 +35,6 @@ def build_steps(
{
"verb": "create_final_community_reports",
"args": {
"full_content_text_embed": (
community_report_full_content_embed_config
if not skip_full_content_embedding
else None
),
"summary_text_embed": (
community_report_summary_embed_config
if not skip_summary_embedding
else None
),
"title_text_embed": (
community_report_title_embed_config
if not skip_title_embedding
else None
),
"summarization_strategy": summarization_strategy,
"async_mode": async_mode,
"num_threads": num_threads,

View File

@ -19,23 +19,11 @@ def build_steps(
## Dependencies
* `workflow:create_final_text_units`
"""
base_text_embed = config.get("text_embed", {})
document_raw_content_embed_config = config.get(
"document_raw_content_embed", base_text_embed
)
skip_raw_content_embedding = config.get("skip_raw_content_embedding", False)
document_attribute_columns = config.get("document_attribute_columns", [])
return [
{
"verb": "create_final_documents",
"args": {
"document_attribute_columns": document_attribute_columns,
"raw_content_text_embed": (
document_raw_content_embed_config
if not skip_raw_content_embedding
else None
),
},
"args": {"document_attribute_columns": document_attribute_columns},
"input": {
"source": DEFAULT_INPUT_NAME,
"text_units": "workflow:create_final_text_units",

View File

@ -3,13 +3,16 @@
"""A module containing build_steps method definition."""
import logging
from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep
workflow_name = "create_final_entities"
log = logging.getLogger(__name__)
def build_steps(
config: PipelineWorkflowConfig,
config: PipelineWorkflowConfig, # noqa: ARG001
) -> list[PipelineWorkflowStep]:
"""
Create the final entities table.
@ -17,26 +20,10 @@ def build_steps(
## Dependencies
* `workflow:create_base_entity_graph`
"""
base_text_embed = config.get("text_embed", {})
entity_name_embed_config = config.get("entity_name_embed", base_text_embed)
entity_name_description_embed_config = config.get(
"entity_name_description_embed", base_text_embed
)
skip_name_embedding = config.get("skip_name_embedding", False)
skip_description_embedding = config.get("skip_description_embedding", False)
return [
{
"verb": "create_final_entities",
"args": {
"name_text_embed": entity_name_embed_config
if not skip_name_embedding
else None,
"description_text_embed": entity_name_description_embed_config
if not skip_description_embedding
else None,
},
"args": {},
"input": {"source": "workflow:create_base_entity_graph"},
},
]

View File

@ -3,13 +3,17 @@
"""A module containing build_steps method definition."""
import logging
from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep
workflow_name = "create_final_relationships"
log = logging.getLogger(__name__)
def build_steps(
config: PipelineWorkflowConfig,
config: PipelineWorkflowConfig, # noqa: ARG001
) -> list[PipelineWorkflowStep]:
"""
Create the final relationships table.
@ -18,19 +22,10 @@ def build_steps(
* `workflow:create_base_entity_graph`
* `workflow:create_final_nodes`
"""
base_text_embed = config.get("text_embed", {})
relationship_description_embed_config = config.get(
"relationship_description_embed", base_text_embed
)
skip_description_embedding = config.get("skip_description_embedding", False)
return [
{
"verb": "create_final_relationships",
"args": {
"description_text_embed": relationship_description_embed_config
if not skip_description_embedding
else None,
},
"args": {},
"input": {
"source": "workflow:create_base_entity_graph",
"nodes": "workflow:create_final_nodes",

View File

@ -19,10 +19,6 @@ def build_steps(
* `workflow:create_final_entities`
* `workflow:create_final_communities`
"""
base_text_embed = config.get("text_embed", {})
text_unit_text_embed_config = config.get("text_unit_text_embed", base_text_embed)
skip_text_unit_embedding = config.get("skip_text_unit_embedding", False)
covariates_enabled = config.get("covariates_enabled", False)
input = {
@ -37,11 +33,7 @@ def build_steps(
return [
{
"verb": "create_final_text_units",
"args": {
"text_text_embed": text_unit_text_embed_config
if not skip_text_unit_embedding
else None,
},
"args": {},
"input": input,
},
]

View File

@ -0,0 +1,49 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing build_steps method definition."""
import logging
from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep
log = logging.getLogger(__name__)
workflow_name = "generate_text_embeddings"
input = {
"source": "workflow:create_final_documents",
"relationships": "workflow:create_final_relationships",
"text_units": "workflow:create_final_text_units",
"entities": "workflow:create_final_entities",
"community_reports": "workflow:create_final_community_reports",
}
def build_steps(
config: PipelineWorkflowConfig,
) -> list[PipelineWorkflowStep]:
"""
Create the final embeddings files.
## Dependencies
* `workflow:create_final_documents`
* `workflow:create_final_relationships`
* `workflow:create_final_text_units`
* `workflow:create_final_entities`
* `workflow:create_final_community_reports`
"""
text_embed = config.get("text_embed", {})
embedded_fields = config.get("embedded_fields", {})
embeddings_snapshot_enabled = config.get("snapshot_embeddings", False)
return [
{
"verb": "generate_text_embeddings",
"args": {
"text_embed": text_embed,
"embedded_fields": embedded_fields,
"embeddings_snapshot_enabled": embeddings_snapshot_enabled,
},
"input": input,
},
]

View File

@ -15,6 +15,7 @@ from .create_final_relationships import (
create_final_relationships,
)
from .create_final_text_units import create_final_text_units
from .generate_text_embeddings import generate_text_embeddings
__all__ = [
"create_base_entity_graph",
@ -27,4 +28,5 @@ __all__ = [
"create_final_nodes",
"create_final_relationships",
"create_final_text_units",
"generate_text_embeddings",
]

View File

@ -30,9 +30,6 @@ async def create_final_community_reports(
summarization_strategy: dict,
async_mode: AsyncType = AsyncType.AsyncIO,
num_threads: int = 4,
full_content_text_embed: dict | None = None,
summary_text_embed: dict | None = None,
title_text_embed: dict | None = None,
**_kwargs: dict,
) -> VerbResult:
"""All the steps to transform community reports."""
@ -57,9 +54,6 @@ async def create_final_community_reports(
summarization_strategy,
async_mode=async_mode,
num_threads=num_threads,
full_content_text_embed=full_content_text_embed,
summary_text_embed=summary_text_embed,
title_text_embed=title_text_embed,
)
return create_verb_result(

View File

@ -8,13 +8,11 @@ from typing import cast
import pandas as pd
from datashaper import (
Table,
VerbCallbacks,
VerbInput,
verb,
)
from datashaper.table_store.types import VerbResult, create_verb_result
from graphrag.index.cache import PipelineCache
from graphrag.index.flows.create_final_documents import (
create_final_documents as create_final_documents_flow,
)
@ -25,25 +23,15 @@ from graphrag.index.utils.ds_util import get_required_input_table
name="create_final_documents",
treats_input_tables_as_immutable=True,
)
async def create_final_documents(
def create_final_documents(
input: VerbInput,
callbacks: VerbCallbacks,
cache: PipelineCache,
document_attribute_columns: list[str] | None = None,
raw_content_text_embed: dict | None = None,
**_kwargs: dict,
) -> VerbResult:
"""All the steps to transform final documents."""
source = cast(pd.DataFrame, input.get_input())
text_units = cast(pd.DataFrame, get_required_input_table(input, "text_units").table)
output = await create_final_documents_flow(
source,
text_units,
callbacks,
cache,
document_attribute_columns=document_attribute_columns,
raw_content_text_embed=raw_content_text_embed,
)
output = create_final_documents_flow(source, text_units, document_attribute_columns)
return create_verb_result(cast(Table, output))

View File

@ -14,7 +14,6 @@ from datashaper import (
)
from datashaper.table_store.types import VerbResult, create_verb_result
from graphrag.index.cache import PipelineCache
from graphrag.index.flows.create_final_entities import (
create_final_entities as create_final_entities_flow,
)
@ -24,23 +23,17 @@ from graphrag.index.flows.create_final_entities import (
name="create_final_entities",
treats_input_tables_as_immutable=True,
)
async def create_final_entities(
def create_final_entities(
input: VerbInput,
callbacks: VerbCallbacks,
cache: PipelineCache,
name_text_embed: dict | None = None,
description_text_embed: dict | None = None,
**_kwargs: dict,
) -> VerbResult:
"""All the steps to transform final entities."""
source = cast(pd.DataFrame, input.get_input())
output = await create_final_entities_flow(
output = create_final_entities_flow(
source,
callbacks,
cache,
name_text_embed=name_text_embed,
description_text_embed=description_text_embed,
)
return create_verb_result(cast(Table, output))

View File

@ -14,7 +14,6 @@ from datashaper import (
)
from datashaper.table_store.types import VerbResult, create_verb_result
from graphrag.index.cache import PipelineCache
from graphrag.index.flows.create_final_relationships import (
create_final_relationships as create_final_relationships_flow,
)
@ -25,23 +24,19 @@ from graphrag.index.utils.ds_util import get_required_input_table
name="create_final_relationships",
treats_input_tables_as_immutable=True,
)
async def create_final_relationships(
def create_final_relationships(
input: VerbInput,
callbacks: VerbCallbacks,
cache: PipelineCache,
description_text_embed: dict | None = None,
**_kwargs: dict,
) -> VerbResult:
"""All the steps to transform final relationships."""
source = cast(pd.DataFrame, input.get_input())
nodes = cast(pd.DataFrame, get_required_input_table(input, "nodes").table)
output = await create_final_relationships_flow(
source,
nodes,
callbacks,
cache,
description_text_embed=description_text_embed,
output = create_final_relationships_flow(
entity_graph=source,
nodes=nodes,
callbacks=callbacks,
)
return create_verb_result(cast(Table, output))

View File

@ -8,14 +8,12 @@ from typing import cast
import pandas as pd
from datashaper import (
Table,
VerbCallbacks,
VerbInput,
VerbResult,
create_verb_result,
verb,
)
from graphrag.index.cache import PipelineCache
from graphrag.index.flows.create_final_text_units import (
create_final_text_units as create_final_text_units_flow,
)
@ -26,10 +24,7 @@ from graphrag.index.utils.ds_util import get_named_input_table, get_required_inp
@verb(name="create_final_text_units", treats_input_tables_as_immutable=True)
async def create_final_text_units(
input: VerbInput,
callbacks: VerbCallbacks,
cache: PipelineCache,
runtime_storage: PipelineStorage,
text_text_embed: dict | None = None,
**_kwargs: dict,
) -> VerbResult:
"""All the steps to transform the text units."""
@ -45,14 +40,11 @@ async def create_final_text_units(
if final_covariates:
final_covariates = cast(pd.DataFrame, final_covariates.table)
output = await create_final_text_units_flow(
output = create_final_text_units_flow(
text_units,
final_entities,
final_relationships,
final_covariates,
callbacks,
cache,
text_text_embed=text_text_embed,
)
return create_verb_result(cast(Table, output))

View File

@ -0,0 +1,70 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""All the steps to transform the text units."""
import logging
from typing import cast
import pandas as pd
from datashaper import (
Table,
VerbCallbacks,
VerbInput,
VerbResult,
create_verb_result,
verb,
)
from graphrag.index.cache import PipelineCache
from graphrag.index.flows.generate_text_embeddings import (
generate_text_embeddings as generate_text_embeddings_flow,
)
from graphrag.index.storage import PipelineStorage
from graphrag.index.utils.ds_util import get_required_input_table
log = logging.getLogger(__name__)
@verb(name="generate_text_embeddings", treats_input_tables_as_immutable=True)
async def generate_text_embeddings(
input: VerbInput,
callbacks: VerbCallbacks,
cache: PipelineCache,
storage: PipelineStorage,
text_embed: dict,
embedded_fields: set[str],
embeddings_snapshot_enabled: bool = False,
**_kwargs: dict,
) -> VerbResult:
"""All the steps to generate embeddings."""
source = cast(pd.DataFrame, input.get_input())
final_relationships = cast(
pd.DataFrame, get_required_input_table(input, "relationships").table
)
final_text_units = cast(
pd.DataFrame, get_required_input_table(input, "text_units").table
)
final_entities = cast(
pd.DataFrame, get_required_input_table(input, "entities").table
)
final_community_reports = cast(
pd.DataFrame, get_required_input_table(input, "community_reports").table
)
await generate_text_embeddings_flow(
final_documents=source,
final_relationships=final_relationships,
final_text_units=final_text_units,
final_entities=final_entities,
final_community_reports=final_community_reports,
callbacks=callbacks,
cache=cache,
storage=storage,
text_embed_config=text_embed,
embedded_fields=embedded_fields,
embeddings_snapshot_enabled=embeddings_snapshot_enabled,
)
return create_verb_result(cast(Table, pd.DataFrame()))

View File

@ -3,10 +3,9 @@
"""The LanceDB vector storage implementation package."""
import json
import json # noqa: I001
from typing import Any
import lancedb as lancedb
import pyarrow as pa
from graphrag.model.types import TextEmbedder
@ -16,6 +15,7 @@ from .base import (
VectorStoreDocument,
VectorStoreSearchResult,
)
import lancedb
class LanceDBVectorStore(BaseVectorStore):

View File

@ -6,9 +6,7 @@ embeddings:
type: "azure_ai_search"
url: ${AZURE_AI_SEARCH_URL_ENDPOINT}
api_key: ${AZURE_AI_SEARCH_API_KEY}
collection_name: "azure_ci"
entity_name_description:
title_column: "name"
container_name: "azure_ci"
input:
type: blob

View File

@ -8,7 +8,8 @@
2500
],
"subworkflows": 1,
"max_runtime": 150
"max_runtime": 150,
"expected_artifacts": 0
},
"create_base_entity_graph": {
"row_range": [
@ -16,7 +17,8 @@
2500
],
"subworkflows": 1,
"max_runtime": 300
"max_runtime": 300,
"expected_artifacts": 1
},
"create_final_entities": {
"row_range": [
@ -29,7 +31,8 @@
"graph_embedding"
],
"subworkflows": 1,
"max_runtime": 300
"max_runtime": 300,
"expected_artifacts": 1
},
"create_final_relationships": {
"row_range": [
@ -37,7 +40,8 @@
6000
],
"subworkflows": 1,
"max_runtime": 150
"max_runtime": 150,
"expected_artifacts": 1
},
"create_final_nodes": {
"row_range": [
@ -52,7 +56,8 @@
"level"
],
"subworkflows": 1,
"max_runtime": 150
"max_runtime": 150,
"expected_artifacts": 1
},
"create_final_communities": {
"row_range": [
@ -60,7 +65,8 @@
2500
],
"subworkflows": 1,
"max_runtime": 150
"max_runtime": 150,
"expected_artifacts": 1
},
"create_final_community_reports": {
"row_range": [
@ -78,7 +84,8 @@
"findings"
],
"subworkflows": 1,
"max_runtime": 300
"max_runtime": 300,
"expected_artifacts": 1
},
"create_final_text_units": {
"row_range": [
@ -90,7 +97,8 @@
"entity_ids"
],
"subworkflows": 1,
"max_runtime": 150
"max_runtime": 150,
"expected_artifacts": 1
},
"create_final_documents": {
"row_range": [
@ -98,7 +106,17 @@
2500
],
"subworkflows": 1,
"max_runtime": 150
"max_runtime": 150,
"expected_artifacts": 1
},
"generate_text_embeddings": {
"row_range": [
1,
2500
],
"subworkflows": 1,
"max_runtime": 150,
"expected_artifacts": 1
}
},
"query_config": [

View File

@ -5,19 +5,8 @@ embeddings:
vector_store:
type: "lancedb"
db_uri: "./tests/fixtures/min-csv/lancedb"
collection_name: "lancedb_ci"
container_name: "lancedb_ci"
overwrite: True
store_in_table: True
entity_name_description:
title_column: "name"
# id_column: "id"
# entity_name: ...
# relationship_description: ...
# community_report_full_content: ...
# community_report_summary: ...
# community_report_title: ...
# document_raw_content: ...
# text_unit_text: ...
storage:
type: file # or blob
@ -29,4 +18,7 @@ reporting:
type: file # or console, blob
base_dir: "output/${timestamp}/reports"
# connection_string: <azure_blob_storage_connection_string>
# container_name: <azure_blob_storage_container_name>
# container_name: <azure_blob_storage_container_name>
snapshots:
embeddings: True

View File

@ -8,7 +8,8 @@
2500
],
"subworkflows": 1,
"max_runtime": 150
"max_runtime": 150,
"expected_artifacts": 0
},
"create_final_covariates": {
"row_range": [
@ -25,7 +26,8 @@
"source_text"
],
"subworkflows": 1,
"max_runtime": 300
"max_runtime": 300,
"expected_artifacts": 1
},
"create_base_entity_graph": {
"row_range": [
@ -33,7 +35,8 @@
2500
],
"subworkflows": 1,
"max_runtime": 300
"max_runtime": 300,
"expected_artifacts": 1
},
"create_final_entities": {
"row_range": [
@ -46,7 +49,8 @@
"graph_embedding"
],
"subworkflows": 1,
"max_runtime": 300
"max_runtime": 300,
"expected_artifacts": 1
},
"create_final_relationships": {
"row_range": [
@ -54,7 +58,8 @@
6000
],
"subworkflows": 1,
"max_runtime": 150
"max_runtime": 150,
"expected_artifacts": 1
},
"create_final_nodes": {
"row_range": [
@ -69,7 +74,8 @@
"level"
],
"subworkflows": 1,
"max_runtime": 150
"max_runtime": 150,
"expected_artifacts": 1
},
"create_final_communities": {
"row_range": [
@ -77,7 +83,8 @@
2500
],
"subworkflows": 1,
"max_runtime": 150
"max_runtime": 150,
"expected_artifacts": 1
},
"create_final_community_reports": {
"row_range": [
@ -95,7 +102,8 @@
"findings"
],
"subworkflows": 1,
"max_runtime": 300
"max_runtime": 300,
"expected_artifacts": 1
},
"create_final_text_units": {
"row_range": [
@ -107,7 +115,8 @@
"entity_ids"
],
"subworkflows": 1,
"max_runtime": 150
"max_runtime": 150,
"expected_artifacts": 1
},
"create_final_documents": {
"row_range": [
@ -115,7 +124,17 @@
2500
],
"subworkflows": 1,
"max_runtime": 150
"max_runtime": 150,
"expected_artifacts": 1
},
"generate_text_embeddings": {
"row_range": [
1,
2500
],
"subworkflows": 1,
"max_runtime": 150,
"expected_artifacts": 1
}
},
"query_config": [

View File

@ -6,10 +6,7 @@ embeddings:
type: "azure_ai_search"
url: ${AZURE_AI_SEARCH_URL_ENDPOINT}
api_key: ${AZURE_AI_SEARCH_API_KEY}
collection_name: "simple_text_ci"
store_in_table: True
entity_name_description:
title_column: "name"
container_name: "simple_text_ci"
community_reports:
prompt: "prompts/community_report.txt"
@ -27,4 +24,7 @@ reporting:
type: file # or console, blob
base_dir: "output/${timestamp}/reports"
# connection_string: <azure_blob_storage_connection_string>
# container_name: <azure_blob_storage_container_name>
# container_name: <azure_blob_storage_container_name>
snapshots:
embeddings: True

View File

@ -172,6 +172,7 @@ class TestIndexer:
stats = json.loads((artifacts / "stats.json").read_bytes().decode("utf-8"))
# Check all workflows run
expected_artifacts = 0
expected_workflows = set(workflow_config.keys())
workflows = set(stats["workflows"].keys())
assert (
@ -180,6 +181,10 @@ class TestIndexer:
# [OPTIONAL] Check runtime
for workflow in expected_workflows:
# Check expected artifacts
expected_artifacts = expected_artifacts + workflow_config[workflow].get(
"expected_artifacts", 1
)
# Check max runtime
max_runtime = workflow_config[workflow].get("max_runtime", None)
if max_runtime:
@ -189,38 +194,40 @@ class TestIndexer:
# Check artifacts
artifact_files = os.listdir(artifacts)
# check that the number of workflows matches the number of artifacts, but:
# (1) do not count workflows with only transient output
# (2) account for the stats.json file
transient_workflows = [
"workflow:create_base_text_units",
]
# check that the number of workflows matches the number of artifacts
assert (
len(artifact_files)
== (len(expected_workflows) - len(transient_workflows) + 1)
len(artifact_files) == (expected_artifacts + 1)
), f"Expected {len(expected_workflows) + 1} artifacts, found: {len(artifact_files)}"
for artifact in artifact_files:
if artifact.endswith(".parquet"):
output_df = pd.read_parquet(artifacts / artifact)
artifact_name = artifact.split(".")[0]
workflow = workflow_config[artifact_name]
# Check number of rows between range
assert (
workflow["row_range"][0]
<= len(output_df)
<= workflow["row_range"][1]
), f"Expected between {workflow['row_range'][0]} and {workflow['row_range'][1]}, found: {len(output_df)} for file: {artifact}"
try:
workflow = workflow_config[artifact_name]
# Get non-nan rows
nan_df = output_df.loc[
:, ~output_df.columns.isin(workflow.get("nan_allowed_columns", []))
]
nan_df = nan_df[nan_df.isna().any(axis=1)]
assert (
len(nan_df) == 0
), f"Found {len(nan_df)} rows with NaN values for file: {artifact} on columns: {nan_df.columns[nan_df.isna().any()].tolist()}"
# Check number of rows between range
assert (
workflow["row_range"][0]
<= len(output_df)
<= workflow["row_range"][1]
), f"Expected between {workflow['row_range'][0]} and {workflow['row_range'][1]}, found: {len(output_df)} for file: {artifact}"
# Get non-nan rows
nan_df = output_df.loc[
:,
~output_df.columns.isin(
workflow.get("nan_allowed_columns", [])
),
]
nan_df = nan_df[nan_df.isna().any(axis=1)]
assert (
len(nan_df) == 0
), f"Found {len(nan_df)} rows with NaN values for file: {artifact} on columns: {nan_df.columns[nan_df.isna().any()].tolist()}"
except KeyError:
log.warning("No workflow config found %s", artifact_name)
def __run_query(self, root: Path, query_config: dict[str, str]):
command = [

View File

@ -74,44 +74,6 @@ async def test_create_final_community_reports():
assert actual["rank_explanation"][:1][0] == "<rating_explanation>"
async def test_create_final_community_reports_with_embeddings():
input_tables = load_input_tables([
"workflow:create_final_nodes",
"workflow:create_final_covariates",
"workflow:create_final_relationships",
"workflow:create_final_communities",
])
expected = load_expected(workflow_name)
config = get_config_for_workflow(workflow_name)
config["create_community_reports"]["strategy"]["llm"] = MOCK_LLM_CONFIG
config["skip_full_content_embedding"] = False
config["community_report_full_content_embed"]["strategy"]["type"] = "mock"
config["skip_summary_embedding"] = False
config["community_report_summary_embed"]["strategy"]["type"] = "mock"
config["skip_title_embedding"] = False
config["community_report_title_embed"]["strategy"]["type"] = "mock"
steps = build_steps(config)
actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
)
assert len(actual.columns) == len(expected.columns) + 3
assert "full_content_embedding" in actual.columns
assert len(actual["full_content_embedding"][:1][0]) == 3
assert "summary_embedding" in actual.columns
assert len(actual["summary_embedding"][:1][0]) == 3
assert "title_embedding" in actual.columns
assert len(actual["title_embedding"][:1][0]) == 3
async def test_create_final_community_reports_missing_llm_throws():
input_tables = load_input_tables([
"workflow:create_final_nodes",

View File

@ -23,8 +23,6 @@ async def test_create_final_documents():
config = get_config_for_workflow(workflow_name)
config["skip_raw_content_embedding"] = True
steps = build_steps(config)
actual = await get_workflow_output(
@ -37,34 +35,6 @@ async def test_create_final_documents():
compare_outputs(actual, expected)
async def test_create_final_documents_with_embeddings():
input_tables = load_input_tables([
"workflow:create_final_text_units",
])
expected = load_expected(workflow_name)
config = get_config_for_workflow(workflow_name)
config["skip_raw_content_embedding"] = False
# default config has a detailed standard embed config
# just override the strategy to mock so the rest of the required parameters are in place
config["document_raw_content_embed"]["strategy"]["type"] = "mock"
steps = build_steps(config)
actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
)
assert "raw_content_embedding" in actual.columns
assert len(actual.columns) == len(expected.columns) + 1
# the mock impl returns an array of 3 floats for each embedding
assert len(actual["raw_content_embedding"][:1][0]) == 3
async def test_create_final_documents_with_attribute_columns():
input_tables = load_input_tables(["workflow:create_final_text_units"])
expected = load_expected(workflow_name)

View File

@ -23,9 +23,6 @@ async def test_create_final_entities():
config = get_config_for_workflow(workflow_name)
config["skip_name_embedding"] = True
config["skip_description_embedding"] = True
steps = build_steps(config)
actual = await get_workflow_output(
@ -50,83 +47,3 @@ async def test_create_final_entities():
],
)
assert len(actual.columns) == len(expected.columns) - 1
async def test_create_final_entities_with_name_embeddings():
input_tables = load_input_tables([
"workflow:create_base_entity_graph",
])
expected = load_expected(workflow_name)
config = get_config_for_workflow(workflow_name)
config["skip_name_embedding"] = False
config["skip_description_embedding"] = True
config["entity_name_embed"]["strategy"]["type"] = "mock"
steps = build_steps(config)
actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
)
assert "name_embedding" in actual.columns
assert len(actual.columns) == len(expected.columns)
# the mock impl returns an array of 3 floats for each embedding
assert len(actual["name_embedding"][:1][0]) == 3
async def test_create_final_entities_with_description_embeddings():
input_tables = load_input_tables([
"workflow:create_base_entity_graph",
])
expected = load_expected(workflow_name)
config = get_config_for_workflow(workflow_name)
config["skip_name_embedding"] = True
config["skip_description_embedding"] = False
config["entity_name_description_embed"]["strategy"]["type"] = "mock"
steps = build_steps(config)
actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
)
assert "description_embedding" in actual.columns
assert len(actual.columns) == len(expected.columns)
assert len(actual["description_embedding"][:1][0]) == 3
async def test_create_final_entities_with_name_and_description_embeddings():
input_tables = load_input_tables([
"workflow:create_base_entity_graph",
])
expected = load_expected(workflow_name)
config = get_config_for_workflow(workflow_name)
config["skip_name_embedding"] = False
config["skip_description_embedding"] = False
config["entity_name_description_embed"]["strategy"]["type"] = "mock"
config["entity_name_embed"]["strategy"]["type"] = "mock"
steps = build_steps(config)
actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
)
assert "description_embedding" in actual.columns
assert len(actual.columns) == len(expected.columns) + 1
assert len(actual["description_embedding"][:1][0]) == 3

View File

@ -24,8 +24,6 @@ async def test_create_final_relationships():
config = get_config_for_workflow(workflow_name)
config["skip_description_embedding"] = True
steps = build_steps(config)
actual = await get_workflow_output(
@ -36,32 +34,3 @@ async def test_create_final_relationships():
)
compare_outputs(actual, expected)
async def test_create_final_relationships_with_embeddings():
input_tables = load_input_tables([
"workflow:create_base_entity_graph",
"workflow:create_final_nodes",
])
expected = load_expected(workflow_name)
config = get_config_for_workflow(workflow_name)
config["skip_description_embedding"] = False
# default config has a detailed standard embed config
# just override the strategy to mock so the rest of the required parameters are in place
config["relationship_description_embed"]["strategy"]["type"] = "mock"
steps = build_steps(config)
actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
)
assert "description_embedding" in actual.columns
assert len(actual.columns) == len(expected.columns) + 1
# the mock impl returns an array of 3 floats for each embedding
assert len(actual["description_embedding"][:1][0]) == 3

View File

@ -33,7 +33,6 @@ async def test_create_final_text_units():
config = get_config_for_workflow(workflow_name)
config["covariates_enabled"] = True
config["skip_text_unit_embedding"] = True
steps = build_steps(config)
@ -65,7 +64,6 @@ async def test_create_final_text_units_no_covariates():
config = get_config_for_workflow(workflow_name)
config["covariates_enabled"] = False
config["skip_text_unit_embedding"] = True
steps = build_steps(config)
@ -83,41 +81,3 @@ async def test_create_final_text_units_no_covariates():
expected,
["id", "text", "n_tokens", "document_ids", "entity_ids", "relationship_ids"],
)
async def test_create_final_text_units_with_embeddings():
input_tables = load_input_tables([
"workflow:create_base_text_units",
"workflow:create_final_entities",
"workflow:create_final_relationships",
"workflow:create_final_covariates",
])
expected = load_expected(workflow_name)
context = create_run_context(None, None, None)
await context.runtime_storage.set(
"base_text_units", input_tables["workflow:create_base_text_units"]
)
config = get_config_for_workflow(workflow_name)
config["covariates_enabled"] = True
config["skip_text_unit_embedding"] = False
# default config has a detailed standard embed config
# just override the strategy to mock so the rest of the required parameters are in place
config["text_unit_text_embed"]["strategy"]["type"] = "mock"
steps = build_steps(config)
actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
context=context,
)
assert "text_embedding" in actual.columns
assert len(actual.columns) == len(expected.columns) + 1
# the mock impl returns an array of 3 floats for each embedding
assert len(actual["text_embedding"][:1][0]) == 3

View File

@ -0,0 +1,82 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from io import BytesIO
import pandas as pd
from graphrag.index.config.embeddings import (
all_embeddings,
)
from graphrag.index.run.utils import create_run_context
from graphrag.index.workflows.v1.generate_text_embeddings import (
build_steps,
workflow_name,
)
from .util import (
get_config_for_workflow,
get_workflow_output,
load_input_tables,
)
async def test_generate_text_embeddings():
input_tables = load_input_tables(
inputs=[
"workflow:create_final_documents",
"workflow:create_final_relationships",
"workflow:create_final_text_units",
"workflow:create_final_entities",
"workflow:create_final_community_reports",
]
)
context = create_run_context(None, None, None)
config = get_config_for_workflow(workflow_name)
config["text_embed"]["strategy"]["type"] = "mock"
config["snapshot_embeddings"] = True
config["embedded_fields"] = all_embeddings
steps = build_steps(config)
await get_workflow_output(
input_tables,
{
"steps": steps,
},
context,
)
parquet_files = context.storage.keys()
for field in all_embeddings:
assert f"embeddings.{field}.parquet" in parquet_files
# entity description should always be here, let's assert its format
entity_description_embeddings_buffer = BytesIO(
await context.storage.get(
"embeddings.entity.description.parquet", as_bytes=True
)
)
entity_description_embeddings = pd.read_parquet(
entity_description_embeddings_buffer
)
assert len(entity_description_embeddings.columns) == 2
assert "id" in entity_description_embeddings.columns
assert "embedding" in entity_description_embeddings.columns
# every other embedding is optional but we've turned them all on, so check a random one
document_raw_content_embeddings_buffer = BytesIO(
await context.storage.get(
"embeddings.document.raw_content.parquet", as_bytes=True
)
)
document_raw_content_embeddings = pd.read_parquet(
document_raw_content_embeddings_buffer
)
assert len(document_raw_content_embeddings.columns) == 2
assert "id" in document_raw_content_embeddings.columns
assert "embedding" in document_raw_content_embeddings.columns