Skip to content

Commit bbea5ed

Browse files
committed
use nodeResolvableTypeReference function
1 parent 8f66d1a commit bbea5ed

8 files changed

Lines changed: 58 additions & 73 deletions

File tree

server/ast/alter_table.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ func nodeAlterTableAlterColumnType(ctx *Context, node *tree.AlterTableAlterColum
324324
return nil, errors.Errorf("ALTER TABLE with USING is not supported yet")
325325
}
326326

327-
convertType, resolvedType, err := nodeResolvableTypeReference(ctx, node.ToType)
327+
convertType, resolvedType, err := nodeResolvableTypeReference(ctx, node.ToType, false)
328328
if err != nil {
329329
return nil, err
330330
}

server/ast/column_table_def.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func nodeColumnTableDef(ctx *Context, node *tree.ColumnTableDef) (*vitess.Column
3737
len(node.UniqueConstraintName) > 0 {
3838
return nil, errors.Errorf("non-foreign key column constraint names are not yet supported")
3939
}
40-
convertType, resolvedType, err := nodeResolvableTypeReference(ctx, node.Type)
40+
convertType, resolvedType, err := nodeResolvableTypeReference(ctx, node.Type, false)
4141
if err != nil {
4242
return nil, err
4343
}

server/ast/create_domain.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ func nodeCreateDomain(ctx *Context, node *tree.CreateDomain) (vitess.Statement,
3434
if err != nil {
3535
return nil, err
3636
}
37-
_, dataType, err := nodeResolvableTypeReference(ctx, node.DataType)
37+
_, dataType, err := nodeResolvableTypeReference(ctx, node.DataType, false)
3838
if err != nil {
3939
return nil, err
4040
}

server/ast/create_function.go

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import (
2525
"github.com/dolthub/doltgresql/core/id"
2626
"github.com/dolthub/doltgresql/postgres/parser/parser"
2727
"github.com/dolthub/doltgresql/postgres/parser/sem/tree"
28-
"github.com/dolthub/doltgresql/postgres/parser/types"
2928
"github.com/dolthub/doltgresql/server/functions/framework"
3029
pgnodes "github.com/dolthub/doltgresql/server/node"
3130
"github.com/dolthub/doltgresql/server/plpgsql"
@@ -45,7 +44,7 @@ func nodeCreateFunction(ctx *Context, node *tree.CreateFunction) (vitess.Stateme
4544
retType = pgtypes.Void
4645
} else if !node.ReturnsTable {
4746
// Return types may specify "trigger", but this doesn't apply elsewhere
48-
retType, err = getDoltgresType(ctx, node.RetType[0].Type, true)
47+
_, retType, err = nodeResolvableTypeReference(ctx, node.RetType[0].Type, true)
4948
if err != nil {
5049
return nil, err
5150
}
@@ -57,7 +56,7 @@ func nodeCreateFunction(ctx *Context, node *tree.CreateFunction) (vitess.Stateme
5756
paramTypes := make([]*pgtypes.DoltgresType, len(node.Args))
5857
for i, arg := range node.Args {
5958
paramNames[i] = arg.Name.String()
60-
paramTypes[i], err = getDoltgresType(ctx, arg.Type, false)
59+
_, paramTypes[i], err = nodeResolvableTypeReference(ctx, arg.Type, false)
6160
if err != nil {
6261
return nil, err
6362
}
@@ -90,7 +89,7 @@ func nodeCreateFunction(ctx *Context, node *tree.CreateFunction) (vitess.Stateme
9089
if err != nil {
9190
return nil, err
9291
}
93-
dt, err := getDoltgresType(ctx, declareTyp, false)
92+
_, dt, err := nodeResolvableTypeReference(ctx, declareTyp, false)
9493
if err != nil {
9594
return nil, err
9695
}
@@ -220,21 +219,3 @@ func validateRoutineOptions(ctx *Context, options []tree.RoutineOption) (map[tre
220219
}
221220
return optDefined, nil
222221
}
223-
224-
// getDoltgresType converts ResolvableTypeReference into *DoltgresType.
225-
func getDoltgresType(ctx *Context, rt tree.ResolvableTypeReference, mayBeTrigger bool) (*pgtypes.DoltgresType, error) {
226-
switch argType := rt.(type) {
227-
case *types.T:
228-
return pgtypes.NewUnresolvedDoltgresType("", strings.ToLower(argType.Name())), nil
229-
case *tree.UnresolvedObjectName:
230-
if mayBeTrigger && argType.NumParts == 1 && argType.SQLString() == "trigger" {
231-
return pgtypes.Trigger, nil
232-
} else {
233-
_, retType, err := nodeResolvableTypeReference(ctx, argType)
234-
return retType, err
235-
}
236-
default:
237-
// return nil, fmt.Errorf("unsupported ResolvableTypeReference type: %T", typ)
238-
return pgtypes.NewUnresolvedDoltgresType("", strings.ToLower(argType.SQLString())), nil
239-
}
240-
}

server/ast/create_sequence.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ func nodeCreateSequence(ctx *Context, node *tree.CreateSequence) (vitess.Stateme
6767
if !dataType.IsEmptyType() {
6868
return nil, errors.Errorf("conflicting or redundant options")
6969
}
70-
_, dataType, err = nodeResolvableTypeReference(ctx, option.AsType)
70+
_, dataType, err = nodeResolvableTypeReference(ctx, option.AsType, false)
7171
if err != nil {
7272
return nil, err
7373
}

server/ast/create_type.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ func nodeCreateType(ctx *Context, node *tree.CreateType) (vitess.Statement, erro
4040
case tree.Composite:
4141
typs := make([]pgnodes.CompositeAsType, len(node.Composite.Types))
4242
for i, t := range node.Composite.Types {
43-
_, dataType, err := nodeResolvableTypeReference(ctx, t.Type)
43+
_, dataType, err := nodeResolvableTypeReference(ctx, t.Type, false)
4444
if err != nil {
4545
return nil, err
4646
}

server/ast/expr.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ func nodeExpr(ctx *Context, node tree.Expr) (vitess.Expr, error) {
110110
unresolvedChildren := make([]vitess.Expr, len(node.Exprs))
111111
var coercedType *pgtypes.DoltgresType
112112
if node.HasResolvedType() {
113-
_, resolvedType, err := nodeResolvableTypeReference(ctx, node.ResolvedType())
113+
_, resolvedType, err := nodeResolvableTypeReference(ctx, node.ResolvedType(), false)
114114
if err != nil {
115115
return nil, err
116116
}
@@ -259,7 +259,7 @@ func nodeExpr(ctx *Context, node tree.Expr) (vitess.Expr, error) {
259259
return nil, errors.Errorf("unknown cast syntax")
260260
}
261261

262-
convertType, resolvedType, err := nodeResolvableTypeReference(ctx, node.Type)
262+
convertType, resolvedType, err := nodeResolvableTypeReference(ctx, node.Type, false)
263263
if err != nil {
264264
return nil, err
265265
}
@@ -593,7 +593,7 @@ func nodeExpr(ctx *Context, node tree.Expr) (vitess.Expr, error) {
593593
defVal := &vitess.Default{ColName: ""}
594594
return defVal, nil
595595
case tree.DomainColumn:
596-
_, dataType, err := nodeResolvableTypeReference(ctx, node.Typ)
596+
_, dataType, err := nodeResolvableTypeReference(ctx, node.Typ, false)
597597
if err != nil {
598598
return nil, err
599599
}

server/ast/resolvable_type_reference.go

Lines changed: 47 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
package ast
1616

1717
import (
18+
"strings"
19+
1820
"github.com/cockroachdb/errors"
1921

2022
vitess "github.com/dolthub/vitess/go/vt/sqlparser"
@@ -26,7 +28,7 @@ import (
2628
)
2729

2830
// nodeResolvableTypeReference handles tree.ResolvableTypeReference nodes.
29-
func nodeResolvableTypeReference(ctx *Context, typ tree.ResolvableTypeReference) (*vitess.ConvertType, *pgtypes.DoltgresType, error) {
31+
func nodeResolvableTypeReference(ctx *Context, typ tree.ResolvableTypeReference, mayBeTrigger bool) (*vitess.ConvertType, *pgtypes.DoltgresType, error) {
3032
if typ == nil {
3133
// TODO: use UNKNOWN?
3234
return nil, nil, nil
@@ -35,36 +37,36 @@ func nodeResolvableTypeReference(ctx *Context, typ tree.ResolvableTypeReference)
3537
var columnTypeName string
3638
var columnTypeLength *vitess.SQLVal
3739
var columnTypeScale *vitess.SQLVal
38-
var resolvedType *pgtypes.DoltgresType
40+
var doltgresType *pgtypes.DoltgresType
3941
var err error
4042
switch columnType := typ.(type) {
4143
case *tree.ArrayTypeReference:
4244
if uon, ok := columnType.ElementType.(*tree.UnresolvedObjectName); ok {
43-
return nodeResolvableTypeReference(ctx, uon)
45+
return nodeResolvableTypeReference(ctx, uon, mayBeTrigger)
4446
}
4547
return nil, nil, errors.Errorf("the given array type is not yet supported")
4648
case *tree.OIDTypeReference:
4749
return nil, nil, errors.Errorf("referencing types by their OID is not yet supported")
4850
case *tree.UnresolvedObjectName:
4951
tn := columnType.ToTableName()
5052
columnTypeName = tn.Object()
51-
resolvedType = pgtypes.NewUnresolvedDoltgresType(tn.Schema(), columnTypeName)
53+
doltgresType = pgtypes.NewUnresolvedDoltgresType(tn.Schema(), columnTypeName)
5254
case *types.GeoMetadata:
5355
return nil, nil, errors.Errorf("geometry types are not yet supported")
5456
case *types.T:
5557
columnTypeName = columnType.SQLStandardName()
5658
if columnType.Family() == types.ArrayFamily {
57-
_, baseResolvedType, err := nodeResolvableTypeReference(ctx, columnType.ArrayContents())
59+
_, baseResolvedType, err := nodeResolvableTypeReference(ctx, columnType.ArrayContents(), mayBeTrigger)
5860
if err != nil {
5961
return nil, nil, err
6062
}
6163
if baseResolvedType.IsResolvedType() {
6264
// currently the built-in types will be resolved, so it can retrieve its array type
63-
resolvedType = baseResolvedType.ToArrayType()
65+
doltgresType = baseResolvedType.ToArrayType()
6466
} else {
6567
// TODO: handle array type of non-built-in types
6668
baseResolvedType.TypCategory = pgtypes.TypeCategory_ArrayTypes
67-
resolvedType = baseResolvedType
69+
doltgresType = baseResolvedType
6870
}
6971
} else if columnType.Family() == types.GeometryFamily {
7072
return nil, nil, errors.Errorf("geometry types are not yet supported")
@@ -73,20 +75,20 @@ func nodeResolvableTypeReference(ctx *Context, typ tree.ResolvableTypeReference)
7375
} else {
7476
switch columnType.Oid() {
7577
case oid.T_record:
76-
resolvedType = pgtypes.Record
78+
doltgresType = pgtypes.Record
7779
case oid.T_bool:
78-
resolvedType = pgtypes.Bool
80+
doltgresType = pgtypes.Bool
7981
case oid.T_bytea:
80-
resolvedType = pgtypes.Bytea
82+
doltgresType = pgtypes.Bytea
8183
case oid.T_bpchar:
8284
width := uint32(columnType.Width())
8385
if width > pgtypes.StringMaxLength {
8486
return nil, nil, errors.Errorf("length for type bpchar cannot exceed %d", pgtypes.StringMaxLength)
8587
} else if width == 0 {
8688
// TODO: need to differentiate between definitions 'bpchar' (valid) and 'char(0)' (invalid)
87-
resolvedType = pgtypes.BpChar
89+
doltgresType = pgtypes.BpChar
8890
} else {
89-
resolvedType, err = pgtypes.NewCharType(int32(width))
91+
doltgresType, err = pgtypes.NewCharType(int32(width))
9092
if err != nil {
9193
return nil, nil, err
9294
}
@@ -99,80 +101,80 @@ func nodeResolvableTypeReference(ctx *Context, typ tree.ResolvableTypeReference)
99101
if width == 0 {
100102
width = 1
101103
}
102-
resolvedType = pgtypes.InternalChar
104+
doltgresType = pgtypes.InternalChar
103105
case oid.T_date:
104-
resolvedType = pgtypes.Date
106+
doltgresType = pgtypes.Date
105107
case oid.T_float4:
106-
resolvedType = pgtypes.Float32
108+
doltgresType = pgtypes.Float32
107109
case oid.T_float8:
108-
resolvedType = pgtypes.Float64
110+
doltgresType = pgtypes.Float64
109111
case oid.T_int2:
110-
resolvedType = pgtypes.Int16
112+
doltgresType = pgtypes.Int16
111113
case oid.T_int4:
112-
resolvedType = pgtypes.Int32
114+
doltgresType = pgtypes.Int32
113115
case oid.T_int8:
114-
resolvedType = pgtypes.Int64
116+
doltgresType = pgtypes.Int64
115117
case oid.T_interval:
116-
resolvedType = pgtypes.Interval
118+
doltgresType = pgtypes.Interval
117119
case oid.T_json:
118-
resolvedType = pgtypes.Json
120+
doltgresType = pgtypes.Json
119121
case oid.T_jsonb:
120-
resolvedType = pgtypes.JsonB
122+
doltgresType = pgtypes.JsonB
121123
case oid.T_name:
122-
resolvedType = pgtypes.Name
124+
doltgresType = pgtypes.Name
123125
case oid.T_numeric:
124126
if columnType.Precision() == 0 && columnType.Scale() == 0 {
125-
resolvedType = pgtypes.Numeric
127+
doltgresType = pgtypes.Numeric
126128
} else {
127-
resolvedType, err = pgtypes.NewNumericTypeWithPrecisionAndScale(columnType.Precision(), columnType.Scale())
129+
doltgresType, err = pgtypes.NewNumericTypeWithPrecisionAndScale(columnType.Precision(), columnType.Scale())
128130
if err != nil {
129131
return nil, nil, err
130132
}
131133
}
132134
case oid.T_oid:
133-
resolvedType = pgtypes.Oid
135+
doltgresType = pgtypes.Oid
134136
case oid.T_regclass:
135-
resolvedType = pgtypes.Regclass
137+
doltgresType = pgtypes.Regclass
136138
case oid.T_regproc:
137-
resolvedType = pgtypes.Regproc
139+
doltgresType = pgtypes.Regproc
138140
case oid.T_regtype:
139-
resolvedType = pgtypes.Regtype
141+
doltgresType = pgtypes.Regtype
140142
case oid.T_text:
141-
resolvedType = pgtypes.Text
143+
doltgresType = pgtypes.Text
142144
case oid.T_time:
143-
resolvedType = pgtypes.Time
145+
doltgresType = pgtypes.Time
144146
case oid.T_timestamp:
145-
resolvedType = pgtypes.Timestamp
147+
doltgresType = pgtypes.Timestamp
146148
case oid.T_timestamptz:
147-
resolvedType = pgtypes.TimestampTZ
149+
doltgresType = pgtypes.TimestampTZ
148150
case oid.T_timetz:
149-
resolvedType = pgtypes.TimeTZ
151+
doltgresType = pgtypes.TimeTZ
150152
case oid.T_uuid:
151-
resolvedType = pgtypes.Uuid
153+
doltgresType = pgtypes.Uuid
152154
case oid.T_varchar:
153155
width := uint32(columnType.Width())
154156
if width > pgtypes.StringMaxLength {
155157
return nil, nil, errors.Errorf("length for type varchar cannot exceed %d", pgtypes.StringMaxLength)
156158
} else if width == 0 {
157159
// TODO: need to differentiate between definitions 'varchar' (valid) and 'varchar(0)' (invalid)
158-
resolvedType = pgtypes.VarChar
160+
doltgresType = pgtypes.VarChar
159161
} else {
160-
resolvedType, err = pgtypes.NewVarCharType(int32(width))
162+
doltgresType, err = pgtypes.NewVarCharType(int32(width))
161163
if err != nil {
162164
return nil, nil, err
163165
}
164166
}
165167
case oid.T_xid:
166-
resolvedType = pgtypes.Xid
168+
doltgresType = pgtypes.Xid
167169
case oid.T_bit:
168170
width := uint32(columnType.Width())
169171
if width > pgtypes.StringMaxLength {
170172
return nil, nil, errors.Errorf("length for type bit cannot exceed %d", pgtypes.StringMaxLength)
171173
} else if width == 0 {
172174
// TODO: need to differentiate between definitions 'bit' (valid) and 'bit(0)' (invalid)
173-
resolvedType = pgtypes.Bit
175+
doltgresType = pgtypes.Bit
174176
} else {
175-
resolvedType, err = pgtypes.NewBitType(int32(width))
177+
doltgresType, err = pgtypes.NewBitType(int32(width))
176178
if err != nil {
177179
return nil, nil, err
178180
}
@@ -183,9 +185,9 @@ func nodeResolvableTypeReference(ctx *Context, typ tree.ResolvableTypeReference)
183185
return nil, nil, errors.Errorf("length for type varbit cannot exceed %d", pgtypes.StringMaxLength)
184186
} else if width == 0 {
185187
// TODO: need to differentiate between definitions 'varbit' (valid) and 'varbit(0)' (invalid)
186-
resolvedType = pgtypes.VarBit
188+
doltgresType = pgtypes.VarBit
187189
} else {
188-
resolvedType, err = pgtypes.NewVarBitType(int32(width))
190+
doltgresType, err = pgtypes.NewVarBitType(int32(width))
189191
if err != nil {
190192
return nil, nil, err
191193
}
@@ -194,12 +196,14 @@ func nodeResolvableTypeReference(ctx *Context, typ tree.ResolvableTypeReference)
194196
return nil, nil, errors.Errorf("unknown type with oid: %d", uint32(columnType.Oid()))
195197
}
196198
}
199+
default:
200+
doltgresType = pgtypes.NewUnresolvedDoltgresType("", strings.ToLower(typ.SQLString()))
197201
}
198202

199203
return &vitess.ConvertType{
200204
Type: columnTypeName,
201205
Length: columnTypeLength,
202206
Scale: columnTypeScale,
203207
Charset: "", // TODO
204-
}, resolvedType, nil
208+
}, doltgresType, nil
205209
}

0 commit comments

Comments
 (0)