| 
									
										
										
										
											2024-01-15 19:47:25 +08:00
										 |  |  | # | 
					
						
							| 
									
										
										
										
											2024-01-19 19:51:57 +08:00
										 |  |  | #  Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | 
					
						
							| 
									
										
										
										
											2024-01-15 19:47:25 +08:00
										 |  |  | # | 
					
						
							|  |  |  | #  Licensed under the Apache License, Version 2.0 (the "License"); | 
					
						
							|  |  |  | #  you may not use this file except in compliance with the License. | 
					
						
							|  |  |  | #  You may obtain a copy of the License at | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | #      http://www.apache.org/licenses/LICENSE-2.0 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | #  Unless required by applicable law or agreed to in writing, software | 
					
						
							|  |  |  | #  distributed under the License is distributed on an "AS IS" BASIS, | 
					
						
							|  |  |  | #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
					
						
							|  |  |  | #  See the License for the specific language governing permissions and | 
					
						
							|  |  |  | #  limitations under the License. | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | from flask import request | 
					
						
							|  |  |  | from flask_login import login_required, current_user | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  | from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService | 
					
						
							|  |  |  | from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | 
					
						
							| 
									
										
										
										
											2024-03-11 16:19:03 +08:00
										 |  |  | from api.db import StatusEnum, LLMType | 
					
						
							|  |  |  | from api.db.db_models import TenantLLM | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  | from api.utils.api_utils import get_json_result | 
					
						
							| 
									
										
										
										
											2024-05-29 16:50:02 +08:00
										 |  |  | from rag.llm import EmbeddingModel, ChatModel, RerankModel | 
					
						
							| 
									
										
										
										
											2024-01-15 19:47:25 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @manager.route('/factories', methods=['GET']) | 
					
						
							|  |  |  | @login_required | 
					
						
							|  |  |  | def factories(): | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         fac = LLMFactoriesService.get_all() | 
					
						
							| 
									
										
										
										
											2024-05-29 16:50:02 +08:00
										 |  |  |         return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed", "BAAI"]]) | 
					
						
							| 
									
										
										
										
											2024-01-15 19:47:25 +08:00
										 |  |  |     except Exception as e: | 
					
						
							|  |  |  |         return server_error_response(e) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @manager.route('/set_api_key', methods=['POST']) | 
					
						
							|  |  |  | @login_required | 
					
						
							|  |  |  | @validate_request("llm_factory", "api_key") | 
					
						
							|  |  |  | def set_api_key(): | 
					
						
							|  |  |  |     req = request.json | 
					
						
							| 
									
										
										
										
											2024-02-19 19:22:17 +08:00
										 |  |  |     # test if api key works | 
					
						
							| 
									
										
										
										
											2024-05-29 19:38:57 +08:00
										 |  |  |     chat_passed, embd_passed, rerank_passed = False, False, False | 
					
						
							| 
									
										
										
										
											2024-03-15 18:59:00 +08:00
										 |  |  |     factory = req["llm_factory"] | 
					
						
							| 
									
										
										
										
											2024-02-19 19:22:17 +08:00
										 |  |  |     msg = "" | 
					
						
							| 
									
										
										
										
											2024-03-15 18:59:00 +08:00
										 |  |  |     for llm in LLMService.query(fid=factory): | 
					
						
							| 
									
										
										
										
											2024-05-29 19:38:57 +08:00
										 |  |  |         if not embd_passed and llm.model_type == LLMType.EMBEDDING.value: | 
					
						
							| 
									
										
										
										
											2024-03-15 18:59:00 +08:00
										 |  |  |             mdl = EmbeddingModel[factory]( | 
					
						
							| 
									
										
										
										
											2024-04-01 19:03:13 +08:00
										 |  |  |                 req["api_key"], llm.llm_name, base_url=req.get("base_url")) | 
					
						
							| 
									
										
										
										
											2024-02-19 19:22:17 +08:00
										 |  |  |             try: | 
					
						
							|  |  |  |                 arr, tc = mdl.encode(["Test if the api key is available"]) | 
					
						
							| 
									
										
										
										
											2024-03-27 11:33:46 +08:00
										 |  |  |                 if len(arr[0]) == 0 or tc == 0: | 
					
						
							|  |  |  |                     raise Exception("Fail") | 
					
						
							| 
									
										
										
										
											2024-05-29 19:38:57 +08:00
										 |  |  |                 embd_passed = True | 
					
						
							| 
									
										
										
										
											2024-02-19 19:22:17 +08:00
										 |  |  |             except Exception as e: | 
					
						
							| 
									
										
										
										
											2024-03-27 17:55:45 +08:00
										 |  |  |                 msg += f"\nFail to access embedding model({llm.llm_name}) using this api key." + str(e) | 
					
						
							| 
									
										
										
										
											2024-03-15 18:59:00 +08:00
										 |  |  |         elif not chat_passed and llm.model_type == LLMType.CHAT.value: | 
					
						
							|  |  |  |             mdl = ChatModel[factory]( | 
					
						
							| 
									
										
										
										
											2024-04-01 19:03:13 +08:00
										 |  |  |                 req["api_key"], llm.llm_name, base_url=req.get("base_url")) | 
					
						
							| 
									
										
										
										
											2024-02-19 19:22:17 +08:00
										 |  |  |             try: | 
					
						
							| 
									
										
										
										
											2024-03-27 11:33:46 +08:00
										 |  |  |                 m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], { | 
					
						
							|  |  |  |                                  "temperature": 0.9}) | 
					
						
							|  |  |  |                 if not tc: | 
					
						
							|  |  |  |                     raise Exception(m) | 
					
						
							| 
									
										
										
										
											2024-02-19 19:22:17 +08:00
										 |  |  |             except Exception as e: | 
					
						
							| 
									
										
										
										
											2024-03-27 11:33:46 +08:00
										 |  |  |                 msg += f"\nFail to access model({llm.llm_name}) using this api key." + str( | 
					
						
							|  |  |  |                     e) | 
					
						
							| 
									
										
										
										
											2024-05-29 19:38:57 +08:00
										 |  |  |             chat_passed = True | 
					
						
							|  |  |  |         elif not rerank_passed and llm.model_type == LLMType.RERANK: | 
					
						
							| 
									
										
										
										
											2024-05-29 16:50:02 +08:00
										 |  |  |             mdl = RerankModel[factory]( | 
					
						
							|  |  |  |                 req["api_key"], llm.llm_name, base_url=req.get("base_url")) | 
					
						
							|  |  |  |             try: | 
					
						
							| 
									
										
										
										
											2024-05-29 19:38:57 +08:00
										 |  |  |                 arr, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"]) | 
					
						
							|  |  |  |                 if len(arr) == 0 or tc == 0: | 
					
						
							| 
									
										
										
										
											2024-05-29 16:50:02 +08:00
										 |  |  |                     raise Exception("Fail") | 
					
						
							|  |  |  |             except Exception as e: | 
					
						
							|  |  |  |                 msg += f"\nFail to access model({llm.llm_name}) using this api key." + str( | 
					
						
							|  |  |  |                     e) | 
					
						
							| 
									
										
										
										
											2024-05-29 19:38:57 +08:00
										 |  |  |             rerank_passed = True | 
					
						
							| 
									
										
										
										
											2024-02-19 19:22:17 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-27 11:33:46 +08:00
										 |  |  |     if msg: | 
					
						
							|  |  |  |         return get_data_error_result(retmsg=msg) | 
					
						
							| 
									
										
										
										
											2024-02-19 19:22:17 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-15 19:47:25 +08:00
										 |  |  |     llm = { | 
					
						
							| 
									
										
										
										
											2024-03-29 10:48:29 +08:00
										 |  |  |         "api_key": req["api_key"], | 
					
						
							|  |  |  |         "api_base": req.get("base_url", "") | 
					
						
							| 
									
										
										
										
											2024-01-15 19:47:25 +08:00
										 |  |  |     } | 
					
						
							|  |  |  |     for n in ["model_type", "llm_name"]: | 
					
						
							| 
									
										
										
										
											2024-03-27 11:33:46 +08:00
										 |  |  |         if n in req: | 
					
						
							|  |  |  |             llm[n] = req[n] | 
					
						
							| 
									
										
										
										
											2024-01-15 19:47:25 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-27 11:33:46 +08:00
										 |  |  |     if not TenantLLMService.filter_update( | 
					
						
							|  |  |  |             [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory], llm): | 
					
						
							| 
									
										
										
										
											2024-03-15 18:59:00 +08:00
										 |  |  |         for llm in LLMService.query(fid=factory): | 
					
						
							| 
									
										
										
										
											2024-03-27 11:33:46 +08:00
										 |  |  |             TenantLLMService.save( | 
					
						
							|  |  |  |                 tenant_id=current_user.id, | 
					
						
							|  |  |  |                 llm_factory=factory, | 
					
						
							|  |  |  |                 llm_name=llm.llm_name, | 
					
						
							|  |  |  |                 model_type=llm.model_type, | 
					
						
							| 
									
										
										
										
											2024-03-28 19:15:16 +08:00
										 |  |  |                 api_key=req["api_key"], | 
					
						
							|  |  |  |                 api_base=req.get("base_url", "") | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2024-03-15 18:59:00 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-15 19:47:25 +08:00
										 |  |  |     return get_json_result(data=True) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-08 19:20:57 +08:00
										 |  |  | @manager.route('/add_llm', methods=['POST']) | 
					
						
							|  |  |  | @login_required | 
					
						
							|  |  |  | @validate_request("llm_factory", "llm_name", "model_type") | 
					
						
							|  |  |  | def add_llm(): | 
					
						
							|  |  |  |     req = request.json | 
					
						
							| 
									
										
										
										
											2024-05-23 11:15:29 +08:00
										 |  |  |     factory = req["llm_factory"] | 
					
						
							|  |  |  |     # For VolcEngine, due to its special authentication method | 
					
						
							|  |  |  |     # Assemble volc_ak, volc_sk, endpoint_id into api_key | 
					
						
							|  |  |  |     if factory == "VolcEngine": | 
					
						
							|  |  |  |         temp = list(eval(req["llm_name"]).items())[0] | 
					
						
							|  |  |  |         llm_name = temp[0] | 
					
						
							|  |  |  |         endpoint_id = temp[1] | 
					
						
							|  |  |  |         api_key = '{' + f'"volc_ak": "{req.get("volc_ak", "")}", ' \ | 
					
						
							|  |  |  |                         f'"volc_sk": "{req.get("volc_sk", "")}", ' \ | 
					
						
							|  |  |  |                         f'"ep_id": "{endpoint_id}", ' + '}' | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         llm_name = req["llm_name"] | 
					
						
							|  |  |  |         api_key = "xxxxxxxxxxxxxxx" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-08 19:20:57 +08:00
										 |  |  |     llm = { | 
					
						
							|  |  |  |         "tenant_id": current_user.id, | 
					
						
							| 
									
										
										
										
											2024-05-23 11:15:29 +08:00
										 |  |  |         "llm_factory": factory, | 
					
						
							| 
									
										
										
										
											2024-04-08 19:20:57 +08:00
										 |  |  |         "model_type": req["model_type"], | 
					
						
							| 
									
										
										
										
											2024-05-23 11:15:29 +08:00
										 |  |  |         "llm_name": llm_name, | 
					
						
							| 
									
										
										
										
											2024-04-08 19:20:57 +08:00
										 |  |  |         "api_base": req.get("api_base", ""), | 
					
						
							| 
									
										
										
										
											2024-05-23 11:15:29 +08:00
										 |  |  |         "api_key": api_key | 
					
						
							| 
									
										
										
										
											2024-04-08 19:20:57 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     msg = "" | 
					
						
							|  |  |  |     if llm["model_type"] == LLMType.EMBEDDING.value: | 
					
						
							|  |  |  |         mdl = EmbeddingModel[factory]( | 
					
						
							|  |  |  |             key=None, model_name=llm["llm_name"], base_url=llm["api_base"]) | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             arr, tc = mdl.encode(["Test if the api key is available"]) | 
					
						
							|  |  |  |             if len(arr[0]) == 0 or tc == 0: | 
					
						
							|  |  |  |                 raise Exception("Fail") | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             msg += f"\nFail to access embedding model({llm['llm_name']})." + str(e) | 
					
						
							|  |  |  |     elif llm["model_type"] == LLMType.CHAT.value: | 
					
						
							|  |  |  |         mdl = ChatModel[factory]( | 
					
						
							| 
									
										
										
										
											2024-05-23 11:15:29 +08:00
										 |  |  |             key=llm['api_key'] if factory == "VolcEngine" else None, | 
					
						
							|  |  |  |             model_name=llm["llm_name"], | 
					
						
							|  |  |  |             base_url=llm["api_base"] | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-04-08 19:20:57 +08:00
										 |  |  |         try: | 
					
						
							|  |  |  |             m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], { | 
					
						
							|  |  |  |                              "temperature": 0.9}) | 
					
						
							|  |  |  |             if not tc: | 
					
						
							|  |  |  |                 raise Exception(m) | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             msg += f"\nFail to access model({llm['llm_name']})." + str( | 
					
						
							|  |  |  |                 e) | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         # TODO: check other type of models | 
					
						
							|  |  |  |         pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if msg: | 
					
						
							|  |  |  |         return get_data_error_result(retmsg=msg) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if not TenantLLMService.filter_update( | 
					
						
							|  |  |  |             [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory, TenantLLM.llm_name == llm["llm_name"]], llm): | 
					
						
							|  |  |  |         TenantLLMService.save(**llm) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return get_json_result(data=True) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-05-08 15:20:45 +08:00
										 |  |  | @manager.route('/delete_llm', methods=['POST']) | 
					
						
							|  |  |  | @login_required | 
					
						
							|  |  |  | @validate_request("llm_factory", "llm_name") | 
					
						
							|  |  |  | def delete_llm(): | 
					
						
							|  |  |  |     req = request.json | 
					
						
							|  |  |  |     TenantLLMService.filter_delete( | 
					
						
							|  |  |  |             [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]]) | 
					
						
							|  |  |  |     return get_json_result(data=True) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-15 19:47:25 +08:00
										 |  |  | @manager.route('/my_llms', methods=['GET']) | 
					
						
							|  |  |  | @login_required | 
					
						
							|  |  |  | def my_llms(): | 
					
						
							|  |  |  |     try: | 
					
						
							| 
									
										
										
										
											2024-03-08 18:59:53 +08:00
										 |  |  |         res = {} | 
					
						
							|  |  |  |         for o in TenantLLMService.get_my_llms(current_user.id): | 
					
						
							|  |  |  |             if o["llm_factory"] not in res: | 
					
						
							|  |  |  |                 res[o["llm_factory"]] = { | 
					
						
							|  |  |  |                     "tags": o["tags"], | 
					
						
							|  |  |  |                     "llm": [] | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             res[o["llm_factory"]]["llm"].append({ | 
					
						
							|  |  |  |                 "type": o["model_type"], | 
					
						
							| 
									
										
										
										
											2024-03-12 11:57:08 +08:00
										 |  |  |                 "name": o["llm_name"], | 
					
						
							| 
									
										
										
										
											2024-03-08 18:59:53 +08:00
										 |  |  |                 "used_token": o["used_tokens"] | 
					
						
							|  |  |  |             }) | 
					
						
							|  |  |  |         return get_json_result(data=res) | 
					
						
							| 
									
										
										
										
											2024-01-15 19:47:25 +08:00
										 |  |  |     except Exception as e: | 
					
						
							|  |  |  |         return server_error_response(e) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @manager.route('/list', methods=['GET']) | 
					
						
							|  |  |  | @login_required | 
					
						
							| 
									
										
										
										
											2024-05-14 14:48:15 +08:00
										 |  |  | def list_app(): | 
					
						
							| 
									
										
										
										
											2024-02-19 19:22:17 +08:00
										 |  |  |     model_type = request.args.get("model_type") | 
					
						
							| 
									
										
										
										
											2024-01-15 19:47:25 +08:00
										 |  |  |     try: | 
					
						
							|  |  |  |         objs = TenantLLMService.query(tenant_id=current_user.id) | 
					
						
							| 
									
										
										
										
											2024-02-22 19:11:37 +08:00
										 |  |  |         facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key]) | 
					
						
							| 
									
										
										
										
											2024-01-15 19:47:25 +08:00
										 |  |  |         llms = LLMService.get_all() | 
					
						
							| 
									
										
										
										
											2024-03-27 11:33:46 +08:00
										 |  |  |         llms = [m.to_dict() | 
					
						
							|  |  |  |                 for m in llms if m.status == StatusEnum.VALID.value] | 
					
						
							| 
									
										
										
										
											2024-01-15 19:47:25 +08:00
										 |  |  |         for m in llms: | 
					
						
							| 
									
										
										
										
											2024-05-29 16:50:02 +08:00
										 |  |  |             m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["Youdao","FastEmbed", "BAAI"] | 
					
						
							| 
									
										
										
										
											2024-01-19 19:51:57 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-08 19:20:57 +08:00
										 |  |  |         llm_set = set([m["llm_name"] for m in llms]) | 
					
						
							|  |  |  |         for o in objs: | 
					
						
							|  |  |  |             if not o.api_key:continue | 
					
						
							|  |  |  |             if o.llm_name in llm_set:continue | 
					
						
							|  |  |  |             llms.append({"llm_name": o.llm_name, "model_type": o.model_type, "fid": o.llm_factory, "available": True}) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-15 19:47:25 +08:00
										 |  |  |         res = {} | 
					
						
							|  |  |  |         for m in llms: | 
					
						
							| 
									
										
										
										
											2024-05-15 11:16:08 +08:00
										 |  |  |             if model_type and m["model_type"].find(model_type)<0: | 
					
						
							| 
									
										
										
										
											2024-03-27 11:33:46 +08:00
										 |  |  |                 continue | 
					
						
							|  |  |  |             if m["fid"] not in res: | 
					
						
							|  |  |  |                 res[m["fid"]] = [] | 
					
						
							| 
									
										
										
										
											2024-01-15 19:47:25 +08:00
										 |  |  |             res[m["fid"]].append(m) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return get_json_result(data=res) | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							| 
									
										
										
										
											2024-03-11 16:19:03 +08:00
										 |  |  |         return server_error_response(e) |