Skip to content

Commit 7d8b791

Browse files
committed
support RETURN stmt
1 parent 79b3e8e commit 7d8b791

7 files changed

Lines changed: 184 additions & 11 deletions

File tree

postgres/parser/parser/sql.y

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4611,6 +4611,10 @@ begin_end_block:
46114611
{
46124612
$$.val = &tree.BeginEndBlock{Statements: $3.stmts()}
46134613
}
4614+
| BEGIN ATOMIC RETURN a_expr ';' END
4615+
{
4616+
$$.val = &tree.BeginEndBlock{Statements: []tree.Statement{&tree.Return{Expr: $4.expr()}}}
4617+
}
46144618

46154619
opt_schema:
46164620
/* EMPTY */

server/ast/convert.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,8 @@ func Convert(postgresStmt parser.Statement) (vitess.Statement, error) {
191191
return nodeReparentDatabase(ctx, stmt)
192192
case *tree.Restore:
193193
return nodeRestore(ctx, stmt)
194+
case *tree.Return:
195+
return nodeReturn(ctx, stmt)
194196
case *tree.Revoke:
195197
return nodeRevoke(ctx, stmt)
196198
case *tree.RevokeRole:

server/ast/return.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// Copyright 2026 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package ast
16+
17+
import (
18+
vitess "github.com/dolthub/vitess/go/vt/sqlparser"
19+
20+
pgnodes "github.com/dolthub/doltgresql/server/node"
21+
22+
"github.com/dolthub/doltgresql/postgres/parser/sem/tree"
23+
)
24+
25+
// nodeReturn handles *tree.Return nodes.
26+
func nodeReturn(ctx *Context, node *tree.Return) (vitess.Statement, error) {
27+
if node == nil {
28+
return nil, nil
29+
}
30+
31+
expr, err := nodeExpr(ctx, node.Expr)
32+
if err != nil {
33+
return nil, err
34+
}
35+
36+
return vitess.InjectedStatement{
37+
Statement: pgnodes.NewReturn(node.Expr.String()),
38+
Children: []vitess.Expr{expr},
39+
}, nil
40+
}

server/functions/framework/sql_function.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package framework
1616

1717
import (
1818
"fmt"
19+
"strings"
1920

2021
"github.com/cockroachdb/errors"
2122
"github.com/dolthub/go-mysql-server/sql"
@@ -110,6 +111,10 @@ func CallSqlFunction(ctx *sql.Context, f SQLFunction, runner sql.StatementRunner
110111
}
111112
}
112113

114+
if lower := strings.ToLower(f.SqlStatement); strings.HasPrefix(lower, "return") {
115+
f.SqlStatement = fmt.Sprintf("SELECT%s", f.SqlStatement[6:])
116+
}
117+
113118
parseds, err := parser.Parse(f.SqlStatement)
114119
if err != nil {
115120
return "", err
@@ -248,6 +253,9 @@ func ReplaceFunctionColumn(parsedAST tree.Statement, params map[string]*ParamTyp
248253
return nil
249254
case *tree.Truncate:
250255
return nil
256+
case *tree.Return:
257+
s.Expr = ReplaceUnresolvedToFunctionColumn(params, s.Expr)
258+
return nil
251259
default:
252260
return errors.Errorf("unsupported statement defined in function: %T", parsedAST)
253261
}

server/node/return.go

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
// Copyright 2026 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package node
16+
17+
import (
18+
"fmt"
19+
20+
"github.com/dolthub/go-mysql-server/sql"
21+
vitess "github.com/dolthub/vitess/go/vt/sqlparser"
22+
)
23+
24+
// Return represents the statement RETURN statement.
25+
type Return struct {
26+
Expr sql.Expression
27+
exprStmt string
28+
}
29+
30+
var _ sql.ExecSourceRel = (*Return)(nil)
31+
var _ vitess.Injectable = (*Return)(nil)
32+
33+
// NewReturn creates a new *Return node.
34+
func NewReturn(exprStmt string) *Return {
35+
return &Return{
36+
Expr: nil,
37+
exprStmt: exprStmt,
38+
}
39+
}
40+
41+
// Children implements the interface sql.ExecSourceRel.
42+
func (r *Return) Children() []sql.Node {
43+
return nil
44+
}
45+
46+
// IsReadOnly implements the interface sql.ExecSourceRel.
47+
func (r *Return) IsReadOnly() bool {
48+
return true
49+
}
50+
51+
// Resolved implements the interface sql.ExecSourceRel.
52+
func (r *Return) Resolved() bool {
53+
if r.Expr == nil {
54+
return false
55+
}
56+
return !r.Expr.Resolved()
57+
}
58+
59+
// RowIter implements the interface sql.ExecSourceRel.
60+
func (r *Return) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) {
61+
// TODO: this cannot be called as we replace RETURN with SELECT to be able to parse the expression
62+
//val, err := r.Expr.Eval(ctx, row)
63+
//if err != nil {
64+
// return nil, err
65+
//}
66+
return sql.RowsToRowIter(), nil
67+
}
68+
69+
// String implements the interface sql.ExecSourceRel.
70+
func (r *Return) String() string {
71+
if r.Expr == nil {
72+
return fmt.Sprintf("RETURN %s", r.exprStmt)
73+
}
74+
return fmt.Sprintf("RETURN %s", r.Expr.String())
75+
}
76+
77+
// Schema implements the interface sql.ExecSourceRel.
78+
func (r *Return) Schema() sql.Schema {
79+
return sql.Schema{
80+
{Name: r.Expr.String(), Type: r.Expr.Type(), Source: ""},
81+
}
82+
}
83+
84+
// WithChildren implements the interface sql.ExecSourceRel.
85+
func (r *Return) WithChildren(children ...sql.Node) (sql.Node, error) {
86+
if len(children) != 0 {
87+
return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), 0)
88+
}
89+
return r, nil
90+
}
91+
92+
// WithResolvedChildren implements the interface sql.ExecSourceRel.
93+
func (r *Return) WithResolvedChildren(children []any) (any, error) {
94+
if len(children) != 1 {
95+
return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), 1)
96+
}
97+
98+
nr := *r
99+
nr.Expr = children[0].(sql.Expression)
100+
return &nr, nil
101+
}

testing/go/create_function_sql_test.go

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ func TestCreateFunctionsLanguageSQL(t *testing.T) {
362362
},
363363
},
364364
{
365-
Name: "use BEGIN ATOMIC ... END in sql_body",
365+
Name: "use sql statements in BEGIN ATOMIC ... END in sql_body",
366366
SetUpScript: []string{},
367367
Assertions: []ScriptTestAssertion{
368368
{
@@ -374,7 +374,7 @@ func TestCreateFunctionsLanguageSQL(t *testing.T) {
374374
Expected: []sql.Row{},
375375
},
376376
{
377-
Skip: true, // TODO
377+
Skip: true, // TODO support json_build_object() function
378378
Query: `SELECT public.match_default();`,
379379
Expected: []sql.Row{{`{"k": 6, "m": 2048, "tokenizer": {"kind": "ngram", "token_length": 3}, "token_filters": [{"kind": "downcase"}], "include_original": true}`}},
380380
},
@@ -392,5 +392,23 @@ func TestCreateFunctionsLanguageSQL(t *testing.T) {
392392
},
393393
},
394394
},
395+
{
396+
Name: "use RETURN in BEGIN ATOMIC ... END in sql_body",
397+
SetUpScript: []string{},
398+
Assertions: []ScriptTestAssertion{
399+
{
400+
Query: `CREATE FUNCTION return1() RETURNS text
401+
LANGUAGE sql
402+
BEGIN ATOMIC
403+
RETURN 1::text || 'one';
404+
END;`,
405+
Expected: []sql.Row{},
406+
},
407+
{
408+
Query: `SELECT return1();`,
409+
Expected: []sql.Row{{"1one"}},
410+
},
411+
},
412+
},
395413
})
396414
}

testing/go/import_dumps_test.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,16 @@ import (
3333

3434
// TestImportingDumps are regression tests against dumps taken from various sources.
3535
func TestImportingDumps(t *testing.T) {
36-
t.Skip("The majority fail for now")
36+
//t.Skip("The majority fail for now")
3737
RunImportTests(t, []ImportTest{
38-
{
39-
Name: "Scrubbed-1",
40-
SetUpScript: []string{
41-
"CREATE USER behfjgnf WITH SUPERUSER PASSWORD 'password';",
42-
},
43-
SkipQueries: []string{"CREATE UNIQUE INDEX dawkmezfehakyikllr"},
44-
SQLFilename: "scrubbed-1.sql",
45-
},
38+
//{
39+
// Name: "Scrubbed-1",
40+
// SetUpScript: []string{
41+
// "CREATE USER behfjgnf WITH SUPERUSER PASSWORD 'password';",
42+
// },
43+
// SkipQueries: []string{"CREATE UNIQUE INDEX dawkmezfehakyikllr"},
44+
// SQLFilename: "scrubbed-1.sql",
45+
//},
4646
{
4747
Name: "A-lang209/Salon-Appointment-Scheduler",
4848
Skip: true, // Database creation uses unsupported params then attempts to connect, hangs indefinitely

0 commit comments

Comments
 (0)