Skip to content

Commit e726eb2

Browse files
authored
Merge pull request #2190 from dolthub/fulghum/set-returning-udfs
Initial support for UDFs with `RETURNS SETOF`
2 parents 090d26b + 7b3a5bc commit e726eb2

9 files changed

Lines changed: 282 additions & 42 deletions

File tree

server/functions/framework/functions.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ type FunctionInterface interface {
3939
// IsStrict returns whether the function is STRICT, which means if any parameter is NULL, then it returns NULL.
4040
// Otherwise, if it's not, the NULL input must be handled by user.
4141
IsStrict() bool
42-
// IsSRF returns whether the function is set returning function, meaning whether the function returns one or more
43-
// rows as a result.
42+
// IsSRF returns whether the function is a set returning function, meaning whether the
43+
// function returns one or more rows as a result.
4444
IsSRF() bool
4545
// InternalID returns the ID associated with this function.
4646
InternalID() id.Id

server/functions/framework/interpreted_function.go

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ type InterpretedFunction struct {
3838
Variadic bool
3939
IsNonDeterministic bool
4040
Strict bool
41-
SRF bool
4241
Statements []plpgsql.InterpreterOperation
4342
}
4443

@@ -87,7 +86,12 @@ func (iFunc InterpretedFunction) IsStrict() bool {
8786

8887
// IsSRF implements the interface FunctionInterface.
8988
func (iFunc InterpretedFunction) IsSRF() bool {
90-
return iFunc.SRF
89+
switch iFunc.ReturnType.TypCategory {
90+
case pgtypes.TypeCategory_CompositeTypes:
91+
return true
92+
default:
93+
return false
94+
}
9195
}
9296

9397
// NonDeterministic implements the interface FunctionInterface.
@@ -107,6 +111,7 @@ func (iFunc InterpretedFunction) QuerySingleReturn(ctx *sql.Context, stack plpgs
107111
if err != nil {
108112
return nil, err
109113
}
114+
110115
return sql.RunInterpreted(ctx, func(subCtx *sql.Context) (any, error) {
111116
sch, rowIter, _, err := stack.Runner().QueryWithBindings(subCtx, stmt, nil, nil, nil)
112117
if err != nil {
@@ -164,13 +169,15 @@ func (iFunc InterpretedFunction) QuerySingleReturn(ctx *sql.Context, stack plpgs
164169
}
165170

166171
// QueryMultiReturn handles queries that may return multiple values over multiple rows.
167-
func (iFunc InterpretedFunction) QueryMultiReturn(ctx *sql.Context, stack plpgsql.InterpreterStack, stmt string, bindings []string) (rows []sql.Row, err error) {
172+
func (iFunc InterpretedFunction) QueryMultiReturn(ctx *sql.Context, stack plpgsql.InterpreterStack, stmt string, bindings []string) (schema sql.Schema, rows []sql.Row, err error) {
168173
stmt, _, err = iFunc.ApplyBindings(ctx, stack, stmt, bindings, true)
169174
if err != nil {
170-
return nil, err
175+
return nil, nil, err
171176
}
172-
return sql.RunInterpreted(ctx, func(subCtx *sql.Context) ([]sql.Row, error) {
173-
_, rowIter, _, err := stack.Runner().QueryWithBindings(subCtx, stmt, nil, nil, nil)
177+
178+
rows, err = sql.RunInterpreted(ctx, func(subCtx *sql.Context) (rows []sql.Row, err error) {
179+
var rowIter sql.RowIter
180+
schema, rowIter, _, err = stack.Runner().QueryWithBindings(subCtx, stmt, nil, nil, nil)
174181
if err != nil {
175182
return nil, err
176183
}
@@ -179,9 +186,10 @@ func (iFunc InterpretedFunction) QueryMultiReturn(ctx *sql.Context, stack plpgsq
179186
// fine.
180187
return sql.RowIterToRows(subCtx, rowIter)
181188
})
189+
return schema, rows, err
182190
}
183191

184-
// ApplyBindings applies the given bindings to the statement. If `varFound` is false, then the error will be state that
192+
// ApplyBindings applies the given bindings to the statement. If `varFound` is false, then the error will state that
185193
// the variable was not found (which means the error may be ignored if you're only concerned with finding a variable).
186194
// If `varFound` is true, then the error is related to formatting the variable. `enforceType` adds casting and quotes to
187195
// ensure that the value is correctly represented in the string.

server/plpgsql/interpreter_logic.go

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,11 @@ type InterpretedFunction interface {
3838
GetParameterNames() []string
3939
GetReturn() *pgtypes.DoltgresType
4040
GetStatements() []InterpreterOperation
41-
QueryMultiReturn(ctx *sql.Context, stack InterpreterStack, stmt string, bindings []string) (rows []sql.Row, err error)
41+
QueryMultiReturn(ctx *sql.Context, stack InterpreterStack, stmt string, bindings []string) (schema sql.Schema, rows []sql.Row, err error)
4242
QuerySingleReturn(ctx *sql.Context, stack InterpreterStack, stmt string, targetType *pgtypes.DoltgresType, bindings []string) (val any, err error)
43+
// IsSRF returns whether the function is a set returning function, meaning whether the
44+
// function returns one or more rows as a result.
45+
IsSRF() bool
4346
}
4447

4548
// GetTypesCollectionFromContext is declared within the core package, but is assigned to this variable to work around
@@ -159,7 +162,7 @@ func call(ctx *sql.Context, iFunc InterpretedFunction, stack InterpreterStack) (
159162
return nil, err
160163
}
161164
} else {
162-
_, err := iFunc.QueryMultiReturn(ctx, stack, operation.PrimaryData, operation.SecondaryData)
165+
_, _, err := iFunc.QueryMultiReturn(ctx, stack, operation.PrimaryData, operation.SecondaryData)
163166
if err != nil {
164167
return nil, err
165168
}
@@ -200,7 +203,7 @@ func call(ctx *sql.Context, iFunc InterpretedFunction, stack InterpreterStack) (
200203
case OpCode_InsertInto:
201204
// TODO: implement
202205
case OpCode_Perform:
203-
_, err := iFunc.QueryMultiReturn(ctx, stack, operation.PrimaryData, operation.SecondaryData)
206+
_, _, err := iFunc.QueryMultiReturn(ctx, stack, operation.PrimaryData, operation.SecondaryData)
204207
if err != nil {
205208
return nil, err
206209
}
@@ -229,9 +232,22 @@ func call(ctx *sql.Context, iFunc InterpretedFunction, stack InterpreterStack) (
229232
sess.Notice(noticeResponse)
230233
}
231234
case OpCode_Return:
235+
// If RETURN QUERY results are being buffered, return those
236+
if len(stack.ReturnQueryResults()) > 0 {
237+
records := stack.ReturnQueryResults()
238+
239+
rows := make([]sql.Row, len(records))
240+
for i, record := range records {
241+
rows[i] = sql.Row{record}
242+
}
243+
244+
return sql.RowsToRowIter(rows...), nil
245+
}
246+
232247
if len(operation.PrimaryData) == 0 {
233248
return nil, nil
234249
}
250+
235251
// TODO: handle record types properly, we'll special case triggers for now
236252
if iFunc.GetReturn().ID == pgtypes.Trigger.ID && len(operation.SecondaryData) == 1 {
237253
normalized := strings.ReplaceAll(strings.ToLower(operation.PrimaryData), " ", "")
@@ -243,7 +259,22 @@ func call(ctx *sql.Context, iFunc InterpretedFunction, stack InterpreterStack) (
243259
}
244260
}
245261
}
246-
return iFunc.QuerySingleReturn(ctx, stack, operation.PrimaryData, iFunc.GetReturn(), operation.SecondaryData)
262+
val, err := iFunc.QuerySingleReturn(ctx, stack, operation.PrimaryData, iFunc.GetReturn(), operation.SecondaryData)
263+
264+
// If this is a set returning function, then we need to return a RowIter and wrap
265+
// the composite value in a sql.Row.
266+
if iFunc.IsSRF() {
267+
return sql.RowsToRowIter(sql.Row{val}), nil
268+
}
269+
return val, err
270+
271+
case OpCode_ReturnQuery:
272+
schema, rows, err := iFunc.QueryMultiReturn(ctx, stack, operation.PrimaryData, operation.SecondaryData)
273+
if err != nil {
274+
return nil, err
275+
}
276+
stack.BufferReturnQueryResults(convertRowsToRecords(schema, rows))
277+
247278
case OpCode_ScopeBegin:
248279
stack.PushScope()
249280
case OpCode_ScopeEnd:
@@ -259,6 +290,30 @@ func call(ctx *sql.Context, iFunc InterpretedFunction, stack InterpreterStack) (
259290
return nil, nil
260291
}
261292

293+
// convertRowsToRecords iterates overs |rows| and converts each field in each row
294+
// into a RecordValue. |schema| is specified for type information.
295+
func convertRowsToRecords(schema sql.Schema, rows []sql.Row) [][]pgtypes.RecordValue {
296+
records := make([][]pgtypes.RecordValue, 0, len(rows))
297+
for _, row := range rows {
298+
record := make([]pgtypes.RecordValue, len(row))
299+
for i, field := range row {
300+
t := schema[i].Type
301+
doltgresType, ok := t.(*pgtypes.DoltgresType)
302+
if !ok {
303+
panic("expected Doltgres type")
304+
}
305+
306+
record[i] = pgtypes.RecordValue{
307+
Value: field,
308+
Type: doltgresType,
309+
}
310+
}
311+
records = append(records, record)
312+
}
313+
314+
return records
315+
}
316+
262317
// applyNoticeOptions adds the specified |options| to the |noticeResponse|.
263318
func applyNoticeOptions(ctx *sql.Context, noticeResponse *pgproto3.NoticeResponse, options map[string]string) error {
264319
for key, value := range options {

server/plpgsql/interpreter_operation.go

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,24 +19,29 @@ package plpgsql
1919
type OpCode uint16
2020

2121
const (
22-
OpCode_Alias OpCode = iota // https://www.postgresql.org/docs/15/plpgsql-declarations.html#PLPGSQL-DECLARATION-ALIAS
23-
OpCode_Assign // https://www.postgresql.org/docs/15/plpgsql-statements.html#PLPGSQL-STATEMENTS-ASSIGNMENT
24-
OpCode_Case // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-CONDITIONALS
25-
OpCode_Declare // https://www.postgresql.org/docs/15/plpgsql-declarations.html
26-
OpCode_DeleteInto // https://www.postgresql.org/docs/15/plpgsql-statements.html
27-
OpCode_Exception // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-ERROR-TRAPPING
28-
OpCode_Execute // Executing a standard SQL statement (expects no rows returned unless Target is specified)
29-
OpCode_Get // https://www.postgresql.org/docs/15/plpgsql-statements.html#PLPGSQL-STATEMENTS-DIAGNOSTICS
30-
OpCode_Goto // All control-flow structures can be represented using Goto
31-
OpCode_If // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-CONDITIONALS
32-
OpCode_InsertInto // https://www.postgresql.org/docs/15/plpgsql-statements.html
33-
OpCode_Perform // https://www.postgresql.org/docs/15/plpgsql-statements.html
34-
OpCode_Raise // https://www.postgresql.org/docs/15/plpgsql-errors-and-messages.html
35-
OpCode_Return // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-STATEMENTS-RETURNING
36-
OpCode_ScopeBegin // This is used for scope control, specific to Doltgres
37-
OpCode_ScopeEnd // This is used for scope control, specific to Doltgres
38-
OpCode_SelectInto // https://www.postgresql.org/docs/15/plpgsql-statements.html
39-
OpCode_UpdateInto // https://www.postgresql.org/docs/15/plpgsql-statements.html
22+
// New OpCode values MUST be added to the END of this list!
23+
// Function OpCodes are persisted to disk, so these values MUST be stable across Doltgres versions.
24+
OpCode_Alias OpCode = 0 // https://www.postgresql.org/docs/15/plpgsql-declarations.html#PLPGSQL-DECLARATION-ALIAS
25+
OpCode_Assign OpCode = 1 // https://www.postgresql.org/docs/15/plpgsql-statements.html#PLPGSQL-STATEMENTS-ASSIGNMENT
26+
OpCode_Case OpCode = 2 // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-CONDITIONALS
27+
OpCode_Declare OpCode = 3 // https://www.postgresql.org/docs/15/plpgsql-declarations.html
28+
OpCode_DeleteInto OpCode = 4 // https://www.postgresql.org/docs/15/plpgsql-statements.html
29+
OpCode_Exception OpCode = 5 // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-ERROR-TRAPPING
30+
OpCode_Execute OpCode = 6 // Executing a standard SQL statement (expects no rows returned unless Target is specified)
31+
OpCode_Get OpCode = 7 // https://www.postgresql.org/docs/15/plpgsql-statements.html#PLPGSQL-STATEMENTS-DIAGNOSTICS
32+
OpCode_Goto OpCode = 8 // All control-flow structures can be represented using Goto
33+
OpCode_If OpCode = 9 // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-CONDITIONALS
34+
OpCode_InsertInto OpCode = 10 // https://www.postgresql.org/docs/15/plpgsql-statements.html
35+
OpCode_Perform OpCode = 11 // https://www.postgresql.org/docs/15/plpgsql-statements.html
36+
OpCode_Raise OpCode = 12 // https://www.postgresql.org/docs/15/plpgsql-errors-and-messages.html
37+
OpCode_Return OpCode = 13 // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-STATEMENTS-RETURNING
38+
OpCode_ScopeBegin OpCode = 14 // This is used for scope control, specific to Doltgres
39+
OpCode_ScopeEnd OpCode = 15 // This is used for scope control, specific to Doltgres
40+
OpCode_SelectInto OpCode = 16 // https://www.postgresql.org/docs/15/plpgsql-statements.html
41+
OpCode_UpdateInto OpCode = 17 // https://www.postgresql.org/docs/15/plpgsql-statements.html
42+
OpCode_ReturnQuery OpCode = 18 // https://www.postgresql.org/docs/current/plpgsql-control-structures.html#PLPGSQL-STATEMENTS-RETURNING-RETURN-NEXT
43+
// New OpCode values MUST be added to the END of this list!
44+
// Function OpCodes are persisted to disk, so these values MUST be stable across Doltgres versions.
4045
)
4146

4247
// InterpreterOperation is an operation that will be performed by the interpreter.

server/plpgsql/interpreter_stack.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ type InterpreterStack struct {
5555
stack *utils.Stack[*InterpreterScopeDetails]
5656
runner sql.StatementRunner
5757
labelID int
58+
59+
// returnQueryBuffer buffers results from RETURN QUERY statements
60+
returnQueryBuffer [][]pgtypes.RecordValue
5861
}
5962

6063
// NewInterpreterStack creates a new InterpreterStack.
@@ -235,3 +238,15 @@ func (is *InterpreterStack) SetAnonymousLabel() {
235238
is.stack.Peek().label = fmt.Sprintf("\t%d", is.labelID)
236239
is.labelID++
237240
}
241+
242+
// BufferReturnQueryResults buffers |results| from a RETURN QUERY statement so that they can be returned when
243+
// the function exits. If results from a previous RETURN QUERY call have already been buffered, |results| will
244+
// be appended.
245+
func (is *InterpreterStack) BufferReturnQueryResults(results [][]pgtypes.RecordValue) {
246+
is.returnQueryBuffer = append(is.returnQueryBuffer, results...)
247+
}
248+
249+
// ReturnQueryResults returns the buffered results from a RETURN QUERY statement.
250+
func (is *InterpreterStack) ReturnQueryResults() [][]pgtypes.RecordValue {
251+
return is.returnQueryBuffer
252+
}

server/plpgsql/json.go

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,12 @@ type plpgSQL_stmt_return struct {
203203
LineNumber int32 `json:"lineno"`
204204
}
205205

206+
// plpgSQL_stmt_return_query exists to match the expected JSON format.
207+
type plpgSQL_stmt_return_query struct {
208+
Query expr `json:"query"`
209+
LineNumber int32 `json:"lineno"`
210+
}
211+
206212
// plpgSQL_stmt_while exists to match the expected JSON format.
207213
type plpgSQL_stmt_while struct {
208214
Condition cond `json:"cond"`
@@ -231,17 +237,18 @@ type sqlstmt struct {
231237
// statement exists to match the expected JSON format. Unlike other structs, this is used like a union rather than
232238
// having a singular expected implementation.
233239
type statement struct {
234-
Assignment *plpgSQL_stmt_assign `json:"PLpgSQL_stmt_assign"`
235-
Case *plpgSQL_stmt_case `json:"PLpgSQL_stmt_case"`
236-
ExecSQL *plpgSQL_stmt_execsql `json:"PLpgSQL_stmt_execsql"`
237-
Exit *plpgSQL_stmt_exit `json:"PLpgSQL_stmt_exit"`
238-
If *plpgSQL_stmt_if `json:"PLpgSQL_stmt_if"`
239-
Loop *plpgSQL_stmt_loop `json:"PLpgSQL_stmt_loop"`
240-
Perform *plpgSQL_stmt_perform `json:"PLpgSQL_stmt_perform"`
241-
Raise *plpgSQL_stmt_raise `json:"PLpgSQL_stmt_raise"`
242-
Return *plpgSQL_stmt_return `json:"PLpgSQL_stmt_return"`
243-
When *plpgSQL_case_when `json:"PLpgSQL_case_when"`
244-
While *plpgSQL_stmt_while `json:"PLpgSQL_stmt_while"`
240+
Assignment *plpgSQL_stmt_assign `json:"PLpgSQL_stmt_assign"`
241+
Case *plpgSQL_stmt_case `json:"PLpgSQL_stmt_case"`
242+
ExecSQL *plpgSQL_stmt_execsql `json:"PLpgSQL_stmt_execsql"`
243+
Exit *plpgSQL_stmt_exit `json:"PLpgSQL_stmt_exit"`
244+
If *plpgSQL_stmt_if `json:"PLpgSQL_stmt_if"`
245+
Loop *plpgSQL_stmt_loop `json:"PLpgSQL_stmt_loop"`
246+
Perform *plpgSQL_stmt_perform `json:"PLpgSQL_stmt_perform"`
247+
Raise *plpgSQL_stmt_raise `json:"PLpgSQL_stmt_raise"`
248+
Return *plpgSQL_stmt_return `json:"PLpgSQL_stmt_return"`
249+
ReturnQuery *plpgSQL_stmt_return_query `json:"PLpgSQL_stmt_return_query"`
250+
When *plpgSQL_case_when `json:"PLpgSQL_case_when"`
251+
While *plpgSQL_stmt_while `json:"PLpgSQL_stmt_while"`
245252
}
246253

247254
// Convert converts the JSON statement into its output form.
@@ -522,6 +529,13 @@ func (stmt *plpgSQL_stmt_return) Convert() Return {
522529
}
523530
}
524531

532+
// Convert converts the JSON statement into its output form.
533+
func (stmt *plpgSQL_stmt_return_query) Convert() ReturnQuery {
534+
return ReturnQuery{
535+
Query: stmt.Query.Expression.Query,
536+
}
537+
}
538+
525539
// Convert converts the JSON statement into its output form.
526540
func (stmt *plpgSQL_stmt_while) Convert() (block Block, err error) {
527541
// Convert the body of the loop first so we can determine the GOTO offsets

server/plpgsql/json_convert.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ func jsonConvertStatement(stmt statement) (Statement, error) {
8686
return stmt.Raise.Convert(), nil
8787
case stmt.Return != nil:
8888
return stmt.Return.Convert(), nil
89+
case stmt.ReturnQuery != nil:
90+
return stmt.ReturnQuery.Convert(), nil
8991
case stmt.While != nil:
9092
return stmt.While.Convert()
9193
default:

server/plpgsql/statements.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,33 @@ type Record struct {
298298
Fields []string
299299
}
300300

301+
// ReturnQuery represents a RETURN QUERY statement.
302+
type ReturnQuery struct {
303+
Query string
304+
}
305+
306+
var _ Statement = ReturnQuery{}
307+
308+
// OperationSize implements the interface Statement.
309+
func (r ReturnQuery) OperationSize() int32 {
310+
return 1
311+
}
312+
313+
// AppendOperations implements the interface Statement.
314+
func (r ReturnQuery) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) error {
315+
query, referencedVariables, err := substituteVariableReferences(r.Query, stack)
316+
if err != nil {
317+
return err
318+
}
319+
320+
*ops = append(*ops, InterpreterOperation{
321+
OpCode: OpCode_ReturnQuery,
322+
PrimaryData: query,
323+
SecondaryData: referencedVariables,
324+
})
325+
return nil
326+
}
327+
301328
// Return represents a RETURN statement.
302329
type Return struct {
303330
Expression string

0 commit comments

Comments
 (0)