[REFACTOR][FIX] Move away from DEFAULT_DATABASE environment variable in favour of driver-config support (dc) (#699)

* fix: remove global DEFAULT_DATABASE usage in favor of driver-specific
config

Fixes bugs introduced in PR #607. This removes reliance on the global
DEFAULT_DATABASE environment variable. It specifies the database within
each driver. PR #607 introduced a Neo4j compatability, as the database
names are different when attempting to support FalkorDB.

This refactor improves compatability across database types and ensures
future reliance by isolating the configuraiton to the driver level.

* fix: make falkordb support optional

This ensures that the the optional dependency and subsequent import is compliant with the graphiti-core project dependencies.

* chore: fmt code

* chore: undo changes to uv.lock

* fix: undo potentially breaking changes to drive interface

* fix: ensure a default database of "None" is provided - falling back to internal default

* chore: ensure default value exists for session and delete_all_indexes

* chore: fix typos and grammar

* chore: update package versions and dependencies in uv.lock and bulk_utils.py

* docs: update database configuration instructions for Neo4j and FalkorDB

Clarified default database names and how to override them in driver constructors. Updated testing requirements to include specific commands for running integration and unit tests.

* fix: ensure params defaults to an empty dictionary in Neo4jDriver

Updated the execute_query method to initialize params as an empty dictionary if not provided, ensuring compatibility with the database configuration.

---------

Co-authored-by: Urmzd <urmzd@dal.ca>
This commit is contained in:
Daniel Chalef 2025-07-10 14:25:39 -07:00 committed by GitHub
parent e5a61de931
commit aa6e38856a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 2124 additions and 2108 deletions

View File

@ -12,7 +12,6 @@ FALKORDB_PORT=
FALKORDB_USER=
FALKORDB_PASSWORD=
DEFAULT_DATABASE=
USE_PARALLEL_RUNTIME=
SEMAPHORE_LIMIT=
GITHUB_SHA=

View File

@ -102,7 +102,11 @@ docker-compose up
### Database Setup
- **Neo4j**: Version 5.26+ required, available via Neo4j Desktop
- Database name defaults to `neo4j` (hardcoded in Neo4jDriver)
- Override by passing `database` parameter to driver constructor
- **FalkorDB**: Version 1.1.2+ as alternative backend
- Database name defaults to `default_db` (hardcoded in FalkorDriver)
- Override by passing `database` parameter to driver constructor
## Development Guidelines
@ -117,8 +121,12 @@ docker-compose up
### Testing Requirements
- Run tests with `make test` or `pytest`
- Integration tests require database connections
- Integration tests require database connections and are marked with `_int` suffix
- Use `pytest-xdist` for parallel test execution
- Run specific test files: `pytest tests/test_specific_file.py`
- Run specific test methods: `pytest tests/test_file.py::test_method_name`
- Run only integration tests: `pytest tests/ -k "_int"`
- Run only unit tests: `pytest tests/ -k "not _int"`
### LLM Provider Support

View File

@ -215,22 +215,18 @@ must be set.
### Database Configuration
`DEFAULT_DATABASE` specifies the database name to use for graph operations. This is particularly important for Neo4j 5+ users:
Database names are configured directly in the driver constructors:
- **Neo4j 5+**: The default database name is `neo4j` (not `default_db`)
- **Neo4j 4**: The default database name is `default_db`
- **FalkorDB**: The default graph name is `default_db`
- **Neo4j**: Database name defaults to `neo4j` (hardcoded in Neo4jDriver)
- **FalkorDB**: Database name defaults to `default_db` (hardcoded in FalkorDriver)
If you encounter the error `Graph not found: default_db` when using Neo4j 5, set:
To use a different database name, pass the `database` parameter when creating the driver:
```bash
export DEFAULT_DATABASE=neo4j
```
```python
from graphiti_core.driver.neo4j_driver import Neo4jDriver
Or add to your `.env` file:
```
DEFAULT_DATABASE=neo4j
# Use custom database name
driver = Neo4jDriver(uri="bolt://localhost:7687", user="neo4j", password="password", database="my_db")
```
### Performance Configuration

View File

@ -42,8 +42,7 @@ export NEO4J_PASSWORD=password
# Optional FalkorDB connection parameters (defaults shown)
export FALKORDB_URI=falkor://localhost:6379
# Database configuration (required for Neo4j 5+)
export DEFAULT_DATABASE=neo4j
# To use a different database, modify the driver constructor in the script
```
3. Run the example:
@ -78,21 +77,18 @@ After running this example, you can:
### "Graph not found: default_db" Error
If you encounter the error `Neo.ClientError.Database.DatabaseNotFound: Graph not found: default_db`, this typically occurs with Neo4j 5+ where the default database name is `neo4j` instead of `default_db`.
If you encounter the error `Neo.ClientError.Database.DatabaseNotFound: Graph not found: default_db`, this occurs when the driver is trying to connect to a database that doesn't exist.
**Solution:**
Set the `DEFAULT_DATABASE` environment variable to `neo4j`:
The Neo4j driver defaults to using `neo4j` as the database name. If you need to use a different database, modify the driver constructor in the script:
```bash
export DEFAULT_DATABASE=neo4j
```
```python
# In quickstart_neo4j.py, change:
driver = Neo4jDriver(uri=neo4j_uri, user=neo4j_user, password=neo4j_password)
Or add it to your `.env` file:
# To specify a different database:
driver = Neo4jDriver(uri=neo4j_uri, user=neo4j_user, password=neo4j_password, database="your_db_name")
```
DEFAULT_DATABASE=neo4j
```
This tells Graphiti to use the correct database name for your Neo4j version.
## Understanding the Output

View File

@ -19,8 +19,6 @@ from abc import ABC, abstractmethod
from collections.abc import Coroutine
from typing import Any
from graphiti_core.helpers import DEFAULT_DATABASE
logger = logging.getLogger(__name__)
@ -54,7 +52,7 @@ class GraphDriver(ABC):
raise NotImplementedError()
@abstractmethod
def session(self, database: str) -> GraphDriverSession:
def session(self, database: str | None = None) -> GraphDriverSession:
raise NotImplementedError()
@abstractmethod
@ -62,5 +60,5 @@ class GraphDriver(ABC):
raise NotImplementedError()
@abstractmethod
def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
def delete_all_indexes(self, database_: str | None = None) -> Coroutine:
raise NotImplementedError()

View File

@ -33,7 +33,6 @@ else:
) from None
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
from graphiti_core.helpers import DEFAULT_DATABASE
logger = logging.getLogger(__name__)
@ -81,6 +80,7 @@ class FalkorDriver(GraphDriver):
username: str | None = None,
password: str | None = None,
falkor_db: FalkorDB | None = None,
database: str = 'default_db',
):
"""
Initialize the FalkorDB driver.
@ -95,15 +95,16 @@ class FalkorDriver(GraphDriver):
self.client = falkor_db
else:
self.client = FalkorDB(host=host, port=port, username=username, password=password)
self._database = database
def _get_graph(self, graph_name: str | None) -> FalkorGraph:
# FalkorDB requires a non-None database name for multi-tenant graphs; the default is DEFAULT_DATABASE
# FalkorDB requires a non-None database name for multi-tenant graphs; the default is "default_db"
if graph_name is None:
graph_name = DEFAULT_DATABASE
graph_name = self._database
return self.client.select_graph(graph_name)
async def execute_query(self, cypher_query_, **kwargs: Any):
graph_name = kwargs.pop('database_', DEFAULT_DATABASE)
graph_name = kwargs.pop('database_', self._database)
graph = self._get_graph(graph_name)
# Convert datetime objects to ISO strings (FalkorDB does not support datetime objects directly)
@ -136,7 +137,7 @@ class FalkorDriver(GraphDriver):
return records, header, None
def session(self, database: str | None) -> GraphDriverSession:
def session(self, database: str | None = None) -> GraphDriverSession:
return FalkorDriverSession(self._get_graph(database))
async def close(self) -> None:
@ -148,10 +149,11 @@ class FalkorDriver(GraphDriver):
elif hasattr(self.client.connection, 'close'):
await self.client.connection.close()
async def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> None:
async def delete_all_indexes(self, database_: str | None = None) -> None:
database = database_ or self._database
await self.execute_query(
'CALL db.indexes() YIELD name DROP INDEX name',
database_=database_,
database_=database,
)

View File

@ -22,7 +22,6 @@ from neo4j import AsyncGraphDatabase, EagerResult
from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
from graphiti_core.helpers import DEFAULT_DATABASE
logger = logging.getLogger(__name__)
@ -30,34 +29,36 @@ logger = logging.getLogger(__name__)
class Neo4jDriver(GraphDriver):
provider: str = 'neo4j'
def __init__(
self,
uri: str,
user: str | None,
password: str | None,
):
def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'neo4j'):
super().__init__()
self.client = AsyncGraphDatabase.driver(
uri=uri,
auth=(user or '', password or ''),
)
self._database = database
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
# Check if database_ is provided in kwargs.
# If not populated, set the value to retain backwards compatibility
params = kwargs.pop('params', None)
if params is None:
params = {}
params.setdefault('database_', self._database)
result = await self.client.execute_query(cypher_query_, parameters_=params, **kwargs)
return result
def session(self, database: str) -> GraphDriverSession:
return self.client.session(database=database) # type: ignore
def session(self, database: str | None = None) -> GraphDriverSession:
_database = database or self._database
return self.client.session(database=_database) # type: ignore
async def close(self) -> None:
return await self.client.close()
def delete_all_indexes(
self, database_: str = DEFAULT_DATABASE
) -> Coroutine[Any, Any, EagerResult]:
def delete_all_indexes(self, database_: str | None = None) -> Coroutine[Any, Any, EagerResult]:
database = database_ or self._database
return self.client.execute_query(
'CALL db.indexes() YIELD name DROP INDEX name',
database_=database_,
database_=database,
)

View File

@ -27,7 +27,7 @@ from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date
from graphiti_core.helpers import parse_db_date
from graphiti_core.models.edges.edge_db_queries import (
COMMUNITY_EDGE_SAVE,
ENTITY_EDGE_SAVE,
@ -71,7 +71,6 @@ class Edge(BaseModel, ABC):
DELETE e
""",
uuid=self.uuid,
database_=DEFAULT_DATABASE,
)
logger.debug(f'Deleted Edge: {self.uuid}')
@ -99,7 +98,6 @@ class EpisodicEdge(Edge):
uuid=self.uuid,
group_id=self.group_id,
created_at=self.created_at,
database_=DEFAULT_DATABASE,
)
logger.debug(f'Saved edge to Graph: {self.uuid}')
@ -119,7 +117,6 @@ class EpisodicEdge(Edge):
e.created_at AS created_at
""",
uuid=uuid,
database_=DEFAULT_DATABASE,
routing_='r',
)
@ -143,7 +140,6 @@ class EpisodicEdge(Edge):
e.created_at AS created_at
""",
uuids=uuids,
database_=DEFAULT_DATABASE,
routing_='r',
)
@ -183,7 +179,6 @@ class EpisodicEdge(Edge):
group_ids=group_ids,
uuid=uuid_cursor,
limit=limit,
database_=DEFAULT_DATABASE,
routing_='r',
)
@ -231,9 +226,7 @@ class EntityEdge(Edge):
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
RETURN e.fact_embedding AS fact_embedding
"""
records, _, _ = await driver.execute_query(
query, uuid=self.uuid, database_=DEFAULT_DATABASE, routing_='r'
)
records, _, _ = await driver.execute_query(query, uuid=self.uuid, routing_='r')
if len(records) == 0:
raise EdgeNotFoundError(self.uuid)
@ -261,7 +254,6 @@ class EntityEdge(Edge):
result = await driver.execute_query(
ENTITY_EDGE_SAVE,
edge_data=edge_data,
database_=DEFAULT_DATABASE,
)
logger.debug(f'Saved edge to Graph: {self.uuid}')
@ -276,7 +268,6 @@ class EntityEdge(Edge):
"""
+ ENTITY_EDGE_RETURN,
uuid=uuid,
database_=DEFAULT_DATABASE,
routing_='r',
)
@ -298,7 +289,6 @@ class EntityEdge(Edge):
"""
+ ENTITY_EDGE_RETURN,
uuids=uuids,
database_=DEFAULT_DATABASE,
routing_='r',
)
@ -331,7 +321,6 @@ class EntityEdge(Edge):
group_ids=group_ids,
uuid=uuid_cursor,
limit=limit,
database_=DEFAULT_DATABASE,
routing_='r',
)
@ -349,9 +338,7 @@ class EntityEdge(Edge):
"""
+ ENTITY_EDGE_RETURN
)
records, _, _ = await driver.execute_query(
query, node_uuid=node_uuid, database_=DEFAULT_DATABASE, routing_='r'
)
records, _, _ = await driver.execute_query(query, node_uuid=node_uuid, routing_='r')
edges = [get_entity_edge_from_record(record) for record in records]
@ -367,7 +354,6 @@ class CommunityEdge(Edge):
uuid=self.uuid,
group_id=self.group_id,
created_at=self.created_at,
database_=DEFAULT_DATABASE,
)
logger.debug(f'Saved edge to Graph: {self.uuid}')
@ -387,7 +373,6 @@ class CommunityEdge(Edge):
e.created_at AS created_at
""",
uuid=uuid,
database_=DEFAULT_DATABASE,
routing_='r',
)
@ -409,7 +394,6 @@ class CommunityEdge(Edge):
e.created_at AS created_at
""",
uuids=uuids,
database_=DEFAULT_DATABASE,
routing_='r',
)
@ -447,7 +431,6 @@ class CommunityEdge(Edge):
group_ids=group_ids,
uuid=uuid_cursor,
limit=limit,
database_=DEFAULT_DATABASE,
routing_='r',
)

View File

@ -30,7 +30,6 @@ from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import (
DEFAULT_DATABASE,
semaphore_gather,
validate_excluded_entity_types,
validate_group_id,
@ -168,7 +167,6 @@ class Graphiti:
raise ValueError('uri must be provided when graph_driver is None')
self.driver = Neo4jDriver(uri, user, password)
self.database = DEFAULT_DATABASE
self.store_raw_episode_content = store_raw_episode_content
self.max_coroutines = max_coroutines
if llm_client:
@ -921,9 +919,7 @@ class Graphiti:
nodes_to_delete: list[EntityNode] = []
for node in nodes:
query: LiteralString = 'MATCH (e:Episodic)-[:MENTIONS]->(n:Entity {uuid: $uuid}) RETURN count(*) AS episode_count'
records, _, _ = await self.driver.execute_query(
query, uuid=node.uuid, database_=DEFAULT_DATABASE, routing_='r'
)
records, _, _ = await self.driver.execute_query(query, uuid=node.uuid, routing_='r')
for record in records:
if record['episode_count'] == 1:

View File

@ -32,7 +32,6 @@ from graphiti_core.errors import GroupIdValidationError
load_dotenv()
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', 'default_db')
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 0))

View File

@ -28,7 +28,7 @@ from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import NodeNotFoundError
from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date
from graphiti_core.helpers import parse_db_date
from graphiti_core.models.nodes.node_db_queries import (
COMMUNITY_NODE_SAVE,
ENTITY_NODE_SAVE,
@ -103,7 +103,6 @@ class Node(BaseModel, ABC):
DETACH DELETE n
""",
uuid=self.uuid,
database_=DEFAULT_DATABASE,
)
logger.debug(f'Deleted Node: {self.uuid}')
@ -126,7 +125,6 @@ class Node(BaseModel, ABC):
DETACH DELETE n
""",
group_id=group_id,
database_=DEFAULT_DATABASE,
)
return 'SUCCESS'
@ -162,7 +160,6 @@ class EpisodicNode(Node):
created_at=self.created_at,
valid_at=self.valid_at,
source=self.source.value,
database_=DEFAULT_DATABASE,
)
logger.debug(f'Saved Node to Graph: {self.uuid}')
@ -185,7 +182,6 @@ class EpisodicNode(Node):
e.entity_edges AS entity_edges
""",
uuid=uuid,
database_=DEFAULT_DATABASE,
routing_='r',
)
@ -213,7 +209,6 @@ class EpisodicNode(Node):
e.entity_edges AS entity_edges
""",
uuids=uuids,
database_=DEFAULT_DATABASE,
routing_='r',
)
@ -254,7 +249,6 @@ class EpisodicNode(Node):
group_ids=group_ids,
uuid=uuid_cursor,
limit=limit,
database_=DEFAULT_DATABASE,
routing_='r',
)
@ -279,7 +273,6 @@ class EpisodicNode(Node):
e.entity_edges AS entity_edges
""",
entity_node_uuid=entity_node_uuid,
database_=DEFAULT_DATABASE,
routing_='r',
)
@ -309,9 +302,7 @@ class EntityNode(Node):
MATCH (n:Entity {uuid: $uuid})
RETURN n.name_embedding AS name_embedding
"""
records, _, _ = await driver.execute_query(
query, uuid=self.uuid, database_=DEFAULT_DATABASE, routing_='r'
)
records, _, _ = await driver.execute_query(query, uuid=self.uuid, routing_='r')
if len(records) == 0:
raise NodeNotFoundError(self.uuid)
@ -334,7 +325,6 @@ class EntityNode(Node):
ENTITY_NODE_SAVE,
labels=self.labels + ['Entity'],
entity_data=entity_data,
database_=DEFAULT_DATABASE,
)
logger.debug(f'Saved Node to Graph: {self.uuid}')
@ -352,7 +342,6 @@ class EntityNode(Node):
records, _, _ = await driver.execute_query(
query,
uuid=uuid,
database_=DEFAULT_DATABASE,
routing_='r',
)
@ -371,7 +360,6 @@ class EntityNode(Node):
"""
+ ENTITY_NODE_RETURN,
uuids=uuids,
database_=DEFAULT_DATABASE,
routing_='r',
)
@ -403,7 +391,6 @@ class EntityNode(Node):
group_ids=group_ids,
uuid=uuid_cursor,
limit=limit,
database_=DEFAULT_DATABASE,
routing_='r',
)
@ -425,7 +412,6 @@ class CommunityNode(Node):
summary=self.summary,
name_embedding=self.name_embedding,
created_at=self.created_at,
database_=DEFAULT_DATABASE,
)
logger.debug(f'Saved Node to Graph: {self.uuid}')
@ -446,9 +432,7 @@ class CommunityNode(Node):
MATCH (c:Community {uuid: $uuid})
RETURN c.name_embedding AS name_embedding
"""
records, _, _ = await driver.execute_query(
query, uuid=self.uuid, database_=DEFAULT_DATABASE, routing_='r'
)
records, _, _ = await driver.execute_query(query, uuid=self.uuid, routing_='r')
if len(records) == 0:
raise NodeNotFoundError(self.uuid)
@ -468,7 +452,6 @@ class CommunityNode(Node):
n.summary AS summary
""",
uuid=uuid,
database_=DEFAULT_DATABASE,
routing_='r',
)
@ -492,7 +475,6 @@ class CommunityNode(Node):
n.summary AS summary
""",
uuids=uuids,
database_=DEFAULT_DATABASE,
routing_='r',
)
@ -529,7 +511,6 @@ class CommunityNode(Node):
group_ids=group_ids,
uuid=uuid_cursor,
limit=limit,
database_=DEFAULT_DATABASE,
routing_='r',
)

View File

@ -31,7 +31,6 @@ from graphiti_core.graph_queries import (
get_vector_cosine_func_query,
)
from graphiti_core.helpers import (
DEFAULT_DATABASE,
RUNTIME_QUERY,
lucene_sanitize,
normalize_l2,
@ -116,7 +115,6 @@ async def get_mentioned_nodes(
records, _, _ = await driver.execute_query(
query,
uuids=episode_uuids,
database_=DEFAULT_DATABASE,
routing_='r',
)
@ -143,7 +141,6 @@ async def get_communities_by_nodes(
records, _, _ = await driver.execute_query(
query,
uuids=node_uuids,
database_=DEFAULT_DATABASE,
routing_='r',
)
@ -198,7 +195,6 @@ async def edge_fulltext_search(
query=fuzzy_query,
group_ids=group_ids,
limit=limit,
database_=DEFAULT_DATABASE,
routing_='r',
)
@ -274,7 +270,6 @@ async def edge_similarity_search(
group_ids=group_ids,
limit=limit,
min_score=min_score,
database_=DEFAULT_DATABASE,
routing_='r',
)
@ -329,7 +324,6 @@ async def edge_bfs_search(
bfs_origin_node_uuids=bfs_origin_node_uuids,
depth=bfs_max_depth,
limit=limit,
database_=DEFAULT_DATABASE,
routing_='r',
)
@ -371,7 +365,6 @@ async def node_fulltext_search(
query=fuzzy_query,
group_ids=group_ids,
limit=limit,
database_=DEFAULT_DATABASE,
routing_='r',
)
@ -425,7 +418,6 @@ async def node_similarity_search(
group_ids=group_ids,
limit=limit,
min_score=min_score,
database_=DEFAULT_DATABASE,
routing_='r',
)
@ -465,7 +457,6 @@ async def node_bfs_search(
bfs_origin_node_uuids=bfs_origin_node_uuids,
depth=bfs_max_depth,
limit=limit,
database_=DEFAULT_DATABASE,
routing_='r',
)
nodes = [get_entity_node_from_record(record) for record in records]
@ -511,7 +502,6 @@ async def episode_fulltext_search(
query=fuzzy_query,
group_ids=group_ids,
limit=limit,
database_=DEFAULT_DATABASE,
routing_='r',
)
episodes = [get_episodic_node_from_record(record) for record in records]
@ -551,7 +541,6 @@ async def community_fulltext_search(
query=fuzzy_query,
group_ids=group_ids,
limit=limit,
database_=DEFAULT_DATABASE,
routing_='r',
)
communities = [get_community_node_from_record(record) for record in records]
@ -603,7 +592,6 @@ async def community_similarity_search(
group_ids=group_ids,
limit=limit,
min_score=min_score,
database_=DEFAULT_DATABASE,
routing_='r',
)
communities = [get_community_node_from_record(record) for record in records]
@ -764,7 +752,6 @@ async def get_relevant_nodes(
group_id=group_id,
limit=limit,
min_score=min_score,
database_=DEFAULT_DATABASE,
routing_='r',
)
@ -834,7 +821,6 @@ async def get_relevant_edges(
edges=[edge.model_dump() for edge in edges],
limit=limit,
min_score=min_score,
database_=DEFAULT_DATABASE,
routing_='r',
)
@ -905,7 +891,6 @@ async def get_edge_invalidation_candidates(
edges=[edge.model_dump() for edge in edges],
limit=limit,
min_score=min_score,
database_=DEFAULT_DATABASE,
routing_='r',
)
invalidation_edges_dict: dict[str, list[EntityEdge]] = {
@ -955,7 +940,6 @@ async def node_distance_reranker(
query,
node_uuids=filtered_uuids,
center_uuid=center_node_uuid,
database_=DEFAULT_DATABASE,
routing_='r',
)
if driver.provider == 'falkordb':
@ -997,7 +981,6 @@ async def episode_mentions_reranker(
results, _, _ = await driver.execute_query(
query,
node_uuids=sorted_uuids,
database_=DEFAULT_DATABASE,
routing_='r',
)
@ -1060,7 +1043,7 @@ async def get_embeddings_for_nodes(
"""
results, _, _ = await driver.execute_query(
query, node_uuids=[node.uuid for node in nodes], database_=DEFAULT_DATABASE, routing_='r'
query, node_uuids=[node.uuid for node in nodes], routing_='r'
)
embeddings_dict: dict[str, list[float]] = {}
@ -1086,7 +1069,6 @@ async def get_embeddings_for_communities(
results, _, _ = await driver.execute_query(
query,
community_uuids=[community.uuid for community in communities],
database_=DEFAULT_DATABASE,
routing_='r',
)
@ -1113,7 +1095,6 @@ async def get_embeddings_for_edges(
results, _, _ = await driver.execute_query(
query,
edge_uuids=[edge.uuid for edge in edges],
database_=DEFAULT_DATABASE,
routing_='r',
)

View File

@ -30,7 +30,7 @@ from graphiti_core.graph_queries import (
get_entity_node_save_bulk_query,
)
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import DEFAULT_DATABASE, normalize_l2, semaphore_gather
from graphiti_core.helpers import normalize_l2, semaphore_gather
from graphiti_core.models.edges.edge_db_queries import (
EPISODIC_EDGE_SAVE_BULK,
)
@ -91,7 +91,7 @@ async def add_nodes_and_edges_bulk(
entity_edges: list[EntityEdge],
embedder: EmbedderClient,
):
session = driver.session(database=DEFAULT_DATABASE)
session = driver.session()
try:
await session.execute_write(
add_nodes_and_edges_bulk_tx,

View File

@ -7,7 +7,7 @@ from pydantic import BaseModel
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.edges import CommunityEdge
from graphiti_core.embedder import EmbedderClient
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
from graphiti_core.helpers import semaphore_gather
from graphiti_core.llm_client import LLMClient
from graphiti_core.nodes import CommunityNode, EntityNode, get_community_node_from_record
from graphiti_core.prompts import prompt_library
@ -37,7 +37,6 @@ async def get_community_clusters(
RETURN
collect(DISTINCT n.group_id) AS group_ids
""",
database_=DEFAULT_DATABASE,
)
group_ids = group_id_values[0]['group_ids'] if group_id_values else []
@ -56,7 +55,6 @@ async def get_community_clusters(
""",
uuid=node.uuid,
group_id=group_id,
database_=DEFAULT_DATABASE,
)
projection[node.uuid] = [
@ -224,7 +222,6 @@ async def remove_communities(driver: GraphDriver):
MATCH (c:Community)
DETACH DELETE c
""",
database_=DEFAULT_DATABASE,
)
@ -243,7 +240,6 @@ async def determine_entity_community(
c.summary AS summary
""",
entity_uuid=entity.uuid,
database_=DEFAULT_DATABASE,
)
if len(records) > 0:
@ -261,7 +257,6 @@ async def determine_entity_community(
c.summary AS summary
""",
entity_uuid=entity.uuid,
database_=DEFAULT_DATABASE,
)
communities: list[CommunityNode] = [

View File

@ -29,7 +29,7 @@ from graphiti_core.edges import (
create_entity_edge_embeddings,
)
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import DEFAULT_DATABASE, MAX_REFLEXION_ITERATIONS, semaphore_gather
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
from graphiti_core.llm_client import LLMClient
from graphiti_core.llm_client.config import ModelSize
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
@ -539,7 +539,6 @@ async def filter_existing_duplicate_of_edges(
records, _, _ = await driver.execute_query(
query,
duplicate_node_uuids=list(duplicate_nodes_map.keys()),
database_=DEFAULT_DATABASE,
routing_='r',
)

View File

@ -21,7 +21,7 @@ from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date, semaphore_gather
from graphiti_core.helpers import parse_db_date, semaphore_gather
from graphiti_core.nodes import EpisodeType, EpisodicNode
EPISODE_WINDOW_LEN = 3
@ -35,7 +35,6 @@ async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bo
"""
SHOW INDEXES YIELD name
""",
database_=DEFAULT_DATABASE,
)
index_names = [record['name'] for record in records]
await semaphore_gather(
@ -43,7 +42,6 @@ async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bo
driver.execute_query(
"""DROP INDEX $name""",
name=name,
database_=DEFAULT_DATABASE,
)
for name in index_names
]
@ -58,7 +56,6 @@ async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bo
*[
driver.execute_query(
query,
database_=DEFAULT_DATABASE,
)
for query in index_queries
]
@ -66,7 +63,7 @@ async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bo
async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None):
async with driver.session(database=DEFAULT_DATABASE) as session:
async with driver.session() as session:
async def delete_all(tx):
await tx.run('MATCH (n) DETACH DELETE n')
@ -134,7 +131,6 @@ async def retrieve_episodes(
source=source.name if source is not None else None,
num_episodes=last_n,
group_ids=group_ids,
database_=DEFAULT_DATABASE,
)
episodes = [

View File

@ -1,6 +1,6 @@
[project]
name = "mcp-server"
version = "0.2.1"
version = "0.3.0"
description = "Graphiti MCP Server"
readme = "README.md"
requires-python = ">=3.10,<4"

View File

@ -21,8 +21,6 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from graphiti_core.helpers import DEFAULT_DATABASE
try:
from graphiti_core.driver.falkordb_driver import FalkorDriver, FalkorDriverSession
@ -83,13 +81,13 @@ class TestFalkorDriver:
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
def test_get_graph_with_none_defaults_to_default_database(self):
"""Test _get_graph with None defaults to DEFAULT_DATABASE."""
"""Test _get_graph with None defaults to default_db."""
mock_graph = MagicMock()
self.mock_client.select_graph.return_value = mock_graph
result = self.driver._get_graph(None)
self.mock_client.select_graph.assert_called_once_with(DEFAULT_DATABASE)
self.mock_client.select_graph.assert_called_once_with('default_db')
assert result is mock_graph
@pytest.mark.asyncio
@ -184,7 +182,7 @@ class TestFalkorDriver:
session = self.driver.session(None)
assert isinstance(session, FalkorDriverSession)
self.mock_client.select_graph.assert_called_once_with(DEFAULT_DATABASE)
self.mock_client.select_graph.assert_called_once_with('default_db')
@pytest.mark.asyncio
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')

4024
uv.lock generated

File diff suppressed because it is too large Load Diff