mirror of
https://github.com/microsoft/graphrag.git
synced 2025-11-06 21:05:56 +00:00
parent
c02ab0984a
commit
a6a78d5897
@ -0,0 +1,4 @@
|
|||||||
|
{
|
||||||
|
"type": "patch",
|
||||||
|
"description": "Add caching to NLP extractor."
|
||||||
|
}
|
||||||
@ -64,3 +64,7 @@ class ExtractGraphNLPConfig(BaseModel):
|
|||||||
text_analyzer: TextAnalyzerConfig = Field(
|
text_analyzer: TextAnalyzerConfig = Field(
|
||||||
description="The text analyzer configuration.", default=TextAnalyzerConfig()
|
description="The text analyzer configuration.", default=TextAnalyzerConfig()
|
||||||
)
|
)
|
||||||
|
parallelization_num_threads: int = Field(
|
||||||
|
description="The number of threads to use for the extraction process.",
|
||||||
|
default=defs.PARALLELIZATION_NUM_THREADS,
|
||||||
|
)
|
||||||
|
|||||||
@ -5,6 +5,7 @@
|
|||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
|
from graphrag.cache.pipeline_cache import PipelineCache
|
||||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||||
from graphrag.config.models.embed_graph_config import EmbedGraphConfig
|
from graphrag.config.models.embed_graph_config import EmbedGraphConfig
|
||||||
from graphrag.config.models.extract_graph_nlp_config import ExtractGraphNLPConfig
|
from graphrag.config.models.extract_graph_nlp_config import ExtractGraphNLPConfig
|
||||||
@ -20,9 +21,10 @@ from graphrag.index.operations.graph_to_dataframes import graph_to_dataframes
|
|||||||
from graphrag.index.operations.prune_graph import prune_graph
|
from graphrag.index.operations.prune_graph import prune_graph
|
||||||
|
|
||||||
|
|
||||||
def extract_graph_nlp(
|
async def extract_graph_nlp(
|
||||||
text_units: pd.DataFrame,
|
text_units: pd.DataFrame,
|
||||||
callbacks: WorkflowCallbacks,
|
callbacks: WorkflowCallbacks,
|
||||||
|
cache: PipelineCache,
|
||||||
extraction_config: ExtractGraphNLPConfig,
|
extraction_config: ExtractGraphNLPConfig,
|
||||||
pruning_config: PruneGraphConfig,
|
pruning_config: PruneGraphConfig,
|
||||||
embed_config: EmbedGraphConfig | None = None,
|
embed_config: EmbedGraphConfig | None = None,
|
||||||
@ -31,10 +33,12 @@ def extract_graph_nlp(
|
|||||||
"""All the steps to create the base entity graph."""
|
"""All the steps to create the base entity graph."""
|
||||||
text_analyzer_config = extraction_config.text_analyzer
|
text_analyzer_config = extraction_config.text_analyzer
|
||||||
text_analyzer = create_noun_phrase_extractor(text_analyzer_config)
|
text_analyzer = create_noun_phrase_extractor(text_analyzer_config)
|
||||||
extracted_nodes, extracted_edges = build_noun_graph(
|
extracted_nodes, extracted_edges = await build_noun_graph(
|
||||||
text_units,
|
text_units,
|
||||||
text_analyzer=text_analyzer,
|
text_analyzer=text_analyzer,
|
||||||
normalize_edge_weights=extraction_config.normalize_edge_weights,
|
normalize_edge_weights=extraction_config.normalize_edge_weights,
|
||||||
|
num_threads=extraction_config.parallelization_num_threads,
|
||||||
|
cache=cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
# create a temporary graph to prune, then turn it back into dataframes
|
# create a temporary graph to prune, then turn it back into dataframes
|
||||||
|
|||||||
@ -7,27 +7,38 @@ import math
|
|||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
|
from graphrag.cache.noop_pipeline_cache import NoopPipelineCache
|
||||||
|
from graphrag.cache.pipeline_cache import PipelineCache
|
||||||
|
from graphrag.config.enums import AsyncType
|
||||||
from graphrag.index.operations.build_noun_graph.np_extractors.base import (
|
from graphrag.index.operations.build_noun_graph.np_extractors.base import (
|
||||||
BaseNounPhraseExtractor,
|
BaseNounPhraseExtractor,
|
||||||
)
|
)
|
||||||
|
from graphrag.index.run.derive_from_rows import derive_from_rows
|
||||||
|
from graphrag.index.utils.hashing import gen_sha512_hash
|
||||||
|
|
||||||
|
|
||||||
def build_noun_graph(
|
async def build_noun_graph(
|
||||||
text_unit_df: pd.DataFrame,
|
text_unit_df: pd.DataFrame,
|
||||||
text_analyzer: BaseNounPhraseExtractor,
|
text_analyzer: BaseNounPhraseExtractor,
|
||||||
normalize_edge_weights: bool,
|
normalize_edge_weights: bool,
|
||||||
|
num_threads: int = 4,
|
||||||
|
cache: PipelineCache | None = None,
|
||||||
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
||||||
"""Build a noun graph from text units."""
|
"""Build a noun graph from text units."""
|
||||||
text_units = text_unit_df.loc[:, ["id", "text"]]
|
text_units = text_unit_df.loc[:, ["id", "text"]]
|
||||||
nodes_df = _extract_nodes(text_units, text_analyzer)
|
nodes_df = await _extract_nodes(
|
||||||
|
text_units, text_analyzer, num_threads=num_threads, cache=cache
|
||||||
|
)
|
||||||
edges_df = _extract_edges(nodes_df, normalize_edge_weights=normalize_edge_weights)
|
edges_df = _extract_edges(nodes_df, normalize_edge_weights=normalize_edge_weights)
|
||||||
|
|
||||||
return (nodes_df, edges_df)
|
return (nodes_df, edges_df)
|
||||||
|
|
||||||
|
|
||||||
def _extract_nodes(
|
async def _extract_nodes(
|
||||||
text_unit_df: pd.DataFrame,
|
text_unit_df: pd.DataFrame,
|
||||||
text_analyzer: BaseNounPhraseExtractor,
|
text_analyzer: BaseNounPhraseExtractor,
|
||||||
|
num_threads: int = 4,
|
||||||
|
cache: PipelineCache | None = None,
|
||||||
) -> pd.DataFrame:
|
) -> pd.DataFrame:
|
||||||
"""
|
"""
|
||||||
Extract initial nodes and edges from text units.
|
Extract initial nodes and edges from text units.
|
||||||
@ -35,9 +46,26 @@ def _extract_nodes(
|
|||||||
Input: text unit df with schema [id, text, document_id]
|
Input: text unit df with schema [id, text, document_id]
|
||||||
Returns a dataframe with schema [id, title, freq, text_unit_ids].
|
Returns a dataframe with schema [id, title, freq, text_unit_ids].
|
||||||
"""
|
"""
|
||||||
text_unit_df["noun_phrases"] = text_unit_df["text"].apply(
|
cache = cache or NoopPipelineCache()
|
||||||
lambda text: text_analyzer.extract(text)
|
cache = cache.child("extract_noun_phrases")
|
||||||
|
|
||||||
|
async def extract(row):
|
||||||
|
text = row["text"]
|
||||||
|
attrs = {"text": text, "analyzer": str(text_analyzer)}
|
||||||
|
key = gen_sha512_hash(attrs, attrs.keys())
|
||||||
|
result = await cache.get(key)
|
||||||
|
if not result:
|
||||||
|
result = text_analyzer.extract(text)
|
||||||
|
await cache.set(key, result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
text_unit_df["noun_phrases"] = await derive_from_rows(
|
||||||
|
text_unit_df,
|
||||||
|
extract,
|
||||||
|
num_threads=num_threads,
|
||||||
|
async_type=AsyncType.Threaded,
|
||||||
)
|
)
|
||||||
|
|
||||||
noun_node_df = text_unit_df.explode("noun_phrases")
|
noun_node_df = text_unit_df.explode("noun_phrases")
|
||||||
noun_node_df = noun_node_df.rename(
|
noun_node_df = noun_node_df.rename(
|
||||||
columns={"noun_phrases": "title", "id": "text_unit_id"}
|
columns={"noun_phrases": "title", "id": "text_unit_id"}
|
||||||
|
|||||||
@ -33,3 +33,7 @@ class BaseNounPhraseExtractor(metaclass=ABCMeta):
|
|||||||
|
|
||||||
Returns: List of noun phrases.
|
Returns: List of noun phrases.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""Return string representation of the extractor, used for cache key generation."""
|
||||||
|
|||||||
@ -172,3 +172,7 @@ class CFGNounPhraseExtractor(BaseNounPhraseExtractor):
|
|||||||
cleaned_tokens, self.max_word_length
|
cleaned_tokens, self.max_word_length
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""Return string representation of the extractor, used for cache key generation."""
|
||||||
|
return f"cfg_{self.model_name}_{self.max_word_length}_{self.include_named_entities}_{self.exclude_entity_tags}_{self.exclude_pos_tags}_{self.exclude_nouns}_{self.word_delimiter}_{self.noun_phrase_grammars}_{self.noun_phrase_tags}"
|
||||||
|
|||||||
@ -117,3 +117,7 @@ class RegexENNounPhraseExtractor(BaseNounPhraseExtractor):
|
|||||||
"has_compound_words": has_compound_words,
|
"has_compound_words": has_compound_words,
|
||||||
"has_valid_tokens": has_valid_tokens,
|
"has_valid_tokens": has_valid_tokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""Return string representation of the extractor, used for cache key generation."""
|
||||||
|
return f"regex_en_{self.exclude_nouns}_{self.max_word_length}_{self.word_delimiter}"
|
||||||
|
|||||||
@ -157,3 +157,7 @@ class SyntacticNounPhraseExtractor(BaseNounPhraseExtractor):
|
|||||||
cleaned_token_texts, self.max_word_length
|
cleaned_token_texts, self.max_word_length
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""Return string representation of the extractor, used for cache key generation."""
|
||||||
|
return f"syntactic_{self.model_name}_{self.max_word_length}_{self.include_named_entities}_{self.exclude_entity_tags}_{self.exclude_pos_tags}_{self.exclude_nouns}_{self.word_delimiter}"
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from typing import Any, TypeVar, cast
|
|||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
|
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||||
from graphrag.config.enums import AsyncType
|
from graphrag.config.enums import AsyncType
|
||||||
from graphrag.logger.progress import progress_ticker
|
from graphrag.logger.progress import progress_ticker
|
||||||
@ -33,11 +34,12 @@ class ParallelizationError(ValueError):
|
|||||||
async def derive_from_rows(
|
async def derive_from_rows(
|
||||||
input: pd.DataFrame,
|
input: pd.DataFrame,
|
||||||
transform: Callable[[pd.Series], Awaitable[ItemType]],
|
transform: Callable[[pd.Series], Awaitable[ItemType]],
|
||||||
callbacks: WorkflowCallbacks,
|
callbacks: WorkflowCallbacks | None = None,
|
||||||
num_threads: int = 4,
|
num_threads: int = 4,
|
||||||
async_type: AsyncType = AsyncType.AsyncIO,
|
async_type: AsyncType = AsyncType.AsyncIO,
|
||||||
) -> list[ItemType | None]:
|
) -> list[ItemType | None]:
|
||||||
"""Apply a generic transform function to each row. Any errors will be reported and thrown."""
|
"""Apply a generic transform function to each row. Any errors will be reported and thrown."""
|
||||||
|
callbacks = callbacks or NoopWorkflowCallbacks()
|
||||||
match async_type:
|
match async_type:
|
||||||
case AsyncType.AsyncIO:
|
case AsyncType.AsyncIO:
|
||||||
return await derive_from_rows_asyncio(
|
return await derive_from_rows_asyncio(
|
||||||
|
|||||||
@ -26,9 +26,10 @@ async def run_workflow(
|
|||||||
"""All the steps to create the base entity graph."""
|
"""All the steps to create the base entity graph."""
|
||||||
text_units = await load_table_from_storage("text_units", context.storage)
|
text_units = await load_table_from_storage("text_units", context.storage)
|
||||||
|
|
||||||
entities, relationships = extract_graph_nlp(
|
entities, relationships = await extract_graph_nlp(
|
||||||
text_units,
|
text_units,
|
||||||
callbacks,
|
callbacks,
|
||||||
|
context.cache,
|
||||||
extraction_config=config.extract_graph_nlp,
|
extraction_config=config.extract_graph_nlp,
|
||||||
pruning_config=config.prune_graph,
|
pruning_config=config.prune_graph,
|
||||||
embed_config=config.embed_graph,
|
embed_config=config.embed_graph,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user