235 lines
29 KiB
Python
Raw Normal View History

import os
from collections.abc import Generator
import pytest
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.google.llm.llm import GoogleLargeLanguageModel
from tests.integration_tests.model_runtime.__mock.google import setup_google_mock
@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True)
def test_validate_credentials(setup_google_mock):
model = GoogleLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='gemini-pro',
credentials={
'google_api_key': 'invalid_key'
}
)
model.validate_credentials(
model='gemini-pro',
credentials={
'google_api_key': os.environ.get('GOOGLE_API_KEY')
}
)
@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True)
def test_invoke_model(setup_google_mock):
model = GoogleLargeLanguageModel()
response = model.invoke(
model='gemini-pro',
credentials={
'google_api_key': os.environ.get('GOOGLE_API_KEY')
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Give me your worst dad joke or i will unplug you'
),
AssistantPromptMessage(
content='Why did the scarecrow win an award? Because he was outstanding in his field!'
),
UserPromptMessage(
content=[
TextPromptMessageContent(
data="ok something snarkier pls"
),
TextPromptMessageContent(
data="i may still unplug you"
)]
)
],
model_parameters={
'temperature': 0.5,
'top_p': 1.0,
'max_tokens_to_sample': 2048
},
stop=['How'],
stream=False,
user="abc-123"
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True)
def test_invoke_stream_model(setup_google_mock):
model = GoogleLargeLanguageModel()
response = model.invoke(
model='gemini-pro',
credentials={
'google_api_key': os.environ.get('GOOGLE_API_KEY')
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Give me your worst dad joke or i will unplug you'
),
AssistantPromptMessage(
content='Why did the scarecrow win an award? Because he was outstanding in his field!'
),
UserPromptMessage(
content=[
TextPromptMessageContent(
data="ok something snarkier pls"
),
TextPromptMessageContent(
data="i may still unplug you"
)]
)
],
model_parameters={
'temperature': 0.2,
'top_k': 5,
'max_tokens_to_sample': 2048
},
stream=True,
user="abc-123"
)
assert isinstance(response, Generator)
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True)
def test_invoke_chat_model_with_vision(setup_google_mock):
model = GoogleLargeLanguageModel()
result = model.invoke(
model='gemini-pro-vision',
credentials={
'google_api_key': os.environ.get('GOOGLE_API_KEY')
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content=[
TextPromptMessageContent(
data="what do you see?"
),
ImagePromptMessageContent(
data='
)
]
)
],
model_parameters={
'temperature': 0.3,
'top_p': 0.2,
'top_k': 3,
'max_tokens': 100
},
stream=False,
user="abc-123"
)
assert isinstance(result, LLMResult)
assert len(result.message.content) > 0
@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True)
def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock):
model = GoogleLargeLanguageModel()
result = model.invoke(
model='gemini-pro-vision',
credentials={
'google_api_key': os.environ.get('GOOGLE_API_KEY')
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.'
),
UserPromptMessage(
content=[
TextPromptMessageContent(
data="what do you see?"
),
ImagePromptMessageContent(
data='
)
]
),
AssistantPromptMessage(
content="I see a blue letter 'D' with a gradient from light blue to dark blue."
),
UserPromptMessage(
content=[
TextPromptMessageContent(
data="what about now?"
),
ImagePromptMessageContent(
data=''
)
]
)
],
model_parameters={
'temperature': 0.3,
'top_p': 0.2,
'top_k': 3,
'max_tokens': 100
},
stream=False,
user="abc-123"
)
print(f"resultz: {result.message.content}")
assert isinstance(result, LLMResult)
assert len(result.message.content) > 0
def test_get_num_tokens():
model = GoogleLargeLanguageModel()
num_tokens = model.get_num_tokens(
model='gemini-pro',
credentials={
'google_api_key': os.environ.get('GOOGLE_API_KEY')
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Hello World!'
)
]
)
assert num_tokens > 0 # The exact number of tokens may vary based on the model's tokenization