| 
									
										
										
										
											2025-02-17 17:05:13 +08:00
										 |  |  | from core.tools.__base.tool_provider import ToolProviderController | 
					
						
							|  |  |  | from core.tools.builtin_tool.provider import BuiltinToolProviderController | 
					
						
							|  |  |  | from core.tools.custom_tool.provider import ApiToolProviderController | 
					
						
							| 
									
										
										
										
											2024-05-27 22:01:11 +08:00
										 |  |  | from core.tools.entities.values import default_tool_label_name_list | 
					
						
							| 
									
										
										
										
											2025-02-17 17:05:13 +08:00
										 |  |  | from core.tools.workflow_as_tool.provider import WorkflowToolProviderController | 
					
						
							| 
									
										
										
										
											2024-05-27 22:01:11 +08:00
										 |  |  | from extensions.ext_database import db | 
					
						
							|  |  |  | from models.tools import ToolLabelBinding | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class ToolLabelManager: | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def filter_tool_labels(cls, tool_labels: list[str]) -> list[str]: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Filter tool labels | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         tool_labels = [label for label in tool_labels if label in default_tool_label_name_list] | 
					
						
							|  |  |  |         return list(set(tool_labels)) | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-05-27 22:01:11 +08:00
										 |  |  |     @classmethod | 
					
						
							|  |  |  |     def update_tool_labels(cls, controller: ToolProviderController, labels: list[str]): | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Update tool labels | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         labels = cls.filter_tool_labels(labels) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): | 
					
						
							|  |  |  |             provider_id = controller.provider_id | 
					
						
							|  |  |  |         else: | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             raise ValueError("Unsupported tool type") | 
					
						
							| 
									
										
										
										
											2024-05-27 22:01:11 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # delete old labels | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id == provider_id).delete() | 
					
						
							| 
									
										
										
										
											2024-05-27 22:01:11 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # insert new labels | 
					
						
							|  |  |  |         for label in labels: | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             db.session.add( | 
					
						
							|  |  |  |                 ToolLabelBinding( | 
					
						
							|  |  |  |                     tool_id=provider_id, | 
					
						
							|  |  |  |                     tool_type=controller.provider_type.value, | 
					
						
							|  |  |  |                     label_name=label, | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2024-05-27 22:01:11 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         db.session.commit() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def get_tool_labels(cls, controller: ToolProviderController) -> list[str]: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Get tool labels | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): | 
					
						
							|  |  |  |             provider_id = controller.provider_id | 
					
						
							|  |  |  |         elif isinstance(controller, BuiltinToolProviderController): | 
					
						
							|  |  |  |             return controller.tool_labels | 
					
						
							|  |  |  |         else: | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             raise ValueError("Unsupported tool type") | 
					
						
							| 
									
										
										
										
											2024-05-27 22:01:11 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-17 17:05:13 +08:00
										 |  |  |         labels = ( | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             db.session.query(ToolLabelBinding.label_name) | 
					
						
							|  |  |  |             .filter( | 
					
						
							|  |  |  |                 ToolLabelBinding.tool_id == provider_id, | 
					
						
							|  |  |  |                 ToolLabelBinding.tool_type == controller.provider_type.value, | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             .all() | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-05-27 22:01:11 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         return [label.label_name for label in labels] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Get tools labels | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :param tool_providers: list of tool providers | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :return: dict of tool labels | 
					
						
							|  |  |  |             :key: tool id | 
					
						
							|  |  |  |             :value: list of tool labels | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         if not tool_providers: | 
					
						
							|  |  |  |             return {} | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-05-27 22:01:11 +08:00
										 |  |  |         for controller in tool_providers: | 
					
						
							|  |  |  |             if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 raise ValueError("Unsupported tool type") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-17 17:05:13 +08:00
										 |  |  |         provider_ids = [] | 
					
						
							|  |  |  |         for controller in tool_providers: | 
					
						
							|  |  |  |             assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController) | 
					
						
							|  |  |  |             provider_ids.append(controller.provider_id) | 
					
						
							| 
									
										
										
										
											2024-05-27 22:01:11 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         labels: list[ToolLabelBinding] = ( | 
					
						
							|  |  |  |             db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id.in_(provider_ids)).all() | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-05-27 22:01:11 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-24 18:38:51 +08:00
										 |  |  |         tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels} | 
					
						
							| 
									
										
										
										
											2024-05-27 22:01:11 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         for label in labels: | 
					
						
							|  |  |  |             tool_labels[label.tool_id].append(label.label_name) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         return tool_labels |