@@ -38,8 +38,11 @@ type InterpretedFunction interface {
3838 GetParameterNames () []string
3939 GetReturn () * pgtypes.DoltgresType
4040 GetStatements () []InterpreterOperation
41- QueryMultiReturn (ctx * sql.Context , stack InterpreterStack , stmt string , bindings []string ) (rows []sql.Row , err error )
41+ QueryMultiReturn (ctx * sql.Context , stack InterpreterStack , stmt string , bindings []string ) (schema sql. Schema , rows []sql.Row , err error )
4242 QuerySingleReturn (ctx * sql.Context , stack InterpreterStack , stmt string , targetType * pgtypes.DoltgresType , bindings []string ) (val any , err error )
43+ // IsSRF returns whether the function is a set returning function, meaning whether the
44+ // function returns one or more rows as a result.
45+ IsSRF () bool
4346}
4447
4548// GetTypesCollectionFromContext is declared within the core package, but is assigned to this variable to work around
@@ -159,7 +162,7 @@ func call(ctx *sql.Context, iFunc InterpretedFunction, stack InterpreterStack) (
159162 return nil , err
160163 }
161164 } else {
162- _ , err := iFunc .QueryMultiReturn (ctx , stack , operation .PrimaryData , operation .SecondaryData )
165+ _ , _ , err := iFunc .QueryMultiReturn (ctx , stack , operation .PrimaryData , operation .SecondaryData )
163166 if err != nil {
164167 return nil , err
165168 }
@@ -200,7 +203,7 @@ func call(ctx *sql.Context, iFunc InterpretedFunction, stack InterpreterStack) (
200203 case OpCode_InsertInto :
201204 // TODO: implement
202205 case OpCode_Perform :
203- _ , err := iFunc .QueryMultiReturn (ctx , stack , operation .PrimaryData , operation .SecondaryData )
206+ _ , _ , err := iFunc .QueryMultiReturn (ctx , stack , operation .PrimaryData , operation .SecondaryData )
204207 if err != nil {
205208 return nil , err
206209 }
@@ -229,9 +232,22 @@ func call(ctx *sql.Context, iFunc InterpretedFunction, stack InterpreterStack) (
229232 sess .Notice (noticeResponse )
230233 }
231234 case OpCode_Return :
235+ // If RETURN QUERY results are being buffered, return those
236+ if len (stack .ReturnQueryResults ()) > 0 {
237+ records := stack .ReturnQueryResults ()
238+
239+ rows := make ([]sql.Row , len (records ))
240+ for i , record := range records {
241+ rows [i ] = sql.Row {record }
242+ }
243+
244+ return sql .RowsToRowIter (rows ... ), nil
245+ }
246+
232247 if len (operation .PrimaryData ) == 0 {
233248 return nil , nil
234249 }
250+
235251 // TODO: handle record types properly, we'll special case triggers for now
236252 if iFunc .GetReturn ().ID == pgtypes .Trigger .ID && len (operation .SecondaryData ) == 1 {
237253 normalized := strings .ReplaceAll (strings .ToLower (operation .PrimaryData ), " " , "" )
@@ -243,7 +259,22 @@ func call(ctx *sql.Context, iFunc InterpretedFunction, stack InterpreterStack) (
243259 }
244260 }
245261 }
246- return iFunc .QuerySingleReturn (ctx , stack , operation .PrimaryData , iFunc .GetReturn (), operation .SecondaryData )
262+ val , err := iFunc .QuerySingleReturn (ctx , stack , operation .PrimaryData , iFunc .GetReturn (), operation .SecondaryData )
263+
264+ // If this is a set returning function, then we need to return a RowIter and wrap
265+ // the composite value in a sql.Row.
266+ if iFunc .IsSRF () {
267+ return sql .RowsToRowIter (sql.Row {val }), nil
268+ }
269+ return val , err
270+
271+ case OpCode_ReturnQuery :
272+ schema , rows , err := iFunc .QueryMultiReturn (ctx , stack , operation .PrimaryData , operation .SecondaryData )
273+ if err != nil {
274+ return nil , err
275+ }
276+ stack .BufferReturnQueryResults (convertRowsToRecords (schema , rows ))
277+
247278 case OpCode_ScopeBegin :
248279 stack .PushScope ()
249280 case OpCode_ScopeEnd :
@@ -259,6 +290,30 @@ func call(ctx *sql.Context, iFunc InterpretedFunction, stack InterpreterStack) (
259290 return nil , nil
260291}
261292
293+ // convertRowsToRecords iterates overs |rows| and converts each field in each row
294+ // into a RecordValue. |schema| is specified for type information.
295+ func convertRowsToRecords (schema sql.Schema , rows []sql.Row ) [][]pgtypes.RecordValue {
296+ records := make ([][]pgtypes.RecordValue , 0 , len (rows ))
297+ for _ , row := range rows {
298+ record := make ([]pgtypes.RecordValue , len (row ))
299+ for i , field := range row {
300+ t := schema [i ].Type
301+ doltgresType , ok := t .(* pgtypes.DoltgresType )
302+ if ! ok {
303+ panic ("expected Doltgres type" )
304+ }
305+
306+ record [i ] = pgtypes.RecordValue {
307+ Value : field ,
308+ Type : doltgresType ,
309+ }
310+ }
311+ records = append (records , record )
312+ }
313+
314+ return records
315+ }
316+
262317// applyNoticeOptions adds the specified |options| to the |noticeResponse|.
263318func applyNoticeOptions (ctx * sql.Context , noticeResponse * pgproto3.NoticeResponse , options map [string ]string ) error {
264319 for key , value := range options {
0 commit comments