mirror of
				https://github.com/infiniflow/ragflow.git
				synced 2025-11-04 03:39:41 +00:00 
			
		
		
		
	### What problem does this PR solve? Fix keys of Xinference deployed models, especially has the same model name with public hosted models. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Co-authored-by: 0000sir <0000sir@gmail.com> Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
		
			
				
	
	
		
			196 lines
		
	
	
		
			6.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			196 lines
		
	
	
		
			6.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
#
 | 
						|
#  Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
 | 
						|
#
 | 
						|
#  Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
#  you may not use this file except in compliance with the License.
 | 
						|
#  You may obtain a copy of the License at
 | 
						|
#
 | 
						|
#      http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
#
 | 
						|
#  Unless required by applicable law or agreed to in writing, software
 | 
						|
#  distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
#  See the License for the specific language governing permissions and
 | 
						|
#  limitations under the License.
 | 
						|
#
 | 
						|
import requests
 | 
						|
from openai.lib.azure import AzureOpenAI
 | 
						|
from zhipuai import ZhipuAI
 | 
						|
import io
 | 
						|
from abc import ABC
 | 
						|
from ollama import Client
 | 
						|
from openai import OpenAI
 | 
						|
import os
 | 
						|
import json
 | 
						|
from rag.utils import num_tokens_from_string
 | 
						|
import base64
 | 
						|
import re
 | 
						|
 | 
						|
 | 
						|
class Base(ABC):
 | 
						|
    def __init__(self, key, model_name):
 | 
						|
        pass
 | 
						|
 | 
						|
    def transcription(self, audio, **kwargs):
 | 
						|
        transcription = self.client.audio.transcriptions.create(
 | 
						|
            model=self.model_name,
 | 
						|
            file=audio,
 | 
						|
            response_format="text"
 | 
						|
        )
 | 
						|
        return transcription.text.strip(), num_tokens_from_string(transcription.text.strip())
 | 
						|
 | 
						|
    def audio2base64(self, audio):
 | 
						|
        if isinstance(audio, bytes):
 | 
						|
            return base64.b64encode(audio).decode("utf-8")
 | 
						|
        if isinstance(audio, io.BytesIO):
 | 
						|
            return base64.b64encode(audio.getvalue()).decode("utf-8")
 | 
						|
        raise TypeError("The input audio file should be in binary format.")
 | 
						|
 | 
						|
 | 
						|
class GPTSeq2txt(Base):
 | 
						|
    def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1"):
 | 
						|
        if not base_url: base_url = "https://api.openai.com/v1"
 | 
						|
        self.client = OpenAI(api_key=key, base_url=base_url)
 | 
						|
        self.model_name = model_name
 | 
						|
 | 
						|
 | 
						|
class QWenSeq2txt(Base):
 | 
						|
    def __init__(self, key, model_name="paraformer-realtime-8k-v1", **kwargs):
 | 
						|
        import dashscope
 | 
						|
        dashscope.api_key = key
 | 
						|
        self.model_name = model_name
 | 
						|
 | 
						|
    def transcription(self, audio, format):
 | 
						|
        from http import HTTPStatus
 | 
						|
        from dashscope.audio.asr import Recognition
 | 
						|
 | 
						|
        recognition = Recognition(model=self.model_name,
 | 
						|
                                  format=format,
 | 
						|
                                  sample_rate=16000,
 | 
						|
                                  callback=None)
 | 
						|
        result = recognition.call(audio)
 | 
						|
 | 
						|
        ans = ""
 | 
						|
        if result.status_code == HTTPStatus.OK:
 | 
						|
            for sentence in result.get_sentence():
 | 
						|
                ans += sentence.text.decode('utf-8') + '\n'
 | 
						|
            return ans, num_tokens_from_string(ans)
 | 
						|
 | 
						|
        return "**ERROR**: " + result.message, 0
 | 
						|
 | 
						|
 | 
						|
class AzureSeq2txt(Base):
 | 
						|
    def __init__(self, key, model_name, lang="Chinese", **kwargs):
 | 
						|
        self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
 | 
						|
        self.model_name = model_name
 | 
						|
        self.lang = lang
 | 
						|
 | 
						|
 | 
						|
class XinferenceSeq2txt(Base):
 | 
						|
    def __init__(self,key,model_name="whisper-small",**kwargs):
 | 
						|
        self.base_url = kwargs.get('base_url', None)
 | 
						|
        self.model_name = model_name
 | 
						|
        self.key = key
 | 
						|
 | 
						|
    def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7):
 | 
						|
        if isinstance(audio, str):
 | 
						|
            audio_file = open(audio, 'rb')
 | 
						|
            audio_data = audio_file.read()
 | 
						|
            audio_file_name = audio.split("/")[-1]
 | 
						|
        else:
 | 
						|
            audio_data = audio
 | 
						|
            audio_file_name = "audio.wav"
 | 
						|
 | 
						|
        payload = {
 | 
						|
            "model": self.model_name,
 | 
						|
            "language": language,
 | 
						|
            "prompt": prompt,
 | 
						|
            "response_format": response_format,
 | 
						|
            "temperature": temperature
 | 
						|
        }
 | 
						|
 | 
						|
        files = {
 | 
						|
            "file": (audio_file_name, audio_data, 'audio/wav')
 | 
						|
        }
 | 
						|
 | 
						|
        try:
 | 
						|
            response = requests.post(
 | 
						|
                f"{self.base_url}/v1/audio/transcriptions",
 | 
						|
                files=files,
 | 
						|
                data=payload
 | 
						|
            )
 | 
						|
            response.raise_for_status()
 | 
						|
            result = response.json()
 | 
						|
 | 
						|
            if 'text' in result:
 | 
						|
                transcription_text = result['text'].strip()
 | 
						|
                return transcription_text, num_tokens_from_string(transcription_text)
 | 
						|
            else:
 | 
						|
                return "**ERROR**: Failed to retrieve transcription.", 0
 | 
						|
 | 
						|
        except requests.exceptions.RequestException as e:
 | 
						|
            return f"**ERROR**: {str(e)}", 0
 | 
						|
 | 
						|
 | 
						|
class TencentCloudSeq2txt(Base):
 | 
						|
    def __init__(
 | 
						|
            self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"
 | 
						|
    ):
 | 
						|
        from tencentcloud.common import credential
 | 
						|
        from tencentcloud.asr.v20190614 import asr_client
 | 
						|
 | 
						|
        key = json.loads(key)
 | 
						|
        sid = key.get("tencent_cloud_sid", "")
 | 
						|
        sk = key.get("tencent_cloud_sk", "")
 | 
						|
        cred = credential.Credential(sid, sk)
 | 
						|
        self.client = asr_client.AsrClient(cred, "")
 | 
						|
        self.model_name = model_name
 | 
						|
 | 
						|
    def transcription(self, audio, max_retries=60, retry_interval=5):
 | 
						|
        from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
 | 
						|
            TencentCloudSDKException,
 | 
						|
        )
 | 
						|
        from tencentcloud.asr.v20190614 import models
 | 
						|
        import time
 | 
						|
 | 
						|
        b64 = self.audio2base64(audio)
 | 
						|
        try:
 | 
						|
            # dispatch disk
 | 
						|
            req = models.CreateRecTaskRequest()
 | 
						|
            params = {
 | 
						|
                "EngineModelType": self.model_name,
 | 
						|
                "ChannelNum": 1,
 | 
						|
                "ResTextFormat": 0,
 | 
						|
                "SourceType": 1,
 | 
						|
                "Data": b64,
 | 
						|
            }
 | 
						|
            req.from_json_string(json.dumps(params))
 | 
						|
            resp = self.client.CreateRecTask(req)
 | 
						|
 | 
						|
            # loop query
 | 
						|
            req = models.DescribeTaskStatusRequest()
 | 
						|
            params = {"TaskId": resp.Data.TaskId}
 | 
						|
            req.from_json_string(json.dumps(params))
 | 
						|
            retries = 0
 | 
						|
            while retries < max_retries:
 | 
						|
                resp = self.client.DescribeTaskStatus(req)
 | 
						|
                if resp.Data.StatusStr == "success":
 | 
						|
                    text = re.sub(
 | 
						|
                        r"\[\d+:\d+\.\d+,\d+:\d+\.\d+\]\s*", "", resp.Data.Result
 | 
						|
                    ).strip()
 | 
						|
                    return text, num_tokens_from_string(text)
 | 
						|
                elif resp.Data.StatusStr == "failed":
 | 
						|
                    return (
 | 
						|
                        "**ERROR**: Failed to retrieve speech recognition results.",
 | 
						|
                        0,
 | 
						|
                    )
 | 
						|
                else:
 | 
						|
                    time.sleep(retry_interval)
 | 
						|
                    retries += 1
 | 
						|
            return "**ERROR**: Max retries exceeded. Task may still be processing.", 0
 | 
						|
 | 
						|
        except TencentCloudSDKException as e:
 | 
						|
            return "**ERROR**: " + str(e), 0
 | 
						|
        except Exception as e:
 | 
						|
            return "**ERROR**: " + str(e), 0
 |