@@ -43,21 +43,11 @@ func nodeCreateFunction(ctx *Context, node *tree.CreateFunction) (vitess.Stateme
4343 var retType * pgtypes.DoltgresType
4444 if len (node .RetType ) == 0 {
4545 retType = pgtypes .Void
46- } else if ! node .ReturnsTable { // Return types may specify "trigger", but this doesn't apply elsewhere
47- switch typ := node .RetType [0 ].Type .(type ) {
48- case * types.T :
49- retType = pgtypes .NewUnresolvedDoltgresType ("" , strings .ToLower (typ .Name ()))
50- case * tree.UnresolvedObjectName :
51- if typ .NumParts == 1 && typ .SQLString () == "trigger" {
52- retType = pgtypes .Trigger
53- } else {
54- _ , retType , err = nodeResolvableTypeReference (ctx , typ )
55- if err != nil {
56- return nil , err
57- }
58- }
59- default :
60- return nil , fmt .Errorf ("unsupported ResolvableTypeReference type: %T" , typ )
46+ } else if ! node .ReturnsTable {
47+ // Return types may specify "trigger", but this doesn't apply elsewhere
48+ retType , err = getDoltgresType (ctx , node .RetType [0 ].Type , true )
49+ if err != nil {
50+ return nil , err
6151 }
6252 } else {
6353 retType = createAnonymousCompositeType (node .RetType )
@@ -67,16 +57,9 @@ func nodeCreateFunction(ctx *Context, node *tree.CreateFunction) (vitess.Stateme
6757 paramTypes := make ([]* pgtypes.DoltgresType , len (node .Args ))
6858 for i , arg := range node .Args {
6959 paramNames [i ] = arg .Name .String ()
70- switch argType := arg .Type .(type ) {
71- case * types.T :
72- paramTypes [i ] = pgtypes .NewUnresolvedDoltgresType ("" , strings .ToLower (argType .Name ()))
73- case * tree.UnresolvedObjectName :
74- _ , paramTypes [i ], err = nodeResolvableTypeReference (ctx , argType )
75- if err != nil {
76- return nil , err
77- }
78- default :
79- paramTypes [i ] = pgtypes .NewUnresolvedDoltgresType ("" , strings .ToLower (argType .SQLString ()))
60+ paramTypes [i ], err = getDoltgresType (ctx , arg .Type , false )
61+ if err != nil {
62+ return nil , err
8063 }
8164 }
8265 var strict bool
@@ -99,6 +82,25 @@ func nodeCreateFunction(ctx *Context, node *tree.CreateFunction) (vitess.Stateme
9982 if err != nil {
10083 return nil , err
10184 }
85+ // parse types
86+ for i , op := range parsedBody {
87+ switch op .OpCode {
88+ case plpgsql .OpCode_Declare :
89+ declareTyp , err := parser .ParseType (op .PrimaryData )
90+ if err != nil {
91+ return nil , err
92+ }
93+ dt , err := getDoltgresType (ctx , declareTyp , false )
94+ if err != nil {
95+ return nil , err
96+ }
97+ dtName := dt .Name ()
98+ if dt .Schema () != "" {
99+ dtName = fmt .Sprintf ("%s.%s" , dt .Schema (), dtName )
100+ }
101+ parsedBody [i ].PrimaryData = dtName
102+ }
103+ }
102104 case "sql" :
103105 as , ok := options [tree .OptionAs1 ]
104106 if ! ok {
@@ -218,3 +220,21 @@ func validateRoutineOptions(ctx *Context, options []tree.RoutineOption) (map[tre
218220 }
219221 return optDefined , nil
220222}
223+
224+ // getDoltgresType converts ResolvableTypeReference into *DoltgresType.
225+ func getDoltgresType (ctx * Context , rt tree.ResolvableTypeReference , mayBeTrigger bool ) (* pgtypes.DoltgresType , error ) {
226+ switch argType := rt .(type ) {
227+ case * types.T :
228+ return pgtypes .NewUnresolvedDoltgresType ("" , strings .ToLower (argType .Name ())), nil
229+ case * tree.UnresolvedObjectName :
230+ if mayBeTrigger && argType .NumParts == 1 && argType .SQLString () == "trigger" {
231+ return pgtypes .Trigger , nil
232+ } else {
233+ _ , retType , err := nodeResolvableTypeReference (ctx , argType )
234+ return retType , err
235+ }
236+ default :
237+ // return nil, fmt.Errorf("unsupported ResolvableTypeReference type: %T", typ)
238+ return pgtypes .NewUnresolvedDoltgresType ("" , strings .ToLower (argType .SQLString ())), nil
239+ }
240+ }
0 commit comments