mirror of
https://github.com/getzep/graphiti.git
synced 2025-06-27 02:00:02 +00:00
1252 lines
48 KiB
Python
1252 lines
48 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Graphiti MCP Server - Exposes Graphiti functionality through the Model Context Protocol (MCP)
|
|
"""
|
|
|
|
import argparse
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
import sys
|
|
from collections.abc import Callable
|
|
from datetime import datetime, timezone
|
|
from typing import Any, TypedDict, cast
|
|
|
|
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
|
from dotenv import load_dotenv
|
|
from mcp.server.fastmcp import FastMCP
|
|
from openai import AsyncAzureOpenAI
|
|
from pydantic import BaseModel, Field
|
|
|
|
from graphiti_core import Graphiti
|
|
from graphiti_core.edges import EntityEdge
|
|
from graphiti_core.embedder.azure_openai import AzureOpenAIEmbedderClient
|
|
from graphiti_core.embedder.client import EmbedderClient
|
|
from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig
|
|
from graphiti_core.llm_client import LLMClient
|
|
from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient
|
|
from graphiti_core.llm_client.config import LLMConfig
|
|
from graphiti_core.llm_client.openai_client import OpenAIClient
|
|
from graphiti_core.nodes import EpisodeType, EpisodicNode
|
|
from graphiti_core.search.search_config_recipes import (
|
|
NODE_HYBRID_SEARCH_NODE_DISTANCE,
|
|
NODE_HYBRID_SEARCH_RRF,
|
|
)
|
|
from graphiti_core.search.search_filters import SearchFilters
|
|
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
|
|
|
|
load_dotenv()
|
|
|
|
|
|
DEFAULT_LLM_MODEL = 'gpt-4.1-mini'
|
|
SMALL_LLM_MODEL = 'gpt-4.1-nano'
|
|
DEFAULT_EMBEDDER_MODEL = 'text-embedding-3-small'
|
|
|
|
# Semaphore limit for concurrent Graphiti operations.
|
|
# Decrease this if you're experiencing 429 rate limit errors from your LLM provider.
|
|
# Increase if you have high rate limits.
|
|
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 10))
|
|
|
|
|
|
class Requirement(BaseModel):
|
|
"""A Requirement represents a specific need, feature, or functionality that a product or service must fulfill.
|
|
|
|
Always ensure an edge is created between the requirement and the project it belongs to, and clearly indicate on the
|
|
edge that the requirement is a requirement.
|
|
|
|
Instructions for identifying and extracting requirements:
|
|
1. Look for explicit statements of needs or necessities ("We need X", "X is required", "X must have Y")
|
|
2. Identify functional specifications that describe what the system should do
|
|
3. Pay attention to non-functional requirements like performance, security, or usability criteria
|
|
4. Extract constraints or limitations that must be adhered to
|
|
5. Focus on clear, specific, and measurable requirements rather than vague wishes
|
|
6. Capture the priority or importance if mentioned ("critical", "high priority", etc.)
|
|
7. Include any dependencies between requirements when explicitly stated
|
|
8. Preserve the original intent and scope of the requirement
|
|
9. Categorize requirements appropriately based on their domain or function
|
|
"""
|
|
|
|
project_name: str = Field(
|
|
...,
|
|
description='The name of the project to which the requirement belongs.',
|
|
)
|
|
description: str = Field(
|
|
...,
|
|
description='Description of the requirement. Only use information mentioned in the context to write this description.',
|
|
)
|
|
|
|
|
|
class Preference(BaseModel):
|
|
"""A Preference represents a user's expressed like, dislike, or preference for something.
|
|
|
|
Instructions for identifying and extracting preferences:
|
|
1. Look for explicit statements of preference such as "I like/love/enjoy/prefer X" or "I don't like/hate/dislike X"
|
|
2. Pay attention to comparative statements ("I prefer X over Y")
|
|
3. Consider the emotional tone when users mention certain topics
|
|
4. Extract only preferences that are clearly expressed, not assumptions
|
|
5. Categorize the preference appropriately based on its domain (food, music, brands, etc.)
|
|
6. Include relevant qualifiers (e.g., "likes spicy food" rather than just "likes food")
|
|
7. Only extract preferences directly stated by the user, not preferences of others they mention
|
|
8. Provide a concise but specific description that captures the nature of the preference
|
|
"""
|
|
|
|
category: str = Field(
|
|
...,
|
|
description="The category of the preference. (e.g., 'Brands', 'Food', 'Music')",
|
|
)
|
|
description: str = Field(
|
|
...,
|
|
description='Brief description of the preference. Only use information mentioned in the context to write this description.',
|
|
)
|
|
|
|
|
|
class Procedure(BaseModel):
|
|
"""A Procedure informing the agent what actions to take or how to perform in certain scenarios. Procedures are typically composed of several steps.
|
|
|
|
Instructions for identifying and extracting procedures:
|
|
1. Look for sequential instructions or steps ("First do X, then do Y")
|
|
2. Identify explicit directives or commands ("Always do X when Y happens")
|
|
3. Pay attention to conditional statements ("If X occurs, then do Y")
|
|
4. Extract procedures that have clear beginning and end points
|
|
5. Focus on actionable instructions rather than general information
|
|
6. Preserve the original sequence and dependencies between steps
|
|
7. Include any specified conditions or triggers for the procedure
|
|
8. Capture any stated purpose or goal of the procedure
|
|
9. Summarize complex procedures while maintaining critical details
|
|
"""
|
|
|
|
description: str = Field(
|
|
...,
|
|
description='Brief description of the procedure. Only use information mentioned in the context to write this description.',
|
|
)
|
|
|
|
|
|
ENTITY_TYPES: dict[str, BaseModel] = {
|
|
'Requirement': Requirement, # type: ignore
|
|
'Preference': Preference, # type: ignore
|
|
'Procedure': Procedure, # type: ignore
|
|
}
|
|
|
|
|
|
# Type definitions for API responses
|
|
class ErrorResponse(TypedDict):
|
|
error: str
|
|
|
|
|
|
class SuccessResponse(TypedDict):
|
|
message: str
|
|
|
|
|
|
class NodeResult(TypedDict):
|
|
uuid: str
|
|
name: str
|
|
summary: str
|
|
labels: list[str]
|
|
group_id: str
|
|
created_at: str
|
|
attributes: dict[str, Any]
|
|
|
|
|
|
class NodeSearchResponse(TypedDict):
|
|
message: str
|
|
nodes: list[NodeResult]
|
|
|
|
|
|
class FactSearchResponse(TypedDict):
|
|
message: str
|
|
facts: list[dict[str, Any]]
|
|
|
|
|
|
class EpisodeSearchResponse(TypedDict):
|
|
message: str
|
|
episodes: list[dict[str, Any]]
|
|
|
|
|
|
class StatusResponse(TypedDict):
|
|
status: str
|
|
message: str
|
|
|
|
|
|
def create_azure_credential_token_provider() -> Callable[[], str]:
|
|
credential = DefaultAzureCredential()
|
|
token_provider = get_bearer_token_provider(
|
|
credential, 'https://cognitiveservices.azure.com/.default'
|
|
)
|
|
return token_provider
|
|
|
|
|
|
# Server configuration classes
|
|
# The configuration system has a hierarchy:
|
|
# - GraphitiConfig is the top-level configuration
|
|
# - LLMConfig handles all OpenAI/LLM related settings
|
|
# - EmbedderConfig manages embedding settings
|
|
# - Neo4jConfig manages database connection details
|
|
# - Various other settings like group_id and feature flags
|
|
# Configuration values are loaded from:
|
|
# 1. Default values in the class definitions
|
|
# 2. Environment variables (loaded via load_dotenv())
|
|
# 3. Command line arguments (which override environment variables)
|
|
class GraphitiLLMConfig(BaseModel):
|
|
"""Configuration for the LLM client.
|
|
|
|
Centralizes all LLM-specific configuration parameters including API keys and model selection.
|
|
"""
|
|
|
|
api_key: str | None = None
|
|
model: str = DEFAULT_LLM_MODEL
|
|
small_model: str = SMALL_LLM_MODEL
|
|
temperature: float = 0.0
|
|
azure_openai_endpoint: str | None = None
|
|
azure_openai_deployment_name: str | None = None
|
|
azure_openai_api_version: str | None = None
|
|
azure_openai_use_managed_identity: bool = False
|
|
|
|
@classmethod
|
|
def from_env(cls) -> 'GraphitiLLMConfig':
|
|
"""Create LLM configuration from environment variables."""
|
|
# Get model from environment, or use default if not set or empty
|
|
model_env = os.environ.get('MODEL_NAME', '')
|
|
model = model_env if model_env.strip() else DEFAULT_LLM_MODEL
|
|
|
|
# Get small_model from environment, or use default if not set or empty
|
|
small_model_env = os.environ.get('SMALL_MODEL_NAME', '')
|
|
small_model = small_model_env if small_model_env.strip() else SMALL_LLM_MODEL
|
|
|
|
azure_openai_endpoint = os.environ.get('AZURE_OPENAI_ENDPOINT', None)
|
|
azure_openai_api_version = os.environ.get('AZURE_OPENAI_API_VERSION', None)
|
|
azure_openai_deployment_name = os.environ.get('AZURE_OPENAI_DEPLOYMENT_NAME', None)
|
|
azure_openai_use_managed_identity = (
|
|
os.environ.get('AZURE_OPENAI_USE_MANAGED_IDENTITY', 'false').lower() == 'true'
|
|
)
|
|
|
|
if azure_openai_endpoint is None:
|
|
# Setup for OpenAI API
|
|
# Log if empty model was provided
|
|
if model_env == '':
|
|
logger.debug(
|
|
f'MODEL_NAME environment variable not set, using default: {DEFAULT_LLM_MODEL}'
|
|
)
|
|
elif not model_env.strip():
|
|
logger.warning(
|
|
f'Empty MODEL_NAME environment variable, using default: {DEFAULT_LLM_MODEL}'
|
|
)
|
|
|
|
return cls(
|
|
api_key=os.environ.get('OPENAI_API_KEY'),
|
|
model=model,
|
|
small_model=small_model,
|
|
temperature=float(os.environ.get('LLM_TEMPERATURE', '0.0')),
|
|
)
|
|
else:
|
|
# Setup for Azure OpenAI API
|
|
# Log if empty deployment name was provided
|
|
if azure_openai_deployment_name is None:
|
|
logger.error('AZURE_OPENAI_DEPLOYMENT_NAME environment variable not set')
|
|
|
|
raise ValueError('AZURE_OPENAI_DEPLOYMENT_NAME environment variable not set')
|
|
if not azure_openai_use_managed_identity:
|
|
# api key
|
|
api_key = os.environ.get('OPENAI_API_KEY', None)
|
|
else:
|
|
# Managed identity
|
|
api_key = None
|
|
|
|
return cls(
|
|
azure_openai_use_managed_identity=azure_openai_use_managed_identity,
|
|
azure_openai_endpoint=azure_openai_endpoint,
|
|
api_key=api_key,
|
|
azure_openai_api_version=azure_openai_api_version,
|
|
azure_openai_deployment_name=azure_openai_deployment_name,
|
|
model=model,
|
|
small_model=small_model,
|
|
temperature=float(os.environ.get('LLM_TEMPERATURE', '0.0')),
|
|
)
|
|
|
|
@classmethod
|
|
def from_cli_and_env(cls, args: argparse.Namespace) -> 'GraphitiLLMConfig':
|
|
"""Create LLM configuration from CLI arguments, falling back to environment variables."""
|
|
# Start with environment-based config
|
|
config = cls.from_env()
|
|
|
|
# CLI arguments override environment variables when provided
|
|
if hasattr(args, 'model') and args.model:
|
|
# Only use CLI model if it's not empty
|
|
if args.model.strip():
|
|
config.model = args.model
|
|
else:
|
|
# Log that empty model was provided and default is used
|
|
logger.warning(f'Empty model name provided, using default: {DEFAULT_LLM_MODEL}')
|
|
|
|
if hasattr(args, 'small_model') and args.small_model:
|
|
if args.small_model.strip():
|
|
config.small_model = args.small_model
|
|
else:
|
|
logger.warning(f'Empty small_model name provided, using default: {SMALL_LLM_MODEL}')
|
|
|
|
if hasattr(args, 'temperature') and args.temperature is not None:
|
|
config.temperature = args.temperature
|
|
|
|
return config
|
|
|
|
def create_client(self) -> LLMClient:
|
|
"""Create an LLM client based on this configuration.
|
|
|
|
Returns:
|
|
LLMClient instance
|
|
"""
|
|
|
|
if self.azure_openai_endpoint is not None:
|
|
# Azure OpenAI API setup
|
|
if self.azure_openai_use_managed_identity:
|
|
# Use managed identity for authentication
|
|
token_provider = create_azure_credential_token_provider()
|
|
return AzureOpenAILLMClient(
|
|
azure_client=AsyncAzureOpenAI(
|
|
azure_endpoint=self.azure_openai_endpoint,
|
|
azure_deployment=self.azure_openai_deployment_name,
|
|
api_version=self.azure_openai_api_version,
|
|
azure_ad_token_provider=token_provider,
|
|
),
|
|
config=LLMConfig(
|
|
api_key=self.api_key,
|
|
model=self.model,
|
|
small_model=self.small_model,
|
|
temperature=self.temperature,
|
|
),
|
|
)
|
|
elif self.api_key:
|
|
# Use API key for authentication
|
|
return AzureOpenAILLMClient(
|
|
azure_client=AsyncAzureOpenAI(
|
|
azure_endpoint=self.azure_openai_endpoint,
|
|
azure_deployment=self.azure_openai_deployment_name,
|
|
api_version=self.azure_openai_api_version,
|
|
api_key=self.api_key,
|
|
),
|
|
config=LLMConfig(
|
|
api_key=self.api_key,
|
|
model=self.model,
|
|
small_model=self.small_model,
|
|
temperature=self.temperature,
|
|
),
|
|
)
|
|
else:
|
|
raise ValueError('OPENAI_API_KEY must be set when using Azure OpenAI API')
|
|
|
|
if not self.api_key:
|
|
raise ValueError('OPENAI_API_KEY must be set when using OpenAI API')
|
|
|
|
llm_client_config = LLMConfig(
|
|
api_key=self.api_key, model=self.model, small_model=self.small_model
|
|
)
|
|
|
|
# Set temperature
|
|
llm_client_config.temperature = self.temperature
|
|
|
|
return OpenAIClient(config=llm_client_config)
|
|
|
|
|
|
class GraphitiEmbedderConfig(BaseModel):
|
|
"""Configuration for the embedder client.
|
|
|
|
Centralizes all embedding-related configuration parameters.
|
|
"""
|
|
|
|
model: str = DEFAULT_EMBEDDER_MODEL
|
|
api_key: str | None = None
|
|
azure_openai_endpoint: str | None = None
|
|
azure_openai_deployment_name: str | None = None
|
|
azure_openai_api_version: str | None = None
|
|
azure_openai_use_managed_identity: bool = False
|
|
|
|
@classmethod
|
|
def from_env(cls) -> 'GraphitiEmbedderConfig':
|
|
"""Create embedder configuration from environment variables."""
|
|
|
|
# Get model from environment, or use default if not set or empty
|
|
model_env = os.environ.get('EMBEDDER_MODEL_NAME', '')
|
|
model = model_env if model_env.strip() else DEFAULT_EMBEDDER_MODEL
|
|
|
|
azure_openai_endpoint = os.environ.get('AZURE_OPENAI_EMBEDDING_ENDPOINT', None)
|
|
azure_openai_api_version = os.environ.get('AZURE_OPENAI_EMBEDDING_API_VERSION', None)
|
|
azure_openai_deployment_name = os.environ.get(
|
|
'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME', None
|
|
)
|
|
azure_openai_use_managed_identity = (
|
|
os.environ.get('AZURE_OPENAI_USE_MANAGED_IDENTITY', 'false').lower() == 'true'
|
|
)
|
|
if azure_openai_endpoint is not None:
|
|
# Setup for Azure OpenAI API
|
|
# Log if empty deployment name was provided
|
|
azure_openai_deployment_name = os.environ.get(
|
|
'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME', None
|
|
)
|
|
if azure_openai_deployment_name is None:
|
|
logger.error('AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME environment variable not set')
|
|
|
|
raise ValueError(
|
|
'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME environment variable not set'
|
|
)
|
|
|
|
if not azure_openai_use_managed_identity:
|
|
# api key
|
|
api_key = os.environ.get('AZURE_OPENAI_EMBEDDING_API_KEY', None) or os.environ.get(
|
|
'OPENAI_API_KEY', None
|
|
)
|
|
else:
|
|
# Managed identity
|
|
api_key = None
|
|
|
|
return cls(
|
|
azure_openai_use_managed_identity=azure_openai_use_managed_identity,
|
|
azure_openai_endpoint=azure_openai_endpoint,
|
|
api_key=api_key,
|
|
azure_openai_api_version=azure_openai_api_version,
|
|
azure_openai_deployment_name=azure_openai_deployment_name,
|
|
)
|
|
else:
|
|
return cls(
|
|
model=model,
|
|
api_key=os.environ.get('OPENAI_API_KEY'),
|
|
)
|
|
|
|
def create_client(self) -> EmbedderClient | None:
|
|
if self.azure_openai_endpoint is not None:
|
|
# Azure OpenAI API setup
|
|
if self.azure_openai_use_managed_identity:
|
|
# Use managed identity for authentication
|
|
token_provider = create_azure_credential_token_provider()
|
|
return AzureOpenAIEmbedderClient(
|
|
azure_client=AsyncAzureOpenAI(
|
|
azure_endpoint=self.azure_openai_endpoint,
|
|
azure_deployment=self.azure_openai_deployment_name,
|
|
api_version=self.azure_openai_api_version,
|
|
azure_ad_token_provider=token_provider,
|
|
),
|
|
model=self.model,
|
|
)
|
|
elif self.api_key:
|
|
# Use API key for authentication
|
|
return AzureOpenAIEmbedderClient(
|
|
azure_client=AsyncAzureOpenAI(
|
|
azure_endpoint=self.azure_openai_endpoint,
|
|
azure_deployment=self.azure_openai_deployment_name,
|
|
api_version=self.azure_openai_api_version,
|
|
api_key=self.api_key,
|
|
),
|
|
model=self.model,
|
|
)
|
|
else:
|
|
logger.error('OPENAI_API_KEY must be set when using Azure OpenAI API')
|
|
return None
|
|
else:
|
|
# OpenAI API setup
|
|
if not self.api_key:
|
|
return None
|
|
|
|
embedder_config = OpenAIEmbedderConfig(api_key=self.api_key, embedding_model=self.model)
|
|
|
|
return OpenAIEmbedder(config=embedder_config)
|
|
|
|
|
|
class Neo4jConfig(BaseModel):
|
|
"""Configuration for Neo4j database connection."""
|
|
|
|
uri: str = 'bolt://localhost:7687'
|
|
user: str = 'neo4j'
|
|
password: str = 'password'
|
|
|
|
@classmethod
|
|
def from_env(cls) -> 'Neo4jConfig':
|
|
"""Create Neo4j configuration from environment variables."""
|
|
return cls(
|
|
uri=os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
|
|
user=os.environ.get('NEO4J_USER', 'neo4j'),
|
|
password=os.environ.get('NEO4J_PASSWORD', 'password'),
|
|
)
|
|
|
|
|
|
class GraphitiConfig(BaseModel):
|
|
"""Configuration for Graphiti client.
|
|
|
|
Centralizes all configuration parameters for the Graphiti client.
|
|
"""
|
|
|
|
llm: GraphitiLLMConfig = Field(default_factory=GraphitiLLMConfig)
|
|
embedder: GraphitiEmbedderConfig = Field(default_factory=GraphitiEmbedderConfig)
|
|
neo4j: Neo4jConfig = Field(default_factory=Neo4jConfig)
|
|
group_id: str | None = None
|
|
use_custom_entities: bool = False
|
|
destroy_graph: bool = False
|
|
|
|
@classmethod
|
|
def from_env(cls) -> 'GraphitiConfig':
|
|
"""Create a configuration instance from environment variables."""
|
|
return cls(
|
|
llm=GraphitiLLMConfig.from_env(),
|
|
embedder=GraphitiEmbedderConfig.from_env(),
|
|
neo4j=Neo4jConfig.from_env(),
|
|
)
|
|
|
|
@classmethod
|
|
def from_cli_and_env(cls, args: argparse.Namespace) -> 'GraphitiConfig':
|
|
"""Create configuration from CLI arguments, falling back to environment variables."""
|
|
# Start with environment configuration
|
|
config = cls.from_env()
|
|
|
|
# Apply CLI overrides
|
|
if args.group_id:
|
|
config.group_id = args.group_id
|
|
else:
|
|
config.group_id = 'default'
|
|
|
|
config.use_custom_entities = args.use_custom_entities
|
|
config.destroy_graph = args.destroy_graph
|
|
|
|
# Update LLM config using CLI args
|
|
config.llm = GraphitiLLMConfig.from_cli_and_env(args)
|
|
|
|
return config
|
|
|
|
|
|
class MCPConfig(BaseModel):
|
|
"""Configuration for MCP server."""
|
|
|
|
transport: str = 'sse' # Default to SSE transport
|
|
|
|
@classmethod
|
|
def from_cli(cls, args: argparse.Namespace) -> 'MCPConfig':
|
|
"""Create MCP configuration from CLI arguments."""
|
|
return cls(transport=args.transport)
|
|
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
stream=sys.stderr,
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Create global config instance - will be properly initialized later
|
|
config = GraphitiConfig()
|
|
|
|
# MCP server instructions
|
|
GRAPHITI_MCP_INSTRUCTIONS = """
|
|
Graphiti is a memory service for AI agents built on a knowledge graph. Graphiti performs well
|
|
with dynamic data such as user interactions, changing enterprise data, and external information.
|
|
|
|
Graphiti transforms information into a richly connected knowledge network, allowing you to
|
|
capture relationships between concepts, entities, and information. The system organizes data as episodes
|
|
(content snippets), nodes (entities), and facts (relationships between entities), creating a dynamic,
|
|
queryable memory store that evolves with new information. Graphiti supports multiple data formats, including
|
|
structured JSON data, enabling seamless integration with existing data pipelines and systems.
|
|
|
|
Facts contain temporal metadata, allowing you to track the time of creation and whether a fact is invalid
|
|
(superseded by new information).
|
|
|
|
Key capabilities:
|
|
1. Add episodes (text, messages, or JSON) to the knowledge graph with the add_memory tool
|
|
2. Search for nodes (entities) in the graph using natural language queries with search_nodes
|
|
3. Find relevant facts (relationships between entities) with search_facts
|
|
4. Retrieve specific entity edges or episodes by UUID
|
|
5. Manage the knowledge graph with tools like delete_episode, delete_entity_edge, and clear_graph
|
|
|
|
The server connects to a database for persistent storage and uses language models for certain operations.
|
|
Each piece of information is organized by group_id, allowing you to maintain separate knowledge domains.
|
|
|
|
When adding information, provide descriptive names and detailed content to improve search quality.
|
|
When searching, use specific queries and consider filtering by group_id for more relevant results.
|
|
|
|
For optimal performance, ensure the database is properly configured and accessible, and valid
|
|
API keys are provided for any language model operations.
|
|
"""
|
|
|
|
# MCP server instance
|
|
mcp = FastMCP(
|
|
'Graphiti Agent Memory',
|
|
instructions=GRAPHITI_MCP_INSTRUCTIONS,
|
|
)
|
|
|
|
# Initialize Graphiti client
|
|
graphiti_client: Graphiti | None = None
|
|
|
|
|
|
async def initialize_graphiti():
|
|
"""Initialize the Graphiti client with the configured settings."""
|
|
global graphiti_client, config
|
|
|
|
try:
|
|
# Create LLM client if possible
|
|
llm_client = config.llm.create_client()
|
|
if not llm_client and config.use_custom_entities:
|
|
# If custom entities are enabled, we must have an LLM client
|
|
raise ValueError('OPENAI_API_KEY must be set when custom entities are enabled')
|
|
|
|
# Validate Neo4j configuration
|
|
if not config.neo4j.uri or not config.neo4j.user or not config.neo4j.password:
|
|
raise ValueError('NEO4J_URI, NEO4J_USER, and NEO4J_PASSWORD must be set')
|
|
|
|
embedder_client = config.embedder.create_client()
|
|
|
|
# Initialize Graphiti client
|
|
graphiti_client = Graphiti(
|
|
uri=config.neo4j.uri,
|
|
user=config.neo4j.user,
|
|
password=config.neo4j.password,
|
|
llm_client=llm_client,
|
|
embedder=embedder_client,
|
|
max_coroutines=SEMAPHORE_LIMIT,
|
|
)
|
|
|
|
# Destroy graph if requested
|
|
if config.destroy_graph:
|
|
logger.info('Destroying graph...')
|
|
await clear_data(graphiti_client.driver)
|
|
|
|
# Initialize the graph database with Graphiti's indices
|
|
await graphiti_client.build_indices_and_constraints()
|
|
logger.info('Graphiti client initialized successfully')
|
|
|
|
# Log configuration details for transparency
|
|
if llm_client:
|
|
logger.info(f'Using OpenAI model: {config.llm.model}')
|
|
logger.info(f'Using temperature: {config.llm.temperature}')
|
|
else:
|
|
logger.info('No LLM client configured - entity extraction will be limited')
|
|
|
|
logger.info(f'Using group_id: {config.group_id}')
|
|
logger.info(
|
|
f'Custom entity extraction: {"enabled" if config.use_custom_entities else "disabled"}'
|
|
)
|
|
logger.info(f'Using concurrency limit: {SEMAPHORE_LIMIT}')
|
|
|
|
except Exception as e:
|
|
logger.error(f'Failed to initialize Graphiti: {str(e)}')
|
|
raise
|
|
|
|
|
|
def format_fact_result(edge: EntityEdge) -> dict[str, Any]:
|
|
"""Format an entity edge into a readable result.
|
|
|
|
Since EntityEdge is a Pydantic BaseModel, we can use its built-in serialization capabilities.
|
|
|
|
Args:
|
|
edge: The EntityEdge to format
|
|
|
|
Returns:
|
|
A dictionary representation of the edge with serialized dates and excluded embeddings
|
|
"""
|
|
result = edge.model_dump(
|
|
mode='json',
|
|
exclude={
|
|
'fact_embedding',
|
|
},
|
|
)
|
|
result.get('attributes', {}).pop('fact_embedding', None)
|
|
return result
|
|
|
|
|
|
# Dictionary to store queues for each group_id
|
|
# Each queue is a list of tasks to be processed sequentially
|
|
episode_queues: dict[str, asyncio.Queue] = {}
|
|
# Dictionary to track if a worker is running for each group_id
|
|
queue_workers: dict[str, bool] = {}
|
|
|
|
|
|
async def process_episode_queue(group_id: str):
|
|
"""Process episodes for a specific group_id sequentially.
|
|
|
|
This function runs as a long-lived task that processes episodes
|
|
from the queue one at a time.
|
|
"""
|
|
global queue_workers
|
|
|
|
logger.info(f'Starting episode queue worker for group_id: {group_id}')
|
|
queue_workers[group_id] = True
|
|
|
|
try:
|
|
while True:
|
|
# Get the next episode processing function from the queue
|
|
# This will wait if the queue is empty
|
|
process_func = await episode_queues[group_id].get()
|
|
|
|
try:
|
|
# Process the episode
|
|
await process_func()
|
|
except Exception as e:
|
|
logger.error(f'Error processing queued episode for group_id {group_id}: {str(e)}')
|
|
finally:
|
|
# Mark the task as done regardless of success/failure
|
|
episode_queues[group_id].task_done()
|
|
except asyncio.CancelledError:
|
|
logger.info(f'Episode queue worker for group_id {group_id} was cancelled')
|
|
except Exception as e:
|
|
logger.error(f'Unexpected error in queue worker for group_id {group_id}: {str(e)}')
|
|
finally:
|
|
queue_workers[group_id] = False
|
|
logger.info(f'Stopped episode queue worker for group_id: {group_id}')
|
|
|
|
|
|
@mcp.tool()
|
|
async def add_memory(
|
|
name: str,
|
|
episode_body: str,
|
|
group_id: str | None = None,
|
|
source: str = 'text',
|
|
source_description: str = '',
|
|
uuid: str | None = None,
|
|
) -> SuccessResponse | ErrorResponse:
|
|
"""Add an episode to memory. This is the primary way to add information to the graph.
|
|
|
|
This function returns immediately and processes the episode addition in the background.
|
|
Episodes for the same group_id are processed sequentially to avoid race conditions.
|
|
|
|
Args:
|
|
name (str): Name of the episode
|
|
episode_body (str): The content of the episode to persist to memory. When source='json', this must be a
|
|
properly escaped JSON string, not a raw Python dictionary. The JSON data will be
|
|
automatically processed to extract entities and relationships.
|
|
group_id (str, optional): A unique ID for this graph. If not provided, uses the default group_id from CLI
|
|
or a generated one.
|
|
source (str, optional): Source type, must be one of:
|
|
- 'text': For plain text content (default)
|
|
- 'json': For structured data
|
|
- 'message': For conversation-style content
|
|
source_description (str, optional): Description of the source
|
|
uuid (str, optional): Optional UUID for the episode
|
|
|
|
Examples:
|
|
# Adding plain text content
|
|
add_memory(
|
|
name="Company News",
|
|
episode_body="Acme Corp announced a new product line today.",
|
|
source="text",
|
|
source_description="news article",
|
|
group_id="some_arbitrary_string"
|
|
)
|
|
|
|
# Adding structured JSON data
|
|
# NOTE: episode_body must be a properly escaped JSON string. Note the triple backslashes
|
|
add_memory(
|
|
name="Customer Profile",
|
|
episode_body="{\\\"company\\\": {\\\"name\\\": \\\"Acme Technologies\\\"}, \\\"products\\\": [{\\\"id\\\": \\\"P001\\\", \\\"name\\\": \\\"CloudSync\\\"}, {\\\"id\\\": \\\"P002\\\", \\\"name\\\": \\\"DataMiner\\\"}]}",
|
|
source="json",
|
|
source_description="CRM data"
|
|
)
|
|
|
|
# Adding message-style content
|
|
add_memory(
|
|
name="Customer Conversation",
|
|
episode_body="user: What's your return policy?\nassistant: You can return items within 30 days.",
|
|
source="message",
|
|
source_description="chat transcript",
|
|
group_id="some_arbitrary_string"
|
|
)
|
|
|
|
Notes:
|
|
When using source='json':
|
|
- The JSON must be a properly escaped string, not a raw Python dictionary
|
|
- The JSON will be automatically processed to extract entities and relationships
|
|
- Complex nested structures are supported (arrays, nested objects, mixed data types), but keep nesting to a minimum
|
|
- Entities will be created from appropriate JSON properties
|
|
- Relationships between entities will be established based on the JSON structure
|
|
"""
|
|
global graphiti_client, episode_queues, queue_workers
|
|
|
|
if graphiti_client is None:
|
|
return {'error': 'Graphiti client not initialized'}
|
|
|
|
try:
|
|
# Map string source to EpisodeType enum
|
|
source_type = EpisodeType.text
|
|
if source.lower() == 'message':
|
|
source_type = EpisodeType.message
|
|
elif source.lower() == 'json':
|
|
source_type = EpisodeType.json
|
|
|
|
# Use the provided group_id or fall back to the default from config
|
|
effective_group_id = group_id if group_id is not None else config.group_id
|
|
|
|
# Cast group_id to str to satisfy type checker
|
|
# The Graphiti client expects a str for group_id, not Optional[str]
|
|
group_id_str = str(effective_group_id) if effective_group_id is not None else ''
|
|
|
|
# We've already checked that graphiti_client is not None above
|
|
# This assert statement helps type checkers understand that graphiti_client is defined
|
|
assert graphiti_client is not None, 'graphiti_client should not be None here'
|
|
|
|
# Use cast to help the type checker understand that graphiti_client is not None
|
|
client = cast(Graphiti, graphiti_client)
|
|
|
|
# Define the episode processing function
|
|
async def process_episode():
|
|
try:
|
|
logger.info(f"Processing queued episode '{name}' for group_id: {group_id_str}")
|
|
# Use all entity types if use_custom_entities is enabled, otherwise use empty dict
|
|
entity_types = ENTITY_TYPES if config.use_custom_entities else {}
|
|
|
|
await client.add_episode(
|
|
name=name,
|
|
episode_body=episode_body,
|
|
source=source_type,
|
|
source_description=source_description,
|
|
group_id=group_id_str, # Using the string version of group_id
|
|
uuid=uuid,
|
|
reference_time=datetime.now(timezone.utc),
|
|
entity_types=entity_types,
|
|
)
|
|
logger.info(f"Episode '{name}' added successfully")
|
|
|
|
logger.info(f"Episode '{name}' processed successfully")
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(
|
|
f"Error processing episode '{name}' for group_id {group_id_str}: {error_msg}"
|
|
)
|
|
|
|
# Initialize queue for this group_id if it doesn't exist
|
|
if group_id_str not in episode_queues:
|
|
episode_queues[group_id_str] = asyncio.Queue()
|
|
|
|
# Add the episode processing function to the queue
|
|
await episode_queues[group_id_str].put(process_episode)
|
|
|
|
# Start a worker for this queue if one isn't already running
|
|
if not queue_workers.get(group_id_str, False):
|
|
asyncio.create_task(process_episode_queue(group_id_str))
|
|
|
|
# Return immediately with a success message
|
|
return {
|
|
'message': f"Episode '{name}' queued for processing (position: {episode_queues[group_id_str].qsize()})"
|
|
}
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error queuing episode task: {error_msg}')
|
|
return {'error': f'Error queuing episode task: {error_msg}'}
|
|
|
|
|
|
@mcp.tool()
|
|
async def search_memory_nodes(
|
|
query: str,
|
|
group_ids: list[str] | None = None,
|
|
max_nodes: int = 10,
|
|
center_node_uuid: str | None = None,
|
|
entity: str = '', # cursor seems to break with None
|
|
) -> NodeSearchResponse | ErrorResponse:
|
|
"""Search the graph memory for relevant node summaries.
|
|
These contain a summary of all of a node's relationships with other nodes.
|
|
|
|
Note: entity is a single entity type to filter results (permitted: "Preference", "Procedure").
|
|
|
|
Args:
|
|
query: The search query
|
|
group_ids: Optional list of group IDs to filter results
|
|
max_nodes: Maximum number of nodes to return (default: 10)
|
|
center_node_uuid: Optional UUID of a node to center the search around
|
|
entity: Optional single entity type to filter results (permitted: "Preference", "Procedure")
|
|
"""
|
|
global graphiti_client
|
|
|
|
if graphiti_client is None:
|
|
return ErrorResponse(error='Graphiti client not initialized')
|
|
|
|
try:
|
|
# Use the provided group_ids or fall back to the default from config if none provided
|
|
effective_group_ids = (
|
|
group_ids if group_ids is not None else [config.group_id] if config.group_id else []
|
|
)
|
|
|
|
# Configure the search
|
|
if center_node_uuid is not None:
|
|
search_config = NODE_HYBRID_SEARCH_NODE_DISTANCE.model_copy(deep=True)
|
|
else:
|
|
search_config = NODE_HYBRID_SEARCH_RRF.model_copy(deep=True)
|
|
search_config.limit = max_nodes
|
|
|
|
filters = SearchFilters()
|
|
if entity != '':
|
|
filters.node_labels = [entity]
|
|
|
|
# We've already checked that graphiti_client is not None above
|
|
assert graphiti_client is not None
|
|
|
|
# Use cast to help the type checker understand that graphiti_client is not None
|
|
client = cast(Graphiti, graphiti_client)
|
|
|
|
# Perform the search using the _search method
|
|
search_results = await client._search(
|
|
query=query,
|
|
config=search_config,
|
|
group_ids=effective_group_ids,
|
|
center_node_uuid=center_node_uuid,
|
|
search_filter=filters,
|
|
)
|
|
|
|
if not search_results.nodes:
|
|
return NodeSearchResponse(message='No relevant nodes found', nodes=[])
|
|
|
|
# Format the node results
|
|
formatted_nodes: list[NodeResult] = [
|
|
{
|
|
'uuid': node.uuid,
|
|
'name': node.name,
|
|
'summary': node.summary if hasattr(node, 'summary') else '',
|
|
'labels': node.labels if hasattr(node, 'labels') else [],
|
|
'group_id': node.group_id,
|
|
'created_at': node.created_at.isoformat(),
|
|
'attributes': node.attributes if hasattr(node, 'attributes') else {},
|
|
}
|
|
for node in search_results.nodes
|
|
]
|
|
|
|
return NodeSearchResponse(message='Nodes retrieved successfully', nodes=formatted_nodes)
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error searching nodes: {error_msg}')
|
|
return ErrorResponse(error=f'Error searching nodes: {error_msg}')
|
|
|
|
|
|
@mcp.tool()
|
|
async def search_memory_facts(
|
|
query: str,
|
|
group_ids: list[str] | None = None,
|
|
max_facts: int = 10,
|
|
center_node_uuid: str | None = None,
|
|
) -> FactSearchResponse | ErrorResponse:
|
|
"""Search the graph memory for relevant facts.
|
|
|
|
Args:
|
|
query: The search query
|
|
group_ids: Optional list of group IDs to filter results
|
|
max_facts: Maximum number of facts to return (default: 10)
|
|
center_node_uuid: Optional UUID of a node to center the search around
|
|
"""
|
|
global graphiti_client
|
|
|
|
if graphiti_client is None:
|
|
return {'error': 'Graphiti client not initialized'}
|
|
|
|
try:
|
|
# Use the provided group_ids or fall back to the default from config if none provided
|
|
effective_group_ids = (
|
|
group_ids if group_ids is not None else [config.group_id] if config.group_id else []
|
|
)
|
|
|
|
# We've already checked that graphiti_client is not None above
|
|
assert graphiti_client is not None
|
|
|
|
# Use cast to help the type checker understand that graphiti_client is not None
|
|
client = cast(Graphiti, graphiti_client)
|
|
|
|
relevant_edges = await client.search(
|
|
group_ids=effective_group_ids,
|
|
query=query,
|
|
num_results=max_facts,
|
|
center_node_uuid=center_node_uuid,
|
|
)
|
|
|
|
if not relevant_edges:
|
|
return {'message': 'No relevant facts found', 'facts': []}
|
|
|
|
facts = [format_fact_result(edge) for edge in relevant_edges]
|
|
return {'message': 'Facts retrieved successfully', 'facts': facts}
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error searching facts: {error_msg}')
|
|
return {'error': f'Error searching facts: {error_msg}'}
|
|
|
|
|
|
@mcp.tool()
|
|
async def delete_entity_edge(uuid: str) -> SuccessResponse | ErrorResponse:
|
|
"""Delete an entity edge from the graph memory.
|
|
|
|
Args:
|
|
uuid: UUID of the entity edge to delete
|
|
"""
|
|
global graphiti_client
|
|
|
|
if graphiti_client is None:
|
|
return {'error': 'Graphiti client not initialized'}
|
|
|
|
try:
|
|
# We've already checked that graphiti_client is not None above
|
|
assert graphiti_client is not None
|
|
|
|
# Use cast to help the type checker understand that graphiti_client is not None
|
|
client = cast(Graphiti, graphiti_client)
|
|
|
|
# Get the entity edge by UUID
|
|
entity_edge = await EntityEdge.get_by_uuid(client.driver, uuid)
|
|
# Delete the edge using its delete method
|
|
await entity_edge.delete(client.driver)
|
|
return {'message': f'Entity edge with UUID {uuid} deleted successfully'}
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error deleting entity edge: {error_msg}')
|
|
return {'error': f'Error deleting entity edge: {error_msg}'}
|
|
|
|
|
|
@mcp.tool()
|
|
async def delete_episode(uuid: str) -> SuccessResponse | ErrorResponse:
|
|
"""Delete an episode from the graph memory.
|
|
|
|
Args:
|
|
uuid: UUID of the episode to delete
|
|
"""
|
|
global graphiti_client
|
|
|
|
if graphiti_client is None:
|
|
return {'error': 'Graphiti client not initialized'}
|
|
|
|
try:
|
|
# We've already checked that graphiti_client is not None above
|
|
assert graphiti_client is not None
|
|
|
|
# Use cast to help the type checker understand that graphiti_client is not None
|
|
client = cast(Graphiti, graphiti_client)
|
|
|
|
# Get the episodic node by UUID - EpisodicNode is already imported at the top
|
|
episodic_node = await EpisodicNode.get_by_uuid(client.driver, uuid)
|
|
# Delete the node using its delete method
|
|
await episodic_node.delete(client.driver)
|
|
return {'message': f'Episode with UUID {uuid} deleted successfully'}
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error deleting episode: {error_msg}')
|
|
return {'error': f'Error deleting episode: {error_msg}'}
|
|
|
|
|
|
@mcp.tool()
|
|
async def get_entity_edge(uuid: str) -> dict[str, Any] | ErrorResponse:
|
|
"""Get an entity edge from the graph memory by its UUID.
|
|
|
|
Args:
|
|
uuid: UUID of the entity edge to retrieve
|
|
"""
|
|
global graphiti_client
|
|
|
|
if graphiti_client is None:
|
|
return {'error': 'Graphiti client not initialized'}
|
|
|
|
try:
|
|
# We've already checked that graphiti_client is not None above
|
|
assert graphiti_client is not None
|
|
|
|
# Use cast to help the type checker understand that graphiti_client is not None
|
|
client = cast(Graphiti, graphiti_client)
|
|
|
|
# Get the entity edge directly using the EntityEdge class method
|
|
entity_edge = await EntityEdge.get_by_uuid(client.driver, uuid)
|
|
|
|
# Use the format_fact_result function to serialize the edge
|
|
# Return the Python dict directly - MCP will handle serialization
|
|
return format_fact_result(entity_edge)
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error getting entity edge: {error_msg}')
|
|
return {'error': f'Error getting entity edge: {error_msg}'}
|
|
|
|
|
|
@mcp.tool()
|
|
async def get_episodes(
|
|
group_id: str | None = None, last_n: int = 10
|
|
) -> list[dict[str, Any]] | EpisodeSearchResponse | ErrorResponse:
|
|
"""Get the most recent memory episodes for a specific group.
|
|
|
|
Args:
|
|
group_id: ID of the group to retrieve episodes from. If not provided, uses the default group_id.
|
|
last_n: Number of most recent episodes to retrieve (default: 10)
|
|
"""
|
|
global graphiti_client
|
|
|
|
if graphiti_client is None:
|
|
return {'error': 'Graphiti client not initialized'}
|
|
|
|
try:
|
|
# Use the provided group_id or fall back to the default from config
|
|
effective_group_id = group_id if group_id is not None else config.group_id
|
|
|
|
if not isinstance(effective_group_id, str):
|
|
return {'error': 'Group ID must be a string'}
|
|
|
|
# We've already checked that graphiti_client is not None above
|
|
assert graphiti_client is not None
|
|
|
|
# Use cast to help the type checker understand that graphiti_client is not None
|
|
client = cast(Graphiti, graphiti_client)
|
|
|
|
episodes = await client.retrieve_episodes(
|
|
group_ids=[effective_group_id], last_n=last_n, reference_time=datetime.now(timezone.utc)
|
|
)
|
|
|
|
if not episodes:
|
|
return {'message': f'No episodes found for group {effective_group_id}', 'episodes': []}
|
|
|
|
# Use Pydantic's model_dump method for EpisodicNode serialization
|
|
formatted_episodes = [
|
|
# Use mode='json' to handle datetime serialization
|
|
episode.model_dump(mode='json')
|
|
for episode in episodes
|
|
]
|
|
|
|
# Return the Python list directly - MCP will handle serialization
|
|
return formatted_episodes
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error getting episodes: {error_msg}')
|
|
return {'error': f'Error getting episodes: {error_msg}'}
|
|
|
|
|
|
@mcp.tool()
|
|
async def clear_graph() -> SuccessResponse | ErrorResponse:
|
|
"""Clear all data from the graph memory and rebuild indices."""
|
|
global graphiti_client
|
|
|
|
if graphiti_client is None:
|
|
return {'error': 'Graphiti client not initialized'}
|
|
|
|
try:
|
|
# We've already checked that graphiti_client is not None above
|
|
assert graphiti_client is not None
|
|
|
|
# Use cast to help the type checker understand that graphiti_client is not None
|
|
client = cast(Graphiti, graphiti_client)
|
|
|
|
# clear_data is already imported at the top
|
|
await clear_data(client.driver)
|
|
await client.build_indices_and_constraints()
|
|
return {'message': 'Graph cleared successfully and indices rebuilt'}
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error clearing graph: {error_msg}')
|
|
return {'error': f'Error clearing graph: {error_msg}'}
|
|
|
|
|
|
@mcp.resource('http://graphiti/status')
|
|
async def get_status() -> StatusResponse:
|
|
"""Get the status of the Graphiti MCP server and Neo4j connection."""
|
|
global graphiti_client
|
|
|
|
if graphiti_client is None:
|
|
return {'status': 'error', 'message': 'Graphiti client not initialized'}
|
|
|
|
try:
|
|
# We've already checked that graphiti_client is not None above
|
|
assert graphiti_client is not None
|
|
|
|
# Use cast to help the type checker understand that graphiti_client is not None
|
|
client = cast(Graphiti, graphiti_client)
|
|
|
|
# Test Neo4j connection
|
|
await client.driver.verify_connectivity()
|
|
return {'status': 'ok', 'message': 'Graphiti MCP server is running and connected to Neo4j'}
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error checking Neo4j connection: {error_msg}')
|
|
return {
|
|
'status': 'error',
|
|
'message': f'Graphiti MCP server is running but Neo4j connection failed: {error_msg}',
|
|
}
|
|
|
|
|
|
async def initialize_server() -> MCPConfig:
|
|
"""Parse CLI arguments and initialize the Graphiti server configuration."""
|
|
global config
|
|
|
|
parser = argparse.ArgumentParser(
|
|
description='Run the Graphiti MCP server with optional LLM client'
|
|
)
|
|
parser.add_argument(
|
|
'--group-id',
|
|
help='Namespace for the graph. This is an arbitrary string used to organize related data. '
|
|
'If not provided, a random UUID will be generated.',
|
|
)
|
|
parser.add_argument(
|
|
'--transport',
|
|
choices=['sse', 'stdio'],
|
|
default='sse',
|
|
help='Transport to use for communication with the client. (default: sse)',
|
|
)
|
|
parser.add_argument(
|
|
'--model', help=f'Model name to use with the LLM client. (default: {DEFAULT_LLM_MODEL})'
|
|
)
|
|
parser.add_argument(
|
|
'--small-model',
|
|
help=f'Small model name to use with the LLM client. (default: {SMALL_LLM_MODEL})',
|
|
)
|
|
parser.add_argument(
|
|
'--temperature',
|
|
type=float,
|
|
help='Temperature setting for the LLM (0.0-2.0). Lower values make output more deterministic. (default: 0.7)',
|
|
)
|
|
parser.add_argument('--destroy-graph', action='store_true', help='Destroy all Graphiti graphs')
|
|
parser.add_argument(
|
|
'--use-custom-entities',
|
|
action='store_true',
|
|
help='Enable entity extraction using the predefined ENTITY_TYPES',
|
|
)
|
|
parser.add_argument(
|
|
'--host',
|
|
default=os.environ.get('MCP_SERVER_HOST'),
|
|
help='Host to bind the MCP server to (default: MCP_SERVER_HOST environment variable)',
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Build configuration from CLI arguments and environment variables
|
|
config = GraphitiConfig.from_cli_and_env(args)
|
|
|
|
# Log the group ID configuration
|
|
if args.group_id:
|
|
logger.info(f'Using provided group_id: {config.group_id}')
|
|
else:
|
|
logger.info(f'Generated random group_id: {config.group_id}')
|
|
|
|
# Log entity extraction configuration
|
|
if config.use_custom_entities:
|
|
logger.info('Entity extraction enabled using predefined ENTITY_TYPES')
|
|
else:
|
|
logger.info('Entity extraction disabled (no custom entities will be used)')
|
|
|
|
# Initialize Graphiti
|
|
await initialize_graphiti()
|
|
|
|
if args.host:
|
|
logger.info(f'Setting MCP server host to: {args.host}')
|
|
# Set MCP server host from CLI or env
|
|
mcp.settings.host = args.host
|
|
|
|
# Return MCP configuration
|
|
return MCPConfig.from_cli(args)
|
|
|
|
|
|
async def run_mcp_server():
|
|
"""Run the MCP server in the current event loop."""
|
|
# Initialize the server
|
|
mcp_config = await initialize_server()
|
|
|
|
# Run the server with stdio transport for MCP in the same event loop
|
|
logger.info(f'Starting MCP server with transport: {mcp_config.transport}')
|
|
if mcp_config.transport == 'stdio':
|
|
await mcp.run_stdio_async()
|
|
elif mcp_config.transport == 'sse':
|
|
logger.info(
|
|
f'Running MCP server with SSE transport on {mcp.settings.host}:{mcp.settings.port}'
|
|
)
|
|
await mcp.run_sse_async()
|
|
|
|
|
|
def main():
|
|
"""Main function to run the Graphiti MCP server."""
|
|
try:
|
|
# Run everything in a single event loop
|
|
asyncio.run(run_mcp_server())
|
|
except Exception as e:
|
|
logger.error(f'Error initializing Graphiti MCP server: {str(e)}')
|
|
raise
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|