mirror of
https://github.com/getzep/graphiti.git
synced 2025-06-27 02:00:02 +00:00
Add group ID validation and error handling (#618)
- Introduced `GroupIdValidationError` to handle invalid group ID formats. - Added `validate_group_id` function to check that group IDs contain only alphanumeric characters, dashes, or underscores. - Integrated `validate_group_id` checks in the `Graphiti` class to ensure group IDs are validated during processing.
This commit is contained in:
parent
fe870b953f
commit
a6bb9b3eca
@ -73,3 +73,11 @@ class EntityTypeValidationError(GraphitiError):
|
|||||||
def __init__(self, entity_type: str, entity_type_attribute: str):
|
def __init__(self, entity_type: str, entity_type_attribute: str):
|
||||||
self.message = f'{entity_type_attribute} cannot be used as an attribute for {entity_type} as it is a protected attribute name.'
|
self.message = f'{entity_type_attribute} cannot be used as an attribute for {entity_type} as it is a protected attribute name.'
|
||||||
super().__init__(self.message)
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
|
||||||
|
class GroupIdValidationError(GraphitiError):
|
||||||
|
"""Raised when a group_id contains invalid characters."""
|
||||||
|
|
||||||
|
def __init__(self, group_id: str):
|
||||||
|
self.message = f'group_id "{group_id}" must contain only alphanumeric characters, dashes, or underscores'
|
||||||
|
super().__init__(self.message)
|
||||||
|
@ -29,7 +29,7 @@ from graphiti_core.driver.neo4j_driver import Neo4jDriver
|
|||||||
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
||||||
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
||||||
from graphiti_core.graphiti_types import GraphitiClients
|
from graphiti_core.graphiti_types import GraphitiClients
|
||||||
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather, validate_group_id
|
||||||
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
||||||
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
|
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
|
||||||
from graphiti_core.search.search import SearchConfig, search
|
from graphiti_core.search.search import SearchConfig, search
|
||||||
@ -351,6 +351,7 @@ class Graphiti:
|
|||||||
now = utc_now()
|
now = utc_now()
|
||||||
|
|
||||||
validate_entity_types(entity_types)
|
validate_entity_types(entity_types)
|
||||||
|
validate_group_id(group_id)
|
||||||
|
|
||||||
previous_episodes = (
|
previous_episodes = (
|
||||||
await self.retrieve_episodes(
|
await self.retrieve_episodes(
|
||||||
@ -503,6 +504,8 @@ class Graphiti:
|
|||||||
start = time()
|
start = time()
|
||||||
now = utc_now()
|
now = utc_now()
|
||||||
|
|
||||||
|
validate_group_id(group_id)
|
||||||
|
|
||||||
episodes = [
|
episodes = [
|
||||||
EpisodicNode(
|
EpisodicNode(
|
||||||
name=episode.name,
|
name=episode.name,
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
from collections.abc import Coroutine
|
from collections.abc import Coroutine
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
@ -25,6 +26,8 @@ from neo4j import time as neo4j_time
|
|||||||
from numpy._typing import NDArray
|
from numpy._typing import NDArray
|
||||||
from typing_extensions import LiteralString
|
from typing_extensions import LiteralString
|
||||||
|
|
||||||
|
from graphiti_core.errors import GroupIdValidationError
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', 'neo4j')
|
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', 'neo4j')
|
||||||
@ -103,3 +106,29 @@ async def semaphore_gather(
|
|||||||
return await coroutine
|
return await coroutine
|
||||||
|
|
||||||
return await asyncio.gather(*(_wrap_coroutine(coroutine) for coroutine in coroutines))
|
return await asyncio.gather(*(_wrap_coroutine(coroutine) for coroutine in coroutines))
|
||||||
|
|
||||||
|
|
||||||
|
def validate_group_id(group_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Validate that a group_id contains only ASCII alphanumeric characters, dashes, and underscores.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group_id: The group_id to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if valid, False otherwise
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
GroupIdValidationError: If group_id contains invalid characters
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Allow empty string (default case)
|
||||||
|
if not group_id:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check if string contains only ASCII alphanumeric characters, dashes, or underscores
|
||||||
|
# Pattern matches: letters (a-z, A-Z), digits (0-9), hyphens (-), and underscores (_)
|
||||||
|
if not re.match(r'^[a-zA-Z0-9_-]+$', group_id):
|
||||||
|
raise GroupIdValidationError(group_id)
|
||||||
|
|
||||||
|
return True
|
||||||
|
Loading…
x
Reference in New Issue
Block a user