diff --git a/kag/common/llm/openai_client.py b/kag/common/llm/openai_client.py index 1432aa66..3eeeebd8 100644 --- a/kag/common/llm/openai_client.py +++ b/kag/common/llm/openai_client.py @@ -12,7 +12,7 @@ import logging import asyncio -from openai import OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI +from openai import OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI, NOT_GIVEN from kag.interface import LLMClient from typing import Callable, Optional @@ -121,7 +121,7 @@ class OpenAIClient(LLMClient): temperature=self.temperature, timeout=self.timeout, tools=tools, - max_tokens=self.max_tokens, + max_tokens=self.max_tokens if self.max_tokens > 0 else NOT_GIVEN, extra_body={"chat_template_kwargs": {"enable_thinking": self.think}}, ) if not self.stream: @@ -213,7 +213,7 @@ class OpenAIClient(LLMClient): temperature=self.temperature, timeout=self.timeout, tools=tools, - max_tokens=self.max_tokens, + max_tokens=self.max_tokens if self.max_tokens > 0 else NOT_GIVEN, extra_body={"chat_template_kwargs": {"enable_thinking": self.think}}, ) if not self.stream: diff --git a/kag/common/vectorize_model/__init__.py b/kag/common/vectorize_model/__init__.py index 1af8cfd3..0f007f6d 100644 --- a/kag/common/vectorize_model/__init__.py +++ b/kag/common/vectorize_model/__init__.py @@ -14,6 +14,7 @@ from kag.common.vectorize_model.local_bge_model import ( LocalBGEVectorizeModel, LocalBGEM3VectorizeModel, ) +from kag.common.vectorize_model.ollama_model import OllamaVectorizeModel from kag.common.vectorize_model.openai_model import OpenAIVectorizeModel from kag.common.vectorize_model.mock_model import MockVectorizeModel from kag.common.vectorize_model.vectorize_model_config_checker import ( @@ -25,6 +26,7 @@ __all__ = [ "LocalBGEM3VectorizeModel", "LocalBGEVectorizeModel", "OpenAIVectorizeModel", + "OllamaVectorizeModel", "MockVectorizeModel", "VectorizeModelConfigChecker", ] diff --git a/kag/common/vectorize_model/ollama_model.py b/kag/common/vectorize_model/ollama_model.py new file mode 100644 index 00000000..262688b9 --- /dev/null +++ b/kag/common/vectorize_model/ollama_model.py @@ -0,0 +1,119 @@ +# Copyright 2023 OpenSPG Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. +import asyncio +from typing import Union, Iterable, List + +from ollama import AsyncClient + +from kag.interface import VectorizeModelABC, EmbeddingVector +import logging + +logger = logging.getLogger(__name__) + + +@VectorizeModelABC.register("Ollama") +@VectorizeModelABC.register("ollama") +class OllamaVectorizeModel(VectorizeModelABC): + """ + A class that extends the VectorizeModelABC base class. + It invokes Ollama embedding services to convert texts into embedding vectors. + """ + + def __init__( + self, + model: str = "bge-m3", + base_url: str = "", + vector_dimensions: int = None, + timeout: float = None, + max_rate: float = 1000, + time_period: float = 1, + batch_size: int = 8, + **kwargs, + ): + """ + Initializes the OpenAIVectorizeModel instance. + + Args: + model (str, optional): The model to use for embedding. Defaults to "text-embedding-3-small". + api_key (str, optional): The API key for accessing the OpenAI service. Defaults to "". + base_url (str, optional): The base URL for the OpenAI service. Defaults to "". + vector_dimensions (int, optional): The number of dimensions for the embedding vectors. Defaults to None. + """ + name = self.generate_key(base_url, model) + super().__init__(name, vector_dimensions, max_rate, time_period) + + self.model = model + self.timeout = timeout + self.base_url = base_url + self.batch_size = batch_size + self.aclient = AsyncClient(host=self.base_url, timeout=self.timeout) + + @classmethod + def generate_key(cls, base_url, model, *args, **kwargs) -> str: + return f"{cls}_{base_url}_{model}" + + def vectorize( + self, texts: Union[str, Iterable[str]] + ) -> Union[EmbeddingVector, Iterable[EmbeddingVector]]: + return asyncio.run(self.avectorize(texts)) + + async def avectorize( + self, texts: Union[str, Iterable[str]] + ) -> Union[EmbeddingVector, Iterable[EmbeddingVector]]: + """ + Vectorize a text string into an embedding vector or multiple text strings into multiple embedding vectors. + + Args: + texts (Union[str, Iterable[str]]): The text or texts to vectorize. + + Returns: + Union[EmbeddingVector, Iterable[EmbeddingVector]]: The embedding vector(s) of the text(s). + """ + + # Handle empty strings in the input + if isinstance(texts, list): + # Create a set of original texts to remove empty and duplicated strings + filtered_texts = [x for x in set(texts) if x] + + if not filtered_texts: + return [[] for _ in texts] # Return empty vectors for all inputs + + embeddings = await self._execute_batch_vectorize(filtered_texts) + results = { + text: embedding for text, embedding in zip(filtered_texts, embeddings) + } + + return [results[text] if text else [] for text in texts] + + elif isinstance(texts, str) and not texts.strip(): + return [] # Return empty vector for empty string + else: + embeddings = await self._execute_batch_vectorize([texts]) + return embeddings[0] + + async def _execute_batch_vectorize(self, texts: List[str]) -> List[EmbeddingVector]: + async def do_task_with_semaphore(_semaphore, _input: str) -> EmbeddingVector: + async with _semaphore: + embeddings_response = await self.aclient.embeddings( + prompt=_input, model=self.model + ) + return embeddings_response.embedding + + try: + semaphore = asyncio.Semaphore(self.batch_size) + embeddings = await asyncio.gather( + *[do_task_with_semaphore(semaphore, text) for text in texts] + ) + return embeddings + except Exception as e: + logger.error(f"Error: {e}") + logger.error(f"input: {texts}") + logger.error(f"model: {self.model}") diff --git a/kag/tools/graph_api/graph_api_abc.py b/kag/tools/graph_api/graph_api_abc.py index dd983ef9..576f0398 100644 --- a/kag/tools/graph_api/graph_api_abc.py +++ b/kag/tools/graph_api/graph_api_abc.py @@ -16,7 +16,7 @@ def replace_qota(s: str): def generate_label(s: SPOBase, heads: List[EntityData], schema): if heads: - return list(set([f"{h.type}" for h in heads])) + return list(set([f"`{h.type}`" for h in heads])) if not isinstance(s, SPOEntity): return ["Entity"] diff --git a/tests/unit/common/vectorize_model/test_vectorize_model.py b/tests/unit/common/vectorize_model/test_vectorize_model.py index ad5761e8..50b2e7e6 100644 --- a/tests/unit/common/vectorize_model/test_vectorize_model.py +++ b/tests/unit/common/vectorize_model/test_vectorize_model.py @@ -7,7 +7,6 @@ from kag.interface import VectorizeModelABC @pytest.mark.skip(reason="Missing API key") def test_openai_vectorize_model(): - conf = { "type": "openai", "model": "BAAI/bge-m3", @@ -21,6 +20,19 @@ def test_openai_vectorize_model(): assert res1 is not None and res1 == res2 +@pytest.mark.skip(reason="Missing model") +def test_ollama_vectorize_model(): + conf = { + "type": "ollama", + "model": "", + "base_url": "http://127.0.0.1:11434/", + "vector_dimensions": 1024, + } + vectorize_model = VectorizeModelABC.from_config(copy.deepcopy(conf)) + emb = vectorize_model.vectorize("你好") + assert len(emb) == vectorize_model.get_vector_dimensions() + + @pytest.mark.skip(reason="Missing model file") def test_bge_vectorize_model(): conf = {