mirror of
https://github.com/getzep/graphiti.git
synced 2025-11-22 13:12:03 +00:00
* fix: remove global DEFAULT_DATABASE usage in favor of driver-specific config Fixes bugs introduced in PR #607. This removes reliance on the global DEFAULT_DATABASE environment variable. It specifies the database within each driver. PR #607 introduced a Neo4j compatability, as the database names are different when attempting to support FalkorDB. This refactor improves compatability across database types and ensures future reliance by isolating the configuraiton to the driver level. * fix: make falkordb support optional This ensures that the the optional dependency and subsequent import is compliant with the graphiti-core project dependencies. * chore: fmt code * chore: undo changes to uv.lock * fix: undo potentially breaking changes to drive interface * fix: ensure a default database of "None" is provided - falling back to internal default * chore: ensure default value exists for session and delete_all_indexes * chore: fix typos and grammar * chore: update package versions and dependencies in uv.lock and bulk_utils.py * docs: update database configuration instructions for Neo4j and FalkorDB Clarified default database names and how to override them in driver constructors. Updated testing requirements to include specific commands for running integration and unit tests. * fix: ensure params defaults to an empty dictionary in Neo4jDriver Updated the execute_query method to initialize params as an empty dictionary if not provided, ensuring compatibility with the database configuration. --------- Co-authored-by: Urmzd <urmzd@dal.ca>
171 lines
6.2 KiB
Python
171 lines
6.2 KiB
Python
"""
|
|
Copyright 2024, Zep Software, Inc.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
import logging
|
|
from datetime import datetime
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
if TYPE_CHECKING:
|
|
from falkordb import Graph as FalkorGraph
|
|
from falkordb.asyncio import FalkorDB
|
|
else:
|
|
try:
|
|
from falkordb import Graph as FalkorGraph
|
|
from falkordb.asyncio import FalkorDB
|
|
except ImportError:
|
|
# If falkordb is not installed, raise an ImportError
|
|
raise ImportError(
|
|
'falkordb is required for FalkorDriver. '
|
|
'Install it with: pip install graphiti-core[falkordb]'
|
|
) from None
|
|
|
|
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class FalkorDriverSession(GraphDriverSession):
|
|
def __init__(self, graph: FalkorGraph):
|
|
self.graph = graph
|
|
|
|
async def __aenter__(self):
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type, exc, tb):
|
|
# No cleanup needed for Falkor, but method must exist
|
|
pass
|
|
|
|
async def close(self):
|
|
# No explicit close needed for FalkorDB, but method must exist
|
|
pass
|
|
|
|
async def execute_write(self, func, *args, **kwargs):
|
|
# Directly await the provided async function with `self` as the transaction/session
|
|
return await func(self, *args, **kwargs)
|
|
|
|
async def run(self, query: str | list, **kwargs: Any) -> Any:
|
|
# FalkorDB does not support argument for Label Set, so it's converted into an array of queries
|
|
if isinstance(query, list):
|
|
for cypher, params in query:
|
|
params = convert_datetimes_to_strings(params)
|
|
await self.graph.query(str(cypher), params) # type: ignore[reportUnknownArgumentType]
|
|
else:
|
|
params = dict(kwargs)
|
|
params = convert_datetimes_to_strings(params)
|
|
await self.graph.query(str(query), params) # type: ignore[reportUnknownArgumentType]
|
|
# Assuming `graph.query` is async (ideal); otherwise, wrap in executor
|
|
return None
|
|
|
|
|
|
class FalkorDriver(GraphDriver):
|
|
provider: str = 'falkordb'
|
|
|
|
def __init__(
|
|
self,
|
|
host: str = 'localhost',
|
|
port: int = 6379,
|
|
username: str | None = None,
|
|
password: str | None = None,
|
|
falkor_db: FalkorDB | None = None,
|
|
database: str = 'default_db',
|
|
):
|
|
"""
|
|
Initialize the FalkorDB driver.
|
|
|
|
FalkorDB is a multi-tenant graph database.
|
|
To connect, provide the host and port.
|
|
The default parameters assume a local (on-premises) FalkorDB instance.
|
|
"""
|
|
super().__init__()
|
|
if falkor_db is not None:
|
|
# If a FalkorDB instance is provided, use it directly
|
|
self.client = falkor_db
|
|
else:
|
|
self.client = FalkorDB(host=host, port=port, username=username, password=password)
|
|
self._database = database
|
|
|
|
def _get_graph(self, graph_name: str | None) -> FalkorGraph:
|
|
# FalkorDB requires a non-None database name for multi-tenant graphs; the default is "default_db"
|
|
if graph_name is None:
|
|
graph_name = self._database
|
|
return self.client.select_graph(graph_name)
|
|
|
|
async def execute_query(self, cypher_query_, **kwargs: Any):
|
|
graph_name = kwargs.pop('database_', self._database)
|
|
graph = self._get_graph(graph_name)
|
|
|
|
# Convert datetime objects to ISO strings (FalkorDB does not support datetime objects directly)
|
|
params = convert_datetimes_to_strings(dict(kwargs))
|
|
|
|
try:
|
|
result = await graph.query(cypher_query_, params) # type: ignore[reportUnknownArgumentType]
|
|
except Exception as e:
|
|
if 'already indexed' in str(e):
|
|
# check if index already exists
|
|
logger.info(f'Index already exists: {e}')
|
|
return None
|
|
logger.error(f'Error executing FalkorDB query: {e}')
|
|
raise
|
|
|
|
# Convert the result header to a list of strings
|
|
header = [h[1] for h in result.header]
|
|
|
|
# Convert FalkorDB's result format (list of lists) to the format expected by Graphiti (list of dicts)
|
|
records = []
|
|
for row in result.result_set:
|
|
record = {}
|
|
for i, field_name in enumerate(header):
|
|
if i < len(row):
|
|
record[field_name] = row[i]
|
|
else:
|
|
# If there are more fields in header than values in row, set to None
|
|
record[field_name] = None
|
|
records.append(record)
|
|
|
|
return records, header, None
|
|
|
|
def session(self, database: str | None = None) -> GraphDriverSession:
|
|
return FalkorDriverSession(self._get_graph(database))
|
|
|
|
async def close(self) -> None:
|
|
"""Close the driver connection."""
|
|
if hasattr(self.client, 'aclose'):
|
|
await self.client.aclose() # type: ignore[reportUnknownMemberType]
|
|
elif hasattr(self.client.connection, 'aclose'):
|
|
await self.client.connection.aclose()
|
|
elif hasattr(self.client.connection, 'close'):
|
|
await self.client.connection.close()
|
|
|
|
async def delete_all_indexes(self, database_: str | None = None) -> None:
|
|
database = database_ or self._database
|
|
await self.execute_query(
|
|
'CALL db.indexes() YIELD name DROP INDEX name',
|
|
database_=database,
|
|
)
|
|
|
|
|
|
def convert_datetimes_to_strings(obj):
|
|
if isinstance(obj, dict):
|
|
return {k: convert_datetimes_to_strings(v) for k, v in obj.items()}
|
|
elif isinstance(obj, list):
|
|
return [convert_datetimes_to_strings(item) for item in obj]
|
|
elif isinstance(obj, tuple):
|
|
return tuple(convert_datetimes_to_strings(item) for item in obj)
|
|
elif isinstance(obj, datetime):
|
|
return obj.isoformat()
|
|
else:
|
|
return obj
|