Skip to content

Commit 1430450

Browse files
committed
resolve DECLARE variable type in CREATE FUNCTION
1 parent 61355a3 commit 1430450

2 files changed

Lines changed: 45 additions & 26 deletions

File tree

server/ast/create_function.go

Lines changed: 45 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -43,21 +43,11 @@ func nodeCreateFunction(ctx *Context, node *tree.CreateFunction) (vitess.Stateme
4343
var retType *pgtypes.DoltgresType
4444
if len(node.RetType) == 0 {
4545
retType = pgtypes.Void
46-
} else if !node.ReturnsTable { // Return types may specify "trigger", but this doesn't apply elsewhere
47-
switch typ := node.RetType[0].Type.(type) {
48-
case *types.T:
49-
retType = pgtypes.NewUnresolvedDoltgresType("", strings.ToLower(typ.Name()))
50-
case *tree.UnresolvedObjectName:
51-
if typ.NumParts == 1 && typ.SQLString() == "trigger" {
52-
retType = pgtypes.Trigger
53-
} else {
54-
_, retType, err = nodeResolvableTypeReference(ctx, typ)
55-
if err != nil {
56-
return nil, err
57-
}
58-
}
59-
default:
60-
return nil, fmt.Errorf("unsupported ResolvableTypeReference type: %T", typ)
46+
} else if !node.ReturnsTable {
47+
// Return types may specify "trigger", but this doesn't apply elsewhere
48+
retType, err = getDoltgresType(ctx, node.RetType[0].Type, true)
49+
if err != nil {
50+
return nil, err
6151
}
6252
} else {
6353
retType = createAnonymousCompositeType(node.RetType)
@@ -67,16 +57,9 @@ func nodeCreateFunction(ctx *Context, node *tree.CreateFunction) (vitess.Stateme
6757
paramTypes := make([]*pgtypes.DoltgresType, len(node.Args))
6858
for i, arg := range node.Args {
6959
paramNames[i] = arg.Name.String()
70-
switch argType := arg.Type.(type) {
71-
case *types.T:
72-
paramTypes[i] = pgtypes.NewUnresolvedDoltgresType("", strings.ToLower(argType.Name()))
73-
case *tree.UnresolvedObjectName:
74-
_, paramTypes[i], err = nodeResolvableTypeReference(ctx, argType)
75-
if err != nil {
76-
return nil, err
77-
}
78-
default:
79-
paramTypes[i] = pgtypes.NewUnresolvedDoltgresType("", strings.ToLower(argType.SQLString()))
60+
paramTypes[i], err = getDoltgresType(ctx, arg.Type, false)
61+
if err != nil {
62+
return nil, err
8063
}
8164
}
8265
var strict bool
@@ -99,6 +82,25 @@ func nodeCreateFunction(ctx *Context, node *tree.CreateFunction) (vitess.Stateme
9982
if err != nil {
10083
return nil, err
10184
}
85+
// parse types
86+
for i, op := range parsedBody {
87+
switch op.OpCode {
88+
case plpgsql.OpCode_Declare:
89+
declareTyp, err := parser.ParseType(op.PrimaryData)
90+
if err != nil {
91+
return nil, err
92+
}
93+
dt, err := getDoltgresType(ctx, declareTyp, false)
94+
if err != nil {
95+
return nil, err
96+
}
97+
dtName := dt.Name()
98+
if dt.Schema() != "" {
99+
dtName = fmt.Sprintf("%s.%s", dt.Schema(), dtName)
100+
}
101+
parsedBody[i].PrimaryData = dtName
102+
}
103+
}
102104
case "sql":
103105
as, ok := options[tree.OptionAs1]
104106
if !ok {
@@ -218,3 +220,21 @@ func validateRoutineOptions(ctx *Context, options []tree.RoutineOption) (map[tre
218220
}
219221
return optDefined, nil
220222
}
223+
224+
// getDoltgresType converts ResolvableTypeReference into *DoltgresType.
225+
func getDoltgresType(ctx *Context, rt tree.ResolvableTypeReference, mayBeTrigger bool) (*pgtypes.DoltgresType, error) {
226+
switch argType := rt.(type) {
227+
case *types.T:
228+
return pgtypes.NewUnresolvedDoltgresType("", strings.ToLower(argType.Name())), nil
229+
case *tree.UnresolvedObjectName:
230+
if mayBeTrigger && argType.NumParts == 1 && argType.SQLString() == "trigger" {
231+
return pgtypes.Trigger, nil
232+
} else {
233+
_, retType, err := nodeResolvableTypeReference(ctx, argType)
234+
return retType, err
235+
}
236+
default:
237+
// return nil, fmt.Errorf("unsupported ResolvableTypeReference type: %T", typ)
238+
return pgtypes.NewUnresolvedDoltgresType("", strings.ToLower(argType.SQLString())), nil
239+
}
240+
}

testing/go/create_function_plpgsql_test.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1413,7 +1413,6 @@ $$;`,
14131413
Query: "set search_path to 'public'",
14141414
},
14151415
{
1416-
Skip: true,
14171416
Query: "SELECT public.ambienttempdetail_insertupdate(101, 25.5, 15);",
14181417
Expected: []sql.Row{{101}},
14191418
},

0 commit comments

Comments
 (0)