|
| 1 | +// Copyright 2026 Dolthub, Inc. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +package analyzer |
| 16 | + |
| 17 | +import ( |
| 18 | + "github.com/dolthub/go-mysql-server/sql" |
| 19 | + "github.com/dolthub/go-mysql-server/sql/analyzer" |
| 20 | + "github.com/dolthub/go-mysql-server/sql/expression" |
| 21 | + "github.com/dolthub/go-mysql-server/sql/plan" |
| 22 | + "github.com/dolthub/go-mysql-server/sql/transform" |
| 23 | + |
| 24 | + pgexprs "github.com/dolthub/doltgresql/server/expression" |
| 25 | + "github.com/dolthub/doltgresql/server/functions/framework" |
| 26 | + pgtypes "github.com/dolthub/doltgresql/server/types" |
| 27 | +) |
| 28 | + |
| 29 | +// ResolveValuesTypes determines the common type for each column in a VALUES clause |
| 30 | +// by examining all rows, following PostgreSQL's type resolution rules. |
| 31 | +// This ensures VALUES(1),(2.01),(3) correctly infers numeric type, not integer. |
| 32 | +func ResolveValuesTypes(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, scope *plan.Scope, selector analyzer.RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) { |
| 33 | + // Track which VDTs we transform so we can update parent nodes |
| 34 | + transformedVDTs := make(map[*plan.ValueDerivedTable]sql.Schema) |
| 35 | + |
| 36 | + // First pass: transform VDTs and record their new schemas |
| 37 | + node, same1, err := transform.NodeWithOpaque(node, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) { |
| 38 | + newNode, same, err := transformValuesNode(n) |
| 39 | + if err != nil { |
| 40 | + return nil, same, err |
| 41 | + } |
| 42 | + if !same { |
| 43 | + if vdt, ok := newNode.(*plan.ValueDerivedTable); ok { |
| 44 | + transformedVDTs[vdt] = vdt.Schema() |
| 45 | + } |
| 46 | + } |
| 47 | + return newNode, same, err |
| 48 | + }) |
| 49 | + if err != nil { |
| 50 | + return nil, transform.SameTree, err |
| 51 | + } |
| 52 | + |
| 53 | + // Second pass: update GetField types in parent nodes that reference transformed VDTs |
| 54 | + if len(transformedVDTs) > 0 { |
| 55 | + node, _, err = transform.NodeWithOpaque(node, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) { |
| 56 | + return updateGetFieldTypes(n, transformedVDTs) |
| 57 | + }) |
| 58 | + if err != nil { |
| 59 | + return nil, transform.SameTree, err |
| 60 | + } |
| 61 | + } |
| 62 | + |
| 63 | + return node, same1, nil |
| 64 | +} |
| 65 | + |
| 66 | +// getSourceSchema traverses through wrapper nodes (GroupBy, Filter, etc.) to find |
| 67 | +// the actual source schema from a VDT or other data source. This is needed because |
| 68 | +// nodes like GroupBy produce a different output schema than their input schema. |
| 69 | +func getSourceSchema(n sql.Node) sql.Schema { |
| 70 | + switch node := n.(type) { |
| 71 | + case *plan.GroupBy: |
| 72 | + // GroupBy's Schema() returns aggregate output, but we need the source schema |
| 73 | + return getSourceSchema(node.Child) |
| 74 | + case *plan.Filter: |
| 75 | + return getSourceSchema(node.Child) |
| 76 | + case *plan.Sort: |
| 77 | + return getSourceSchema(node.Child) |
| 78 | + case *plan.Limit: |
| 79 | + return getSourceSchema(node.Child) |
| 80 | + case *plan.Offset: |
| 81 | + return getSourceSchema(node.Child) |
| 82 | + case *plan.Distinct: |
| 83 | + return getSourceSchema(node.Child) |
| 84 | + case *plan.SubqueryAlias: |
| 85 | + // SubqueryAlias wraps a VDT - get the child's schema |
| 86 | + return node.Child.Schema() |
| 87 | + case *plan.ValueDerivedTable: |
| 88 | + return node.Schema() |
| 89 | + default: |
| 90 | + // For other nodes, return their schema directly |
| 91 | + return n.Schema() |
| 92 | + } |
| 93 | +} |
| 94 | + |
| 95 | +// updateGetFieldTypes updates GetField expressions that reference transformed VDT columns |
| 96 | +func updateGetFieldTypes(n sql.Node, transformedVDTs map[*plan.ValueDerivedTable]sql.Schema) (sql.Node, transform.TreeIdentity, error) { |
| 97 | + // Only handle nodes that have expressions (like Project) |
| 98 | + exprNode, ok := n.(sql.Expressioner) |
| 99 | + if !ok { |
| 100 | + return n, transform.SameTree, nil |
| 101 | + } |
| 102 | + |
| 103 | + // Get the source schema by traversing through wrapper nodes like GroupBy |
| 104 | + // This ensures we get the VDT's schema, not the aggregate output schema |
| 105 | + var childSchema sql.Schema |
| 106 | + switch node := n.(type) { |
| 107 | + case *plan.Project: |
| 108 | + childSchema = getSourceSchema(node.Child) |
| 109 | + case *plan.SubqueryAlias: |
| 110 | + childSchema = node.Child.Schema() |
| 111 | + default: |
| 112 | + return n, transform.SameTree, nil |
| 113 | + } |
| 114 | + |
| 115 | + if childSchema == nil { |
| 116 | + return n, transform.SameTree, nil |
| 117 | + } |
| 118 | + |
| 119 | + // Transform expressions to update GetField types (recursively for nested expressions) |
| 120 | + exprs := exprNode.Expressions() |
| 121 | + newExprs := make([]sql.Expression, len(exprs)) |
| 122 | + changed := false |
| 123 | + |
| 124 | + for i, expr := range exprs { |
| 125 | + newExpr, exprChanged, err := updateGetFieldExprRecursive(expr, childSchema) |
| 126 | + if err != nil { |
| 127 | + return nil, transform.SameTree, err |
| 128 | + } |
| 129 | + newExprs[i] = newExpr |
| 130 | + if exprChanged { |
| 131 | + changed = true |
| 132 | + } |
| 133 | + } |
| 134 | + |
| 135 | + if !changed { |
| 136 | + return n, transform.SameTree, nil |
| 137 | + } |
| 138 | + |
| 139 | + newNode, err := exprNode.WithExpressions(newExprs...) |
| 140 | + if err != nil { |
| 141 | + return nil, transform.SameTree, err |
| 142 | + } |
| 143 | + return newNode.(sql.Node), transform.NewTree, nil |
| 144 | +} |
| 145 | + |
| 146 | +// updateGetFieldExprRecursive recursively updates GetField expressions in the expression tree |
| 147 | +func updateGetFieldExprRecursive(expr sql.Expression, childSchema sql.Schema) (sql.Expression, bool, error) { |
| 148 | + // First try to update if this is a GetField |
| 149 | + if _, ok := expr.(*expression.GetField); ok { |
| 150 | + return updateGetFieldExpr(expr, childSchema) |
| 151 | + } |
| 152 | + |
| 153 | + // Recursively process children |
| 154 | + children := expr.Children() |
| 155 | + if len(children) == 0 { |
| 156 | + return expr, false, nil |
| 157 | + } |
| 158 | + |
| 159 | + newChildren := make([]sql.Expression, len(children)) |
| 160 | + changed := false |
| 161 | + for i, child := range children { |
| 162 | + newChild, childChanged, err := updateGetFieldExprRecursive(child, childSchema) |
| 163 | + if err != nil { |
| 164 | + return nil, false, err |
| 165 | + } |
| 166 | + newChildren[i] = newChild |
| 167 | + if childChanged { |
| 168 | + changed = true |
| 169 | + } |
| 170 | + } |
| 171 | + |
| 172 | + if !changed { |
| 173 | + return expr, false, nil |
| 174 | + } |
| 175 | + |
| 176 | + newExpr, err := expr.WithChildren(newChildren...) |
| 177 | + if err != nil { |
| 178 | + return nil, false, err |
| 179 | + } |
| 180 | + return newExpr, true, nil |
| 181 | +} |
| 182 | + |
| 183 | +// updateGetFieldExpr updates a GetField expression to use the correct type from the child schema |
| 184 | +func updateGetFieldExpr(expr sql.Expression, childSchema sql.Schema) (sql.Expression, bool, error) { |
| 185 | + gf, ok := expr.(*expression.GetField) |
| 186 | + if !ok { |
| 187 | + return expr, false, nil |
| 188 | + } |
| 189 | + |
| 190 | + idx := gf.Index() |
| 191 | + // GetField indices are 1-based in GMS planbuilder, so subtract 1 for schema access |
| 192 | + schemaIdx := idx - 1 |
| 193 | + if schemaIdx < 0 || schemaIdx >= len(childSchema) { |
| 194 | + return expr, false, nil |
| 195 | + } |
| 196 | + |
| 197 | + newType := childSchema[schemaIdx].Type |
| 198 | + if gf.Type() == newType { |
| 199 | + return expr, false, nil |
| 200 | + } |
| 201 | + |
| 202 | + // Create a new GetField with the updated type |
| 203 | + newGf := expression.NewGetFieldWithTable( |
| 204 | + idx, |
| 205 | + int(gf.TableId()), |
| 206 | + newType, |
| 207 | + gf.Database(), |
| 208 | + gf.Table(), |
| 209 | + gf.Name(), |
| 210 | + gf.IsNullable(), |
| 211 | + ) |
| 212 | + return newGf, true, nil |
| 213 | +} |
| 214 | + |
| 215 | +// transformValuesNode transforms a VALUES or ValueDerivedTable node to use common types |
| 216 | +func transformValuesNode(n sql.Node) (sql.Node, transform.TreeIdentity, error) { |
| 217 | + // Handle both ValueDerivedTable and Values nodes |
| 218 | + var values *plan.Values |
| 219 | + var vdt *plan.ValueDerivedTable |
| 220 | + var isVDT bool |
| 221 | + |
| 222 | + switch v := n.(type) { |
| 223 | + case *plan.ValueDerivedTable: |
| 224 | + vdt = v |
| 225 | + values = v.Values |
| 226 | + isVDT = true |
| 227 | + case *plan.Values: |
| 228 | + values = v |
| 229 | + isVDT = false |
| 230 | + default: |
| 231 | + return n, transform.SameTree, nil |
| 232 | + } |
| 233 | + |
| 234 | + // Skip if no rows or single row (nothing to unify) |
| 235 | + if len(values.ExpressionTuples) <= 1 { |
| 236 | + return n, transform.SameTree, nil |
| 237 | + } |
| 238 | + |
| 239 | + numCols := len(values.ExpressionTuples[0]) |
| 240 | + if numCols == 0 { |
| 241 | + return n, transform.SameTree, nil |
| 242 | + } |
| 243 | + |
| 244 | + // Collect types for each column across all rows |
| 245 | + columnTypes := make([][]*pgtypes.DoltgresType, numCols) |
| 246 | + for colIdx := 0; colIdx < numCols; colIdx++ { |
| 247 | + columnTypes[colIdx] = make([]*pgtypes.DoltgresType, len(values.ExpressionTuples)) |
| 248 | + for rowIdx, row := range values.ExpressionTuples { |
| 249 | + exprType := row[colIdx].Type() |
| 250 | + if exprType == nil { |
| 251 | + columnTypes[colIdx][rowIdx] = pgtypes.Unknown |
| 252 | + } else if pgType, ok := exprType.(*pgtypes.DoltgresType); ok { |
| 253 | + columnTypes[colIdx][rowIdx] = pgType |
| 254 | + } else { |
| 255 | + // Non-DoltgresType encountered - should have been sanitized |
| 256 | + // Return unchanged and let TypeSanitizer handle it |
| 257 | + return n, transform.SameTree, nil |
| 258 | + } |
| 259 | + } |
| 260 | + } |
| 261 | + |
| 262 | + // Find common type for each column |
| 263 | + commonTypes := make([]*pgtypes.DoltgresType, numCols) |
| 264 | + for colIdx := 0; colIdx < numCols; colIdx++ { |
| 265 | + commonType, err := framework.FindCommonType(columnTypes[colIdx]) |
| 266 | + if err != nil { |
| 267 | + return nil, transform.NewTree, err |
| 268 | + } |
| 269 | + commonTypes[colIdx] = commonType |
| 270 | + } |
| 271 | + |
| 272 | + // Check if any changes are needed |
| 273 | + needsChange := false |
| 274 | + for colIdx := 0; colIdx < numCols; colIdx++ { |
| 275 | + for rowIdx := 0; rowIdx < len(values.ExpressionTuples); rowIdx++ { |
| 276 | + if !columnTypes[colIdx][rowIdx].Equals(commonTypes[colIdx]) { |
| 277 | + needsChange = true |
| 278 | + break |
| 279 | + } |
| 280 | + } |
| 281 | + if needsChange { |
| 282 | + break |
| 283 | + } |
| 284 | + } |
| 285 | + |
| 286 | + if !needsChange { |
| 287 | + return n, transform.SameTree, nil |
| 288 | + } |
| 289 | + |
| 290 | + // Create new expression tuples with implicit casts where needed |
| 291 | + newTuples := make([][]sql.Expression, len(values.ExpressionTuples)) |
| 292 | + for rowIdx, row := range values.ExpressionTuples { |
| 293 | + newTuples[rowIdx] = make([]sql.Expression, numCols) |
| 294 | + for colIdx, expr := range row { |
| 295 | + fromType := columnTypes[colIdx][rowIdx] |
| 296 | + toType := commonTypes[colIdx] |
| 297 | + if fromType.Equals(toType) { |
| 298 | + newTuples[rowIdx][colIdx] = expr |
| 299 | + } else if fromType.ID == pgtypes.Unknown.ID { |
| 300 | + // Unknown type can be coerced to any type without explicit cast |
| 301 | + // Use UnknownCoercion to report the target type while passing through values |
| 302 | + newTuples[rowIdx][colIdx] = pgexprs.NewUnknownCoercion(expr, toType) |
| 303 | + } else { |
| 304 | + newTuples[rowIdx][colIdx] = pgexprs.NewImplicitCast(expr, fromType, toType) |
| 305 | + } |
| 306 | + } |
| 307 | + } |
| 308 | + |
| 309 | + // Flatten the new tuples into a single expression slice for WithExpressions |
| 310 | + var flatExprs []sql.Expression |
| 311 | + for _, row := range newTuples { |
| 312 | + flatExprs = append(flatExprs, row...) |
| 313 | + } |
| 314 | + |
| 315 | + if isVDT { |
| 316 | + // Use WithExpressions to preserve all VDT fields (name, columns, id, cols) |
| 317 | + // while updating the expressions and recalculating the schema |
| 318 | + newNode, err := vdt.WithExpressions(flatExprs...) |
| 319 | + if err != nil { |
| 320 | + return nil, transform.NewTree, err |
| 321 | + } |
| 322 | + return newNode, transform.NewTree, nil |
| 323 | + } |
| 324 | + |
| 325 | + // For standalone Values node, use WithExpressions as well |
| 326 | + newNode, err := values.WithExpressions(flatExprs...) |
| 327 | + if err != nil { |
| 328 | + return nil, transform.NewTree, err |
| 329 | + } |
| 330 | + return newNode, transform.NewTree, nil |
| 331 | +} |
0 commit comments