mirror of
https://github.com/OpenSPG/KAG.git
synced 2025-06-27 03:20:08 +00:00
feat(kag): ollama vectorize model (#562)
* update DSL query string * fix(tools): update ner construct params * feat(kag): ollama vectorize model * feat(kag): formatter * feat(common): remove LLM default setting of max_tokens * feat(common): remove LLM default setting of max_tokens * fix(tools): wrapper entity type with '`' in generate_label()
This commit is contained in:
parent
881d6e5d0c
commit
164f518691
@ -12,7 +12,7 @@
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
from openai import OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI
|
||||
from openai import OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI, NOT_GIVEN
|
||||
|
||||
from kag.interface import LLMClient
|
||||
from typing import Callable, Optional
|
||||
@ -121,7 +121,7 @@ class OpenAIClient(LLMClient):
|
||||
temperature=self.temperature,
|
||||
timeout=self.timeout,
|
||||
tools=tools,
|
||||
max_tokens=self.max_tokens,
|
||||
max_tokens=self.max_tokens if self.max_tokens > 0 else NOT_GIVEN,
|
||||
extra_body={"chat_template_kwargs": {"enable_thinking": self.think}},
|
||||
)
|
||||
if not self.stream:
|
||||
@ -213,7 +213,7 @@ class OpenAIClient(LLMClient):
|
||||
temperature=self.temperature,
|
||||
timeout=self.timeout,
|
||||
tools=tools,
|
||||
max_tokens=self.max_tokens,
|
||||
max_tokens=self.max_tokens if self.max_tokens > 0 else NOT_GIVEN,
|
||||
extra_body={"chat_template_kwargs": {"enable_thinking": self.think}},
|
||||
)
|
||||
if not self.stream:
|
||||
|
@ -14,6 +14,7 @@ from kag.common.vectorize_model.local_bge_model import (
|
||||
LocalBGEVectorizeModel,
|
||||
LocalBGEM3VectorizeModel,
|
||||
)
|
||||
from kag.common.vectorize_model.ollama_model import OllamaVectorizeModel
|
||||
from kag.common.vectorize_model.openai_model import OpenAIVectorizeModel
|
||||
from kag.common.vectorize_model.mock_model import MockVectorizeModel
|
||||
from kag.common.vectorize_model.vectorize_model_config_checker import (
|
||||
@ -25,6 +26,7 @@ __all__ = [
|
||||
"LocalBGEM3VectorizeModel",
|
||||
"LocalBGEVectorizeModel",
|
||||
"OpenAIVectorizeModel",
|
||||
"OllamaVectorizeModel",
|
||||
"MockVectorizeModel",
|
||||
"VectorizeModelConfigChecker",
|
||||
]
|
||||
|
119
kag/common/vectorize_model/ollama_model.py
Normal file
119
kag/common/vectorize_model/ollama_model.py
Normal file
@ -0,0 +1,119 @@
|
||||
# Copyright 2023 OpenSPG Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
# in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
import asyncio
|
||||
from typing import Union, Iterable, List
|
||||
|
||||
from ollama import AsyncClient
|
||||
|
||||
from kag.interface import VectorizeModelABC, EmbeddingVector
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@VectorizeModelABC.register("Ollama")
|
||||
@VectorizeModelABC.register("ollama")
|
||||
class OllamaVectorizeModel(VectorizeModelABC):
|
||||
"""
|
||||
A class that extends the VectorizeModelABC base class.
|
||||
It invokes Ollama embedding services to convert texts into embedding vectors.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "bge-m3",
|
||||
base_url: str = "",
|
||||
vector_dimensions: int = None,
|
||||
timeout: float = None,
|
||||
max_rate: float = 1000,
|
||||
time_period: float = 1,
|
||||
batch_size: int = 8,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initializes the OpenAIVectorizeModel instance.
|
||||
|
||||
Args:
|
||||
model (str, optional): The model to use for embedding. Defaults to "text-embedding-3-small".
|
||||
api_key (str, optional): The API key for accessing the OpenAI service. Defaults to "".
|
||||
base_url (str, optional): The base URL for the OpenAI service. Defaults to "".
|
||||
vector_dimensions (int, optional): The number of dimensions for the embedding vectors. Defaults to None.
|
||||
"""
|
||||
name = self.generate_key(base_url, model)
|
||||
super().__init__(name, vector_dimensions, max_rate, time_period)
|
||||
|
||||
self.model = model
|
||||
self.timeout = timeout
|
||||
self.base_url = base_url
|
||||
self.batch_size = batch_size
|
||||
self.aclient = AsyncClient(host=self.base_url, timeout=self.timeout)
|
||||
|
||||
@classmethod
|
||||
def generate_key(cls, base_url, model, *args, **kwargs) -> str:
|
||||
return f"{cls}_{base_url}_{model}"
|
||||
|
||||
def vectorize(
|
||||
self, texts: Union[str, Iterable[str]]
|
||||
) -> Union[EmbeddingVector, Iterable[EmbeddingVector]]:
|
||||
return asyncio.run(self.avectorize(texts))
|
||||
|
||||
async def avectorize(
|
||||
self, texts: Union[str, Iterable[str]]
|
||||
) -> Union[EmbeddingVector, Iterable[EmbeddingVector]]:
|
||||
"""
|
||||
Vectorize a text string into an embedding vector or multiple text strings into multiple embedding vectors.
|
||||
|
||||
Args:
|
||||
texts (Union[str, Iterable[str]]): The text or texts to vectorize.
|
||||
|
||||
Returns:
|
||||
Union[EmbeddingVector, Iterable[EmbeddingVector]]: The embedding vector(s) of the text(s).
|
||||
"""
|
||||
|
||||
# Handle empty strings in the input
|
||||
if isinstance(texts, list):
|
||||
# Create a set of original texts to remove empty and duplicated strings
|
||||
filtered_texts = [x for x in set(texts) if x]
|
||||
|
||||
if not filtered_texts:
|
||||
return [[] for _ in texts] # Return empty vectors for all inputs
|
||||
|
||||
embeddings = await self._execute_batch_vectorize(filtered_texts)
|
||||
results = {
|
||||
text: embedding for text, embedding in zip(filtered_texts, embeddings)
|
||||
}
|
||||
|
||||
return [results[text] if text else [] for text in texts]
|
||||
|
||||
elif isinstance(texts, str) and not texts.strip():
|
||||
return [] # Return empty vector for empty string
|
||||
else:
|
||||
embeddings = await self._execute_batch_vectorize([texts])
|
||||
return embeddings[0]
|
||||
|
||||
async def _execute_batch_vectorize(self, texts: List[str]) -> List[EmbeddingVector]:
|
||||
async def do_task_with_semaphore(_semaphore, _input: str) -> EmbeddingVector:
|
||||
async with _semaphore:
|
||||
embeddings_response = await self.aclient.embeddings(
|
||||
prompt=_input, model=self.model
|
||||
)
|
||||
return embeddings_response.embedding
|
||||
|
||||
try:
|
||||
semaphore = asyncio.Semaphore(self.batch_size)
|
||||
embeddings = await asyncio.gather(
|
||||
*[do_task_with_semaphore(semaphore, text) for text in texts]
|
||||
)
|
||||
return embeddings
|
||||
except Exception as e:
|
||||
logger.error(f"Error: {e}")
|
||||
logger.error(f"input: {texts}")
|
||||
logger.error(f"model: {self.model}")
|
@ -16,7 +16,7 @@ def replace_qota(s: str):
|
||||
|
||||
def generate_label(s: SPOBase, heads: List[EntityData], schema):
|
||||
if heads:
|
||||
return list(set([f"{h.type}" for h in heads]))
|
||||
return list(set([f"`{h.type}`" for h in heads]))
|
||||
|
||||
if not isinstance(s, SPOEntity):
|
||||
return ["Entity"]
|
||||
|
@ -7,7 +7,6 @@ from kag.interface import VectorizeModelABC
|
||||
|
||||
@pytest.mark.skip(reason="Missing API key")
|
||||
def test_openai_vectorize_model():
|
||||
|
||||
conf = {
|
||||
"type": "openai",
|
||||
"model": "BAAI/bge-m3",
|
||||
@ -21,6 +20,19 @@ def test_openai_vectorize_model():
|
||||
assert res1 is not None and res1 == res2
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Missing model")
|
||||
def test_ollama_vectorize_model():
|
||||
conf = {
|
||||
"type": "ollama",
|
||||
"model": "",
|
||||
"base_url": "http://127.0.0.1:11434/",
|
||||
"vector_dimensions": 1024,
|
||||
}
|
||||
vectorize_model = VectorizeModelABC.from_config(copy.deepcopy(conf))
|
||||
emb = vectorize_model.vectorize("你好")
|
||||
assert len(emb) == vectorize_model.get_vector_dimensions()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Missing model file")
|
||||
def test_bge_vectorize_model():
|
||||
conf = {
|
||||
|
Loading…
x
Reference in New Issue
Block a user