Skip to content

Commit ea4e3e9

Browse files
committed
working tests for varbit
1 parent e656885 commit ea4e3e9

3 files changed

Lines changed: 29 additions & 50 deletions

File tree

server/functions/varbit.go

Lines changed: 17 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,8 @@ package functions
1616

1717
import (
1818
"fmt"
19-
"strings"
2019

21-
"github.com/cockroachdb/errors"
20+
"github.com/dolthub/doltgresql/postgres/parser/sem/tree"
2221
"github.com/dolthub/go-mysql-server/sql"
2322

2423
"github.com/dolthub/doltgresql/server/functions/framework"
@@ -42,37 +41,23 @@ var varbitin = framework.Function3{
4241
Return: pgtypes.VarBit,
4342
Parameters: [3]*pgtypes.DoltgresType{pgtypes.Cstring, pgtypes.Oid, pgtypes.Int32},
4443
Strict: true,
45-
Callable: func(ctx *sql.Context, _ [4]*pgtypes.DoltgresType, val1, val2, val3 any) (any, error) {
44+
Callable: func(ctx *sql.Context, _ [4]*pgtypes.DoltgresType, val1, _, val3 any) (any, error) {
4645
input := val1.(string)
4746
typmod := val3.(int32)
48-
49-
// Parse bit string - remove leading 'B' or 'b' prefix if present
50-
bitStr := strings.TrimSpace(input)
51-
if len(bitStr) > 0 && (bitStr[0] == 'B' || bitStr[0] == 'b') {
52-
bitStr = bitStr[1:]
53-
// Remove quotes if present
54-
if len(bitStr) > 0 && (bitStr[0] == '\'' || bitStr[0] == '"') {
55-
if len(bitStr) > 1 && bitStr[len(bitStr)-1] == bitStr[0] {
56-
bitStr = bitStr[1 : len(bitStr)-1]
57-
}
58-
}
59-
}
60-
61-
// Validate that all characters are '0' or '1'
62-
for _, r := range bitStr {
63-
if r != '0' && r != '1' {
64-
return nil, pgtypes.ErrInvalidSyntaxForType.New("varbit", input)
65-
}
47+
48+
bitStr, err := tree.ParseDBitArray(input)
49+
if err != nil {
50+
return nil, err
6651
}
67-
52+
6853
// Check length against typmod (varbit allows up to typmod length)
6954
if typmod != -1 {
7055
maxLength := pgtypes.GetCharLengthFromTypmod(typmod)
71-
if int32(len(bitStr)) > maxLength {
72-
return nil, errors.Wrap(pgtypes.ErrCastOutOfRange, fmt.Sprintf("bit string length %d exceeds maximum length %d for type varbit(%d)", len(bitStr), maxLength, maxLength))
56+
if int32(bitStr.BitLen()) > maxLength {
57+
return nil, pgtypes.ErrVarBitLengthExceeded.New(maxLength)
7358
}
7459
}
75-
60+
7661
return bitStr, nil
7762
},
7863
}
@@ -84,17 +69,8 @@ var varbitout = framework.Function1{
8469
Parameters: [1]*pgtypes.DoltgresType{pgtypes.VarBit},
8570
Strict: true,
8671
Callable: func(ctx *sql.Context, t [2]*pgtypes.DoltgresType, val any) (any, error) {
87-
bitStr := val.(string)
88-
typ := t[0]
89-
tm := typ.GetAttTypMod()
90-
if tm != -1 {
91-
maxLength := pgtypes.GetCharLengthFromTypmod(tm)
92-
// Truncate if needed (shouldn't happen normally, but handle it)
93-
if int32(len(bitStr)) > maxLength {
94-
bitStr = bitStr[:maxLength]
95-
}
96-
}
97-
return bitStr, nil
72+
bitStr := val.(*tree.DBitArray)
73+
return tree.AsStringWithFlags(bitStr, tree.FmtPgwireText), nil
9874
},
9975
}
10076

@@ -110,7 +86,7 @@ var varbitrecv = framework.Function3{
11086
return nil, nil
11187
}
11288
reader := utils.NewReader(data)
113-
return reader.String(), nil
89+
return tree.ParseDBitArray(reader.String())
11490
},
11591
}
11692

@@ -121,9 +97,9 @@ var varbitsend = framework.Function1{
12197
Parameters: [1]*pgtypes.DoltgresType{pgtypes.VarBit},
12298
Strict: true,
12399
Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val any) (any, error) {
124-
bitStr := val.(string)
125-
writer := utils.NewWriter(uint64(len(bitStr) + 4))
126-
writer.String(bitStr)
100+
bitStr := val.(*tree.DBitArray)
101+
writer := utils.NewWriter(uint64(bitStr.BitLen() + 4))
102+
writer.String(tree.AsStringWithFlags(bitStr, tree.FmtPgwireText))
127103
return writer.Data(), nil
128104
},
129105
}
@@ -135,7 +111,7 @@ var varbittypmodin = framework.Function1{
135111
Parameters: [1]*pgtypes.DoltgresType{pgtypes.CstringArray},
136112
Strict: true,
137113
Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val any) (any, error) {
138-
return getTypModFromStringArr("bit", val.([]any))
114+
return getTypModFromStringArr("bit varying", val.([]any))
139115
},
140116
}
141117

server/types/varbit.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,12 @@ package types
1616

1717
import (
1818
"github.com/dolthub/doltgresql/core/id"
19+
"gopkg.in/src-d/go-errors.v1"
1920
)
2021

22+
// ErrVarBitLengthExceeded is returned when a varbit value exceeds the defined length.
23+
var ErrVarBitLengthExceeded = errors.NewKind(`bit string too long for type bit varying(%d)`)
24+
2125
// VarBit is a varying-length bit string.
2226
var VarBit = &DoltgresType{
2327
ID: toInternal("varbit"),
@@ -33,7 +37,7 @@ var VarBit = &DoltgresType{
3337
Elem: id.NullType,
3438
Array: toInternal("_varbit"),
3539
InputFunc: toFuncID("varbit_in", toInternal("cstring"), toInternal("oid"), toInternal("int4")),
36-
OutputFunc: toFuncID("varbit_out", toInternal("bit")),
40+
OutputFunc: toFuncID("varbit_out", toInternal("varbit")),
3741
ReceiveFunc: toFuncID("varbit_recv", toInternal("internal"), toInternal("oid"), toInternal("int4")),
3842
SendFunc: toFuncID("varbit_send", toInternal("varbit")),
3943
ModInFunc: toFuncID("varbittypmodin", toInternal("varbit")),
@@ -54,8 +58,8 @@ var VarBit = &DoltgresType{
5458
CompareFunc: toFuncID("bttextcmp", toInternal("text"), toInternal("text")),
5559
}
5660

57-
// NewBitType returns a Bit type with type modifier set
58-
// representing the number of bits in the string.
61+
// NewVarBitType returns a VarBit type with type modifier set
62+
// representing the max number of bits in the string.
5963
func NewVarBitType(width int32) (*DoltgresType, error) {
6064
typmod, err := GetTypModFromCharLength("bit", width)
6165
if err != nil {

testing/go/types_test.go

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,7 @@ var typesTests = []ScriptTest{
7979
},
8080
},
8181
{
82-
Name: "Bit type",
83-
Focus: true,
82+
Name: "Bit type",
8483
SetUpScript: []string{
8584
"CREATE TABLE t_bit (id INTEGER primary key, v1 BIT(8), v2 BIT(3));",
8685
"INSERT INTO t_bit VALUES (1, B'11011010', '101'), (2, B'00101011', '000');",
@@ -303,8 +302,8 @@ var typesTests = []ScriptTest{
303302
},
304303
},
305304
{
306-
Name: "Bit varying type",
307-
Skip: true,
305+
Name: "Bit varying type",
306+
Focus: true,
308307
SetUpScript: []string{
309308
"CREATE TABLE t_bit_varying (id INTEGER primary key, v1 BIT VARYING(16));",
310309
"INSERT INTO t_bit_varying VALUES (1, B'1101101010101010'), (2, B'0010101101010101');",
@@ -313,8 +312,8 @@ var typesTests = []ScriptTest{
313312
{
314313
Query: "SELECT * FROM t_bit_varying ORDER BY id;",
315314
Expected: []sql.Row{
316-
{1, []byte{0xDA, 0xAA}},
317-
{2, []byte{0x2B, 0xA5}},
315+
{1, pgtype.Bits{Bytes: []uint8{0xda, 0xaa}, Len: 16, Valid: true}},
316+
{2, pgtype.Bits{Bytes: []uint8{0x2b, 0x55}, Len: 16, Valid: true}},
318317
},
319318
},
320319
},

0 commit comments

Comments
 (0)