Skip to content

Commit 93d0f7c

Browse files
committed
refactor(analyzer): simplify VALUES type resolution implementation
Refactor ResolveValuesTypes analyzer rule to use simpler implementation based on PR review feedback. Changes centralize unknown type handling and eliminate fragile tree traversal logic. Changes: - Use TableId-based lookup instead of recursive tree traversal to update GetField types, eliminating dependency on specific node types like SubqueryAlias - Leverage pgtransform.NodeExprsWithOpaque for expression updates instead of manual recursion through four helper functions - Move unknown type handling into cast functions (GetExplicitCast, GetAssignmentCast, GetImplicitCast) to eliminate scattered checks across call sites - Add requiresCasts return value to FindCommonType to optimize case where no type conversion is needed - Simplify VALUES node transformation using sql.Expressioner interface to handle both ValueDerivedTable and Values uniformly - Add comprehensive test coverage for VALUES with GROUP BY, DISTINCT, LIMIT/OFFSET, ORDER BY, subqueries, WHERE clause, aggregates, and combined operations This refactoring reduces code complexity from ~300 lines to ~180 lines while improving maintainability and eliminating potential bugs from manual tree walking. Refs: #1648
1 parent 7cdc339 commit 93d0f7c

8 files changed

Lines changed: 324 additions & 263 deletions

File tree

server/analyzer/resolve_values_types.go

Lines changed: 65 additions & 218 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,15 @@
1515
package analyzer
1616

1717
import (
18+
"github.com/cockroachdb/errors"
1819
"github.com/dolthub/go-mysql-server/sql"
1920
"github.com/dolthub/go-mysql-server/sql/analyzer"
2021
"github.com/dolthub/go-mysql-server/sql/expression"
2122
"github.com/dolthub/go-mysql-server/sql/plan"
2223
"github.com/dolthub/go-mysql-server/sql/transform"
2324

25+
pgtransform "github.com/dolthub/doltgresql/server/transform"
26+
2427
pgexprs "github.com/dolthub/doltgresql/server/expression"
2528
"github.com/dolthub/doltgresql/server/functions/framework"
2629
pgtypes "github.com/dolthub/doltgresql/server/types"
@@ -30,18 +33,17 @@ import (
3033
// by examining all rows, following PostgreSQL's type resolution rules.
3134
// This ensures VALUES(1),(2.01),(3) correctly infers numeric type, not integer.
3235
func ResolveValuesTypes(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, scope *plan.Scope, selector analyzer.RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) {
33-
// Track which VDTs we transform so we can update parent nodes
34-
transformedVDTs := make(map[*plan.ValueDerivedTable]sql.Schema)
35-
36-
// First pass: transform VDTs and record their new schemas
37-
node, same1, err := transform.NodeWithOpaque(node, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
36+
// Track which VDTs we transform so we can update GetField nodes
37+
transformedVDTs := make(map[sql.TableId]sql.Schema)
38+
// First we transform VDTs and record their new schemas
39+
node, same, err := transform.NodeWithOpaque(node, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
3840
newNode, same, err := transformValuesNode(n)
3941
if err != nil {
4042
return nil, same, err
4143
}
4244
if !same {
4345
if vdt, ok := newNode.(*plan.ValueDerivedTable); ok {
44-
transformedVDTs[vdt] = vdt.Schema()
46+
transformedVDTs[vdt.Id()] = vdt.Schema()
4547
}
4648
}
4749
return newNode, same, err
@@ -50,183 +52,61 @@ func ResolveValuesTypes(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, s
5052
return nil, transform.SameTree, err
5153
}
5254

53-
// Second pass: update GetField types in parent nodes that reference transformed VDTs
55+
// Next we update all GetField expressions that refer to a transformed VDT
5456
if len(transformedVDTs) > 0 {
55-
node, _, err = transform.NodeWithOpaque(node, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
56-
return updateGetFieldTypes(n, transformedVDTs)
57-
})
58-
if err != nil {
59-
return nil, transform.SameTree, err
60-
}
61-
}
62-
63-
return node, same1, nil
64-
}
65-
66-
// getSourceSchema traverses through wrapper nodes (GroupBy, Filter, etc.) to find
67-
// the actual source schema from a VDT or other data source. This is needed because
68-
// nodes like GroupBy produce a different output schema than their input schema.
69-
func getSourceSchema(n sql.Node) sql.Schema {
70-
switch node := n.(type) {
71-
case *plan.GroupBy:
72-
// GroupBy's Schema() returns aggregate output, but we need the source schema
73-
return getSourceSchema(node.Child)
74-
case *plan.Filter:
75-
return getSourceSchema(node.Child)
76-
case *plan.Sort:
77-
return getSourceSchema(node.Child)
78-
case *plan.Limit:
79-
return getSourceSchema(node.Child)
80-
case *plan.Offset:
81-
return getSourceSchema(node.Child)
82-
case *plan.Distinct:
83-
return getSourceSchema(node.Child)
84-
case *plan.SubqueryAlias:
85-
// SubqueryAlias wraps a VDT - get the child's schema
86-
return node.Child.Schema()
87-
case *plan.ValueDerivedTable:
88-
return node.Schema()
89-
default:
90-
// For other nodes, return their schema directly
91-
return n.Schema()
92-
}
93-
}
94-
95-
// updateGetFieldTypes updates GetField expressions that reference transformed VDT columns
96-
func updateGetFieldTypes(n sql.Node, transformedVDTs map[*plan.ValueDerivedTable]sql.Schema) (sql.Node, transform.TreeIdentity, error) {
97-
// Only handle nodes that have expressions (like Project)
98-
exprNode, ok := n.(sql.Expressioner)
99-
if !ok {
100-
return n, transform.SameTree, nil
101-
}
102-
103-
// Get the source schema by traversing through wrapper nodes like GroupBy
104-
// This ensures we get the VDT's schema, not the aggregate output schema
105-
var childSchema sql.Schema
106-
switch node := n.(type) {
107-
case *plan.Project:
108-
childSchema = getSourceSchema(node.Child)
109-
case *plan.SubqueryAlias:
110-
childSchema = node.Child.Schema()
111-
default:
112-
return n, transform.SameTree, nil
113-
}
57+
node, _, err = pgtransform.NodeExprsWithOpaque(node, func(expr sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
58+
gf, ok := expr.(*expression.GetField)
59+
if !ok {
60+
return expr, transform.SameTree, nil
61+
}
62+
newSch, ok := transformedVDTs[gf.TableId()]
63+
if !ok {
64+
return expr, transform.SameTree, nil
65+
}
11466

115-
if childSchema == nil {
116-
return n, transform.SameTree, nil
117-
}
67+
// GetField indices are 1-based in GMS planbuilder, so subtract 1 for schema access
68+
schemaIdx := gf.Index() - 1
69+
if schemaIdx < 0 || schemaIdx >= len(newSch) {
70+
return nil, transform.NewTree, errors.Newf("GetField `%s` on table `%s` uses invalid index `%d`",
71+
gf.Name(), gf.Table(), gf.Index())
72+
}
11873

119-
// Transform expressions to update GetField types (recursively for nested expressions)
120-
exprs := exprNode.Expressions()
121-
newExprs := make([]sql.Expression, len(exprs))
122-
changed := false
74+
newType := newSch[schemaIdx].Type
75+
if gf.Type() == newType {
76+
return expr, transform.SameTree, nil
77+
}
12378

124-
for i, expr := range exprs {
125-
newExpr, exprChanged, err := updateGetFieldExprRecursive(expr, childSchema)
79+
// Create a new expression with the updated type
80+
newGf := expression.NewGetFieldWithTable(
81+
gf.Index(),
82+
int(gf.TableId()),
83+
newType,
84+
gf.Database(),
85+
gf.Table(),
86+
gf.Name(),
87+
gf.IsNullable(),
88+
)
89+
return newGf, transform.NewTree, nil
90+
})
12691
if err != nil {
12792
return nil, transform.SameTree, err
12893
}
129-
newExprs[i] = newExpr
130-
if exprChanged {
131-
changed = true
132-
}
133-
}
134-
135-
if !changed {
136-
return n, transform.SameTree, nil
13794
}
13895

139-
newNode, err := exprNode.WithExpressions(newExprs...)
140-
if err != nil {
141-
return nil, transform.SameTree, err
142-
}
143-
return newNode.(sql.Node), transform.NewTree, nil
96+
return node, same, nil
14497
}
14598

146-
// updateGetFieldExprRecursive recursively updates GetField expressions in the expression tree
147-
func updateGetFieldExprRecursive(expr sql.Expression, childSchema sql.Schema) (sql.Expression, bool, error) {
148-
// First try to update if this is a GetField
149-
if _, ok := expr.(*expression.GetField); ok {
150-
return updateGetFieldExpr(expr, childSchema)
151-
}
152-
153-
// Recursively process children
154-
children := expr.Children()
155-
if len(children) == 0 {
156-
return expr, false, nil
157-
}
158-
159-
newChildren := make([]sql.Expression, len(children))
160-
changed := false
161-
for i, child := range children {
162-
newChild, childChanged, err := updateGetFieldExprRecursive(child, childSchema)
163-
if err != nil {
164-
return nil, false, err
165-
}
166-
newChildren[i] = newChild
167-
if childChanged {
168-
changed = true
169-
}
170-
}
171-
172-
if !changed {
173-
return expr, false, nil
174-
}
175-
176-
newExpr, err := expr.WithChildren(newChildren...)
177-
if err != nil {
178-
return nil, false, err
179-
}
180-
return newExpr, true, nil
181-
}
182-
183-
// updateGetFieldExpr updates a GetField expression to use the correct type from the child schema
184-
func updateGetFieldExpr(expr sql.Expression, childSchema sql.Schema) (sql.Expression, bool, error) {
185-
gf, ok := expr.(*expression.GetField)
186-
if !ok {
187-
return expr, false, nil
188-
}
189-
190-
idx := gf.Index()
191-
// GetField indices are 1-based in GMS planbuilder, so subtract 1 for schema access
192-
schemaIdx := idx - 1
193-
if schemaIdx < 0 || schemaIdx >= len(childSchema) {
194-
return expr, false, nil
195-
}
196-
197-
newType := childSchema[schemaIdx].Type
198-
if gf.Type() == newType {
199-
return expr, false, nil
200-
}
201-
202-
// Create a new GetField with the updated type
203-
newGf := expression.NewGetFieldWithTable(
204-
idx,
205-
int(gf.TableId()),
206-
newType,
207-
gf.Database(),
208-
gf.Table(),
209-
gf.Name(),
210-
gf.IsNullable(),
211-
)
212-
return newGf, true, nil
213-
}
214-
215-
// transformValuesNode transforms a VALUES or ValueDerivedTable node to use common types
99+
// transformValuesNode transforms a plan.Values or plan.ValueDerivedTable node to use common types
216100
func transformValuesNode(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
217-
// Handle both ValueDerivedTable and Values nodes
218101
var values *plan.Values
219-
var vdt *plan.ValueDerivedTable
220-
var isVDT bool
221-
102+
var expressionerNode sql.Expressioner
222103
switch v := n.(type) {
223104
case *plan.ValueDerivedTable:
224-
vdt = v
225105
values = v.Values
226-
isVDT = true
106+
expressionerNode = v
227107
case *plan.Values:
228108
values = v
229-
isVDT = false
109+
expressionerNode = v
230110
default:
231111
return n, transform.SameTree, nil
232112
}
@@ -235,8 +115,12 @@ func transformValuesNode(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
235115
if len(values.ExpressionTuples) <= 1 {
236116
return n, transform.SameTree, nil
237117
}
238-
239118
numCols := len(values.ExpressionTuples[0])
119+
for i := 1; i < len(values.ExpressionTuples); i++ {
120+
if len(values.ExpressionTuples[i]) != numCols {
121+
return nil, transform.NewTree, errors.New("VALUES lists must all be the same length")
122+
}
123+
}
240124
if numCols == 0 {
241125
return n, transform.SameTree, nil
242126
}
@@ -252,78 +136,41 @@ func transformValuesNode(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
252136
} else if pgType, ok := exprType.(*pgtypes.DoltgresType); ok {
253137
columnTypes[colIdx][rowIdx] = pgType
254138
} else {
255-
// Non-DoltgresType encountered - should have been sanitized
256-
// Return unchanged and let TypeSanitizer handle it
257-
return n, transform.SameTree, nil
139+
return n, transform.NewTree, errors.New("VALUES cannot use GMS types")
258140
}
259141
}
260142
}
261143

262144
// Find common type for each column
263-
commonTypes := make([]*pgtypes.DoltgresType, numCols)
145+
var newTuples [][]sql.Expression
264146
for colIdx := 0; colIdx < numCols; colIdx++ {
265-
commonType, err := framework.FindCommonType(columnTypes[colIdx])
147+
commonType, requiresCasts, err := framework.FindCommonType(columnTypes[colIdx])
266148
if err != nil {
267149
return nil, transform.NewTree, err
268150
}
269-
commonTypes[colIdx] = commonType
270-
}
271-
272-
// Check if any changes are needed
273-
needsChange := false
274-
for colIdx := 0; colIdx < numCols; colIdx++ {
275-
for rowIdx := 0; rowIdx < len(values.ExpressionTuples); rowIdx++ {
276-
if !columnTypes[colIdx][rowIdx].Equals(commonTypes[colIdx]) {
277-
needsChange = true
278-
break
151+
// If we require any casts, then we'll add casting to all expressions in the list
152+
if requiresCasts {
153+
if len(newTuples) == 0 {
154+
newTuples = make([][]sql.Expression, len(values.ExpressionTuples))
155+
copy(newTuples, values.ExpressionTuples)
156+
}
157+
for rowIdx := 0; rowIdx < len(newTuples); rowIdx++ {
158+
newTuples[rowIdx][colIdx] = pgexprs.NewImplicitCast(
159+
newTuples[rowIdx][colIdx], columnTypes[colIdx][rowIdx], commonType)
279160
}
280-
}
281-
if needsChange {
282-
break
283161
}
284162
}
285-
286-
if !needsChange {
163+
// If we didn't require any casts, then we can simply return our old node
164+
if len(newTuples) == 0 {
287165
return n, transform.SameTree, nil
288166
}
289167

290-
// Create new expression tuples with implicit casts where needed
291-
newTuples := make([][]sql.Expression, len(values.ExpressionTuples))
292-
for rowIdx, row := range values.ExpressionTuples {
293-
newTuples[rowIdx] = make([]sql.Expression, numCols)
294-
for colIdx, expr := range row {
295-
fromType := columnTypes[colIdx][rowIdx]
296-
toType := commonTypes[colIdx]
297-
if fromType.Equals(toType) {
298-
newTuples[rowIdx][colIdx] = expr
299-
} else if fromType.ID == pgtypes.Unknown.ID {
300-
// Unknown type can be coerced to any type without explicit cast
301-
// Use UnknownCoercion to report the target type while passing through values
302-
newTuples[rowIdx][colIdx] = pgexprs.NewUnknownCoercion(expr, toType)
303-
} else {
304-
newTuples[rowIdx][colIdx] = pgexprs.NewImplicitCast(expr, fromType, toType)
305-
}
306-
}
307-
}
308-
309168
// Flatten the new tuples into a single expression slice for WithExpressions
310169
var flatExprs []sql.Expression
311170
for _, row := range newTuples {
312171
flatExprs = append(flatExprs, row...)
313172
}
314-
315-
if isVDT {
316-
// Use WithExpressions to preserve all VDT fields (name, columns, id, cols)
317-
// while updating the expressions and recalculating the schema
318-
newNode, err := vdt.WithExpressions(flatExprs...)
319-
if err != nil {
320-
return nil, transform.NewTree, err
321-
}
322-
return newNode, transform.NewTree, nil
323-
}
324-
325-
// For standalone Values node, use WithExpressions as well
326-
newNode, err := values.WithExpressions(flatExprs...)
173+
newNode, err := expressionerNode.WithExpressions(flatExprs...)
327174
if err != nil {
328175
return nil, transform.NewTree, err
329176
}

server/expression/array.go

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,7 @@ func (array *Array) Eval(ctx *sql.Context, row sql.Row) (any, error) {
8282
// We always cast the element, as there may be parameter restrictions in place
8383
castFunc := framework.GetImplicitCast(doltgresType, resultTyp)
8484
if castFunc == nil {
85-
if doltgresType.ID == pgtypes.Unknown.ID {
86-
castFunc = framework.UnknownLiteralCast
87-
} else {
88-
return nil, errors.Errorf("cannot find cast function from %s to %s", doltgresType.String(), resultTyp.String())
89-
}
85+
return nil, errors.Errorf("cannot find cast function from %s to %s", doltgresType.String(), resultTyp.String())
9086
}
9187

9288
values[i], err = castFunc(ctx, val, resultTyp)
@@ -175,7 +171,7 @@ func (array *Array) getTargetType(children ...sql.Expression) (*pgtypes.Doltgres
175171
childrenTypes = append(childrenTypes, childType)
176172
}
177173
}
178-
targetType, err := framework.FindCommonType(childrenTypes)
174+
targetType, _, err := framework.FindCommonType(childrenTypes)
179175
if err != nil {
180176
return nil, errors.Errorf("ARRAY %s", err.Error())
181177
}

0 commit comments

Comments
 (0)