Skip to content

Commit df3db0a

Browse files
authored
Merge pull request #2128 from dolthub/daylon/delete-table-hooks
Added some hooks for DROP TABLE
2 parents b21726c + 0751267 commit df3db0a

15 files changed

Lines changed: 7743 additions & 6817 deletions

File tree

core/table_to_dolt.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
// Copyright 2025 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package core
16+
17+
import (
18+
"github.com/dolthub/dolt/go/libraries/doltcore/sqle"
19+
"github.com/dolthub/go-mysql-server/sql"
20+
"github.com/dolthub/go-mysql-server/sql/plan"
21+
)
22+
23+
// SQLNodeToDoltTable takes a sql.Node and returns a *sqle.DoltTable if either the node is a Dolt table, or it is a
24+
// wrapper or container that holds a Dolt table. Returns nil if a Dolt table could not be found. If the node is not a
25+
// sql.Table, then this will return nil.
26+
func SQLNodeToDoltTable(n sql.Node) *sqle.DoltTable {
27+
tbl, ok := n.(sql.Table)
28+
if !ok {
29+
return nil
30+
}
31+
return SQLTableToDoltTable(tbl)
32+
}
33+
34+
// SQLTableToDoltTable takes a sql.Table and returns a *sqle.DoltTable if either the table is a Dolt table, or it is a
35+
// wrapper or container that holds a Dolt table. Returns nil if a Dolt table could not be found.
36+
func SQLTableToDoltTable(tbl sql.Table) *sqle.DoltTable {
37+
switch t := tbl.(type) {
38+
case *plan.ResolvedTable:
39+
return SQLTableToDoltTable(t.Table)
40+
case *plan.ProcessTable:
41+
return SQLTableToDoltTable(t.Table)
42+
case *plan.IndexedTableAccess:
43+
return SQLTableToDoltTable(t.Table)
44+
case *plan.ProcedureResolvedTable:
45+
return SQLTableToDoltTable(t.ResolvedTable.Table)
46+
case *sqle.WritableIndexedDoltTable:
47+
return t.WritableDoltTable.DoltTable
48+
case *sqle.IndexedDoltTable:
49+
return t.DoltTable
50+
case *sqle.AlterableDoltTable:
51+
return t.WritableDoltTable.DoltTable
52+
case *sqle.WritableDoltTable:
53+
return t.DoltTable
54+
case *sqle.DoltTable:
55+
return t
56+
default:
57+
if wrapper, ok := tbl.(sql.TableWrapper); ok {
58+
return SQLTableToDoltTable(wrapper.Underlying())
59+
}
60+
return nil
61+
}
62+
}

server/ast/call.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func nodeCall(ctx *Context, node *tree.Call) (vitess.Statement, error) {
3737
if node.Procedure.WindowDef != nil {
3838
return nil, errors.Errorf("procedure window definitions are not yet supported")
3939
}
40-
if node.Procedure.AggType != tree.GeneralAgg {
40+
if node.Procedure.AggType == tree.OrderedSetAgg {
4141
return nil, errors.Errorf("procedure aggregation is not yet supported")
4242
}
4343
if len(node.Procedure.OrderBy) > 0 {

server/ast/convert.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ func Convert(postgresStmt parser.Statement) (vitess.Statement, error) {
129129
return nodeDropFunction(ctx, stmt)
130130
case *tree.DropIndex:
131131
return nodeDropIndex(ctx, stmt)
132+
case *tree.DropProcedure:
133+
return nodeDropProcedure(ctx, stmt)
132134
case *tree.DropRole:
133135
return nodeDropRole(ctx, stmt)
134136
case *tree.DropSchema:

server/ast/drop_procedure.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// Copyright 2025 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package ast
16+
17+
import (
18+
"fmt"
19+
20+
vitess "github.com/dolthub/vitess/go/vt/sqlparser"
21+
22+
"github.com/dolthub/doltgresql/postgres/parser/sem/tree"
23+
pgnodes "github.com/dolthub/doltgresql/server/node"
24+
)
25+
26+
// nodeDropProcedure handles *tree.DropProcedure nodes.
27+
func nodeDropProcedure(_ *Context, node *tree.DropProcedure) (vitess.Statement, error) {
28+
if node == nil {
29+
return nil, nil
30+
}
31+
32+
if node.DropBehavior == tree.DropCascade {
33+
return nil, fmt.Errorf("DROP PROCEDURE with CASCADE is not supported yet")
34+
}
35+
36+
if len(node.Procedures) == 0 {
37+
return nil, fmt.Errorf("no function name specified for DROP PROCEDURE")
38+
}
39+
40+
return vitess.InjectedStatement{
41+
Statement: pgnodes.NewDropProcedure(
42+
node.IfExists,
43+
node.Procedures,
44+
node.DropBehavior == tree.DropCascade),
45+
Children: nil,
46+
}, nil
47+
}

server/functions/record.go

Lines changed: 68 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,13 @@
1515
package functions
1616

1717
import (
18-
"fmt"
19-
18+
"github.com/cockroachdb/errors"
2019
"github.com/dolthub/go-mysql-server/sql"
2120

21+
"github.com/dolthub/doltgresql/core"
22+
"github.com/dolthub/doltgresql/core/id"
23+
"github.com/dolthub/doltgresql/utils"
24+
2225
"github.com/dolthub/doltgresql/server/functions/framework"
2326
pgtypes "github.com/dolthub/doltgresql/server/types"
2427
)
@@ -40,7 +43,7 @@ var record_in = framework.Function3{
4043
Parameters: [3]*pgtypes.DoltgresType{pgtypes.Cstring, pgtypes.Oid, pgtypes.Int32},
4144
Strict: true,
4245
Callable: func(ctx *sql.Context, _ [4]*pgtypes.DoltgresType, val1, val2, val3 any) (any, error) {
43-
return nil, fmt.Errorf("record_in not implemented")
46+
return nil, errors.Errorf("record_in not implemented")
4447
},
4548
}
4649

@@ -53,24 +56,66 @@ var record_out = framework.Function1{
5356
Callable: func(ctx *sql.Context, t [2]*pgtypes.DoltgresType, val any) (any, error) {
5457
values, ok := val.([]pgtypes.RecordValue)
5558
if !ok {
56-
return nil, fmt.Errorf("expected []RecordValue, but got %T", val)
59+
return nil, errors.Errorf("expected []RecordValue, but got %T", val)
5760
}
5861
return pgtypes.RecordToString(ctx, values)
5962
},
6063
}
6164

62-
// record_recv represents the PostgreSQL function of record type IO receive.
65+
// record_recv represents the PostgreSQL function of record type IO receive. The input of this function is expected to
66+
// be the output of record_send.
6367
var record_recv = framework.Function3{
6468
Name: "record_recv",
6569
Return: pgtypes.Record,
6670
Parameters: [3]*pgtypes.DoltgresType{pgtypes.Internal, pgtypes.Oid, pgtypes.Int32},
6771
Strict: true,
6872
Callable: func(ctx *sql.Context, _ [4]*pgtypes.DoltgresType, val1, val2, val3 any) (any, error) {
69-
return nil, fmt.Errorf("record_recv not implemented")
73+
data, ok := val1.([]byte)
74+
if !ok {
75+
return nil, errors.Errorf("expected []byte, but got `%T`", val1)
76+
}
77+
typeColl, err := core.GetTypesCollectionFromContext(ctx)
78+
if err != nil {
79+
return nil, err
80+
}
81+
reader := utils.NewReader(data)
82+
version := reader.Byte()
83+
switch version {
84+
case 0:
85+
valuesLen := reader.VariableUint()
86+
values := make([]pgtypes.RecordValue, valuesLen)
87+
for i := uint64(0); i < valuesLen; i++ {
88+
typeId := id.Type(reader.Id())
89+
valueData := reader.ByteSlice()
90+
dgtype, err := typeColl.GetType(ctx, typeId)
91+
if err != nil {
92+
return nil, err
93+
}
94+
if dgtype == nil {
95+
return nil, errors.Errorf("record_recv encountered type `%s.%s` which could not be found",
96+
typeId.SchemaName(), typeId.TypeName())
97+
}
98+
value, err := dgtype.DeserializeValue(ctx, valueData)
99+
if err != nil {
100+
return nil, err
101+
}
102+
values[i] = pgtypes.RecordValue{
103+
Value: value,
104+
Type: dgtype,
105+
}
106+
}
107+
if reader.RemainingBytes() > 0 {
108+
return nil, errors.New("record_recv encountered extra data during deserialization")
109+
}
110+
return values, nil
111+
default:
112+
return nil, errors.Errorf("version %d of record serialization is not supported, please upgrade the server", version)
113+
}
70114
},
71115
}
72116

73-
// record_send represents the PostgreSQL function of record type IO send.
117+
// record_send represents the PostgreSQL function of record type IO send. The output of this function is expected to
118+
// be the input of record_recv.
74119
var record_send = framework.Function1{
75120
Name: "record_send",
76121
Return: pgtypes.Bytea,
@@ -79,16 +124,24 @@ var record_send = framework.Function1{
79124
Callable: func(ctx *sql.Context, t [2]*pgtypes.DoltgresType, val any) (any, error) {
80125
values, ok := val.([]pgtypes.RecordValue)
81126
if !ok {
82-
return nil, fmt.Errorf("expected []RecordValue, but got %T", val)
127+
return nil, errors.Errorf("expected []RecordValue, but got %T", val)
83128
}
84-
// TODO: converting from a string back to the record doesn't work as we lose type information, so we need to
85-
// figure out how to retain this information
86-
output, err := pgtypes.RecordToString(ctx, values)
87-
if err != nil {
88-
return nil, err
129+
writer := utils.NewWriter(uint64(16 * len(values)))
130+
writer.Byte(0) // Version
131+
writer.VariableUint(uint64(len(values)))
132+
for _, value := range values {
133+
dgtype, ok := value.Type.(*pgtypes.DoltgresType)
134+
if !ok {
135+
return nil, errors.Errorf("record_send only supports Doltgres types, but received `%T`", value.Type)
136+
}
137+
valBytes, err := dgtype.SerializeValue(ctx, value.Value)
138+
if err != nil {
139+
return nil, err
140+
}
141+
writer.Id(dgtype.ID.AsId())
142+
writer.ByteSlice(valBytes)
89143
}
90-
91-
return []byte(output.(string)), nil
144+
return writer.Data(), nil
92145
},
93146
}
94147

0 commit comments

Comments
 (0)