From a6bb9b3eca5560d12abc973b00a0afbdcb117b2e Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Tue, 24 Jun 2025 09:33:54 -0700 Subject: [PATCH] Add group ID validation and error handling (#618) - Introduced `GroupIdValidationError` to handle invalid group ID formats. - Added `validate_group_id` function to check that group IDs contain only alphanumeric characters, dashes, or underscores. - Integrated `validate_group_id` checks in the `Graphiti` class to ensure group IDs are validated during processing. --- graphiti_core/errors.py | 8 ++++++++ graphiti_core/graphiti.py | 5 ++++- graphiti_core/helpers.py | 29 +++++++++++++++++++++++++++++ 3 files changed, 41 insertions(+), 1 deletion(-) diff --git a/graphiti_core/errors.py b/graphiti_core/errors.py index de333010..3bbb9a94 100644 --- a/graphiti_core/errors.py +++ b/graphiti_core/errors.py @@ -73,3 +73,11 @@ class EntityTypeValidationError(GraphitiError): def __init__(self, entity_type: str, entity_type_attribute: str): self.message = f'{entity_type_attribute} cannot be used as an attribute for {entity_type} as it is a protected attribute name.' super().__init__(self.message) + + +class GroupIdValidationError(GraphitiError): + """Raised when a group_id contains invalid characters.""" + + def __init__(self, group_id: str): + self.message = f'group_id "{group_id}" must contain only alphanumeric characters, dashes, or underscores' + super().__init__(self.message) diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 50eda87a..18a39657 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -29,7 +29,7 @@ from graphiti_core.driver.neo4j_driver import Neo4jDriver from graphiti_core.edges import EntityEdge, EpisodicEdge from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder from graphiti_core.graphiti_types import GraphitiClients -from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather +from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather, validate_group_id from graphiti_core.llm_client import LLMClient, OpenAIClient from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode from graphiti_core.search.search import SearchConfig, search @@ -351,6 +351,7 @@ class Graphiti: now = utc_now() validate_entity_types(entity_types) + validate_group_id(group_id) previous_episodes = ( await self.retrieve_episodes( @@ -503,6 +504,8 @@ class Graphiti: start = time() now = utc_now() + validate_group_id(group_id) + episodes = [ EpisodicNode( name=episode.name, diff --git a/graphiti_core/helpers.py b/graphiti_core/helpers.py index 47b3a6da..a115f075 100644 --- a/graphiti_core/helpers.py +++ b/graphiti_core/helpers.py @@ -16,6 +16,7 @@ limitations under the License. import asyncio import os +import re from collections.abc import Coroutine from datetime import datetime @@ -25,6 +26,8 @@ from neo4j import time as neo4j_time from numpy._typing import NDArray from typing_extensions import LiteralString +from graphiti_core.errors import GroupIdValidationError + load_dotenv() DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', 'neo4j') @@ -103,3 +106,29 @@ async def semaphore_gather( return await coroutine return await asyncio.gather(*(_wrap_coroutine(coroutine) for coroutine in coroutines)) + + +def validate_group_id(group_id: str) -> bool: + """ + Validate that a group_id contains only ASCII alphanumeric characters, dashes, and underscores. + + Args: + group_id: The group_id to validate + + Returns: + True if valid, False otherwise + + Raises: + GroupIdValidationError: If group_id contains invalid characters + """ + + # Allow empty string (default case) + if not group_id: + return True + + # Check if string contains only ASCII alphanumeric characters, dashes, or underscores + # Pattern matches: letters (a-z, A-Z), digits (0-9), hyphens (-), and underscores (_) + if not re.match(r'^[a-zA-Z0-9_-]+$', group_id): + raise GroupIdValidationError(group_id) + + return True