Skip to content

Commit 56713aa

Browse files
committed
Updated cast logic to be more accurate
1 parent 4e0f8c3 commit 56713aa

8 files changed

Lines changed: 235 additions & 59 deletions

File tree

server/cast/bit.go

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
// Copyright 2026 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package cast
16+
17+
import (
18+
"github.com/cockroachdb/errors"
19+
"github.com/dolthub/go-mysql-server/sql"
20+
21+
"github.com/dolthub/doltgresql/postgres/parser/sem/tree"
22+
23+
"github.com/dolthub/doltgresql/server/functions/framework"
24+
pgtypes "github.com/dolthub/doltgresql/server/types"
25+
)
26+
27+
// initBit handles all casts that are built-in. This comprises only the "From" types.
28+
func initBit() {
29+
bitExplicit()
30+
bitImplicit()
31+
}
32+
33+
// bitExplicit registers all explicit casts. This comprises only the "From" types.
34+
func bitExplicit() {
35+
framework.MustAddExplicitTypeCast(framework.TypeCast{
36+
FromType: pgtypes.Bit,
37+
ToType: pgtypes.Int32,
38+
Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) {
39+
array, err := tree.ParseDBitArray(val.(string))
40+
if err != nil {
41+
return nil, err
42+
}
43+
if array.BitLen() > 32 {
44+
return nil, errors.Wrap(pgtypes.ErrCastOutOfRange, "integer out of range")
45+
}
46+
return int32(array.AsInt64(32)), nil
47+
},
48+
})
49+
framework.MustAddExplicitTypeCast(framework.TypeCast{
50+
FromType: pgtypes.Bit,
51+
ToType: pgtypes.Int64,
52+
Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) {
53+
array, err := tree.ParseDBitArray(val.(string))
54+
if err != nil {
55+
return nil, err
56+
}
57+
if array.BitLen() > 64 {
58+
return nil, errors.Wrap(pgtypes.ErrCastOutOfRange, "bigint out of range")
59+
}
60+
return array.AsInt64(64), nil
61+
},
62+
})
63+
}
64+
65+
// bitImplicit registers all implicit casts. This comprises only the "From" types.
66+
func bitImplicit() {
67+
framework.MustAddImplicitTypeCast(framework.TypeCast{
68+
FromType: pgtypes.Bit,
69+
ToType: pgtypes.Bit,
70+
Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) {
71+
input := val.(string)
72+
array, err := tree.ParseDBitArray(input)
73+
if err != nil {
74+
return nil, err
75+
}
76+
expectedLength := pgtypes.GetCharLengthFromTypmod(targetType.GetAttTypMod())
77+
if array.BitLen() != uint(expectedLength) {
78+
return nil, pgtypes.ErrWrongLengthBit.New(len(input), expectedLength)
79+
}
80+
return tree.AsStringWithFlags(array, tree.FmtPgwireText), nil
81+
},
82+
})
83+
framework.MustAddImplicitTypeCast(framework.TypeCast{
84+
FromType: pgtypes.Bit,
85+
ToType: pgtypes.VarBit,
86+
Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) {
87+
input := val.(string)
88+
array, err := tree.ParseDBitArray(input)
89+
if err != nil {
90+
return nil, err
91+
}
92+
atttypmod := targetType.GetAttTypMod()
93+
if atttypmod != -1 {
94+
maxLength := pgtypes.GetCharLengthFromTypmod(atttypmod)
95+
if int32(array.BitLen()) > maxLength {
96+
return nil, pgtypes.ErrVarBitLengthExceeded.New(maxLength)
97+
}
98+
}
99+
return tree.AsStringWithFlags(array, tree.FmtPgwireText), nil
100+
},
101+
})
102+
}

server/cast/init.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121

2222
// Init initializes all casts in this package.
2323
func Init() {
24+
initBit()
2425
initBool()
2526
initChar()
2627
initDate()
@@ -44,6 +45,7 @@ func Init() {
4445
initTimestamp()
4546
initTimestampTZ()
4647
initTimeTZ()
48+
initVarBit()
4749
initVarChar()
4850

4951
// This is a hack to get around import cycles. The types package needs these references for type conversions in

server/cast/text.go

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -75,18 +75,4 @@ func textImplicit() {
7575
return handleStringCast(val.(string), targetType)
7676
},
7777
})
78-
framework.MustAddImplicitTypeCast(framework.TypeCast{
79-
FromType: pgtypes.Text,
80-
ToType: pgtypes.VarBit,
81-
Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) {
82-
return targetType.IoInput(ctx, val.(string))
83-
},
84-
})
85-
framework.MustAddImplicitTypeCast(framework.TypeCast{
86-
FromType: pgtypes.Text,
87-
ToType: pgtypes.Bit,
88-
Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) {
89-
return targetType.IoInput(ctx, val.(string))
90-
},
91-
})
9278
}

server/cast/varbit.go

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// Copyright 2026 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package cast
16+
17+
import (
18+
"github.com/dolthub/go-mysql-server/sql"
19+
20+
"github.com/dolthub/doltgresql/postgres/parser/sem/tree"
21+
22+
"github.com/dolthub/doltgresql/server/functions/framework"
23+
pgtypes "github.com/dolthub/doltgresql/server/types"
24+
)
25+
26+
// initVarBit handles all casts that are built-in. This comprises only the "From" types.
27+
func initVarBit() {
28+
varBitImplicit()
29+
}
30+
31+
// varBitImplicit registers all implicit casts. This comprises only the "From" types.
32+
func varBitImplicit() {
33+
framework.MustAddImplicitTypeCast(framework.TypeCast{
34+
FromType: pgtypes.VarBit,
35+
ToType: pgtypes.Bit,
36+
Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) {
37+
input := val.(string)
38+
array, err := tree.ParseDBitArray(input)
39+
if err != nil {
40+
return nil, err
41+
}
42+
expectedLength := pgtypes.GetCharLengthFromTypmod(targetType.GetAttTypMod())
43+
if array.BitLen() != uint(expectedLength) {
44+
return nil, pgtypes.ErrWrongLengthBit.New(len(input), expectedLength)
45+
}
46+
return tree.AsStringWithFlags(array, tree.FmtPgwireText), nil
47+
},
48+
})
49+
framework.MustAddImplicitTypeCast(framework.TypeCast{
50+
FromType: pgtypes.VarBit,
51+
ToType: pgtypes.VarBit,
52+
Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) {
53+
input := val.(string)
54+
array, err := tree.ParseDBitArray(input)
55+
if err != nil {
56+
return nil, err
57+
}
58+
atttypmod := targetType.GetAttTypMod()
59+
if atttypmod != -1 {
60+
maxLength := pgtypes.GetCharLengthFromTypmod(atttypmod)
61+
if int32(array.BitLen()) > maxLength {
62+
return nil, pgtypes.ErrVarBitLengthExceeded.New(maxLength)
63+
}
64+
}
65+
return tree.AsStringWithFlags(array, tree.FmtPgwireText), nil
66+
},
67+
})
68+
}

server/functions/framework/cast.go

Lines changed: 43 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,9 @@ func GetExplicitCast(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresTyp
130130
} else if tcf = getCast(implicitTypeCastMutex, implicitTypeCastsMap, fromType, toType, GetExplicitCast); tcf != nil {
131131
return tcf
132132
}
133-
// We check for the identity after checking the maps, as the identity may be overridden (such as for types that have
134-
// parameters). If one of the types are a string type, then we do not use the identity, and use the I/O conversions
135-
// below.
136-
if fromType.ID == toType.ID && toType.TypCategory != pgtypes.TypeCategory_StringTypes && fromType.TypCategory != pgtypes.TypeCategory_StringTypes {
137-
return IdentityCast
133+
// We check for the identity and sizing casts after checking the maps, as the identity may be overridden by a user.
134+
if cast := getSizingOrIdentityCast(fromType, toType, true); cast != nil {
135+
return cast
138136
}
139137
// All types have a built-in explicit cast from string types: https://www.postgresql.org/docs/15/sql-createcast.html
140138
if fromType.TypCategory == pgtypes.TypeCategory_StringTypes {
@@ -172,15 +170,11 @@ func GetAssignmentCast(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresT
172170
} else if tcf = getCast(implicitTypeCastMutex, implicitTypeCastsMap, fromType, toType, GetAssignmentCast); tcf != nil {
173171
return tcf
174172
}
175-
// We check for the identity after checking the maps, as the identity may be overridden (such as for types that have
176-
// parameters). If the "to" type is a string type, then we do not use the identity, and use the I/O conversion below.
177-
if fromType.ID == toType.ID && fromType.TypCategory != pgtypes.TypeCategory_StringTypes && fromType.TypCategory != pgtypes.TypeCategory_BitStringTypes {
178-
return IdentityCast
173+
// We check for the identity and sizing casts after checking the maps, as the identity may be overridden by a user.
174+
if cast := getSizingOrIdentityCast(fromType, toType, false); cast != nil {
175+
return cast
179176
}
180-
181177
// All types have a built-in assignment cast to string types: https://www.postgresql.org/docs/15/sql-createcast.html
182-
// This is also where length checks occur for types like char(n), varchar(n), bit(n), etc., which is not great
183-
// TODO: move length checks to their own analyzer step
184178
if toType.TypCategory == pgtypes.TypeCategory_StringTypes {
185179
return func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) {
186180
if val == nil {
@@ -192,17 +186,6 @@ func GetAssignmentCast(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresT
192186
}
193187
return targetType.IoInput(ctx, str)
194188
}
195-
} else if toType.TypCategory == pgtypes.TypeCategory_BitStringTypes {
196-
return func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) {
197-
if val == nil {
198-
return nil, nil
199-
}
200-
str, err := fromType.IoOutput(ctx, val)
201-
if err != nil {
202-
return nil, err
203-
}
204-
return targetType.IoInput(ctx, str)
205-
}
206189
}
207190
return nil
208191
}
@@ -213,10 +196,9 @@ func GetImplicitCast(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresTyp
213196
if tcf := getCast(implicitTypeCastMutex, implicitTypeCastsMap, fromType, toType, GetImplicitCast); tcf != nil {
214197
return tcf
215198
}
216-
// We check for the identity after checking the maps, as the identity may be overridden (such as for types that have
217-
// parameters).
218-
if fromType.ID == toType.ID {
219-
return IdentityCast
199+
// We check for the identity and sizing casts after checking the maps, as the identity may be overridden by a user.
200+
if cast := getSizingOrIdentityCast(fromType, toType, false); cast != nil {
201+
return cast
220202
}
221203
return nil
222204
}
@@ -296,6 +278,40 @@ func getCast(mutex *sync.RWMutex,
296278
return nil
297279
}
298280

281+
// getSizingOrIdentityCast returns an identity cast if the two types are exactly the same, and a sizing cast if they
282+
// only differ in their atttypmod values. Returns nil if no functions are matched. This mirrors the behavior as described in:
283+
// https://www.postgresql.org/docs/15/typeconv-query.html
284+
func getSizingOrIdentityCast(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresType, isExplicitCast bool) pgtypes.TypeCastFunction {
285+
// If we receive different types, then we can return immediately
286+
if fromType.ID != toType.ID {
287+
return nil
288+
}
289+
// If we have different atttypmod values, then we need to do a sizing cast only if one exists
290+
if fromType.GetAttTypMod() != toType.GetAttTypMod() {
291+
// TODO: We don't have any sizing cast functions implemented, so for now we'll approximate using output to input.
292+
// We can use the query below to find all implemented sizing cast functions. It's also detailed in the link above.
293+
// Lastly, not all sizing functions accept a boolean, but for those that do, we need to see whether true is
294+
// used for explicit casts, or whether true is used for implicit casts.
295+
// SELECT
296+
// format_type(c.castsource, NULL) AS source,
297+
// format_type(c.casttarget, NULL) AS target,
298+
// p.oid::regprocedure AS func
299+
// FROM pg_cast c JOIN pg_proc p ON p.oid = c.castfunc WHERE c.castsource = c.casttarget ORDER BY 1,2;
300+
return func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) {
301+
if val == nil {
302+
return nil, nil
303+
}
304+
str, err := fromType.IoOutput(ctx, val)
305+
if err != nil {
306+
return nil, err
307+
}
308+
return targetType.IoInput(ctx, str)
309+
}
310+
}
311+
// If there is no sizing cast, then we simply use the identity cast
312+
return IdentityCast
313+
}
314+
299315
// IdentityCast returns the input value.
300316
func IdentityCast(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) {
301317
return val, nil

server/types/bit.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ var Bit = &DoltgresType{
4141
OutputFunc: toFuncID("bit_out", toInternal("bit")),
4242
ReceiveFunc: toFuncID("bit_recv", toInternal("internal"), toInternal("oid"), toInternal("int4")),
4343
SendFunc: toFuncID("bit_send", toInternal("bit")),
44-
ModInFunc: toFuncID("bittypmodin", toInternal("bit")),
45-
ModOutFunc: toFuncID("bittypmodout", toInternal("bit")),
44+
ModInFunc: toFuncID("bittypmodin", toInternal("_cstring")),
45+
ModOutFunc: toFuncID("bittypmodout", toInternal("int4")),
4646
AnalyzeFunc: toFuncID("-"),
4747
Align: TypeAlignment_Int,
4848
Storage: TypeStorage_Extended,

server/types/varbit.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ var VarBit = &DoltgresType{
4141
OutputFunc: toFuncID("varbit_out", toInternal("varbit")),
4242
ReceiveFunc: toFuncID("varbit_recv", toInternal("internal"), toInternal("oid"), toInternal("int4")),
4343
SendFunc: toFuncID("varbit_send", toInternal("varbit")),
44-
ModInFunc: toFuncID("varbittypmodin", toInternal("varbit")),
45-
ModOutFunc: toFuncID("varbittypmodout", toInternal("varbit")),
44+
ModInFunc: toFuncID("varbittypmodin", toInternal("_cstring")),
45+
ModOutFunc: toFuncID("varbittypmodout", toInternal("int4")),
4646
AnalyzeFunc: toFuncID("-"),
4747
Align: TypeAlignment_Int,
4848
Storage: TypeStorage_Extended,

testing/go/adaptive_encoding_test.go

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -236,10 +236,10 @@ func TestAdaptiveEncodingVarbit(t *testing.T) {
236236
SetUpScript: setup.SetupScript{
237237
fmt.Sprintf(`create table blobt (i char(1) primary key, b %s);`, columnType),
238238
fmt.Sprintf(`create table blobt2 (i char(2) primary key, b1 %s, b2 %s);`, columnType, columnType),
239-
`insert into blobt values
240-
('F', LOAD_FILE('testdata/fullSizeVarbit')),
241-
('H', LOAD_FILE('testdata/halfSizeVarbit')),
242-
('T', LOAD_FILE('testdata/tinyFileVarbit'))`,
239+
fmt.Sprintf(`insert into blobt values
240+
('F', LOAD_FILE('testdata/fullSizeVarbit')::%s),
241+
('H', LOAD_FILE('testdata/halfSizeVarbit')::%s),
242+
('T', LOAD_FILE('testdata/tinyFileVarbit')::%s)`, columnType, columnType, columnType),
243243
},
244244
Assertions: []ScriptTestAssertion{
245245
{
@@ -267,16 +267,18 @@ func TestAdaptiveEncodingVarbit(t *testing.T) {
267267
Name: "Adaptive Encoding With Two Columns",
268268
SetUpScript: setup.SetupScript{
269269
fmt.Sprintf(`create table blobt2 (i char(2) primary key, b1 %s, b2 %s);`, columnType, columnType),
270-
`insert into blobt2 values
271-
('FF', LOAD_FILE('testdata/fullSizeVarbit'), LOAD_FILE('testdata/fullSizeVarbit')),
272-
('HF', LOAD_FILE('testdata/halfSizeVarbit'), LOAD_FILE('testdata/fullSizeVarbit')),
273-
('TF', LOAD_FILE('testdata/tinyFileVarbit'), LOAD_FILE('testdata/fullSizeVarbit')),
274-
('FH', LOAD_FILE('testdata/fullSizeVarbit'), LOAD_FILE('testdata/halfSizeVarbit')),
275-
('HH', LOAD_FILE('testdata/halfSizeVarbit'), LOAD_FILE('testdata/halfSizeVarbit')),
276-
('TH', LOAD_FILE('testdata/tinyFileVarbit'), LOAD_FILE('testdata/halfSizeVarbit')),
277-
('FT', LOAD_FILE('testdata/fullSizeVarbit'), LOAD_FILE('testdata/tinyFileVarbit')),
278-
('HT', LOAD_FILE('testdata/halfSizeVarbit'), LOAD_FILE('testdata/tinyFileVarbit')),
279-
('TT', LOAD_FILE('testdata/tinyFileVarbit'), LOAD_FILE('testdata/tinyFileVarbit'))`,
270+
fmt.Sprintf(`insert into blobt2 values
271+
('FF', LOAD_FILE('testdata/fullSizeVarbit')::%s, LOAD_FILE('testdata/fullSizeVarbit')::%s),
272+
('HF', LOAD_FILE('testdata/halfSizeVarbit')::%s, LOAD_FILE('testdata/fullSizeVarbit')::%s),
273+
('TF', LOAD_FILE('testdata/tinyFileVarbit')::%s, LOAD_FILE('testdata/fullSizeVarbit')::%s),
274+
('FH', LOAD_FILE('testdata/fullSizeVarbit')::%s, LOAD_FILE('testdata/halfSizeVarbit')::%s),
275+
('HH', LOAD_FILE('testdata/halfSizeVarbit')::%s, LOAD_FILE('testdata/halfSizeVarbit')::%s),
276+
('TH', LOAD_FILE('testdata/tinyFileVarbit')::%s, LOAD_FILE('testdata/halfSizeVarbit')::%s),
277+
('FT', LOAD_FILE('testdata/fullSizeVarbit')::%s, LOAD_FILE('testdata/tinyFileVarbit')::%s),
278+
('HT', LOAD_FILE('testdata/halfSizeVarbit')::%s, LOAD_FILE('testdata/tinyFileVarbit')::%s),
279+
('TT', LOAD_FILE('testdata/tinyFileVarbit')::%s, LOAD_FILE('testdata/tinyFileVarbit')::%s)`, columnType, columnType,
280+
columnType, columnType, columnType, columnType, columnType, columnType, columnType, columnType,
281+
columnType, columnType, columnType, columnType, columnType, columnType, columnType, columnType),
280282
},
281283
Assertions: []ScriptTestAssertion{
282284
{

0 commit comments

Comments
 (0)