feat(kag): ollama vectorize model (#562)

* update DSL query string

* fix(tools): update ner construct params

* feat(kag): ollama vectorize model

* feat(kag): formatter

* feat(common): remove LLM default setting of max_tokens

* feat(common): remove LLM default setting of max_tokens

* fix(tools): wrapper entity type with '`' in generate_label()
This commit is contained in:
thundax 2025-05-31 16:53:01 +08:00 committed by GitHub
parent 881d6e5d0c
commit 164f518691
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 138 additions and 5 deletions

View File

@ -12,7 +12,7 @@
import logging
import asyncio
from openai import OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI
from openai import OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI, NOT_GIVEN
from kag.interface import LLMClient
from typing import Callable, Optional
@ -121,7 +121,7 @@ class OpenAIClient(LLMClient):
temperature=self.temperature,
timeout=self.timeout,
tools=tools,
max_tokens=self.max_tokens,
max_tokens=self.max_tokens if self.max_tokens > 0 else NOT_GIVEN,
extra_body={"chat_template_kwargs": {"enable_thinking": self.think}},
)
if not self.stream:
@ -213,7 +213,7 @@ class OpenAIClient(LLMClient):
temperature=self.temperature,
timeout=self.timeout,
tools=tools,
max_tokens=self.max_tokens,
max_tokens=self.max_tokens if self.max_tokens > 0 else NOT_GIVEN,
extra_body={"chat_template_kwargs": {"enable_thinking": self.think}},
)
if not self.stream:

View File

@ -14,6 +14,7 @@ from kag.common.vectorize_model.local_bge_model import (
LocalBGEVectorizeModel,
LocalBGEM3VectorizeModel,
)
from kag.common.vectorize_model.ollama_model import OllamaVectorizeModel
from kag.common.vectorize_model.openai_model import OpenAIVectorizeModel
from kag.common.vectorize_model.mock_model import MockVectorizeModel
from kag.common.vectorize_model.vectorize_model_config_checker import (
@ -25,6 +26,7 @@ __all__ = [
"LocalBGEM3VectorizeModel",
"LocalBGEVectorizeModel",
"OpenAIVectorizeModel",
"OllamaVectorizeModel",
"MockVectorizeModel",
"VectorizeModelConfigChecker",
]

View File

@ -0,0 +1,119 @@
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import asyncio
from typing import Union, Iterable, List
from ollama import AsyncClient
from kag.interface import VectorizeModelABC, EmbeddingVector
import logging
logger = logging.getLogger(__name__)
@VectorizeModelABC.register("Ollama")
@VectorizeModelABC.register("ollama")
class OllamaVectorizeModel(VectorizeModelABC):
"""
A class that extends the VectorizeModelABC base class.
It invokes Ollama embedding services to convert texts into embedding vectors.
"""
def __init__(
self,
model: str = "bge-m3",
base_url: str = "",
vector_dimensions: int = None,
timeout: float = None,
max_rate: float = 1000,
time_period: float = 1,
batch_size: int = 8,
**kwargs,
):
"""
Initializes the OpenAIVectorizeModel instance.
Args:
model (str, optional): The model to use for embedding. Defaults to "text-embedding-3-small".
api_key (str, optional): The API key for accessing the OpenAI service. Defaults to "".
base_url (str, optional): The base URL for the OpenAI service. Defaults to "".
vector_dimensions (int, optional): The number of dimensions for the embedding vectors. Defaults to None.
"""
name = self.generate_key(base_url, model)
super().__init__(name, vector_dimensions, max_rate, time_period)
self.model = model
self.timeout = timeout
self.base_url = base_url
self.batch_size = batch_size
self.aclient = AsyncClient(host=self.base_url, timeout=self.timeout)
@classmethod
def generate_key(cls, base_url, model, *args, **kwargs) -> str:
return f"{cls}_{base_url}_{model}"
def vectorize(
self, texts: Union[str, Iterable[str]]
) -> Union[EmbeddingVector, Iterable[EmbeddingVector]]:
return asyncio.run(self.avectorize(texts))
async def avectorize(
self, texts: Union[str, Iterable[str]]
) -> Union[EmbeddingVector, Iterable[EmbeddingVector]]:
"""
Vectorize a text string into an embedding vector or multiple text strings into multiple embedding vectors.
Args:
texts (Union[str, Iterable[str]]): The text or texts to vectorize.
Returns:
Union[EmbeddingVector, Iterable[EmbeddingVector]]: The embedding vector(s) of the text(s).
"""
# Handle empty strings in the input
if isinstance(texts, list):
# Create a set of original texts to remove empty and duplicated strings
filtered_texts = [x for x in set(texts) if x]
if not filtered_texts:
return [[] for _ in texts] # Return empty vectors for all inputs
embeddings = await self._execute_batch_vectorize(filtered_texts)
results = {
text: embedding for text, embedding in zip(filtered_texts, embeddings)
}
return [results[text] if text else [] for text in texts]
elif isinstance(texts, str) and not texts.strip():
return [] # Return empty vector for empty string
else:
embeddings = await self._execute_batch_vectorize([texts])
return embeddings[0]
async def _execute_batch_vectorize(self, texts: List[str]) -> List[EmbeddingVector]:
async def do_task_with_semaphore(_semaphore, _input: str) -> EmbeddingVector:
async with _semaphore:
embeddings_response = await self.aclient.embeddings(
prompt=_input, model=self.model
)
return embeddings_response.embedding
try:
semaphore = asyncio.Semaphore(self.batch_size)
embeddings = await asyncio.gather(
*[do_task_with_semaphore(semaphore, text) for text in texts]
)
return embeddings
except Exception as e:
logger.error(f"Error: {e}")
logger.error(f"input: {texts}")
logger.error(f"model: {self.model}")

View File

@ -16,7 +16,7 @@ def replace_qota(s: str):
def generate_label(s: SPOBase, heads: List[EntityData], schema):
if heads:
return list(set([f"{h.type}" for h in heads]))
return list(set([f"`{h.type}`" for h in heads]))
if not isinstance(s, SPOEntity):
return ["Entity"]

View File

@ -7,7 +7,6 @@ from kag.interface import VectorizeModelABC
@pytest.mark.skip(reason="Missing API key")
def test_openai_vectorize_model():
conf = {
"type": "openai",
"model": "BAAI/bge-m3",
@ -21,6 +20,19 @@ def test_openai_vectorize_model():
assert res1 is not None and res1 == res2
@pytest.mark.skip(reason="Missing model")
def test_ollama_vectorize_model():
conf = {
"type": "ollama",
"model": "",
"base_url": "http://127.0.0.1:11434/",
"vector_dimensions": 1024,
}
vectorize_model = VectorizeModelABC.from_config(copy.deepcopy(conf))
emb = vectorize_model.vectorize("你好")
assert len(emb) == vectorize_model.get_vector_dimensions()
@pytest.mark.skip(reason="Missing model file")
def test_bge_vectorize_model():
conf = {