| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | import logging | 
					
						
							| 
									
										
										
										
											2024-12-24 18:38:51 +08:00
										 |  |  | from collections.abc import Mapping | 
					
						
							|  |  |  | from typing import Any, Optional | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-08 18:51:46 +08:00
										 |  |  | from core.app.app_config.entities import AppConfig | 
					
						
							| 
									
										
										
										
											2024-09-11 16:40:52 +08:00
										 |  |  | from core.moderation.base import ModerationAction, ModerationError | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | from core.moderation.factory import ModerationFactory | 
					
						
							| 
									
										
										
										
											2024-08-09 15:22:16 +08:00
										 |  |  | from core.ops.entities.trace_entity import TraceTaskName | 
					
						
							|  |  |  | from core.ops.ops_trace_manager import TraceQueueManager, TraceTask | 
					
						
							| 
									
										
										
										
											2024-06-26 17:33:29 +08:00
										 |  |  | from core.ops.utils import measure_time | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | logger = logging.getLogger(__name__) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-08 18:51:46 +08:00
										 |  |  | class InputModeration: | 
					
						
							| 
									
										
										
										
											2024-06-26 17:33:29 +08:00
										 |  |  |     def check( | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         self, | 
					
						
							|  |  |  |         app_id: str, | 
					
						
							| 
									
										
										
										
											2024-06-26 17:33:29 +08:00
										 |  |  |         tenant_id: str, | 
					
						
							|  |  |  |         app_config: AppConfig, | 
					
						
							| 
									
										
										
										
											2024-12-24 18:38:51 +08:00
										 |  |  |         inputs: Mapping[str, Any], | 
					
						
							| 
									
										
										
										
											2024-06-26 17:33:29 +08:00
										 |  |  |         query: str, | 
					
						
							|  |  |  |         message_id: str, | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         trace_manager: Optional[TraceQueueManager] = None, | 
					
						
							| 
									
										
										
										
											2024-12-24 18:38:51 +08:00
										 |  |  |     ) -> tuple[bool, Mapping[str, Any], str]: | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Process sensitive_word_avoidance. | 
					
						
							|  |  |  |         :param app_id: app id | 
					
						
							|  |  |  |         :param tenant_id: tenant id | 
					
						
							| 
									
										
										
										
											2024-04-08 18:51:46 +08:00
										 |  |  |         :param app_config: app config | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         :param inputs: inputs | 
					
						
							|  |  |  |         :param query: query | 
					
						
							| 
									
										
										
										
											2024-06-26 17:33:29 +08:00
										 |  |  |         :param message_id: message id | 
					
						
							|  |  |  |         :param trace_manager: trace manager | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2024-12-24 18:38:51 +08:00
										 |  |  |         inputs = dict(inputs) | 
					
						
							| 
									
										
										
										
											2024-04-08 18:51:46 +08:00
										 |  |  |         if not app_config.sensitive_word_avoidance: | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             return False, inputs, query | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-08 18:51:46 +08:00
										 |  |  |         sensitive_word_avoidance_config = app_config.sensitive_word_avoidance | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         moderation_type = sensitive_word_avoidance_config.type | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         moderation_factory = ModerationFactory( | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             name=moderation_type, app_id=app_id, tenant_id=tenant_id, config=sensitive_word_avoidance_config.config | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-26 17:33:29 +08:00
										 |  |  |         with measure_time() as timer: | 
					
						
							|  |  |  |             moderation_result = moderation_factory.moderation_for_inputs(inputs, query) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-26 17:33:29 +08:00
										 |  |  |         if trace_manager: | 
					
						
							|  |  |  |             trace_manager.add_trace_task( | 
					
						
							|  |  |  |                 TraceTask( | 
					
						
							|  |  |  |                     TraceTaskName.MODERATION_TRACE, | 
					
						
							|  |  |  |                     message_id=message_id, | 
					
						
							|  |  |  |                     moderation_result=moderation_result, | 
					
						
							|  |  |  |                     inputs=inputs, | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                     timer=timer, | 
					
						
							| 
									
										
										
										
											2024-06-26 17:33:29 +08:00
										 |  |  |                 ) | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2024-06-28 00:24:37 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         if not moderation_result.flagged: | 
					
						
							|  |  |  |             return False, inputs, query | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if moderation_result.action == ModerationAction.DIRECT_OUTPUT: | 
					
						
							| 
									
										
										
										
											2024-09-11 16:40:52 +08:00
										 |  |  |             raise ModerationError(moderation_result.preset_response) | 
					
						
							| 
									
										
										
										
											2024-09-09 22:46:13 +07:00
										 |  |  |         elif moderation_result.action == ModerationAction.OVERRIDDEN: | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             inputs = moderation_result.inputs | 
					
						
							|  |  |  |             query = moderation_result.query | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return True, inputs, query |