Add group ID validation and error handling (#618)

- Introduced `GroupIdValidationError` to handle invalid group ID formats.
- Added `validate_group_id` function to check that group IDs contain only alphanumeric characters, dashes, or underscores.
- Integrated `validate_group_id` checks in the `Graphiti` class to ensure group IDs are validated during processing.
This commit is contained in:
Daniel Chalef 2025-06-24 09:33:54 -07:00 committed by GitHub
parent fe870b953f
commit a6bb9b3eca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 41 additions and 1 deletions

View File

@ -73,3 +73,11 @@ class EntityTypeValidationError(GraphitiError):
def __init__(self, entity_type: str, entity_type_attribute: str):
self.message = f'{entity_type_attribute} cannot be used as an attribute for {entity_type} as it is a protected attribute name.'
super().__init__(self.message)
class GroupIdValidationError(GraphitiError):
"""Raised when a group_id contains invalid characters."""
def __init__(self, group_id: str):
self.message = f'group_id "{group_id}" must contain only alphanumeric characters, dashes, or underscores'
super().__init__(self.message)

View File

@ -29,7 +29,7 @@ from graphiti_core.driver.neo4j_driver import Neo4jDriver
from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather, validate_group_id
from graphiti_core.llm_client import LLMClient, OpenAIClient
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
from graphiti_core.search.search import SearchConfig, search
@ -351,6 +351,7 @@ class Graphiti:
now = utc_now()
validate_entity_types(entity_types)
validate_group_id(group_id)
previous_episodes = (
await self.retrieve_episodes(
@ -503,6 +504,8 @@ class Graphiti:
start = time()
now = utc_now()
validate_group_id(group_id)
episodes = [
EpisodicNode(
name=episode.name,

View File

@ -16,6 +16,7 @@ limitations under the License.
import asyncio
import os
import re
from collections.abc import Coroutine
from datetime import datetime
@ -25,6 +26,8 @@ from neo4j import time as neo4j_time
from numpy._typing import NDArray
from typing_extensions import LiteralString
from graphiti_core.errors import GroupIdValidationError
load_dotenv()
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', 'neo4j')
@ -103,3 +106,29 @@ async def semaphore_gather(
return await coroutine
return await asyncio.gather(*(_wrap_coroutine(coroutine) for coroutine in coroutines))
def validate_group_id(group_id: str) -> bool:
"""
Validate that a group_id contains only ASCII alphanumeric characters, dashes, and underscores.
Args:
group_id: The group_id to validate
Returns:
True if valid, False otherwise
Raises:
GroupIdValidationError: If group_id contains invalid characters
"""
# Allow empty string (default case)
if not group_id:
return True
# Check if string contains only ASCII alphanumeric characters, dashes, or underscores
# Pattern matches: letters (a-z, A-Z), digits (0-9), hyphens (-), and underscores (_)
if not re.match(r'^[a-zA-Z0-9_-]+$', group_id):
raise GroupIdValidationError(group_id)
return True