55import sqlalchemy
66from sqlalchemy .engine .interfaces import Dialect
77from sqlalchemy .ext .compiler import compiles
8+ from sqlalchemy .types import TypeDecorator , UserDefinedType
89
910from databricks .sql .utils import ParamEscaper
1011
@@ -26,6 +27,11 @@ def process_literal_param_hack(value: Any):
2627 return value
2728
2829
30+ def identity_processor (value ):
31+ """This method returns the value itself, when no other processor is provided"""
32+ return value
33+
34+
2935@compiles (sqlalchemy .types .Enum , "databricks" )
3036@compiles (sqlalchemy .types .String , "databricks" )
3137@compiles (sqlalchemy .types .Text , "databricks" )
@@ -321,3 +327,73 @@ class TINYINT(sqlalchemy.types.TypeDecorator):
321327@compiles (TINYINT , "databricks" )
322328def compile_tinyint (type_ , compiler , ** kw ):
323329 return "TINYINT"
330+
331+
332+ class DatabricksArray (UserDefinedType ):
333+ """
334+ A custom array type that can wrap any other SQLAlchemy type.
335+
336+ Examples:
337+ DatabricksArray(String) -> ARRAY<STRING>
338+ DatabricksArray(Integer) -> ARRAY<INT>
339+ DatabricksArray(CustomType) -> ARRAY<CUSTOM_TYPE>
340+ """
341+
342+ def __init__ (self , item_type ):
343+ self .item_type = item_type () if isinstance (item_type , type ) else item_type
344+
345+ def bind_processor (self , dialect ):
346+ item_processor = self .item_type .bind_processor (dialect )
347+ if item_processor is None :
348+ item_processor = identity_processor
349+
350+ def process (value ):
351+ return [item_processor (val ) for val in value ]
352+
353+ return process
354+
355+
356+ @compiles (DatabricksArray , "databricks" )
357+ def compile_databricks_array (type_ , compiler , ** kw ):
358+ inner = compiler .process (type_ .item_type , ** kw )
359+
360+ return f"ARRAY<{ inner } >"
361+
362+
363+ class DatabricksMap (UserDefinedType ):
364+ """
365+ A custom map type that can wrap any other SQLAlchemy types for both key and value.
366+
367+ Examples:
368+ DatabricksMap(String, String) -> MAP<STRING,STRING>
369+ DatabricksMap(Integer, String) -> MAP<INT,STRING>
370+ DatabricksMap(String, DatabricksArray(Integer)) -> MAP<STRING,ARRAY<INT>>
371+ """
372+
373+ def __init__ (self , key_type , value_type ):
374+ self .key_type = key_type () if isinstance (key_type , type ) else key_type
375+ self .value_type = value_type () if isinstance (value_type , type ) else value_type
376+
377+ def bind_processor (self , dialect ):
378+ key_processor = self .key_type .bind_processor (dialect )
379+ value_processor = self .value_type .bind_processor (dialect )
380+
381+ if key_processor is None :
382+ key_processor = identity_processor
383+ if value_processor is None :
384+ value_processor = identity_processor
385+
386+ def process (value ):
387+ return {
388+ key_processor (key ): value_processor (value )
389+ for key , value in value .items ()
390+ }
391+
392+ return process
393+
394+
395+ @compiles (DatabricksMap , "databricks" )
396+ def compile_databricks_map (type_ , compiler , ** kw ):
397+ key_type = compiler .process (type_ .key_type , ** kw )
398+ value_type = compiler .process (type_ .value_type , ** kw )
399+ return f"MAP<{ key_type } ,{ value_type } >"
0 commit comments