mirror of
https://github.com/getzep/graphiti.git
synced 2025-06-27 02:00:02 +00:00
Add group ID validation and error handling (#618)
- Introduced `GroupIdValidationError` to handle invalid group ID formats. - Added `validate_group_id` function to check that group IDs contain only alphanumeric characters, dashes, or underscores. - Integrated `validate_group_id` checks in the `Graphiti` class to ensure group IDs are validated during processing.
This commit is contained in:
parent
fe870b953f
commit
a6bb9b3eca
@ -73,3 +73,11 @@ class EntityTypeValidationError(GraphitiError):
|
||||
def __init__(self, entity_type: str, entity_type_attribute: str):
|
||||
self.message = f'{entity_type_attribute} cannot be used as an attribute for {entity_type} as it is a protected attribute name.'
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class GroupIdValidationError(GraphitiError):
|
||||
"""Raised when a group_id contains invalid characters."""
|
||||
|
||||
def __init__(self, group_id: str):
|
||||
self.message = f'group_id "{group_id}" must contain only alphanumeric characters, dashes, or underscores'
|
||||
super().__init__(self.message)
|
||||
|
@ -29,7 +29,7 @@ from graphiti_core.driver.neo4j_driver import Neo4jDriver
|
||||
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
||||
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
||||
from graphiti_core.graphiti_types import GraphitiClients
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather, validate_group_id
|
||||
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
||||
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
|
||||
from graphiti_core.search.search import SearchConfig, search
|
||||
@ -351,6 +351,7 @@ class Graphiti:
|
||||
now = utc_now()
|
||||
|
||||
validate_entity_types(entity_types)
|
||||
validate_group_id(group_id)
|
||||
|
||||
previous_episodes = (
|
||||
await self.retrieve_episodes(
|
||||
@ -503,6 +504,8 @@ class Graphiti:
|
||||
start = time()
|
||||
now = utc_now()
|
||||
|
||||
validate_group_id(group_id)
|
||||
|
||||
episodes = [
|
||||
EpisodicNode(
|
||||
name=episode.name,
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Coroutine
|
||||
from datetime import datetime
|
||||
|
||||
@ -25,6 +26,8 @@ from neo4j import time as neo4j_time
|
||||
from numpy._typing import NDArray
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.errors import GroupIdValidationError
|
||||
|
||||
load_dotenv()
|
||||
|
||||
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', 'neo4j')
|
||||
@ -103,3 +106,29 @@ async def semaphore_gather(
|
||||
return await coroutine
|
||||
|
||||
return await asyncio.gather(*(_wrap_coroutine(coroutine) for coroutine in coroutines))
|
||||
|
||||
|
||||
def validate_group_id(group_id: str) -> bool:
|
||||
"""
|
||||
Validate that a group_id contains only ASCII alphanumeric characters, dashes, and underscores.
|
||||
|
||||
Args:
|
||||
group_id: The group_id to validate
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise
|
||||
|
||||
Raises:
|
||||
GroupIdValidationError: If group_id contains invalid characters
|
||||
"""
|
||||
|
||||
# Allow empty string (default case)
|
||||
if not group_id:
|
||||
return True
|
||||
|
||||
# Check if string contains only ASCII alphanumeric characters, dashes, or underscores
|
||||
# Pattern matches: letters (a-z, A-Z), digits (0-9), hyphens (-), and underscores (_)
|
||||
if not re.match(r'^[a-zA-Z0-9_-]+$', group_id):
|
||||
raise GroupIdValidationError(group_id)
|
||||
|
||||
return True
|
||||
|
Loading…
x
Reference in New Issue
Block a user