@@ -20,6 +20,7 @@ import (
2020
2121 "github.com/cockroachdb/errors"
2222 "github.com/dolthub/go-mysql-server/sql"
23+ "github.com/dolthub/go-mysql-server/sql/plan"
2324
2425 "github.com/dolthub/doltgresql/core/triggers"
2526 pgexprs "github.com/dolthub/doltgresql/server/expression"
@@ -97,6 +98,19 @@ func (te *TriggerExecution) BuildRowIter(ctx *sql.Context, b sql.NodeExecBuilder
9798 }
9899 }
99100 }
101+
102+ tgOp := ""
103+ switch te .Source .(type ) {
104+ case * plan.InsertInto :
105+ tgOp = "INSERT"
106+ case * plan.Update :
107+ tgOp = "UPDATE"
108+ case * plan.DeleteFrom :
109+ tgOp = "DELETE"
110+ case * plan.Truncate :
111+ tgOp = "TRUNCATE"
112+ }
113+
100114 return & triggerExecutionIter {
101115 functions : trigFuncs ,
102116 whens : whens ,
@@ -105,6 +119,7 @@ func (te *TriggerExecution) BuildRowIter(ctx *sql.Context, b sql.NodeExecBuilder
105119 runner : te .Runner .Runner ,
106120 sch : te .Sch ,
107121 source : sourceIter ,
122+ tgOp : tgOp ,
108123 }, nil
109124}
110125
@@ -171,6 +186,7 @@ type triggerExecutionIter struct {
171186 runner sql.StatementRunner
172187 sch sql.Schema
173188 source sql.RowIter
189+ tgOp string
174190}
175191
176192var _ sql.RowIter = (* triggerExecutionIter )(nil )
@@ -195,9 +211,16 @@ func (t *triggerExecutionIter) Next(ctx *sql.Context) (sql.Row, error) {
195211 case TriggerExecutionRowHandling_New :
196212 newRow = nextRow
197213 }
214+
215+ // TODO: handle other special variables
216+ triggerVars := make (map [string ]any )
217+ if t .tgOp != "" {
218+ triggerVars ["TG_OP" ] = t .tgOp
219+ }
220+
198221 for funcIdx , function := range t .functions {
199222 if t .whens [funcIdx ].ID .IsValid () {
200- whenValue , err := plpgsql .TriggerCall (ctx , t .whens [funcIdx ], t .runner , t .sch , oldRow , newRow )
223+ whenValue , err := plpgsql .TriggerCall (ctx , t .whens [funcIdx ], t .runner , t .sch , oldRow , newRow , triggerVars )
201224 if err != nil {
202225 if strings .Contains (err .Error (), "no valid cast for return value" ) {
203226 // TODO: this error should technically be caught during parsing, but interpreted functions don't
@@ -214,7 +237,8 @@ func (t *triggerExecutionIter) Next(ctx *sql.Context) (sql.Row, error) {
214237 continue
215238 }
216239 }
217- returnedValue , err := plpgsql .TriggerCall (ctx , function , t .runner , t .sch , oldRow , newRow )
240+
241+ returnedValue , err := plpgsql .TriggerCall (ctx , function , t .runner , t .sch , oldRow , newRow , triggerVars )
218242 if err != nil {
219243 return nil , err
220244 }
0 commit comments