mirror of
https://github.com/getzep/graphiti.git
synced 2025-06-27 02:00:02 +00:00
Set max tokens by prompt (#255)
* set max tokens * update generic openai client * mypy updates * fix: dockerfile --------- Co-authored-by: paulpaliychuk <pavlo.paliychuk.ca@gmail.com>
This commit is contained in:
parent
77cb67cdfe
commit
0f50b74735
@ -23,7 +23,7 @@ RUN poetry build && pip install dist/*.whl
|
||||
|
||||
# Install server dependencies
|
||||
WORKDIR /app/server
|
||||
RUN poetry install --no-interaction --no-ansi --no-dev
|
||||
RUN poetry install --no-interaction --no-ansi --only main --no-root
|
||||
|
||||
FROM python:3.12-slim
|
||||
|
||||
|
@ -73,12 +73,12 @@ def lucene_sanitize(query: str) -> str:
|
||||
return sanitized
|
||||
|
||||
|
||||
def normalize_l2(embedding: list[float]) -> list[float]:
|
||||
def normalize_l2(embedding: list[float]):
|
||||
embedding_array = np.array(embedding)
|
||||
if embedding_array.ndim == 1:
|
||||
norm = np.linalg.norm(embedding_array)
|
||||
if norm == 0:
|
||||
return embedding_array.tolist()
|
||||
return [0.0] * len(embedding)
|
||||
return (embedding_array / norm).tolist()
|
||||
else:
|
||||
norm = np.linalg.norm(embedding_array, 2, axis=1, keepdims=True)
|
||||
|
@ -48,7 +48,10 @@ class AnthropicClient(LLMClient):
|
||||
)
|
||||
|
||||
async def _generate_response(
|
||||
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
||||
self,
|
||||
messages: list[Message],
|
||||
response_model: type[BaseModel] | None = None,
|
||||
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||
) -> dict[str, typing.Any]:
|
||||
system_message = messages[0]
|
||||
user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]] + [
|
||||
@ -59,7 +62,7 @@ class AnthropicClient(LLMClient):
|
||||
result = await self.client.messages.create(
|
||||
system='Only include JSON in the response. Do not include any additional text or explanation of the content.\n'
|
||||
+ system_message.content,
|
||||
max_tokens=self.max_tokens,
|
||||
max_tokens=max_tokens or self.max_tokens,
|
||||
temperature=self.temperature,
|
||||
messages=user_messages, # type: ignore
|
||||
model=self.model or DEFAULT_MODEL,
|
||||
|
@ -26,7 +26,7 @@ from pydantic import BaseModel
|
||||
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random_exponential
|
||||
|
||||
from ..prompts.models import Message
|
||||
from .config import LLMConfig
|
||||
from .config import DEFAULT_MAX_TOKENS, LLMConfig
|
||||
from .errors import RateLimitError
|
||||
|
||||
DEFAULT_TEMPERATURE = 0
|
||||
@ -90,16 +90,22 @@ class LLMClient(ABC):
|
||||
reraise=True,
|
||||
)
|
||||
async def _generate_response_with_retry(
|
||||
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
||||
self,
|
||||
messages: list[Message],
|
||||
response_model: type[BaseModel] | None = None,
|
||||
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||
) -> dict[str, typing.Any]:
|
||||
try:
|
||||
return await self._generate_response(messages, response_model)
|
||||
return await self._generate_response(messages, response_model, max_tokens)
|
||||
except (httpx.HTTPStatusError, RateLimitError) as e:
|
||||
raise e
|
||||
|
||||
@abstractmethod
|
||||
async def _generate_response(
|
||||
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
||||
self,
|
||||
messages: list[Message],
|
||||
response_model: type[BaseModel] | None = None,
|
||||
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||
) -> dict[str, typing.Any]:
|
||||
pass
|
||||
|
||||
@ -110,7 +116,10 @@ class LLMClient(ABC):
|
||||
return hashlib.md5(key_str.encode()).hexdigest()
|
||||
|
||||
async def generate_response(
|
||||
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
||||
self,
|
||||
messages: list[Message],
|
||||
response_model: type[BaseModel] | None = None,
|
||||
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||
) -> dict[str, typing.Any]:
|
||||
if response_model is not None:
|
||||
serialized_model = json.dumps(response_model.model_json_schema())
|
||||
@ -131,7 +140,7 @@ class LLMClient(ABC):
|
||||
for message in messages:
|
||||
message.content = self._clean_input(message.content)
|
||||
|
||||
response = await self._generate_response_with_retry(messages, response_model)
|
||||
response = await self._generate_response_with_retry(messages, response_model, max_tokens)
|
||||
|
||||
if self.cache_enabled:
|
||||
self.cache_dir.set(cache_key, response)
|
||||
|
@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
DEFAULT_MAX_TOKENS = 16384
|
||||
DEFAULT_MAX_TOKENS = 1024
|
||||
DEFAULT_TEMPERATURE = 0
|
||||
|
||||
|
||||
|
@ -45,7 +45,10 @@ class GroqClient(LLMClient):
|
||||
self.client = AsyncGroq(api_key=config.api_key)
|
||||
|
||||
async def _generate_response(
|
||||
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
||||
self,
|
||||
messages: list[Message],
|
||||
response_model: type[BaseModel] | None = None,
|
||||
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||
) -> dict[str, typing.Any]:
|
||||
msgs: list[ChatCompletionMessageParam] = []
|
||||
for m in messages:
|
||||
@ -58,7 +61,7 @@ class GroqClient(LLMClient):
|
||||
model=self.model or DEFAULT_MODEL,
|
||||
messages=msgs,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
max_tokens=max_tokens or self.max_tokens,
|
||||
response_format={'type': 'json_object'},
|
||||
)
|
||||
result = response.choices[0].message.content or ''
|
||||
|
@ -25,7 +25,7 @@ from pydantic import BaseModel
|
||||
|
||||
from ..prompts.models import Message
|
||||
from .client import LLMClient
|
||||
from .config import LLMConfig
|
||||
from .config import DEFAULT_MAX_TOKENS, LLMConfig
|
||||
from .errors import RateLimitError, RefusalError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -58,7 +58,11 @@ class OpenAIClient(LLMClient):
|
||||
MAX_RETRIES: ClassVar[int] = 2
|
||||
|
||||
def __init__(
|
||||
self, config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None
|
||||
self,
|
||||
config: LLMConfig | None = None,
|
||||
cache: bool = False,
|
||||
client: typing.Any = None,
|
||||
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||
):
|
||||
"""
|
||||
Initialize the OpenAIClient with the provided configuration, cache setting, and client.
|
||||
@ -84,7 +88,10 @@ class OpenAIClient(LLMClient):
|
||||
self.client = client
|
||||
|
||||
async def _generate_response(
|
||||
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
||||
self,
|
||||
messages: list[Message],
|
||||
response_model: type[BaseModel] | None = None,
|
||||
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||
) -> dict[str, typing.Any]:
|
||||
openai_messages: list[ChatCompletionMessageParam] = []
|
||||
for m in messages:
|
||||
@ -98,7 +105,7 @@ class OpenAIClient(LLMClient):
|
||||
model=self.model or DEFAULT_MODEL,
|
||||
messages=openai_messages,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
max_tokens=max_tokens or self.max_tokens,
|
||||
response_format=response_model, # type: ignore
|
||||
)
|
||||
|
||||
@ -119,14 +126,17 @@ class OpenAIClient(LLMClient):
|
||||
raise
|
||||
|
||||
async def generate_response(
|
||||
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
||||
self,
|
||||
messages: list[Message],
|
||||
response_model: type[BaseModel] | None = None,
|
||||
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||
) -> dict[str, typing.Any]:
|
||||
retry_count = 0
|
||||
last_error = None
|
||||
|
||||
while retry_count <= self.MAX_RETRIES:
|
||||
try:
|
||||
response = await self._generate_response(messages, response_model)
|
||||
response = await self._generate_response(messages, response_model, max_tokens)
|
||||
return response
|
||||
except (RateLimitError, RefusalError):
|
||||
# These errors should not trigger retries
|
||||
|
@ -26,7 +26,7 @@ from pydantic import BaseModel
|
||||
|
||||
from ..prompts.models import Message
|
||||
from .client import LLMClient
|
||||
from .config import LLMConfig
|
||||
from .config import DEFAULT_MAX_TOKENS, LLMConfig
|
||||
from .errors import RateLimitError, RefusalError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -85,7 +85,10 @@ class OpenAIGenericClient(LLMClient):
|
||||
self.client = client
|
||||
|
||||
async def _generate_response(
|
||||
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
||||
self,
|
||||
messages: list[Message],
|
||||
response_model: type[BaseModel] | None = None,
|
||||
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||
) -> dict[str, typing.Any]:
|
||||
openai_messages: list[ChatCompletionMessageParam] = []
|
||||
for m in messages:
|
||||
@ -111,7 +114,10 @@ class OpenAIGenericClient(LLMClient):
|
||||
raise
|
||||
|
||||
async def generate_response(
|
||||
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
||||
self,
|
||||
messages: list[Message],
|
||||
response_model: type[BaseModel] | None = None,
|
||||
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||
) -> dict[str, typing.Any]:
|
||||
retry_count = 0
|
||||
last_error = None
|
||||
@ -126,7 +132,9 @@ class OpenAIGenericClient(LLMClient):
|
||||
|
||||
while retry_count <= self.MAX_RETRIES:
|
||||
try:
|
||||
response = await self._generate_response(messages, response_model)
|
||||
response = await self._generate_response(
|
||||
messages, response_model, max_tokens=max_tokens
|
||||
)
|
||||
return response
|
||||
except (RateLimitError, RefusalError):
|
||||
# These errors should not trigger retries
|
||||
|
@ -79,6 +79,8 @@ async def extract_edges(
|
||||
) -> list[EntityEdge]:
|
||||
start = time()
|
||||
|
||||
EXTRACT_EDGES_MAX_TOKENS = 16384
|
||||
|
||||
node_uuids_by_name_map = {node.name: node.uuid for node in nodes}
|
||||
|
||||
# Prepare context for LLM
|
||||
@ -93,7 +95,9 @@ async def extract_edges(
|
||||
reflexion_iterations = 0
|
||||
while facts_missed and reflexion_iterations < MAX_REFLEXION_ITERATIONS:
|
||||
llm_response = await llm_client.generate_response(
|
||||
prompt_library.extract_edges.edge(context), response_model=ExtractedEdges
|
||||
prompt_library.extract_edges.edge(context),
|
||||
response_model=ExtractedEdges,
|
||||
max_tokens=EXTRACT_EDGES_MAX_TOKENS,
|
||||
)
|
||||
edges_data = llm_response.get('edges', [])
|
||||
|
||||
|
2772
poetry.lock
generated
2772
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "graphiti-core"
|
||||
version = "0.5.1"
|
||||
version = "0.5.2"
|
||||
description = "A temporal graph building library"
|
||||
authors = [
|
||||
"Paul Paliychuk <paul@getzep.com>",
|
||||
|
Loading…
x
Reference in New Issue
Block a user