Skip to content

Commit 1759e8b

Browse files
authored
Merge pull request #2022 from dolthub/zachmu/prepared
Bug fixes for null handling in prepared statements
2 parents d3226c2 + f590e62 commit 1759e8b

4 files changed

Lines changed: 179 additions & 28 deletions

File tree

server/doltgres_handler.go

Lines changed: 55 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -281,11 +281,15 @@ func (h *DoltgresHandler) convertBindParameters(ctx *sql.Context, types []uint32
281281
if !ok {
282282
return nil, errors.Errorf("unhandled oid type: %v", types[i])
283283
}
284-
v, err := pgTyp.IoInput(ctx, bindVarString)
285-
if err != nil {
286-
return nil, err
284+
if bindVarString == nil {
285+
bindings[fmt.Sprintf("v%d", i+1)] = sqlparser.InjectedExpr{Expression: pgexprs.NewUnsafeLiteral(nil, pgTyp)}
286+
} else {
287+
v, err := pgTyp.IoInput(ctx, *bindVarString)
288+
if err != nil {
289+
return nil, err
290+
}
291+
bindings[fmt.Sprintf("v%d", i+1)] = sqlparser.InjectedExpr{Expression: pgexprs.NewUnsafeLiteral(v, pgTyp)}
287292
}
288-
bindings[fmt.Sprintf("v%d", i+1)] = sqlparser.InjectedExpr{Expression: pgexprs.NewUnsafeLiteral(v, pgTyp)}
289293
}
290294
return bindings, nil
291295
}
@@ -299,50 +303,75 @@ func (h *DoltgresHandler) convertBindParameters(ctx *sql.Context, types []uint32
299303
// This function relies on the pgtype library to decode values, in text and binary formats,
300304
// however, a few types cannot be scanned directly into strings from the binary format by this
301305
// library, so there is special handling for them.
302-
func (h *DoltgresHandler) convertBindParameterToString(typ uint32, value []byte, formatCode int16) (bindVarString string, err error) {
306+
func (h *DoltgresHandler) convertBindParameterToString(typ uint32, value []byte, formatCode int16) (bindVarString *string, err error) {
303307
isBinaryFormat := formatCode == pgtype.BinaryFormatCode
304308

305309
switch {
306310
case (typ == pgtype.TimestampOID || typ == pgtype.TimestamptzOID) && isBinaryFormat:
307-
var t time.Time
311+
var t *time.Time
308312
if err := h.pgTypeMap.Scan(typ, formatCode, value, &t); err != nil {
309-
return "", err
313+
return nil, err
314+
}
315+
if t != nil {
316+
format := t.Format("2006-01-02 15:04:05")
317+
bindVarString = &format
310318
}
311-
bindVarString = t.Format("2006-01-02 15:04:05")
312319
case typ == pgtype.DateOID && isBinaryFormat:
313-
var d pgtype.Date
320+
var d *pgtype.Date
314321
if err := h.pgTypeMap.Scan(typ, formatCode, value, &d); err != nil {
315-
return "", err
322+
return nil, err
323+
}
324+
if d != nil {
325+
format := d.Time.Format("2006-01-02")
326+
bindVarString = &format
316327
}
317-
bindVarString = d.Time.Format("2006-01-02")
318328
case typ == pgtype.BoolOID && isBinaryFormat:
319-
var b bool
329+
var b *bool
320330
if err := h.pgTypeMap.Scan(typ, formatCode, value, &b); err != nil {
321-
return "", err
331+
return nil, err
322332
}
323-
if b {
324-
bindVarString = "true"
325-
} else {
326-
bindVarString = "false"
333+
if b != nil {
334+
if *b {
335+
var t = "true"
336+
bindVarString = &t
337+
} else {
338+
var f = "false"
339+
bindVarString = &f
340+
}
327341
}
328342
case typ == pgtype.ByteaOID && isBinaryFormat:
329-
bindVarString = `\x` + hex.EncodeToString(value)
343+
if value != nil {
344+
s := `\x` + hex.EncodeToString(value)
345+
bindVarString = &s
346+
}
330347
case typ == pgtype.Int2OID && isBinaryFormat:
331-
bindVarString = strconv.FormatInt(int64(binary.BigEndian.Uint16(value)), 10)
348+
if value != nil {
349+
formatInt := strconv.FormatInt(int64(binary.BigEndian.Uint16(value)), 10)
350+
bindVarString = &formatInt
351+
}
332352
case typ == pgtype.Int4OID && isBinaryFormat:
333-
bindVarString = strconv.FormatInt(int64(binary.BigEndian.Uint32(value)), 10)
353+
if value != nil {
354+
formatInt := strconv.FormatInt(int64(binary.BigEndian.Uint32(value)), 10)
355+
bindVarString = &formatInt
356+
}
334357
case typ == pgtype.Int8OID && isBinaryFormat:
335-
bindVarString = strconv.FormatInt(int64(binary.BigEndian.Uint64(value)), 10)
358+
if value != nil {
359+
formatInt := strconv.FormatInt(int64(binary.BigEndian.Uint64(value)), 10)
360+
bindVarString = &formatInt
361+
}
336362
case typ == pgtype.UUIDOID && isBinaryFormat:
337-
u, err := uuid.FromBytes(value)
338-
if err != nil {
339-
return "", err
363+
if value != nil {
364+
u, err := uuid.FromBytes(value)
365+
if err != nil {
366+
return nil, err
367+
}
368+
s := u.String()
369+
bindVarString = &s
340370
}
341-
bindVarString = u.String()
342371
default:
343372
// For text format or types that can handle binary-to-string conversion
344373
if err := h.pgTypeMap.Scan(typ, formatCode, value, &bindVarString); err != nil {
345-
return "", err
374+
return nil, err
346375
}
347376
}
348377

testing/go/framework.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ func NormalizeExpectedRow(fds []pgconn.FieldDescription, rows []sql.Row) []sql.R
486486
// try using text type
487487
dt = types.Text
488488
}
489-
if dt.ID == types.Json.ID {
489+
if dt.ID == types.Json.ID && row[i] != nil {
490490
newRow[i] = UnmarshalAndMarshalJsonString(row[i].(string))
491491
} else if dt.IsArrayType() && dt.ArrayBaseType().ID == types.Json.ID {
492492
// TODO: need to have valid sql.Context
@@ -538,6 +538,10 @@ func UnmarshalAndMarshalJsonString(val string) string {
538538
// There are an infinite number of ways to represent the same value in-memory,
539539
// so we must at least normalize Numeric values.
540540
func NormalizeValToString(dt *types.DoltgresType, v any) any {
541+
if v == nil {
542+
return nil
543+
}
544+
541545
switch dt.ID {
542546
case types.Json.ID:
543547
str, err := json.Marshal(v)

testing/go/prepared_statement_test.go

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,36 @@ var preparedStatementTests = []ScriptTest{
186186
{3, 4},
187187
},
188188
},
189+
{
190+
Query: "INSERT INTO test VALUES ($1, $2) returning *",
191+
BindVars: []any{2, nil},
192+
Expected: []sql.Row{
193+
{2, nil},
194+
},
195+
},
196+
},
197+
},
198+
{
199+
Name: "Integer types",
200+
SetUpScript: []string{
201+
"drop table if exists test",
202+
"CREATE TABLE test (pk BIGINT PRIMARY KEY, v2 SMALLINT, v4 INTEGER, v5 BIGINT);",
203+
},
204+
Assertions: []ScriptTestAssertion{
205+
{
206+
Query: "INSERT INTO test VALUES ($1, $2, $3, $4) returning *",
207+
BindVars: []any{1, 10, 100, 1000},
208+
Expected: []sql.Row{
209+
{1, 10, 100, 1000},
210+
},
211+
},
212+
{
213+
Query: "INSERT INTO test VALUES ($1, $2, $3, $4) returning *",
214+
BindVars: []any{2, nil, nil, nil},
215+
Expected: []sql.Row{
216+
{2, nil, nil, nil},
217+
},
218+
},
189219
},
190220
},
191221
{
@@ -274,6 +304,13 @@ var preparedStatementTests = []ScriptTest{
274304
{1, "hello"},
275305
},
276306
},
307+
{
308+
Query: "INSERT INTO test VALUES ($1, $2) returning *",
309+
BindVars: []any{2, nil},
310+
Expected: []sql.Row{
311+
{2, nil},
312+
},
313+
},
277314
},
278315
},
279316
{
@@ -503,6 +540,13 @@ var preparedStatementTests = []ScriptTest{
503540
{3, "2024-04-01"},
504541
},
505542
},
543+
{
544+
Query: "INSERT INTO test VALUES ($1, $2) returning *;",
545+
BindVars: []any{"5", nil},
546+
Expected: []sql.Row{
547+
{5, nil},
548+
},
549+
},
506550
},
507551
},
508552
{
@@ -607,6 +651,13 @@ var preparedStatementTests = []ScriptTest{
607651
{2, "2024-12-25 09:15:30"},
608652
},
609653
},
654+
{
655+
Query: "INSERT INTO test VALUES ($1, $2) returning *;",
656+
BindVars: []any{"3", nil},
657+
Expected: []sql.Row{
658+
{3, nil},
659+
},
660+
},
610661
},
611662
},
612663
{
@@ -860,6 +911,13 @@ var preparedStatementTests = []ScriptTest{
860911
{2, "6ba7b810-9dad-11d1-80b4-00c04fd430c8"},
861912
},
862913
},
914+
{
915+
Query: "INSERT INTO test VALUES ($1, $2) returning *;",
916+
BindVars: []any{"3", nil},
917+
Expected: []sql.Row{
918+
{3, nil},
919+
},
920+
},
863921
},
864922
},
865923
{
@@ -958,6 +1016,13 @@ var preparedStatementTests = []ScriptTest{
9581016
{2, Numeric("999.99"), Numeric("12.345")},
9591017
},
9601018
},
1019+
{
1020+
Query: "INSERT INTO test VALUES ($1, $2, $3) returning *;",
1021+
BindVars: []any{"3", nil, nil},
1022+
Expected: []sql.Row{
1023+
{3, nil, nil},
1024+
},
1025+
},
9611026
},
9621027
},
9631028
{
@@ -1066,6 +1131,13 @@ var preparedStatementTests = []ScriptTest{
10661131
{3, "t"},
10671132
},
10681133
},
1134+
{
1135+
Query: "INSERT INTO test VALUES ($1, $2) returning *;",
1136+
BindVars: []any{"4", nil},
1137+
Expected: []sql.Row{
1138+
{4, nil},
1139+
},
1140+
},
10691141
},
10701142
},
10711143
{
@@ -1213,6 +1285,13 @@ var preparedStatementTests = []ScriptTest{
12131285
{4, []byte{0xC0, 0xFF, 0xEE}},
12141286
},
12151287
},
1288+
{
1289+
Query: "INSERT INTO t_bytea VALUES ($1, $2) returning *;",
1290+
BindVars: []any{5, nil},
1291+
Expected: []sql.Row{
1292+
{5, nil},
1293+
},
1294+
},
12161295
},
12171296
},
12181297
}

testing/go/types_test.go

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1066,6 +1066,25 @@ var typesTests = []ScriptTest{
10661066
{4, `{"key1": {"key": [2,3]}}`},
10671067
},
10681068
},
1069+
{
1070+
Query: "Insert into t_json values (100, null) returning *",
1071+
Expected: []sql.Row{
1072+
{100, nil},
1073+
},
1074+
},
1075+
{
1076+
Query: "select * from t_json where id = 100",
1077+
Expected: []sql.Row{
1078+
{100, nil},
1079+
},
1080+
},
1081+
{
1082+
Query: "Insert into t_json values ($1, $2) returning *",
1083+
BindVars: []any{"101", nil},
1084+
Expected: []sql.Row{
1085+
{101, nil},
1086+
},
1087+
},
10691088
{
10701089
Query: "SELECT '5'::json;",
10711090
Expected: []sql.Row{
@@ -1085,6 +1104,13 @@ var typesTests = []ScriptTest{
10851104
},
10861105
},
10871106
{
1107+
Query: `SELECT null::json;`,
1108+
Expected: []sql.Row{
1109+
{nil},
1110+
},
1111+
},
1112+
{
1113+
Skip: true, // https://github.com/jackc/pgx/issues/2430
10881114
Query: `SELECT 'null'::json;`,
10891115
Expected: []sql.Row{
10901116
{`null`},
@@ -1135,6 +1161,19 @@ var typesTests = []ScriptTest{
11351161
{2, `{"num": 42}`},
11361162
},
11371163
},
1164+
{
1165+
Query: "insert into t_jsonb values (3, null) returning *",
1166+
Expected: []sql.Row{
1167+
{3, nil},
1168+
},
1169+
},
1170+
{
1171+
Query: "insert into t_jsonb values ($1, $2) returning *",
1172+
BindVars: []any{"4", nil},
1173+
Expected: []sql.Row{
1174+
{4, nil},
1175+
},
1176+
},
11381177
{
11391178
Query: `SELECT '{"bar": "baz", "balance": 7.77, "active":false}'::jsonb;`,
11401179
Expected: []sql.Row{
@@ -1287,7 +1326,7 @@ var typesTests = []ScriptTest{
12871326
{
12881327
Query: "SELECT * FROM t_jsonb ORDER BY v1;",
12891328
Expected: []sql.Row{
1290-
{`null`},
1329+
{nil}, // should be "null", but https://github.com/jackc/pgx/issues/2430
12911330
{`"random string"`},
12921331
{`789.123`},
12931332
{`123456`},

0 commit comments

Comments
 (0)