| 
									
										
										
										
											2024-01-15 08:46:22 +08:00
										 |  |  | # | 
					
						
							| 
									
										
										
										
											2024-01-19 19:51:57 +08:00
										 |  |  | #  Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | 
					
						
							| 
									
										
										
										
											2024-01-15 08:46:22 +08:00
										 |  |  | # | 
					
						
							|  |  |  | #  Licensed under the Apache License, Version 2.0 (the "License"); | 
					
						
							|  |  |  | #  you may not use this file except in compliance with the License. | 
					
						
							|  |  |  | #  You may obtain a copy of the License at | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | #      http://www.apache.org/licenses/LICENSE-2.0 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | #  Unless required by applicable law or agreed to in writing, software | 
					
						
							|  |  |  | #  distributed under the License is distributed on an "AS IS" BASIS, | 
					
						
							|  |  |  | #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
					
						
							|  |  |  | #  See the License for the specific language governing permissions and | 
					
						
							|  |  |  | #  limitations under the License. | 
					
						
							|  |  |  | # | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  | import io | 
					
						
							| 
									
										
										
										
											2023-12-28 13:50:13 +08:00
										 |  |  | from abc import ABC | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | from PIL import Image | 
					
						
							| 
									
										
										
										
											2023-12-28 13:50:13 +08:00
										 |  |  | from openai import OpenAI | 
					
						
							|  |  |  | import os | 
					
						
							|  |  |  | import base64 | 
					
						
							|  |  |  | from io import BytesIO | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  | from api.utils import get_uuid | 
					
						
							|  |  |  | from api.utils.file_utils import get_project_base_directory | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-28 13:50:13 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | class Base(ABC): | 
					
						
							| 
									
										
										
										
											2024-01-15 08:46:22 +08:00
										 |  |  |     def __init__(self, key, model_name): | 
					
						
							|  |  |  |         pass | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-28 13:50:13 +08:00
										 |  |  |     def describe(self, image, max_tokens=300): | 
					
						
							|  |  |  |         raise NotImplementedError("Please implement encode method!") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def image2base64(self, image): | 
					
						
							| 
									
										
										
										
											2024-01-22 19:51:38 +08:00
										 |  |  |         if isinstance(image, bytes): | 
					
						
							|  |  |  |             return base64.b64encode(image).decode("utf-8") | 
					
						
							| 
									
										
										
										
											2023-12-28 13:50:13 +08:00
										 |  |  |         if isinstance(image, BytesIO): | 
					
						
							|  |  |  |             return base64.b64encode(image.getvalue()).decode("utf-8") | 
					
						
							|  |  |  |         buffered = BytesIO() | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             image.save(buffered, format="JPEG") | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             image.save(buffered, format="PNG") | 
					
						
							|  |  |  |         return base64.b64encode(buffered.getvalue()).decode("utf-8") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def prompt(self, b64): | 
					
						
							|  |  |  |         return [ | 
					
						
							|  |  |  |             { | 
					
						
							|  |  |  |                 "role": "user", | 
					
						
							|  |  |  |                 "content": [ | 
					
						
							|  |  |  |                     { | 
					
						
							|  |  |  |                         "type": "image_url", | 
					
						
							|  |  |  |                         "image_url": { | 
					
						
							|  |  |  |                             "url": f"data:image/jpeg;base64,{b64}" | 
					
						
							|  |  |  |                         }, | 
					
						
							|  |  |  |                     }, | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |                     { | 
					
						
							|  |  |  |                         "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \ | 
					
						
							|  |  |  |                             "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.", | 
					
						
							|  |  |  |                     }, | 
					
						
							| 
									
										
										
										
											2023-12-28 13:50:13 +08:00
										 |  |  |                 ], | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         ] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class GptV4(Base): | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |     def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese"): | 
					
						
							| 
									
										
										
										
											2024-02-08 17:01:01 +08:00
										 |  |  |         self.client = OpenAI(api_key=key) | 
					
						
							| 
									
										
										
										
											2024-01-15 08:46:22 +08:00
										 |  |  |         self.model_name = model_name | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |         self.lang = lang | 
					
						
							| 
									
										
										
										
											2023-12-28 13:50:13 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def describe(self, image, max_tokens=300): | 
					
						
							|  |  |  |         b64 = self.image2base64(image) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         res = self.client.chat.completions.create( | 
					
						
							| 
									
										
										
										
											2024-01-15 08:46:22 +08:00
										 |  |  |             model=self.model_name, | 
					
						
							| 
									
										
										
										
											2023-12-28 13:50:13 +08:00
										 |  |  |             messages=self.prompt(b64), | 
					
						
							|  |  |  |             max_tokens=max_tokens, | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-01-23 19:45:36 +08:00
										 |  |  |         return res.choices[0].message.content.strip(), res.usage.total_tokens | 
					
						
							| 
									
										
										
										
											2023-12-28 13:50:13 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class QWenCV(Base): | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |     def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese"): | 
					
						
							| 
									
										
										
										
											2024-01-15 08:46:22 +08:00
										 |  |  |         import dashscope | 
					
						
							|  |  |  |         dashscope.api_key = key | 
					
						
							|  |  |  |         self.model_name = model_name | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |         self.lang = lang | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def prompt(self, binary): | 
					
						
							|  |  |  |         # stupid as hell | 
					
						
							|  |  |  |         tmp_dir = get_project_base_directory("tmp") | 
					
						
							|  |  |  |         if not os.path.exists(tmp_dir): os.mkdir(tmp_dir) | 
					
						
							|  |  |  |         path = os.path.join(tmp_dir, "%s.jpg"%get_uuid()) | 
					
						
							|  |  |  |         Image.open(io.BytesIO(binary)).save(path) | 
					
						
							|  |  |  |         return [ | 
					
						
							|  |  |  |             { | 
					
						
							|  |  |  |                 "role": "user", | 
					
						
							|  |  |  |                 "content": [ | 
					
						
							|  |  |  |                     { | 
					
						
							|  |  |  |                         "image": f"file://{path}" | 
					
						
							|  |  |  |                     }, | 
					
						
							|  |  |  |                     { | 
					
						
							|  |  |  |                         "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \ | 
					
						
							|  |  |  |                             "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.", | 
					
						
							|  |  |  |                     }, | 
					
						
							|  |  |  |                 ], | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         ] | 
					
						
							| 
									
										
										
										
											2024-01-15 08:46:22 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-28 13:50:13 +08:00
										 |  |  |     def describe(self, image, max_tokens=300): | 
					
						
							|  |  |  |         from http import HTTPStatus | 
					
						
							|  |  |  |         from dashscope import MultiModalConversation | 
					
						
							| 
									
										
										
										
											2024-01-15 08:46:22 +08:00
										 |  |  |         response = MultiModalConversation.call(model=self.model_name, | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |                                                messages=self.prompt(image)) | 
					
						
							| 
									
										
										
										
											2023-12-28 13:50:13 +08:00
										 |  |  |         if response.status_code == HTTPStatus.OK: | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |             return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens | 
					
						
							| 
									
										
										
										
											2024-01-23 19:45:36 +08:00
										 |  |  |         return response.message, 0 | 
					
						
							| 
									
										
										
										
											2024-02-08 17:01:01 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from zhipuai import ZhipuAI | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class Zhipu4V(Base): | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |     def __init__(self, key, model_name="glm-4v", lang="Chinese"): | 
					
						
							| 
									
										
										
										
											2024-02-08 17:01:01 +08:00
										 |  |  |         self.client = ZhipuAI(api_key=key) | 
					
						
							|  |  |  |         self.model_name = model_name | 
					
						
							| 
									
										
										
										
											2024-02-23 18:28:12 +08:00
										 |  |  |         self.lang = lang | 
					
						
							| 
									
										
										
										
											2024-02-08 17:01:01 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def describe(self, image, max_tokens=1024): | 
					
						
							|  |  |  |         b64 = self.image2base64(image) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         res = self.client.chat.completions.create( | 
					
						
							|  |  |  |             model=self.model_name, | 
					
						
							|  |  |  |             messages=self.prompt(b64), | 
					
						
							|  |  |  |             max_tokens=max_tokens, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         return res.choices[0].message.content.strip(), res.usage.total_tokens | 
					
						
							| 
									
										
										
										
											2024-03-12 11:57:08 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class LocalCV(Base): | 
					
						
							|  |  |  |     def __init__(self, key, model_name="glm-4v", lang="Chinese"): | 
					
						
							|  |  |  |         pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def describe(self, image, max_tokens=1024): | 
					
						
							|  |  |  |         return "", 0 |