@@ -16,9 +16,8 @@ package functions
1616
1717import (
1818 "fmt"
19- "strings"
2019
21- "github.com/cockroachdb/errors "
20+ "github.com/dolthub/doltgresql/postgres/parser/sem/tree "
2221 "github.com/dolthub/go-mysql-server/sql"
2322
2423 "github.com/dolthub/doltgresql/server/functions/framework"
@@ -42,38 +41,21 @@ var bitin = framework.Function3{
4241 Return : pgtypes .Bit ,
4342 Parameters : [3 ]* pgtypes.DoltgresType {pgtypes .Cstring , pgtypes .Oid , pgtypes .Int32 },
4443 Strict : true ,
45- Callable : func (ctx * sql.Context , _ [4 ]* pgtypes.DoltgresType , val1 , val2 , val3 any ) (any , error ) {
44+ Callable : func (ctx * sql.Context , _ [4 ]* pgtypes.DoltgresType , val1 , _ , val3 any ) (any , error ) {
4645 input := val1 .(string )
4746 typmod := val3 .(int32 )
48-
49- // Parse bit string - remove leading 'B' or 'b' prefix if present
50- bitStr := strings .TrimSpace (input )
51- if len (bitStr ) > 0 && (bitStr [0 ] == 'B' || bitStr [0 ] == 'b' ) {
52- bitStr = bitStr [1 :]
53- // Remove quotes if present
54- if len (bitStr ) > 0 && (bitStr [0 ] == '\'' || bitStr [0 ] == '"' ) {
55- if len (bitStr ) > 1 && bitStr [len (bitStr )- 1 ] == bitStr [0 ] {
56- bitStr = bitStr [1 : len (bitStr )- 1 ]
57- }
58- }
59- }
60-
61- // Validate that all characters are '0' or '1'
62- for _ , r := range bitStr {
63- if r != '0' && r != '1' {
64- return nil , pgtypes .ErrInvalidSyntaxForType .New ("bit" , input )
65- }
47+
48+ array , err := tree .ParseDBitArray (input )
49+ if err != nil {
50+ return nil , err
6651 }
67-
68- // Check length against typmod
69- if typmod != - 1 {
70- expectedLength := pgtypes .GetCharLengthFromTypmod (typmod )
71- if int32 (len (bitStr )) != expectedLength {
72- return nil , pgtypes .ErrInvalidSyntaxForType .New ("bit" , input )
73- }
52+
53+ expectedLength := pgtypes .GetCharLengthFromTypmod (typmod )
54+ if array .BitLen () != uint (expectedLength ) {
55+ return nil , pgtypes .ErrInvalidSyntaxForType .New ("bit" , input )
7456 }
75-
76- return bitStr , nil
57+
58+ return array , nil
7759 },
7860}
7961
@@ -84,19 +66,8 @@ var bitout = framework.Function1{
8466 Parameters : [1 ]* pgtypes.DoltgresType {pgtypes .Bit },
8567 Strict : true ,
8668 Callable : func (ctx * sql.Context , t [2 ]* pgtypes.DoltgresType , val any ) (any , error ) {
87- bitStr := val .(string )
88- typ := t [0 ]
89- tm := typ .GetAttTypMod ()
90- if tm != - 1 {
91- expectedLength := pgtypes .GetCharLengthFromTypmod (tm )
92- // Pad with zeros if needed (shouldn't happen for fixed-length bit)
93- if int32 (len (bitStr )) < expectedLength {
94- bitStr = bitStr + strings .Repeat ("0" , int (expectedLength - int32 (len (bitStr ))))
95- } else if int32 (len (bitStr )) > expectedLength {
96- bitStr = bitStr [:expectedLength ]
97- }
98- }
99- return bitStr , nil
69+ bitStr := val .(tree.DBitArray )
70+ return bitStr .String (), nil
10071 },
10172}
10273
@@ -112,7 +83,7 @@ var bitrecv = framework.Function3{
11283 return nil , nil
11384 }
11485 reader := utils .NewReader (data )
115- return reader .String (), nil
86+ return tree . ParseDBitArray ( reader .String ())
11687 },
11788}
11889
@@ -123,10 +94,8 @@ var bitsend = framework.Function1{
12394 Parameters : [1 ]* pgtypes.DoltgresType {pgtypes .Bit },
12495 Strict : true ,
12596 Callable : func (ctx * sql.Context , _ [2 ]* pgtypes.DoltgresType , val any ) (any , error ) {
126- bitStr := val .(string )
127- writer := utils .NewWriter (uint64 (len (bitStr ) + 4 ))
128- writer .String (bitStr )
129- return writer .Data (), nil
97+ bitStr := val .(tree.DBitArray )
98+ return bitStr .String (), nil
13099 },
131100}
132101
0 commit comments