mirror of
https://github.com/microsoft/graphrag.git
synced 2025-12-26 14:38:52 +00:00
Perf optimizations in map_query_to_entities() (#1276)
* Address perf issue in map_query_to_entities() * Add semver --------- Co-authored-by: Matthieu Maitre <mmaitre@microsoft.com> Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
This commit is contained in:
parent
1f70d42572
commit
6aae386b30
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Perf optimizations in map_query_to_entities()"
|
||||
}
|
||||
@ -7,6 +7,7 @@ from enum import Enum
|
||||
|
||||
from graphrag.model import Entity, Relationship
|
||||
from graphrag.query.input.retrieval.entities import (
|
||||
get_entity_by_id,
|
||||
get_entity_by_key,
|
||||
get_entity_by_name,
|
||||
)
|
||||
@ -36,7 +37,7 @@ def map_query_to_entities(
|
||||
query: str,
|
||||
text_embedding_vectorstore: BaseVectorStore,
|
||||
text_embedder: BaseTextEmbedding,
|
||||
all_entities: list[Entity],
|
||||
all_entities_dict: dict[str, Entity],
|
||||
embedding_vectorstore_key: str = EntityVectorStoreKey.ID,
|
||||
include_entity_names: list[str] | None = None,
|
||||
exclude_entity_names: list[str] | None = None,
|
||||
@ -48,6 +49,7 @@ def map_query_to_entities(
|
||||
include_entity_names = []
|
||||
if exclude_entity_names is None:
|
||||
exclude_entity_names = []
|
||||
all_entities = list(all_entities_dict.values())
|
||||
matched_entities = []
|
||||
if query != "":
|
||||
# get entities with highest semantic similarity to query
|
||||
@ -58,11 +60,16 @@ def map_query_to_entities(
|
||||
k=k * oversample_scaler,
|
||||
)
|
||||
for result in search_results:
|
||||
matched = get_entity_by_key(
|
||||
entities=all_entities,
|
||||
key=embedding_vectorstore_key,
|
||||
value=result.document.id,
|
||||
)
|
||||
if embedding_vectorstore_key == EntityVectorStoreKey.ID and isinstance(
|
||||
result.document.id, str
|
||||
):
|
||||
matched = get_entity_by_id(all_entities_dict, result.document.id)
|
||||
else:
|
||||
matched = get_entity_by_key(
|
||||
entities=all_entities,
|
||||
key=embedding_vectorstore_key,
|
||||
value=result.document.id,
|
||||
)
|
||||
if matched:
|
||||
matched_entities.append(matched)
|
||||
else:
|
||||
|
||||
@ -12,17 +12,26 @@ import pandas as pd
|
||||
from graphrag.model import Entity
|
||||
|
||||
|
||||
def get_entity_by_id(entities: dict[str, Entity], value: str) -> Entity | None:
|
||||
"""Get entity by id."""
|
||||
entity = entities.get(value)
|
||||
if entity is None and is_valid_uuid(value):
|
||||
entity = entities.get(value.replace("-", ""))
|
||||
return entity
|
||||
|
||||
|
||||
def get_entity_by_key(
|
||||
entities: Iterable[Entity], key: str, value: str | int
|
||||
) -> Entity | None:
|
||||
"""Get entity by key."""
|
||||
for entity in entities:
|
||||
if isinstance(value, str) and is_valid_uuid(value):
|
||||
if getattr(entity, key) == value or getattr(entity, key) == value.replace(
|
||||
"-", ""
|
||||
):
|
||||
if isinstance(value, str) and is_valid_uuid(value):
|
||||
value_no_dashes = value.replace("-", "")
|
||||
for entity in entities:
|
||||
entity_value = getattr(entity, key)
|
||||
if entity_value in (value, value_no_dashes):
|
||||
return entity
|
||||
else:
|
||||
else:
|
||||
for entity in entities:
|
||||
if getattr(entity, key) == value:
|
||||
return entity
|
||||
return None
|
||||
|
||||
@ -141,7 +141,7 @@ class LocalSearchMixedContext(LocalContextBuilder):
|
||||
query=query,
|
||||
text_embedding_vectorstore=self.entity_text_embeddings,
|
||||
text_embedder=self.text_embedder,
|
||||
all_entities=list(self.entities.values()),
|
||||
all_entities_dict=self.entities,
|
||||
embedding_vectorstore_key=self.embedding_vectorstore_key,
|
||||
include_entity_names=include_entity_names,
|
||||
exclude_entity_names=exclude_entity_names,
|
||||
|
||||
2
tests/unit/query/context_builder/__init__.py
Normal file
2
tests/unit/query/context_builder/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
182
tests/unit/query/context_builder/test_entity_extraction.py
Normal file
182
tests/unit/query/context_builder/test_entity_extraction.py
Normal file
@ -0,0 +1,182 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from typing import Any
|
||||
|
||||
from graphrag.model import Entity
|
||||
from graphrag.model.types import TextEmbedder
|
||||
from graphrag.query.context_builder.entity_extraction import (
|
||||
EntityVectorStoreKey,
|
||||
map_query_to_entities,
|
||||
)
|
||||
from graphrag.query.llm.base import BaseTextEmbedding
|
||||
from graphrag.vector_stores import (
|
||||
BaseVectorStore,
|
||||
VectorStoreDocument,
|
||||
VectorStoreSearchResult,
|
||||
)
|
||||
|
||||
|
||||
class MockBaseVectorStore(BaseVectorStore):
|
||||
def __init__(self, documents: list[VectorStoreDocument]) -> None:
|
||||
super().__init__("mock")
|
||||
self.documents = documents
|
||||
|
||||
def connect(self, **kwargs: Any) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def load_documents(
|
||||
self, documents: list[VectorStoreDocument], overwrite: bool = True
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self, query_embedding: list[float], k: int = 10, **kwargs: Any
|
||||
) -> list[VectorStoreSearchResult]:
|
||||
return [
|
||||
VectorStoreSearchResult(document=document, score=1)
|
||||
for document in self.documents[:k]
|
||||
]
|
||||
|
||||
def similarity_search_by_text(
|
||||
self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any
|
||||
) -> list[VectorStoreSearchResult]:
|
||||
return sorted(
|
||||
[
|
||||
VectorStoreSearchResult(
|
||||
document=document, score=abs(len(text) - len(document.text or ""))
|
||||
)
|
||||
for document in self.documents
|
||||
],
|
||||
key=lambda x: x.score,
|
||||
)[:k]
|
||||
|
||||
def filter_by_id(self, include_ids: list[str] | list[int]) -> Any:
|
||||
return [document for document in self.documents if document.id in include_ids]
|
||||
|
||||
|
||||
class MockBaseTextEmbedding(BaseTextEmbedding):
|
||||
def embed(self, text: str, **kwargs: Any) -> list[float]:
|
||||
return [len(text)]
|
||||
|
||||
async def aembed(self, text: str, **kwargs: Any) -> list[float]:
|
||||
return [len(text)]
|
||||
|
||||
|
||||
def test_map_query_to_entities():
|
||||
entities = [
|
||||
Entity(
|
||||
id="2da37c7a-50a8-44d4-aa2c-fd401e19976c",
|
||||
short_id="sid1",
|
||||
title="t1",
|
||||
rank=2,
|
||||
),
|
||||
Entity(
|
||||
id="c4f93564-4507-4ee4-b102-98add401a965",
|
||||
short_id="sid2",
|
||||
title="t22",
|
||||
rank=4,
|
||||
),
|
||||
Entity(
|
||||
id="7c6f2bc9-47c9-4453-93a3-d2e174a02cd9",
|
||||
short_id="sid3",
|
||||
title="t333",
|
||||
rank=1,
|
||||
),
|
||||
Entity(
|
||||
id="8fd6d72a-8e9d-4183-8a97-c38bcc971c83",
|
||||
short_id="sid4",
|
||||
title="t4444",
|
||||
rank=3,
|
||||
),
|
||||
]
|
||||
|
||||
assert map_query_to_entities(
|
||||
query="t22",
|
||||
text_embedding_vectorstore=MockBaseVectorStore([
|
||||
VectorStoreDocument(id=entity.id, text=entity.title, vector=None)
|
||||
for entity in entities
|
||||
]),
|
||||
text_embedder=MockBaseTextEmbedding(),
|
||||
all_entities_dict={entity.id: entity for entity in entities},
|
||||
embedding_vectorstore_key=EntityVectorStoreKey.ID,
|
||||
k=1,
|
||||
oversample_scaler=1,
|
||||
) == [
|
||||
Entity(
|
||||
id="c4f93564-4507-4ee4-b102-98add401a965",
|
||||
short_id="sid2",
|
||||
title="t22",
|
||||
rank=4,
|
||||
)
|
||||
]
|
||||
|
||||
assert map_query_to_entities(
|
||||
query="t22",
|
||||
text_embedding_vectorstore=MockBaseVectorStore([
|
||||
VectorStoreDocument(id=entity.title, text=entity.title, vector=None)
|
||||
for entity in entities
|
||||
]),
|
||||
text_embedder=MockBaseTextEmbedding(),
|
||||
all_entities_dict={entity.id: entity for entity in entities},
|
||||
embedding_vectorstore_key=EntityVectorStoreKey.TITLE,
|
||||
k=1,
|
||||
oversample_scaler=1,
|
||||
) == [
|
||||
Entity(
|
||||
id="c4f93564-4507-4ee4-b102-98add401a965",
|
||||
short_id="sid2",
|
||||
title="t22",
|
||||
rank=4,
|
||||
)
|
||||
]
|
||||
|
||||
assert map_query_to_entities(
|
||||
query="",
|
||||
text_embedding_vectorstore=MockBaseVectorStore([
|
||||
VectorStoreDocument(id=entity.id, text=entity.title, vector=None)
|
||||
for entity in entities
|
||||
]),
|
||||
text_embedder=MockBaseTextEmbedding(),
|
||||
all_entities_dict={entity.id: entity for entity in entities},
|
||||
embedding_vectorstore_key=EntityVectorStoreKey.ID,
|
||||
k=2,
|
||||
) == [
|
||||
Entity(
|
||||
id="c4f93564-4507-4ee4-b102-98add401a965",
|
||||
short_id="sid2",
|
||||
title="t22",
|
||||
rank=4,
|
||||
),
|
||||
Entity(
|
||||
id="8fd6d72a-8e9d-4183-8a97-c38bcc971c83",
|
||||
short_id="sid4",
|
||||
title="t4444",
|
||||
rank=3,
|
||||
),
|
||||
]
|
||||
|
||||
assert map_query_to_entities(
|
||||
query="",
|
||||
text_embedding_vectorstore=MockBaseVectorStore([
|
||||
VectorStoreDocument(id=entity.id, text=entity.title, vector=None)
|
||||
for entity in entities
|
||||
]),
|
||||
text_embedder=MockBaseTextEmbedding(),
|
||||
all_entities_dict={entity.id: entity for entity in entities},
|
||||
embedding_vectorstore_key=EntityVectorStoreKey.TITLE,
|
||||
k=2,
|
||||
) == [
|
||||
Entity(
|
||||
id="c4f93564-4507-4ee4-b102-98add401a965",
|
||||
short_id="sid2",
|
||||
title="t22",
|
||||
rank=4,
|
||||
),
|
||||
Entity(
|
||||
id="8fd6d72a-8e9d-4183-8a97-c38bcc971c83",
|
||||
short_id="sid4",
|
||||
title="t4444",
|
||||
rank=3,
|
||||
),
|
||||
]
|
||||
2
tests/unit/query/input/__init__.py
Normal file
2
tests/unit/query/input/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
2
tests/unit/query/input/retrieval/__init__.py
Normal file
2
tests/unit/query/input/retrieval/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
167
tests/unit/query/input/retrieval/test_entities.py
Normal file
167
tests/unit/query/input/retrieval/test_entities.py
Normal file
@ -0,0 +1,167 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.model import Entity
|
||||
from graphrag.query.input.retrieval.entities import (
|
||||
get_entity_by_id,
|
||||
get_entity_by_key,
|
||||
)
|
||||
|
||||
|
||||
def test_get_entity_by_id():
|
||||
assert (
|
||||
get_entity_by_id(
|
||||
{
|
||||
entity.id: entity
|
||||
for entity in [
|
||||
Entity(
|
||||
id="2da37c7a-50a8-44d4-aa2c-fd401e19976c",
|
||||
short_id="sid1",
|
||||
title="title1",
|
||||
),
|
||||
]
|
||||
},
|
||||
"00000000-0000-0000-0000-000000000000",
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
assert get_entity_by_id(
|
||||
{
|
||||
entity.id: entity
|
||||
for entity in [
|
||||
Entity(
|
||||
id="2da37c7a-50a8-44d4-aa2c-fd401e19976c",
|
||||
short_id="sid1",
|
||||
title="title1",
|
||||
),
|
||||
Entity(
|
||||
id="c4f93564-4507-4ee4-b102-98add401a965",
|
||||
short_id="sid2",
|
||||
title="title2",
|
||||
),
|
||||
Entity(
|
||||
id="7c6f2bc9-47c9-4453-93a3-d2e174a02cd9",
|
||||
short_id="sid3",
|
||||
title="title3",
|
||||
),
|
||||
]
|
||||
},
|
||||
"7c6f2bc9-47c9-4453-93a3-d2e174a02cd9",
|
||||
) == Entity(
|
||||
id="7c6f2bc9-47c9-4453-93a3-d2e174a02cd9", short_id="sid3", title="title3"
|
||||
)
|
||||
|
||||
assert get_entity_by_id(
|
||||
{
|
||||
entity.id: entity
|
||||
for entity in [
|
||||
Entity(
|
||||
id="2da37c7a50a844d4aa2cfd401e19976c",
|
||||
short_id="sid1",
|
||||
title="title1",
|
||||
),
|
||||
Entity(
|
||||
id="c4f9356445074ee4b10298add401a965",
|
||||
short_id="sid2",
|
||||
title="title2",
|
||||
),
|
||||
Entity(
|
||||
id="7c6f2bc947c9445393a3d2e174a02cd9",
|
||||
short_id="sid3",
|
||||
title="title3",
|
||||
),
|
||||
]
|
||||
},
|
||||
"7c6f2bc9-47c9-4453-93a3-d2e174a02cd9",
|
||||
) == Entity(id="7c6f2bc947c9445393a3d2e174a02cd9", short_id="sid3", title="title3")
|
||||
|
||||
assert get_entity_by_id(
|
||||
{
|
||||
entity.id: entity
|
||||
for entity in [
|
||||
Entity(id="id1", short_id="sid1", title="title1"),
|
||||
Entity(id="id2", short_id="sid2", title="title2"),
|
||||
Entity(id="id3", short_id="sid3", title="title3"),
|
||||
]
|
||||
},
|
||||
"id3",
|
||||
) == Entity(id="id3", short_id="sid3", title="title3")
|
||||
|
||||
|
||||
def test_get_entity_by_key():
|
||||
assert (
|
||||
get_entity_by_key(
|
||||
[
|
||||
Entity(
|
||||
id="2da37c7a-50a8-44d4-aa2c-fd401e19976c",
|
||||
short_id="sid1",
|
||||
title="title1",
|
||||
),
|
||||
],
|
||||
"id",
|
||||
"00000000-0000-0000-0000-000000000000",
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
assert get_entity_by_key(
|
||||
[
|
||||
Entity(
|
||||
id="2da37c7a-50a8-44d4-aa2c-fd401e19976c",
|
||||
short_id="sid1",
|
||||
title="title1",
|
||||
),
|
||||
Entity(
|
||||
id="c4f93564-4507-4ee4-b102-98add401a965",
|
||||
short_id="sid2",
|
||||
title="title2",
|
||||
),
|
||||
Entity(
|
||||
id="7c6f2bc9-47c9-4453-93a3-d2e174a02cd9",
|
||||
short_id="sid3",
|
||||
title="title3",
|
||||
),
|
||||
],
|
||||
"id",
|
||||
"7c6f2bc9-47c9-4453-93a3-d2e174a02cd9",
|
||||
) == Entity(
|
||||
id="7c6f2bc9-47c9-4453-93a3-d2e174a02cd9", short_id="sid3", title="title3"
|
||||
)
|
||||
|
||||
assert get_entity_by_key(
|
||||
[
|
||||
Entity(
|
||||
id="2da37c7a50a844d4aa2cfd401e19976c", short_id="sid1", title="title1"
|
||||
),
|
||||
Entity(
|
||||
id="c4f9356445074ee4b10298add401a965", short_id="sid2", title="title2"
|
||||
),
|
||||
Entity(
|
||||
id="7c6f2bc947c9445393a3d2e174a02cd9", short_id="sid3", title="title3"
|
||||
),
|
||||
],
|
||||
"id",
|
||||
"7c6f2bc9-47c9-4453-93a3-d2e174a02cd9",
|
||||
) == Entity(id="7c6f2bc947c9445393a3d2e174a02cd9", short_id="sid3", title="title3")
|
||||
|
||||
assert get_entity_by_key(
|
||||
[
|
||||
Entity(id="id1", short_id="sid1", title="title1"),
|
||||
Entity(id="id2", short_id="sid2", title="title2"),
|
||||
Entity(id="id3", short_id="sid3", title="title3"),
|
||||
],
|
||||
"id",
|
||||
"id3",
|
||||
) == Entity(id="id3", short_id="sid3", title="title3")
|
||||
|
||||
assert get_entity_by_key(
|
||||
[
|
||||
Entity(id="id1", short_id="sid1", title="title1", rank=1),
|
||||
Entity(id="id2", short_id="sid2", title="title2a", rank=2),
|
||||
Entity(id="id3", short_id="sid3", title="title3", rank=3),
|
||||
Entity(id="id2", short_id="sid2", title="title2b", rank=2),
|
||||
],
|
||||
"rank",
|
||||
2,
|
||||
) == Entity(id="id2", short_id="sid2", title="title2a", rank=2)
|
||||
Loading…
x
Reference in New Issue
Block a user