Skip to content

Commit b4381d1

Browse files
authored
Merge pull request #2009 from dolthub/daylon/table-type
Tables as types
2 parents b4a352e + 74ad3f9 commit b4381d1

29 files changed

Lines changed: 508 additions & 139 deletions

core/init.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020

2121
"github.com/dolthub/doltgresql/core/conflicts"
2222
"github.com/dolthub/doltgresql/core/id"
23+
"github.com/dolthub/doltgresql/core/typecollection"
2324
"github.com/dolthub/doltgresql/server/plpgsql"
2425

2526
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
@@ -35,4 +36,6 @@ func Init() {
3536
conflicts.ClearContextValues = ClearContextValues
3637
plpgsql.GetTypesCollectionFromContext = GetTypesCollectionFromContext
3738
id.RegisterListener(sequenceIDListener{}, id.Section_Table)
39+
typecollection.GetSqlTableFromContext = GetSqlTableFromContext
40+
typecollection.GetSchemaName = GetSchemaName
3841
}

core/typecollection/typecollection.go

Lines changed: 74 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"github.com/dolthub/dolt/go/store/hash"
2626
"github.com/dolthub/dolt/go/store/prolly"
2727
"github.com/dolthub/dolt/go/store/prolly/tree"
28+
"github.com/dolthub/go-mysql-server/sql"
2829

2930
"github.com/dolthub/doltgresql/core/id"
3031
"github.com/dolthub/doltgresql/core/rootobject/objinterface"
@@ -101,8 +102,10 @@ func (pgs *TypeCollection) DropType(ctx context.Context, names ...id.Type) (err
101102
}
102103

103104
// GetAllTypes returns a map containing all types in the collection, grouped by the schema they're contained in.
104-
// Each type array is also sorted by the type name. It includes built-in types.
105+
// Each type array is also sorted by the type name. It includes built-in types, but does not include types referring to
106+
// a table's row type.
105107
func (pgs *TypeCollection) GetAllTypes(ctx context.Context) (typeMap map[string][]*pgtypes.DoltgresType, schemaNames []string, totalCount int, err error) {
108+
// TODO: this should probably get tables as well since tables create composite types matching their rows
106109
schemaNamesMap := make(map[string]struct{})
107110
typeMap = make(map[string][]*pgtypes.DoltgresType)
108111
err = pgs.IterateTypes(ctx, func(t *pgtypes.DoltgresType) (stop bool, err error) {
@@ -158,9 +161,21 @@ func (pgs *TypeCollection) GetType(ctx context.Context, name id.Type) (*pgtypes.
158161
}
159162
// The initial load is from the internal map
160163
h, err := pgs.underlyingMap.Get(ctx, string(name))
161-
if err != nil || h.IsEmpty() {
164+
if err != nil {
162165
return nil, err
163166
}
167+
if h.IsEmpty() {
168+
// If it's not a built-in type or created type, then check if it's a composite table row type
169+
sqlCtx, ok := ctx.(*sql.Context)
170+
if !ok {
171+
return nil, nil
172+
}
173+
tbl, schema, err := pgs.getTable(sqlCtx, name.SchemaName(), name.TypeName())
174+
if err != nil || tbl == nil {
175+
return nil, err
176+
}
177+
return pgs.tableToType(sqlCtx, tbl, schema)
178+
}
164179
data, err := pgs.ns.ReadBytes(ctx, h)
165180
if err != nil {
166181
return nil, err
@@ -180,19 +195,26 @@ func (pgs *TypeCollection) HasType(ctx context.Context, name id.Type) bool {
180195
if _, ok := pgtypes.IDToBuiltInDoltgresType[name]; ok {
181196
return true
182197
}
183-
198+
// Now we'll check our created types
184199
if _, ok := pgs.accessedMap[name]; ok {
185200
return true
186201
}
187202
ok, err := pgs.underlyingMap.Has(ctx, string(name))
188203
if err == nil && ok {
189204
return true
190205
}
191-
return false
206+
// If it's not a built-in type or created type, then check if it's a composite table row type
207+
sqlCtx, ok := ctx.(*sql.Context)
208+
if !ok {
209+
return false
210+
}
211+
tbl, _, err := pgs.getTable(sqlCtx, name.SchemaName(), name.TypeName())
212+
return err == nil && tbl != nil
192213
}
193214

194215
// resolveName returns the fully resolved name of the given type. Returns an error if the name is ambiguous.
195216
func (pgs *TypeCollection) resolveName(ctx context.Context, schemaName string, typeName string) (id.Type, error) {
217+
// TODO: this should probably check table names as well since tables create composite types matching their rows
196218
// First check for an exact match in the built-in types
197219
inputID := id.NewType(schemaName, typeName)
198220
if _, ok := pgtypes.IDToBuiltInDoltgresType[inputID]; ok {
@@ -251,6 +273,7 @@ func (pgs *TypeCollection) resolveName(ctx context.Context, schemaName string, t
251273

252274
// IterateTypes iterates over all types in the collection.
253275
func (pgs *TypeCollection) IterateTypes(ctx context.Context, f func(typ *pgtypes.DoltgresType) (stop bool, err error)) error {
276+
// TODO: this should probably iterate tables as well since tables create composite types matching their rows
254277
// We can iterate the built-in types first
255278
for _, t := range pgtypes.GetAllBuitInTypes() {
256279
stop, err := f(t)
@@ -368,3 +391,50 @@ func (pgs *TypeCollection) writeCache(ctx context.Context) (err error) {
368391
clear(pgs.accessedMap)
369392
return nil
370393
}
394+
395+
// getTable returns the SQL table that matches the given schema and table name. Returns a nil table if one is not found.
396+
// This is intended for use with tableToType.
397+
func (*TypeCollection) getTable(ctx *sql.Context, schema string, tblName string) (tbl sql.Table, actualSchema string, err error) {
398+
actualSchema, err = GetSchemaName(ctx, nil, schema)
399+
if err != nil {
400+
return nil, "", err
401+
}
402+
tbl, err = GetSqlTableFromContext(ctx, "", doltdb.TableName{
403+
Name: tblName,
404+
Schema: actualSchema,
405+
})
406+
if err != nil || tbl == nil {
407+
return nil, "", err
408+
}
409+
if schTbl, ok := tbl.(sql.DatabaseSchemaTable); ok {
410+
actualSchema = schTbl.DatabaseSchema().SchemaName()
411+
}
412+
return tbl, actualSchema, nil
413+
}
414+
415+
// tableToType handles type creation related to a table's composite row type.
416+
// https://www.postgresql.org/docs/15/sql-createtable.html
417+
func (*TypeCollection) tableToType(ctx *sql.Context, tbl sql.Table, schema string) (*pgtypes.DoltgresType, error) {
418+
tblName := tbl.Name()
419+
tblSch := tbl.Schema()
420+
typeID := id.NewType(schema, tblName)
421+
relID := id.NewTable(schema, tblName).AsId()
422+
arrayID := id.NewType(schema, "_"+tblName)
423+
attrs := make([]pgtypes.CompositeAttribute, len(tblSch))
424+
for i, col := range tblSch {
425+
collation := "" // TODO: what should we use for the collation?
426+
colType, ok := col.Type.(*pgtypes.DoltgresType)
427+
if !ok {
428+
// TODO: perhaps we should use a better error message stating that it uses a non-Doltgres type?
429+
return nil, pgtypes.ErrTypeDoesNotExist.New(tblName)
430+
}
431+
attrs[i] = pgtypes.NewCompositeAttribute(ctx, relID, col.Name, colType.ID, int16(i+1), collation)
432+
}
433+
return pgtypes.NewCompositeType(ctx, relID, arrayID, typeID, attrs), nil
434+
}
435+
436+
// GetSqlTableFromContext is a forward declaration to get around import cycles
437+
var GetSqlTableFromContext func(ctx *sql.Context, databaseName string, tableName doltdb.TableName) (sql.Table, error)
438+
439+
// GetSchemaName is a forward declaration to get around import cycles
440+
var GetSchemaName func(ctx *sql.Context, db sql.Database, schemaName string) (string, error)

go.mod

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@ require (
66
github.com/PuerkitoBio/goquery v1.8.1
77
github.com/cockroachdb/apd/v2 v2.0.3-0.20200518165714-d020e156310a
88
github.com/cockroachdb/errors v1.7.5
9-
github.com/dolthub/dolt/go v0.40.5-0.20251211214546-0b9999455622
9+
github.com/dolthub/dolt/go v0.40.5-0.20251212132525-3f50c72565d4
1010
github.com/dolthub/eventsapi_schema v0.0.0-20250915094920-eadfd39051ca
1111
github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2
12-
github.com/dolthub/go-mysql-server v0.20.1-0.20251211205836-45695e02d2b6
12+
github.com/dolthub/go-mysql-server v0.20.1-0.20251212100505-44ab12c341e7
1313
github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216
1414
github.com/dolthub/vitess v0.0.0-20251210200925-1d33d416d162
1515
github.com/fatih/color v1.13.0

go.sum

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,8 @@ github.com/dolthub/aws-sdk-go-ini-parser v0.0.0-20250305001723-2821c37f6c12 h1:I
228228
github.com/dolthub/aws-sdk-go-ini-parser v0.0.0-20250305001723-2821c37f6c12/go.mod h1:rN7X8BHwkjPcfMQQ2QTAq/xM3leUSGLfb+1Js7Y6TVo=
229229
github.com/dolthub/dolt-mcp v0.2.2 h1:bpROmam74n95uU4EA3BpOIVlTDT0pzeFMBwe/YRq2mI=
230230
github.com/dolthub/dolt-mcp v0.2.2/go.mod h1:S++DJ4QWTAXq+0TNzFa7Oq3IhoT456DJHwAINFAHgDQ=
231-
github.com/dolthub/dolt/go v0.40.5-0.20251211214546-0b9999455622 h1:M0FRziRmHiiv4osP2f5AR+nhY1KnIiclax4gXb3hlGg=
232-
github.com/dolthub/dolt/go v0.40.5-0.20251211214546-0b9999455622/go.mod h1:+g40yZ9gyg0JIgaiwCpfjAG9OAbp35DMseO9W7OISQg=
231+
github.com/dolthub/dolt/go v0.40.5-0.20251212132525-3f50c72565d4 h1:jUnjruGfdVZhefoqmFV6NihbmW3xkHdXtHSH5yFFp/s=
232+
github.com/dolthub/dolt/go v0.40.5-0.20251212132525-3f50c72565d4/go.mod h1:8ezocTxz4FEGVZx5uMW7mU/2Fka5Ik23b3leFa4Doyw=
233233
github.com/dolthub/eventsapi_schema v0.0.0-20250915094920-eadfd39051ca h1:BGFz/0OlKIuC6qHIZQbvPapFvdAJkeEyGXWVgL5clmE=
234234
github.com/dolthub/eventsapi_schema v0.0.0-20250915094920-eadfd39051ca/go.mod h1:CoDLfgPqHyBtth0Cp+fi/CmC4R81zJNX4wPjShdZ+Bw=
235235
github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 h1:u3PMzfF8RkKd3lB9pZ2bfn0qEG+1Gms9599cr0REMww=
@@ -238,8 +238,8 @@ github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U=
238238
github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0=
239239
github.com/dolthub/go-icu-regex v0.0.0-20250916051405-78a38d478790 h1:zxMsH7RLiG+dlZ/y0LgJHTV26XoiSJcuWq+em6t6VVc=
240240
github.com/dolthub/go-icu-regex v0.0.0-20250916051405-78a38d478790/go.mod h1:F3cnm+vMRK1HaU6+rNqQrOCyR03HHhR1GWG2gnPOqaE=
241-
github.com/dolthub/go-mysql-server v0.20.1-0.20251211205836-45695e02d2b6 h1:pLvSf/YvXYyCaWFso9QVWMaRSR4/+hIPnx5QtEd78p8=
242-
github.com/dolthub/go-mysql-server v0.20.1-0.20251211205836-45695e02d2b6/go.mod h1:NjewWKoa5bVSLdKwL7fg7eAfrcIxDybWUKoWEHWRTw4=
241+
github.com/dolthub/go-mysql-server v0.20.1-0.20251212100505-44ab12c341e7 h1:NVQXrGTC8a5dWPgL18ZO0iVstek4fO48nWYwQHA6ReY=
242+
github.com/dolthub/go-mysql-server v0.20.1-0.20251212100505-44ab12c341e7/go.mod h1:NjewWKoa5bVSLdKwL7fg7eAfrcIxDybWUKoWEHWRTw4=
243243
github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 h1:OAsXLAPL4du6tfbBgK0xXHZkOlos63RdKYS3Sgw/dfI=
244244
github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63/go.mod h1:lV7lUeuDhH5thVGDCKXbatwKy2KW80L4rMT46n+Y2/Q=
245245
github.com/dolthub/ishell v0.0.0-20240701202509-2b217167d718 h1:lT7hE5k+0nkBdj/1UOSFwjWpNxf+LCApbRHgnCA17XE=

server/analyzer/create_function.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import (
2121
"github.com/dolthub/go-mysql-server/sql/planbuilder"
2222
"github.com/dolthub/go-mysql-server/sql/transform"
2323

24-
psql "github.com/dolthub/doltgresql/postgres/parser/parser/sql"
2524
"github.com/dolthub/doltgresql/server/node"
2625
)
2726

@@ -37,8 +36,7 @@ func ValidateCreateFunction(ctx *sql.Context, a *analyzer.Analyzer, n sql.Node,
3736
return n, transform.SameTree, nil
3837
}
3938

40-
parser := psql.NewPostgresParser()
41-
builder := planbuilder.New(ctx, a.Catalog, nil, parser)
39+
builder := planbuilder.New(ctx, a.Catalog, nil)
4240
_, _, err := builder.BindOnly(ct.SqlDefParsed, ct.SqlDef, nil)
4341
if err != nil {
4442
return nil, transform.SameTree, err

server/analyzer/domain_constraints.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ func getDomainDefault(ctx *sql.Context, a *analyzer.Analyzer, defExpr, tblName s
8484
if defExpr == "" {
8585
return nil, nil
8686
}
87-
parsed, err := sql.GlobalParser.ParseSimple(fmt.Sprintf("select %s from %s", defExpr, tblName))
87+
parsed, err := a.Parser.ParseSimple(fmt.Sprintf("select %s from %s", defExpr, tblName))
8888
if err != nil {
8989
return nil, err
9090
}
@@ -97,7 +97,7 @@ func getDomainDefault(ctx *sql.Context, a *analyzer.Analyzer, defExpr, tblName s
9797
if !ok {
9898
return nil, sql.ErrInvalidColumnDefaultValue.New(defExpr)
9999
}
100-
builder := planbuilder.New(ctx, a.Catalog, nil, sql.GlobalParser)
100+
builder := planbuilder.New(ctx, a.Catalog, nil)
101101
return builder.BuildColumnDefaultValueWithTable(ae.Expr, selectStmt.From[0], typ, nullable), nil
102102
}
103103

@@ -228,7 +228,7 @@ func parseAndReplaceDomainCheckConstraint(ctx *sql.Context, a *analyzer.Analyzer
228228
return nil, sql.ErrInvalidCheckConstraint.New(checkExpr)
229229
}
230230

231-
builder := planbuilder.New(ctx, a.Catalog, nil, sql.GlobalParser)
231+
builder := planbuilder.New(ctx, a.Catalog, nil)
232232
var tblExpr vitess.TableExpr
233233
if len(convertedSelectStmt.From) == 1 {
234234
tblExpr = convertedSelectStmt.From[0]

server/auth/database.go

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ import (
2222

2323
"github.com/dolthub/dolt/go/libraries/doltcore/env"
2424
"github.com/dolthub/dolt/go/libraries/utils/filesys"
25-
26-
doltgresservercfg "github.com/dolthub/doltgresql/servercfg"
2725
)
2826

2927
// authFileName is the name of the file that contains all authorization-related data.
@@ -131,7 +129,7 @@ func LockWrite(f func()) {
131129

132130
// dbInit handle the global database initialization. Panics if an error occurs, since it points to something going
133131
// terribly wrong.
134-
func dbInit(dEnv *env.DoltEnv, cfg *doltgresservercfg.DoltgresConfig) {
132+
func dbInit(dEnv *env.DoltEnv, cfg Config) {
135133
globalDatabase = Database{
136134
rolesByName: make(map[string]RoleID),
137135
rolesByID: make(map[RoleID]Role),
@@ -143,8 +141,8 @@ func dbInit(dEnv *env.DoltEnv, cfg *doltgresservercfg.DoltgresConfig) {
143141
globalLock = &sync.RWMutex{}
144142
if dEnv != nil {
145143
if _, ok := dEnv.FS.(*filesys.InMemFS); !ok {
146-
if cfg != nil && cfg.AuthFile != nil && len(*cfg.AuthFile) > 0 {
147-
authFileName = *cfg.AuthFile
144+
if cfg != nil && len(cfg.AuthFilePath()) > 0 {
145+
authFileName = cfg.AuthFilePath()
148146
}
149147
fileSystem = dEnv.FS
150148
authData, err := fileSystem.ReadFile(authFileName)

server/auth/init.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ import (
1919

2020
"github.com/dolthub/dolt/go/libraries/doltcore/env"
2121
"github.com/dolthub/go-mysql-server/sql"
22-
23-
doltgresservercfg "github.com/dolthub/doltgresql/servercfg"
2422
)
2523

2624
// doltgresPasswordEnvVar is the name of the environment variable that can be used to set the password for the
@@ -31,8 +29,13 @@ const doltgresPasswordEnvVar = "DOLTGRES_PASSWORD"
3129
// default user.
3230
const doltgresUserEnvVar = "DOLTGRES_USER"
3331

32+
// Config is an interface that exists as pulling the actual config package would cause a cyclical dependency.
33+
type Config interface {
34+
AuthFilePath() string
35+
}
36+
3437
// Init handles all initialization needs in this package.
35-
func Init(dEnv *env.DoltEnv, cfg *doltgresservercfg.DoltgresConfig) {
38+
func Init(dEnv *env.DoltEnv, cfg Config) {
3639
dbInit(dEnv, cfg)
3740
sql.SetAuthorizationHandlerFactory(AuthorizationHandlerFactory{})
3841
}

server/connection_handler.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,8 @@ func (h *ConnectionHandler) chooseInitialParameters(startupMessage *pgproto3.Sta
302302
db = h.mysqlConn.User
303303
}
304304
useStmt := fmt.Sprintf("SET database TO '%s';", db)
305-
parsed, err := sql.GlobalParser.ParseSimple(useStmt)
305+
postgresParser := psql.PostgresParser{}
306+
parsed, err := postgresParser.ParseSimple(useStmt)
306307
if err != nil {
307308
return err
308309
}
@@ -748,7 +749,7 @@ func (h *ConnectionHandler) handleCopyDataHelper(copyState *copyFromStdinState,
748749
}
749750

750751
// we build an insert node to use for the full insert plan, for which the copy from node will be the row source
751-
builder := planbuilder.New(sqlCtx, h.doltgresHandler.e.Analyzer.Catalog, nil, psql.NewPostgresParser())
752+
builder := planbuilder.New(sqlCtx, h.doltgresHandler.e.Analyzer.Catalog, nil)
752753
node, flags, err := builder.BindOnly(copyFromStdinNode.InsertStub, "", nil)
753754
if err != nil {
754755
return false, false, err

server/expression/explicit_cast.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"github.com/dolthub/go-mysql-server/sql/expression"
2424
vitess "github.com/dolthub/vitess/go/vt/sqlparser"
2525

26+
"github.com/dolthub/doltgresql/core"
2627
"github.com/dolthub/doltgresql/server/functions/framework"
2728
pgtypes "github.com/dolthub/doltgresql/server/types"
2829
)
@@ -98,6 +99,49 @@ func (c *ExplicitCast) Eval(ctx *sql.Context, row sql.Row) (any, error) {
9899
if castFunction == nil {
99100
if fromType.ID == pgtypes.Unknown.ID {
100101
castFunction = framework.UnknownLiteralCast
102+
} else if fromType.IsRecordType() && c.castToType.IsCompositeType() { // TODO: should this only be in explicit, or assignment and implicit too?
103+
// Casting to a record type will always work for any composite type.
104+
// TODO: is the above statement true for all cases?
105+
// When casting to a composite type, then we must match the arity and have valid casts for every position.
106+
if c.castToType.IsRecordType() {
107+
castFunction = framework.IdentityCast
108+
} else {
109+
castFunction = func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) {
110+
vals, ok := val.([]pgtypes.RecordValue)
111+
if !ok {
112+
// TODO: better error message
113+
return nil, errors.New("casting input error from record type")
114+
}
115+
if len(targetType.CompositeAttrs) != len(vals) {
116+
return nil, errors.Newf("cannot cast type %s to %s", "", targetType.Name())
117+
}
118+
typeCollection, err := core.GetTypesCollectionFromContext(ctx)
119+
if err != nil {
120+
return nil, err
121+
}
122+
outputVals := make([]pgtypes.RecordValue, len(vals))
123+
for i := range vals {
124+
valType, ok := vals[i].Type.(*pgtypes.DoltgresType)
125+
if !ok {
126+
// TODO: if this is a GMS type, then we should cast to a Doltgres type here
127+
return nil, errors.New("cannot cast record containing GMS type")
128+
}
129+
outputVals[i].Type, err = typeCollection.GetType(ctx, targetType.CompositeAttrs[i].TypeID)
130+
if err != nil {
131+
return nil, err
132+
}
133+
innerExplicit := ExplicitCast{
134+
sqlChild: NewUnsafeLiteral(vals[i].Value, valType),
135+
castToType: outputVals[i].Type.(*pgtypes.DoltgresType),
136+
}
137+
outputVals[i].Value, err = innerExplicit.Eval(ctx, nil)
138+
if err != nil {
139+
return nil, err
140+
}
141+
}
142+
return outputVals, nil
143+
}
144+
}
101145
} else {
102146
return nil, errors.Errorf("EXPLICIT CAST: cast from `%s` to `%s` does not exist: %s",
103147
fromType.String(), c.castToType.String(), c.sqlChild.String())

0 commit comments

Comments
 (0)