1515package analyzer
1616
1717import (
18+ "github.com/cockroachdb/errors"
1819 "github.com/dolthub/go-mysql-server/sql"
1920 "github.com/dolthub/go-mysql-server/sql/analyzer"
2021 "github.com/dolthub/go-mysql-server/sql/expression"
2122 "github.com/dolthub/go-mysql-server/sql/plan"
2223 "github.com/dolthub/go-mysql-server/sql/transform"
2324
25+ pgtransform "github.com/dolthub/doltgresql/server/transform"
26+
2427 pgexprs "github.com/dolthub/doltgresql/server/expression"
2528 "github.com/dolthub/doltgresql/server/functions/framework"
2629 pgtypes "github.com/dolthub/doltgresql/server/types"
@@ -30,18 +33,17 @@ import (
3033// by examining all rows, following PostgreSQL's type resolution rules.
3134// This ensures VALUES(1),(2.01),(3) correctly infers numeric type, not integer.
3235func ResolveValuesTypes (ctx * sql.Context , a * analyzer.Analyzer , node sql.Node , scope * plan.Scope , selector analyzer.RuleSelector , qFlags * sql.QueryFlags ) (sql.Node , transform.TreeIdentity , error ) {
33- // Track which VDTs we transform so we can update parent nodes
34- transformedVDTs := make (map [* plan.ValueDerivedTable ]sql.Schema )
35-
36- // First pass: transform VDTs and record their new schemas
37- node , same1 , err := transform .NodeWithOpaque (node , func (n sql.Node ) (sql.Node , transform.TreeIdentity , error ) {
36+ // Track which VDTs we transform so we can update GetField nodes
37+ transformedVDTs := make (map [sql.TableId ]sql.Schema )
38+ // First we transform VDTs and record their new schemas
39+ node , same , err := transform .NodeWithOpaque (node , func (n sql.Node ) (sql.Node , transform.TreeIdentity , error ) {
3840 newNode , same , err := transformValuesNode (n )
3941 if err != nil {
4042 return nil , same , err
4143 }
4244 if ! same {
4345 if vdt , ok := newNode .(* plan.ValueDerivedTable ); ok {
44- transformedVDTs [vdt ] = vdt .Schema ()
46+ transformedVDTs [vdt . Id () ] = vdt .Schema ()
4547 }
4648 }
4749 return newNode , same , err
@@ -50,183 +52,61 @@ func ResolveValuesTypes(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, s
5052 return nil , transform .SameTree , err
5153 }
5254
53- // Second pass: update GetField types in parent nodes that reference transformed VDTs
55+ // Next we update all GetField expressions that refer to a transformed VDT
5456 if len (transformedVDTs ) > 0 {
55- node , _ , err = transform .NodeWithOpaque (node , func (n sql.Node ) (sql.Node , transform.TreeIdentity , error ) {
56- return updateGetFieldTypes (n , transformedVDTs )
57- })
58- if err != nil {
59- return nil , transform .SameTree , err
60- }
61- }
62-
63- return node , same1 , nil
64- }
65-
66- // getSourceSchema traverses through wrapper nodes (GroupBy, Filter, etc.) to find
67- // the actual source schema from a VDT or other data source. This is needed because
68- // nodes like GroupBy produce a different output schema than their input schema.
69- func getSourceSchema (n sql.Node ) sql.Schema {
70- switch node := n .(type ) {
71- case * plan.GroupBy :
72- // GroupBy's Schema() returns aggregate output, but we need the source schema
73- return getSourceSchema (node .Child )
74- case * plan.Filter :
75- return getSourceSchema (node .Child )
76- case * plan.Sort :
77- return getSourceSchema (node .Child )
78- case * plan.Limit :
79- return getSourceSchema (node .Child )
80- case * plan.Offset :
81- return getSourceSchema (node .Child )
82- case * plan.Distinct :
83- return getSourceSchema (node .Child )
84- case * plan.SubqueryAlias :
85- // SubqueryAlias wraps a VDT - get the child's schema
86- return node .Child .Schema ()
87- case * plan.ValueDerivedTable :
88- return node .Schema ()
89- default :
90- // For other nodes, return their schema directly
91- return n .Schema ()
92- }
93- }
94-
95- // updateGetFieldTypes updates GetField expressions that reference transformed VDT columns
96- func updateGetFieldTypes (n sql.Node , transformedVDTs map [* plan.ValueDerivedTable ]sql.Schema ) (sql.Node , transform.TreeIdentity , error ) {
97- // Only handle nodes that have expressions (like Project)
98- exprNode , ok := n .(sql.Expressioner )
99- if ! ok {
100- return n , transform .SameTree , nil
101- }
102-
103- // Get the source schema by traversing through wrapper nodes like GroupBy
104- // This ensures we get the VDT's schema, not the aggregate output schema
105- var childSchema sql.Schema
106- switch node := n .(type ) {
107- case * plan.Project :
108- childSchema = getSourceSchema (node .Child )
109- case * plan.SubqueryAlias :
110- childSchema = node .Child .Schema ()
111- default :
112- return n , transform .SameTree , nil
113- }
57+ node , _ , err = pgtransform .NodeExprsWithOpaque (node , func (expr sql.Expression ) (sql.Expression , transform.TreeIdentity , error ) {
58+ gf , ok := expr .(* expression.GetField )
59+ if ! ok {
60+ return expr , transform .SameTree , nil
61+ }
62+ newSch , ok := transformedVDTs [gf .TableId ()]
63+ if ! ok {
64+ return expr , transform .SameTree , nil
65+ }
11466
115- if childSchema == nil {
116- return n , transform .SameTree , nil
117- }
67+ // GetField indices are 1-based in GMS planbuilder, so subtract 1 for schema access
68+ schemaIdx := gf .Index () - 1
69+ if schemaIdx < 0 || schemaIdx >= len (newSch ) {
70+ return nil , transform .NewTree , errors .Newf ("GetField `%s` on table `%s` uses invalid index `%d`" ,
71+ gf .Name (), gf .Table (), gf .Index ())
72+ }
11873
119- // Transform expressions to update GetField types (recursively for nested expressions)
120- exprs := exprNode . Expressions ()
121- newExprs := make ([]sql. Expression , len ( exprs ))
122- changed := false
74+ newType := newSch [ schemaIdx ]. Type
75+ if gf . Type () == newType {
76+ return expr , transform . SameTree , nil
77+ }
12378
124- for i , expr := range exprs {
125- newExpr , exprChanged , err := updateGetFieldExprRecursive (expr , childSchema )
79+ // Create a new expression with the updated type
80+ newGf := expression .NewGetFieldWithTable (
81+ gf .Index (),
82+ int (gf .TableId ()),
83+ newType ,
84+ gf .Database (),
85+ gf .Table (),
86+ gf .Name (),
87+ gf .IsNullable (),
88+ )
89+ return newGf , transform .NewTree , nil
90+ })
12691 if err != nil {
12792 return nil , transform .SameTree , err
12893 }
129- newExprs [i ] = newExpr
130- if exprChanged {
131- changed = true
132- }
133- }
134-
135- if ! changed {
136- return n , transform .SameTree , nil
13794 }
13895
139- newNode , err := exprNode .WithExpressions (newExprs ... )
140- if err != nil {
141- return nil , transform .SameTree , err
142- }
143- return newNode .(sql.Node ), transform .NewTree , nil
96+ return node , same , nil
14497}
14598
146- // updateGetFieldExprRecursive recursively updates GetField expressions in the expression tree
147- func updateGetFieldExprRecursive (expr sql.Expression , childSchema sql.Schema ) (sql.Expression , bool , error ) {
148- // First try to update if this is a GetField
149- if _ , ok := expr .(* expression.GetField ); ok {
150- return updateGetFieldExpr (expr , childSchema )
151- }
152-
153- // Recursively process children
154- children := expr .Children ()
155- if len (children ) == 0 {
156- return expr , false , nil
157- }
158-
159- newChildren := make ([]sql.Expression , len (children ))
160- changed := false
161- for i , child := range children {
162- newChild , childChanged , err := updateGetFieldExprRecursive (child , childSchema )
163- if err != nil {
164- return nil , false , err
165- }
166- newChildren [i ] = newChild
167- if childChanged {
168- changed = true
169- }
170- }
171-
172- if ! changed {
173- return expr , false , nil
174- }
175-
176- newExpr , err := expr .WithChildren (newChildren ... )
177- if err != nil {
178- return nil , false , err
179- }
180- return newExpr , true , nil
181- }
182-
183- // updateGetFieldExpr updates a GetField expression to use the correct type from the child schema
184- func updateGetFieldExpr (expr sql.Expression , childSchema sql.Schema ) (sql.Expression , bool , error ) {
185- gf , ok := expr .(* expression.GetField )
186- if ! ok {
187- return expr , false , nil
188- }
189-
190- idx := gf .Index ()
191- // GetField indices are 1-based in GMS planbuilder, so subtract 1 for schema access
192- schemaIdx := idx - 1
193- if schemaIdx < 0 || schemaIdx >= len (childSchema ) {
194- return expr , false , nil
195- }
196-
197- newType := childSchema [schemaIdx ].Type
198- if gf .Type () == newType {
199- return expr , false , nil
200- }
201-
202- // Create a new GetField with the updated type
203- newGf := expression .NewGetFieldWithTable (
204- idx ,
205- int (gf .TableId ()),
206- newType ,
207- gf .Database (),
208- gf .Table (),
209- gf .Name (),
210- gf .IsNullable (),
211- )
212- return newGf , true , nil
213- }
214-
215- // transformValuesNode transforms a VALUES or ValueDerivedTable node to use common types
99+ // transformValuesNode transforms a plan.Values or plan.ValueDerivedTable node to use common types
216100func transformValuesNode (n sql.Node ) (sql.Node , transform.TreeIdentity , error ) {
217- // Handle both ValueDerivedTable and Values nodes
218101 var values * plan.Values
219- var vdt * plan.ValueDerivedTable
220- var isVDT bool
221-
102+ var expressionerNode sql.Expressioner
222103 switch v := n .(type ) {
223104 case * plan.ValueDerivedTable :
224- vdt = v
225105 values = v .Values
226- isVDT = true
106+ expressionerNode = v
227107 case * plan.Values :
228108 values = v
229- isVDT = false
109+ expressionerNode = v
230110 default :
231111 return n , transform .SameTree , nil
232112 }
@@ -235,8 +115,12 @@ func transformValuesNode(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
235115 if len (values .ExpressionTuples ) <= 1 {
236116 return n , transform .SameTree , nil
237117 }
238-
239118 numCols := len (values .ExpressionTuples [0 ])
119+ for i := 1 ; i < len (values .ExpressionTuples ); i ++ {
120+ if len (values .ExpressionTuples [i ]) != numCols {
121+ return nil , transform .NewTree , errors .New ("VALUES lists must all be the same length" )
122+ }
123+ }
240124 if numCols == 0 {
241125 return n , transform .SameTree , nil
242126 }
@@ -252,78 +136,41 @@ func transformValuesNode(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
252136 } else if pgType , ok := exprType .(* pgtypes.DoltgresType ); ok {
253137 columnTypes [colIdx ][rowIdx ] = pgType
254138 } else {
255- // Non-DoltgresType encountered - should have been sanitized
256- // Return unchanged and let TypeSanitizer handle it
257- return n , transform .SameTree , nil
139+ return n , transform .NewTree , errors .New ("VALUES cannot use GMS types" )
258140 }
259141 }
260142 }
261143
262144 // Find common type for each column
263- commonTypes := make ([] * pgtypes. DoltgresType , numCols )
145+ var newTuples [][]sql. Expression
264146 for colIdx := 0 ; colIdx < numCols ; colIdx ++ {
265- commonType , err := framework .FindCommonType (columnTypes [colIdx ])
147+ commonType , requiresCasts , err := framework .FindCommonType (columnTypes [colIdx ])
266148 if err != nil {
267149 return nil , transform .NewTree , err
268150 }
269- commonTypes [colIdx ] = commonType
270- }
271-
272- // Check if any changes are needed
273- needsChange := false
274- for colIdx := 0 ; colIdx < numCols ; colIdx ++ {
275- for rowIdx := 0 ; rowIdx < len (values .ExpressionTuples ); rowIdx ++ {
276- if ! columnTypes [colIdx ][rowIdx ].Equals (commonTypes [colIdx ]) {
277- needsChange = true
278- break
151+ // If we require any casts, then we'll add casting to all expressions in the list
152+ if requiresCasts {
153+ if len (newTuples ) == 0 {
154+ newTuples = make ([][]sql.Expression , len (values .ExpressionTuples ))
155+ copy (newTuples , values .ExpressionTuples )
156+ }
157+ for rowIdx := 0 ; rowIdx < len (newTuples ); rowIdx ++ {
158+ newTuples [rowIdx ][colIdx ] = pgexprs .NewImplicitCast (
159+ newTuples [rowIdx ][colIdx ], columnTypes [colIdx ][rowIdx ], commonType )
279160 }
280- }
281- if needsChange {
282- break
283161 }
284162 }
285-
286- if ! needsChange {
163+ // If we didn't require any casts, then we can simply return our old node
164+ if len ( newTuples ) == 0 {
287165 return n , transform .SameTree , nil
288166 }
289167
290- // Create new expression tuples with implicit casts where needed
291- newTuples := make ([][]sql.Expression , len (values .ExpressionTuples ))
292- for rowIdx , row := range values .ExpressionTuples {
293- newTuples [rowIdx ] = make ([]sql.Expression , numCols )
294- for colIdx , expr := range row {
295- fromType := columnTypes [colIdx ][rowIdx ]
296- toType := commonTypes [colIdx ]
297- if fromType .Equals (toType ) {
298- newTuples [rowIdx ][colIdx ] = expr
299- } else if fromType .ID == pgtypes .Unknown .ID {
300- // Unknown type can be coerced to any type without explicit cast
301- // Use UnknownCoercion to report the target type while passing through values
302- newTuples [rowIdx ][colIdx ] = pgexprs .NewUnknownCoercion (expr , toType )
303- } else {
304- newTuples [rowIdx ][colIdx ] = pgexprs .NewImplicitCast (expr , fromType , toType )
305- }
306- }
307- }
308-
309168 // Flatten the new tuples into a single expression slice for WithExpressions
310169 var flatExprs []sql.Expression
311170 for _ , row := range newTuples {
312171 flatExprs = append (flatExprs , row ... )
313172 }
314-
315- if isVDT {
316- // Use WithExpressions to preserve all VDT fields (name, columns, id, cols)
317- // while updating the expressions and recalculating the schema
318- newNode , err := vdt .WithExpressions (flatExprs ... )
319- if err != nil {
320- return nil , transform .NewTree , err
321- }
322- return newNode , transform .NewTree , nil
323- }
324-
325- // For standalone Values node, use WithExpressions as well
326- newNode , err := values .WithExpressions (flatExprs ... )
173+ newNode , err := expressionerNode .WithExpressions (flatExprs ... )
327174 if err != nil {
328175 return nil , transform .NewTree , err
329176 }
0 commit comments