@@ -112,13 +112,33 @@ func nodeCreateFunction(ctx *Context, node *tree.CreateFunction) (vitess.Stateme
112112 }
113113 case "sql" :
114114 as , ok := options [tree .OptionAs1 ]
115- if ! ok {
116- return nil , errors .Errorf ("CREATE FUNCTION definition needed for LANGUAGE SQL" )
115+ if ok {
116+ sqlDef , sqlDefParsedStmts , err = handleLanguageSQLAs (as .Definition , params )
117+ if err != nil {
118+ return nil , err
119+ }
120+ break
117121 }
118- sqlDef , sqlDefParsedStmts , err = handleLanguageSQL (as .Definition , params )
119- if err != nil {
120- return nil , err
122+ sqlBody , ok := options [tree .OptionSqlBody ]
123+ if ok {
124+ beginAtomic , ok := sqlBody .SqlBody .(* tree.BeginEndBlock )
125+ if ! ok {
126+ return nil , errors .Errorf ("Expected BEGIN ATOMIC in CREATE FUNCTION definition, got %T" , sqlBody .SqlBody )
127+ }
128+ stmts := make ([]parser.Statement , len (beginAtomic .Statements ))
129+ for i , s := range beginAtomic .Statements {
130+ stmts [i ] = parser.Statement {
131+ AST : s ,
132+ SQL : s .String (),
133+ }
134+ }
135+ sqlDef , sqlDefParsedStmts , err = convertSQLStmts (stmts , params )
136+ if err != nil {
137+ return nil , err
138+ }
139+ break
121140 }
141+ return nil , errors .Errorf ("CREATE FUNCTION definition needed for LANGUAGE SQL" )
122142 case "c" :
123143 symbolOption , ok := options [tree .OptionAs2 ]
124144 if ! ok {
@@ -185,13 +205,17 @@ func createAnonymousCompositeType(fieldTypes []tree.SimpleColumnDef) *pgtypes.Do
185205 return pgtypes .NewCompositeType (context .Background (), id .Null , id .NullType , typeId , attrs )
186206}
187207
188- // handleLanguageSQL handles parsing SQL definition strings in both CREATE FUNCTION and CREATE PROCEDURE.
189- func handleLanguageSQL (definition string , params []pgnodes.RoutineArg ) (string , []vitess.Statement , error ) {
208+ // handleLanguageSQLAs handles parsing SQL definition strings in both CREATE FUNCTION and CREATE PROCEDURE.
209+ func handleLanguageSQLAs (definition string , params []pgnodes.RoutineArg ) (string , []vitess.Statement , error ) {
190210 stmts , err := parser .Parse (definition )
191211 if err != nil {
192212 return "" , nil , err
193213 }
194214
215+ return convertSQLStmts (stmts , params )
216+ }
217+
218+ func convertSQLStmts (stmts []parser.Statement , params []pgnodes.RoutineArg ) (string , []vitess.Statement , error ) {
195219 paramMap := make (map [string ]* framework.ParamTypAndValue , len (params ))
196220 for i , param := range params {
197221 tv := & framework.ParamTypAndValue {
@@ -212,14 +236,17 @@ func handleLanguageSQL(definition string, params []pgnodes.RoutineArg) (string,
212236 var vitessASTs = make ([]vitess.Statement , len (stmts ))
213237 for i , stmt := range stmts {
214238 sqlDefs [i ] = stmt .AST .String ()
215- err = framework .ReplaceFunctionColumn (stmt .AST , paramMap )
239+ err : = framework .ReplaceFunctionColumn (stmt .AST , paramMap )
216240 if err != nil {
217241 return "" , nil , err
218242 }
219243 // stmt.AST is updated at this point with FunctionColumn
220244 vitessASTs [i ], err = Convert (stmt )
245+ if err != nil {
246+ return "" , nil , err
247+ }
221248 }
222- return strings .Join (sqlDefs , ";" ), vitessASTs , err
249+ return strings .Join (sqlDefs , ";" ), vitessASTs , nil
223250}
224251
225252// validateRoutineOptions ensures that each option is defined only once. Returns a map containing all options, or an
0 commit comments