1515package ast
1616
1717import (
18+ "strings"
19+
1820 "github.com/cockroachdb/errors"
1921
2022 vitess "github.com/dolthub/vitess/go/vt/sqlparser"
@@ -26,7 +28,7 @@ import (
2628)
2729
2830// nodeResolvableTypeReference handles tree.ResolvableTypeReference nodes.
29- func nodeResolvableTypeReference (ctx * Context , typ tree.ResolvableTypeReference ) (* vitess.ConvertType , * pgtypes.DoltgresType , error ) {
31+ func nodeResolvableTypeReference (ctx * Context , typ tree.ResolvableTypeReference , mayBeTrigger bool ) (* vitess.ConvertType , * pgtypes.DoltgresType , error ) {
3032 if typ == nil {
3133 // TODO: use UNKNOWN?
3234 return nil , nil , nil
@@ -35,36 +37,36 @@ func nodeResolvableTypeReference(ctx *Context, typ tree.ResolvableTypeReference)
3537 var columnTypeName string
3638 var columnTypeLength * vitess.SQLVal
3739 var columnTypeScale * vitess.SQLVal
38- var resolvedType * pgtypes.DoltgresType
40+ var doltgresType * pgtypes.DoltgresType
3941 var err error
4042 switch columnType := typ .(type ) {
4143 case * tree.ArrayTypeReference :
4244 if uon , ok := columnType .ElementType .(* tree.UnresolvedObjectName ); ok {
43- return nodeResolvableTypeReference (ctx , uon )
45+ return nodeResolvableTypeReference (ctx , uon , mayBeTrigger )
4446 }
4547 return nil , nil , errors .Errorf ("the given array type is not yet supported" )
4648 case * tree.OIDTypeReference :
4749 return nil , nil , errors .Errorf ("referencing types by their OID is not yet supported" )
4850 case * tree.UnresolvedObjectName :
4951 tn := columnType .ToTableName ()
5052 columnTypeName = tn .Object ()
51- resolvedType = pgtypes .NewUnresolvedDoltgresType (tn .Schema (), columnTypeName )
53+ doltgresType = pgtypes .NewUnresolvedDoltgresType (tn .Schema (), columnTypeName )
5254 case * types.GeoMetadata :
5355 return nil , nil , errors .Errorf ("geometry types are not yet supported" )
5456 case * types.T :
5557 columnTypeName = columnType .SQLStandardName ()
5658 if columnType .Family () == types .ArrayFamily {
57- _ , baseResolvedType , err := nodeResolvableTypeReference (ctx , columnType .ArrayContents ())
59+ _ , baseResolvedType , err := nodeResolvableTypeReference (ctx , columnType .ArrayContents (), mayBeTrigger )
5860 if err != nil {
5961 return nil , nil , err
6062 }
6163 if baseResolvedType .IsResolvedType () {
6264 // currently the built-in types will be resolved, so it can retrieve its array type
63- resolvedType = baseResolvedType .ToArrayType ()
65+ doltgresType = baseResolvedType .ToArrayType ()
6466 } else {
6567 // TODO: handle array type of non-built-in types
6668 baseResolvedType .TypCategory = pgtypes .TypeCategory_ArrayTypes
67- resolvedType = baseResolvedType
69+ doltgresType = baseResolvedType
6870 }
6971 } else if columnType .Family () == types .GeometryFamily {
7072 return nil , nil , errors .Errorf ("geometry types are not yet supported" )
@@ -73,20 +75,20 @@ func nodeResolvableTypeReference(ctx *Context, typ tree.ResolvableTypeReference)
7375 } else {
7476 switch columnType .Oid () {
7577 case oid .T_record :
76- resolvedType = pgtypes .Record
78+ doltgresType = pgtypes .Record
7779 case oid .T_bool :
78- resolvedType = pgtypes .Bool
80+ doltgresType = pgtypes .Bool
7981 case oid .T_bytea :
80- resolvedType = pgtypes .Bytea
82+ doltgresType = pgtypes .Bytea
8183 case oid .T_bpchar :
8284 width := uint32 (columnType .Width ())
8385 if width > pgtypes .StringMaxLength {
8486 return nil , nil , errors .Errorf ("length for type bpchar cannot exceed %d" , pgtypes .StringMaxLength )
8587 } else if width == 0 {
8688 // TODO: need to differentiate between definitions 'bpchar' (valid) and 'char(0)' (invalid)
87- resolvedType = pgtypes .BpChar
89+ doltgresType = pgtypes .BpChar
8890 } else {
89- resolvedType , err = pgtypes .NewCharType (int32 (width ))
91+ doltgresType , err = pgtypes .NewCharType (int32 (width ))
9092 if err != nil {
9193 return nil , nil , err
9294 }
@@ -99,80 +101,80 @@ func nodeResolvableTypeReference(ctx *Context, typ tree.ResolvableTypeReference)
99101 if width == 0 {
100102 width = 1
101103 }
102- resolvedType = pgtypes .InternalChar
104+ doltgresType = pgtypes .InternalChar
103105 case oid .T_date :
104- resolvedType = pgtypes .Date
106+ doltgresType = pgtypes .Date
105107 case oid .T_float4 :
106- resolvedType = pgtypes .Float32
108+ doltgresType = pgtypes .Float32
107109 case oid .T_float8 :
108- resolvedType = pgtypes .Float64
110+ doltgresType = pgtypes .Float64
109111 case oid .T_int2 :
110- resolvedType = pgtypes .Int16
112+ doltgresType = pgtypes .Int16
111113 case oid .T_int4 :
112- resolvedType = pgtypes .Int32
114+ doltgresType = pgtypes .Int32
113115 case oid .T_int8 :
114- resolvedType = pgtypes .Int64
116+ doltgresType = pgtypes .Int64
115117 case oid .T_interval :
116- resolvedType = pgtypes .Interval
118+ doltgresType = pgtypes .Interval
117119 case oid .T_json :
118- resolvedType = pgtypes .Json
120+ doltgresType = pgtypes .Json
119121 case oid .T_jsonb :
120- resolvedType = pgtypes .JsonB
122+ doltgresType = pgtypes .JsonB
121123 case oid .T_name :
122- resolvedType = pgtypes .Name
124+ doltgresType = pgtypes .Name
123125 case oid .T_numeric :
124126 if columnType .Precision () == 0 && columnType .Scale () == 0 {
125- resolvedType = pgtypes .Numeric
127+ doltgresType = pgtypes .Numeric
126128 } else {
127- resolvedType , err = pgtypes .NewNumericTypeWithPrecisionAndScale (columnType .Precision (), columnType .Scale ())
129+ doltgresType , err = pgtypes .NewNumericTypeWithPrecisionAndScale (columnType .Precision (), columnType .Scale ())
128130 if err != nil {
129131 return nil , nil , err
130132 }
131133 }
132134 case oid .T_oid :
133- resolvedType = pgtypes .Oid
135+ doltgresType = pgtypes .Oid
134136 case oid .T_regclass :
135- resolvedType = pgtypes .Regclass
137+ doltgresType = pgtypes .Regclass
136138 case oid .T_regproc :
137- resolvedType = pgtypes .Regproc
139+ doltgresType = pgtypes .Regproc
138140 case oid .T_regtype :
139- resolvedType = pgtypes .Regtype
141+ doltgresType = pgtypes .Regtype
140142 case oid .T_text :
141- resolvedType = pgtypes .Text
143+ doltgresType = pgtypes .Text
142144 case oid .T_time :
143- resolvedType = pgtypes .Time
145+ doltgresType = pgtypes .Time
144146 case oid .T_timestamp :
145- resolvedType = pgtypes .Timestamp
147+ doltgresType = pgtypes .Timestamp
146148 case oid .T_timestamptz :
147- resolvedType = pgtypes .TimestampTZ
149+ doltgresType = pgtypes .TimestampTZ
148150 case oid .T_timetz :
149- resolvedType = pgtypes .TimeTZ
151+ doltgresType = pgtypes .TimeTZ
150152 case oid .T_uuid :
151- resolvedType = pgtypes .Uuid
153+ doltgresType = pgtypes .Uuid
152154 case oid .T_varchar :
153155 width := uint32 (columnType .Width ())
154156 if width > pgtypes .StringMaxLength {
155157 return nil , nil , errors .Errorf ("length for type varchar cannot exceed %d" , pgtypes .StringMaxLength )
156158 } else if width == 0 {
157159 // TODO: need to differentiate between definitions 'varchar' (valid) and 'varchar(0)' (invalid)
158- resolvedType = pgtypes .VarChar
160+ doltgresType = pgtypes .VarChar
159161 } else {
160- resolvedType , err = pgtypes .NewVarCharType (int32 (width ))
162+ doltgresType , err = pgtypes .NewVarCharType (int32 (width ))
161163 if err != nil {
162164 return nil , nil , err
163165 }
164166 }
165167 case oid .T_xid :
166- resolvedType = pgtypes .Xid
168+ doltgresType = pgtypes .Xid
167169 case oid .T_bit :
168170 width := uint32 (columnType .Width ())
169171 if width > pgtypes .StringMaxLength {
170172 return nil , nil , errors .Errorf ("length for type bit cannot exceed %d" , pgtypes .StringMaxLength )
171173 } else if width == 0 {
172174 // TODO: need to differentiate between definitions 'bit' (valid) and 'bit(0)' (invalid)
173- resolvedType = pgtypes .Bit
175+ doltgresType = pgtypes .Bit
174176 } else {
175- resolvedType , err = pgtypes .NewBitType (int32 (width ))
177+ doltgresType , err = pgtypes .NewBitType (int32 (width ))
176178 if err != nil {
177179 return nil , nil , err
178180 }
@@ -183,9 +185,9 @@ func nodeResolvableTypeReference(ctx *Context, typ tree.ResolvableTypeReference)
183185 return nil , nil , errors .Errorf ("length for type varbit cannot exceed %d" , pgtypes .StringMaxLength )
184186 } else if width == 0 {
185187 // TODO: need to differentiate between definitions 'varbit' (valid) and 'varbit(0)' (invalid)
186- resolvedType = pgtypes .VarBit
188+ doltgresType = pgtypes .VarBit
187189 } else {
188- resolvedType , err = pgtypes .NewVarBitType (int32 (width ))
190+ doltgresType , err = pgtypes .NewVarBitType (int32 (width ))
189191 if err != nil {
190192 return nil , nil , err
191193 }
@@ -194,12 +196,14 @@ func nodeResolvableTypeReference(ctx *Context, typ tree.ResolvableTypeReference)
194196 return nil , nil , errors .Errorf ("unknown type with oid: %d" , uint32 (columnType .Oid ()))
195197 }
196198 }
199+ default :
200+ doltgresType = pgtypes .NewUnresolvedDoltgresType ("" , strings .ToLower (typ .SQLString ()))
197201 }
198202
199203 return & vitess.ConvertType {
200204 Type : columnTypeName ,
201205 Length : columnTypeLength ,
202206 Scale : columnTypeScale ,
203207 Charset : "" , // TODO
204- }, resolvedType , nil
208+ }, doltgresType , nil
205209}
0 commit comments