mirror of
https://github.com/getzep/graphiti.git
synced 2025-11-16 18:27:32 +00:00
[REFACTOR][FIX] Move away from DEFAULT_DATABASE environment variable in favour of driver-config support (dc) (#699)
* fix: remove global DEFAULT_DATABASE usage in favor of driver-specific config Fixes bugs introduced in PR #607. This removes reliance on the global DEFAULT_DATABASE environment variable. It specifies the database within each driver. PR #607 introduced a Neo4j compatability, as the database names are different when attempting to support FalkorDB. This refactor improves compatability across database types and ensures future reliance by isolating the configuraiton to the driver level. * fix: make falkordb support optional This ensures that the the optional dependency and subsequent import is compliant with the graphiti-core project dependencies. * chore: fmt code * chore: undo changes to uv.lock * fix: undo potentially breaking changes to drive interface * fix: ensure a default database of "None" is provided - falling back to internal default * chore: ensure default value exists for session and delete_all_indexes * chore: fix typos and grammar * chore: update package versions and dependencies in uv.lock and bulk_utils.py * docs: update database configuration instructions for Neo4j and FalkorDB Clarified default database names and how to override them in driver constructors. Updated testing requirements to include specific commands for running integration and unit tests. * fix: ensure params defaults to an empty dictionary in Neo4jDriver Updated the execute_query method to initialize params as an empty dictionary if not provided, ensuring compatibility with the database configuration. --------- Co-authored-by: Urmzd <urmzd@dal.ca>
This commit is contained in:
parent
e5a61de931
commit
aa6e38856a
@ -12,7 +12,6 @@ FALKORDB_PORT=
|
||||
FALKORDB_USER=
|
||||
FALKORDB_PASSWORD=
|
||||
|
||||
DEFAULT_DATABASE=
|
||||
USE_PARALLEL_RUNTIME=
|
||||
SEMAPHORE_LIMIT=
|
||||
GITHUB_SHA=
|
||||
|
||||
10
CLAUDE.md
10
CLAUDE.md
@ -102,7 +102,11 @@ docker-compose up
|
||||
### Database Setup
|
||||
|
||||
- **Neo4j**: Version 5.26+ required, available via Neo4j Desktop
|
||||
- Database name defaults to `neo4j` (hardcoded in Neo4jDriver)
|
||||
- Override by passing `database` parameter to driver constructor
|
||||
- **FalkorDB**: Version 1.1.2+ as alternative backend
|
||||
- Database name defaults to `default_db` (hardcoded in FalkorDriver)
|
||||
- Override by passing `database` parameter to driver constructor
|
||||
|
||||
## Development Guidelines
|
||||
|
||||
@ -117,8 +121,12 @@ docker-compose up
|
||||
### Testing Requirements
|
||||
|
||||
- Run tests with `make test` or `pytest`
|
||||
- Integration tests require database connections
|
||||
- Integration tests require database connections and are marked with `_int` suffix
|
||||
- Use `pytest-xdist` for parallel test execution
|
||||
- Run specific test files: `pytest tests/test_specific_file.py`
|
||||
- Run specific test methods: `pytest tests/test_file.py::test_method_name`
|
||||
- Run only integration tests: `pytest tests/ -k "_int"`
|
||||
- Run only unit tests: `pytest tests/ -k "not _int"`
|
||||
|
||||
### LLM Provider Support
|
||||
|
||||
|
||||
20
README.md
20
README.md
@ -215,22 +215,18 @@ must be set.
|
||||
|
||||
### Database Configuration
|
||||
|
||||
`DEFAULT_DATABASE` specifies the database name to use for graph operations. This is particularly important for Neo4j 5+ users:
|
||||
Database names are configured directly in the driver constructors:
|
||||
|
||||
- **Neo4j 5+**: The default database name is `neo4j` (not `default_db`)
|
||||
- **Neo4j 4**: The default database name is `default_db`
|
||||
- **FalkorDB**: The default graph name is `default_db`
|
||||
- **Neo4j**: Database name defaults to `neo4j` (hardcoded in Neo4jDriver)
|
||||
- **FalkorDB**: Database name defaults to `default_db` (hardcoded in FalkorDriver)
|
||||
|
||||
If you encounter the error `Graph not found: default_db` when using Neo4j 5, set:
|
||||
To use a different database name, pass the `database` parameter when creating the driver:
|
||||
|
||||
```bash
|
||||
export DEFAULT_DATABASE=neo4j
|
||||
```
|
||||
```python
|
||||
from graphiti_core.driver.neo4j_driver import Neo4jDriver
|
||||
|
||||
Or add to your `.env` file:
|
||||
|
||||
```
|
||||
DEFAULT_DATABASE=neo4j
|
||||
# Use custom database name
|
||||
driver = Neo4jDriver(uri="bolt://localhost:7687", user="neo4j", password="password", database="my_db")
|
||||
```
|
||||
|
||||
### Performance Configuration
|
||||
|
||||
@ -42,8 +42,7 @@ export NEO4J_PASSWORD=password
|
||||
# Optional FalkorDB connection parameters (defaults shown)
|
||||
export FALKORDB_URI=falkor://localhost:6379
|
||||
|
||||
# Database configuration (required for Neo4j 5+)
|
||||
export DEFAULT_DATABASE=neo4j
|
||||
# To use a different database, modify the driver constructor in the script
|
||||
```
|
||||
|
||||
3. Run the example:
|
||||
@ -78,21 +77,18 @@ After running this example, you can:
|
||||
|
||||
### "Graph not found: default_db" Error
|
||||
|
||||
If you encounter the error `Neo.ClientError.Database.DatabaseNotFound: Graph not found: default_db`, this typically occurs with Neo4j 5+ where the default database name is `neo4j` instead of `default_db`.
|
||||
If you encounter the error `Neo.ClientError.Database.DatabaseNotFound: Graph not found: default_db`, this occurs when the driver is trying to connect to a database that doesn't exist.
|
||||
|
||||
**Solution:**
|
||||
Set the `DEFAULT_DATABASE` environment variable to `neo4j`:
|
||||
The Neo4j driver defaults to using `neo4j` as the database name. If you need to use a different database, modify the driver constructor in the script:
|
||||
|
||||
```bash
|
||||
export DEFAULT_DATABASE=neo4j
|
||||
```
|
||||
```python
|
||||
# In quickstart_neo4j.py, change:
|
||||
driver = Neo4jDriver(uri=neo4j_uri, user=neo4j_user, password=neo4j_password)
|
||||
|
||||
Or add it to your `.env` file:
|
||||
# To specify a different database:
|
||||
driver = Neo4jDriver(uri=neo4j_uri, user=neo4j_user, password=neo4j_password, database="your_db_name")
|
||||
```
|
||||
DEFAULT_DATABASE=neo4j
|
||||
```
|
||||
|
||||
This tells Graphiti to use the correct database name for your Neo4j version.
|
||||
|
||||
## Understanding the Output
|
||||
|
||||
|
||||
@ -19,8 +19,6 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import Coroutine
|
||||
from typing import Any
|
||||
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -54,7 +52,7 @@ class GraphDriver(ABC):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def session(self, database: str) -> GraphDriverSession:
|
||||
def session(self, database: str | None = None) -> GraphDriverSession:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
@ -62,5 +60,5 @@ class GraphDriver(ABC):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
|
||||
def delete_all_indexes(self, database_: str | None = None) -> Coroutine:
|
||||
raise NotImplementedError()
|
||||
|
||||
@ -33,7 +33,6 @@ else:
|
||||
) from None
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -81,6 +80,7 @@ class FalkorDriver(GraphDriver):
|
||||
username: str | None = None,
|
||||
password: str | None = None,
|
||||
falkor_db: FalkorDB | None = None,
|
||||
database: str = 'default_db',
|
||||
):
|
||||
"""
|
||||
Initialize the FalkorDB driver.
|
||||
@ -95,15 +95,16 @@ class FalkorDriver(GraphDriver):
|
||||
self.client = falkor_db
|
||||
else:
|
||||
self.client = FalkorDB(host=host, port=port, username=username, password=password)
|
||||
self._database = database
|
||||
|
||||
def _get_graph(self, graph_name: str | None) -> FalkorGraph:
|
||||
# FalkorDB requires a non-None database name for multi-tenant graphs; the default is DEFAULT_DATABASE
|
||||
# FalkorDB requires a non-None database name for multi-tenant graphs; the default is "default_db"
|
||||
if graph_name is None:
|
||||
graph_name = DEFAULT_DATABASE
|
||||
graph_name = self._database
|
||||
return self.client.select_graph(graph_name)
|
||||
|
||||
async def execute_query(self, cypher_query_, **kwargs: Any):
|
||||
graph_name = kwargs.pop('database_', DEFAULT_DATABASE)
|
||||
graph_name = kwargs.pop('database_', self._database)
|
||||
graph = self._get_graph(graph_name)
|
||||
|
||||
# Convert datetime objects to ISO strings (FalkorDB does not support datetime objects directly)
|
||||
@ -136,7 +137,7 @@ class FalkorDriver(GraphDriver):
|
||||
|
||||
return records, header, None
|
||||
|
||||
def session(self, database: str | None) -> GraphDriverSession:
|
||||
def session(self, database: str | None = None) -> GraphDriverSession:
|
||||
return FalkorDriverSession(self._get_graph(database))
|
||||
|
||||
async def close(self) -> None:
|
||||
@ -148,10 +149,11 @@ class FalkorDriver(GraphDriver):
|
||||
elif hasattr(self.client.connection, 'close'):
|
||||
await self.client.connection.close()
|
||||
|
||||
async def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> None:
|
||||
async def delete_all_indexes(self, database_: str | None = None) -> None:
|
||||
database = database_ or self._database
|
||||
await self.execute_query(
|
||||
'CALL db.indexes() YIELD name DROP INDEX name',
|
||||
database_=database_,
|
||||
database_=database,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -22,7 +22,6 @@ from neo4j import AsyncGraphDatabase, EagerResult
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -30,34 +29,36 @@ logger = logging.getLogger(__name__)
|
||||
class Neo4jDriver(GraphDriver):
|
||||
provider: str = 'neo4j'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
user: str | None,
|
||||
password: str | None,
|
||||
):
|
||||
def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'neo4j'):
|
||||
super().__init__()
|
||||
self.client = AsyncGraphDatabase.driver(
|
||||
uri=uri,
|
||||
auth=(user or '', password or ''),
|
||||
)
|
||||
self._database = database
|
||||
|
||||
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
|
||||
# Check if database_ is provided in kwargs.
|
||||
# If not populated, set the value to retain backwards compatibility
|
||||
params = kwargs.pop('params', None)
|
||||
if params is None:
|
||||
params = {}
|
||||
params.setdefault('database_', self._database)
|
||||
|
||||
result = await self.client.execute_query(cypher_query_, parameters_=params, **kwargs)
|
||||
|
||||
return result
|
||||
|
||||
def session(self, database: str) -> GraphDriverSession:
|
||||
return self.client.session(database=database) # type: ignore
|
||||
def session(self, database: str | None = None) -> GraphDriverSession:
|
||||
_database = database or self._database
|
||||
return self.client.session(database=_database) # type: ignore
|
||||
|
||||
async def close(self) -> None:
|
||||
return await self.client.close()
|
||||
|
||||
def delete_all_indexes(
|
||||
self, database_: str = DEFAULT_DATABASE
|
||||
) -> Coroutine[Any, Any, EagerResult]:
|
||||
def delete_all_indexes(self, database_: str | None = None) -> Coroutine[Any, Any, EagerResult]:
|
||||
database = database_ or self._database
|
||||
return self.client.execute_query(
|
||||
'CALL db.indexes() YIELD name DROP INDEX name',
|
||||
database_=database_,
|
||||
database_=database,
|
||||
)
|
||||
|
||||
@ -27,7 +27,7 @@ from typing_extensions import LiteralString
|
||||
from graphiti_core.driver.driver import GraphDriver
|
||||
from graphiti_core.embedder import EmbedderClient
|
||||
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date
|
||||
from graphiti_core.helpers import parse_db_date
|
||||
from graphiti_core.models.edges.edge_db_queries import (
|
||||
COMMUNITY_EDGE_SAVE,
|
||||
ENTITY_EDGE_SAVE,
|
||||
@ -71,7 +71,6 @@ class Edge(BaseModel, ABC):
|
||||
DELETE e
|
||||
""",
|
||||
uuid=self.uuid,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
logger.debug(f'Deleted Edge: {self.uuid}')
|
||||
@ -99,7 +98,6 @@ class EpisodicEdge(Edge):
|
||||
uuid=self.uuid,
|
||||
group_id=self.group_id,
|
||||
created_at=self.created_at,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
logger.debug(f'Saved edge to Graph: {self.uuid}')
|
||||
@ -119,7 +117,6 @@ class EpisodicEdge(Edge):
|
||||
e.created_at AS created_at
|
||||
""",
|
||||
uuid=uuid,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
@ -143,7 +140,6 @@ class EpisodicEdge(Edge):
|
||||
e.created_at AS created_at
|
||||
""",
|
||||
uuids=uuids,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
@ -183,7 +179,6 @@ class EpisodicEdge(Edge):
|
||||
group_ids=group_ids,
|
||||
uuid=uuid_cursor,
|
||||
limit=limit,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
@ -231,9 +226,7 @@ class EntityEdge(Edge):
|
||||
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
||||
RETURN e.fact_embedding AS fact_embedding
|
||||
"""
|
||||
records, _, _ = await driver.execute_query(
|
||||
query, uuid=self.uuid, database_=DEFAULT_DATABASE, routing_='r'
|
||||
)
|
||||
records, _, _ = await driver.execute_query(query, uuid=self.uuid, routing_='r')
|
||||
|
||||
if len(records) == 0:
|
||||
raise EdgeNotFoundError(self.uuid)
|
||||
@ -261,7 +254,6 @@ class EntityEdge(Edge):
|
||||
result = await driver.execute_query(
|
||||
ENTITY_EDGE_SAVE,
|
||||
edge_data=edge_data,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
logger.debug(f'Saved edge to Graph: {self.uuid}')
|
||||
@ -276,7 +268,6 @@ class EntityEdge(Edge):
|
||||
"""
|
||||
+ ENTITY_EDGE_RETURN,
|
||||
uuid=uuid,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
@ -298,7 +289,6 @@ class EntityEdge(Edge):
|
||||
"""
|
||||
+ ENTITY_EDGE_RETURN,
|
||||
uuids=uuids,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
@ -331,7 +321,6 @@ class EntityEdge(Edge):
|
||||
group_ids=group_ids,
|
||||
uuid=uuid_cursor,
|
||||
limit=limit,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
@ -349,9 +338,7 @@ class EntityEdge(Edge):
|
||||
"""
|
||||
+ ENTITY_EDGE_RETURN
|
||||
)
|
||||
records, _, _ = await driver.execute_query(
|
||||
query, node_uuid=node_uuid, database_=DEFAULT_DATABASE, routing_='r'
|
||||
)
|
||||
records, _, _ = await driver.execute_query(query, node_uuid=node_uuid, routing_='r')
|
||||
|
||||
edges = [get_entity_edge_from_record(record) for record in records]
|
||||
|
||||
@ -367,7 +354,6 @@ class CommunityEdge(Edge):
|
||||
uuid=self.uuid,
|
||||
group_id=self.group_id,
|
||||
created_at=self.created_at,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
logger.debug(f'Saved edge to Graph: {self.uuid}')
|
||||
@ -387,7 +373,6 @@ class CommunityEdge(Edge):
|
||||
e.created_at AS created_at
|
||||
""",
|
||||
uuid=uuid,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
@ -409,7 +394,6 @@ class CommunityEdge(Edge):
|
||||
e.created_at AS created_at
|
||||
""",
|
||||
uuids=uuids,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
@ -447,7 +431,6 @@ class CommunityEdge(Edge):
|
||||
group_ids=group_ids,
|
||||
uuid=uuid_cursor,
|
||||
limit=limit,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
|
||||
@ -30,7 +30,6 @@ from graphiti_core.edges import EntityEdge, EpisodicEdge
|
||||
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
||||
from graphiti_core.graphiti_types import GraphitiClients
|
||||
from graphiti_core.helpers import (
|
||||
DEFAULT_DATABASE,
|
||||
semaphore_gather,
|
||||
validate_excluded_entity_types,
|
||||
validate_group_id,
|
||||
@ -168,7 +167,6 @@ class Graphiti:
|
||||
raise ValueError('uri must be provided when graph_driver is None')
|
||||
self.driver = Neo4jDriver(uri, user, password)
|
||||
|
||||
self.database = DEFAULT_DATABASE
|
||||
self.store_raw_episode_content = store_raw_episode_content
|
||||
self.max_coroutines = max_coroutines
|
||||
if llm_client:
|
||||
@ -921,9 +919,7 @@ class Graphiti:
|
||||
nodes_to_delete: list[EntityNode] = []
|
||||
for node in nodes:
|
||||
query: LiteralString = 'MATCH (e:Episodic)-[:MENTIONS]->(n:Entity {uuid: $uuid}) RETURN count(*) AS episode_count'
|
||||
records, _, _ = await self.driver.execute_query(
|
||||
query, uuid=node.uuid, database_=DEFAULT_DATABASE, routing_='r'
|
||||
)
|
||||
records, _, _ = await self.driver.execute_query(query, uuid=node.uuid, routing_='r')
|
||||
|
||||
for record in records:
|
||||
if record['episode_count'] == 1:
|
||||
|
||||
@ -32,7 +32,6 @@ from graphiti_core.errors import GroupIdValidationError
|
||||
|
||||
load_dotenv()
|
||||
|
||||
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', 'default_db')
|
||||
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
|
||||
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
|
||||
MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 0))
|
||||
|
||||
@ -28,7 +28,7 @@ from typing_extensions import LiteralString
|
||||
from graphiti_core.driver.driver import GraphDriver
|
||||
from graphiti_core.embedder import EmbedderClient
|
||||
from graphiti_core.errors import NodeNotFoundError
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date
|
||||
from graphiti_core.helpers import parse_db_date
|
||||
from graphiti_core.models.nodes.node_db_queries import (
|
||||
COMMUNITY_NODE_SAVE,
|
||||
ENTITY_NODE_SAVE,
|
||||
@ -103,7 +103,6 @@ class Node(BaseModel, ABC):
|
||||
DETACH DELETE n
|
||||
""",
|
||||
uuid=self.uuid,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
logger.debug(f'Deleted Node: {self.uuid}')
|
||||
@ -126,7 +125,6 @@ class Node(BaseModel, ABC):
|
||||
DETACH DELETE n
|
||||
""",
|
||||
group_id=group_id,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
return 'SUCCESS'
|
||||
@ -162,7 +160,6 @@ class EpisodicNode(Node):
|
||||
created_at=self.created_at,
|
||||
valid_at=self.valid_at,
|
||||
source=self.source.value,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
logger.debug(f'Saved Node to Graph: {self.uuid}')
|
||||
@ -185,7 +182,6 @@ class EpisodicNode(Node):
|
||||
e.entity_edges AS entity_edges
|
||||
""",
|
||||
uuid=uuid,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
@ -213,7 +209,6 @@ class EpisodicNode(Node):
|
||||
e.entity_edges AS entity_edges
|
||||
""",
|
||||
uuids=uuids,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
@ -254,7 +249,6 @@ class EpisodicNode(Node):
|
||||
group_ids=group_ids,
|
||||
uuid=uuid_cursor,
|
||||
limit=limit,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
@ -279,7 +273,6 @@ class EpisodicNode(Node):
|
||||
e.entity_edges AS entity_edges
|
||||
""",
|
||||
entity_node_uuid=entity_node_uuid,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
@ -309,9 +302,7 @@ class EntityNode(Node):
|
||||
MATCH (n:Entity {uuid: $uuid})
|
||||
RETURN n.name_embedding AS name_embedding
|
||||
"""
|
||||
records, _, _ = await driver.execute_query(
|
||||
query, uuid=self.uuid, database_=DEFAULT_DATABASE, routing_='r'
|
||||
)
|
||||
records, _, _ = await driver.execute_query(query, uuid=self.uuid, routing_='r')
|
||||
|
||||
if len(records) == 0:
|
||||
raise NodeNotFoundError(self.uuid)
|
||||
@ -334,7 +325,6 @@ class EntityNode(Node):
|
||||
ENTITY_NODE_SAVE,
|
||||
labels=self.labels + ['Entity'],
|
||||
entity_data=entity_data,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
logger.debug(f'Saved Node to Graph: {self.uuid}')
|
||||
@ -352,7 +342,6 @@ class EntityNode(Node):
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
uuid=uuid,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
@ -371,7 +360,6 @@ class EntityNode(Node):
|
||||
"""
|
||||
+ ENTITY_NODE_RETURN,
|
||||
uuids=uuids,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
@ -403,7 +391,6 @@ class EntityNode(Node):
|
||||
group_ids=group_ids,
|
||||
uuid=uuid_cursor,
|
||||
limit=limit,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
@ -425,7 +412,6 @@ class CommunityNode(Node):
|
||||
summary=self.summary,
|
||||
name_embedding=self.name_embedding,
|
||||
created_at=self.created_at,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
logger.debug(f'Saved Node to Graph: {self.uuid}')
|
||||
@ -446,9 +432,7 @@ class CommunityNode(Node):
|
||||
MATCH (c:Community {uuid: $uuid})
|
||||
RETURN c.name_embedding AS name_embedding
|
||||
"""
|
||||
records, _, _ = await driver.execute_query(
|
||||
query, uuid=self.uuid, database_=DEFAULT_DATABASE, routing_='r'
|
||||
)
|
||||
records, _, _ = await driver.execute_query(query, uuid=self.uuid, routing_='r')
|
||||
|
||||
if len(records) == 0:
|
||||
raise NodeNotFoundError(self.uuid)
|
||||
@ -468,7 +452,6 @@ class CommunityNode(Node):
|
||||
n.summary AS summary
|
||||
""",
|
||||
uuid=uuid,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
@ -492,7 +475,6 @@ class CommunityNode(Node):
|
||||
n.summary AS summary
|
||||
""",
|
||||
uuids=uuids,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
@ -529,7 +511,6 @@ class CommunityNode(Node):
|
||||
group_ids=group_ids,
|
||||
uuid=uuid_cursor,
|
||||
limit=limit,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
|
||||
@ -31,7 +31,6 @@ from graphiti_core.graph_queries import (
|
||||
get_vector_cosine_func_query,
|
||||
)
|
||||
from graphiti_core.helpers import (
|
||||
DEFAULT_DATABASE,
|
||||
RUNTIME_QUERY,
|
||||
lucene_sanitize,
|
||||
normalize_l2,
|
||||
@ -116,7 +115,6 @@ async def get_mentioned_nodes(
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
uuids=episode_uuids,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
@ -143,7 +141,6 @@ async def get_communities_by_nodes(
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
uuids=node_uuids,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
@ -198,7 +195,6 @@ async def edge_fulltext_search(
|
||||
query=fuzzy_query,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
@ -274,7 +270,6 @@ async def edge_similarity_search(
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
min_score=min_score,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
@ -329,7 +324,6 @@ async def edge_bfs_search(
|
||||
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
||||
depth=bfs_max_depth,
|
||||
limit=limit,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
@ -371,7 +365,6 @@ async def node_fulltext_search(
|
||||
query=fuzzy_query,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
@ -425,7 +418,6 @@ async def node_similarity_search(
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
min_score=min_score,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
@ -465,7 +457,6 @@ async def node_bfs_search(
|
||||
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
||||
depth=bfs_max_depth,
|
||||
limit=limit,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
nodes = [get_entity_node_from_record(record) for record in records]
|
||||
@ -511,7 +502,6 @@ async def episode_fulltext_search(
|
||||
query=fuzzy_query,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
episodes = [get_episodic_node_from_record(record) for record in records]
|
||||
@ -551,7 +541,6 @@ async def community_fulltext_search(
|
||||
query=fuzzy_query,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
communities = [get_community_node_from_record(record) for record in records]
|
||||
@ -603,7 +592,6 @@ async def community_similarity_search(
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
min_score=min_score,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
communities = [get_community_node_from_record(record) for record in records]
|
||||
@ -764,7 +752,6 @@ async def get_relevant_nodes(
|
||||
group_id=group_id,
|
||||
limit=limit,
|
||||
min_score=min_score,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
@ -834,7 +821,6 @@ async def get_relevant_edges(
|
||||
edges=[edge.model_dump() for edge in edges],
|
||||
limit=limit,
|
||||
min_score=min_score,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
@ -905,7 +891,6 @@ async def get_edge_invalidation_candidates(
|
||||
edges=[edge.model_dump() for edge in edges],
|
||||
limit=limit,
|
||||
min_score=min_score,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
invalidation_edges_dict: dict[str, list[EntityEdge]] = {
|
||||
@ -955,7 +940,6 @@ async def node_distance_reranker(
|
||||
query,
|
||||
node_uuids=filtered_uuids,
|
||||
center_uuid=center_node_uuid,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
if driver.provider == 'falkordb':
|
||||
@ -997,7 +981,6 @@ async def episode_mentions_reranker(
|
||||
results, _, _ = await driver.execute_query(
|
||||
query,
|
||||
node_uuids=sorted_uuids,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
@ -1060,7 +1043,7 @@ async def get_embeddings_for_nodes(
|
||||
"""
|
||||
|
||||
results, _, _ = await driver.execute_query(
|
||||
query, node_uuids=[node.uuid for node in nodes], database_=DEFAULT_DATABASE, routing_='r'
|
||||
query, node_uuids=[node.uuid for node in nodes], routing_='r'
|
||||
)
|
||||
|
||||
embeddings_dict: dict[str, list[float]] = {}
|
||||
@ -1086,7 +1069,6 @@ async def get_embeddings_for_communities(
|
||||
results, _, _ = await driver.execute_query(
|
||||
query,
|
||||
community_uuids=[community.uuid for community in communities],
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
@ -1113,7 +1095,6 @@ async def get_embeddings_for_edges(
|
||||
results, _, _ = await driver.execute_query(
|
||||
query,
|
||||
edge_uuids=[edge.uuid for edge in edges],
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
|
||||
@ -30,7 +30,7 @@ from graphiti_core.graph_queries import (
|
||||
get_entity_node_save_bulk_query,
|
||||
)
|
||||
from graphiti_core.graphiti_types import GraphitiClients
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE, normalize_l2, semaphore_gather
|
||||
from graphiti_core.helpers import normalize_l2, semaphore_gather
|
||||
from graphiti_core.models.edges.edge_db_queries import (
|
||||
EPISODIC_EDGE_SAVE_BULK,
|
||||
)
|
||||
@ -91,7 +91,7 @@ async def add_nodes_and_edges_bulk(
|
||||
entity_edges: list[EntityEdge],
|
||||
embedder: EmbedderClient,
|
||||
):
|
||||
session = driver.session(database=DEFAULT_DATABASE)
|
||||
session = driver.session()
|
||||
try:
|
||||
await session.execute_write(
|
||||
add_nodes_and_edges_bulk_tx,
|
||||
|
||||
@ -7,7 +7,7 @@ from pydantic import BaseModel
|
||||
from graphiti_core.driver.driver import GraphDriver
|
||||
from graphiti_core.edges import CommunityEdge
|
||||
from graphiti_core.embedder import EmbedderClient
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
||||
from graphiti_core.helpers import semaphore_gather
|
||||
from graphiti_core.llm_client import LLMClient
|
||||
from graphiti_core.nodes import CommunityNode, EntityNode, get_community_node_from_record
|
||||
from graphiti_core.prompts import prompt_library
|
||||
@ -37,7 +37,6 @@ async def get_community_clusters(
|
||||
RETURN
|
||||
collect(DISTINCT n.group_id) AS group_ids
|
||||
""",
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
group_ids = group_id_values[0]['group_ids'] if group_id_values else []
|
||||
@ -56,7 +55,6 @@ async def get_community_clusters(
|
||||
""",
|
||||
uuid=node.uuid,
|
||||
group_id=group_id,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
projection[node.uuid] = [
|
||||
@ -224,7 +222,6 @@ async def remove_communities(driver: GraphDriver):
|
||||
MATCH (c:Community)
|
||||
DETACH DELETE c
|
||||
""",
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
|
||||
@ -243,7 +240,6 @@ async def determine_entity_community(
|
||||
c.summary AS summary
|
||||
""",
|
||||
entity_uuid=entity.uuid,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
if len(records) > 0:
|
||||
@ -261,7 +257,6 @@ async def determine_entity_community(
|
||||
c.summary AS summary
|
||||
""",
|
||||
entity_uuid=entity.uuid,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
communities: list[CommunityNode] = [
|
||||
|
||||
@ -29,7 +29,7 @@ from graphiti_core.edges import (
|
||||
create_entity_edge_embeddings,
|
||||
)
|
||||
from graphiti_core.graphiti_types import GraphitiClients
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE, MAX_REFLEXION_ITERATIONS, semaphore_gather
|
||||
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
|
||||
from graphiti_core.llm_client import LLMClient
|
||||
from graphiti_core.llm_client.config import ModelSize
|
||||
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
|
||||
@ -539,7 +539,6 @@ async def filter_existing_duplicate_of_edges(
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
duplicate_node_uuids=list(duplicate_nodes_map.keys()),
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
|
||||
@ -21,7 +21,7 @@ from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver
|
||||
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date, semaphore_gather
|
||||
from graphiti_core.helpers import parse_db_date, semaphore_gather
|
||||
from graphiti_core.nodes import EpisodeType, EpisodicNode
|
||||
|
||||
EPISODE_WINDOW_LEN = 3
|
||||
@ -35,7 +35,6 @@ async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bo
|
||||
"""
|
||||
SHOW INDEXES YIELD name
|
||||
""",
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
index_names = [record['name'] for record in records]
|
||||
await semaphore_gather(
|
||||
@ -43,7 +42,6 @@ async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bo
|
||||
driver.execute_query(
|
||||
"""DROP INDEX $name""",
|
||||
name=name,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
for name in index_names
|
||||
]
|
||||
@ -58,7 +56,6 @@ async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bo
|
||||
*[
|
||||
driver.execute_query(
|
||||
query,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
for query in index_queries
|
||||
]
|
||||
@ -66,7 +63,7 @@ async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bo
|
||||
|
||||
|
||||
async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None):
|
||||
async with driver.session(database=DEFAULT_DATABASE) as session:
|
||||
async with driver.session() as session:
|
||||
|
||||
async def delete_all(tx):
|
||||
await tx.run('MATCH (n) DETACH DELETE n')
|
||||
@ -134,7 +131,6 @@ async def retrieve_episodes(
|
||||
source=source.name if source is not None else None,
|
||||
num_episodes=last_n,
|
||||
group_ids=group_ids,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
episodes = [
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "mcp-server"
|
||||
version = "0.2.1"
|
||||
version = "0.3.0"
|
||||
description = "Graphiti MCP Server"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10,<4"
|
||||
|
||||
@ -21,8 +21,6 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE
|
||||
|
||||
try:
|
||||
from graphiti_core.driver.falkordb_driver import FalkorDriver, FalkorDriverSession
|
||||
|
||||
@ -83,13 +81,13 @@ class TestFalkorDriver:
|
||||
|
||||
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
|
||||
def test_get_graph_with_none_defaults_to_default_database(self):
|
||||
"""Test _get_graph with None defaults to DEFAULT_DATABASE."""
|
||||
"""Test _get_graph with None defaults to default_db."""
|
||||
mock_graph = MagicMock()
|
||||
self.mock_client.select_graph.return_value = mock_graph
|
||||
|
||||
result = self.driver._get_graph(None)
|
||||
|
||||
self.mock_client.select_graph.assert_called_once_with(DEFAULT_DATABASE)
|
||||
self.mock_client.select_graph.assert_called_once_with('default_db')
|
||||
assert result is mock_graph
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -184,7 +182,7 @@ class TestFalkorDriver:
|
||||
session = self.driver.session(None)
|
||||
|
||||
assert isinstance(session, FalkorDriverSession)
|
||||
self.mock_client.select_graph.assert_called_once_with(DEFAULT_DATABASE)
|
||||
self.mock_client.select_graph.assert_called_once_with('default_db')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user