Entity classification updates (#281)

* node classification updates

* update

* remove unused code

* update
This commit is contained in:
Preston Rasmussen 2025-02-27 15:12:50 -05:00 committed by GitHub
parent 1d2417ec26
commit 6f874730f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 15 additions and 8 deletions

View File

@ -31,9 +31,13 @@ class MissedEntities(BaseModel):
class EntityClassification(BaseModel):
entity_classification: str = Field(
entities: list[str] = Field(
...,
description='Dictionary of entity classifications. Key is the entity name and value is the entity type',
description='List of entities',
)
entity_classifications: list[str | None] = Field(
...,
description='List of entities classifications. The index of the classification should match the index of the entity it corresponds to.',
)
@ -180,7 +184,8 @@ def classify_nodes(context: dict[str, Any]) -> list[Message]:
Guidelines:
1. Each entity must have exactly one type
2. If none of the provided entity types accurately classify an extracted node, the type should be set to None
2. Only use the provided ENTITY TYPES as types, do not use additional types to classify entities.
3. If none of the provided entity types accurately classify an extracted node, the type should be set to None
"""
return [
Message(role='system', content=sys_prompt),

View File

@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import ast
import logging
from time import time
@ -163,8 +162,9 @@ async def extract_nodes(
prompt_library.extract_nodes.classify_nodes(node_classification_context),
response_model=EntityClassification,
)
response_string = llm_response.get('entity_classification', '{}')
node_classifications.update(ast.literal_eval(response_string))
entities = llm_response.get('entities', [])
entity_classifications = llm_response.get('entity_classifications', [])
node_classifications.update(dict(zip(entities, entity_classifications)))
end = time()
logger.debug(f'Extracted new nodes: {extracted_node_names} in {(end - start) * 1000} ms')
@ -173,7 +173,9 @@ async def extract_nodes(
for name in extracted_node_names:
entity_type = node_classifications.get(name)
labels = (
['Entity'] if entity_type is None or entity_type == 'None' else ['Entity', entity_type]
['Entity']
if entity_type is None or entity_type == 'None' or entity_type == 'null'
else ['Entity', entity_type]
)
new_node = EntityNode(

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "graphiti-core"
version = "0.7.4"
version = "0.7.5"
description = "A temporal graph building library"
authors = [
"Paul Paliychuk <paul@getzep.com>",