| 
									
										
										
										
											2024-04-15 14:43:44 +08:00
										 |  |  | # | 
					
						
							|  |  |  | #  Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | #  Licensed under the Apache License, Version 2.0 (the "License"); | 
					
						
							|  |  |  | #  you may not use this file except in compliance with the License. | 
					
						
							|  |  |  | #  You may obtain a copy of the License at | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | #      http://www.apache.org/licenses/LICENSE-2.0 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | #  Unless required by applicable law or agreed to in writing, software | 
					
						
							|  |  |  | #  distributed under the License is distributed on an "AS IS" BASIS, | 
					
						
							|  |  |  | #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
					
						
							|  |  |  | #  See the License for the specific language governing permissions and | 
					
						
							|  |  |  | #  limitations under the License. | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | from datetime import datetime, timedelta | 
					
						
							|  |  |  | from flask import request | 
					
						
							|  |  |  | from flask_login import login_required, current_user | 
					
						
							|  |  |  | from api.db.db_models import APIToken, API4Conversation | 
					
						
							|  |  |  | from api.db.services.api_service import APITokenService, API4ConversationService | 
					
						
							|  |  |  | from api.db.services.dialog_service import DialogService, chat | 
					
						
							|  |  |  | from api.db.services.user_service import UserTenantService | 
					
						
							|  |  |  | from api.settings import RetCode | 
					
						
							|  |  |  | from api.utils import get_uuid, current_timestamp, datetime_format | 
					
						
							|  |  |  | from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request | 
					
						
							|  |  |  | from itsdangerous import URLSafeTimedSerializer | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def generate_confirmation_token(tenent_id): | 
					
						
							|  |  |  |     serializer = URLSafeTimedSerializer(tenent_id) | 
					
						
							|  |  |  |     return "ragflow-" + serializer.dumps(get_uuid(), salt=tenent_id)[2:34] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @manager.route('/new_token', methods=['POST']) | 
					
						
							|  |  |  | @validate_request("dialog_id") | 
					
						
							|  |  |  | @login_required | 
					
						
							|  |  |  | def new_token(): | 
					
						
							|  |  |  |     req = request.json | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         tenants = UserTenantService.query(user_id=current_user.id) | 
					
						
							|  |  |  |         if not tenants: | 
					
						
							|  |  |  |             return get_data_error_result(retmsg="Tenant not found!") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         tenant_id = tenants[0].tenant_id | 
					
						
							|  |  |  |         obj = {"tenant_id": tenant_id, "token": generate_confirmation_token(tenant_id), | 
					
						
							|  |  |  |                "dialog_id": req["dialog_id"], | 
					
						
							|  |  |  |                "create_time": current_timestamp(), | 
					
						
							|  |  |  |                "create_date": datetime_format(datetime.now()), | 
					
						
							|  |  |  |                "update_time": None, | 
					
						
							|  |  |  |                "update_date": None | 
					
						
							|  |  |  |                } | 
					
						
							|  |  |  |         if not APITokenService.save(**obj): | 
					
						
							|  |  |  |             return get_data_error_result(retmsg="Fail to new a dialog!") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return get_json_result(data=obj) | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         return server_error_response(e) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @manager.route('/token_list', methods=['GET']) | 
					
						
							|  |  |  | @login_required | 
					
						
							|  |  |  | def token_list(): | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         tenants = UserTenantService.query(user_id=current_user.id) | 
					
						
							|  |  |  |         if not tenants: | 
					
						
							|  |  |  |             return get_data_error_result(retmsg="Tenant not found!") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         objs = APITokenService.query(tenant_id=tenants[0].tenant_id, dialog_id=request.args["dialog_id"]) | 
					
						
							|  |  |  |         return get_json_result(data=[o.to_dict() for o in objs]) | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         return server_error_response(e) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @manager.route('/rm', methods=['POST']) | 
					
						
							|  |  |  | @validate_request("tokens", "tenant_id") | 
					
						
							|  |  |  | @login_required | 
					
						
							|  |  |  | def rm(): | 
					
						
							|  |  |  |     req = request.json | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         for token in req["tokens"]: | 
					
						
							|  |  |  |             APITokenService.filter_delete( | 
					
						
							|  |  |  |                 [APIToken.tenant_id == req["tenant_id"], APIToken.token == token]) | 
					
						
							|  |  |  |         return get_json_result(data=True) | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         return server_error_response(e) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @manager.route('/stats', methods=['GET']) | 
					
						
							|  |  |  | @login_required | 
					
						
							|  |  |  | def stats(): | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         tenants = UserTenantService.query(user_id=current_user.id) | 
					
						
							|  |  |  |         if not tenants: | 
					
						
							|  |  |  |             return get_data_error_result(retmsg="Tenant not found!") | 
					
						
							|  |  |  |         objs = API4ConversationService.stats( | 
					
						
							|  |  |  |             tenants[0].tenant_id, | 
					
						
							|  |  |  |             request.args.get( | 
					
						
							|  |  |  |                 "from_date", | 
					
						
							|  |  |  |                 (datetime.now() - | 
					
						
							|  |  |  |                  timedelta( | 
					
						
							|  |  |  |                     days=7)).strftime("%Y-%m-%d 24:00:00")), | 
					
						
							|  |  |  |             request.args.get( | 
					
						
							|  |  |  |                 "to_date", | 
					
						
							|  |  |  |                 datetime.now().strftime("%Y-%m-%d %H:%M:%S"))) | 
					
						
							|  |  |  |         res = { | 
					
						
							|  |  |  |             "pv": [(o["dt"], o["pv"]) for o in objs], | 
					
						
							|  |  |  |             "uv": [(o["dt"], o["uv"]) for o in objs], | 
					
						
							| 
									
										
										
										
											2024-04-19 18:02:53 +08:00
										 |  |  |             "speed": [(o["dt"], float(o["tokens"])/(float(o["duration"]+0.1))) for o in objs], | 
					
						
							| 
									
										
										
										
											2024-04-18 09:37:23 +08:00
										 |  |  |             "tokens": [(o["dt"], float(o["tokens"])/1000.) for o in objs], | 
					
						
							| 
									
										
										
										
											2024-04-15 14:43:44 +08:00
										 |  |  |             "round": [(o["dt"], o["round"]) for o in objs], | 
					
						
							|  |  |  |             "thumb_up": [(o["dt"], o["thumb_up"]) for o in objs] | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         return get_json_result(data=res) | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         return server_error_response(e) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-18 09:37:23 +08:00
										 |  |  | @manager.route('/new_conversation', methods=['GET']) | 
					
						
							| 
									
										
										
										
											2024-04-15 14:43:44 +08:00
										 |  |  | def set_conversation(): | 
					
						
							|  |  |  |     token = request.headers.get('Authorization').split()[1] | 
					
						
							|  |  |  |     objs = APIToken.query(token=token) | 
					
						
							|  |  |  |     if not objs: | 
					
						
							|  |  |  |         return get_json_result( | 
					
						
							|  |  |  |             data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR) | 
					
						
							|  |  |  |     req = request.json | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         e, dia = DialogService.get_by_id(objs[0].dialog_id) | 
					
						
							|  |  |  |         if not e: | 
					
						
							|  |  |  |             return get_data_error_result(retmsg="Dialog not found") | 
					
						
							|  |  |  |         conv = { | 
					
						
							|  |  |  |             "id": get_uuid(), | 
					
						
							|  |  |  |             "dialog_id": dia.id, | 
					
						
							| 
									
										
										
										
											2024-04-18 09:37:23 +08:00
										 |  |  |             "user_id": request.args.get("user_id", ""), | 
					
						
							| 
									
										
										
										
											2024-04-15 14:43:44 +08:00
										 |  |  |             "message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}] | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         API4ConversationService.save(**conv) | 
					
						
							|  |  |  |         e, conv = API4ConversationService.get_by_id(conv["id"]) | 
					
						
							|  |  |  |         if not e: | 
					
						
							|  |  |  |             return get_data_error_result(retmsg="Fail to new a conversation!") | 
					
						
							|  |  |  |         conv = conv.to_dict() | 
					
						
							|  |  |  |         return get_json_result(data=conv) | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         return server_error_response(e) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @manager.route('/completion', methods=['POST']) | 
					
						
							|  |  |  | @validate_request("conversation_id", "messages") | 
					
						
							|  |  |  | def completion(): | 
					
						
							|  |  |  |     token = request.headers.get('Authorization').split()[1] | 
					
						
							|  |  |  |     if not APIToken.query(token=token): | 
					
						
							|  |  |  |         return get_json_result( | 
					
						
							|  |  |  |             data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR) | 
					
						
							|  |  |  |     req = request.json | 
					
						
							|  |  |  |     e, conv = API4ConversationService.get_by_id(req["conversation_id"]) | 
					
						
							|  |  |  |     if not e: | 
					
						
							|  |  |  |         return get_data_error_result(retmsg="Conversation not found!") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     msg = [] | 
					
						
							|  |  |  |     for m in req["messages"]: | 
					
						
							|  |  |  |         if m["role"] == "system": | 
					
						
							|  |  |  |             continue | 
					
						
							|  |  |  |         if m["role"] == "assistant" and not msg: | 
					
						
							|  |  |  |             continue | 
					
						
							|  |  |  |         msg.append({"role": m["role"], "content": m["content"]}) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         conv.message.append(msg[-1]) | 
					
						
							|  |  |  |         e, dia = DialogService.get_by_id(conv.dialog_id) | 
					
						
							|  |  |  |         if not e: | 
					
						
							|  |  |  |             return get_data_error_result(retmsg="Dialog not found!") | 
					
						
							|  |  |  |         del req["conversation_id"] | 
					
						
							|  |  |  |         del req["messages"] | 
					
						
							|  |  |  |         ans = chat(dia, msg, **req) | 
					
						
							|  |  |  |         if not conv.reference: | 
					
						
							|  |  |  |             conv.reference = [] | 
					
						
							|  |  |  |         conv.reference.append(ans["reference"]) | 
					
						
							|  |  |  |         conv.message.append({"role": "assistant", "content": ans["answer"]}) | 
					
						
							|  |  |  |         API4ConversationService.append_message(conv.id, conv.to_dict()) | 
					
						
							|  |  |  |         return get_json_result(data=ans) | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         return server_error_response(e) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @manager.route('/conversation/<conversation_id>', methods=['GET']) | 
					
						
							|  |  |  | # @login_required | 
					
						
							|  |  |  | def get(conversation_id): | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         e, conv = API4ConversationService.get_by_id(conversation_id) | 
					
						
							|  |  |  |         if not e: | 
					
						
							|  |  |  |             return get_data_error_result(retmsg="Conversation not found!") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return get_json_result(data=conv.to_dict()) | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         return server_error_response(e) |