Skip to content

Commit d0ee709

Browse files
authored
Merge pull request #2443 from dolthub/daylon/dolt_branch_control
Added proper dolt_branch_control functionality
2 parents 2cbc098 + c7ce7d5 commit d0ee709

9 files changed

Lines changed: 404 additions & 11 deletions

File tree

server/analyzer/assign_insert_casts.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ func AssignInsertCasts(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, sc
4949
for _, col := range insertInto.Destination.Schema() {
5050
colType, ok := col.Type.(*pgtypes.DoltgresType)
5151
if !ok {
52-
return nil, transform.NewTree, errors.Errorf("INSERT: non-Doltgres type found in destination: %s", col.Type.String())
52+
// Only non-Doltgres destination tables will have GMS types (such as system tables), so we don't error here
53+
colType = pgtypes.FromGmsType(col.Type)
5354
}
5455
destinationNameToType[strings.ToLower(col.Name)] = colType
5556
}

server/analyzer/assign_update_casts.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ func assignUpdateFieldCasts(updateExprs []sql.Expression) ([]sql.Expression, err
113113
}
114114
toType, ok := setField.LeftChild.Type().(*pgtypes.DoltgresType)
115115
if !ok {
116-
return nil, errors.Errorf("UPDATE: non-Doltgres type found in destination: %s", setField.LeftChild.String())
116+
// Only non-Doltgres destination tables will have GMS types (such as system tables), so we don't error here
117+
toType = pgtypes.FromGmsType(setField.LeftChild.Type())
117118
}
118119
// We only assign the existing expression if the types perfectly match (same parameters), otherwise we'll cast
119120
if fromType.Equals(toType) {

server/auth/gms_privilege_set.go

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
// Copyright 2026 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package auth
16+
17+
import "github.com/dolthub/go-mysql-server/sql"
18+
19+
// PrivilegeSetLayer is used to allow some functions that inspect the GMS privilege set (such as branch control) to
20+
// interface with Doltgres' auth system.
21+
type PrivilegeSetLayer struct {
22+
Role RoleID
23+
}
24+
25+
var _ sql.PrivilegeSet = (*PrivilegeSetLayer)(nil)
26+
27+
// NewPrivilegeSetLayer creates a new PrivilegeSetLayer for the user in the given context's session.
28+
func NewPrivilegeSetLayer(ctx *sql.Context) *PrivilegeSetLayer {
29+
return &PrivilegeSetLayer{
30+
Role: GetRole(ctx.Client().User).id,
31+
}
32+
}
33+
34+
// Has implements the interface sql.PrivilegeSet.
35+
func (privSet *PrivilegeSetLayer) Has(privileges ...sql.PrivilegeType) bool {
36+
return IsSuperUser(privSet.Role)
37+
}
38+
39+
// HasPrivileges implements the interface sql.PrivilegeSet.
40+
func (privSet *PrivilegeSetLayer) HasPrivileges() bool {
41+
return IsSuperUser(privSet.Role)
42+
}
43+
44+
// Count implements the interface sql.PrivilegeSet.
45+
func (privSet *PrivilegeSetLayer) Count() int {
46+
if IsSuperUser(privSet.Role) {
47+
return 31 // The current number in GMS
48+
}
49+
return 0
50+
}
51+
52+
// Database implements the interface sql.PrivilegeSet.
53+
func (privSet *PrivilegeSetLayer) Database(dbName string) sql.PrivilegeSetDatabase {
54+
return &PrivilegeSetLayerDatabase{
55+
Db: dbName,
56+
Role: privSet.Role,
57+
}
58+
}
59+
60+
// GetDatabases implements the interface sql.PrivilegeSet.
61+
func (privSet *PrivilegeSetLayer) GetDatabases() []sql.PrivilegeSetDatabase {
62+
return nil
63+
}
64+
65+
// Equals implements the interface sql.PrivilegeSet.
66+
func (privSet *PrivilegeSetLayer) Equals(otherPs sql.PrivilegeSet) bool {
67+
if other, ok := otherPs.(*PrivilegeSetLayer); ok {
68+
return privSet.Role == other.Role
69+
}
70+
return false
71+
}
72+
73+
// ToSlice implements the interface sql.PrivilegeSet.
74+
func (privSet *PrivilegeSetLayer) ToSlice() []sql.PrivilegeType {
75+
if IsSuperUser(privSet.Role) {
76+
return []sql.PrivilegeType{sql.PrivilegeType_Select,
77+
sql.PrivilegeType_Insert,
78+
sql.PrivilegeType_Update,
79+
sql.PrivilegeType_Delete,
80+
sql.PrivilegeType_Create,
81+
sql.PrivilegeType_Drop,
82+
sql.PrivilegeType_Reload,
83+
sql.PrivilegeType_Shutdown,
84+
sql.PrivilegeType_Process,
85+
sql.PrivilegeType_File,
86+
sql.PrivilegeType_GrantOption,
87+
sql.PrivilegeType_References,
88+
sql.PrivilegeType_Index,
89+
sql.PrivilegeType_Alter,
90+
sql.PrivilegeType_ShowDB,
91+
sql.PrivilegeType_Super,
92+
sql.PrivilegeType_CreateTempTable,
93+
sql.PrivilegeType_LockTables,
94+
sql.PrivilegeType_Execute,
95+
sql.PrivilegeType_ReplicationSlave,
96+
sql.PrivilegeType_ReplicationClient,
97+
sql.PrivilegeType_CreateView,
98+
sql.PrivilegeType_ShowView,
99+
sql.PrivilegeType_CreateRoutine,
100+
sql.PrivilegeType_AlterRoutine,
101+
sql.PrivilegeType_CreateUser,
102+
sql.PrivilegeType_Event,
103+
sql.PrivilegeType_Trigger,
104+
sql.PrivilegeType_CreateTablespace,
105+
sql.PrivilegeType_CreateRole,
106+
sql.PrivilegeType_DropRole}
107+
}
108+
return nil
109+
}
110+
111+
// PrivilegeSetLayerDatabase is the database portion of PrivilegeSetLayer.
112+
type PrivilegeSetLayerDatabase struct {
113+
Db string
114+
Role RoleID
115+
}
116+
117+
var _ sql.PrivilegeSetDatabase = (*PrivilegeSetLayerDatabase)(nil)
118+
119+
// Name implements the interface sql.PrivilegeSetDatabase.
120+
func (privSet *PrivilegeSetLayerDatabase) Name() string {
121+
return privSet.Db
122+
}
123+
124+
// Has implements the interface sql.PrivilegeSetDatabase.
125+
func (privSet *PrivilegeSetLayerDatabase) Has(privileges ...sql.PrivilegeType) bool {
126+
return IsSuperUser(privSet.Role)
127+
}
128+
129+
// HasPrivileges implements the interface sql.PrivilegeSetDatabase.
130+
func (privSet *PrivilegeSetLayerDatabase) HasPrivileges() bool {
131+
return IsSuperUser(privSet.Role)
132+
}
133+
134+
// Count implements the interface sql.PrivilegeSetDatabase.
135+
func (privSet *PrivilegeSetLayerDatabase) Count() int {
136+
if IsSuperUser(privSet.Role) {
137+
return 31 // The current number in GMS
138+
}
139+
return 0
140+
}
141+
142+
// Table implements the interface sql.PrivilegeSetDatabase.
143+
func (privSet *PrivilegeSetLayerDatabase) Table(tblName string) sql.PrivilegeSetTable {
144+
panic("Table is not yet implemented for the Doltgres privilege layer")
145+
}
146+
147+
// GetTables implements the interface sql.PrivilegeSetDatabase.
148+
func (privSet *PrivilegeSetLayerDatabase) GetTables() []sql.PrivilegeSetTable {
149+
return nil
150+
}
151+
152+
// Routine implements the interface sql.PrivilegeSetDatabase.
153+
func (privSet *PrivilegeSetLayerDatabase) Routine(routineName string, isProcedure bool) sql.PrivilegeSetRoutine {
154+
panic("Routine is not yet implemented for the Doltgres privilege layer")
155+
}
156+
157+
// GetRoutines implements the interface sql.PrivilegeSetDatabase.
158+
func (privSet *PrivilegeSetLayerDatabase) GetRoutines() []sql.PrivilegeSetRoutine {
159+
return nil
160+
}
161+
162+
// Equals implements the interface sql.PrivilegeSetDatabase.
163+
func (privSet *PrivilegeSetLayerDatabase) Equals(otherPs sql.PrivilegeSetDatabase) bool {
164+
if other, ok := otherPs.(*PrivilegeSetLayerDatabase); ok {
165+
return privSet.Role == other.Role && privSet.Db == other.Db
166+
}
167+
return false
168+
}
169+
170+
// ToSlice implements the interface sql.PrivilegeSetDatabase.
171+
func (privSet *PrivilegeSetLayerDatabase) ToSlice() []sql.PrivilegeType {
172+
if IsSuperUser(privSet.Role) {
173+
return []sql.PrivilegeType{sql.PrivilegeType_Select,
174+
sql.PrivilegeType_Insert,
175+
sql.PrivilegeType_Update,
176+
sql.PrivilegeType_Delete,
177+
sql.PrivilegeType_Create,
178+
sql.PrivilegeType_Drop,
179+
sql.PrivilegeType_Reload,
180+
sql.PrivilegeType_Shutdown,
181+
sql.PrivilegeType_Process,
182+
sql.PrivilegeType_File,
183+
sql.PrivilegeType_GrantOption,
184+
sql.PrivilegeType_References,
185+
sql.PrivilegeType_Index,
186+
sql.PrivilegeType_Alter,
187+
sql.PrivilegeType_ShowDB,
188+
sql.PrivilegeType_Super,
189+
sql.PrivilegeType_CreateTempTable,
190+
sql.PrivilegeType_LockTables,
191+
sql.PrivilegeType_Execute,
192+
sql.PrivilegeType_ReplicationSlave,
193+
sql.PrivilegeType_ReplicationClient,
194+
sql.PrivilegeType_CreateView,
195+
sql.PrivilegeType_ShowView,
196+
sql.PrivilegeType_CreateRoutine,
197+
sql.PrivilegeType_AlterRoutine,
198+
sql.PrivilegeType_CreateUser,
199+
sql.PrivilegeType_Event,
200+
sql.PrivilegeType_Trigger,
201+
sql.PrivilegeType_CreateTablespace,
202+
sql.PrivilegeType_CreateRole,
203+
sql.PrivilegeType_DropRole}
204+
}
205+
return nil
206+
}

server/doltgres_handler.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ import (
4646

4747
"github.com/dolthub/doltgresql/core/id"
4848
"github.com/dolthub/doltgresql/postgres/parser/uuid"
49+
"github.com/dolthub/doltgresql/server/auth"
4950
pgexprs "github.com/dolthub/doltgresql/server/expression"
5051
pgtransform "github.com/dolthub/doltgresql/server/transform"
5152
pgtypes "github.com/dolthub/doltgresql/server/types"
@@ -447,6 +448,7 @@ func (h *DoltgresHandler) doQuery(ctx context.Context, c *mysql.Conn, query stri
447448
if err != nil {
448449
return err
449450
}
451+
sqlCtx.SetPrivilegeSet(auth.NewPrivilegeSetLayer(sqlCtx), 1)
450452

451453
start := time.Now()
452454
var queryStrToLog string

server/expression/gms_cast.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ func (c *GMSCast) Eval(ctx *sql.Context, row sql.Row) (any, error) {
9898
fallthrough
9999
// In Postgres, Int32 is generally the smallest value returned. But we convert int8 and int16 to this type during
100100
// schema conversion, which means we must do so here as well to avoid runtime panics.
101-
case query.Type_INT16, query.Type_INT24, query.Type_INT32, query.Type_YEAR, query.Type_ENUM:
101+
case query.Type_INT16, query.Type_INT24, query.Type_INT32, query.Type_YEAR:
102102
newVal, _, err := types.Int32.Convert(ctx, val)
103103
if err != nil {
104104
return nil, err
@@ -107,7 +107,7 @@ func (c *GMSCast) Eval(ctx *sql.Context, row sql.Row) (any, error) {
107107
return nil, errors.Errorf("GMSCast expected type `int32`, got `%T`", val)
108108
}
109109
return newVal, nil
110-
case query.Type_INT64, query.Type_SET, query.Type_BIT, query.Type_UINT8, query.Type_UINT16, query.Type_UINT24, query.Type_UINT32:
110+
case query.Type_INT64, query.Type_BIT, query.Type_UINT8, query.Type_UINT16, query.Type_UINT24, query.Type_UINT32:
111111
newVal, _, err := types.Int64.Convert(ctx, val)
112112
if err != nil {
113113
return nil, err
@@ -163,7 +163,7 @@ func (c *GMSCast) Eval(ctx *sql.Context, row sql.Row) (any, error) {
163163
return val.String(), nil
164164
}
165165
return nil, errors.Errorf("GMSCast expected type `Timespan`, got `%T`", val)
166-
case query.Type_CHAR, query.Type_VARCHAR, query.Type_TEXT, query.Type_BINARY, query.Type_VARBINARY, query.Type_BLOB:
166+
case query.Type_CHAR, query.Type_VARCHAR, query.Type_TEXT, query.Type_BINARY, query.Type_VARBINARY, query.Type_BLOB, query.Type_SET, query.Type_ENUM:
167167
newVal, _, err := types.LongText.Convert(ctx, val)
168168
if err != nil {
169169
return nil, err

server/types/utils.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,9 @@ func FromGmsTypeToDoltgresType(typ sql.Type) (*DoltgresType, error) {
7676
return Bool, nil
7777
}
7878
return Int32, nil
79-
case query.Type_INT16, query.Type_INT24, query.Type_INT32, query.Type_YEAR, query.Type_ENUM:
79+
case query.Type_INT16, query.Type_INT24, query.Type_INT32, query.Type_YEAR:
8080
return Int32, nil
81-
case query.Type_INT64, query.Type_SET, query.Type_BIT, query.Type_UINT8, query.Type_UINT16, query.Type_UINT24, query.Type_UINT32:
81+
case query.Type_INT64, query.Type_BIT, query.Type_UINT8, query.Type_UINT16, query.Type_UINT24, query.Type_UINT32:
8282
return Int64, nil
8383
case query.Type_UINT64:
8484
return Numeric, nil
@@ -94,7 +94,7 @@ func FromGmsTypeToDoltgresType(typ sql.Type) (*DoltgresType, error) {
9494
return Text, nil
9595
case query.Type_DATETIME, query.Type_TIMESTAMP:
9696
return Timestamp, nil
97-
case query.Type_CHAR, query.Type_VARCHAR, query.Type_TEXT, query.Type_BINARY, query.Type_VARBINARY, query.Type_BLOB:
97+
case query.Type_CHAR, query.Type_VARCHAR, query.Type_TEXT, query.Type_BINARY, query.Type_VARBINARY, query.Type_BLOB, query.Type_SET, query.Type_ENUM:
9898
return Text, nil
9999
case query.Type_JSON:
100100
return Json, nil

0 commit comments

Comments
 (0)