Skip to content

Commit 7cdc339

Browse files
committed
feat(analyzer): fix VALUES clause type inference
Add ResolveValuesTypes analyzer rule to compute common types across all VALUES rows, not just the first row. Previously, DoltgreSQL would incorrectly use only the first value to determine column types, causing errors when subsequent values had different types like VALUES(1),(2.01),(3). Changes: - Two-pass transformation strategy: first pass transforms VDT nodes with unified types, second pass updates GetField expressions in parent nodes - Use FindCommonType() to resolve types per PostgreSQL rules - Apply ImplicitCast for type conversions and UnknownCoercion for unknown-typed literals - Handle aggregates via getSourceSchema() - Add UnknownCoercion expression type for unknown -> target coercion without conversion Tests: - Add 4 bats integration tests for mixed int/decimal VALUES - Add 3 Go test cases covering int-first, decimal-first, SUM aggregate, and multi-column scenarios Refs: #1648
1 parent bf08254 commit 7cdc339

6 files changed

Lines changed: 477 additions & 4 deletions

File tree

server/analyzer/init.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,15 @@ const (
4949
ruleId_ValidateCreateSchema // validateCreateSchema
5050
ruleId_ResolveAlterColumn // resolveAlterColumn
5151
ruleId_ValidateCreateFunction
52+
ruleId_ResolveValuesTypes // resolveValuesTypes
5253
)
5354

5455
// Init adds additional rules to the analyzer to handle Doltgres-specific functionality.
5556
func Init() {
5657
analyzer.AlwaysBeforeDefault = append(analyzer.AlwaysBeforeDefault,
5758
analyzer.Rule{Id: ruleId_ResolveType, Apply: ResolveType},
5859
analyzer.Rule{Id: ruleId_TypeSanitizer, Apply: TypeSanitizer},
60+
analyzer.Rule{Id: ruleId_ResolveValuesTypes, Apply: ResolveValuesTypes},
5961
analyzer.Rule{Id: ruleId_GenerateForeignKeyName, Apply: generateForeignKeyName},
6062
analyzer.Rule{Id: ruleId_AddDomainConstraints, Apply: AddDomainConstraints},
6163
analyzer.Rule{Id: ruleId_ValidateColumnDefaults, Apply: ValidateColumnDefaults},
Lines changed: 331 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,331 @@
1+
// Copyright 2026 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package analyzer
16+
17+
import (
18+
"github.com/dolthub/go-mysql-server/sql"
19+
"github.com/dolthub/go-mysql-server/sql/analyzer"
20+
"github.com/dolthub/go-mysql-server/sql/expression"
21+
"github.com/dolthub/go-mysql-server/sql/plan"
22+
"github.com/dolthub/go-mysql-server/sql/transform"
23+
24+
pgexprs "github.com/dolthub/doltgresql/server/expression"
25+
"github.com/dolthub/doltgresql/server/functions/framework"
26+
pgtypes "github.com/dolthub/doltgresql/server/types"
27+
)
28+
29+
// ResolveValuesTypes determines the common type for each column in a VALUES clause
30+
// by examining all rows, following PostgreSQL's type resolution rules.
31+
// This ensures VALUES(1),(2.01),(3) correctly infers numeric type, not integer.
32+
func ResolveValuesTypes(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, scope *plan.Scope, selector analyzer.RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) {
33+
// Track which VDTs we transform so we can update parent nodes
34+
transformedVDTs := make(map[*plan.ValueDerivedTable]sql.Schema)
35+
36+
// First pass: transform VDTs and record their new schemas
37+
node, same1, err := transform.NodeWithOpaque(node, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
38+
newNode, same, err := transformValuesNode(n)
39+
if err != nil {
40+
return nil, same, err
41+
}
42+
if !same {
43+
if vdt, ok := newNode.(*plan.ValueDerivedTable); ok {
44+
transformedVDTs[vdt] = vdt.Schema()
45+
}
46+
}
47+
return newNode, same, err
48+
})
49+
if err != nil {
50+
return nil, transform.SameTree, err
51+
}
52+
53+
// Second pass: update GetField types in parent nodes that reference transformed VDTs
54+
if len(transformedVDTs) > 0 {
55+
node, _, err = transform.NodeWithOpaque(node, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
56+
return updateGetFieldTypes(n, transformedVDTs)
57+
})
58+
if err != nil {
59+
return nil, transform.SameTree, err
60+
}
61+
}
62+
63+
return node, same1, nil
64+
}
65+
66+
// getSourceSchema traverses through wrapper nodes (GroupBy, Filter, etc.) to find
67+
// the actual source schema from a VDT or other data source. This is needed because
68+
// nodes like GroupBy produce a different output schema than their input schema.
69+
func getSourceSchema(n sql.Node) sql.Schema {
70+
switch node := n.(type) {
71+
case *plan.GroupBy:
72+
// GroupBy's Schema() returns aggregate output, but we need the source schema
73+
return getSourceSchema(node.Child)
74+
case *plan.Filter:
75+
return getSourceSchema(node.Child)
76+
case *plan.Sort:
77+
return getSourceSchema(node.Child)
78+
case *plan.Limit:
79+
return getSourceSchema(node.Child)
80+
case *plan.Offset:
81+
return getSourceSchema(node.Child)
82+
case *plan.Distinct:
83+
return getSourceSchema(node.Child)
84+
case *plan.SubqueryAlias:
85+
// SubqueryAlias wraps a VDT - get the child's schema
86+
return node.Child.Schema()
87+
case *plan.ValueDerivedTable:
88+
return node.Schema()
89+
default:
90+
// For other nodes, return their schema directly
91+
return n.Schema()
92+
}
93+
}
94+
95+
// updateGetFieldTypes updates GetField expressions that reference transformed VDT columns
96+
func updateGetFieldTypes(n sql.Node, transformedVDTs map[*plan.ValueDerivedTable]sql.Schema) (sql.Node, transform.TreeIdentity, error) {
97+
// Only handle nodes that have expressions (like Project)
98+
exprNode, ok := n.(sql.Expressioner)
99+
if !ok {
100+
return n, transform.SameTree, nil
101+
}
102+
103+
// Get the source schema by traversing through wrapper nodes like GroupBy
104+
// This ensures we get the VDT's schema, not the aggregate output schema
105+
var childSchema sql.Schema
106+
switch node := n.(type) {
107+
case *plan.Project:
108+
childSchema = getSourceSchema(node.Child)
109+
case *plan.SubqueryAlias:
110+
childSchema = node.Child.Schema()
111+
default:
112+
return n, transform.SameTree, nil
113+
}
114+
115+
if childSchema == nil {
116+
return n, transform.SameTree, nil
117+
}
118+
119+
// Transform expressions to update GetField types (recursively for nested expressions)
120+
exprs := exprNode.Expressions()
121+
newExprs := make([]sql.Expression, len(exprs))
122+
changed := false
123+
124+
for i, expr := range exprs {
125+
newExpr, exprChanged, err := updateGetFieldExprRecursive(expr, childSchema)
126+
if err != nil {
127+
return nil, transform.SameTree, err
128+
}
129+
newExprs[i] = newExpr
130+
if exprChanged {
131+
changed = true
132+
}
133+
}
134+
135+
if !changed {
136+
return n, transform.SameTree, nil
137+
}
138+
139+
newNode, err := exprNode.WithExpressions(newExprs...)
140+
if err != nil {
141+
return nil, transform.SameTree, err
142+
}
143+
return newNode.(sql.Node), transform.NewTree, nil
144+
}
145+
146+
// updateGetFieldExprRecursive recursively updates GetField expressions in the expression tree
147+
func updateGetFieldExprRecursive(expr sql.Expression, childSchema sql.Schema) (sql.Expression, bool, error) {
148+
// First try to update if this is a GetField
149+
if _, ok := expr.(*expression.GetField); ok {
150+
return updateGetFieldExpr(expr, childSchema)
151+
}
152+
153+
// Recursively process children
154+
children := expr.Children()
155+
if len(children) == 0 {
156+
return expr, false, nil
157+
}
158+
159+
newChildren := make([]sql.Expression, len(children))
160+
changed := false
161+
for i, child := range children {
162+
newChild, childChanged, err := updateGetFieldExprRecursive(child, childSchema)
163+
if err != nil {
164+
return nil, false, err
165+
}
166+
newChildren[i] = newChild
167+
if childChanged {
168+
changed = true
169+
}
170+
}
171+
172+
if !changed {
173+
return expr, false, nil
174+
}
175+
176+
newExpr, err := expr.WithChildren(newChildren...)
177+
if err != nil {
178+
return nil, false, err
179+
}
180+
return newExpr, true, nil
181+
}
182+
183+
// updateGetFieldExpr updates a GetField expression to use the correct type from the child schema
184+
func updateGetFieldExpr(expr sql.Expression, childSchema sql.Schema) (sql.Expression, bool, error) {
185+
gf, ok := expr.(*expression.GetField)
186+
if !ok {
187+
return expr, false, nil
188+
}
189+
190+
idx := gf.Index()
191+
// GetField indices are 1-based in GMS planbuilder, so subtract 1 for schema access
192+
schemaIdx := idx - 1
193+
if schemaIdx < 0 || schemaIdx >= len(childSchema) {
194+
return expr, false, nil
195+
}
196+
197+
newType := childSchema[schemaIdx].Type
198+
if gf.Type() == newType {
199+
return expr, false, nil
200+
}
201+
202+
// Create a new GetField with the updated type
203+
newGf := expression.NewGetFieldWithTable(
204+
idx,
205+
int(gf.TableId()),
206+
newType,
207+
gf.Database(),
208+
gf.Table(),
209+
gf.Name(),
210+
gf.IsNullable(),
211+
)
212+
return newGf, true, nil
213+
}
214+
215+
// transformValuesNode transforms a VALUES or ValueDerivedTable node to use common types
216+
func transformValuesNode(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
217+
// Handle both ValueDerivedTable and Values nodes
218+
var values *plan.Values
219+
var vdt *plan.ValueDerivedTable
220+
var isVDT bool
221+
222+
switch v := n.(type) {
223+
case *plan.ValueDerivedTable:
224+
vdt = v
225+
values = v.Values
226+
isVDT = true
227+
case *plan.Values:
228+
values = v
229+
isVDT = false
230+
default:
231+
return n, transform.SameTree, nil
232+
}
233+
234+
// Skip if no rows or single row (nothing to unify)
235+
if len(values.ExpressionTuples) <= 1 {
236+
return n, transform.SameTree, nil
237+
}
238+
239+
numCols := len(values.ExpressionTuples[0])
240+
if numCols == 0 {
241+
return n, transform.SameTree, nil
242+
}
243+
244+
// Collect types for each column across all rows
245+
columnTypes := make([][]*pgtypes.DoltgresType, numCols)
246+
for colIdx := 0; colIdx < numCols; colIdx++ {
247+
columnTypes[colIdx] = make([]*pgtypes.DoltgresType, len(values.ExpressionTuples))
248+
for rowIdx, row := range values.ExpressionTuples {
249+
exprType := row[colIdx].Type()
250+
if exprType == nil {
251+
columnTypes[colIdx][rowIdx] = pgtypes.Unknown
252+
} else if pgType, ok := exprType.(*pgtypes.DoltgresType); ok {
253+
columnTypes[colIdx][rowIdx] = pgType
254+
} else {
255+
// Non-DoltgresType encountered - should have been sanitized
256+
// Return unchanged and let TypeSanitizer handle it
257+
return n, transform.SameTree, nil
258+
}
259+
}
260+
}
261+
262+
// Find common type for each column
263+
commonTypes := make([]*pgtypes.DoltgresType, numCols)
264+
for colIdx := 0; colIdx < numCols; colIdx++ {
265+
commonType, err := framework.FindCommonType(columnTypes[colIdx])
266+
if err != nil {
267+
return nil, transform.NewTree, err
268+
}
269+
commonTypes[colIdx] = commonType
270+
}
271+
272+
// Check if any changes are needed
273+
needsChange := false
274+
for colIdx := 0; colIdx < numCols; colIdx++ {
275+
for rowIdx := 0; rowIdx < len(values.ExpressionTuples); rowIdx++ {
276+
if !columnTypes[colIdx][rowIdx].Equals(commonTypes[colIdx]) {
277+
needsChange = true
278+
break
279+
}
280+
}
281+
if needsChange {
282+
break
283+
}
284+
}
285+
286+
if !needsChange {
287+
return n, transform.SameTree, nil
288+
}
289+
290+
// Create new expression tuples with implicit casts where needed
291+
newTuples := make([][]sql.Expression, len(values.ExpressionTuples))
292+
for rowIdx, row := range values.ExpressionTuples {
293+
newTuples[rowIdx] = make([]sql.Expression, numCols)
294+
for colIdx, expr := range row {
295+
fromType := columnTypes[colIdx][rowIdx]
296+
toType := commonTypes[colIdx]
297+
if fromType.Equals(toType) {
298+
newTuples[rowIdx][colIdx] = expr
299+
} else if fromType.ID == pgtypes.Unknown.ID {
300+
// Unknown type can be coerced to any type without explicit cast
301+
// Use UnknownCoercion to report the target type while passing through values
302+
newTuples[rowIdx][colIdx] = pgexprs.NewUnknownCoercion(expr, toType)
303+
} else {
304+
newTuples[rowIdx][colIdx] = pgexprs.NewImplicitCast(expr, fromType, toType)
305+
}
306+
}
307+
}
308+
309+
// Flatten the new tuples into a single expression slice for WithExpressions
310+
var flatExprs []sql.Expression
311+
for _, row := range newTuples {
312+
flatExprs = append(flatExprs, row...)
313+
}
314+
315+
if isVDT {
316+
// Use WithExpressions to preserve all VDT fields (name, columns, id, cols)
317+
// while updating the expressions and recalculating the schema
318+
newNode, err := vdt.WithExpressions(flatExprs...)
319+
if err != nil {
320+
return nil, transform.NewTree, err
321+
}
322+
return newNode, transform.NewTree, nil
323+
}
324+
325+
// For standalone Values node, use WithExpressions as well
326+
newNode, err := values.WithExpressions(flatExprs...)
327+
if err != nil {
328+
return nil, transform.NewTree, err
329+
}
330+
return newNode, transform.NewTree, nil
331+
}

0 commit comments

Comments
 (0)