Set max tokens by prompt (#255)

* set max tokens

* update generic openai client

* mypy updates

* fix: dockerfile

---------

Co-authored-by: paulpaliychuk <pavlo.paliychuk.ca@gmail.com>
This commit is contained in:
Preston Rasmussen 2025-01-24 10:14:49 -05:00 committed by GitHub
parent 77cb67cdfe
commit 0f50b74735
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 1373 additions and 1488 deletions

View File

@ -23,7 +23,7 @@ RUN poetry build && pip install dist/*.whl
# Install server dependencies
WORKDIR /app/server
RUN poetry install --no-interaction --no-ansi --no-dev
RUN poetry install --no-interaction --no-ansi --only main --no-root
FROM python:3.12-slim

View File

@ -73,12 +73,12 @@ def lucene_sanitize(query: str) -> str:
return sanitized
def normalize_l2(embedding: list[float]) -> list[float]:
def normalize_l2(embedding: list[float]):
embedding_array = np.array(embedding)
if embedding_array.ndim == 1:
norm = np.linalg.norm(embedding_array)
if norm == 0:
return embedding_array.tolist()
return [0.0] * len(embedding)
return (embedding_array / norm).tolist()
else:
norm = np.linalg.norm(embedding_array, 2, axis=1, keepdims=True)

View File

@ -48,7 +48,10 @@ class AnthropicClient(LLMClient):
)
async def _generate_response(
self, messages: list[Message], response_model: type[BaseModel] | None = None
self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
) -> dict[str, typing.Any]:
system_message = messages[0]
user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]] + [
@ -59,7 +62,7 @@ class AnthropicClient(LLMClient):
result = await self.client.messages.create(
system='Only include JSON in the response. Do not include any additional text or explanation of the content.\n'
+ system_message.content,
max_tokens=self.max_tokens,
max_tokens=max_tokens or self.max_tokens,
temperature=self.temperature,
messages=user_messages, # type: ignore
model=self.model or DEFAULT_MODEL,

View File

@ -26,7 +26,7 @@ from pydantic import BaseModel
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random_exponential
from ..prompts.models import Message
from .config import LLMConfig
from .config import DEFAULT_MAX_TOKENS, LLMConfig
from .errors import RateLimitError
DEFAULT_TEMPERATURE = 0
@ -90,16 +90,22 @@ class LLMClient(ABC):
reraise=True,
)
async def _generate_response_with_retry(
self, messages: list[Message], response_model: type[BaseModel] | None = None
self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
) -> dict[str, typing.Any]:
try:
return await self._generate_response(messages, response_model)
return await self._generate_response(messages, response_model, max_tokens)
except (httpx.HTTPStatusError, RateLimitError) as e:
raise e
@abstractmethod
async def _generate_response(
self, messages: list[Message], response_model: type[BaseModel] | None = None
self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
) -> dict[str, typing.Any]:
pass
@ -110,7 +116,10 @@ class LLMClient(ABC):
return hashlib.md5(key_str.encode()).hexdigest()
async def generate_response(
self, messages: list[Message], response_model: type[BaseModel] | None = None
self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
) -> dict[str, typing.Any]:
if response_model is not None:
serialized_model = json.dumps(response_model.model_json_schema())
@ -131,7 +140,7 @@ class LLMClient(ABC):
for message in messages:
message.content = self._clean_input(message.content)
response = await self._generate_response_with_retry(messages, response_model)
response = await self._generate_response_with_retry(messages, response_model, max_tokens)
if self.cache_enabled:
self.cache_dir.set(cache_key, response)

View File

@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
DEFAULT_MAX_TOKENS = 16384
DEFAULT_MAX_TOKENS = 1024
DEFAULT_TEMPERATURE = 0

View File

@ -45,7 +45,10 @@ class GroqClient(LLMClient):
self.client = AsyncGroq(api_key=config.api_key)
async def _generate_response(
self, messages: list[Message], response_model: type[BaseModel] | None = None
self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
) -> dict[str, typing.Any]:
msgs: list[ChatCompletionMessageParam] = []
for m in messages:
@ -58,7 +61,7 @@ class GroqClient(LLMClient):
model=self.model or DEFAULT_MODEL,
messages=msgs,
temperature=self.temperature,
max_tokens=self.max_tokens,
max_tokens=max_tokens or self.max_tokens,
response_format={'type': 'json_object'},
)
result = response.choices[0].message.content or ''

View File

@ -25,7 +25,7 @@ from pydantic import BaseModel
from ..prompts.models import Message
from .client import LLMClient
from .config import LLMConfig
from .config import DEFAULT_MAX_TOKENS, LLMConfig
from .errors import RateLimitError, RefusalError
logger = logging.getLogger(__name__)
@ -58,7 +58,11 @@ class OpenAIClient(LLMClient):
MAX_RETRIES: ClassVar[int] = 2
def __init__(
self, config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None
self,
config: LLMConfig | None = None,
cache: bool = False,
client: typing.Any = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
):
"""
Initialize the OpenAIClient with the provided configuration, cache setting, and client.
@ -84,7 +88,10 @@ class OpenAIClient(LLMClient):
self.client = client
async def _generate_response(
self, messages: list[Message], response_model: type[BaseModel] | None = None
self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
) -> dict[str, typing.Any]:
openai_messages: list[ChatCompletionMessageParam] = []
for m in messages:
@ -98,7 +105,7 @@ class OpenAIClient(LLMClient):
model=self.model or DEFAULT_MODEL,
messages=openai_messages,
temperature=self.temperature,
max_tokens=self.max_tokens,
max_tokens=max_tokens or self.max_tokens,
response_format=response_model, # type: ignore
)
@ -119,14 +126,17 @@ class OpenAIClient(LLMClient):
raise
async def generate_response(
self, messages: list[Message], response_model: type[BaseModel] | None = None
self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
) -> dict[str, typing.Any]:
retry_count = 0
last_error = None
while retry_count <= self.MAX_RETRIES:
try:
response = await self._generate_response(messages, response_model)
response = await self._generate_response(messages, response_model, max_tokens)
return response
except (RateLimitError, RefusalError):
# These errors should not trigger retries

View File

@ -26,7 +26,7 @@ from pydantic import BaseModel
from ..prompts.models import Message
from .client import LLMClient
from .config import LLMConfig
from .config import DEFAULT_MAX_TOKENS, LLMConfig
from .errors import RateLimitError, RefusalError
logger = logging.getLogger(__name__)
@ -85,7 +85,10 @@ class OpenAIGenericClient(LLMClient):
self.client = client
async def _generate_response(
self, messages: list[Message], response_model: type[BaseModel] | None = None
self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
) -> dict[str, typing.Any]:
openai_messages: list[ChatCompletionMessageParam] = []
for m in messages:
@ -111,7 +114,10 @@ class OpenAIGenericClient(LLMClient):
raise
async def generate_response(
self, messages: list[Message], response_model: type[BaseModel] | None = None
self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
) -> dict[str, typing.Any]:
retry_count = 0
last_error = None
@ -126,7 +132,9 @@ class OpenAIGenericClient(LLMClient):
while retry_count <= self.MAX_RETRIES:
try:
response = await self._generate_response(messages, response_model)
response = await self._generate_response(
messages, response_model, max_tokens=max_tokens
)
return response
except (RateLimitError, RefusalError):
# These errors should not trigger retries

View File

@ -79,6 +79,8 @@ async def extract_edges(
) -> list[EntityEdge]:
start = time()
EXTRACT_EDGES_MAX_TOKENS = 16384
node_uuids_by_name_map = {node.name: node.uuid for node in nodes}
# Prepare context for LLM
@ -93,7 +95,9 @@ async def extract_edges(
reflexion_iterations = 0
while facts_missed and reflexion_iterations < MAX_REFLEXION_ITERATIONS:
llm_response = await llm_client.generate_response(
prompt_library.extract_edges.edge(context), response_model=ExtractedEdges
prompt_library.extract_edges.edge(context),
response_model=ExtractedEdges,
max_tokens=EXTRACT_EDGES_MAX_TOKENS,
)
edges_data = llm_response.get('edges', [])

2772
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "graphiti-core"
version = "0.5.1"
version = "0.5.2"
description = "A temporal graph building library"
authors = [
"Paul Paliychuk <paul@getzep.com>",