mirror of
				https://github.com/langgenius/dify.git
				synced 2025-10-24 23:48:40 +00:00 
			
		
		
		
	 6a9e0b1005
			
		
	
	
		6a9e0b1005
		
			
		
	
	
	
	
		
			
			- Introduce `WorkflowDraftVariable` model and the corresponding migration. - Implement `EnumText`, a custom column type for SQLAlchemy designed to work seamlessly with enumeration classes based on `StrEnum`.
		
			
				
	
	
		
			78 lines
		
	
	
		
			2.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			78 lines
		
	
	
		
			2.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import enum
 | |
| from typing import Generic, TypeVar
 | |
| 
 | |
| from sqlalchemy import CHAR, VARCHAR, TypeDecorator
 | |
| from sqlalchemy.dialects.postgresql import UUID
 | |
| 
 | |
| 
 | |
| class StringUUID(TypeDecorator):
 | |
|     impl = CHAR
 | |
|     cache_ok = True
 | |
| 
 | |
|     def process_bind_param(self, value, dialect):
 | |
|         if value is None:
 | |
|             return value
 | |
|         elif dialect.name == "postgresql":
 | |
|             return str(value)
 | |
|         else:
 | |
|             return value.hex
 | |
| 
 | |
|     def load_dialect_impl(self, dialect):
 | |
|         if dialect.name == "postgresql":
 | |
|             return dialect.type_descriptor(UUID())
 | |
|         else:
 | |
|             return dialect.type_descriptor(CHAR(36))
 | |
| 
 | |
|     def process_result_value(self, value, dialect):
 | |
|         if value is None:
 | |
|             return value
 | |
|         return str(value)
 | |
| 
 | |
| 
 | |
| _E = TypeVar("_E", bound=enum.StrEnum)
 | |
| 
 | |
| 
 | |
| class EnumText(TypeDecorator, Generic[_E]):
 | |
|     impl = VARCHAR
 | |
|     cache_ok = True
 | |
| 
 | |
|     _length: int
 | |
|     _enum_class: type[_E]
 | |
| 
 | |
|     def __init__(self, enum_class: type[_E], length: int | None = None):
 | |
|         self._enum_class = enum_class
 | |
|         max_enum_value_len = max(len(e.value) for e in enum_class)
 | |
|         if length is not None:
 | |
|             if length < max_enum_value_len:
 | |
|                 raise ValueError("length should be greater than enum value length.")
 | |
|             self._length = length
 | |
|         else:
 | |
|             # leave some rooms for future longer enum values.
 | |
|             self._length = max(max_enum_value_len, 20)
 | |
| 
 | |
|     def process_bind_param(self, value: _E | str | None, dialect):
 | |
|         if value is None:
 | |
|             return value
 | |
|         if isinstance(value, self._enum_class):
 | |
|             return value.value
 | |
|         elif isinstance(value, str):
 | |
|             self._enum_class(value)
 | |
|             return value
 | |
|         else:
 | |
|             raise TypeError(f"expected str or {self._enum_class}, got {type(value)}")
 | |
| 
 | |
|     def load_dialect_impl(self, dialect):
 | |
|         return dialect.type_descriptor(VARCHAR(self._length))
 | |
| 
 | |
|     def process_result_value(self, value, dialect) -> _E | None:
 | |
|         if value is None:
 | |
|             return value
 | |
|         if not isinstance(value, str):
 | |
|             raise TypeError(f"expected str, got {type(value)}")
 | |
|         return self._enum_class(value)
 | |
| 
 | |
|     def compare_values(self, x, y):
 | |
|         if x is None or y is None:
 | |
|             return x is y
 | |
|         return x == y
 |