1131 lines
36 KiB
Python
Raw Normal View History

import base64
import copy
import json
import os
import struct
from functools import lru_cache
2024-12-06 11:38:27 +08:00
from typing import List, Dict, Callable, Any, Union
import aioboto3
import aiohttp
2024-10-10 15:02:30 +08:00
import numpy as np
2024-10-16 15:15:10 +08:00
import ollama
import torch
2024-10-25 13:32:25 +05:30
from openai import (
AsyncOpenAI,
APIConnectionError,
RateLimitError,
Timeout,
AsyncAzureOpenAI,
)
from pydantic import BaseModel, Field
2024-10-10 15:02:30 +08:00
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from transformers import AutoTokenizer, AutoModelForCausalLM
2024-10-10 15:02:30 +08:00
from .base import BaseKVStorage
from .utils import (
compute_args_hash,
wrap_embedding_func_with_attrs,
locate_json_string_body_from_string,
quantize_embedding,
get_best_cached_response,
)
import sys
if sys.version_info < (3, 9):
from typing import AsyncIterator
else:
from collections.abc import AsyncIterator
2024-10-14 19:41:07 +08:00
os.environ["TOKENIZERS_PARALLELISM"] = "false"
2024-10-10 15:02:30 +08:00
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def openai_complete_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
base_url=None,
api_key=None,
**kwargs,
2024-10-10 15:02:30 +08:00
) -> str:
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
openai_async_client = (
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
)
2024-10-10 15:02:30 +08:00
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
# Handle cache
mode = kwargs.pop("mode", "default")
args_hash = compute_args_hash(model, messages)
cached_response, quantized, min_val, max_val = await handle_cache(
kwargs.get("hashing_kv"), args_hash, prompt, mode
)
if cached_response is not None:
return cached_response
2024-10-10 15:02:30 +08:00
2024-11-30 00:00:51 +08:00
if "response_format" in kwargs:
response = await openai_async_client.beta.chat.completions.parse(
model=model, messages=messages, **kwargs
)
else:
response = await openai_async_client.chat.completions.create(
model=model, messages=messages, **kwargs
)
2024-11-25 13:29:55 +08:00
content = response.choices[0].message.content
2024-11-25 13:40:38 +08:00
if r"\u" in content:
content = content.encode("utf-8").decode("unicode_escape")
# Save to cache
await save_to_cache(
kwargs.get("hashing_kv"),
CacheData(
args_hash=args_hash,
content=content,
model=model,
prompt=prompt,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=mode,
),
)
2024-11-25 13:29:55 +08:00
return content
2024-10-10 15:02:30 +08:00
2024-10-21 20:40:49 +02:00
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
2024-10-25 13:32:25 +05:30
async def azure_openai_complete_if_cache(
model,
2024-10-21 20:40:49 +02:00
prompt,
system_prompt=None,
history_messages=[],
base_url=None,
api_key=None,
api_version=None,
2024-10-25 13:32:25 +05:30
**kwargs,
):
2024-10-21 20:40:49 +02:00
if api_key:
os.environ["AZURE_OPENAI_API_KEY"] = api_key
if base_url:
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
if api_version:
os.environ["AZURE_OPENAI_API_VERSION"] = api_version
2024-10-21 20:40:49 +02:00
2024-10-25 13:32:25 +05:30
openai_async_client = AsyncAzureOpenAI(
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
)
2024-10-21 20:40:49 +02:00
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
mode = kwargs.pop("mode", "default")
2024-10-21 20:40:49 +02:00
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
if prompt is not None:
messages.append({"role": "user", "content": prompt})
# Handle cache
args_hash = compute_args_hash(model, messages)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, prompt, mode
)
if cached_response is not None:
return cached_response
2024-10-21 20:40:49 +02:00
response = await openai_async_client.chat.completions.create(
model=model, messages=messages, **kwargs
)
content = response.choices[0].message.content
# Save to cache
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=content,
model=model,
prompt=prompt,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=mode,
),
)
2024-10-21 20:40:49 +02:00
return content
2024-10-25 13:32:25 +05:30
class BedrockError(Exception):
"""Generic error for issues related to Amazon Bedrock"""
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, max=60),
retry=retry_if_exception_type((BedrockError)),
)
async def bedrock_complete_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
aws_access_key_id=None,
aws_secret_access_key=None,
aws_session_token=None,
**kwargs,
) -> str:
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
"AWS_ACCESS_KEY_ID", aws_access_key_id
)
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
"AWS_SECRET_ACCESS_KEY", aws_secret_access_key
)
os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
"AWS_SESSION_TOKEN", aws_session_token
)
# Fix message history format
messages = []
for history_message in history_messages:
message = copy.copy(history_message)
message["content"] = [{"text": message["content"]}]
messages.append(message)
# Add user prompt
messages.append({"role": "user", "content": [{"text": prompt}]})
# Handle cache
mode = kwargs.pop("mode", "default")
args_hash = compute_args_hash(model, messages)
cached_response, quantized, min_val, max_val = await handle_cache(
kwargs.get("hashing_kv"), args_hash, prompt, mode
)
if cached_response is not None:
return cached_response
# Initialize Converse API arguments
args = {"modelId": model, "messages": messages}
# Define system prompt
if system_prompt:
args["system"] = [{"text": system_prompt}]
# Map and set up inference parameters
inference_params_map = {
"max_tokens": "maxTokens",
"top_p": "topP",
"stop_sequences": "stopSequences",
}
if inference_params := list(
set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"])
):
args["inferenceConfig"] = {}
for param in inference_params:
args["inferenceConfig"][inference_params_map.get(param, param)] = (
kwargs.pop(param)
)
# Handle cache
mode = kwargs.pop("mode", "default")
args_hash = compute_args_hash(model, messages)
cached_response, quantized, min_val, max_val = await handle_cache(
kwargs.get("hashing_kv"), args_hash, prompt, mode
)
if cached_response is not None:
return cached_response
# Call model via Converse API
session = aioboto3.Session()
async with session.client("bedrock-runtime") as bedrock_async_client:
try:
response = await bedrock_async_client.converse(**args, **kwargs)
except Exception as e:
raise BedrockError(e)
# Save to cache
await save_to_cache(
kwargs.get("hashing_kv"),
CacheData(
args_hash=args_hash,
content=response["output"]["message"]["content"][0]["text"],
model=model,
prompt=prompt,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=mode,
),
)
return response["output"]["message"]["content"][0]["text"]
2024-10-23 15:02:28 +08:00
@lru_cache(maxsize=1)
def initialize_hf_model(model_name):
2024-10-25 13:32:25 +05:30
hf_tokenizer = AutoTokenizer.from_pretrained(
model_name, device_map="auto", trust_remote_code=True
)
hf_model = AutoModelForCausalLM.from_pretrained(
model_name, device_map="auto", trust_remote_code=True
)
2024-10-23 15:25:46 +08:00
if hf_tokenizer.pad_token is None:
hf_tokenizer.pad_token = hf_tokenizer.eos_token
2024-10-23 15:02:28 +08:00
return hf_model, hf_tokenizer
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
2024-10-14 19:41:07 +08:00
async def hf_model_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
**kwargs,
2024-10-14 19:41:07 +08:00
) -> str:
model_name = model
2024-10-23 15:02:28 +08:00
hf_model, hf_tokenizer = initialize_hf_model(model_name)
2024-10-14 19:41:07 +08:00
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
# Handle cache
mode = kwargs.pop("mode", "default")
args_hash = compute_args_hash(model, messages)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, prompt, mode
)
if cached_response is not None:
return cached_response
input_prompt = ""
2024-10-14 19:41:07 +08:00
try:
input_prompt = hf_tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
except Exception:
2024-10-14 19:41:07 +08:00
try:
ori_message = copy.deepcopy(messages)
if messages[0]["role"] == "system":
messages[1]["content"] = (
"<system>"
+ messages[0]["content"]
+ "</system>\n"
+ messages[1]["content"]
)
2024-10-14 19:41:07 +08:00
messages = messages[1:]
input_prompt = hf_tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
except Exception:
2024-10-14 19:41:07 +08:00
len_message = len(ori_message)
for msgid in range(len_message):
input_prompt = (
input_prompt
+ "<"
+ ori_message[msgid]["role"]
+ ">"
+ ori_message[msgid]["content"]
+ "</"
+ ori_message[msgid]["role"]
+ ">\n"
)
input_ids = hf_tokenizer(
input_prompt, return_tensors="pt", padding=True, truncation=True
).to("cuda")
2024-10-26 02:20:23 +08:00
inputs = {k: v.to(hf_model.device) for k, v in input_ids.items()}
output = hf_model.generate(
2024-10-26 02:42:40 +08:00
**input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True
)
2024-10-26 16:24:35 +08:00
response_text = hf_tokenizer.decode(
output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
)
# Save to cache
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=response_text,
model=model,
prompt=prompt,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=mode,
),
)
2024-10-14 19:41:07 +08:00
return response_text
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
2024-10-16 15:15:10 +08:00
async def ollama_model_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
**kwargs,
) -> Union[str, AsyncIterator[str]]:
stream = True if kwargs.get("stream") else False
2024-10-16 15:15:10 +08:00
kwargs.pop("max_tokens", None)
# kwargs.pop("response_format", None) # allow json
host = kwargs.pop("host", None)
timeout = kwargs.pop("timeout", None)
2024-10-16 15:15:10 +08:00
ollama_client = ollama.AsyncClient(host=host, timeout=timeout)
2024-10-16 15:15:10 +08:00
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
# Handle cache
mode = kwargs.pop("mode", "default")
args_hash = compute_args_hash(model, messages)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, prompt, mode
)
if cached_response is not None:
return cached_response
2024-10-16 15:15:10 +08:00
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
if stream:
""" cannot cache stream response """
2024-10-16 15:15:10 +08:00
async def inner():
async for chunk in response:
yield chunk["message"]["content"]
2024-10-16 15:15:10 +08:00
return inner()
else:
result = response["message"]["content"]
# Save to cache
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=result,
model=model,
prompt=prompt,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=mode,
),
)
return result
2024-10-16 15:15:10 +08:00
result = response["message"]["content"]
# Save to cache
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=result,
model=model,
prompt=prompt,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=mode,
),
)
2024-10-16 15:15:10 +08:00
return result
2024-10-14 19:41:07 +08:00
2024-10-26 16:11:15 +08:00
@lru_cache(maxsize=1)
2024-10-26 16:24:35 +08:00
def initialize_lmdeploy_pipeline(
model,
tp=1,
chat_template=None,
log_level="WARNING",
model_format="hf",
quant_policy=0,
):
2024-10-26 16:11:15 +08:00
from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig
2024-10-26 16:24:35 +08:00
2024-10-26 16:11:15 +08:00
lmdeploy_pipe = pipeline(
model_path=model,
2024-10-26 16:24:35 +08:00
backend_config=TurbomindEngineConfig(
tp=tp, model_format=model_format, quant_policy=quant_policy
),
chat_template_config=(
ChatTemplateConfig(model_name=chat_template) if chat_template else None
),
2024-10-26 16:24:35 +08:00
log_level="WARNING",
)
2024-10-26 16:11:15 +08:00
return lmdeploy_pipe
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
2024-10-26 16:11:15 +08:00
async def lmdeploy_model_if_cache(
2024-10-26 16:24:35 +08:00
model,
prompt,
system_prompt=None,
history_messages=[],
chat_template=None,
model_format="hf",
quant_policy=0,
**kwargs,
2024-10-26 16:11:15 +08:00
) -> str:
"""
Args:
model (str): The path to the model.
It could be one of the following options:
- i) A local directory path of a turbomind model which is
converted by `lmdeploy convert` command or download
from ii) and iii).
- ii) The model_id of a lmdeploy-quantized model hosted
inside a model repo on huggingface.co, such as
"InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc.
- iii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
chat_template (str): needed when model is a pytorch model on
huggingface.co, such as "internlm-chat-7b",
2024-10-26 16:24:35 +08:00
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on,
2024-10-26 16:11:15 +08:00
and when the model name of local path did not match the original model name in HF.
tp (int): tensor parallel
prompt (Union[str, List[str]]): input texts to be completed.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
skip_special_tokens (bool): Whether or not to remove special tokens
2024-10-26 16:24:35 +08:00
in the decoding. Default to be True.
do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise.
2024-10-26 16:11:15 +08:00
Default to be False, which means greedy decoding will be applied.
"""
try:
import lmdeploy
from lmdeploy import version_info, GenerationConfig
2024-10-26 16:24:35 +08:00
except Exception:
2024-10-26 16:11:15 +08:00
raise ImportError("Please install lmdeploy before intialize lmdeploy backend.")
2024-10-26 16:24:35 +08:00
2024-10-26 16:11:15 +08:00
kwargs.pop("response_format", None)
max_new_tokens = kwargs.pop("max_tokens", 512)
2024-10-26 16:24:35 +08:00
tp = kwargs.pop("tp", 1)
skip_special_tokens = kwargs.pop("skip_special_tokens", True)
do_preprocess = kwargs.pop("do_preprocess", True)
do_sample = kwargs.pop("do_sample", False)
2024-10-26 16:11:15 +08:00
gen_params = kwargs
2024-10-26 16:24:35 +08:00
2024-10-26 16:11:15 +08:00
version = version_info
if do_sample is not None and version < (0, 6, 0):
raise RuntimeError(
2024-10-26 16:24:35 +08:00
"`do_sample` parameter is not supported by lmdeploy until "
f"v0.6.0, but currently using lmdeloy {lmdeploy.__version__}"
)
2024-10-26 16:11:15 +08:00
else:
do_sample = True
gen_params.update(do_sample=do_sample)
lmdeploy_pipe = initialize_lmdeploy_pipeline(
model=model,
tp=tp,
chat_template=chat_template,
model_format=model_format,
quant_policy=quant_policy,
2024-10-26 16:24:35 +08:00
log_level="WARNING",
)
2024-10-26 16:11:15 +08:00
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
# Handle cache
mode = kwargs.pop("mode", "default")
args_hash = compute_args_hash(model, messages)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, prompt, mode
)
if cached_response is not None:
return cached_response
2024-10-26 16:11:15 +08:00
gen_config = GenerationConfig(
2024-10-26 16:24:35 +08:00
skip_special_tokens=skip_special_tokens,
max_new_tokens=max_new_tokens,
**gen_params,
)
2024-10-26 16:11:15 +08:00
response = ""
2024-10-26 16:24:35 +08:00
async for res in lmdeploy_pipe.generate(
messages,
gen_config=gen_config,
do_preprocess=do_preprocess,
stream_response=False,
session_id=1,
):
2024-10-26 16:11:15 +08:00
response += res.response
# Save to cache
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=response,
model=model,
prompt=prompt,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=mode,
),
)
2024-10-26 16:11:15 +08:00
return response
class GPTKeywordExtractionFormat(BaseModel):
high_level_keywords: List[str]
low_level_keywords: List[str]
2024-10-10 15:02:30 +08:00
async def gpt_4o_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
2024-10-10 15:02:30 +08:00
) -> str:
2024-12-05 11:47:56 +08:00
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat
2024-10-10 15:02:30 +08:00
return await openai_complete_if_cache(
"gpt-4o",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
async def gpt_4o_mini_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
2024-10-10 15:02:30 +08:00
) -> str:
2024-12-05 11:47:56 +08:00
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat
2024-10-10 15:02:30 +08:00
return await openai_complete_if_cache(
"gpt-4o-mini",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
2024-12-04 19:44:04 +08:00
async def nvidia_openai_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
2024-12-05 11:47:56 +08:00
keyword_extraction = kwargs.pop("keyword_extraction", None)
result = await openai_complete_if_cache(
2024-12-04 19:44:04 +08:00
"nvidia/llama-3.1-nemotron-70b-instruct", # context length 128k
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
base_url="https://integrate.api.nvidia.com/v1",
**kwargs,
)
if keyword_extraction: # TODO: use JSON API
return locate_json_string_body_from_string(result)
return result
2024-12-04 19:44:04 +08:00
2024-10-21 20:40:49 +02:00
async def azure_openai_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
2024-10-21 20:40:49 +02:00
) -> str:
2024-12-05 11:47:56 +08:00
keyword_extraction = kwargs.pop("keyword_extraction", None)
result = await azure_openai_complete_if_cache(
2024-10-21 20:40:49 +02:00
"conversation-4o-mini",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
if keyword_extraction: # TODO: use JSON API
return locate_json_string_body_from_string(result)
return result
2024-10-25 13:32:25 +05:30
async def bedrock_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
2024-12-05 11:47:56 +08:00
keyword_extraction = kwargs.pop("keyword_extraction", None)
result = await bedrock_complete_if_cache(
"anthropic.claude-3-haiku-20240307-v1:0",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
if keyword_extraction: # TODO: use JSON API
return locate_json_string_body_from_string(result)
return result
2024-10-14 20:33:46 +08:00
async def hf_model_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
2024-10-14 19:41:07 +08:00
) -> str:
2024-12-05 11:47:56 +08:00
keyword_extraction = kwargs.pop("keyword_extraction", None)
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
result = await hf_model_if_cache(
2024-10-15 20:06:59 +08:00
model_name,
2024-10-14 19:41:07 +08:00
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
if keyword_extraction: # TODO: use JSON API
return locate_json_string_body_from_string(result)
return result
2024-10-14 19:41:07 +08:00
2024-10-16 15:15:10 +08:00
async def ollama_model_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> Union[str, AsyncIterator[str]]:
2024-12-05 11:47:56 +08:00
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
2024-11-30 00:00:51 +08:00
kwargs["format"] = "json"
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
2024-10-16 15:15:10 +08:00
return await ollama_model_if_cache(
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
2024-10-10 15:02:30 +08:00
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
2024-10-10 15:02:30 +08:00
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def openai_embedding(
texts: list[str],
model: str = "text-embedding-3-small",
base_url: str = None,
2024-11-25 13:40:38 +08:00
api_key: str = None,
) -> np.ndarray:
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
openai_async_client = (
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
)
2024-10-10 15:02:30 +08:00
response = await openai_async_client.embeddings.create(
2024-11-25 13:40:38 +08:00
model=model, input=texts, encoding_format="float"
2024-10-10 15:02:30 +08:00
)
return np.array([dp.embedding for dp in response.data])
2024-10-14 19:41:07 +08:00
@wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=512)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def nvidia_openai_embedding(
texts: list[str],
model: str = "nvidia/llama-3.2-nv-embedqa-1b-v1",
# refer to https://build.nvidia.com/nim?filters=usecase%3Ausecase_text_to_embedding
base_url: str = "https://integrate.api.nvidia.com/v1",
api_key: str = None,
2024-12-04 19:44:04 +08:00
input_type: str = "passage", # query for retrieval, passage for embedding
trunc: str = "NONE", # NONE or START or END
encode: str = "float", # float or base64
) -> np.ndarray:
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
openai_async_client = (
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
)
response = await openai_async_client.embeddings.create(
2024-12-04 19:44:04 +08:00
model=model,
input=texts,
encoding_format=encode,
extra_body={"input_type": input_type, "truncate": trunc},
)
return np.array([dp.embedding for dp in response.data])
2024-12-04 19:44:04 +08:00
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8191)
2024-10-21 20:40:49 +02:00
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def azure_openai_embedding(
texts: list[str],
model: str = "text-embedding-3-small",
base_url: str = None,
api_key: str = None,
2024-11-30 17:11:38 +01:00
api_version: str = None,
2024-10-21 20:40:49 +02:00
) -> np.ndarray:
if api_key:
os.environ["AZURE_OPENAI_API_KEY"] = api_key
if base_url:
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
2024-11-30 17:11:38 +01:00
if api_version:
os.environ["AZURE_OPENAI_API_VERSION"] = api_version
2024-10-21 20:40:49 +02:00
2024-10-25 13:32:25 +05:30
openai_async_client = AsyncAzureOpenAI(
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
)
2024-10-21 20:40:49 +02:00
response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="float"
)
return np.array([dp.embedding for dp in response.data])
2024-10-14 19:41:07 +08:00
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def siliconcloud_embedding(
texts: list[str],
model: str = "netease-youdao/bce-embedding-base_v1",
base_url: str = "https://api.siliconflow.cn/v1/embeddings",
max_token_size: int = 512,
api_key: str = None,
) -> np.ndarray:
2024-10-25 13:32:25 +05:30
if api_key and not api_key.startswith("Bearer "):
api_key = "Bearer " + api_key
2024-10-25 13:32:25 +05:30
headers = {"Authorization": api_key, "Content-Type": "application/json"}
truncate_texts = [text[0:max_token_size] for text in texts]
2024-10-25 13:32:25 +05:30
payload = {"model": model, "input": truncate_texts, "encoding_format": "base64"}
base64_strings = []
async with aiohttp.ClientSession() as session:
async with session.post(base_url, headers=headers, json=payload) as response:
content = await response.json()
2024-10-25 13:32:25 +05:30
if "code" in content:
raise ValueError(content)
2024-10-25 13:32:25 +05:30
base64_strings = [item["embedding"] for item in content["data"]]
embeddings = []
for string in base64_strings:
decode_bytes = base64.b64decode(string)
n = len(decode_bytes) // 4
2024-10-25 13:32:25 +05:30
float_array = struct.unpack("<" + "f" * n, decode_bytes)
embeddings.append(float_array)
return np.array(embeddings)
2024-10-14 19:41:07 +08:00
2024-10-23 11:08:40 +08:00
# @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
# @retry(
# stop=stop_after_attempt(3),
# wait=wait_exponential(multiplier=1, min=4, max=10),
# retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), # TODO: fix exceptions
# )
async def bedrock_embedding(
texts: list[str],
model: str = "amazon.titan-embed-text-v2:0",
aws_access_key_id=None,
aws_secret_access_key=None,
aws_session_token=None,
) -> np.ndarray:
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
"AWS_ACCESS_KEY_ID", aws_access_key_id
)
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
"AWS_SECRET_ACCESS_KEY", aws_secret_access_key
)
os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
"AWS_SESSION_TOKEN", aws_session_token
)
session = aioboto3.Session()
async with session.client("bedrock-runtime") as bedrock_async_client:
if (model_provider := model.split(".")[0]) == "amazon":
embed_texts = []
for text in texts:
if "v2" in model:
body = json.dumps(
{
"inputText": text,
# 'dimensions': embedding_dim,
"embeddingTypes": ["float"],
}
)
elif "v1" in model:
body = json.dumps({"inputText": text})
else:
raise ValueError(f"Model {model} is not supported!")
response = await bedrock_async_client.invoke_model(
modelId=model,
body=body,
accept="application/json",
contentType="application/json",
)
response_body = await response.get("body").json()
embed_texts.append(response_body["embedding"])
elif model_provider == "cohere":
body = json.dumps(
{"texts": texts, "input_type": "search_document", "truncate": "NONE"}
)
response = await bedrock_async_client.invoke_model(
model=model,
body=body,
accept="application/json",
contentType="application/json",
)
response_body = json.loads(response.get("body").read())
embed_texts = response_body["embeddings"]
else:
raise ValueError(f"Model provider '{model_provider}' is not supported!")
return np.array(embed_texts)
2024-10-15 19:40:08 +08:00
async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
device = next(embed_model.parameters()).device
input_ids = tokenizer(
texts, return_tensors="pt", padding=True, truncation=True
).input_ids.to(device)
2024-10-14 19:41:07 +08:00
with torch.no_grad():
2024-10-15 19:40:08 +08:00
outputs = embed_model(input_ids)
2024-10-14 19:41:07 +08:00
embeddings = outputs.last_hidden_state.mean(dim=1)
if embeddings.dtype == torch.bfloat16:
return embeddings.detach().to(torch.float32).cpu().numpy()
else:
return embeddings.detach().cpu().numpy()
2024-10-14 19:41:07 +08:00
async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
"""
Deprecated in favor of `embed`.
"""
2024-10-16 15:15:10 +08:00
embed_text = []
ollama_client = ollama.Client(**kwargs)
2024-10-16 15:15:10 +08:00
for text in texts:
data = ollama_client.embeddings(model=embed_model, prompt=text)
2024-10-16 15:15:10 +08:00
embed_text.append(data["embedding"])
return embed_text
2024-10-14 19:41:07 +08:00
async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
ollama_client = ollama.Client(**kwargs)
data = ollama_client.embed(model=embed_model, input=texts)
return data["embeddings"]
2024-10-21 18:34:43 +01:00
class Model(BaseModel):
"""
This is a Pydantic model class named 'Model' that is used to define a custom language model.
Attributes:
gen_func (Callable[[Any], str]): A callable function that generates the response from the language model.
The function should take any argument and return a string.
kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function.
This could include parameters such as the model name, API key, etc.
Example usage:
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]})
In this example, 'openai_complete_if_cache' is the callable function that generates the response from the OpenAI model.
The 'kwargs' dictionary contains the model name and API key to be passed to the function.
"""
2024-10-25 13:32:25 +05:30
gen_func: Callable[[Any], str] = Field(
...,
description="A function that generates the response from the llm. The response must be a string",
)
kwargs: Dict[str, Any] = Field(
...,
description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
)
2024-10-21 18:34:43 +01:00
class Config:
arbitrary_types_allowed = True
2024-10-25 13:32:25 +05:30
class MultiModel:
2024-10-21 18:34:43 +01:00
"""
Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier.
Could also be used for spliting across diffrent models or providers.
Attributes:
models (List[Model]): A list of language models to be used.
Usage example:
```python
models = [
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]}),
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_2"]}),
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_3"]}),
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_4"]}),
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_5"]}),
]
multi_model = MultiModel(models)
rag = LightRAG(
llm_model_func=multi_model.llm_model_func
/ ..other args
)
```
"""
2024-10-25 13:32:25 +05:30
2024-10-21 18:34:43 +01:00
def __init__(self, models: List[Model]):
self._models = models
self._current_model = 0
2024-10-25 13:32:25 +05:30
2024-10-21 18:34:43 +01:00
def _next_model(self):
self._current_model = (self._current_model + 1) % len(self._models)
return self._models[self._current_model]
async def llm_model_func(
2024-10-25 13:32:25 +05:30
self, prompt, system_prompt=None, history_messages=[], **kwargs
2024-10-21 18:34:43 +01:00
) -> str:
2024-10-25 13:32:25 +05:30
kwargs.pop("model", None) # stop from overwriting the custom model name
2024-10-21 18:34:43 +01:00
next_model = self._next_model()
2024-10-25 13:32:25 +05:30
args = dict(
prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
**next_model.kwargs,
2024-10-21 18:34:43 +01:00
)
2024-10-25 13:32:25 +05:30
return await next_model.gen_func(**args)
async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
"""Generic cache handling function"""
if hashing_kv is None:
return None, None, None, None
# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
)
is_embedding_cache_enabled = embedding_cache_config["enabled"]
quantized = min_val = max_val = None
if is_embedding_cache_enabled:
# Use embedding cache
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
current_embedding = await embedding_model_func([prompt])
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
best_cached_response = await get_best_cached_response(
hashing_kv,
current_embedding[0],
similarity_threshold=embedding_cache_config["similarity_threshold"],
mode=mode,
)
if best_cached_response is not None:
return best_cached_response, None, None, None
else:
# Use regular cache
mode_cache = await hashing_kv.get_by_id(mode) or {}
if args_hash in mode_cache:
return mode_cache[args_hash]["return"], None, None, None
return None, quantized, min_val, max_val
@dataclass
class CacheData:
args_hash: str
content: str
model: str
prompt: str
quantized: Optional[np.ndarray] = None
min_val: Optional[float] = None
max_val: Optional[float] = None
mode: str = "default"
async def save_to_cache(hashing_kv, cache_data: CacheData):
if hashing_kv is None:
return
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
mode_cache[cache_data.args_hash] = {
"return": cache_data.content,
"model": cache_data.model,
"embedding": cache_data.quantized.tobytes().hex()
if cache_data.quantized is not None
else None,
"embedding_shape": cache_data.quantized.shape
if cache_data.quantized is not None
else None,
"embedding_min": cache_data.min_val,
"embedding_max": cache_data.max_val,
"original_prompt": cache_data.prompt,
}
await hashing_kv.upsert({cache_data.mode: mode_cache})
2024-10-10 15:02:30 +08:00
if __name__ == "__main__":
import asyncio
async def main():
result = await gpt_4o_mini_complete("How are you?")
2024-10-10 15:02:30 +08:00
print(result)
2024-11-06 11:18:14 -05:00
asyncio.run(main())