mirror of
				https://github.com/langgenius/dify.git
				synced 2025-10-31 19:03:09 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			178 lines
		
	
	
		
			6.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			178 lines
		
	
	
		
			6.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import os
 | |
| import re
 | |
| from typing import Union
 | |
| 
 | |
| import pytest
 | |
| from _pytest.monkeypatch import MonkeyPatch
 | |
| from requests import Response
 | |
| from requests.exceptions import ConnectionError
 | |
| from requests.sessions import Session
 | |
| from xinference_client.client.restful.restful_client import (
 | |
|     Client,
 | |
|     RESTfulChatglmCppChatModelHandle,
 | |
|     RESTfulChatModelHandle,
 | |
|     RESTfulEmbeddingModelHandle,
 | |
|     RESTfulGenerateModelHandle,
 | |
|     RESTfulRerankModelHandle,
 | |
| )
 | |
| from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage
 | |
| 
 | |
| 
 | |
| class MockXinferenceClass:
 | |
|     def get_chat_model(self: Client, model_uid: str) -> Union[RESTfulChatglmCppChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatModelHandle]:
 | |
|         if not re.match(r'https?:\/\/[^\s\/$.?#].[^\s]*$', self.base_url):
 | |
|             raise RuntimeError('404 Not Found')
 | |
|         
 | |
|         if 'generate' == model_uid:
 | |
|             return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={})
 | |
|         if 'chat' == model_uid:
 | |
|             return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={})
 | |
|         if 'embedding' == model_uid:
 | |
|             return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={})
 | |
|         if 'rerank' == model_uid:
 | |
|             return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={})
 | |
|         raise RuntimeError('404 Not Found')
 | |
|         
 | |
|     def get(self: Session, url: str, **kwargs):
 | |
|         response = Response()
 | |
|         if 'v1/models/' in url:
 | |
|             # get model uid
 | |
|             model_uid = url.split('/')[-1] or ''
 | |
|             if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \
 | |
|                 model_uid not in ['generate', 'chat', 'embedding', 'rerank']:
 | |
|                 response.status_code = 404
 | |
|                 response._content = b'{}'
 | |
|                 return response
 | |
| 
 | |
|             # check if url is valid
 | |
|             if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url):
 | |
|                 response.status_code = 404
 | |
|                 response._content = b'{}'
 | |
|                 return response
 | |
|             
 | |
|             if model_uid in ['generate', 'chat']:
 | |
|                 response.status_code = 200
 | |
|                 response._content = b'''{
 | |
|                     "model_type": "LLM",
 | |
|                     "address": "127.0.0.1:43877",
 | |
|                     "accelerators": [
 | |
|                         "0",
 | |
|                         "1"
 | |
|                     ],
 | |
|                     "model_name": "chatglm3-6b",
 | |
|                     "model_lang": [
 | |
|                         "en"
 | |
|                     ],
 | |
|                     "model_ability": [
 | |
|                         "generate",
 | |
|                         "chat"
 | |
|                     ],
 | |
|                     "model_description": "latest chatglm3",
 | |
|                     "model_format": "pytorch",
 | |
|                     "model_size_in_billions": 7,
 | |
|                     "quantization": "none",
 | |
|                     "model_hub": "huggingface",
 | |
|                     "revision": null,
 | |
|                     "context_length": 2048,
 | |
|                     "replica": 1
 | |
|                 }'''
 | |
|                 return response
 | |
|             
 | |
|             elif model_uid == 'embedding':
 | |
|                 response.status_code = 200
 | |
|                 response._content = b'''{
 | |
|                     "model_type": "embedding",
 | |
|                     "address": "127.0.0.1:43877",
 | |
|                     "accelerators": [
 | |
|                         "0",
 | |
|                         "1"
 | |
|                     ],
 | |
|                     "model_name": "bge",
 | |
|                     "model_lang": [
 | |
|                         "en"
 | |
|                     ],
 | |
|                     "revision": null,
 | |
|                     "max_tokens": 512
 | |
|                 }'''
 | |
|                 return response
 | |
|             
 | |
|         elif 'v1/cluster/auth' in url:
 | |
|             response.status_code = 200
 | |
|             response._content = b'''{
 | |
|                 "auth": true
 | |
|             }'''
 | |
|             return response
 | |
|         
 | |
|     def _check_cluster_authenticated(self):
 | |
|         self._cluster_authed = True
 | |
|         
 | |
|     def rerank(self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int, return_documents: bool) -> dict:
 | |
|         # check if self._model_uid is a valid uuid
 | |
|         if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
 | |
|             self._model_uid != 'rerank':
 | |
|             raise RuntimeError('404 Not Found')
 | |
|         
 | |
|         if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._base_url):
 | |
|             raise RuntimeError('404 Not Found')
 | |
| 
 | |
|         if top_n is None:
 | |
|             top_n = 1
 | |
| 
 | |
|         return {
 | |
|             'results': [
 | |
|                 {
 | |
|                     'index': i,
 | |
|                     'document': doc,
 | |
|                     'relevance_score': 0.9
 | |
|                 }
 | |
|                 for i, doc in enumerate(documents[:top_n])
 | |
|             ]
 | |
|         }
 | |
|         
 | |
|     def create_embedding(
 | |
|         self: RESTfulGenerateModelHandle,
 | |
|         input: Union[str, list[str]],
 | |
|         **kwargs
 | |
|     ) -> dict:
 | |
|         # check if self._model_uid is a valid uuid
 | |
|         if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
 | |
|             self._model_uid != 'embedding':
 | |
|             raise RuntimeError('404 Not Found')
 | |
| 
 | |
|         if isinstance(input, str):
 | |
|             input = [input]
 | |
|         ipt_len = len(input)
 | |
| 
 | |
|         embedding = Embedding(
 | |
|             object="list",
 | |
|             model=self._model_uid,
 | |
|             data=[
 | |
|                 EmbeddingData(
 | |
|                     index=i,
 | |
|                     object="embedding",
 | |
|                     embedding=[1919.810 for _ in range(768)]
 | |
|                 )
 | |
|                 for i in range(ipt_len)
 | |
|             ],
 | |
|             usage=EmbeddingUsage(
 | |
|                 prompt_tokens=ipt_len,
 | |
|                 total_tokens=ipt_len
 | |
|             )
 | |
|         )
 | |
| 
 | |
|         return embedding
 | |
| 
 | |
| MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
 | |
| 
 | |
| @pytest.fixture
 | |
| def setup_xinference_mock(request, monkeypatch: MonkeyPatch):
 | |
|     if MOCK:
 | |
|         monkeypatch.setattr(Client, 'get_model', MockXinferenceClass.get_chat_model)
 | |
|         monkeypatch.setattr(Client, '_check_cluster_authenticated', MockXinferenceClass._check_cluster_authenticated)
 | |
|         monkeypatch.setattr(Session, 'get', MockXinferenceClass.get)
 | |
|         monkeypatch.setattr(RESTfulEmbeddingModelHandle, 'create_embedding', MockXinferenceClass.create_embedding)
 | |
|         monkeypatch.setattr(RESTfulRerankModelHandle, 'rerank', MockXinferenceClass.rerank)
 | |
|     yield
 | |
| 
 | |
|     if MOCK:
 | |
|         monkeypatch.undo() | 
