| 
									
										
										
										
											2024-01-15 08:46:22 +08:00
										 |  |  | # | 
					
						
							| 
									
										
										
										
											2024-01-19 19:51:57 +08:00
										 |  |  | #  Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | 
					
						
							| 
									
										
										
										
											2024-01-15 08:46:22 +08:00
										 |  |  | # | 
					
						
							|  |  |  | #  Licensed under the Apache License, Version 2.0 (the "License"); | 
					
						
							|  |  |  | #  you may not use this file except in compliance with the License. | 
					
						
							|  |  |  | #  You may obtain a copy of the License at | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | #      http://www.apache.org/licenses/LICENSE-2.0 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | #  Unless required by applicable law or agreed to in writing, software | 
					
						
							|  |  |  | #  distributed under the License is distributed on an "AS IS" BASIS, | 
					
						
							|  |  |  | #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
					
						
							|  |  |  | #  See the License for the specific language governing permissions and | 
					
						
							|  |  |  | #  limitations under the License. | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | import operator | 
					
						
							|  |  |  | from functools import reduce | 
					
						
							|  |  |  | from typing import Dict, Type, Union | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  | from api.utils import current_timestamp, timestamp_to_date | 
					
						
							| 
									
										
										
										
											2024-01-15 08:46:22 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  | from api.db.db_models import DB, DataBaseModel | 
					
						
							|  |  |  | from api.db.runtime_config import RuntimeConfig | 
					
						
							|  |  |  | from api.utils.log_utils import getLogger | 
					
						
							| 
									
										
										
										
											2024-01-15 08:46:22 +08:00
										 |  |  | from enum import Enum | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | LOGGER = getLogger() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @DB.connection_context() | 
					
						
							|  |  |  | def bulk_insert_into_db(model, data_source, replace_on_conflict=False): | 
					
						
							|  |  |  |     DB.create_tables([model]) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-27 11:33:46 +08:00
										 |  |  |     for i, data in enumerate(data_source): | 
					
						
							| 
									
										
										
										
											2024-02-01 18:53:56 +08:00
										 |  |  |         current_time = current_timestamp() + i | 
					
						
							| 
									
										
										
										
											2024-01-31 19:57:45 +08:00
										 |  |  |         current_date = timestamp_to_date(current_time) | 
					
						
							|  |  |  |         if 'create_time' not in data: | 
					
						
							|  |  |  |             data['create_time'] = current_time | 
					
						
							|  |  |  |         data['create_date'] = timestamp_to_date(data['create_time']) | 
					
						
							|  |  |  |         data['update_time'] = current_time | 
					
						
							|  |  |  |         data['update_date'] = current_date | 
					
						
							| 
									
										
										
										
											2024-01-15 08:46:22 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-31 19:57:45 +08:00
										 |  |  |     preserve = tuple(data_source[0].keys() - {'create_time', 'create_date'}) | 
					
						
							| 
									
										
										
										
											2024-01-15 08:46:22 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-31 19:57:45 +08:00
										 |  |  |     batch_size = 1000 | 
					
						
							| 
									
										
										
										
											2024-01-15 08:46:22 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     for i in range(0, len(data_source), batch_size): | 
					
						
							|  |  |  |         with DB.atomic(): | 
					
						
							|  |  |  |             query = model.insert_many(data_source[i:i + batch_size]) | 
					
						
							|  |  |  |             if replace_on_conflict: | 
					
						
							|  |  |  |                 query = query.on_conflict(preserve=preserve) | 
					
						
							|  |  |  |             query.execute() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def get_dynamic_db_model(base, job_id): | 
					
						
							| 
									
										
										
										
											2024-03-27 11:33:46 +08:00
										 |  |  |     return type(base.model( | 
					
						
							|  |  |  |         table_index=get_dynamic_tracking_table_index(job_id=job_id))) | 
					
						
							| 
									
										
										
										
											2024-01-15 08:46:22 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def get_dynamic_tracking_table_index(job_id): | 
					
						
							|  |  |  |     return job_id[:8] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def fill_db_model_object(model_object, human_model_dict): | 
					
						
							|  |  |  |     for k, v in human_model_dict.items(): | 
					
						
							|  |  |  |         attr_name = 'f_%s' % k | 
					
						
							|  |  |  |         if hasattr(model_object.__class__, attr_name): | 
					
						
							|  |  |  |             setattr(model_object, attr_name, v) | 
					
						
							|  |  |  |     return model_object | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # https://docs.peewee-orm.com/en/latest/peewee/query_operators.html | 
					
						
							|  |  |  | supported_operators = { | 
					
						
							|  |  |  |     '==': operator.eq, | 
					
						
							|  |  |  |     '<': operator.lt, | 
					
						
							|  |  |  |     '<=': operator.le, | 
					
						
							|  |  |  |     '>': operator.gt, | 
					
						
							|  |  |  |     '>=': operator.ge, | 
					
						
							|  |  |  |     '!=': operator.ne, | 
					
						
							|  |  |  |     '<<': operator.lshift, | 
					
						
							|  |  |  |     '>>': operator.rshift, | 
					
						
							|  |  |  |     '%': operator.mod, | 
					
						
							|  |  |  |     '**': operator.pow, | 
					
						
							|  |  |  |     '^': operator.xor, | 
					
						
							|  |  |  |     '~': operator.inv, | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-27 11:33:46 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | def query_dict2expression( | 
					
						
							|  |  |  |         model: Type[DataBaseModel], query: Dict[str, Union[bool, int, str, list, tuple]]): | 
					
						
							| 
									
										
										
										
											2024-01-15 08:46:22 +08:00
										 |  |  |     expression = [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     for field, value in query.items(): | 
					
						
							|  |  |  |         if not isinstance(value, (list, tuple)): | 
					
						
							|  |  |  |             value = ('==', value) | 
					
						
							|  |  |  |         op, *val = value | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         field = getattr(model, f'f_{field}') | 
					
						
							| 
									
										
										
										
											2024-03-27 11:33:46 +08:00
										 |  |  |         value = supported_operators[op]( | 
					
						
							|  |  |  |             field, val[0]) if op in supported_operators else getattr( | 
					
						
							|  |  |  |             field, op)( | 
					
						
							|  |  |  |             *val) | 
					
						
							| 
									
										
										
										
											2024-01-15 08:46:22 +08:00
										 |  |  |         expression.append(value) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return reduce(operator.iand, expression) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def query_db(model: Type[DataBaseModel], limit: int = 0, offset: int = 0, | 
					
						
							|  |  |  |              query: dict = None, order_by: Union[str, list, tuple] = None): | 
					
						
							|  |  |  |     data = model.select() | 
					
						
							|  |  |  |     if query: | 
					
						
							|  |  |  |         data = data.where(query_dict2expression(model, query)) | 
					
						
							|  |  |  |     count = data.count() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if not order_by: | 
					
						
							|  |  |  |         order_by = 'create_time' | 
					
						
							|  |  |  |     if not isinstance(order_by, (list, tuple)): | 
					
						
							|  |  |  |         order_by = (order_by, 'asc') | 
					
						
							|  |  |  |     order_by, order = order_by | 
					
						
							|  |  |  |     order_by = getattr(model, f'f_{order_by}') | 
					
						
							|  |  |  |     order_by = getattr(order_by, order)() | 
					
						
							|  |  |  |     data = data.order_by(order_by) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if limit > 0: | 
					
						
							|  |  |  |         data = data.limit(limit) | 
					
						
							|  |  |  |     if offset > 0: | 
					
						
							|  |  |  |         data = data.offset(offset) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return list(data), count |