Merge branch 'main' into incremental_indexing/main

This commit is contained in:
Alonso Guevara 2024-09-25 17:33:33 -06:00
commit bf45f42969
36 changed files with 644 additions and 329 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Collapse create_base_text_units."
}

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Merge text_embed into create-final-relationships subflow."
}

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Optimize Create Base Documents subflow"
}

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Collapse covariates flow."
}

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Collapse create-final-documents."
}

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Fix nested json parsing"
}

View File

@ -100,6 +100,7 @@ aembed
dedupe
dropna
dtypes
notna
# LLM Terms
AOAI

View File

@ -152,7 +152,6 @@ class ClaimExtractor:
subject = resolved_entities.get(subject, subject)
claim["object_id"] = obj
claim["subject_id"] = subject
claim["doc_id"] = document_id
return claim
async def _process_document(
@ -200,10 +199,7 @@ class ClaimExtractor:
if response.output != "YES":
break
result = self._parse_claim_tuples(results, prompt_args)
for r in result:
r["doc_id"] = f"{doc_index}"
return result
return self._parse_claim_tuples(results, prompt_args)
def _parse_claim_tuples(
self, claims: str, prompt_variables: dict
@ -243,6 +239,5 @@ class ClaimExtractor:
"end_date": pull_field(5, claim_fields),
"description": pull_field(6, claim_fields),
"source_text": pull_field(7, claim_fields),
"doc_id": pull_field(8, claim_fields),
})
return result

View File

@ -49,16 +49,37 @@ async def extract_covariates(
entity_types: list[str] | None = None,
**kwargs,
) -> TableContainer:
"""
Extract claims from a piece of text.
"""Extract claims from a piece of text."""
source = cast(pd.DataFrame, input.get_input())
output = await extract_covariates_df(
source,
cache,
callbacks,
column,
covariate_type,
strategy,
async_mode,
entity_types,
**kwargs,
)
return TableContainer(table=output)
## Usage
TODO
"""
async def extract_covariates_df(
input: pd.DataFrame,
cache: PipelineCache,
callbacks: VerbCallbacks,
column: str,
covariate_type: str,
strategy: dict[str, Any] | None,
async_mode: AsyncType = AsyncType.AsyncIO,
entity_types: list[str] | None = None,
**kwargs,
):
"""Extract claims from a piece of text."""
log.debug("extract_covariates strategy=%s", strategy)
if entity_types is None:
entity_types = DEFAULT_ENTITY_TYPES
output = cast(pd.DataFrame, input.get_input())
resolved_entities_map = {}
@ -79,14 +100,13 @@ async def extract_covariates(
]
results = await derive_from_rows(
output,
input,
run_strategy,
callbacks,
scheduling_type=async_mode,
num_threads=kwargs.get("num_threads", 4),
)
output = pd.DataFrame([item for row in results for item in row or []])
return TableContainer(table=output)
return pd.DataFrame([item for row in results for item in row or []])
def load_strategy(strategy_type: ExtractClaimsStrategyType) -> CovariateExtractStrategy:
@ -103,8 +123,4 @@ def load_strategy(strategy_type: ExtractClaimsStrategyType) -> CovariateExtractS
def create_row_from_claim_data(row, covariate_data: Covariate, covariate_type: str):
"""Create a row from the claim data and the input row."""
item = {**row, **asdict(covariate_data), "covariate_type": covariate_type}
# TODO: doc_id from extraction isn't necessary
# since chunking happens before this
del item["doc_id"]
return item
return {**row, **asdict(covariate_data), "covariate_type": covariate_type}

View File

@ -5,17 +5,6 @@
MOCK_LLM_RESPONSES = [
"""
[
{
"subject": "COMPANY A",
"object": "GOVERNMENT AGENCY B",
"type": "ANTI-COMPETITIVE PRACTICES",
"status": "TRUE",
"start_date": "2022-01-10T00:00:00",
"end_date": "2022-01-10T00:00:00",
"description": "Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10",
"source_text": ["According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B."]
}
]
(COMPANY A<|>GOVERNMENT AGENCY B<|>ANTI-COMPETITIVE PRACTICES<|>TRUE<|>2022-01-10T00:00:00<|>2022-01-10T00:00:00<|>Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10<|>According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B.)
""".strip()
]

View File

@ -91,16 +91,13 @@ def create_covariate(item: dict[str, Any]) -> Covariate:
"""Create a covariate from the item."""
return Covariate(
subject_id=item.get("subject_id"),
subject_type=item.get("subject_type"),
object_id=item.get("object_id"),
object_type=item.get("object_type"),
type=item.get("type"),
status=item.get("status"),
start_date=item.get("start_date"),
end_date=item.get("end_date"),
description=item.get("description"),
source_text=item.get("source_text"),
doc_id=item.get("doc_id"),
record_id=item.get("record_id"),
id=item.get("id"),
)

View File

@ -18,9 +18,7 @@ class Covariate:
covariate_type: str | None = None
subject_id: str | None = None
subject_type: str | None = None
object_id: str | None = None
object_type: str | None = None
type: str | None = None
status: str | None = None
start_date: str | None = None

View File

@ -16,7 +16,7 @@ def genid(
input: VerbInput,
to: str,
method: str = "md5_hash",
hash: list[str] = [], # noqa A002
hash: list[str] | None = None, # noqa A002
**_kwargs: dict,
) -> TableContainer:
"""
@ -52,15 +52,29 @@ def genid(
"""
data = cast(pd.DataFrame, input.source.table)
if method == "md5_hash":
if len(hash) == 0:
msg = 'Must specify the "hash" columns to use md5_hash method'
output = genid_df(data, to, method, hash)
return TableContainer(table=output)
def genid_df(
input: pd.DataFrame,
to: str,
method: str = "md5_hash",
hash: list[str] | None = None, # noqa A002
):
"""Generate a unique id for each row in the tabular data."""
data = input
match method:
case "md5_hash":
if not hash:
msg = 'Must specify the "hash" columns to use md5_hash method'
raise ValueError(msg)
data[to] = data.apply(lambda row: gen_md5_hash(row, hash), axis=1)
case "increment":
data[to] = data.index + 1
case _:
msg = f"Unknown method {method}"
raise ValueError(msg)
data[to] = data.apply(lambda row: gen_md5_hash(row, hash), axis=1)
elif method == "increment":
data[to] = data.index + 1
else:
msg = f"Unknown method {method}"
raise ValueError(msg)
return TableContainer(table=data)
return data

View File

@ -85,9 +85,24 @@ def chunk(
type: sentence
```
"""
input_table = cast(pd.DataFrame, input.get_input())
output = chunk_df(input_table, column, to, callbacks, strategy)
return TableContainer(table=output)
def chunk_df(
input: pd.DataFrame,
column: str,
to: str,
callbacks: VerbCallbacks,
strategy: dict[str, Any] | None = None,
) -> pd.DataFrame:
"""Chunk a piece of text into smaller pieces."""
output = input
if strategy is None:
strategy = {}
output = cast(pd.DataFrame, input.get_input())
strategy_name = strategy.get("type", ChunkStrategyType.tokens)
strategy_config = {**strategy}
strategy_exec = load_strategy(strategy_name)
@ -102,7 +117,7 @@ def chunk(
),
axis=1,
)
return TableContainer(table=output)
return output
def run_strategy(

View File

@ -79,6 +79,23 @@ async def text_embed(
<...>
```
"""
input_df = cast(pd.DataFrame, input.get_input())
result_df = await text_embed_df(
input_df, callbacks, cache, column, strategy, **kwargs
)
return TableContainer(table=result_df)
# TODO: this ultimately just creates a new column, so our embed function could just generate a series instead of updating the dataframe
async def text_embed_df(
input: pd.DataFrame,
callbacks: VerbCallbacks,
cache: PipelineCache,
column: str,
strategy: dict,
**kwargs,
):
"""Embed a piece of text into a vector space."""
vector_store_config = strategy.get("vector_store")
if vector_store_config:
@ -113,28 +130,28 @@ async def text_embed(
async def _text_embed_in_memory(
input: VerbInput,
input: pd.DataFrame,
callbacks: VerbCallbacks,
cache: PipelineCache,
column: str,
strategy: dict,
to: str,
):
output_df = cast(pd.DataFrame, input.get_input())
output_df = input
strategy_type = strategy["type"]
strategy_exec = load_strategy(strategy_type)
strategy_args = {**strategy}
input_table = input.get_input()
input_table = input
texts: list[str] = input_table[column].to_numpy().tolist()
result = await strategy_exec(texts, callbacks, cache, strategy_args)
output_df[to] = result.embeddings
return TableContainer(table=output_df)
return output_df
async def _text_embed_with_vector_store(
input: VerbInput,
input: pd.DataFrame,
callbacks: VerbCallbacks,
cache: PipelineCache,
column: str,
@ -144,7 +161,7 @@ async def _text_embed_with_vector_store(
store_in_table: bool = False,
to: str = "",
):
output_df = cast(pd.DataFrame, input.get_input())
output_df = input
strategy_type = strategy["type"]
strategy_exec = load_strategy(strategy_type)
strategy_args = {**strategy}
@ -179,10 +196,8 @@ async def _text_embed_with_vector_store(
all_results = []
while insert_batch_size * i < input.get_input().shape[0]:
batch = input.get_input().iloc[
insert_batch_size * i : insert_batch_size * (i + 1)
]
while insert_batch_size * i < input.shape[0]:
batch = input.iloc[insert_batch_size * i : insert_batch_size * (i + 1)]
texts: list[str] = batch[column].to_numpy().tolist()
titles: list[str] = batch[title_column].to_numpy().tolist()
ids: list[str] = batch[id_column].to_numpy().tolist()
@ -218,7 +233,7 @@ async def _text_embed_with_vector_store(
if store_in_table:
output_df[to] = all_results
return TableContainer(table=output_df)
return output_df
def _create_vector_store(

View File

@ -22,91 +22,16 @@ def build_steps(
chunk_column_name = config.get("chunk_column", "chunk")
chunk_by_columns = config.get("chunk_by", []) or []
n_tokens_column_name = config.get("n_tokens_column", "n_tokens")
text_chunk = config.get("text_chunk", {})
return [
{
"verb": "orderby",
"verb": "create_base_text_units",
"args": {
"orders": [
# sort for reproducibility
{"column": "id", "direction": "asc"},
]
"chunk_column_name": chunk_column_name,
"n_tokens_column_name": n_tokens_column_name,
"chunk_by_columns": chunk_by_columns,
**text_chunk,
},
"input": {"source": DEFAULT_INPUT_NAME},
},
{
"verb": "zip",
"args": {
# Pack the document ids with the text
# So when we unpack the chunks, we can restore the document id
"columns": ["id", "text"],
"to": "text_with_ids",
},
},
{
"verb": "aggregate_override",
"args": {
"groupby": [*chunk_by_columns] if len(chunk_by_columns) > 0 else None,
"aggregations": [
{
"column": "text_with_ids",
"operation": "array_agg",
"to": "texts",
}
],
},
},
{
"verb": "chunk",
"args": {"column": "texts", "to": "chunks", **config.get("text_chunk", {})},
},
{
"verb": "select",
"args": {
"columns": [*chunk_by_columns, "chunks"],
},
},
{
"verb": "unroll",
"args": {
"column": "chunks",
},
},
{
"verb": "rename",
"args": {
"columns": {
"chunks": chunk_column_name,
}
},
},
{
"verb": "genid",
"args": {
# Generate a unique id for each chunk
"to": "chunk_id",
"method": "md5_hash",
"hash": [chunk_column_name],
},
},
{
"verb": "unzip",
"args": {
"column": chunk_column_name,
"to": ["document_ids", chunk_column_name, n_tokens_column_name],
},
},
{"verb": "copy", "args": {"column": "chunk_id", "to": "id"}},
{
# ELIMINATE EMPTY CHUNKS
"verb": "filter",
"args": {
"column": chunk_column_name,
"criteria": [
{
"type": "value",
"operator": "is not empty",
}
],
},
},
]

View File

@ -21,70 +21,19 @@ def build_steps(
* `workflow:create_base_extracted_entities`
"""
claim_extract_config = config.get("claim_extract", {})
input = {"source": "workflow:create_base_text_units"}
chunk_column = config.get("chunk_column", "chunk")
chunk_id_column = config.get("chunk_id_column", "chunk_id")
async_mode = config.get("async_mode", AsyncType.AsyncIO)
return [
{
"verb": "extract_covariates",
"verb": "create_final_covariates",
"args": {
"column": config.get("chunk_column", "chunk"),
"id_column": config.get("chunk_id_column", "chunk_id"),
"resolved_entities_column": "resolved_entities",
"column": chunk_column,
"id_column": chunk_id_column,
"covariate_type": "claim",
"async_mode": config.get("async_mode", AsyncType.AsyncIO),
"async_mode": async_mode,
**claim_extract_config,
},
"input": input,
},
{
"verb": "window",
"args": {"to": "id", "operation": "uuid", "column": "covariate_type"},
},
{
"verb": "genid",
"args": {
"to": "human_readable_id",
"method": "increment",
},
},
{
"verb": "convert",
"args": {
"column": "human_readable_id",
"type": "string",
"to": "human_readable_id",
},
},
{
"verb": "rename",
"args": {
"columns": {
"chunk_id": "text_unit_id",
}
},
},
{
"verb": "select",
"args": {
"columns": [
"id",
"human_readable_id",
"covariate_type",
"type",
"description",
"subject_id",
"subject_type",
"object_id",
"object_type",
"status",
"start_date",
"end_date",
"source_text",
"text_unit_id",
"document_ids",
"n_tokens",
]
},
"input": {"source": "workflow:create_base_text_units"},
},
]

View File

@ -16,7 +16,6 @@ def build_steps(
## Dependencies
* `workflow:create_base_documents`
* `workflow:create_base_document_nodes`
"""
base_text_embed = config.get("text_embed", {})
document_raw_content_embed_config = config.get(
@ -25,17 +24,12 @@ def build_steps(
skip_raw_content_embedding = config.get("skip_raw_content_embedding", False)
return [
{
"verb": "rename",
"args": {"columns": {"text_units": "text_unit_ids"}},
"verb": "create_final_documents",
"args": {
"columns": {"text_units": "text_unit_ids"},
"skip_embedding": skip_raw_content_embedding,
"text_embed": document_raw_content_embed_config,
},
"input": {"source": "workflow:create_base_documents"},
},
{
"verb": "text_embed",
"enabled": not skip_raw_content_embedding,
"args": {
"column": "raw_content",
"to": "raw_content_embedding",
**document_raw_content_embed_config,
},
},
]

View File

@ -23,30 +23,15 @@ def build_steps(
"relationship_description_embed", base_text_embed
)
skip_description_embedding = config.get("skip_description_embedding", False)
return [
{
"id": "pre_embedding",
"verb": "create_final_relationships_pre_embedding",
"input": {"source": "workflow:create_base_entity_graph"},
},
{
"id": "description_embedding",
"verb": "text_embed",
"enabled": not skip_description_embedding,
"verb": "create_final_relationships",
"args": {
"embedding_name": "relationship_description",
"column": "description",
"to": "description_embedding",
**relationship_description_embed_config,
"skip_embedding": skip_description_embedding,
"text_embed": relationship_description_embed_config,
},
},
{
"verb": "create_final_relationships_post_embedding",
"input": {
"source": "pre_embedding"
if skip_description_embedding
else "description_embedding",
"source": "workflow:create_base_entity_graph",
"nodes": "workflow:create_final_nodes",
},
},

View File

@ -4,21 +4,23 @@
"""The Indexing Engine workflows -> subflows package root."""
from .create_base_documents import create_base_documents
from .create_base_text_units import create_base_text_units
from .create_final_communities import create_final_communities
from .create_final_covariates import create_final_covariates
from .create_final_documents import create_final_documents
from .create_final_nodes import create_final_nodes
from .create_final_relationships_post_embedding import (
create_final_relationships_post_embedding,
)
from .create_final_relationships_pre_embedding import (
create_final_relationships_pre_embedding,
from .create_final_relationships import (
create_final_relationships,
)
from .create_final_text_units_pre_embedding import create_final_text_units_pre_embedding
__all__ = [
"create_base_documents",
"create_base_text_units",
"create_final_communities",
"create_final_covariates",
"create_final_documents",
"create_final_nodes",
"create_final_relationships_post_embedding",
"create_final_relationships_pre_embedding",
"create_final_relationships",
"create_final_text_units_pre_embedding",
]

View File

@ -13,8 +13,6 @@ from datashaper import (
)
from datashaper.table_store.types import VerbResult, create_verb_result
from graphrag.index.verbs.overrides.aggregate import aggregate_df
@verb(name="create_base_documents", treats_input_tables_as_immutable=True)
def create_base_documents(
@ -26,16 +24,16 @@ def create_base_documents(
source = cast(pd.DataFrame, input.get_input())
text_units = cast(pd.DataFrame, input.get_others()[0])
text_units = cast(
pd.DataFrame, text_units.explode("document_ids")[["id", "document_ids", "text"]]
)
text_units.rename(
columns={
"document_ids": "chunk_doc_id",
"id": "chunk_id",
"text": "chunk_text",
},
inplace=True,
text_units = (
text_units.explode("document_ids")
.loc[:, ["id", "document_ids", "text"]]
.rename(
columns={
"document_ids": "chunk_doc_id",
"id": "chunk_id",
"text": "chunk_text",
}
)
)
joined = text_units.merge(
@ -43,38 +41,37 @@ def create_base_documents(
left_on="chunk_doc_id",
right_on="id",
how="inner",
copy=False,
)
docs_with_text_units = aggregate_df(
joined,
groupby=["id"],
aggregations=[
{
"column": "chunk_id",
"operation": "array_agg",
"to": "text_units",
}
],
docs_with_text_units = joined.groupby("id", sort=False).agg(
text_units=("chunk_id", list)
)
rejoined = docs_with_text_units.merge(
source,
on="id",
how="right",
)
copy=False,
).reset_index(drop=True)
rejoined.rename(columns={"text": "raw_content"}, inplace=True)
rejoined["id"] = rejoined["id"].astype(str)
# attribute columns are converted to strings and then collapsed into a single json object
# Convert attribute columns to strings and collapse them into a JSON object
if document_attribute_columns:
for column in document_attribute_columns:
rejoined[column] = rejoined[column].astype(str)
rejoined["attributes"] = rejoined[document_attribute_columns].apply(
lambda row: {**row},
axis=1,
# Convert all specified columns to string at once
rejoined[document_attribute_columns] = rejoined[
document_attribute_columns
].astype(str)
# Collapse the document_attribute_columns into a single JSON object column
rejoined["attributes"] = rejoined[document_attribute_columns].to_dict(
orient="records"
)
# Drop the original attribute columns after collapsing them
rejoined.drop(columns=document_attribute_columns, inplace=True)
rejoined.reset_index()
return create_verb_result(
cast(

View File

@ -0,0 +1,86 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""All the steps to transform base text_units."""
from typing import Any, cast
import pandas as pd
from datashaper import (
Table,
VerbCallbacks,
VerbInput,
verb,
)
from datashaper.table_store.types import VerbResult, create_verb_result
from graphrag.index.verbs.genid import genid_df
from graphrag.index.verbs.overrides.aggregate import aggregate_df
from graphrag.index.verbs.text.chunk.text_chunk import chunk_df
@verb(name="create_base_text_units", treats_input_tables_as_immutable=True)
def create_base_text_units(
input: VerbInput,
callbacks: VerbCallbacks,
chunk_column_name: str,
n_tokens_column_name: str,
chunk_by_columns: list[str],
strategy: dict[str, Any] | None = None,
**_kwargs: dict,
) -> VerbResult:
"""All the steps to transform base text_units."""
table = cast(pd.DataFrame, input.get_input())
sort = table.sort_values(by=["id"], ascending=[True])
sort["text_with_ids"] = list(
zip(*[sort[col] for col in ["id", "text"]], strict=True)
)
aggregated = aggregate_df(
sort,
groupby=[*chunk_by_columns] if len(chunk_by_columns) > 0 else None,
aggregations=[
{
"column": "text_with_ids",
"operation": "array_agg",
"to": "texts",
}
],
)
chunked = chunk_df(
aggregated,
column="texts",
to="chunks",
callbacks=callbacks,
strategy=strategy,
)
chunked = cast(pd.DataFrame, chunked[[*chunk_by_columns, "chunks"]])
chunked = chunked.explode("chunks")
chunked.rename(
columns={
"chunks": chunk_column_name,
},
inplace=True,
)
chunked = genid_df(
chunked, to="chunk_id", method="md5_hash", hash=[chunk_column_name]
)
chunked[["document_ids", chunk_column_name, n_tokens_column_name]] = pd.DataFrame(
chunked[chunk_column_name].tolist(), index=chunked.index
)
chunked["id"] = chunked["chunk_id"]
filtered = chunked[chunked[chunk_column_name].notna()].reset_index(drop=True)
return create_verb_result(
cast(
Table,
filtered,
)
)

View File

@ -0,0 +1,78 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""All the steps to extract and format covariates."""
from typing import Any, cast
from uuid import uuid4
import pandas as pd
from datashaper import (
AsyncType,
Table,
VerbCallbacks,
VerbInput,
verb,
)
from datashaper.table_store.types import VerbResult, create_verb_result
from graphrag.index.cache import PipelineCache
from graphrag.index.verbs.covariates.extract_covariates.extract_covariates import (
extract_covariates_df,
)
@verb(name="create_final_covariates", treats_input_tables_as_immutable=True)
async def create_final_covariates(
input: VerbInput,
cache: PipelineCache,
callbacks: VerbCallbacks,
column: str,
covariate_type: str,
strategy: dict[str, Any] | None,
async_mode: AsyncType = AsyncType.AsyncIO,
entity_types: list[str] | None = None,
**kwargs: dict,
) -> VerbResult:
"""All the steps to extract and format covariates."""
source = cast(pd.DataFrame, input.get_input())
covariates = await extract_covariates_df(
source,
cache,
callbacks,
column,
covariate_type,
strategy,
async_mode,
entity_types,
**kwargs,
)
covariates["id"] = covariates["covariate_type"].apply(lambda _x: str(uuid4()))
covariates["human_readable_id"] = (covariates.index + 1).astype(str)
covariates.rename(columns={"chunk_id": "text_unit_id"}, inplace=True)
return create_verb_result(
cast(
Table,
covariates[
[
"id",
"human_readable_id",
"covariate_type",
"type",
"description",
"subject_id",
"object_id",
"status",
"start_date",
"end_date",
"source_text",
"text_unit_id",
"document_ids",
"n_tokens",
]
],
)
)

View File

@ -0,0 +1,48 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""All the steps to transform final documents."""
from typing import cast
import pandas as pd
from datashaper import (
Table,
VerbCallbacks,
VerbInput,
verb,
)
from datashaper.table_store.types import VerbResult, create_verb_result
from graphrag.index.cache import PipelineCache
from graphrag.index.verbs.text.embed.text_embed import text_embed_df
@verb(
name="create_final_documents",
treats_input_tables_as_immutable=True,
)
async def create_final_documents(
input: VerbInput,
callbacks: VerbCallbacks,
cache: PipelineCache,
text_embed: dict,
skip_embedding: bool = False,
**_kwargs: dict,
) -> VerbResult:
"""All the steps to transform final documents."""
source = cast(pd.DataFrame, input.get_input())
source.rename(columns={"text_units": "text_unit_ids"}, inplace=True)
if not skip_embedding:
source = await text_embed_df(
source,
callbacks,
cache,
column="raw_content",
strategy=text_embed["strategy"],
to="raw_content_embedding",
)
return create_verb_result(cast(Table, source))

View File

@ -1,37 +1,64 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""All the steps to transform final relationships after they are embedded."""
"""All the steps to transform final relationships."""
from typing import Any, cast
import pandas as pd
from datashaper import (
Table,
VerbCallbacks,
VerbInput,
verb,
)
from datashaper.table_store.types import VerbResult, create_verb_result
from graphrag.index.cache import PipelineCache
from graphrag.index.utils.ds_util import get_required_input_table
from graphrag.index.verbs.graph.compute_edge_combined_degree import (
compute_edge_combined_degree_df,
)
from graphrag.index.verbs.graph.unpack import unpack_graph_df
from graphrag.index.verbs.text.embed.text_embed import text_embed_df
@verb(
name="create_final_relationships_post_embedding",
name="create_final_relationships",
treats_input_tables_as_immutable=True,
)
def create_final_relationships_post_embedding(
async def create_final_relationships(
input: VerbInput,
callbacks: VerbCallbacks,
cache: PipelineCache,
text_embed: dict,
skip_embedding: bool = False,
**_kwargs: dict,
) -> VerbResult:
"""All the steps to transform final relationships after they are embedded."""
"""All the steps to transform final relationships."""
table = cast(pd.DataFrame, input.get_input())
nodes = cast(pd.DataFrame, get_required_input_table(input, "nodes").table)
pruned_edges = table.drop(columns=["level"])
graph_edges = unpack_graph_df(table, callbacks, "clustered_graph", "edges")
graph_edges.rename(columns={"source_id": "text_unit_ids"}, inplace=True)
filtered = cast(
pd.DataFrame, graph_edges[graph_edges["level"] == 0].reset_index(drop=True)
)
if not skip_embedding:
filtered = await text_embed_df(
filtered,
callbacks,
cache,
column="description",
strategy=text_embed["strategy"],
to="description_embedding",
embedding_name="relationship_description",
)
pruned_edges = filtered.drop(columns=["level"])
filtered_nodes = cast(
pd.DataFrame,

View File

@ -1,38 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""All the steps to transform final relationships before they are embedded."""
from typing import cast
import pandas as pd
from datashaper import (
Table,
VerbCallbacks,
VerbInput,
verb,
)
from datashaper.table_store.types import VerbResult, create_verb_result
from graphrag.index.verbs.graph.unpack import unpack_graph_df
@verb(
name="create_final_relationships_pre_embedding",
treats_input_tables_as_immutable=True,
)
def create_final_relationships_pre_embedding(
input: VerbInput,
callbacks: VerbCallbacks,
**_kwargs: dict,
) -> VerbResult:
"""All the steps to transform final relationships before they are embedded."""
table = cast(pd.DataFrame, input.get_input())
graph_edges = unpack_graph_df(table, callbacks, "clustered_graph", "edges")
graph_edges.rename(columns={"source_id": "text_unit_ids"}, inplace=True)
filtered = graph_edges[graph_edges["level"] == 0].reset_index(drop=True)
return create_verb_result(cast(Table, filtered))

View File

@ -104,7 +104,7 @@ def try_parse_json_object(input: str) -> tuple[str, dict]:
return input, result
_pattern = r"\{(.*)\}"
_match = re.search(_pattern, input)
_match = re.search(_pattern, input, re.DOTALL)
input = "{" + _match.group(1) + "}" if _match else input
# Clean up json string.

View File

@ -41,7 +41,6 @@ class Covariate(Identified):
d: dict[str, Any],
id_key: str = "id",
subject_id_key: str = "subject_id",
subject_type_key: str = "subject_type",
covariate_type_key: str = "covariate_type",
short_id_key: str = "short_id",
text_unit_ids_key: str = "text_unit_ids",
@ -53,7 +52,6 @@ class Covariate(Identified):
id=d[id_key],
short_id=d.get(short_id_key),
subject_id=d[subject_id_key],
subject_type=d.get(subject_type_key, "entity"),
covariate_type=d.get(covariate_type_key, "claim"),
text_unit_ids=d.get(text_unit_ids_key),
document_ids=d.get(document_ids_key),

View File

@ -157,8 +157,7 @@ def read_covariates(
id_col: str = "id",
short_id_col: str | None = "short_id",
subject_col: str = "subject_id",
subject_type_col: str | None = "subject_type",
covariate_type_col: str | None = "covariate_type",
covariate_type_col: str | None = "type",
text_unit_ids_col: str | None = "text_unit_ids",
document_ids_col: str | None = "document_ids",
attributes_cols: list[str] | None = None,
@ -170,9 +169,6 @@ def read_covariates(
id=to_str(row, id_col),
short_id=to_optional_str(row, short_id_col) if short_id_col else str(idx),
subject_id=to_str(row, subject_col),
subject_type=(
to_str(row, subject_type_col) if subject_type_col else "entity"
),
covariate_type=(
to_str(row, covariate_type_col) if covariate_type_col else "claim"
),

View File

@ -7,7 +7,7 @@
1,
2000
],
"subworkflows": 11,
"subworkflows": 1,
"max_runtime": 10
},
"create_base_extracted_entities": {
@ -52,7 +52,7 @@
1,
2000
],
"subworkflows": 2,
"subworkflows": 1,
"max_runtime": 100
},
"create_final_nodes": {

View File

@ -7,7 +7,7 @@
1,
2000
],
"subworkflows": 11,
"subworkflows": 1,
"max_runtime": 10
},
"create_base_extracted_entities": {
@ -26,15 +26,13 @@
"nan_allowed_columns": [
"type",
"description",
"subject_type",
"object_id",
"object_type",
"status",
"start_date",
"end_date",
"source_text"
],
"subworkflows": 6,
"subworkflows": 1,
"max_runtime": 300
},
"create_summarized_entities": {
@ -71,7 +69,7 @@
1,
2000
],
"subworkflows": 2,
"subworkflows": 1,
"max_runtime": 100
},
"create_final_nodes": {

View File

@ -0,0 +1,35 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from graphrag.index.workflows.v1.create_base_text_units import (
build_steps,
workflow_name,
)
from .util import (
compare_outputs,
get_config_for_workflow,
get_workflow_output,
load_expected,
load_input_tables,
)
async def test_create_base_text_units():
input_tables = load_input_tables(inputs=[])
expected = load_expected(workflow_name)
config = get_config_for_workflow(workflow_name)
# test data was created with 4o, so we need to match the encoding for chunks to be identical
config["text_chunk"]["strategy"]["encoding_name"] = "o200k_base"
steps = build_steps(config)
actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
)
compare_outputs(actual, expected)

View File

@ -0,0 +1,68 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from pandas.testing import assert_series_equal
from graphrag.index.workflows.v1.create_final_covariates import (
build_steps,
workflow_name,
)
from .util import (
get_config_for_workflow,
get_workflow_output,
load_expected,
load_input_tables,
)
async def test_create_final_covariates():
input_tables = load_input_tables(["workflow:create_base_text_units"])
expected = load_expected(workflow_name)
config = get_config_for_workflow(workflow_name)
# deleting the llm config results in a default mock injection in run_gi_extract_claims
del config["claim_extract"]["strategy"]["llm"]
steps = build_steps(config)
actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
)
input = input_tables["workflow:create_base_text_units"]
# we removed the subject_type and object_type columns so expect two less columns than the pre-refactor outputs
assert len(actual.columns) == (len(expected.columns) - 2)
# our mock only returns one covariate per text unit, so that's a 1:1 mapping versus the LLM-extracted content in the test data
assert len(actual) == len(input)
# assert all of the columns that covariates copied from the input
assert_series_equal(actual["text_unit_id"], input["id"], check_names=False)
assert_series_equal(actual["text_unit_id"], input["chunk_id"], check_names=False)
assert_series_equal(actual["document_ids"], input["document_ids"])
assert_series_equal(actual["n_tokens"], input["n_tokens"])
# make sure the human ids are incrementing and cast to strings
assert actual["human_readable_id"][0] == "1"
assert actual["human_readable_id"][1] == "2"
# check that the mock data is parsed and inserted into the correct columns
assert actual["covariate_type"][0] == "claim"
assert actual["subject_id"][0] == "COMPANY A"
assert actual["object_id"][0] == "GOVERNMENT AGENCY B"
assert actual["type"][0] == "ANTI-COMPETITIVE PRACTICES"
assert actual["status"][0] == "TRUE"
assert actual["start_date"][0] == "2022-01-10T00:00:00"
assert actual["end_date"][0] == "2022-01-10T00:00:00"
assert (
actual["description"][0]
== "Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10"
)
assert (
actual["source_text"][0]
== "According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B."
)

View File

@ -0,0 +1,66 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from graphrag.index.workflows.v1.create_final_documents import (
build_steps,
workflow_name,
)
from .util import (
compare_outputs,
get_config_for_workflow,
get_workflow_output,
load_expected,
load_input_tables,
remove_disabled_steps,
)
async def test_create_final_documents():
input_tables = load_input_tables([
"workflow:create_base_documents",
])
expected = load_expected(workflow_name)
config = get_config_for_workflow(workflow_name)
config["skip_raw_content_embedding"] = True
steps = remove_disabled_steps(build_steps(config))
actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
)
compare_outputs(actual, expected)
async def test_create_final_documents_with_embeddings():
input_tables = load_input_tables([
"workflow:create_base_documents",
])
expected = load_expected(workflow_name)
config = get_config_for_workflow(workflow_name)
config["skip_raw_content_embedding"] = False
# default config has a detailed standard embed config
# just override the strategy to mock so the rest of the required parameters are in place
config["document_raw_content_embed"]["strategy"]["type"] = "mock"
steps = remove_disabled_steps(build_steps(config))
actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
)
assert "raw_content_embedding" in actual.columns
assert len(actual.columns) == len(expected.columns) + 1
# the mock impl returns an array of 3 floats for each embedding
assert len(actual["raw_content_embedding"][0]) == 3

View File

@ -37,3 +37,32 @@ async def test_create_final_relationships():
)
compare_outputs(actual, expected)
async def test_create_final_relationships_with_embeddings():
input_tables = load_input_tables([
"workflow:create_base_entity_graph",
"workflow:create_final_nodes",
])
expected = load_expected(workflow_name)
config = get_config_for_workflow(workflow_name)
config["skip_description_embedding"] = False
# default config has a detailed standard embed config
# just override the strategy to mock so the rest of the required parameters are in place
config["relationship_description_embed"]["strategy"]["type"] = "mock"
steps = remove_disabled_steps(build_steps(config))
actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
)
assert "description_embedding" in actual.columns
assert len(actual.columns) == len(expected.columns) + 1
# the mock impl returns an array of 3 floats for each embedding
assert len(actual["description_embedding"][0]) == 3

View File

@ -13,6 +13,7 @@ from graphrag.index import (
PipelineWorkflowStep,
create_pipeline_config,
)
from graphrag.index.run.utils import _create_run_context
def load_input_tables(inputs: list[str]) -> dict[str, pd.DataFrame]:
@ -31,6 +32,7 @@ def load_input_tables(inputs: list[str]) -> dict[str, pd.DataFrame]:
# remove the workflow: prefix if it exists, because that is not part of the actual table filename
name = input.replace("workflow:", "")
input_tables[input] = pd.read_parquet(f"tests/verbs/data/{name}.parquet")
return input_tables
@ -42,8 +44,12 @@ def load_expected(output: str) -> pd.DataFrame:
def get_config_for_workflow(name: str) -> PipelineWorkflowConfig:
"""Instantiates the bare minimum config to get a default workflow config for testing."""
config = create_graphrag_config()
# this flag needs to be set before creating the pipeline config, or the entire covariate workflow will be excluded
config.claim_extraction.enabled = True
pipeline_config = create_pipeline_config(config)
print(pipeline_config.workflows)
result = next(conf for conf in pipeline_config.workflows if conf.name == name)
return cast(PipelineWorkflowConfig, result.config)
@ -59,7 +65,9 @@ async def get_workflow_output(
input_tables=input_tables,
)
await workflow.run()
context = _create_run_context(None, None, None)
await workflow.run(context=context)
# if there's only one output, it is the default here, no name required
return cast(pd.DataFrame, workflow.output())