Skip to content

Commit 4969f87

Browse files
authored
Merge pull request #2198 from dolthub/fulghum/returns-table
Support for UDFs with `RETURNS TABLE`
2 parents 797357d + 92b19e9 commit 4969f87

15 files changed

Lines changed: 547 additions & 116 deletions

File tree

core/typecollection/typecollection.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@ import (
3232
pgtypes "github.com/dolthub/doltgresql/server/types"
3333
)
3434

35+
// anonymousCompositePrefix is the prefix for anonymous composite type names. These types are not stored on
36+
// disk, but instead are created dynamically as needed.
37+
const anonymousCompositePrefix = "table("
38+
39+
// anonymousCompositeSuffix is the suffix for anonymous composite type names.
40+
const anonymousCompositeSuffix = ")"
41+
3542
// TypeCollection is a collection of all types (both built-in and user defined).
3643
type TypeCollection struct {
3744
accessedMap map[id.Type]*pgtypes.DoltgresType
@@ -165,6 +172,11 @@ func (pgs *TypeCollection) GetType(ctx context.Context, name id.Type) (*pgtypes.
165172
return nil, err
166173
}
167174
if h.IsEmpty() {
175+
// If this is an anonymous composite type, create it dynamically
176+
if isAnonymousCompositeType(name) {
177+
return createAnonymousCompositeType(ctx, name)
178+
}
179+
168180
// If it's not a built-in type or created type, then check if it's a composite table row type
169181
sqlCtx, ok := ctx.(*sql.Context)
170182
if !ok {
@@ -189,6 +201,34 @@ func (pgs *TypeCollection) GetType(ctx context.Context, name id.Type) (*pgtypes.
189201
return pgt, nil
190202
}
191203

204+
// isAnonymousCompositeType return true if |returnType| represents an anonymous composite return type
205+
// for a function (i.e. the function was declared as "RETURNS TABLE(...)").
206+
func isAnonymousCompositeType(returnType id.Type) bool {
207+
typeName := returnType.TypeName()
208+
return strings.HasPrefix(typeName, anonymousCompositePrefix) &&
209+
strings.HasSuffix(typeName, anonymousCompositeSuffix)
210+
}
211+
212+
// createAnonymousCompositeType creates a new DoltgresType for the anonymous composite return type for a function,
213+
// as represented by |returnType|.
214+
func createAnonymousCompositeType(ctx context.Context, returnType id.Type) (*pgtypes.DoltgresType, error) {
215+
typeName := returnType.TypeName()
216+
attributeTypes := typeName[len(anonymousCompositePrefix) : len(typeName)-len(anonymousCompositeSuffix)]
217+
attributeTypesSlice := strings.Split(attributeTypes, ",")
218+
219+
attrs := make([]pgtypes.CompositeAttribute, len(attributeTypesSlice))
220+
for i, attributeNameAndType := range attributeTypesSlice {
221+
split := strings.Split(attributeNameAndType, ":")
222+
if len(split) != 2 {
223+
return nil, fmt.Errorf("unexpected anonymous composite type attribute syntax: %s", attributeNameAndType)
224+
}
225+
226+
typeId := id.NewType("", split[1])
227+
attrs[i] = pgtypes.NewCompositeAttribute(nil, id.Null, split[0], typeId, int16(i), "")
228+
}
229+
return pgtypes.NewCompositeType(ctx, id.Null, id.NullType, returnType, attrs), nil
230+
}
231+
192232
// HasType checks if a type exists with given schema and type name.
193233
func (pgs *TypeCollection) HasType(ctx context.Context, name id.Type) bool {
194234
// We can check the built-in types first

postgres/parser/parser/sql.y

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4313,11 +4313,11 @@ create_function_stmt:
43134313
}
43144314
| CREATE FUNCTION routine_name opt_routine_arg_with_default_list RETURNS SETOF typename create_function_option_list
43154315
{
4316-
$$.val = &tree.CreateFunction{Name: $3.unresolvedObjectName(), Args: $4.routineArgs(), SetOf: true, RetType: []tree.SimpleColumnDef{tree.SimpleColumnDef{Type: $7.typeReference()}}, Options: $8.routineOptions()}
4316+
$$.val = &tree.CreateFunction{Name: $3.unresolvedObjectName(), Args: $4.routineArgs(), ReturnsSetOf: true, RetType: []tree.SimpleColumnDef{tree.SimpleColumnDef{Type: $7.typeReference()}}, Options: $8.routineOptions()}
43174317
}
43184318
| CREATE FUNCTION routine_name opt_routine_arg_with_default_list RETURNS TABLE '(' opt_returns_table_col_def_list ')' create_function_option_list
43194319
{
4320-
$$.val = &tree.CreateFunction{Name: $3.unresolvedObjectName(), Args: $4.routineArgs(), RetType: $8.simpleColumnDefs(), Options: $10.routineOptions()}
4320+
$$.val = &tree.CreateFunction{Name: $3.unresolvedObjectName(), Args: $4.routineArgs(), ReturnsTable: true, RetType: $8.simpleColumnDefs(), Options: $10.routineOptions()}
43214321
}
43224322
| CREATE OR REPLACE FUNCTION routine_name opt_routine_arg_with_default_list create_function_option_list
43234323
{
@@ -4329,11 +4329,11 @@ create_function_stmt:
43294329
}
43304330
| CREATE OR REPLACE FUNCTION routine_name opt_routine_arg_with_default_list RETURNS SETOF typename create_function_option_list
43314331
{
4332-
$$.val = &tree.CreateFunction{Name: $5.unresolvedObjectName(), Replace: true, Args: $6.routineArgs(), SetOf: true, RetType: []tree.SimpleColumnDef{tree.SimpleColumnDef{Type: $9.typeReference()}}, Options: $10.routineOptions()}
4332+
$$.val = &tree.CreateFunction{Name: $5.unresolvedObjectName(), Replace: true, Args: $6.routineArgs(), ReturnsSetOf: true, RetType: []tree.SimpleColumnDef{tree.SimpleColumnDef{Type: $9.typeReference()}}, Options: $10.routineOptions()}
43334333
}
43344334
| CREATE OR REPLACE FUNCTION routine_name opt_routine_arg_with_default_list RETURNS TABLE '(' opt_returns_table_col_def_list ')' create_function_option_list
43354335
{
4336-
$$.val = &tree.CreateFunction{Name: $5.unresolvedObjectName(), Replace: true, Args: $6.routineArgs(), RetType: $10.simpleColumnDefs(), Options: $12.routineOptions()}
4336+
$$.val = &tree.CreateFunction{Name: $5.unresolvedObjectName(), Replace: true, Args: $6.routineArgs(), ReturnsTable: true, RetType: $10.simpleColumnDefs(), Options: $12.routineOptions()}
43374337
}
43384338

43394339
opt_returns_table_col_def_list:

postgres/parser/sem/tree/create_function.go

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,13 @@ var _ Statement = &CreateFunction{}
2525

2626
// CreateFunction represents a CREATE FUNCTION statement.
2727
type CreateFunction struct {
28-
Name *UnresolvedObjectName
29-
Replace bool
30-
Args RoutineArgs
31-
SetOf bool
32-
RetType []SimpleColumnDef
33-
Options []RoutineOption
28+
Name *UnresolvedObjectName
29+
Replace bool
30+
Args RoutineArgs
31+
ReturnsSetOf bool
32+
ReturnsTable bool
33+
RetType []SimpleColumnDef
34+
Options []RoutineOption
3435
}
3536

3637
// Format implements the NodeFormatter interface.
@@ -47,9 +48,9 @@ func (node *CreateFunction) Format(ctx *FmtCtx) {
4748
ctx.WriteString(" )")
4849
}
4950
if node.RetType != nil {
50-
if len(node.RetType) == 1 && node.RetType[0].Name == "" {
51+
if !node.ReturnsTable {
5152
ctx.WriteString("RETURNS ")
52-
if node.SetOf {
53+
if node.ReturnsSetOf {
5354
ctx.WriteString("SETOF ")
5455
}
5556
ctx.WriteString(node.RetType[0].Type.SQLString())

server/ast/create_function.go

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@
1515
package ast
1616

1717
import (
18+
"context"
1819
"fmt"
1920
"strings"
2021

2122
"github.com/cockroachdb/errors"
2223
vitess "github.com/dolthub/vitess/go/vt/sqlparser"
2324

25+
"github.com/dolthub/doltgresql/core/id"
2426
"github.com/dolthub/doltgresql/postgres/parser/parser"
2527
"github.com/dolthub/doltgresql/postgres/parser/sem/tree"
2628
"github.com/dolthub/doltgresql/postgres/parser/types"
@@ -38,27 +40,41 @@ func nodeCreateFunction(ctx *Context, node *tree.CreateFunction) (vitess.Stateme
3840
}
3941
// Grab the general information that we'll need to create the function
4042
tableName := node.Name.ToTableName()
41-
retType := pgtypes.Void
42-
if len(node.RetType) == 1 {
43+
var retType *pgtypes.DoltgresType
44+
if len(node.RetType) == 0 {
45+
retType = pgtypes.Void
46+
} else if !node.ReturnsTable { // Return types may specify "trigger", but this doesn't apply elsewhere
4347
switch typ := node.RetType[0].Type.(type) {
4448
case *types.T:
4549
retType = pgtypes.NewUnresolvedDoltgresType("", strings.ToLower(typ.Name()))
46-
default:
47-
sqlString := strings.ToLower(typ.SQLString())
48-
if sqlString == "trigger" {
50+
case *tree.UnresolvedObjectName:
51+
if typ.NumParts == 1 && typ.SQLString() == "trigger" {
4952
retType = pgtypes.Trigger
5053
} else {
51-
retType = pgtypes.NewUnresolvedDoltgresType("", sqlString)
54+
_, retType, err = nodeResolvableTypeReference(ctx, typ)
55+
if err != nil {
56+
return nil, err
57+
}
5258
}
59+
default:
60+
return nil, fmt.Errorf("unsupported ResolvableTypeReference type: %T", typ)
5361
}
62+
} else {
63+
retType = createAnonymousCompositeType(node.RetType)
5464
}
65+
5566
paramNames := make([]string, len(node.Args))
5667
paramTypes := make([]*pgtypes.DoltgresType, len(node.Args))
5768
for i, arg := range node.Args {
5869
paramNames[i] = arg.Name.String()
5970
switch argType := arg.Type.(type) {
6071
case *types.T:
6172
paramTypes[i] = pgtypes.NewUnresolvedDoltgresType("", strings.ToLower(argType.Name()))
73+
case *tree.UnresolvedObjectName:
74+
_, paramTypes[i], err = nodeResolvableTypeReference(ctx, argType)
75+
if err != nil {
76+
return nil, err
77+
}
6278
default:
6379
paramTypes[i] = pgtypes.NewUnresolvedDoltgresType("", strings.ToLower(argType.SQLString()))
6480
}
@@ -121,11 +137,38 @@ func nodeCreateFunction(ctx *Context, node *tree.CreateFunction) (vitess.Stateme
121137
parsedBody,
122138
sqlDef,
123139
sqlDefParsed,
124-
node.SetOf,
140+
node.ReturnsSetOf,
125141
),
126142
}, nil
127143
}
128144

145+
// createAnonymousCompositeType creates a new DoltgresType for the anonymous composite return
146+
// type for a function, as represented by the |fieldTypes| that were specified in the function
147+
// definition.
148+
func createAnonymousCompositeType(fieldTypes []tree.SimpleColumnDef) *pgtypes.DoltgresType {
149+
attrs := make([]pgtypes.CompositeAttribute, len(fieldTypes))
150+
for i, fieldType := range fieldTypes {
151+
attrs[i] = pgtypes.NewCompositeAttribute(nil, id.Null, fieldType.Name.String(),
152+
id.NewType("", fieldType.Type.SQLString()), int16(i), "")
153+
}
154+
155+
typeIdString := "table("
156+
for i, attr := range attrs {
157+
if i > 0 {
158+
typeIdString += ","
159+
}
160+
typeIdString += attr.Name
161+
typeIdString += ":"
162+
typeIdString += attr.TypeID.TypeName()
163+
}
164+
typeIdString += ")"
165+
166+
// NOTE: there is no schema needed, since these types are anonymous and can't be directly referenced
167+
typeId := id.NewType("", typeIdString)
168+
169+
return pgtypes.NewCompositeType(context.Background(), id.Null, id.NullType, typeId, attrs)
170+
}
171+
129172
// handleLanguageSQL handles parsing SQL definition strings in both CREATE FUNCTION and CREATE PROCEDURE.
130173
func handleLanguageSQL(definition string, paramNames []string, paramTypes []*pgtypes.DoltgresType) (string, vitess.Statement, error) {
131174
stmt, err := parser.ParseOne(definition)

server/expression/array.go

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,7 @@ func (array *Array) Eval(ctx *sql.Context, row sql.Row) (any, error) {
8282
// We always cast the element, as there may be parameter restrictions in place
8383
castFunc := framework.GetImplicitCast(doltgresType, resultTyp)
8484
if castFunc == nil {
85-
if doltgresType.ID == pgtypes.Unknown.ID {
86-
castFunc = framework.UnknownLiteralCast
87-
} else {
88-
return nil, errors.Errorf("cannot find cast function from %s to %s", doltgresType.String(), resultTyp.String())
89-
}
85+
return nil, errors.Errorf("cannot find cast function from %s to %s", doltgresType.String(), resultTyp.String())
9086
}
9187

9288
values[i], err = castFunc(ctx, val, resultTyp)
@@ -175,7 +171,7 @@ func (array *Array) getTargetType(children ...sql.Expression) (*pgtypes.Doltgres
175171
childrenTypes = append(childrenTypes, childType)
176172
}
177173
}
178-
targetType, err := framework.FindCommonType(childrenTypes)
174+
targetType, _, err := framework.FindCommonType(childrenTypes)
179175
if err != nil {
180176
return nil, errors.Errorf("ARRAY %s", err.Error())
181177
}

server/expression/assignment_cast.go

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,8 @@ func (ac *AssignmentCast) Eval(ctx *sql.Context, row sql.Row) (any, error) {
5656
}
5757
castFunc := framework.GetAssignmentCast(ac.fromType, ac.toType)
5858
if castFunc == nil {
59-
if ac.fromType.ID == pgtypes.Unknown.ID {
60-
castFunc = framework.UnknownLiteralCast
61-
} else {
62-
return nil, errors.Errorf("ASSIGNMENT_CAST: target is of type %s but expression is of type %s: %s",
63-
ac.toType.String(), ac.fromType.String(), ac.expr.String())
64-
}
59+
return nil, errors.Errorf("ASSIGNMENT_CAST: target is of type %s but expression is of type %s: %s",
60+
ac.toType.String(), ac.fromType.String(), ac.expr.String())
6561
}
6662
return castFunc(ctx, val, ac.toType)
6763
}

server/expression/explicit_cast.go

Lines changed: 2 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import (
2323
"github.com/dolthub/go-mysql-server/sql/expression"
2424
vitess "github.com/dolthub/vitess/go/vt/sqlparser"
2525

26-
"github.com/dolthub/doltgresql/core"
2726
"github.com/dolthub/doltgresql/server/functions/framework"
2827
pgtypes "github.com/dolthub/doltgresql/server/types"
2928
)
@@ -97,55 +96,8 @@ func (c *ExplicitCast) Eval(ctx *sql.Context, row sql.Row) (any, error) {
9796
baseCastToType := checkForDomainType(c.castToType)
9897
castFunction := framework.GetExplicitCast(fromType, baseCastToType)
9998
if castFunction == nil {
100-
if fromType.ID == pgtypes.Unknown.ID {
101-
castFunction = framework.UnknownLiteralCast
102-
} else if fromType.IsRecordType() && c.castToType.IsCompositeType() { // TODO: should this only be in explicit, or assignment and implicit too?
103-
// Casting to a record type will always work for any composite type.
104-
// TODO: is the above statement true for all cases?
105-
// When casting to a composite type, then we must match the arity and have valid casts for every position.
106-
if c.castToType.IsRecordType() {
107-
castFunction = framework.IdentityCast
108-
} else {
109-
castFunction = func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) {
110-
vals, ok := val.([]pgtypes.RecordValue)
111-
if !ok {
112-
// TODO: better error message
113-
return nil, errors.New("casting input error from record type")
114-
}
115-
if len(targetType.CompositeAttrs) != len(vals) {
116-
return nil, errors.Newf("cannot cast type %s to %s", "", targetType.Name())
117-
}
118-
typeCollection, err := core.GetTypesCollectionFromContext(ctx)
119-
if err != nil {
120-
return nil, err
121-
}
122-
outputVals := make([]pgtypes.RecordValue, len(vals))
123-
for i := range vals {
124-
valType, ok := vals[i].Type.(*pgtypes.DoltgresType)
125-
if !ok {
126-
// TODO: if this is a GMS type, then we should cast to a Doltgres type here
127-
return nil, errors.New("cannot cast record containing GMS type")
128-
}
129-
outputVals[i].Type, err = typeCollection.GetType(ctx, targetType.CompositeAttrs[i].TypeID)
130-
if err != nil {
131-
return nil, err
132-
}
133-
innerExplicit := ExplicitCast{
134-
sqlChild: NewUnsafeLiteral(vals[i].Value, valType),
135-
castToType: outputVals[i].Type.(*pgtypes.DoltgresType),
136-
}
137-
outputVals[i].Value, err = innerExplicit.Eval(ctx, nil)
138-
if err != nil {
139-
return nil, err
140-
}
141-
}
142-
return outputVals, nil
143-
}
144-
}
145-
} else {
146-
return nil, errors.Errorf("EXPLICIT CAST: cast from `%s` to `%s` does not exist: %s",
147-
fromType.String(), c.castToType.String(), c.sqlChild.String())
148-
}
99+
return nil, errors.Errorf("EXPLICIT CAST: cast from `%s` to `%s` does not exist: %s",
100+
fromType.String(), c.castToType.String(), c.sqlChild.String())
149101
}
150102
castResult, err := castFunction(ctx, val, c.castToType)
151103
if err != nil {

0 commit comments

Comments
 (0)