mirror of
https://github.com/getzep/graphiti.git
synced 2025-12-29 08:05:02 +00:00
Entity classification updates (#281)
* node classification updates * update * remove unused code * update
This commit is contained in:
parent
1d2417ec26
commit
6f874730f3
@ -31,9 +31,13 @@ class MissedEntities(BaseModel):
|
||||
|
||||
|
||||
class EntityClassification(BaseModel):
|
||||
entity_classification: str = Field(
|
||||
entities: list[str] = Field(
|
||||
...,
|
||||
description='Dictionary of entity classifications. Key is the entity name and value is the entity type',
|
||||
description='List of entities',
|
||||
)
|
||||
entity_classifications: list[str | None] = Field(
|
||||
...,
|
||||
description='List of entities classifications. The index of the classification should match the index of the entity it corresponds to.',
|
||||
)
|
||||
|
||||
|
||||
@ -180,7 +184,8 @@ def classify_nodes(context: dict[str, Any]) -> list[Message]:
|
||||
|
||||
Guidelines:
|
||||
1. Each entity must have exactly one type
|
||||
2. If none of the provided entity types accurately classify an extracted node, the type should be set to None
|
||||
2. Only use the provided ENTITY TYPES as types, do not use additional types to classify entities.
|
||||
3. If none of the provided entity types accurately classify an extracted node, the type should be set to None
|
||||
"""
|
||||
return [
|
||||
Message(role='system', content=sys_prompt),
|
||||
|
||||
@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import logging
|
||||
from time import time
|
||||
|
||||
@ -163,8 +162,9 @@ async def extract_nodes(
|
||||
prompt_library.extract_nodes.classify_nodes(node_classification_context),
|
||||
response_model=EntityClassification,
|
||||
)
|
||||
response_string = llm_response.get('entity_classification', '{}')
|
||||
node_classifications.update(ast.literal_eval(response_string))
|
||||
entities = llm_response.get('entities', [])
|
||||
entity_classifications = llm_response.get('entity_classifications', [])
|
||||
node_classifications.update(dict(zip(entities, entity_classifications)))
|
||||
|
||||
end = time()
|
||||
logger.debug(f'Extracted new nodes: {extracted_node_names} in {(end - start) * 1000} ms')
|
||||
@ -173,7 +173,9 @@ async def extract_nodes(
|
||||
for name in extracted_node_names:
|
||||
entity_type = node_classifications.get(name)
|
||||
labels = (
|
||||
['Entity'] if entity_type is None or entity_type == 'None' else ['Entity', entity_type]
|
||||
['Entity']
|
||||
if entity_type is None or entity_type == 'None' or entity_type == 'null'
|
||||
else ['Entity', entity_type]
|
||||
)
|
||||
|
||||
new_node = EntityNode(
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "graphiti-core"
|
||||
version = "0.7.4"
|
||||
version = "0.7.5"
|
||||
description = "A temporal graph building library"
|
||||
authors = [
|
||||
"Paul Paliychuk <paul@getzep.com>",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user