Skip to content

Commit 09d90ef

Browse files
authored
allow running multiple statements in function (#2501)
1 parent 98153b0 commit 09d90ef

24 files changed

Lines changed: 743 additions & 280 deletions

core/functions/collection.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ type Function struct {
4848
ReturnType id.Type
4949
ParameterNames []string
5050
ParameterTypes []id.Type
51+
ParameterDefaults []string
5152
Variadic bool
5253
IsNonDeterministic bool
5354
Strict bool

core/functions/serialization.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ func (function Function) Serialize(ctx context.Context) ([]byte, error) {
3232

3333
// Write all of the functions to the writer
3434
writer := utils.NewWriter(256)
35-
writer.VariableUint(2) // Version
35+
writer.VariableUint(3) // Version
3636
// Write the function data
3737
writer.Id(function.ID.AsId())
3838
writer.Id(function.ReturnType.AsId())
@@ -58,6 +58,8 @@ func (function Function) Serialize(ctx context.Context) ([]byte, error) {
5858
// Write version 2 data
5959
writer.String(function.SQLDefinition)
6060
writer.Bool(function.SetOf)
61+
// Write version 3 data
62+
writer.StringSlice(function.ParameterDefaults)
6163
// Returns the data
6264
return writer.Data(), nil
6365
}
@@ -70,7 +72,7 @@ func DeserializeFunction(ctx context.Context, data []byte) (Function, error) {
7072
}
7173
reader := utils.NewReader(data)
7274
version := reader.VariableUint()
73-
if version > 2 {
75+
if version > 3 {
7476
return Function{}, errors.Errorf("version %d of functions is not supported, please upgrade the server", version)
7577
}
7678

@@ -101,10 +103,13 @@ func DeserializeFunction(ctx context.Context, data []byte) (Function, error) {
101103
f.ExtensionName = reader.String()
102104
f.ExtensionSymbol = reader.String()
103105
}
104-
if version == 2 {
106+
if version >= 2 {
105107
f.SQLDefinition = reader.String()
106108
f.SetOf = reader.Bool()
107109
}
110+
if version >= 3 {
111+
f.ParameterDefaults = reader.StringSlice()
112+
}
108113
if !reader.IsEmpty() {
109114
return Function{}, errors.Errorf("extra data found while deserializing a function")
110115
}

core/procedures/collection.go

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,16 @@ type Collection struct {
5454

5555
// Procedure represents a created procedure.
5656
type Procedure struct {
57-
ID id.Procedure
58-
ParameterNames []string
59-
ParameterTypes []id.Type
60-
ParameterModes []ParameterMode
61-
Definition string
62-
ExtensionName string // Only used when this is an extension procedure
63-
ExtensionSymbol string // Only used when this is an extension procedure
64-
Operations []plpgsql.InterpreterOperation // Only used when this is a plpgsql language
65-
SQLDefinition string // Only used when this is a sql language
57+
ID id.Procedure
58+
ParameterNames []string
59+
ParameterTypes []id.Type
60+
ParameterModes []ParameterMode
61+
ParameterDefaults []string
62+
Definition string
63+
ExtensionName string // Only used when this is an extension procedure
64+
ExtensionSymbol string // Only used when this is an extension procedure
65+
Operations []plpgsql.InterpreterOperation // Only used when this is a plpgsql language
66+
SQLDefinition string // Only used when this is a sql language
6667
}
6768

6869
var _ objinterface.Collection = (*Collection)(nil)

core/procedures/serialization.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ func (procedure Procedure) Serialize(ctx context.Context) ([]byte, error) {
3232

3333
// Write all of the procedures to the writer
3434
writer := utils.NewWriter(256)
35-
writer.VariableUint(0) // Version
35+
writer.VariableUint(1) // Version
3636
// Write the procedure data
3737
writer.Id(procedure.ID.AsId())
3838
writer.StringSlice(procedure.ParameterNames)
@@ -56,6 +56,8 @@ func (procedure Procedure) Serialize(ctx context.Context) ([]byte, error) {
5656
writer.Int32(int32(op.Index))
5757
writer.StringMap(op.Options)
5858
}
59+
// Write version 1 data
60+
writer.StringSlice(procedure.ParameterDefaults)
5961
// Returns the data
6062
return writer.Data(), nil
6163
}
@@ -68,8 +70,8 @@ func DeserializeProcedure(ctx context.Context, data []byte) (Procedure, error) {
6870
}
6971
reader := utils.NewReader(data)
7072
version := reader.VariableUint()
71-
if version > 0 {
72-
return Procedure{}, errors.Errorf("version %d of functions is not supported, please upgrade the server", version)
73+
if version > 1 {
74+
return Procedure{}, errors.Errorf("version %d of procedures is not supported, please upgrade the server", version)
7375
}
7476

7577
// Read from the reader
@@ -100,6 +102,9 @@ func DeserializeProcedure(ctx context.Context, data []byte) (Procedure, error) {
100102
op.Options = reader.StringMap()
101103
p.Operations[opIdx] = op
102104
}
105+
if version >= 1 {
106+
p.ParameterDefaults = reader.StringSlice()
107+
}
103108
if !reader.IsEmpty() {
104109
return Procedure{}, errors.New("extra data found while deserializing a procedure")
105110
}

server/analyzer/create_function.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,11 @@ func ValidateCreateFunction(ctx *sql.Context, a *analyzer.Analyzer, n sql.Node,
4545
}
4646

4747
builder := planbuilder.New(ctx, a.Catalog, nil)
48-
_, _, err = builder.BindOnly(ct.SqlDefParsed, ct.SqlDef, nil)
49-
if err != nil {
50-
return nil, transform.SameTree, err
48+
for _, parsed := range ct.SqlDefParsedStmts {
49+
_, _, err = builder.BindOnly(parsed, ct.SqlDef, nil)
50+
if err != nil {
51+
return nil, transform.SameTree, err
52+
}
5153
}
5254

5355
return n, transform.SameTree, nil

server/analyzer/init.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,9 @@ const (
4848
ruleId_ValidateCreateTable // validateCreateTable
4949
ruleId_ValidateCreateSchema // validateCreateSchema
5050
ruleId_ResolveAlterColumn // resolveAlterColumn
51-
ruleId_ValidateCreateFunction
52-
ruleId_ResolveValuesTypes // resolveValuesTypes
51+
ruleId_ValidateCreateFunction // validateCreateFunction
52+
ruleId_ResolveValuesTypes // resolveValuesTypes
53+
ruleId_ResolveProcedureDefaults // resolveProcedureDefaults
5354
)
5455

5556
// Init adds additional rules to the analyzer to handle Doltgres-specific functionality.
@@ -66,6 +67,7 @@ func Init() {
6667
analyzer.Rule{Id: ruleId_AssignTriggers, Apply: AssignTriggers},
6768
analyzer.Rule{Id: ruleId_ValidateCreateFunction, Apply: ValidateCreateFunction},
6869
analyzer.Rule{Id: ruleId_ValidateCreateSchema, Apply: ValidateCreateSchema},
70+
analyzer.Rule{Id: ruleId_ResolveProcedureDefaults, Apply: ResolveProcedureDefaults},
6971
)
7072

7173
analyzer.OnceBeforeDefault = append([]analyzer.Rule{

server/analyzer/optimize_functions.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,14 @@
1515
package analyzer
1616

1717
import (
18+
"fmt"
19+
1820
"github.com/cockroachdb/errors"
1921
"github.com/dolthub/go-mysql-server/sql"
2022
"github.com/dolthub/go-mysql-server/sql/analyzer"
23+
"github.com/dolthub/go-mysql-server/sql/expression"
2124
"github.com/dolthub/go-mysql-server/sql/plan"
25+
"github.com/dolthub/go-mysql-server/sql/planbuilder"
2226
"github.com/dolthub/go-mysql-server/sql/transform"
2327

2428
"github.com/dolthub/doltgresql/server/functions/framework"
@@ -58,6 +62,13 @@ func OptimizeFunctions(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, sc
5862
if quickFunction := compiledFunction.GetQuickFunction(); quickFunction != nil {
5963
return quickFunction, transform.NewTree, nil
6064
}
65+
66+
// fill in default exprs if applicable
67+
if err := compiledFunction.ResolveDefaultValues(func(defExpr string) (sql.Expression, error) {
68+
return getDefaultExpr(ctx, a.Catalog, defExpr)
69+
}); err != nil {
70+
return nil, transform.SameTree, err
71+
}
6172
}
6273
if v, ok := in.(*plan.Values); ok {
6374
hasMultipleExpressionTuples = len(v.ExpressionTuples) > 1
@@ -92,6 +103,13 @@ func OptimizeFunctions(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, sc
92103
return nil, transform.SameTree, err
93104
}
94105
}
106+
107+
// fill in default exprs if applicablea
108+
if err = compiledFunction.ResolveDefaultValues(func(defExpr string) (sql.Expression, error) {
109+
return getDefaultExpr(ctx, a.Catalog, defExpr)
110+
}); err != nil {
111+
return nil, transform.SameTree, err
112+
}
95113
}
96114
return expr, transform.SameTree, nil
97115
})
@@ -113,3 +131,17 @@ func OptimizeFunctions(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, sc
113131
return projectNode, sameNode && sameExprs, err
114132
})
115133
}
134+
135+
// getDefaultExpr takes the default value definition, parses, builds and returns sql.ColumnDefaultValue.
136+
func getDefaultExpr(ctx *sql.Context, c sql.Catalog, defExpr string) (sql.Expression, error) {
137+
builder := planbuilder.New(ctx, c, nil)
138+
proj, _, _, _, err := builder.Parse(fmt.Sprintf("select %s", defExpr), nil, false)
139+
if err != nil {
140+
return nil, err
141+
}
142+
parsedExpr := proj.(*plan.Project).Projections[0]
143+
if a, ok := parsedExpr.(*expression.Alias); ok {
144+
parsedExpr = a.Child
145+
}
146+
return parsedExpr, nil
147+
}
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
// Copyright 2026 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package analyzer
16+
17+
import (
18+
"strings"
19+
20+
"github.com/dolthub/go-mysql-server/sql"
21+
"github.com/dolthub/go-mysql-server/sql/analyzer"
22+
"github.com/dolthub/go-mysql-server/sql/plan"
23+
"github.com/dolthub/go-mysql-server/sql/transform"
24+
25+
"github.com/dolthub/doltgresql/core"
26+
"github.com/dolthub/doltgresql/core/extensions"
27+
"github.com/dolthub/doltgresql/core/id"
28+
"github.com/dolthub/doltgresql/server/functions"
29+
"github.com/dolthub/doltgresql/server/functions/framework"
30+
pgnodes "github.com/dolthub/doltgresql/server/node"
31+
pgtypes "github.com/dolthub/doltgresql/server/types"
32+
)
33+
34+
// ResolveProcedureDefaults resolves default expressions of routines that are in string format by parsing it into sql.Expression.
35+
// This function retrieves the procedure overloads and sets CompiledFunction in the Call node.
36+
func ResolveProcedureDefaults(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, scope *plan.Scope, selector analyzer.RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) {
37+
switch n := node.(type) {
38+
case *pgnodes.Call:
39+
procCollection, err := core.GetProceduresCollectionFromContext(ctx)
40+
if err != nil {
41+
return nil, transform.SameTree, err
42+
}
43+
typesCollection, err := core.GetTypesCollectionFromContext(ctx)
44+
if err != nil {
45+
return nil, transform.SameTree, err
46+
}
47+
schemaName, err := core.GetSchemaName(ctx, nil, n.SchemaName)
48+
if err != nil {
49+
return nil, transform.SameTree, err
50+
}
51+
procName := id.NewProcedure(schemaName, n.ProcedureName)
52+
overloads, err := procCollection.GetProcedureOverloads(ctx, procName)
53+
if err != nil {
54+
return nil, transform.SameTree, err
55+
}
56+
if len(overloads) == 0 {
57+
if strings.HasPrefix(n.ProcedureName, "dolt_") {
58+
return nil, transform.SameTree, functions.ErrDoltProcedureSelectOnly
59+
}
60+
return nil, transform.SameTree, sql.ErrStoredProcedureDoesNotExist.New(n.ProcedureName)
61+
}
62+
63+
same := transform.SameTree
64+
overloadTree := framework.NewOverloads()
65+
for _, overload := range overloads {
66+
paramTypes := make([]*pgtypes.DoltgresType, len(overload.ParameterTypes))
67+
for i, paramType := range overload.ParameterTypes {
68+
paramTypes[i], err = typesCollection.GetType(ctx, paramType)
69+
if err != nil || paramTypes[i] == nil {
70+
return nil, transform.SameTree, err
71+
}
72+
}
73+
// TODO: we should probably have procedure equivalents instead of converting these to functions
74+
// probably fine for now since we don't implement/support the differing functionality between the two just yet
75+
if len(overload.ExtensionName) > 0 {
76+
if err = overloadTree.Add(framework.CFunction{
77+
ID: id.Function(overload.ID),
78+
ReturnType: pgtypes.Void,
79+
ParameterTypes: paramTypes,
80+
Variadic: false,
81+
IsNonDeterministic: true,
82+
Strict: false,
83+
ExtensionName: extensions.LibraryIdentifier(overload.ExtensionName),
84+
ExtensionSymbol: overload.ExtensionSymbol,
85+
}); err != nil {
86+
return nil, transform.SameTree, err
87+
}
88+
} else if len(overload.SQLDefinition) > 0 {
89+
if err = overloadTree.Add(framework.SQLFunction{
90+
ID: id.Function(overload.ID),
91+
ReturnType: pgtypes.Void,
92+
ParameterNames: overload.ParameterNames,
93+
ParameterTypes: paramTypes,
94+
ParameterDefaults: overload.ParameterDefaults,
95+
Variadic: false,
96+
IsNonDeterministic: true,
97+
Strict: false,
98+
SqlStatement: overload.SQLDefinition,
99+
SetOf: false,
100+
}); err != nil {
101+
return nil, transform.SameTree, err
102+
}
103+
} else {
104+
if err = overloadTree.Add(framework.InterpretedFunction{
105+
ID: id.Function(overload.ID),
106+
ReturnType: pgtypes.Void,
107+
ParameterNames: overload.ParameterNames,
108+
ParameterTypes: paramTypes,
109+
Variadic: false,
110+
IsNonDeterministic: true,
111+
Strict: false,
112+
Statements: overload.Operations,
113+
}); err != nil {
114+
return nil, transform.SameTree, err
115+
}
116+
}
117+
}
118+
compiledFunction := framework.NewCompiledFunction(n.ProcedureName, n.Exprs, overloadTree, false)
119+
// fill in default exprs if applicable
120+
if err := compiledFunction.ResolveDefaultValues(func(defExpr string) (sql.Expression, error) {
121+
return getDefaultExpr(ctx, a.Catalog, defExpr)
122+
}); err != nil {
123+
return nil, transform.SameTree, err
124+
}
125+
n.CompiledFunc = compiledFunction
126+
return node, same, nil
127+
default:
128+
return node, transform.SameTree, nil
129+
}
130+
}

server/analyzer/resolve_type.go

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,26 +70,22 @@ func ResolveTypeForNodes(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node,
7070
if err != nil {
7171
return nil, transform.NewTree, err
7272
}
73-
paramTypes := make([]*pgtypes.DoltgresType, len(n.ParameterTypes))
74-
for i := range n.ParameterTypes {
75-
paramTypes[i], err = resolveType(ctx, db, n.ParameterTypes[i])
73+
for i := range n.Parameters {
74+
n.Parameters[i].Type, err = resolveType(ctx, db, n.Parameters[i].Type)
7675
if err != nil {
7776
return nil, transform.NewTree, err
7877
}
7978
}
8079
n.ReturnType = retType
81-
n.ParameterTypes = paramTypes
8280
return node, transform.NewTree, nil
8381
case *pgnodes.CreateProcedure:
84-
paramTypes := make([]*pgtypes.DoltgresType, len(n.ParameterTypes))
85-
for i := range n.ParameterTypes {
82+
for i := range n.Parameters {
8683
var err error
87-
paramTypes[i], err = resolveType(ctx, db, n.ParameterTypes[i])
84+
n.Parameters[i].Type, err = resolveType(ctx, db, n.Parameters[i].Type)
8885
if err != nil {
8986
return nil, transform.NewTree, err
9087
}
9188
}
92-
n.ParameterTypes = paramTypes
9389
return node, transform.NewTree, nil
9490
case *plan.CreateTable:
9591
for _, col := range n.TargetSchema() {

server/analyzer/validate_column_defaults.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import (
2424
pgnode "github.com/dolthub/doltgresql/server/node"
2525
)
2626

27-
// validateColumnDefaults ensures that newly created column defaults from a DDL statement are legal for the type of
27+
// ValidateColumnDefaults ensures that newly created column defaults from a DDL statement are legal for the type of
2828
// column, various other business logic checks to match MySQL's logic.
2929
func ValidateColumnDefaults(ctx *sql.Context, _ *analyzer.Analyzer, n sql.Node, _ *plan.Scope, _ analyzer.RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) {
3030
span, ctx := ctx.Span("validateColumnDefaults")

0 commit comments

Comments
 (0)