| 
									
										
										
										
											2022-08-03 12:01:57 +02:00
										 |  |  | #  Copyright 2021 Collate | 
					
						
							|  |  |  | #  Licensed under the Apache License, Version 2.0 (the "License"); | 
					
						
							|  |  |  | #  you may not use this file except in compliance with the License. | 
					
						
							|  |  |  | #  You may obtain a copy of the License at | 
					
						
							|  |  |  | #  http://www.apache.org/licenses/LICENSE-2.0 | 
					
						
							|  |  |  | #  Unless required by applicable law or agreed to in writing, software | 
					
						
							|  |  |  | #  distributed under the License is distributed on an "AS IS" BASIS, | 
					
						
							|  |  |  | #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
					
						
							|  |  |  | #  See the License for the specific language governing permissions and | 
					
						
							|  |  |  | #  limitations under the License. | 
					
						
							|  |  |  | """
 | 
					
						
							|  |  |  | Test helpers module | 
					
						
							|  |  |  | """
 | 
					
						
							| 
									
										
										
										
											2022-10-26 11:18:08 +02:00
										 |  |  | import uuid | 
					
						
							| 
									
										
										
										
											2022-08-03 12:01:57 +02:00
										 |  |  | from unittest import TestCase | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-26 11:18:08 +02:00
										 |  |  | from metadata.generated.schema.entity.data.table import Column, DataType, Table | 
					
						
							|  |  |  | from metadata.generated.schema.type.tagLabel import ( | 
					
						
							|  |  |  |     LabelType, | 
					
						
							|  |  |  |     State, | 
					
						
							|  |  |  |     TagLabel, | 
					
						
							|  |  |  |     TagSource, | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
										
										
											2022-10-05 16:09:33 +02:00
										 |  |  | from metadata.utils.helpers import ( | 
					
						
							|  |  |  |     clean_up_starting_ending_double_quotes_in_string, | 
					
						
							| 
									
										
										
										
											2023-05-09 12:05:35 +02:00
										 |  |  |     deep_size_of_dict, | 
					
						
							| 
									
										
										
										
											2022-10-26 11:18:08 +02:00
										 |  |  |     get_entity_tier_from_tags, | 
					
						
							| 
									
										
										
										
											2023-06-02 09:41:31 +02:00
										 |  |  |     is_safe_sql_query, | 
					
						
							| 
									
										
										
										
											2022-10-05 16:09:33 +02:00
										 |  |  |     list_to_dict, | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
										
										
											2022-08-03 12:01:57 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class TestHelpers(TestCase): | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     Test helpers module | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_list_to_dict(self): | 
					
						
							|  |  |  |         original = ["key=value", "a=b"] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.assertEqual(list_to_dict(original=original), {"key": "value", "a": "b"}) | 
					
						
							|  |  |  |         self.assertEqual(list_to_dict([]), {}) | 
					
						
							|  |  |  |         self.assertEqual(list_to_dict(None), {}) | 
					
						
							| 
									
										
										
										
											2022-10-05 16:09:33 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def test_clean_up_starting_ending_double_quotes_in_string(self): | 
					
						
							|  |  |  |         input_ = '"password"' | 
					
						
							|  |  |  |         output_ = "password" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         assert clean_up_starting_ending_double_quotes_in_string(input_) == output_ | 
					
						
							| 
									
										
										
										
											2022-10-26 11:18:08 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def test_get_entity_tier_from_tags(self): | 
					
						
							|  |  |  |         """test correct entity tier are returned""" | 
					
						
							|  |  |  |         table_entity_w_tier = Table( | 
					
						
							|  |  |  |             id=uuid.uuid4(), | 
					
						
							|  |  |  |             name="table_entity_test", | 
					
						
							|  |  |  |             columns=[Column(name="col1", dataType=DataType.STRING)], | 
					
						
							|  |  |  |             tags=[ | 
					
						
							|  |  |  |                 TagLabel( | 
					
						
							|  |  |  |                     tagFQN="Tier.Tier1", | 
					
						
							| 
									
										
										
										
											2023-03-09 00:30:36 -08:00
										 |  |  |                     source=TagSource.Classification, | 
					
						
							| 
									
										
										
										
											2022-10-26 11:18:08 +02:00
										 |  |  |                     labelType=LabelType.Automated, | 
					
						
							|  |  |  |                     state=State.Confirmed, | 
					
						
							|  |  |  |                 ), | 
					
						
							|  |  |  |                 TagLabel( | 
					
						
							|  |  |  |                     tagFQN="Foo.Bar", | 
					
						
							| 
									
										
										
										
											2023-03-09 00:30:36 -08:00
										 |  |  |                     source=TagSource.Classification, | 
					
						
							| 
									
										
										
										
											2022-10-26 11:18:08 +02:00
										 |  |  |                     labelType=LabelType.Automated, | 
					
						
							|  |  |  |                     state=State.Confirmed, | 
					
						
							|  |  |  |                 ), | 
					
						
							|  |  |  |             ], | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         assert get_entity_tier_from_tags(table_entity_w_tier.tags) == "Tier.Tier1" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         table_entity_wo_tier = Table( | 
					
						
							|  |  |  |             id=uuid.uuid4(), | 
					
						
							|  |  |  |             name="table_entity_test", | 
					
						
							|  |  |  |             columns=[Column(name="col1", dataType=DataType.STRING)], | 
					
						
							|  |  |  |             tags=[ | 
					
						
							|  |  |  |                 TagLabel( | 
					
						
							|  |  |  |                     tagFQN="Foo.Bar", | 
					
						
							| 
									
										
										
										
											2023-03-09 00:30:36 -08:00
										 |  |  |                     source=TagSource.Classification, | 
					
						
							| 
									
										
										
										
											2022-10-26 11:18:08 +02:00
										 |  |  |                     labelType=LabelType.Automated, | 
					
						
							|  |  |  |                     state=State.Confirmed, | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |             ], | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         assert get_entity_tier_from_tags(table_entity_wo_tier.tags) is None | 
					
						
							| 
									
										
										
										
											2023-05-09 12:05:35 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def test_deep_size_of_dict(self): | 
					
						
							|  |  |  |         """test deep size of dict""" | 
					
						
							|  |  |  |         test_dict = { | 
					
						
							|  |  |  |             "a": 1, | 
					
						
							|  |  |  |             "b": {"c": 2, "d": {"e": "Hello World", "f": [4, 5, 6]}}, | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         assert deep_size_of_dict(test_dict) >= 1000 | 
					
						
							|  |  |  |         assert deep_size_of_dict(test_dict) <= 1500 | 
					
						
							| 
									
										
										
										
											2023-06-02 09:41:31 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def test_is_safe_sql_query(self): | 
					
						
							|  |  |  |         """Test is_safe_sql_query function""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         delete_query = """
 | 
					
						
							|  |  |  |          DELETE FROM airflow_task_instance | 
					
						
							|  |  |  |          WHERE dag_id = 'test_dag_id' | 
					
						
							|  |  |  |          """
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         drop_query = """
 | 
					
						
							|  |  |  |          DROP TABLE IF EXISTS test_table | 
					
						
							|  |  |  |          """
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         create_query = """
 | 
					
						
							|  |  |  |          CREATE TABLE test_table ( | 
					
						
							|  |  |  |              id INT, | 
					
						
							|  |  |  |              name VARCHAR(255) | 
					
						
							|  |  |  |          ) | 
					
						
							|  |  |  |          """
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         select_query = """
 | 
					
						
							|  |  |  |          SELECT * FROM test_table | 
					
						
							|  |  |  |          """
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         cte_query = """
 | 
					
						
							|  |  |  |          WITH foo AS ( | 
					
						
							|  |  |  |              SELECT * FROM test_table | 
					
						
							|  |  |  |          ) | 
					
						
							|  |  |  |          SELECT * FROM foo | 
					
						
							|  |  |  |          """
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         transaction_query = """
 | 
					
						
							|  |  |  |          BEGIN TRAN T1;   | 
					
						
							|  |  |  |              UPDATE table1 ...;   | 
					
						
							|  |  |  |              BEGIN TRAN M2 WITH MARK;   | 
					
						
							|  |  |  |                  UPDATE table2 ...;   | 
					
						
							|  |  |  |                  SELECT * from table1;   | 
					
						
							|  |  |  |              COMMIT TRAN M2;   | 
					
						
							|  |  |  |              UPDATE table3 ...;   | 
					
						
							|  |  |  |          COMMIT TRAN T1;   | 
					
						
							|  |  |  |          """
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.assertFalse(is_safe_sql_query(delete_query)) | 
					
						
							|  |  |  |         self.assertFalse(is_safe_sql_query(drop_query)) | 
					
						
							|  |  |  |         self.assertFalse(is_safe_sql_query(create_query)) | 
					
						
							|  |  |  |         self.assertTrue(is_safe_sql_query(select_query)) | 
					
						
							|  |  |  |         self.assertTrue(is_safe_sql_query(cte_query)) | 
					
						
							|  |  |  |         self.assertFalse(is_safe_sql_query(transaction_query)) |