Skip to content

Commit f324664

Browse files
committed
fix(analyzer): fix JOIN and aggregate type propagation bugs
Fix two bugs in `ResolveValuesTypes` func that were introduced by our initial code implementation. Both bugs only showed up when VALUES type inference interacted with JOINs or aggregates: - Bug 1: JOIN GetField index: The original code used gf.Index() - 1 to look up columns in VDT schemas, but GetField indices are global across joined tables (e.g., a.n=0, b.id=1, b.label=2), not per-table offsets. This caused out-of-bounds errors in JOIN's. Fixed by matching cols by name instead of index calc'ing. - Bug 2: Aggregate type propagation: The first pass updates GetFields that read directly from a VDT, BUT when a type change ripples through an aggregate (e.g., int4 to numeric inside MIN), the aggregate return type changes while parent nodes still have GetFields with the old type. This can cause runtime panics from type/value mismatches. Fixed by adding a second pass that syncs each GetField type with the child node's actual schema. Test updates: SUM now returns numeric instead of float64 when operating on numeric inputs (matches PostgreSQL behavior). Unskipped 3 tests (2 JOIN, 1 MIN/MAX) that now pass. Refs: #1648
1 parent d4e1d58 commit f324664

2 files changed

Lines changed: 90 additions & 34 deletions

File tree

server/analyzer/resolve_values_types.go

Lines changed: 83 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
package analyzer
1616

1717
import (
18+
"strings"
19+
1820
"github.com/cockroachdb/errors"
1921
"github.com/dolthub/go-mysql-server/sql"
2022
"github.com/dolthub/go-mysql-server/sql/analyzer"
@@ -33,9 +35,9 @@ import (
3335
// by examining all rows, following PostgreSQL's type resolution rules.
3436
// This ensures VALUES(1),(2.01),(3) correctly infers numeric type, not integer.
3537
func ResolveValuesTypes(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, scope *plan.Scope, selector analyzer.RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) {
36-
// Track which VDTs we transform so we can update GetField nodes
38+
// Walk the tree and wrap mixed-type VALUES columns with ImplicitCast.
39+
// We record which VDTs changed so we can fix up GetField types afterward.
3740
transformedVDTs := make(map[sql.TableId]sql.Schema)
38-
// First we transform VDTs and record their new schemas
3941
node, same, err := transform.NodeWithOpaque(node, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
4042
newNode, same, err := transformValuesNode(n)
4143
if err != nil {
@@ -52,7 +54,10 @@ func ResolveValuesTypes(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, s
5254
return nil, transform.SameTree, err
5355
}
5456

55-
// Next we update all GetField expressions that refer to a transformed VDT
57+
// Now, fix GetField types that reference a transformed VDT. For example,
58+
// after wrapping VALUES(1),(2.5) with ImplicitCast to numeric, any
59+
// GetField reading column "n" from that VDT still says int4 and needs
60+
// to be updated to numeric.
5661
if len(transformedVDTs) > 0 {
5762
node, _, err = pgtransform.NodeExprsWithOpaque(node, func(expr sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
5863
gf, ok := expr.(*expression.GetField)
@@ -64,29 +69,74 @@ func ResolveValuesTypes(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, s
6469
return expr, transform.SameTree, nil
6570
}
6671

67-
// GetField indices are 1-based in GMS planbuilder, so subtract 1 for schema access
68-
schemaIdx := gf.Index() - 1
69-
if schemaIdx < 0 || schemaIdx >= len(newSch) {
70-
return nil, transform.NewTree, errors.Errorf("VALUES: GetField `%s` on table `%s` uses invalid index `%d`",
71-
gf.Name(), gf.Table(), gf.Index())
72+
// We match by column name because GetField indices are global
73+
// across all tables in a JOIN (e.g., a.n=0, b.id=1, b.label=2).
74+
// We can't convert a global index to a per-table position without
75+
// knowing the table's starting offset, which we don't have here.
76+
schemaIdx := -1
77+
for i, col := range newSch {
78+
if col.Name == gf.Name() {
79+
schemaIdx = i
80+
break
81+
}
82+
}
83+
if schemaIdx < 0 {
84+
return expr, transform.SameTree, nil
7285
}
7386

7487
newType := newSch[schemaIdx].Type
7588
if gf.Type() == newType {
7689
return expr, transform.SameTree, nil
7790
}
7891

79-
// Create a new expression with the updated type
80-
newGf := expression.NewGetFieldWithTable(
81-
gf.Index(),
82-
int(gf.TableId()),
83-
newType,
84-
gf.Database(),
85-
gf.Table(),
86-
gf.Name(),
87-
gf.IsNullable(),
88-
)
89-
return newGf, transform.NewTree, nil
92+
return getFieldWithType(gf, newType), transform.NewTree, nil
93+
})
94+
if err != nil {
95+
return nil, transform.SameTree, err
96+
}
97+
98+
// The pass above only fixed GetFields that read directly from a VDT
99+
// (matched by tableId). But changing a VDT column's type can have a
100+
// ripple effect: if that column feeds into an aggregate like MIN or
101+
// MAX, the aggregate's return type changes too. Parent nodes that
102+
// read the aggregate result still have the old type. For example:
103+
//
104+
// SELECT MIN(n) FROM (VALUES(1),(2.5)) v(n)
105+
//
106+
// Project [GetField("min(v.n)", tableId=GroupBy, type=int4)]
107+
// └── GroupBy [MIN(GetField("n", tableId=VDT, type=numeric))]
108+
// └── VDT [n: int4 → numeric]
109+
//
110+
// The pass above fixed "n" inside MIN because its tableId=VDT.
111+
// MIN now returns numeric, so GroupBy produces numeric. But the
112+
// Project's GetField still says int4 because its tableId=GroupBy,
113+
// which wasn't in transformedVDTs. At runtime this causes a panic
114+
// because the actual value is decimal.Decimal but the type says int32.
115+
//
116+
// This pass catches those: for each GetField, check if its type
117+
// disagrees with what the child node actually produces.
118+
node, _, err = pgtransform.NodeExprsWithNodeWithOpaque(node, func(n sql.Node, expr sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
119+
gf, ok := expr.(*expression.GetField)
120+
if !ok {
121+
return expr, transform.SameTree, nil
122+
}
123+
// Skip VDT GetFields — the first pass already handled these
124+
if _, isVDT := transformedVDTs[gf.TableId()]; isVDT {
125+
return expr, transform.SameTree, nil
126+
}
127+
// Collect the schema that this node's children produce
128+
var childSchema sql.Schema
129+
for _, child := range n.Children() {
130+
childSchema = append(childSchema, child.Schema()...)
131+
}
132+
// Find the matching column by name and update if the type changed
133+
gfNameLower := strings.ToLower(gf.Name())
134+
for _, col := range childSchema {
135+
if strings.ToLower(col.Name) == gfNameLower && gf.Type() != col.Type {
136+
return getFieldWithType(gf, col.Type), transform.NewTree, nil
137+
}
138+
}
139+
return expr, transform.SameTree, nil
90140
})
91141
if err != nil {
92142
return nil, transform.SameTree, err
@@ -96,6 +146,19 @@ func ResolveValuesTypes(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, s
96146
return node, same, nil
97147
}
98148

149+
// getFieldWithType returns a copy of the GetField with a new type.
150+
func getFieldWithType(gf *expression.GetField, newType sql.Type) *expression.GetField {
151+
return expression.NewGetFieldWithTable(
152+
gf.Index(),
153+
int(gf.TableId()),
154+
newType,
155+
gf.Database(),
156+
gf.Table(),
157+
gf.Name(),
158+
gf.IsNullable(),
159+
)
160+
}
161+
99162
// transformValuesNode transforms a plan.Values or plan.ValueDerivedTable node to use common types
100163
func transformValuesNode(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
101164
var values *plan.Values
@@ -170,7 +233,7 @@ func transformValuesNode(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
170233
}
171234

172235
// Flatten the new tuples into a single expression slice for WithExpressions
173-
var flatExprs []sql.Expression
236+
flatExprs := make([]sql.Expression, 0, len(newTuples)*len(newTuples[0]))
174237
for _, row := range newTuples {
175238
flatExprs = append(flatExprs, row...)
176239
}

testing/go/values_statement_test.go

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -77,19 +77,18 @@ var ValuesStatementTests = []ScriptTest{
7777
},
7878
{
7979
// SUM should work directly now that VALUES has correct type
80-
// Note: SUM returns float64 (double precision) for numeric input
8180
Query: `SELECT SUM(n) FROM (VALUES(1),(2.01),(3)) v(n);`,
82-
Expected: []sql.Row{{6.01}},
81+
Expected: []sql.Row{{Numeric("6.01")}},
8382
},
8483
{
8584
// Exact repro from issue #1648: integer first, explicit cast to numeric
8685
Query: `SELECT SUM(n::numeric) FROM (VALUES(1),(2.01),(3)) v(n);`,
87-
Expected: []sql.Row{{6.01}},
86+
Expected: []sql.Row{{Numeric("6.01")}},
8887
},
8988
{
9089
// Exact repro from issue #1648: decimal first, explicit cast to numeric
9190
Query: `SELECT SUM(n::numeric) FROM (VALUES(1.01),(2),(3)) v(n);`,
92-
Expected: []sql.Row{{6.01}},
91+
Expected: []sql.Row{{Numeric("6.01")}},
9392
},
9493
},
9594
},
@@ -123,8 +122,8 @@ var ValuesStatementTests = []ScriptTest{
123122
// SUM with GROUP BY
124123
Query: `SELECT category, SUM(amount) FROM (VALUES('a', 1),('b', 2.5),('a', 3),('b', 4.5)) v(category, amount) GROUP BY category ORDER BY category;`,
125124
Expected: []sql.Row{
126-
{"a", 4.0},
127-
{"b", 7.0},
125+
{"a", Numeric("4")},
126+
{"b", Numeric("7.0")},
128127
},
129128
},
130129
},
@@ -266,9 +265,6 @@ var ValuesStatementTests = []ScriptTest{
266265
},
267266
{
268267
// MIN/MAX on mixed types
269-
// TODO: ImplicitCast type/value mismatch causes panic; reported type is numeric but
270-
// underlying Go value is int32 for integer literals. See Hydrocharged's review comment.
271-
Skip: true,
272268
Query: `SELECT MIN(n), MAX(n) FROM (VALUES(1),(2.5),(3),(0.5)) v(n);`,
273269
Expected: []sql.Row{
274270
{Numeric("0.5"), Numeric("3")},
@@ -504,7 +500,7 @@ var ValuesStatementTests = []ScriptTest{
504500
{
505501
// SUM over CTE
506502
Query: `WITH nums AS (SELECT * FROM (VALUES(1),(2.5),(3)) v(n)) SELECT SUM(n) FROM nums;`,
507-
Expected: []sql.Row{{6.5}},
503+
Expected: []sql.Row{{Numeric("6.5")}},
508504
},
509505
},
510506
},
@@ -513,8 +509,6 @@ var ValuesStatementTests = []ScriptTest{
513509
SetUpScript: []string{},
514510
Assertions: []ScriptTestAssertion{
515511
{
516-
// TODO: GetField indices are global across joined tables but treated as per-table
517-
Skip: true,
518512
Query: `SELECT a.n, b.label FROM (VALUES(1),(2),(3)) a(n) JOIN (VALUES(1, 'one'),(2, 'two'),(3, 'three')) b(id, label) ON a.n = b.id;`,
519513
Expected: []sql.Row{
520514
{int32(1), "one"},
@@ -523,8 +517,7 @@ var ValuesStatementTests = []ScriptTest{
523517
},
524518
},
525519
{
526-
// TODO: same GetField index issue as above
527-
Skip: true,
520+
// Mixed types in one of the joined VALUES
528521
Query: `SELECT a.n, b.label FROM (VALUES(1),(2.5),(3)) a(n) JOIN (VALUES(1, 'one'),(3, 'three')) b(id, label) ON a.n = b.id;`,
529522
Expected: []sql.Row{
530523
{Numeric("1"), "one"},

0 commit comments

Comments
 (0)