mirror of
https://github.com/microsoft/graphrag.git
synced 2025-08-18 05:31:30 +00:00
Merge branch 'main' into incremental_indexing/main
This commit is contained in:
commit
bf45f42969
@ -0,0 +1,4 @@
|
|||||||
|
{
|
||||||
|
"type": "patch",
|
||||||
|
"description": "Collapse create_base_text_units."
|
||||||
|
}
|
@ -0,0 +1,4 @@
|
|||||||
|
{
|
||||||
|
"type": "patch",
|
||||||
|
"description": "Merge text_embed into create-final-relationships subflow."
|
||||||
|
}
|
@ -0,0 +1,4 @@
|
|||||||
|
{
|
||||||
|
"type": "patch",
|
||||||
|
"description": "Optimize Create Base Documents subflow"
|
||||||
|
}
|
@ -0,0 +1,4 @@
|
|||||||
|
{
|
||||||
|
"type": "patch",
|
||||||
|
"description": "Collapse covariates flow."
|
||||||
|
}
|
@ -0,0 +1,4 @@
|
|||||||
|
{
|
||||||
|
"type": "patch",
|
||||||
|
"description": "Collapse create-final-documents."
|
||||||
|
}
|
@ -0,0 +1,4 @@
|
|||||||
|
{
|
||||||
|
"type": "patch",
|
||||||
|
"description": "Fix nested json parsing"
|
||||||
|
}
|
@ -100,6 +100,7 @@ aembed
|
|||||||
dedupe
|
dedupe
|
||||||
dropna
|
dropna
|
||||||
dtypes
|
dtypes
|
||||||
|
notna
|
||||||
|
|
||||||
# LLM Terms
|
# LLM Terms
|
||||||
AOAI
|
AOAI
|
||||||
|
@ -152,7 +152,6 @@ class ClaimExtractor:
|
|||||||
subject = resolved_entities.get(subject, subject)
|
subject = resolved_entities.get(subject, subject)
|
||||||
claim["object_id"] = obj
|
claim["object_id"] = obj
|
||||||
claim["subject_id"] = subject
|
claim["subject_id"] = subject
|
||||||
claim["doc_id"] = document_id
|
|
||||||
return claim
|
return claim
|
||||||
|
|
||||||
async def _process_document(
|
async def _process_document(
|
||||||
@ -200,10 +199,7 @@ class ClaimExtractor:
|
|||||||
if response.output != "YES":
|
if response.output != "YES":
|
||||||
break
|
break
|
||||||
|
|
||||||
result = self._parse_claim_tuples(results, prompt_args)
|
return self._parse_claim_tuples(results, prompt_args)
|
||||||
for r in result:
|
|
||||||
r["doc_id"] = f"{doc_index}"
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _parse_claim_tuples(
|
def _parse_claim_tuples(
|
||||||
self, claims: str, prompt_variables: dict
|
self, claims: str, prompt_variables: dict
|
||||||
@ -243,6 +239,5 @@ class ClaimExtractor:
|
|||||||
"end_date": pull_field(5, claim_fields),
|
"end_date": pull_field(5, claim_fields),
|
||||||
"description": pull_field(6, claim_fields),
|
"description": pull_field(6, claim_fields),
|
||||||
"source_text": pull_field(7, claim_fields),
|
"source_text": pull_field(7, claim_fields),
|
||||||
"doc_id": pull_field(8, claim_fields),
|
|
||||||
})
|
})
|
||||||
return result
|
return result
|
||||||
|
@ -49,16 +49,37 @@ async def extract_covariates(
|
|||||||
entity_types: list[str] | None = None,
|
entity_types: list[str] | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> TableContainer:
|
) -> TableContainer:
|
||||||
"""
|
"""Extract claims from a piece of text."""
|
||||||
Extract claims from a piece of text.
|
source = cast(pd.DataFrame, input.get_input())
|
||||||
|
output = await extract_covariates_df(
|
||||||
|
source,
|
||||||
|
cache,
|
||||||
|
callbacks,
|
||||||
|
column,
|
||||||
|
covariate_type,
|
||||||
|
strategy,
|
||||||
|
async_mode,
|
||||||
|
entity_types,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return TableContainer(table=output)
|
||||||
|
|
||||||
## Usage
|
|
||||||
TODO
|
async def extract_covariates_df(
|
||||||
"""
|
input: pd.DataFrame,
|
||||||
|
cache: PipelineCache,
|
||||||
|
callbacks: VerbCallbacks,
|
||||||
|
column: str,
|
||||||
|
covariate_type: str,
|
||||||
|
strategy: dict[str, Any] | None,
|
||||||
|
async_mode: AsyncType = AsyncType.AsyncIO,
|
||||||
|
entity_types: list[str] | None = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Extract claims from a piece of text."""
|
||||||
log.debug("extract_covariates strategy=%s", strategy)
|
log.debug("extract_covariates strategy=%s", strategy)
|
||||||
if entity_types is None:
|
if entity_types is None:
|
||||||
entity_types = DEFAULT_ENTITY_TYPES
|
entity_types = DEFAULT_ENTITY_TYPES
|
||||||
output = cast(pd.DataFrame, input.get_input())
|
|
||||||
|
|
||||||
resolved_entities_map = {}
|
resolved_entities_map = {}
|
||||||
|
|
||||||
@ -79,14 +100,13 @@ async def extract_covariates(
|
|||||||
]
|
]
|
||||||
|
|
||||||
results = await derive_from_rows(
|
results = await derive_from_rows(
|
||||||
output,
|
input,
|
||||||
run_strategy,
|
run_strategy,
|
||||||
callbacks,
|
callbacks,
|
||||||
scheduling_type=async_mode,
|
scheduling_type=async_mode,
|
||||||
num_threads=kwargs.get("num_threads", 4),
|
num_threads=kwargs.get("num_threads", 4),
|
||||||
)
|
)
|
||||||
output = pd.DataFrame([item for row in results for item in row or []])
|
return pd.DataFrame([item for row in results for item in row or []])
|
||||||
return TableContainer(table=output)
|
|
||||||
|
|
||||||
|
|
||||||
def load_strategy(strategy_type: ExtractClaimsStrategyType) -> CovariateExtractStrategy:
|
def load_strategy(strategy_type: ExtractClaimsStrategyType) -> CovariateExtractStrategy:
|
||||||
@ -103,8 +123,4 @@ def load_strategy(strategy_type: ExtractClaimsStrategyType) -> CovariateExtractS
|
|||||||
|
|
||||||
def create_row_from_claim_data(row, covariate_data: Covariate, covariate_type: str):
|
def create_row_from_claim_data(row, covariate_data: Covariate, covariate_type: str):
|
||||||
"""Create a row from the claim data and the input row."""
|
"""Create a row from the claim data and the input row."""
|
||||||
item = {**row, **asdict(covariate_data), "covariate_type": covariate_type}
|
return {**row, **asdict(covariate_data), "covariate_type": covariate_type}
|
||||||
# TODO: doc_id from extraction isn't necessary
|
|
||||||
# since chunking happens before this
|
|
||||||
del item["doc_id"]
|
|
||||||
return item
|
|
||||||
|
@ -5,17 +5,6 @@
|
|||||||
|
|
||||||
MOCK_LLM_RESPONSES = [
|
MOCK_LLM_RESPONSES = [
|
||||||
"""
|
"""
|
||||||
[
|
(COMPANY A<|>GOVERNMENT AGENCY B<|>ANTI-COMPETITIVE PRACTICES<|>TRUE<|>2022-01-10T00:00:00<|>2022-01-10T00:00:00<|>Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10<|>According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B.)
|
||||||
{
|
|
||||||
"subject": "COMPANY A",
|
|
||||||
"object": "GOVERNMENT AGENCY B",
|
|
||||||
"type": "ANTI-COMPETITIVE PRACTICES",
|
|
||||||
"status": "TRUE",
|
|
||||||
"start_date": "2022-01-10T00:00:00",
|
|
||||||
"end_date": "2022-01-10T00:00:00",
|
|
||||||
"description": "Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10",
|
|
||||||
"source_text": ["According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B."]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
""".strip()
|
""".strip()
|
||||||
]
|
]
|
||||||
|
@ -91,16 +91,13 @@ def create_covariate(item: dict[str, Any]) -> Covariate:
|
|||||||
"""Create a covariate from the item."""
|
"""Create a covariate from the item."""
|
||||||
return Covariate(
|
return Covariate(
|
||||||
subject_id=item.get("subject_id"),
|
subject_id=item.get("subject_id"),
|
||||||
subject_type=item.get("subject_type"),
|
|
||||||
object_id=item.get("object_id"),
|
object_id=item.get("object_id"),
|
||||||
object_type=item.get("object_type"),
|
|
||||||
type=item.get("type"),
|
type=item.get("type"),
|
||||||
status=item.get("status"),
|
status=item.get("status"),
|
||||||
start_date=item.get("start_date"),
|
start_date=item.get("start_date"),
|
||||||
end_date=item.get("end_date"),
|
end_date=item.get("end_date"),
|
||||||
description=item.get("description"),
|
description=item.get("description"),
|
||||||
source_text=item.get("source_text"),
|
source_text=item.get("source_text"),
|
||||||
doc_id=item.get("doc_id"),
|
|
||||||
record_id=item.get("record_id"),
|
record_id=item.get("record_id"),
|
||||||
id=item.get("id"),
|
id=item.get("id"),
|
||||||
)
|
)
|
||||||
|
@ -18,9 +18,7 @@ class Covariate:
|
|||||||
|
|
||||||
covariate_type: str | None = None
|
covariate_type: str | None = None
|
||||||
subject_id: str | None = None
|
subject_id: str | None = None
|
||||||
subject_type: str | None = None
|
|
||||||
object_id: str | None = None
|
object_id: str | None = None
|
||||||
object_type: str | None = None
|
|
||||||
type: str | None = None
|
type: str | None = None
|
||||||
status: str | None = None
|
status: str | None = None
|
||||||
start_date: str | None = None
|
start_date: str | None = None
|
||||||
|
@ -16,7 +16,7 @@ def genid(
|
|||||||
input: VerbInput,
|
input: VerbInput,
|
||||||
to: str,
|
to: str,
|
||||||
method: str = "md5_hash",
|
method: str = "md5_hash",
|
||||||
hash: list[str] = [], # noqa A002
|
hash: list[str] | None = None, # noqa A002
|
||||||
**_kwargs: dict,
|
**_kwargs: dict,
|
||||||
) -> TableContainer:
|
) -> TableContainer:
|
||||||
"""
|
"""
|
||||||
@ -52,15 +52,29 @@ def genid(
|
|||||||
"""
|
"""
|
||||||
data = cast(pd.DataFrame, input.source.table)
|
data = cast(pd.DataFrame, input.source.table)
|
||||||
|
|
||||||
if method == "md5_hash":
|
output = genid_df(data, to, method, hash)
|
||||||
if len(hash) == 0:
|
|
||||||
msg = 'Must specify the "hash" columns to use md5_hash method'
|
return TableContainer(table=output)
|
||||||
|
|
||||||
|
|
||||||
|
def genid_df(
|
||||||
|
input: pd.DataFrame,
|
||||||
|
to: str,
|
||||||
|
method: str = "md5_hash",
|
||||||
|
hash: list[str] | None = None, # noqa A002
|
||||||
|
):
|
||||||
|
"""Generate a unique id for each row in the tabular data."""
|
||||||
|
data = input
|
||||||
|
match method:
|
||||||
|
case "md5_hash":
|
||||||
|
if not hash:
|
||||||
|
msg = 'Must specify the "hash" columns to use md5_hash method'
|
||||||
|
raise ValueError(msg)
|
||||||
|
data[to] = data.apply(lambda row: gen_md5_hash(row, hash), axis=1)
|
||||||
|
case "increment":
|
||||||
|
data[to] = data.index + 1
|
||||||
|
case _:
|
||||||
|
msg = f"Unknown method {method}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
data[to] = data.apply(lambda row: gen_md5_hash(row, hash), axis=1)
|
return data
|
||||||
elif method == "increment":
|
|
||||||
data[to] = data.index + 1
|
|
||||||
else:
|
|
||||||
msg = f"Unknown method {method}"
|
|
||||||
raise ValueError(msg)
|
|
||||||
return TableContainer(table=data)
|
|
||||||
|
@ -85,9 +85,24 @@ def chunk(
|
|||||||
type: sentence
|
type: sentence
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
input_table = cast(pd.DataFrame, input.get_input())
|
||||||
|
|
||||||
|
output = chunk_df(input_table, column, to, callbacks, strategy)
|
||||||
|
|
||||||
|
return TableContainer(table=output)
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_df(
|
||||||
|
input: pd.DataFrame,
|
||||||
|
column: str,
|
||||||
|
to: str,
|
||||||
|
callbacks: VerbCallbacks,
|
||||||
|
strategy: dict[str, Any] | None = None,
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""Chunk a piece of text into smaller pieces."""
|
||||||
|
output = input
|
||||||
if strategy is None:
|
if strategy is None:
|
||||||
strategy = {}
|
strategy = {}
|
||||||
output = cast(pd.DataFrame, input.get_input())
|
|
||||||
strategy_name = strategy.get("type", ChunkStrategyType.tokens)
|
strategy_name = strategy.get("type", ChunkStrategyType.tokens)
|
||||||
strategy_config = {**strategy}
|
strategy_config = {**strategy}
|
||||||
strategy_exec = load_strategy(strategy_name)
|
strategy_exec = load_strategy(strategy_name)
|
||||||
@ -102,7 +117,7 @@ def chunk(
|
|||||||
),
|
),
|
||||||
axis=1,
|
axis=1,
|
||||||
)
|
)
|
||||||
return TableContainer(table=output)
|
return output
|
||||||
|
|
||||||
|
|
||||||
def run_strategy(
|
def run_strategy(
|
||||||
|
@ -79,6 +79,23 @@ async def text_embed(
|
|||||||
<...>
|
<...>
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
input_df = cast(pd.DataFrame, input.get_input())
|
||||||
|
result_df = await text_embed_df(
|
||||||
|
input_df, callbacks, cache, column, strategy, **kwargs
|
||||||
|
)
|
||||||
|
return TableContainer(table=result_df)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: this ultimately just creates a new column, so our embed function could just generate a series instead of updating the dataframe
|
||||||
|
async def text_embed_df(
|
||||||
|
input: pd.DataFrame,
|
||||||
|
callbacks: VerbCallbacks,
|
||||||
|
cache: PipelineCache,
|
||||||
|
column: str,
|
||||||
|
strategy: dict,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Embed a piece of text into a vector space."""
|
||||||
vector_store_config = strategy.get("vector_store")
|
vector_store_config = strategy.get("vector_store")
|
||||||
|
|
||||||
if vector_store_config:
|
if vector_store_config:
|
||||||
@ -113,28 +130,28 @@ async def text_embed(
|
|||||||
|
|
||||||
|
|
||||||
async def _text_embed_in_memory(
|
async def _text_embed_in_memory(
|
||||||
input: VerbInput,
|
input: pd.DataFrame,
|
||||||
callbacks: VerbCallbacks,
|
callbacks: VerbCallbacks,
|
||||||
cache: PipelineCache,
|
cache: PipelineCache,
|
||||||
column: str,
|
column: str,
|
||||||
strategy: dict,
|
strategy: dict,
|
||||||
to: str,
|
to: str,
|
||||||
):
|
):
|
||||||
output_df = cast(pd.DataFrame, input.get_input())
|
output_df = input
|
||||||
strategy_type = strategy["type"]
|
strategy_type = strategy["type"]
|
||||||
strategy_exec = load_strategy(strategy_type)
|
strategy_exec = load_strategy(strategy_type)
|
||||||
strategy_args = {**strategy}
|
strategy_args = {**strategy}
|
||||||
input_table = input.get_input()
|
input_table = input
|
||||||
|
|
||||||
texts: list[str] = input_table[column].to_numpy().tolist()
|
texts: list[str] = input_table[column].to_numpy().tolist()
|
||||||
result = await strategy_exec(texts, callbacks, cache, strategy_args)
|
result = await strategy_exec(texts, callbacks, cache, strategy_args)
|
||||||
|
|
||||||
output_df[to] = result.embeddings
|
output_df[to] = result.embeddings
|
||||||
return TableContainer(table=output_df)
|
return output_df
|
||||||
|
|
||||||
|
|
||||||
async def _text_embed_with_vector_store(
|
async def _text_embed_with_vector_store(
|
||||||
input: VerbInput,
|
input: pd.DataFrame,
|
||||||
callbacks: VerbCallbacks,
|
callbacks: VerbCallbacks,
|
||||||
cache: PipelineCache,
|
cache: PipelineCache,
|
||||||
column: str,
|
column: str,
|
||||||
@ -144,7 +161,7 @@ async def _text_embed_with_vector_store(
|
|||||||
store_in_table: bool = False,
|
store_in_table: bool = False,
|
||||||
to: str = "",
|
to: str = "",
|
||||||
):
|
):
|
||||||
output_df = cast(pd.DataFrame, input.get_input())
|
output_df = input
|
||||||
strategy_type = strategy["type"]
|
strategy_type = strategy["type"]
|
||||||
strategy_exec = load_strategy(strategy_type)
|
strategy_exec = load_strategy(strategy_type)
|
||||||
strategy_args = {**strategy}
|
strategy_args = {**strategy}
|
||||||
@ -179,10 +196,8 @@ async def _text_embed_with_vector_store(
|
|||||||
|
|
||||||
all_results = []
|
all_results = []
|
||||||
|
|
||||||
while insert_batch_size * i < input.get_input().shape[0]:
|
while insert_batch_size * i < input.shape[0]:
|
||||||
batch = input.get_input().iloc[
|
batch = input.iloc[insert_batch_size * i : insert_batch_size * (i + 1)]
|
||||||
insert_batch_size * i : insert_batch_size * (i + 1)
|
|
||||||
]
|
|
||||||
texts: list[str] = batch[column].to_numpy().tolist()
|
texts: list[str] = batch[column].to_numpy().tolist()
|
||||||
titles: list[str] = batch[title_column].to_numpy().tolist()
|
titles: list[str] = batch[title_column].to_numpy().tolist()
|
||||||
ids: list[str] = batch[id_column].to_numpy().tolist()
|
ids: list[str] = batch[id_column].to_numpy().tolist()
|
||||||
@ -218,7 +233,7 @@ async def _text_embed_with_vector_store(
|
|||||||
if store_in_table:
|
if store_in_table:
|
||||||
output_df[to] = all_results
|
output_df[to] = all_results
|
||||||
|
|
||||||
return TableContainer(table=output_df)
|
return output_df
|
||||||
|
|
||||||
|
|
||||||
def _create_vector_store(
|
def _create_vector_store(
|
||||||
|
@ -22,91 +22,16 @@ def build_steps(
|
|||||||
chunk_column_name = config.get("chunk_column", "chunk")
|
chunk_column_name = config.get("chunk_column", "chunk")
|
||||||
chunk_by_columns = config.get("chunk_by", []) or []
|
chunk_by_columns = config.get("chunk_by", []) or []
|
||||||
n_tokens_column_name = config.get("n_tokens_column", "n_tokens")
|
n_tokens_column_name = config.get("n_tokens_column", "n_tokens")
|
||||||
|
text_chunk = config.get("text_chunk", {})
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
"verb": "orderby",
|
"verb": "create_base_text_units",
|
||||||
"args": {
|
"args": {
|
||||||
"orders": [
|
"chunk_column_name": chunk_column_name,
|
||||||
# sort for reproducibility
|
"n_tokens_column_name": n_tokens_column_name,
|
||||||
{"column": "id", "direction": "asc"},
|
"chunk_by_columns": chunk_by_columns,
|
||||||
]
|
**text_chunk,
|
||||||
},
|
},
|
||||||
"input": {"source": DEFAULT_INPUT_NAME},
|
"input": {"source": DEFAULT_INPUT_NAME},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"verb": "zip",
|
|
||||||
"args": {
|
|
||||||
# Pack the document ids with the text
|
|
||||||
# So when we unpack the chunks, we can restore the document id
|
|
||||||
"columns": ["id", "text"],
|
|
||||||
"to": "text_with_ids",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"verb": "aggregate_override",
|
|
||||||
"args": {
|
|
||||||
"groupby": [*chunk_by_columns] if len(chunk_by_columns) > 0 else None,
|
|
||||||
"aggregations": [
|
|
||||||
{
|
|
||||||
"column": "text_with_ids",
|
|
||||||
"operation": "array_agg",
|
|
||||||
"to": "texts",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"verb": "chunk",
|
|
||||||
"args": {"column": "texts", "to": "chunks", **config.get("text_chunk", {})},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"verb": "select",
|
|
||||||
"args": {
|
|
||||||
"columns": [*chunk_by_columns, "chunks"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"verb": "unroll",
|
|
||||||
"args": {
|
|
||||||
"column": "chunks",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"verb": "rename",
|
|
||||||
"args": {
|
|
||||||
"columns": {
|
|
||||||
"chunks": chunk_column_name,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"verb": "genid",
|
|
||||||
"args": {
|
|
||||||
# Generate a unique id for each chunk
|
|
||||||
"to": "chunk_id",
|
|
||||||
"method": "md5_hash",
|
|
||||||
"hash": [chunk_column_name],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"verb": "unzip",
|
|
||||||
"args": {
|
|
||||||
"column": chunk_column_name,
|
|
||||||
"to": ["document_ids", chunk_column_name, n_tokens_column_name],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{"verb": "copy", "args": {"column": "chunk_id", "to": "id"}},
|
|
||||||
{
|
|
||||||
# ELIMINATE EMPTY CHUNKS
|
|
||||||
"verb": "filter",
|
|
||||||
"args": {
|
|
||||||
"column": chunk_column_name,
|
|
||||||
"criteria": [
|
|
||||||
{
|
|
||||||
"type": "value",
|
|
||||||
"operator": "is not empty",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
]
|
]
|
||||||
|
@ -21,70 +21,19 @@ def build_steps(
|
|||||||
* `workflow:create_base_extracted_entities`
|
* `workflow:create_base_extracted_entities`
|
||||||
"""
|
"""
|
||||||
claim_extract_config = config.get("claim_extract", {})
|
claim_extract_config = config.get("claim_extract", {})
|
||||||
|
chunk_column = config.get("chunk_column", "chunk")
|
||||||
input = {"source": "workflow:create_base_text_units"}
|
chunk_id_column = config.get("chunk_id_column", "chunk_id")
|
||||||
|
async_mode = config.get("async_mode", AsyncType.AsyncIO)
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
"verb": "extract_covariates",
|
"verb": "create_final_covariates",
|
||||||
"args": {
|
"args": {
|
||||||
"column": config.get("chunk_column", "chunk"),
|
"column": chunk_column,
|
||||||
"id_column": config.get("chunk_id_column", "chunk_id"),
|
"id_column": chunk_id_column,
|
||||||
"resolved_entities_column": "resolved_entities",
|
|
||||||
"covariate_type": "claim",
|
"covariate_type": "claim",
|
||||||
"async_mode": config.get("async_mode", AsyncType.AsyncIO),
|
"async_mode": async_mode,
|
||||||
**claim_extract_config,
|
**claim_extract_config,
|
||||||
},
|
},
|
||||||
"input": input,
|
"input": {"source": "workflow:create_base_text_units"},
|
||||||
},
|
|
||||||
{
|
|
||||||
"verb": "window",
|
|
||||||
"args": {"to": "id", "operation": "uuid", "column": "covariate_type"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"verb": "genid",
|
|
||||||
"args": {
|
|
||||||
"to": "human_readable_id",
|
|
||||||
"method": "increment",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"verb": "convert",
|
|
||||||
"args": {
|
|
||||||
"column": "human_readable_id",
|
|
||||||
"type": "string",
|
|
||||||
"to": "human_readable_id",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"verb": "rename",
|
|
||||||
"args": {
|
|
||||||
"columns": {
|
|
||||||
"chunk_id": "text_unit_id",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"verb": "select",
|
|
||||||
"args": {
|
|
||||||
"columns": [
|
|
||||||
"id",
|
|
||||||
"human_readable_id",
|
|
||||||
"covariate_type",
|
|
||||||
"type",
|
|
||||||
"description",
|
|
||||||
"subject_id",
|
|
||||||
"subject_type",
|
|
||||||
"object_id",
|
|
||||||
"object_type",
|
|
||||||
"status",
|
|
||||||
"start_date",
|
|
||||||
"end_date",
|
|
||||||
"source_text",
|
|
||||||
"text_unit_id",
|
|
||||||
"document_ids",
|
|
||||||
"n_tokens",
|
|
||||||
]
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
@ -16,7 +16,6 @@ def build_steps(
|
|||||||
|
|
||||||
## Dependencies
|
## Dependencies
|
||||||
* `workflow:create_base_documents`
|
* `workflow:create_base_documents`
|
||||||
* `workflow:create_base_document_nodes`
|
|
||||||
"""
|
"""
|
||||||
base_text_embed = config.get("text_embed", {})
|
base_text_embed = config.get("text_embed", {})
|
||||||
document_raw_content_embed_config = config.get(
|
document_raw_content_embed_config = config.get(
|
||||||
@ -25,17 +24,12 @@ def build_steps(
|
|||||||
skip_raw_content_embedding = config.get("skip_raw_content_embedding", False)
|
skip_raw_content_embedding = config.get("skip_raw_content_embedding", False)
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
"verb": "rename",
|
"verb": "create_final_documents",
|
||||||
"args": {"columns": {"text_units": "text_unit_ids"}},
|
"args": {
|
||||||
|
"columns": {"text_units": "text_unit_ids"},
|
||||||
|
"skip_embedding": skip_raw_content_embedding,
|
||||||
|
"text_embed": document_raw_content_embed_config,
|
||||||
|
},
|
||||||
"input": {"source": "workflow:create_base_documents"},
|
"input": {"source": "workflow:create_base_documents"},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"verb": "text_embed",
|
|
||||||
"enabled": not skip_raw_content_embedding,
|
|
||||||
"args": {
|
|
||||||
"column": "raw_content",
|
|
||||||
"to": "raw_content_embedding",
|
|
||||||
**document_raw_content_embed_config,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
]
|
]
|
||||||
|
@ -23,30 +23,15 @@ def build_steps(
|
|||||||
"relationship_description_embed", base_text_embed
|
"relationship_description_embed", base_text_embed
|
||||||
)
|
)
|
||||||
skip_description_embedding = config.get("skip_description_embedding", False)
|
skip_description_embedding = config.get("skip_description_embedding", False)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
"id": "pre_embedding",
|
"verb": "create_final_relationships",
|
||||||
"verb": "create_final_relationships_pre_embedding",
|
|
||||||
"input": {"source": "workflow:create_base_entity_graph"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "description_embedding",
|
|
||||||
"verb": "text_embed",
|
|
||||||
"enabled": not skip_description_embedding,
|
|
||||||
"args": {
|
"args": {
|
||||||
"embedding_name": "relationship_description",
|
"skip_embedding": skip_description_embedding,
|
||||||
"column": "description",
|
"text_embed": relationship_description_embed_config,
|
||||||
"to": "description_embedding",
|
|
||||||
**relationship_description_embed_config,
|
|
||||||
},
|
},
|
||||||
},
|
|
||||||
{
|
|
||||||
"verb": "create_final_relationships_post_embedding",
|
|
||||||
"input": {
|
"input": {
|
||||||
"source": "pre_embedding"
|
"source": "workflow:create_base_entity_graph",
|
||||||
if skip_description_embedding
|
|
||||||
else "description_embedding",
|
|
||||||
"nodes": "workflow:create_final_nodes",
|
"nodes": "workflow:create_final_nodes",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -4,21 +4,23 @@
|
|||||||
"""The Indexing Engine workflows -> subflows package root."""
|
"""The Indexing Engine workflows -> subflows package root."""
|
||||||
|
|
||||||
from .create_base_documents import create_base_documents
|
from .create_base_documents import create_base_documents
|
||||||
|
from .create_base_text_units import create_base_text_units
|
||||||
from .create_final_communities import create_final_communities
|
from .create_final_communities import create_final_communities
|
||||||
|
from .create_final_covariates import create_final_covariates
|
||||||
|
from .create_final_documents import create_final_documents
|
||||||
from .create_final_nodes import create_final_nodes
|
from .create_final_nodes import create_final_nodes
|
||||||
from .create_final_relationships_post_embedding import (
|
from .create_final_relationships import (
|
||||||
create_final_relationships_post_embedding,
|
create_final_relationships,
|
||||||
)
|
|
||||||
from .create_final_relationships_pre_embedding import (
|
|
||||||
create_final_relationships_pre_embedding,
|
|
||||||
)
|
)
|
||||||
from .create_final_text_units_pre_embedding import create_final_text_units_pre_embedding
|
from .create_final_text_units_pre_embedding import create_final_text_units_pre_embedding
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"create_base_documents",
|
"create_base_documents",
|
||||||
|
"create_base_text_units",
|
||||||
"create_final_communities",
|
"create_final_communities",
|
||||||
|
"create_final_covariates",
|
||||||
|
"create_final_documents",
|
||||||
"create_final_nodes",
|
"create_final_nodes",
|
||||||
"create_final_relationships_post_embedding",
|
"create_final_relationships",
|
||||||
"create_final_relationships_pre_embedding",
|
|
||||||
"create_final_text_units_pre_embedding",
|
"create_final_text_units_pre_embedding",
|
||||||
]
|
]
|
||||||
|
@ -13,8 +13,6 @@ from datashaper import (
|
|||||||
)
|
)
|
||||||
from datashaper.table_store.types import VerbResult, create_verb_result
|
from datashaper.table_store.types import VerbResult, create_verb_result
|
||||||
|
|
||||||
from graphrag.index.verbs.overrides.aggregate import aggregate_df
|
|
||||||
|
|
||||||
|
|
||||||
@verb(name="create_base_documents", treats_input_tables_as_immutable=True)
|
@verb(name="create_base_documents", treats_input_tables_as_immutable=True)
|
||||||
def create_base_documents(
|
def create_base_documents(
|
||||||
@ -26,16 +24,16 @@ def create_base_documents(
|
|||||||
source = cast(pd.DataFrame, input.get_input())
|
source = cast(pd.DataFrame, input.get_input())
|
||||||
text_units = cast(pd.DataFrame, input.get_others()[0])
|
text_units = cast(pd.DataFrame, input.get_others()[0])
|
||||||
|
|
||||||
text_units = cast(
|
text_units = (
|
||||||
pd.DataFrame, text_units.explode("document_ids")[["id", "document_ids", "text"]]
|
text_units.explode("document_ids")
|
||||||
)
|
.loc[:, ["id", "document_ids", "text"]]
|
||||||
text_units.rename(
|
.rename(
|
||||||
columns={
|
columns={
|
||||||
"document_ids": "chunk_doc_id",
|
"document_ids": "chunk_doc_id",
|
||||||
"id": "chunk_id",
|
"id": "chunk_id",
|
||||||
"text": "chunk_text",
|
"text": "chunk_text",
|
||||||
},
|
}
|
||||||
inplace=True,
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
joined = text_units.merge(
|
joined = text_units.merge(
|
||||||
@ -43,38 +41,37 @@ def create_base_documents(
|
|||||||
left_on="chunk_doc_id",
|
left_on="chunk_doc_id",
|
||||||
right_on="id",
|
right_on="id",
|
||||||
how="inner",
|
how="inner",
|
||||||
|
copy=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
docs_with_text_units = aggregate_df(
|
docs_with_text_units = joined.groupby("id", sort=False).agg(
|
||||||
joined,
|
text_units=("chunk_id", list)
|
||||||
groupby=["id"],
|
|
||||||
aggregations=[
|
|
||||||
{
|
|
||||||
"column": "chunk_id",
|
|
||||||
"operation": "array_agg",
|
|
||||||
"to": "text_units",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
rejoined = docs_with_text_units.merge(
|
rejoined = docs_with_text_units.merge(
|
||||||
source,
|
source,
|
||||||
on="id",
|
on="id",
|
||||||
how="right",
|
how="right",
|
||||||
)
|
copy=False,
|
||||||
|
).reset_index(drop=True)
|
||||||
|
|
||||||
rejoined.rename(columns={"text": "raw_content"}, inplace=True)
|
rejoined.rename(columns={"text": "raw_content"}, inplace=True)
|
||||||
rejoined["id"] = rejoined["id"].astype(str)
|
rejoined["id"] = rejoined["id"].astype(str)
|
||||||
|
|
||||||
# attribute columns are converted to strings and then collapsed into a single json object
|
# Convert attribute columns to strings and collapse them into a JSON object
|
||||||
if document_attribute_columns:
|
if document_attribute_columns:
|
||||||
for column in document_attribute_columns:
|
# Convert all specified columns to string at once
|
||||||
rejoined[column] = rejoined[column].astype(str)
|
rejoined[document_attribute_columns] = rejoined[
|
||||||
rejoined["attributes"] = rejoined[document_attribute_columns].apply(
|
document_attribute_columns
|
||||||
lambda row: {**row},
|
].astype(str)
|
||||||
axis=1,
|
|
||||||
|
# Collapse the document_attribute_columns into a single JSON object column
|
||||||
|
rejoined["attributes"] = rejoined[document_attribute_columns].to_dict(
|
||||||
|
orient="records"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Drop the original attribute columns after collapsing them
|
||||||
rejoined.drop(columns=document_attribute_columns, inplace=True)
|
rejoined.drop(columns=document_attribute_columns, inplace=True)
|
||||||
rejoined.reset_index()
|
|
||||||
|
|
||||||
return create_verb_result(
|
return create_verb_result(
|
||||||
cast(
|
cast(
|
||||||
|
@ -0,0 +1,86 @@
|
|||||||
|
# Copyright (c) 2024 Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License
|
||||||
|
|
||||||
|
"""All the steps to transform base text_units."""
|
||||||
|
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from datashaper import (
|
||||||
|
Table,
|
||||||
|
VerbCallbacks,
|
||||||
|
VerbInput,
|
||||||
|
verb,
|
||||||
|
)
|
||||||
|
from datashaper.table_store.types import VerbResult, create_verb_result
|
||||||
|
|
||||||
|
from graphrag.index.verbs.genid import genid_df
|
||||||
|
from graphrag.index.verbs.overrides.aggregate import aggregate_df
|
||||||
|
from graphrag.index.verbs.text.chunk.text_chunk import chunk_df
|
||||||
|
|
||||||
|
|
||||||
|
@verb(name="create_base_text_units", treats_input_tables_as_immutable=True)
|
||||||
|
def create_base_text_units(
|
||||||
|
input: VerbInput,
|
||||||
|
callbacks: VerbCallbacks,
|
||||||
|
chunk_column_name: str,
|
||||||
|
n_tokens_column_name: str,
|
||||||
|
chunk_by_columns: list[str],
|
||||||
|
strategy: dict[str, Any] | None = None,
|
||||||
|
**_kwargs: dict,
|
||||||
|
) -> VerbResult:
|
||||||
|
"""All the steps to transform base text_units."""
|
||||||
|
table = cast(pd.DataFrame, input.get_input())
|
||||||
|
|
||||||
|
sort = table.sort_values(by=["id"], ascending=[True])
|
||||||
|
|
||||||
|
sort["text_with_ids"] = list(
|
||||||
|
zip(*[sort[col] for col in ["id", "text"]], strict=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
aggregated = aggregate_df(
|
||||||
|
sort,
|
||||||
|
groupby=[*chunk_by_columns] if len(chunk_by_columns) > 0 else None,
|
||||||
|
aggregations=[
|
||||||
|
{
|
||||||
|
"column": "text_with_ids",
|
||||||
|
"operation": "array_agg",
|
||||||
|
"to": "texts",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
chunked = chunk_df(
|
||||||
|
aggregated,
|
||||||
|
column="texts",
|
||||||
|
to="chunks",
|
||||||
|
callbacks=callbacks,
|
||||||
|
strategy=strategy,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunked = cast(pd.DataFrame, chunked[[*chunk_by_columns, "chunks"]])
|
||||||
|
chunked = chunked.explode("chunks")
|
||||||
|
chunked.rename(
|
||||||
|
columns={
|
||||||
|
"chunks": chunk_column_name,
|
||||||
|
},
|
||||||
|
inplace=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunked = genid_df(
|
||||||
|
chunked, to="chunk_id", method="md5_hash", hash=[chunk_column_name]
|
||||||
|
)
|
||||||
|
|
||||||
|
chunked[["document_ids", chunk_column_name, n_tokens_column_name]] = pd.DataFrame(
|
||||||
|
chunked[chunk_column_name].tolist(), index=chunked.index
|
||||||
|
)
|
||||||
|
chunked["id"] = chunked["chunk_id"]
|
||||||
|
|
||||||
|
filtered = chunked[chunked[chunk_column_name].notna()].reset_index(drop=True)
|
||||||
|
|
||||||
|
return create_verb_result(
|
||||||
|
cast(
|
||||||
|
Table,
|
||||||
|
filtered,
|
||||||
|
)
|
||||||
|
)
|
@ -0,0 +1,78 @@
|
|||||||
|
# Copyright (c) 2024 Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License
|
||||||
|
|
||||||
|
"""All the steps to extract and format covariates."""
|
||||||
|
|
||||||
|
from typing import Any, cast
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from datashaper import (
|
||||||
|
AsyncType,
|
||||||
|
Table,
|
||||||
|
VerbCallbacks,
|
||||||
|
VerbInput,
|
||||||
|
verb,
|
||||||
|
)
|
||||||
|
from datashaper.table_store.types import VerbResult, create_verb_result
|
||||||
|
|
||||||
|
from graphrag.index.cache import PipelineCache
|
||||||
|
from graphrag.index.verbs.covariates.extract_covariates.extract_covariates import (
|
||||||
|
extract_covariates_df,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@verb(name="create_final_covariates", treats_input_tables_as_immutable=True)
|
||||||
|
async def create_final_covariates(
|
||||||
|
input: VerbInput,
|
||||||
|
cache: PipelineCache,
|
||||||
|
callbacks: VerbCallbacks,
|
||||||
|
column: str,
|
||||||
|
covariate_type: str,
|
||||||
|
strategy: dict[str, Any] | None,
|
||||||
|
async_mode: AsyncType = AsyncType.AsyncIO,
|
||||||
|
entity_types: list[str] | None = None,
|
||||||
|
**kwargs: dict,
|
||||||
|
) -> VerbResult:
|
||||||
|
"""All the steps to extract and format covariates."""
|
||||||
|
source = cast(pd.DataFrame, input.get_input())
|
||||||
|
|
||||||
|
covariates = await extract_covariates_df(
|
||||||
|
source,
|
||||||
|
cache,
|
||||||
|
callbacks,
|
||||||
|
column,
|
||||||
|
covariate_type,
|
||||||
|
strategy,
|
||||||
|
async_mode,
|
||||||
|
entity_types,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
covariates["id"] = covariates["covariate_type"].apply(lambda _x: str(uuid4()))
|
||||||
|
covariates["human_readable_id"] = (covariates.index + 1).astype(str)
|
||||||
|
covariates.rename(columns={"chunk_id": "text_unit_id"}, inplace=True)
|
||||||
|
|
||||||
|
return create_verb_result(
|
||||||
|
cast(
|
||||||
|
Table,
|
||||||
|
covariates[
|
||||||
|
[
|
||||||
|
"id",
|
||||||
|
"human_readable_id",
|
||||||
|
"covariate_type",
|
||||||
|
"type",
|
||||||
|
"description",
|
||||||
|
"subject_id",
|
||||||
|
"object_id",
|
||||||
|
"status",
|
||||||
|
"start_date",
|
||||||
|
"end_date",
|
||||||
|
"source_text",
|
||||||
|
"text_unit_id",
|
||||||
|
"document_ids",
|
||||||
|
"n_tokens",
|
||||||
|
]
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
@ -0,0 +1,48 @@
|
|||||||
|
# Copyright (c) 2024 Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License
|
||||||
|
|
||||||
|
"""All the steps to transform final documents."""
|
||||||
|
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from datashaper import (
|
||||||
|
Table,
|
||||||
|
VerbCallbacks,
|
||||||
|
VerbInput,
|
||||||
|
verb,
|
||||||
|
)
|
||||||
|
from datashaper.table_store.types import VerbResult, create_verb_result
|
||||||
|
|
||||||
|
from graphrag.index.cache import PipelineCache
|
||||||
|
from graphrag.index.verbs.text.embed.text_embed import text_embed_df
|
||||||
|
|
||||||
|
|
||||||
|
@verb(
|
||||||
|
name="create_final_documents",
|
||||||
|
treats_input_tables_as_immutable=True,
|
||||||
|
)
|
||||||
|
async def create_final_documents(
|
||||||
|
input: VerbInput,
|
||||||
|
callbacks: VerbCallbacks,
|
||||||
|
cache: PipelineCache,
|
||||||
|
text_embed: dict,
|
||||||
|
skip_embedding: bool = False,
|
||||||
|
**_kwargs: dict,
|
||||||
|
) -> VerbResult:
|
||||||
|
"""All the steps to transform final documents."""
|
||||||
|
source = cast(pd.DataFrame, input.get_input())
|
||||||
|
|
||||||
|
source.rename(columns={"text_units": "text_unit_ids"}, inplace=True)
|
||||||
|
|
||||||
|
if not skip_embedding:
|
||||||
|
source = await text_embed_df(
|
||||||
|
source,
|
||||||
|
callbacks,
|
||||||
|
cache,
|
||||||
|
column="raw_content",
|
||||||
|
strategy=text_embed["strategy"],
|
||||||
|
to="raw_content_embedding",
|
||||||
|
)
|
||||||
|
|
||||||
|
return create_verb_result(cast(Table, source))
|
@ -1,37 +1,64 @@
|
|||||||
# Copyright (c) 2024 Microsoft Corporation.
|
# Copyright (c) 2024 Microsoft Corporation.
|
||||||
# Licensed under the MIT License
|
# Licensed under the MIT License
|
||||||
|
|
||||||
"""All the steps to transform final relationships after they are embedded."""
|
"""All the steps to transform final relationships."""
|
||||||
|
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from datashaper import (
|
from datashaper import (
|
||||||
Table,
|
Table,
|
||||||
|
VerbCallbacks,
|
||||||
VerbInput,
|
VerbInput,
|
||||||
verb,
|
verb,
|
||||||
)
|
)
|
||||||
from datashaper.table_store.types import VerbResult, create_verb_result
|
from datashaper.table_store.types import VerbResult, create_verb_result
|
||||||
|
|
||||||
|
from graphrag.index.cache import PipelineCache
|
||||||
from graphrag.index.utils.ds_util import get_required_input_table
|
from graphrag.index.utils.ds_util import get_required_input_table
|
||||||
from graphrag.index.verbs.graph.compute_edge_combined_degree import (
|
from graphrag.index.verbs.graph.compute_edge_combined_degree import (
|
||||||
compute_edge_combined_degree_df,
|
compute_edge_combined_degree_df,
|
||||||
)
|
)
|
||||||
|
from graphrag.index.verbs.graph.unpack import unpack_graph_df
|
||||||
|
from graphrag.index.verbs.text.embed.text_embed import text_embed_df
|
||||||
|
|
||||||
|
|
||||||
@verb(
|
@verb(
|
||||||
name="create_final_relationships_post_embedding",
|
name="create_final_relationships",
|
||||||
treats_input_tables_as_immutable=True,
|
treats_input_tables_as_immutable=True,
|
||||||
)
|
)
|
||||||
def create_final_relationships_post_embedding(
|
async def create_final_relationships(
|
||||||
input: VerbInput,
|
input: VerbInput,
|
||||||
|
callbacks: VerbCallbacks,
|
||||||
|
cache: PipelineCache,
|
||||||
|
text_embed: dict,
|
||||||
|
skip_embedding: bool = False,
|
||||||
**_kwargs: dict,
|
**_kwargs: dict,
|
||||||
) -> VerbResult:
|
) -> VerbResult:
|
||||||
"""All the steps to transform final relationships after they are embedded."""
|
"""All the steps to transform final relationships."""
|
||||||
table = cast(pd.DataFrame, input.get_input())
|
table = cast(pd.DataFrame, input.get_input())
|
||||||
nodes = cast(pd.DataFrame, get_required_input_table(input, "nodes").table)
|
nodes = cast(pd.DataFrame, get_required_input_table(input, "nodes").table)
|
||||||
|
|
||||||
pruned_edges = table.drop(columns=["level"])
|
graph_edges = unpack_graph_df(table, callbacks, "clustered_graph", "edges")
|
||||||
|
|
||||||
|
graph_edges.rename(columns={"source_id": "text_unit_ids"}, inplace=True)
|
||||||
|
|
||||||
|
filtered = cast(
|
||||||
|
pd.DataFrame, graph_edges[graph_edges["level"] == 0].reset_index(drop=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not skip_embedding:
|
||||||
|
filtered = await text_embed_df(
|
||||||
|
filtered,
|
||||||
|
callbacks,
|
||||||
|
cache,
|
||||||
|
column="description",
|
||||||
|
strategy=text_embed["strategy"],
|
||||||
|
to="description_embedding",
|
||||||
|
embedding_name="relationship_description",
|
||||||
|
)
|
||||||
|
|
||||||
|
pruned_edges = filtered.drop(columns=["level"])
|
||||||
|
|
||||||
filtered_nodes = cast(
|
filtered_nodes = cast(
|
||||||
pd.DataFrame,
|
pd.DataFrame,
|
@ -1,38 +0,0 @@
|
|||||||
# Copyright (c) 2024 Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License
|
|
||||||
|
|
||||||
"""All the steps to transform final relationships before they are embedded."""
|
|
||||||
|
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
from datashaper import (
|
|
||||||
Table,
|
|
||||||
VerbCallbacks,
|
|
||||||
VerbInput,
|
|
||||||
verb,
|
|
||||||
)
|
|
||||||
from datashaper.table_store.types import VerbResult, create_verb_result
|
|
||||||
|
|
||||||
from graphrag.index.verbs.graph.unpack import unpack_graph_df
|
|
||||||
|
|
||||||
|
|
||||||
@verb(
|
|
||||||
name="create_final_relationships_pre_embedding",
|
|
||||||
treats_input_tables_as_immutable=True,
|
|
||||||
)
|
|
||||||
def create_final_relationships_pre_embedding(
|
|
||||||
input: VerbInput,
|
|
||||||
callbacks: VerbCallbacks,
|
|
||||||
**_kwargs: dict,
|
|
||||||
) -> VerbResult:
|
|
||||||
"""All the steps to transform final relationships before they are embedded."""
|
|
||||||
table = cast(pd.DataFrame, input.get_input())
|
|
||||||
|
|
||||||
graph_edges = unpack_graph_df(table, callbacks, "clustered_graph", "edges")
|
|
||||||
|
|
||||||
graph_edges.rename(columns={"source_id": "text_unit_ids"}, inplace=True)
|
|
||||||
|
|
||||||
filtered = graph_edges[graph_edges["level"] == 0].reset_index(drop=True)
|
|
||||||
|
|
||||||
return create_verb_result(cast(Table, filtered))
|
|
@ -104,7 +104,7 @@ def try_parse_json_object(input: str) -> tuple[str, dict]:
|
|||||||
return input, result
|
return input, result
|
||||||
|
|
||||||
_pattern = r"\{(.*)\}"
|
_pattern = r"\{(.*)\}"
|
||||||
_match = re.search(_pattern, input)
|
_match = re.search(_pattern, input, re.DOTALL)
|
||||||
input = "{" + _match.group(1) + "}" if _match else input
|
input = "{" + _match.group(1) + "}" if _match else input
|
||||||
|
|
||||||
# Clean up json string.
|
# Clean up json string.
|
||||||
|
@ -41,7 +41,6 @@ class Covariate(Identified):
|
|||||||
d: dict[str, Any],
|
d: dict[str, Any],
|
||||||
id_key: str = "id",
|
id_key: str = "id",
|
||||||
subject_id_key: str = "subject_id",
|
subject_id_key: str = "subject_id",
|
||||||
subject_type_key: str = "subject_type",
|
|
||||||
covariate_type_key: str = "covariate_type",
|
covariate_type_key: str = "covariate_type",
|
||||||
short_id_key: str = "short_id",
|
short_id_key: str = "short_id",
|
||||||
text_unit_ids_key: str = "text_unit_ids",
|
text_unit_ids_key: str = "text_unit_ids",
|
||||||
@ -53,7 +52,6 @@ class Covariate(Identified):
|
|||||||
id=d[id_key],
|
id=d[id_key],
|
||||||
short_id=d.get(short_id_key),
|
short_id=d.get(short_id_key),
|
||||||
subject_id=d[subject_id_key],
|
subject_id=d[subject_id_key],
|
||||||
subject_type=d.get(subject_type_key, "entity"),
|
|
||||||
covariate_type=d.get(covariate_type_key, "claim"),
|
covariate_type=d.get(covariate_type_key, "claim"),
|
||||||
text_unit_ids=d.get(text_unit_ids_key),
|
text_unit_ids=d.get(text_unit_ids_key),
|
||||||
document_ids=d.get(document_ids_key),
|
document_ids=d.get(document_ids_key),
|
||||||
|
@ -157,8 +157,7 @@ def read_covariates(
|
|||||||
id_col: str = "id",
|
id_col: str = "id",
|
||||||
short_id_col: str | None = "short_id",
|
short_id_col: str | None = "short_id",
|
||||||
subject_col: str = "subject_id",
|
subject_col: str = "subject_id",
|
||||||
subject_type_col: str | None = "subject_type",
|
covariate_type_col: str | None = "type",
|
||||||
covariate_type_col: str | None = "covariate_type",
|
|
||||||
text_unit_ids_col: str | None = "text_unit_ids",
|
text_unit_ids_col: str | None = "text_unit_ids",
|
||||||
document_ids_col: str | None = "document_ids",
|
document_ids_col: str | None = "document_ids",
|
||||||
attributes_cols: list[str] | None = None,
|
attributes_cols: list[str] | None = None,
|
||||||
@ -170,9 +169,6 @@ def read_covariates(
|
|||||||
id=to_str(row, id_col),
|
id=to_str(row, id_col),
|
||||||
short_id=to_optional_str(row, short_id_col) if short_id_col else str(idx),
|
short_id=to_optional_str(row, short_id_col) if short_id_col else str(idx),
|
||||||
subject_id=to_str(row, subject_col),
|
subject_id=to_str(row, subject_col),
|
||||||
subject_type=(
|
|
||||||
to_str(row, subject_type_col) if subject_type_col else "entity"
|
|
||||||
),
|
|
||||||
covariate_type=(
|
covariate_type=(
|
||||||
to_str(row, covariate_type_col) if covariate_type_col else "claim"
|
to_str(row, covariate_type_col) if covariate_type_col else "claim"
|
||||||
),
|
),
|
||||||
|
4
tests/fixtures/min-csv/config.json
vendored
4
tests/fixtures/min-csv/config.json
vendored
@ -7,7 +7,7 @@
|
|||||||
1,
|
1,
|
||||||
2000
|
2000
|
||||||
],
|
],
|
||||||
"subworkflows": 11,
|
"subworkflows": 1,
|
||||||
"max_runtime": 10
|
"max_runtime": 10
|
||||||
},
|
},
|
||||||
"create_base_extracted_entities": {
|
"create_base_extracted_entities": {
|
||||||
@ -52,7 +52,7 @@
|
|||||||
1,
|
1,
|
||||||
2000
|
2000
|
||||||
],
|
],
|
||||||
"subworkflows": 2,
|
"subworkflows": 1,
|
||||||
"max_runtime": 100
|
"max_runtime": 100
|
||||||
},
|
},
|
||||||
"create_final_nodes": {
|
"create_final_nodes": {
|
||||||
|
8
tests/fixtures/text/config.json
vendored
8
tests/fixtures/text/config.json
vendored
@ -7,7 +7,7 @@
|
|||||||
1,
|
1,
|
||||||
2000
|
2000
|
||||||
],
|
],
|
||||||
"subworkflows": 11,
|
"subworkflows": 1,
|
||||||
"max_runtime": 10
|
"max_runtime": 10
|
||||||
},
|
},
|
||||||
"create_base_extracted_entities": {
|
"create_base_extracted_entities": {
|
||||||
@ -26,15 +26,13 @@
|
|||||||
"nan_allowed_columns": [
|
"nan_allowed_columns": [
|
||||||
"type",
|
"type",
|
||||||
"description",
|
"description",
|
||||||
"subject_type",
|
|
||||||
"object_id",
|
"object_id",
|
||||||
"object_type",
|
|
||||||
"status",
|
"status",
|
||||||
"start_date",
|
"start_date",
|
||||||
"end_date",
|
"end_date",
|
||||||
"source_text"
|
"source_text"
|
||||||
],
|
],
|
||||||
"subworkflows": 6,
|
"subworkflows": 1,
|
||||||
"max_runtime": 300
|
"max_runtime": 300
|
||||||
},
|
},
|
||||||
"create_summarized_entities": {
|
"create_summarized_entities": {
|
||||||
@ -71,7 +69,7 @@
|
|||||||
1,
|
1,
|
||||||
2000
|
2000
|
||||||
],
|
],
|
||||||
"subworkflows": 2,
|
"subworkflows": 1,
|
||||||
"max_runtime": 100
|
"max_runtime": 100
|
||||||
},
|
},
|
||||||
"create_final_nodes": {
|
"create_final_nodes": {
|
||||||
|
35
tests/verbs/test_create_base_text_units.py
Normal file
35
tests/verbs/test_create_base_text_units.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
# Copyright (c) 2024 Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License
|
||||||
|
|
||||||
|
from graphrag.index.workflows.v1.create_base_text_units import (
|
||||||
|
build_steps,
|
||||||
|
workflow_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .util import (
|
||||||
|
compare_outputs,
|
||||||
|
get_config_for_workflow,
|
||||||
|
get_workflow_output,
|
||||||
|
load_expected,
|
||||||
|
load_input_tables,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_base_text_units():
|
||||||
|
input_tables = load_input_tables(inputs=[])
|
||||||
|
expected = load_expected(workflow_name)
|
||||||
|
|
||||||
|
config = get_config_for_workflow(workflow_name)
|
||||||
|
# test data was created with 4o, so we need to match the encoding for chunks to be identical
|
||||||
|
config["text_chunk"]["strategy"]["encoding_name"] = "o200k_base"
|
||||||
|
|
||||||
|
steps = build_steps(config)
|
||||||
|
|
||||||
|
actual = await get_workflow_output(
|
||||||
|
input_tables,
|
||||||
|
{
|
||||||
|
"steps": steps,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
compare_outputs(actual, expected)
|
68
tests/verbs/test_create_final_covariates.py
Normal file
68
tests/verbs/test_create_final_covariates.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
# Copyright (c) 2024 Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License
|
||||||
|
|
||||||
|
from pandas.testing import assert_series_equal
|
||||||
|
|
||||||
|
from graphrag.index.workflows.v1.create_final_covariates import (
|
||||||
|
build_steps,
|
||||||
|
workflow_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .util import (
|
||||||
|
get_config_for_workflow,
|
||||||
|
get_workflow_output,
|
||||||
|
load_expected,
|
||||||
|
load_input_tables,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_final_covariates():
|
||||||
|
input_tables = load_input_tables(["workflow:create_base_text_units"])
|
||||||
|
expected = load_expected(workflow_name)
|
||||||
|
|
||||||
|
config = get_config_for_workflow(workflow_name)
|
||||||
|
|
||||||
|
# deleting the llm config results in a default mock injection in run_gi_extract_claims
|
||||||
|
del config["claim_extract"]["strategy"]["llm"]
|
||||||
|
|
||||||
|
steps = build_steps(config)
|
||||||
|
|
||||||
|
actual = await get_workflow_output(
|
||||||
|
input_tables,
|
||||||
|
{
|
||||||
|
"steps": steps,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
input = input_tables["workflow:create_base_text_units"]
|
||||||
|
# we removed the subject_type and object_type columns so expect two less columns than the pre-refactor outputs
|
||||||
|
assert len(actual.columns) == (len(expected.columns) - 2)
|
||||||
|
# our mock only returns one covariate per text unit, so that's a 1:1 mapping versus the LLM-extracted content in the test data
|
||||||
|
assert len(actual) == len(input)
|
||||||
|
|
||||||
|
# assert all of the columns that covariates copied from the input
|
||||||
|
assert_series_equal(actual["text_unit_id"], input["id"], check_names=False)
|
||||||
|
assert_series_equal(actual["text_unit_id"], input["chunk_id"], check_names=False)
|
||||||
|
assert_series_equal(actual["document_ids"], input["document_ids"])
|
||||||
|
assert_series_equal(actual["n_tokens"], input["n_tokens"])
|
||||||
|
|
||||||
|
# make sure the human ids are incrementing and cast to strings
|
||||||
|
assert actual["human_readable_id"][0] == "1"
|
||||||
|
assert actual["human_readable_id"][1] == "2"
|
||||||
|
|
||||||
|
# check that the mock data is parsed and inserted into the correct columns
|
||||||
|
assert actual["covariate_type"][0] == "claim"
|
||||||
|
assert actual["subject_id"][0] == "COMPANY A"
|
||||||
|
assert actual["object_id"][0] == "GOVERNMENT AGENCY B"
|
||||||
|
assert actual["type"][0] == "ANTI-COMPETITIVE PRACTICES"
|
||||||
|
assert actual["status"][0] == "TRUE"
|
||||||
|
assert actual["start_date"][0] == "2022-01-10T00:00:00"
|
||||||
|
assert actual["end_date"][0] == "2022-01-10T00:00:00"
|
||||||
|
assert (
|
||||||
|
actual["description"][0]
|
||||||
|
== "Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
actual["source_text"][0]
|
||||||
|
== "According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B."
|
||||||
|
)
|
66
tests/verbs/test_create_final_documents.py
Normal file
66
tests/verbs/test_create_final_documents.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
# Copyright (c) 2024 Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License
|
||||||
|
|
||||||
|
from graphrag.index.workflows.v1.create_final_documents import (
|
||||||
|
build_steps,
|
||||||
|
workflow_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .util import (
|
||||||
|
compare_outputs,
|
||||||
|
get_config_for_workflow,
|
||||||
|
get_workflow_output,
|
||||||
|
load_expected,
|
||||||
|
load_input_tables,
|
||||||
|
remove_disabled_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_final_documents():
|
||||||
|
input_tables = load_input_tables([
|
||||||
|
"workflow:create_base_documents",
|
||||||
|
])
|
||||||
|
expected = load_expected(workflow_name)
|
||||||
|
|
||||||
|
config = get_config_for_workflow(workflow_name)
|
||||||
|
|
||||||
|
config["skip_raw_content_embedding"] = True
|
||||||
|
|
||||||
|
steps = remove_disabled_steps(build_steps(config))
|
||||||
|
|
||||||
|
actual = await get_workflow_output(
|
||||||
|
input_tables,
|
||||||
|
{
|
||||||
|
"steps": steps,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
compare_outputs(actual, expected)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_final_documents_with_embeddings():
|
||||||
|
input_tables = load_input_tables([
|
||||||
|
"workflow:create_base_documents",
|
||||||
|
])
|
||||||
|
expected = load_expected(workflow_name)
|
||||||
|
|
||||||
|
config = get_config_for_workflow(workflow_name)
|
||||||
|
|
||||||
|
config["skip_raw_content_embedding"] = False
|
||||||
|
# default config has a detailed standard embed config
|
||||||
|
# just override the strategy to mock so the rest of the required parameters are in place
|
||||||
|
config["document_raw_content_embed"]["strategy"]["type"] = "mock"
|
||||||
|
|
||||||
|
steps = remove_disabled_steps(build_steps(config))
|
||||||
|
|
||||||
|
actual = await get_workflow_output(
|
||||||
|
input_tables,
|
||||||
|
{
|
||||||
|
"steps": steps,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "raw_content_embedding" in actual.columns
|
||||||
|
assert len(actual.columns) == len(expected.columns) + 1
|
||||||
|
# the mock impl returns an array of 3 floats for each embedding
|
||||||
|
assert len(actual["raw_content_embedding"][0]) == 3
|
@ -37,3 +37,32 @@ async def test_create_final_relationships():
|
|||||||
)
|
)
|
||||||
|
|
||||||
compare_outputs(actual, expected)
|
compare_outputs(actual, expected)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_final_relationships_with_embeddings():
|
||||||
|
input_tables = load_input_tables([
|
||||||
|
"workflow:create_base_entity_graph",
|
||||||
|
"workflow:create_final_nodes",
|
||||||
|
])
|
||||||
|
expected = load_expected(workflow_name)
|
||||||
|
|
||||||
|
config = get_config_for_workflow(workflow_name)
|
||||||
|
|
||||||
|
config["skip_description_embedding"] = False
|
||||||
|
# default config has a detailed standard embed config
|
||||||
|
# just override the strategy to mock so the rest of the required parameters are in place
|
||||||
|
config["relationship_description_embed"]["strategy"]["type"] = "mock"
|
||||||
|
|
||||||
|
steps = remove_disabled_steps(build_steps(config))
|
||||||
|
|
||||||
|
actual = await get_workflow_output(
|
||||||
|
input_tables,
|
||||||
|
{
|
||||||
|
"steps": steps,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "description_embedding" in actual.columns
|
||||||
|
assert len(actual.columns) == len(expected.columns) + 1
|
||||||
|
# the mock impl returns an array of 3 floats for each embedding
|
||||||
|
assert len(actual["description_embedding"][0]) == 3
|
||||||
|
@ -13,6 +13,7 @@ from graphrag.index import (
|
|||||||
PipelineWorkflowStep,
|
PipelineWorkflowStep,
|
||||||
create_pipeline_config,
|
create_pipeline_config,
|
||||||
)
|
)
|
||||||
|
from graphrag.index.run.utils import _create_run_context
|
||||||
|
|
||||||
|
|
||||||
def load_input_tables(inputs: list[str]) -> dict[str, pd.DataFrame]:
|
def load_input_tables(inputs: list[str]) -> dict[str, pd.DataFrame]:
|
||||||
@ -31,6 +32,7 @@ def load_input_tables(inputs: list[str]) -> dict[str, pd.DataFrame]:
|
|||||||
# remove the workflow: prefix if it exists, because that is not part of the actual table filename
|
# remove the workflow: prefix if it exists, because that is not part of the actual table filename
|
||||||
name = input.replace("workflow:", "")
|
name = input.replace("workflow:", "")
|
||||||
input_tables[input] = pd.read_parquet(f"tests/verbs/data/{name}.parquet")
|
input_tables[input] = pd.read_parquet(f"tests/verbs/data/{name}.parquet")
|
||||||
|
|
||||||
return input_tables
|
return input_tables
|
||||||
|
|
||||||
|
|
||||||
@ -42,8 +44,12 @@ def load_expected(output: str) -> pd.DataFrame:
|
|||||||
def get_config_for_workflow(name: str) -> PipelineWorkflowConfig:
|
def get_config_for_workflow(name: str) -> PipelineWorkflowConfig:
|
||||||
"""Instantiates the bare minimum config to get a default workflow config for testing."""
|
"""Instantiates the bare minimum config to get a default workflow config for testing."""
|
||||||
config = create_graphrag_config()
|
config = create_graphrag_config()
|
||||||
|
|
||||||
|
# this flag needs to be set before creating the pipeline config, or the entire covariate workflow will be excluded
|
||||||
|
config.claim_extraction.enabled = True
|
||||||
|
|
||||||
pipeline_config = create_pipeline_config(config)
|
pipeline_config = create_pipeline_config(config)
|
||||||
print(pipeline_config.workflows)
|
|
||||||
result = next(conf for conf in pipeline_config.workflows if conf.name == name)
|
result = next(conf for conf in pipeline_config.workflows if conf.name == name)
|
||||||
return cast(PipelineWorkflowConfig, result.config)
|
return cast(PipelineWorkflowConfig, result.config)
|
||||||
|
|
||||||
@ -59,7 +65,9 @@ async def get_workflow_output(
|
|||||||
input_tables=input_tables,
|
input_tables=input_tables,
|
||||||
)
|
)
|
||||||
|
|
||||||
await workflow.run()
|
context = _create_run_context(None, None, None)
|
||||||
|
|
||||||
|
await workflow.run(context=context)
|
||||||
|
|
||||||
# if there's only one output, it is the default here, no name required
|
# if there's only one output, it is the default here, no name required
|
||||||
return cast(pd.DataFrame, workflow.output())
|
return cast(pd.DataFrame, workflow.output())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user