Skip to content

Commit fd9c0ce

Browse files
authored
Merge pull request #2187 from codeaucafe/codeaucafe/1648-values-clause-type-inference
#1648: fix VALUES clause type inference
2 parents ee471e2 + 6ca7492 commit fd9c0ce

6 files changed

Lines changed: 986 additions & 5 deletions

File tree

server/analyzer/init.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,15 @@ const (
4949
ruleId_ValidateCreateSchema // validateCreateSchema
5050
ruleId_ResolveAlterColumn // resolveAlterColumn
5151
ruleId_ValidateCreateFunction
52+
ruleId_ResolveValuesTypes // resolveValuesTypes
5253
)
5354

5455
// Init adds additional rules to the analyzer to handle Doltgres-specific functionality.
5556
func Init() {
5657
analyzer.AlwaysBeforeDefault = append(analyzer.AlwaysBeforeDefault,
5758
analyzer.Rule{Id: ruleId_ResolveType, Apply: ResolveType},
5859
analyzer.Rule{Id: ruleId_TypeSanitizer, Apply: TypeSanitizer},
60+
analyzer.Rule{Id: ruleId_ResolveValuesTypes, Apply: ResolveValuesTypes},
5961
analyzer.Rule{Id: ruleId_GenerateForeignKeyName, Apply: generateForeignKeyName},
6062
analyzer.Rule{Id: ruleId_AddDomainConstraints, Apply: AddDomainConstraints},
6163
analyzer.Rule{Id: ruleId_ValidateColumnDefaults, Apply: ValidateColumnDefaults},
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
// Copyright 2026 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package analyzer
16+
17+
import (
18+
"strings"
19+
20+
"github.com/cockroachdb/errors"
21+
"github.com/dolthub/go-mysql-server/sql"
22+
"github.com/dolthub/go-mysql-server/sql/analyzer"
23+
"github.com/dolthub/go-mysql-server/sql/expression"
24+
"github.com/dolthub/go-mysql-server/sql/plan"
25+
"github.com/dolthub/go-mysql-server/sql/transform"
26+
27+
pgexprs "github.com/dolthub/doltgresql/server/expression"
28+
"github.com/dolthub/doltgresql/server/functions/framework"
29+
pgtransform "github.com/dolthub/doltgresql/server/transform"
30+
pgtypes "github.com/dolthub/doltgresql/server/types"
31+
)
32+
33+
// ResolveValuesTypes determines the common type for each column in a VALUES clause
34+
// by examining all rows, following PostgreSQL's type resolution rules.
35+
// This ensures VALUES(1),(2.01),(3) correctly infers numeric type, not integer.
36+
func ResolveValuesTypes(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, scope *plan.Scope, selector analyzer.RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) {
37+
// Walk the tree and wrap mixed-type VALUES columns with ImplicitCast.
38+
// We record which VDTs changed so we can fix up GetField types afterward.
39+
transformedVDTs := make(map[sql.TableId]sql.Schema)
40+
node, same, err := transform.NodeWithOpaque(node, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
41+
newNode, same, err := transformValuesNode(n)
42+
if err != nil {
43+
return nil, same, err
44+
}
45+
if !same {
46+
if vdt, ok := newNode.(*plan.ValueDerivedTable); ok {
47+
transformedVDTs[vdt.Id()] = vdt.Schema()
48+
}
49+
}
50+
return newNode, same, err
51+
})
52+
if err != nil {
53+
return nil, transform.SameTree, err
54+
}
55+
56+
// Now, fix GetField types that reference a transformed VDT. For example,
57+
// after wrapping VALUES(1),(2.5) with ImplicitCast to numeric, any
58+
// GetField reading column "n" from that VDT still says int4 and needs
59+
// to be updated to numeric.
60+
if len(transformedVDTs) > 0 {
61+
node, _, err = pgtransform.NodeExprsWithOpaque(node, func(expr sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
62+
gf, ok := expr.(*expression.GetField)
63+
if !ok {
64+
return expr, transform.SameTree, nil
65+
}
66+
newSch, ok := transformedVDTs[gf.TableId()]
67+
if !ok {
68+
return expr, transform.SameTree, nil
69+
}
70+
71+
// We match by column name because GetField indices are global
72+
// across all tables in a JOIN (e.g., a.n=0, b.id=1, b.label=2).
73+
// We can't convert a global index to a per-table position without
74+
// knowing the table's starting offset, which we don't have here.
75+
schemaIdx := -1
76+
for i, col := range newSch {
77+
if col.Name == gf.Name() {
78+
schemaIdx = i
79+
break
80+
}
81+
}
82+
if schemaIdx < 0 {
83+
return expr, transform.SameTree, nil
84+
}
85+
86+
newType := newSch[schemaIdx].Type
87+
if gf.Type() == newType {
88+
return expr, transform.SameTree, nil
89+
}
90+
91+
return expression.NewGetFieldWithTable(
92+
gf.Index(), int(gf.TableId()), newType,
93+
gf.Database(), gf.Table(), gf.Name(), gf.IsNullable(),
94+
), transform.NewTree, nil
95+
})
96+
if err != nil {
97+
return nil, transform.SameTree, err
98+
}
99+
100+
// The pass above only fixed GetFields that read directly from a VDT
101+
// (matched by tableId). But changing a VDT column's type can have a
102+
// ripple effect: if that column feeds into an aggregate like MIN or
103+
// MAX, the aggregate's return type changes too. Parent nodes that
104+
// read the aggregate result still have the old type. For example:
105+
//
106+
// SELECT MIN(n) FROM (VALUES(1),(2.5)) v(n)
107+
//
108+
// Project [GetField("min(v.n)", tableId=GroupBy, type=int4)]
109+
// └── GroupBy [MIN(GetField("n", tableId=VDT, type=numeric))]
110+
// └── VDT [n: int4 → numeric]
111+
//
112+
// The pass above fixed "n" inside MIN because its tableId=VDT.
113+
// MIN now returns numeric, so GroupBy produces numeric. But the
114+
// Project's GetField still says int4 because its tableId=GroupBy,
115+
// which wasn't in transformedVDTs. At runtime this causes a panic
116+
// because the actual value is decimal.Decimal but the type says int32.
117+
//
118+
// This pass catches those: for each GetField, check if its type
119+
// disagrees with what the child node actually produces.
120+
node, _, err = pgtransform.NodeExprsWithNodeWithOpaque(node, func(n sql.Node, expr sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
121+
gf, ok := expr.(*expression.GetField)
122+
if !ok {
123+
return expr, transform.SameTree, nil
124+
}
125+
// Skip VDT GetFields — the first pass already handled these
126+
if _, isVDT := transformedVDTs[gf.TableId()]; isVDT {
127+
return expr, transform.SameTree, nil
128+
}
129+
// Collect the schema that this node's children produce
130+
var childSchema sql.Schema
131+
for _, child := range n.Children() {
132+
childSchema = append(childSchema, child.Schema()...)
133+
}
134+
// TODO: GMS is case-insensitive for identifiers, so aggregate
135+
// GetField names and child schema names may differ in casing.
136+
// We use strings.ToLower to handle this, but Postgres requires
137+
// case-sensitivity for quoted identifiers, which this breaks.
138+
gfName := strings.ToLower(gf.Name())
139+
for _, col := range childSchema {
140+
if strings.ToLower(col.Name) == gfName && gf.Type() != col.Type {
141+
return expression.NewGetFieldWithTable(
142+
gf.Index(), int(gf.TableId()), col.Type,
143+
gf.Database(), gf.Table(), gf.Name(), gf.IsNullable(),
144+
), transform.NewTree, nil
145+
}
146+
}
147+
return expr, transform.SameTree, nil
148+
})
149+
if err != nil {
150+
return nil, transform.SameTree, err
151+
}
152+
}
153+
154+
return node, same, nil
155+
}
156+
157+
// transformValuesNode transforms a plan.Values or plan.ValueDerivedTable node to use common types
158+
func transformValuesNode(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
159+
var values *plan.Values
160+
var expressionerNode sql.Expressioner
161+
switch v := n.(type) {
162+
case *plan.ValueDerivedTable:
163+
values = v.Values
164+
expressionerNode = v
165+
case *plan.Values:
166+
values = v
167+
expressionerNode = v
168+
default:
169+
return n, transform.SameTree, nil
170+
}
171+
172+
// Skip if no rows or single row (nothing to unify)
173+
if len(values.ExpressionTuples) <= 1 {
174+
return n, transform.SameTree, nil
175+
}
176+
numCols := len(values.ExpressionTuples[0])
177+
for i := 1; i < len(values.ExpressionTuples); i++ {
178+
if len(values.ExpressionTuples[i]) != numCols {
179+
return nil, transform.NewTree, errors.Errorf("VALUES: row %d has %d columns, expected %d", i+1, len(values.ExpressionTuples[i]), numCols)
180+
}
181+
}
182+
if numCols == 0 {
183+
return n, transform.SameTree, nil
184+
}
185+
186+
// Collect types for each column across all rows
187+
columnTypes := make([][]*pgtypes.DoltgresType, numCols)
188+
for colIdx := 0; colIdx < numCols; colIdx++ {
189+
columnTypes[colIdx] = make([]*pgtypes.DoltgresType, len(values.ExpressionTuples))
190+
for rowIdx, row := range values.ExpressionTuples {
191+
exprType := row[colIdx].Type()
192+
if exprType == nil {
193+
columnTypes[colIdx][rowIdx] = pgtypes.Unknown
194+
} else if pgType, ok := exprType.(*pgtypes.DoltgresType); ok {
195+
columnTypes[colIdx][rowIdx] = pgType
196+
} else {
197+
return nil, transform.NewTree, errors.Errorf("VALUES: non-Doltgres type found in row %d, column %d: %s", rowIdx, colIdx, exprType.String())
198+
}
199+
}
200+
}
201+
202+
// Find common type for each column
203+
var newTuples [][]sql.Expression
204+
for colIdx := 0; colIdx < numCols; colIdx++ {
205+
commonType, requiresCasts, err := framework.FindCommonType(columnTypes[colIdx])
206+
if err != nil {
207+
return nil, transform.NewTree, err
208+
}
209+
// If we require any casts, then we'll add casting to all expressions in the list
210+
if requiresCasts {
211+
if len(newTuples) == 0 {
212+
// Deep copy to avoid mutating the original expression tuples.
213+
newTuples = make([][]sql.Expression, len(values.ExpressionTuples))
214+
for i, row := range values.ExpressionTuples {
215+
newTuples[i] = make([]sql.Expression, len(row))
216+
copy(newTuples[i], row)
217+
}
218+
}
219+
for rowIdx := 0; rowIdx < len(newTuples); rowIdx++ {
220+
newTuples[rowIdx][colIdx] = pgexprs.NewImplicitCast(
221+
newTuples[rowIdx][colIdx], columnTypes[colIdx][rowIdx], commonType)
222+
}
223+
}
224+
}
225+
// If we didn't require any casts, then we can simply return our old node
226+
if len(newTuples) == 0 {
227+
return n, transform.SameTree, nil
228+
}
229+
230+
// Flatten the new tuples into a single expression slice for WithExpressions
231+
flatExprs := make([]sql.Expression, 0, len(newTuples)*len(newTuples[0]))
232+
for _, row := range newTuples {
233+
flatExprs = append(flatExprs, row...)
234+
}
235+
newNode, err := expressionerNode.WithExpressions(flatExprs...)
236+
if err != nil {
237+
return nil, transform.NewTree, err
238+
}
239+
return newNode, transform.NewTree, nil
240+
}

server/expression/explicit_cast.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,11 @@ func (c *ExplicitCast) Eval(ctx *sql.Context, row sql.Row) (any, error) {
9696
baseCastToType := checkForDomainType(c.castToType)
9797
castFunction := framework.GetExplicitCast(fromType, baseCastToType)
9898
if castFunction == nil {
99-
return nil, errors.Errorf("EXPLICIT CAST: cast from `%s` to `%s` does not exist: %s",
100-
fromType.String(), c.castToType.String(), c.sqlChild.String())
99+
return nil, errors.Errorf(
100+
"EXPLICIT CAST: cast from `%s` to `%s` does not exist: %s",
101+
fromType.String(), c.castToType.String(), c.sqlChild.String(),
102+
)
103+
101104
}
102105
castResult, err := castFunction(ctx, val, c.castToType)
103106
if err != nil {

testing/bats/types.bats

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,4 @@ SQL
3737
[[ "$output" =~ '4,"{f,f}"' ]] || false
3838
[[ "$output" =~ '5,{t}' ]] || false
3939
[[ "$output" =~ '6,{f}' ]] || false
40-
}
40+
}

0 commit comments

Comments
 (0)