Skip to content

Commit 8f66d1a

Browse files
committed
add plpgsql-dynexecute statement
1 parent 1430450 commit 8f66d1a

12 files changed

Lines changed: 214 additions & 8 deletions
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Copyright 2026 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package ast
16+
17+
import (
18+
vitess "github.com/dolthub/vitess/go/vt/sqlparser"
19+
20+
"github.com/dolthub/doltgresql/postgres/parser/sem/tree"
21+
)
22+
23+
// nodeAlterDefaultPrivileges handles *tree.AlterDefaultPrivileges nodes.
24+
func nodeAlterDefaultPrivileges(ctx *Context, node *tree.AlterDefaultPrivileges) (vitess.Statement, error) {
25+
return NotYetSupportedError("ALTER DEFAULT PRIVILEGES statement is not yet supported")
26+
}

server/ast/convert.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ func Convert(postgresStmt parser.Statement) (vitess.Statement, error) {
3131
return nodeAlterAggregate(ctx, stmt)
3232
case *tree.AlterDatabase:
3333
return nodeAlterDatabase(ctx, stmt)
34+
case *tree.AlterDefaultPrivileges:
35+
return nodeAlterDefaultPrivileges(ctx, stmt)
3436
case *tree.AlterFunction:
3537
return nodeAlterFunction(ctx, stmt)
3638
case *tree.AlterIndex:

server/functions/framework/interpreted_function.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ func (InterpretedFunction) ApplyBindings(ctx *sql.Context, stack plpgsql.Interpr
211211
}
212212
if enforceType {
213213
switch variable.Type.TypCategory {
214-
case pgtypes.TypeCategory_ArrayTypes, pgtypes.TypeCategory_DateTimeTypes, pgtypes.TypeCategory_StringTypes:
214+
case pgtypes.TypeCategory_ArrayTypes, pgtypes.TypeCategory_DateTimeTypes, pgtypes.TypeCategory_StringTypes, pgtypes.TypeCategory_UserDefinedTypes:
215215
formattedVar = pq.QuoteLiteral(formattedVar)
216216
}
217217
}

server/node/trigger_execution.go

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020

2121
"github.com/cockroachdb/errors"
2222
"github.com/dolthub/go-mysql-server/sql"
23+
"github.com/dolthub/go-mysql-server/sql/plan"
2324

2425
"github.com/dolthub/doltgresql/core/triggers"
2526
pgexprs "github.com/dolthub/doltgresql/server/expression"
@@ -97,6 +98,19 @@ func (te *TriggerExecution) BuildRowIter(ctx *sql.Context, b sql.NodeExecBuilder
9798
}
9899
}
99100
}
101+
102+
tgOp := ""
103+
switch te.Source.(type) {
104+
case *plan.InsertInto:
105+
tgOp = "INSERT"
106+
case *plan.Update:
107+
tgOp = "UPDATE"
108+
case *plan.DeleteFrom:
109+
tgOp = "DELETE"
110+
case *plan.Truncate:
111+
tgOp = "TRUNCATE"
112+
}
113+
100114
return &triggerExecutionIter{
101115
functions: trigFuncs,
102116
whens: whens,
@@ -105,6 +119,7 @@ func (te *TriggerExecution) BuildRowIter(ctx *sql.Context, b sql.NodeExecBuilder
105119
runner: te.Runner.Runner,
106120
sch: te.Sch,
107121
source: sourceIter,
122+
tgOp: tgOp,
108123
}, nil
109124
}
110125

@@ -171,6 +186,7 @@ type triggerExecutionIter struct {
171186
runner sql.StatementRunner
172187
sch sql.Schema
173188
source sql.RowIter
189+
tgOp string
174190
}
175191

176192
var _ sql.RowIter = (*triggerExecutionIter)(nil)
@@ -195,9 +211,16 @@ func (t *triggerExecutionIter) Next(ctx *sql.Context) (sql.Row, error) {
195211
case TriggerExecutionRowHandling_New:
196212
newRow = nextRow
197213
}
214+
215+
// TODO: handle other special variables
216+
triggerVars := make(map[string]any)
217+
if t.tgOp != "" {
218+
triggerVars["TG_OP"] = t.tgOp
219+
}
220+
198221
for funcIdx, function := range t.functions {
199222
if t.whens[funcIdx].ID.IsValid() {
200-
whenValue, err := plpgsql.TriggerCall(ctx, t.whens[funcIdx], t.runner, t.sch, oldRow, newRow)
223+
whenValue, err := plpgsql.TriggerCall(ctx, t.whens[funcIdx], t.runner, t.sch, oldRow, newRow, triggerVars)
201224
if err != nil {
202225
if strings.Contains(err.Error(), "no valid cast for return value") {
203226
// TODO: this error should technically be caught during parsing, but interpreted functions don't
@@ -214,7 +237,8 @@ func (t *triggerExecutionIter) Next(ctx *sql.Context) (sql.Row, error) {
214237
continue
215238
}
216239
}
217-
returnedValue, err := plpgsql.TriggerCall(ctx, function, t.runner, t.sch, oldRow, newRow)
240+
241+
returnedValue, err := plpgsql.TriggerCall(ctx, function, t.runner, t.sch, oldRow, newRow, triggerVars)
218242
if err != nil {
219243
return nil, err
220244
}

server/plpgsql/interpreter_logic.go

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,19 @@ func Call(ctx *sql.Context, iFunc InterpretedFunction, runner sql.StatementRunne
6666
}
6767

6868
// TriggerCall runs the contained trigger operations on the given runner.
69-
func TriggerCall(ctx *sql.Context, iFunc InterpretedFunction, runner sql.StatementRunner, sch sql.Schema, oldRow sql.Row, newRow sql.Row) (any, error) {
69+
func TriggerCall(ctx *sql.Context, iFunc InterpretedFunction, runner sql.StatementRunner, sch sql.Schema, oldRow sql.Row, newRow sql.Row, trigVars map[string]any) (any, error) {
7070
// Set up the initial state of the function
7171
stack := NewInterpreterStack(runner)
7272
// Add the special variables
73-
// TODO: there are way more than just NEW and OLD -> https://www.postgresql.org/docs/15/plpgsql-trigger.html
7473
stack.NewRecord("OLD", sch, oldRow)
7574
stack.NewRecord("NEW", sch, newRow)
75+
for varName, val := range trigVars {
76+
varType, ok := triggerSpecialVariables[varName]
77+
if !ok {
78+
return nil, fmt.Errorf("unknown variable %s for trigger", varName)
79+
}
80+
stack.NewVariableWithValue(varName, varType, val)
81+
}
7682
return call(ctx, iFunc, stack)
7783
}
7884

@@ -388,3 +394,21 @@ func evaluteNoticeMessage(ctx *sql.Context, iFunc InterpretedFunction,
388394
}
389395
return message, nil
390396
}
397+
398+
// triggerSpecialVariables are the list of special variables for triggers.
399+
// https://www.postgresql.org/docs/15/plpgsql-trigger.html
400+
// TODO: NEW and OLD variables are handled separately using `InterpreterStack.NewRecord` function.
401+
var triggerSpecialVariables = map[string]*pgtypes.DoltgresType{
402+
//"NEW":
403+
//"OLD":
404+
"TG_NAME": pgtypes.Name,
405+
"TG_WHEN": pgtypes.Text,
406+
"TG_LEVEL": pgtypes.Text,
407+
"TG_OP": pgtypes.Text,
408+
"TG_RELID": pgtypes.Oid,
409+
"TG_RELNAME": pgtypes.Name,
410+
"TG_TABLE_NAME": pgtypes.Name,
411+
"TG_TABLE_SCHEMA": pgtypes.Name,
412+
"TG_NARGS": pgtypes.Int32,
413+
"TG_ARGV[]": pgtypes.TextArray,
414+
}

server/plpgsql/json.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,15 @@ type plpgSQL_stmt_case struct {
132132
Else []statement `json:"else_stmts"`
133133
}
134134

135+
// plpgSQL_stmt_dynexecute exists to match the expected JSON format.
136+
type plpgSQL_stmt_dynexecute struct {
137+
LineNumber int32 `json:"lineno"`
138+
Into bool `json:"into"`
139+
Query expr `json:"query"`
140+
Target datum `json:"target"`
141+
Params []sqlstmt `json:"params"`
142+
}
143+
135144
// plpgSQL_case_when exists to match the expected JSON format.
136145
type plpgSQL_case_when struct {
137146
LineNumber int32 `json:"lineno"`
@@ -239,6 +248,7 @@ type sqlstmt struct {
239248
type statement struct {
240249
Assignment *plpgSQL_stmt_assign `json:"PLpgSQL_stmt_assign"`
241250
Case *plpgSQL_stmt_case `json:"PLpgSQL_stmt_case"`
251+
DynExec *plpgSQL_stmt_dynexecute `json:"PLpgSQL_stmt_dynexecute"`
242252
ExecSQL *plpgSQL_stmt_execsql `json:"PLpgSQL_stmt_execsql"`
243253
Exit *plpgSQL_stmt_exit `json:"PLpgSQL_stmt_exit"`
244254
If *plpgSQL_stmt_if `json:"PLpgSQL_stmt_if"`
@@ -358,6 +368,34 @@ func (stmt *plpgSQL_stmt_case) Convert() (block Block, err error) {
358368
return block, nil
359369
}
360370

371+
// Convert converts the JSON statement into its output form.
372+
func (stmt *plpgSQL_stmt_dynexecute) Convert() (DynamicExecute, error) {
373+
var params []string
374+
for _, param := range stmt.Params {
375+
params = append(params, param.Expr.Query)
376+
}
377+
var target string
378+
if stmt.Into {
379+
switch {
380+
case stmt.Target.Row != nil:
381+
if len(stmt.Target.Row.Fields) != 1 {
382+
return DynamicExecute{}, errors.New("record types are not yet supported")
383+
}
384+
target = stmt.Target.Row.Fields[0].Name
385+
case stmt.Target.Variable != nil:
386+
target = stmt.Target.Variable.RefName
387+
default:
388+
return DynamicExecute{}, errors.Errorf("unhandled datum type: %T", stmt.Target)
389+
}
390+
}
391+
query := strings.TrimSuffix(strings.TrimPrefix(stmt.Query.Expression.Query, "'"), "'")
392+
return DynamicExecute{
393+
Query: query,
394+
Params: params,
395+
Target: target,
396+
}, nil
397+
}
398+
361399
// Convert converts the JSON statement into its output form.
362400
func (stmt *plpgSQL_stmt_execsql) Convert() (ExecuteSQL, error) {
363401
var target string

server/plpgsql/json_convert.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ func jsonConvertStatement(stmt statement) (Statement, error) {
8989
return stmt.Assignment.Convert()
9090
case stmt.Case != nil:
9191
return stmt.Case.Convert()
92+
case stmt.DynExec != nil:
93+
return stmt.DynExec.Convert()
9294
case stmt.ExecSQL != nil:
9395
return stmt.ExecSQL.Convert()
9496
case stmt.Exit != nil:

server/plpgsql/statements.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,31 @@ func (stmt ExecuteSQL) AppendOperations(ops *[]InterpreterOperation, stack *Inte
165165
return nil
166166
}
167167

168+
// DynamicExecute represents a dynamic SQL statement's execution.
169+
type DynamicExecute struct {
170+
Query string
171+
Params []string
172+
Target string
173+
}
174+
175+
var _ Statement = DynamicExecute{}
176+
177+
// OperationSize implements the interface Statement.
178+
func (DynamicExecute) OperationSize() int32 {
179+
return 1
180+
}
181+
182+
// AppendOperations implements the interface Statement.
183+
func (stmt DynamicExecute) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) error {
184+
*ops = append(*ops, InterpreterOperation{
185+
OpCode: OpCode_Execute,
186+
PrimaryData: stmt.Query,
187+
SecondaryData: stmt.Params,
188+
Target: stmt.Target,
189+
})
190+
return nil
191+
}
192+
168193
// Goto jumps to the counter at the given offset.
169194
type Goto struct {
170195
Offset int32
@@ -400,6 +425,9 @@ func substituteVariableReferences(expression string, stack *InterpreterStack) (n
400425
} else {
401426
newExpression += substring + " "
402427
}
428+
} else if _, ok := triggerSpecialVariables[substring]; ok {
429+
referencedVars = append(referencedVars, substring)
430+
newExpression += fmt.Sprintf("$%d ", len(referencedVars))
403431
} else {
404432
newExpression += substring + " "
405433
}

testing/dumps/intercept.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,12 @@ StartupLoop:
142142
return err
143143
}
144144
break StartupLoop
145+
case *pgproto3.GSSEncRequest:
146+
// we don't support GSSAPI
147+
_, err = clientConn.Write([]byte("N"))
148+
if err != nil {
149+
return err
150+
}
145151
default:
146152
t.Fatalf("unexpected startup message: %v", startupMessage)
147153
}

testing/dumps/sql/Ansh-Rathod_Musive-backend-2.0.sql

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ create table public."Collections"(
5252
name text not null,
5353
username varchar(28) not null,
5454
total_tracks integer DEFAULT 0
55-
)
55+
);
5656

5757
alter table public."Collections" add constraint user_id FOREIGN KEY(username)
5858
REFERENCES public."Users"(username) match full on update CASCADE on delete cascade;
@@ -75,12 +75,12 @@ CREATE OR REPLACE FUNCTION update_collections()
7575
DECLARE
7676
BEGIN
7777
IF TG_OP = 'INSERT' THEN
78-
EXECUTE 'update public."Collections" set total_tracks=total_tracks+1 where id = $1;'
78+
EXECUTE 'update public."Collections" set total_tracks=total_tracks+1 where id = $1;'
7979
USING NEW.collection_id;
8080
END IF;
8181

8282
IF TG_OP = 'DELETE' THEN
83-
EXECUTE 'update public."Collections" set total_tracks=total_tracks-1 where id = $1;'
83+
EXECUTE 'update public."Collections" set total_tracks=total_tracks-1 where id = $1;'
8484
USING OLD.collection_id;
8585
END IF;
8686

0 commit comments

Comments
 (0)