2024-09-16 14:03:05 -04:00
|
|
|
"""
|
|
|
|
Copyright 2024, Zep Software, Inc.
|
|
|
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
limitations under the License.
|
|
|
|
"""
|
|
|
|
|
2024-12-17 13:08:18 -05:00
|
|
|
import asyncio
|
2024-10-21 12:33:32 -04:00
|
|
|
import os
|
2024-12-17 13:08:18 -05:00
|
|
|
from collections.abc import Coroutine
|
2024-08-27 16:18:01 -04:00
|
|
|
from datetime import datetime
|
2025-05-12 14:00:38 -04:00
|
|
|
from typing import Any
|
2024-08-27 16:18:01 -04:00
|
|
|
|
2024-10-08 13:55:10 -04:00
|
|
|
import numpy as np
|
2024-10-22 08:49:14 -04:00
|
|
|
from dotenv import load_dotenv
|
2024-08-27 16:18:01 -04:00
|
|
|
from neo4j import time as neo4j_time
|
2025-04-26 00:24:23 -04:00
|
|
|
from typing_extensions import LiteralString
|
2024-08-27 16:18:01 -04:00
|
|
|
|
2024-10-22 08:49:14 -04:00
|
|
|
load_dotenv()
|
|
|
|
|
2024-10-21 12:33:32 -04:00
|
|
|
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
|
2024-10-31 12:31:37 -04:00
|
|
|
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
|
2024-12-17 13:08:18 -05:00
|
|
|
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
|
2025-04-30 12:08:52 -04:00
|
|
|
MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 0))
|
2024-12-02 11:17:37 -05:00
|
|
|
DEFAULT_PAGE_LIMIT = 20
|
2024-10-21 12:33:32 -04:00
|
|
|
|
2025-04-26 00:24:23 -04:00
|
|
|
RUNTIME_QUERY: LiteralString = (
|
|
|
|
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
|
|
|
|
)
|
|
|
|
|
2024-08-27 16:18:01 -04:00
|
|
|
|
|
|
|
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
|
|
|
|
return neo_date.to_native() if neo_date else None
|
2024-09-26 16:12:38 -04:00
|
|
|
|
|
|
|
|
|
|
|
def lucene_sanitize(query: str) -> str:
|
|
|
|
# Escape special characters from a query before passing into Lucene
|
2024-10-03 10:08:30 -04:00
|
|
|
# + - && || ! ( ) { } [ ] ^ " ~ * ? : \ /
|
2024-09-26 16:12:38 -04:00
|
|
|
escape_map = str.maketrans(
|
|
|
|
{
|
|
|
|
'+': r'\+',
|
|
|
|
'-': r'\-',
|
|
|
|
'&': r'\&',
|
|
|
|
'|': r'\|',
|
|
|
|
'!': r'\!',
|
|
|
|
'(': r'\(',
|
|
|
|
')': r'\)',
|
|
|
|
'{': r'\{',
|
|
|
|
'}': r'\}',
|
|
|
|
'[': r'\[',
|
|
|
|
']': r'\]',
|
|
|
|
'^': r'\^',
|
|
|
|
'"': r'\"',
|
|
|
|
'~': r'\~',
|
|
|
|
'*': r'\*',
|
|
|
|
'?': r'\?',
|
|
|
|
':': r'\:',
|
|
|
|
'\\': r'\\',
|
2024-10-03 10:08:30 -04:00
|
|
|
'/': r'\/',
|
2024-12-09 10:36:46 -05:00
|
|
|
'O': r'\O',
|
|
|
|
'R': r'\R',
|
|
|
|
'N': r'\N',
|
|
|
|
'T': r'\T',
|
|
|
|
'A': r'\A',
|
|
|
|
'D': r'\D',
|
2024-09-26 16:12:38 -04:00
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
sanitized = query.translate(escape_map)
|
|
|
|
return sanitized
|
2024-10-08 13:55:10 -04:00
|
|
|
|
|
|
|
|
2025-01-24 10:14:49 -05:00
|
|
|
def normalize_l2(embedding: list[float]):
|
2024-10-08 13:55:10 -04:00
|
|
|
embedding_array = np.array(embedding)
|
|
|
|
if embedding_array.ndim == 1:
|
|
|
|
norm = np.linalg.norm(embedding_array)
|
|
|
|
if norm == 0:
|
2025-01-24 10:14:49 -05:00
|
|
|
return [0.0] * len(embedding)
|
2024-10-08 13:55:10 -04:00
|
|
|
return (embedding_array / norm).tolist()
|
|
|
|
else:
|
|
|
|
norm = np.linalg.norm(embedding_array, 2, axis=1, keepdims=True)
|
|
|
|
return (np.where(norm == 0, embedding_array, embedding_array / norm)).tolist()
|
2024-12-17 13:08:18 -05:00
|
|
|
|
|
|
|
|
|
|
|
# Use this instead of asyncio.gather() to bound coroutines
|
2025-05-12 14:00:38 -04:00
|
|
|
async def semaphore_gather(
|
|
|
|
*coroutines: Coroutine,
|
|
|
|
max_coroutines: int = SEMAPHORE_LIMIT,
|
|
|
|
):
|
2024-12-17 13:08:18 -05:00
|
|
|
semaphore = asyncio.Semaphore(max_coroutines)
|
|
|
|
|
2025-05-12 14:00:38 -04:00
|
|
|
async def _wrap(coro: Coroutine) -> Any:
|
2024-12-17 13:08:18 -05:00
|
|
|
async with semaphore:
|
2025-05-12 14:00:38 -04:00
|
|
|
return await coro
|
|
|
|
|
|
|
|
results = []
|
|
|
|
batch = []
|
|
|
|
for coroutine in coroutines:
|
|
|
|
batch.append(_wrap(coroutine))
|
|
|
|
# once we hit max_coroutines, gather and clear the batch
|
|
|
|
if len(batch) >= max_coroutines:
|
|
|
|
results.extend(await asyncio.gather(*batch))
|
|
|
|
batch.clear()
|
|
|
|
|
|
|
|
# gather any remaining coroutines in the final batch
|
|
|
|
if batch:
|
|
|
|
results.extend(await asyncio.gather(*batch))
|
|
|
|
|
|
|
|
return results
|