| 
									
										
										
										
											2024-05-31 09:53:04 +08:00
										 |  |  | # | 
					
						
							|  |  |  | #  Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | #  Licensed under the Apache License, Version 2.0 (the "License"); | 
					
						
							|  |  |  | #  you may not use this file except in compliance with the License. | 
					
						
							|  |  |  | #  You may obtain a copy of the License at | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | #      http://www.apache.org/licenses/LICENSE-2.0 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | #  Unless required by applicable law or agreed to in writing, software | 
					
						
							|  |  |  | #  distributed under the License is distributed on an "AS IS" BASIS, | 
					
						
							|  |  |  | #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
					
						
							|  |  |  | #  See the License for the specific language governing permissions and | 
					
						
							|  |  |  | #  limitations under the License. | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-12 11:57:08 +08:00
										 |  |  | import argparse | 
					
						
							|  |  |  | import pickle | 
					
						
							|  |  |  | import random | 
					
						
							|  |  |  | import time | 
					
						
							| 
									
										
										
										
											2024-05-20 12:23:51 +08:00
										 |  |  | from copy import deepcopy | 
					
						
							| 
									
										
										
										
											2024-03-12 11:57:08 +08:00
										 |  |  | from multiprocessing.connection import Listener | 
					
						
							|  |  |  | from threading import Thread | 
					
						
							| 
									
										
										
										
											2024-05-20 12:23:51 +08:00
										 |  |  | from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer | 
					
						
							| 
									
										
										
										
											2024-03-12 11:57:08 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-31 19:09:42 +08:00
										 |  |  | def torch_gc(): | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         import torch | 
					
						
							|  |  |  |         if torch.cuda.is_available(): | 
					
						
							|  |  |  |             # with torch.cuda.device(DEVICE): | 
					
						
							|  |  |  |             torch.cuda.empty_cache() | 
					
						
							|  |  |  |             torch.cuda.ipc_collect() | 
					
						
							|  |  |  |         elif torch.backends.mps.is_available(): | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 from torch.mps import empty_cache | 
					
						
							|  |  |  |                 empty_cache() | 
					
						
							|  |  |  |             except Exception as e: | 
					
						
							|  |  |  |                 pass | 
					
						
							|  |  |  |     except Exception: | 
					
						
							|  |  |  |         pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-12 11:57:08 +08:00
										 |  |  | class RPCHandler: | 
					
						
							|  |  |  |     def __init__(self): | 
					
						
							| 
									
										
										
										
											2024-03-27 11:33:46 +08:00
										 |  |  |         self._functions = {} | 
					
						
							| 
									
										
										
										
											2024-03-12 11:57:08 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def register_function(self, func): | 
					
						
							|  |  |  |         self._functions[func.__name__] = func | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def handle_connection(self, connection): | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             while True: | 
					
						
							|  |  |  |                 # Receive a message | 
					
						
							|  |  |  |                 func_name, args, kwargs = pickle.loads(connection.recv()) | 
					
						
							|  |  |  |                 # Run the RPC and send a response | 
					
						
							|  |  |  |                 try: | 
					
						
							| 
									
										
										
										
											2024-03-27 11:33:46 +08:00
										 |  |  |                     r = self._functions[func_name](*args, **kwargs) | 
					
						
							| 
									
										
										
										
											2024-03-12 11:57:08 +08:00
										 |  |  |                     connection.send(pickle.dumps(r)) | 
					
						
							|  |  |  |                 except Exception as e: | 
					
						
							|  |  |  |                     connection.send(pickle.dumps(e)) | 
					
						
							|  |  |  |         except EOFError: | 
					
						
							| 
									
										
										
										
											2024-03-27 11:33:46 +08:00
										 |  |  |             pass | 
					
						
							| 
									
										
										
										
											2024-03-12 11:57:08 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def rpc_server(hdlr, address, authkey): | 
					
						
							|  |  |  |     sock = Listener(address, authkey=authkey) | 
					
						
							|  |  |  |     while True: | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             client = sock.accept() | 
					
						
							|  |  |  |             t = Thread(target=hdlr.handle_connection, args=(client,)) | 
					
						
							|  |  |  |             t.daemon = True | 
					
						
							|  |  |  |             t.start() | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             print("【EXCEPTION】:", str(e)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | models = [] | 
					
						
							|  |  |  | tokenizer = None | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-27 11:33:46 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-12 11:57:08 +08:00
										 |  |  | def chat(messages, gen_conf): | 
					
						
							|  |  |  |     global tokenizer | 
					
						
							|  |  |  |     model = Model() | 
					
						
							| 
									
										
										
										
											2024-03-14 19:45:29 +08:00
										 |  |  |     try: | 
					
						
							| 
									
										
										
										
											2024-03-31 19:09:42 +08:00
										 |  |  |         torch_gc() | 
					
						
							| 
									
										
										
										
											2024-03-27 11:33:46 +08:00
										 |  |  |         conf = { | 
					
						
							|  |  |  |             "max_new_tokens": int( | 
					
						
							|  |  |  |                 gen_conf.get( | 
					
						
							|  |  |  |                     "max_tokens", 256)), "temperature": float( | 
					
						
							|  |  |  |                 gen_conf.get( | 
					
						
							|  |  |  |                     "temperature", 0.1))} | 
					
						
							| 
									
										
										
										
											2024-03-14 19:45:29 +08:00
										 |  |  |         print(messages, conf) | 
					
						
							|  |  |  |         text = tokenizer.apply_chat_template( | 
					
						
							|  |  |  |             messages, | 
					
						
							|  |  |  |             tokenize=False, | 
					
						
							|  |  |  |             add_generation_prompt=True | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         model_inputs = tokenizer([text], return_tensors="pt").to(model.device) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         generated_ids = model.generate( | 
					
						
							|  |  |  |             model_inputs.input_ids, | 
					
						
							|  |  |  |             **conf | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         generated_ids = [ | 
					
						
							|  |  |  |             output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) | 
					
						
							|  |  |  |         ] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-27 11:33:46 +08:00
										 |  |  |         return tokenizer.batch_decode( | 
					
						
							|  |  |  |             generated_ids, skip_special_tokens=True)[0] | 
					
						
							| 
									
										
										
										
											2024-03-14 19:45:29 +08:00
										 |  |  |     except Exception as e: | 
					
						
							|  |  |  |         return str(e) | 
					
						
							| 
									
										
										
										
											2024-03-12 11:57:08 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-05-20 12:23:51 +08:00
										 |  |  | def chat_streamly(messages, gen_conf): | 
					
						
							|  |  |  |     global tokenizer | 
					
						
							|  |  |  |     model = Model() | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         torch_gc() | 
					
						
							|  |  |  |         conf = deepcopy(gen_conf) | 
					
						
							|  |  |  |         print(messages, conf) | 
					
						
							|  |  |  |         text = tokenizer.apply_chat_template( | 
					
						
							|  |  |  |             messages, | 
					
						
							|  |  |  |             tokenize=False, | 
					
						
							|  |  |  |             add_generation_prompt=True | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         model_inputs = tokenizer([text], return_tensors="pt").to(model.device) | 
					
						
							|  |  |  |         streamer = TextStreamer(tokenizer) | 
					
						
							|  |  |  |         conf["inputs"] = model_inputs.input_ids | 
					
						
							|  |  |  |         conf["streamer"] = streamer | 
					
						
							|  |  |  |         conf["max_new_tokens"] = conf["max_tokens"] | 
					
						
							|  |  |  |         del conf["max_tokens"] | 
					
						
							|  |  |  |         thread = Thread(target=model.generate, kwargs=conf) | 
					
						
							|  |  |  |         thread.start() | 
					
						
							|  |  |  |         for _, new_text in enumerate(streamer): | 
					
						
							|  |  |  |             yield new_text | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         yield "**ERROR**: " + str(e) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-12 11:57:08 +08:00
										 |  |  | def Model(): | 
					
						
							|  |  |  |     global models | 
					
						
							|  |  |  |     random.seed(time.time()) | 
					
						
							|  |  |  |     return random.choice(models) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-27 11:33:46 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-12 11:57:08 +08:00
										 |  |  | if __name__ == "__main__": | 
					
						
							|  |  |  |     parser = argparse.ArgumentParser() | 
					
						
							|  |  |  |     parser.add_argument("--model_name", type=str, help="Model name") | 
					
						
							| 
									
										
										
										
											2024-03-27 11:33:46 +08:00
										 |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--port", | 
					
						
							|  |  |  |         default=7860, | 
					
						
							|  |  |  |         type=int, | 
					
						
							|  |  |  |         help="RPC serving port") | 
					
						
							| 
									
										
										
										
											2024-03-12 11:57:08 +08:00
										 |  |  |     args = parser.parse_args() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     handler = RPCHandler() | 
					
						
							|  |  |  |     handler.register_function(chat) | 
					
						
							| 
									
										
										
										
											2024-05-20 12:23:51 +08:00
										 |  |  |     handler.register_function(chat_streamly) | 
					
						
							| 
									
										
										
										
											2024-03-12 11:57:08 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     models = [] | 
					
						
							| 
									
										
										
										
											2024-03-14 19:45:29 +08:00
										 |  |  |     for _ in range(1): | 
					
						
							| 
									
										
										
										
											2024-03-12 11:57:08 +08:00
										 |  |  |         m = AutoModelForCausalLM.from_pretrained(args.model_name, | 
					
						
							|  |  |  |                                                  device_map="auto", | 
					
						
							| 
									
										
										
										
											2024-03-14 19:45:29 +08:00
										 |  |  |                                                  torch_dtype='auto') | 
					
						
							| 
									
										
										
										
											2024-03-12 11:57:08 +08:00
										 |  |  |         models.append(m) | 
					
						
							| 
									
										
										
										
											2024-03-14 19:45:29 +08:00
										 |  |  |     tokenizer = AutoTokenizer.from_pretrained(args.model_name) | 
					
						
							| 
									
										
										
										
											2024-03-12 11:57:08 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # Run the server | 
					
						
							| 
									
										
										
										
											2024-03-27 11:33:46 +08:00
										 |  |  |     rpc_server(handler, ('0.0.0.0', args.port), | 
					
						
							|  |  |  |                authkey=b'infiniflow-token4kevinhu') |