Skip to content

Commit 69d8a95

Browse files
authored
Merge pull request #2199 from dolthub/daylon/issue-2175
Support multiple statements
2 parents e726eb2 + b8830f0 commit 69d8a95

2 files changed

Lines changed: 291 additions & 33 deletions

File tree

server/connection_handler.go

Lines changed: 66 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,18 @@ func (h *ConnectionHandler) sendClientStartupMessages() error {
276276
}); err != nil {
277277
return err
278278
}
279+
if err := h.send(&pgproto3.ParameterStatus{
280+
Name: "standard_conforming_strings",
281+
Value: "on",
282+
}); err != nil {
283+
return err
284+
}
285+
if err := h.send(&pgproto3.ParameterStatus{
286+
Name: "in_hot_standby",
287+
Value: "off",
288+
}); err != nil {
289+
return err
290+
}
279291
return h.send(&pgproto3.BackendKeyData{
280292
ProcessID: processID,
281293
SecretKey: 0, // TODO: this should represent an ID that can uniquely identify this connection, so that CancelRequest will work
@@ -435,7 +447,7 @@ func (h *ConnectionHandler) handleQuery(message *pgproto3.Query) (endOfMessages
435447
return true, err
436448
}
437449

438-
query, err := h.convertQuery(message.String)
450+
queries, err := h.convertQuery(message.String)
439451
if err != nil {
440452
if printErrorStackTraces {
441453
fmt.Printf("Error parsing query: %+v\n", err)
@@ -447,18 +459,32 @@ func (h *ConnectionHandler) handleQuery(message *pgproto3.Query) (endOfMessages
447459
delete(h.preparedStatements, "")
448460
delete(h.portals, "")
449461

450-
// Certain statement types get handled directly by the handler instead of being passed to the engine
451-
handled, endOfMessages, err = h.handleQueryOutsideEngine(query)
452-
if handled {
453-
return endOfMessages, err
462+
if len(queries) == 1 {
463+
// empty query special case
464+
if queries[0].AST == nil {
465+
return true, h.send(&pgproto3.EmptyQueryResponse{})
466+
}
467+
handled, endOfMessages, err = h.handleQueryOutsideEngine(queries[0])
468+
if handled {
469+
return endOfMessages, err
470+
}
471+
return true, h.query(queries[0])
454472
}
455473

456-
// empty query special case
457-
if query.AST == nil {
458-
return true, h.send(&pgproto3.EmptyQueryResponse{})
474+
for _, query := range queries {
475+
handled, _, err = h.handleQueryOutsideEngine(query)
476+
if err != nil {
477+
return true, err
478+
}
479+
if handled {
480+
continue
481+
}
482+
err = h.query(query)
483+
if err != nil {
484+
return true, err
485+
}
459486
}
460-
461-
return true, h.query(query)
487+
return true, nil
462488
}
463489

464490
// handleQueryOutsideEngine handles any queries that should be handled by the handler directly, rather than being
@@ -498,13 +524,17 @@ func (h *ConnectionHandler) handleParse(message *pgproto3.Parse) error {
498524
h.waitForSync = true
499525

500526
// TODO: "Named prepared statements must be explicitly closed before they can be redefined by another Parse message, but this is not required for the unnamed statement"
501-
query, err := h.convertQuery(message.Query)
527+
queries, err := h.convertQuery(message.Query)
502528
if err != nil {
503529
if printErrorStackTraces {
504530
fmt.Printf("Error parsing query: %+v\n", err)
505531
}
506532
return err
507533
}
534+
if len(queries) != 1 {
535+
return errors.Errorf("cannot insert multiple commands into a prepared statement")
536+
}
537+
query := queries[0]
508538

509539
if query.AST == nil {
510540
// special case: empty query
@@ -1085,33 +1115,36 @@ func (h *ConnectionHandler) sendError(err error) {
10851115
}
10861116

10871117
// convertQuery takes the given Postgres query, and converts it as an ast.ConvertedQuery that will work with the handler.
1088-
func (h *ConnectionHandler) convertQuery(query string) (ConvertedQuery, error) {
1118+
// If the query string contains multiple queries, then multiple ConvertedQuery will be returned.
1119+
func (h *ConnectionHandler) convertQuery(query string) ([]ConvertedQuery, error) {
10891120
s, err := parser.Parse(query)
10901121
if err != nil {
1091-
return ConvertedQuery{}, err
1092-
}
1093-
if len(s) > 1 {
1094-
return ConvertedQuery{}, errors.Errorf("only a single statement at a time is currently supported")
1122+
return nil, err
10951123
}
10961124
if len(s) == 0 {
1097-
return ConvertedQuery{String: query}, nil
1125+
return []ConvertedQuery{{String: query}}, nil
10981126
}
1099-
vitessAST, err := ast.Convert(s[0])
1100-
stmtTag := s[0].AST.StatementTag()
1101-
if err != nil {
1102-
return ConvertedQuery{}, err
1103-
}
1104-
if vitessAST == nil {
1105-
return ConvertedQuery{
1106-
String: s[0].AST.String(),
1107-
StatementTag: stmtTag,
1108-
}, nil
1109-
}
1110-
return ConvertedQuery{
1111-
String: query,
1112-
AST: vitessAST,
1113-
StatementTag: stmtTag,
1114-
}, nil
1127+
converted := make([]ConvertedQuery, len(s))
1128+
for i := range s {
1129+
vitessAST, err := ast.Convert(s[i])
1130+
stmtTag := s[i].AST.StatementTag()
1131+
if err != nil {
1132+
return nil, err
1133+
}
1134+
if vitessAST == nil {
1135+
converted[i] = ConvertedQuery{
1136+
String: s[i].AST.String(),
1137+
StatementTag: stmtTag,
1138+
}
1139+
} else {
1140+
converted[i] = ConvertedQuery{
1141+
String: query,
1142+
AST: vitessAST,
1143+
StatementTag: stmtTag,
1144+
}
1145+
}
1146+
}
1147+
return converted, nil
11151148
}
11161149

11171150
// discardAll handles the DISCARD ALL command

testing/go/multi_statement_test.go

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
// Copyright 2026 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package _go
16+
17+
import (
18+
"context"
19+
"testing"
20+
21+
"github.com/dolthub/dolt/go/libraries/utils/svcs"
22+
"github.com/dolthub/go-mysql-server/sql"
23+
"github.com/jackc/pgx/v5"
24+
"github.com/jackc/pgx/v5/pgconn"
25+
"github.com/stretchr/testify/assert"
26+
"github.com/stretchr/testify/require"
27+
)
28+
29+
// TestMultipleStatements is a test for: https://github.com/dolthub/doltgresql/issues/2175
30+
func TestMultipleStatements(t *testing.T) {
31+
ctx := context.Background()
32+
var conn *Connection
33+
if runOnPostgres {
34+
pgxConn, err := pgx.Connect(ctx, "postgres://postgres:password@127.0.0.1:5432/postgres?sslmode=disable")
35+
require.NoError(t, err)
36+
conn = &Connection{
37+
Default: pgxConn,
38+
Current: pgxConn,
39+
}
40+
require.NoError(t, pgxConn.Ping(ctx))
41+
defer func() {
42+
conn.Close(ctx)
43+
}()
44+
} else {
45+
var controller *svcs.Controller
46+
ctx, conn, controller = CreateServer(t, "postgres")
47+
defer func() {
48+
conn.Close(ctx)
49+
controller.Stop()
50+
err := controller.WaitForStop()
51+
require.NoError(t, err)
52+
}()
53+
}
54+
queries := []string{
55+
`BEGIN;`,
56+
`DROP TABLE IF EXISTS migrations;`,
57+
`DROP TABLE IF EXISTS animals;`,
58+
`CREATE TABLE IF NOT EXISTS migrations (file_name TEXT NOT NULL, file_hash TEXT NOT NULL);`,
59+
`CREATE TABLE IF NOT EXISTS animals (id SERIAL PRIMARY KEY NOT NULL, name TEXT NOT NULL);`,
60+
`;`, // This should be ignored in the output
61+
`INSERT INTO migrations (file_name, file_hash) VALUES ('2021-09-07T154500-create-animals-table.sql', '42331f4277227d09e9bb32eeaf7e04d9c7fe320160e05372ed0ef010cfbf666b');`,
62+
`INSERT INTO animals(name) VALUES('Alpaca');`,
63+
`INSERT INTO animals(name) VALUES('Highland cow');`,
64+
`INSERT INTO animals(name) VALUES('Aardvark');`,
65+
`INSERT INTO migrations (file_name, file_hash) VALUES ('2021-09-07T154700-insert-animals.sql', '3223d0deb6fb7fb2accf6abffc0667ebe4503379987c472d10a585a553f9b3b6');`,
66+
`SELECT * FROM migrations ORDER BY file_name;`,
67+
`SELECT * FROM animals ORDER BY id;`,
68+
`COMMIT;`,
69+
}
70+
combinedQueries := ""
71+
for _, query := range queries {
72+
// We do this just to homogenize the queries, even though we're adding the delimiter right back
73+
query = sql.RemoveSpaceAndDelimiter(query, ';')
74+
combinedQueries += query + ";"
75+
}
76+
// First we'll check all invalid modes that fail immediately
77+
invalidModes := []pgx.QueryExecMode{
78+
pgx.QueryExecModeCacheStatement,
79+
pgx.QueryExecModeCacheDescribe,
80+
pgx.QueryExecModeDescribeExec,
81+
pgx.QueryExecModeExec,
82+
}
83+
for _, mode := range invalidModes {
84+
rows, err := conn.Current.Query(ctx, combinedQueries, mode)
85+
if mode == pgx.QueryExecModeExec {
86+
// This mode requires reading from the returned rows to find the error, rather than erroring immediately
87+
require.NoError(t, err)
88+
_ = rows.Next()
89+
err = rows.Err()
90+
} else {
91+
require.Error(t, err)
92+
}
93+
require.Contains(t, err.Error(), "cannot insert multiple commands into a prepared statement")
94+
}
95+
// Then we'll check the singular valid mode
96+
rows, err := conn.Current.Query(ctx, combinedQueries, pgx.QueryExecModeSimpleProtocol)
97+
require.NoError(t, err)
98+
require.False(t, rows.Next()) // Simple mode doesn't return results with multiple statements
99+
rows.Close()
100+
// Now we'll use the underlying connection to verify all returned results
101+
mrr := conn.Current.PgConn().Exec(ctx, combinedQueries)
102+
results, err := mrr.ReadAll()
103+
require.NoError(t, err)
104+
if assert.Len(t, results, len(testMultipleStatementsResults)) {
105+
for resultIdx, expected := range testMultipleStatementsResults {
106+
result := results[resultIdx]
107+
if assert.Equal(t, len(expected.FieldDescriptions), len(result.FieldDescriptions)) {
108+
for fieldIdx, expectedField := range expected.FieldDescriptions {
109+
resultField := result.FieldDescriptions[fieldIdx]
110+
assert.Equal(t, expectedField.Name, resultField.Name)
111+
assert.Equal(t, expectedField.DataTypeOID, resultField.DataTypeOID)
112+
assert.Equal(t, expectedField.DataTypeSize, resultField.DataTypeSize)
113+
assert.Equal(t, expectedField.TypeModifier, resultField.TypeModifier)
114+
assert.Equal(t, expectedField.Format, resultField.Format)
115+
}
116+
}
117+
if assert.Equal(t, len(expected.Rows), len(result.Rows)) {
118+
for rowIdx, expectedRow := range expected.Rows {
119+
resultRow := result.Rows[rowIdx]
120+
for columnIdx, expectedCol := range expectedRow {
121+
assert.Equal(t, expectedCol, resultRow[columnIdx])
122+
}
123+
}
124+
}
125+
assert.Equal(t, expected.CommandTag, result.CommandTag)
126+
}
127+
}
128+
require.NoError(t, mrr.Close())
129+
130+
// Now we'll ensure that errors are properly handled within multiple statements
131+
queries = []string{
132+
`INSERT INTO animals(name) VALUES('Pigeon');`,
133+
`SELECT * FROM non_existent;`,
134+
`INSERT INTO animals(name) VALUES('Elephant');`,
135+
}
136+
combinedQueries = ""
137+
for _, query := range queries {
138+
query = sql.RemoveSpaceAndDelimiter(query, ';')
139+
combinedQueries += query + ";"
140+
}
141+
mrr = conn.Current.PgConn().Exec(ctx, combinedQueries)
142+
results, err = mrr.ReadAll()
143+
require.Error(t, err)
144+
require.Contains(t, err.Error(), "non_existent")
145+
if assert.Len(t, results, 1) {
146+
assert.Equal(t, results[0].CommandTag, pgconn.NewCommandTag("INSERT 0 1"))
147+
}
148+
}
149+
150+
// testMultipleStatementsResults are used within TestMultipleStatements
151+
var testMultipleStatementsResults = []pgconn.Result{
152+
{CommandTag: pgconn.NewCommandTag("BEGIN")},
153+
{CommandTag: pgconn.NewCommandTag("DROP TABLE")},
154+
{CommandTag: pgconn.NewCommandTag("DROP TABLE")},
155+
{CommandTag: pgconn.NewCommandTag("CREATE TABLE")},
156+
{CommandTag: pgconn.NewCommandTag("CREATE TABLE")},
157+
{CommandTag: pgconn.NewCommandTag("INSERT 0 1")},
158+
{CommandTag: pgconn.NewCommandTag("INSERT 0 1")},
159+
{CommandTag: pgconn.NewCommandTag("INSERT 0 1")},
160+
{CommandTag: pgconn.NewCommandTag("INSERT 0 1")},
161+
{CommandTag: pgconn.NewCommandTag("INSERT 0 1")},
162+
{
163+
FieldDescriptions: []pgconn.FieldDescription{
164+
{
165+
Name: "file_name",
166+
DataTypeOID: 25,
167+
DataTypeSize: -1,
168+
TypeModifier: -1,
169+
Format: 0,
170+
},
171+
{
172+
Name: "file_hash",
173+
DataTypeOID: 25,
174+
DataTypeSize: -1,
175+
TypeModifier: -1,
176+
Format: 0,
177+
},
178+
},
179+
Rows: [][][]byte{
180+
{
181+
[]byte("2021-09-07T154500-create-animals-table.sql"),
182+
[]byte("42331f4277227d09e9bb32eeaf7e04d9c7fe320160e05372ed0ef010cfbf666b"),
183+
},
184+
{
185+
[]byte("2021-09-07T154700-insert-animals.sql"),
186+
[]byte("3223d0deb6fb7fb2accf6abffc0667ebe4503379987c472d10a585a553f9b3b6"),
187+
},
188+
},
189+
CommandTag: pgconn.NewCommandTag("SELECT 2"),
190+
},
191+
{
192+
FieldDescriptions: []pgconn.FieldDescription{
193+
{
194+
Name: "id",
195+
DataTypeOID: 23,
196+
DataTypeSize: 4,
197+
TypeModifier: -1,
198+
Format: 0,
199+
},
200+
{
201+
Name: "name",
202+
DataTypeOID: 25,
203+
DataTypeSize: -1,
204+
TypeModifier: -1,
205+
Format: 0,
206+
},
207+
},
208+
Rows: [][][]byte{
209+
{
210+
[]byte("1"),
211+
[]byte("Alpaca"),
212+
},
213+
{
214+
[]byte("2"),
215+
[]byte("Highland cow"),
216+
},
217+
{
218+
[]byte("3"),
219+
[]byte("Aardvark"),
220+
},
221+
},
222+
CommandTag: pgconn.NewCommandTag("SELECT 3"),
223+
},
224+
{CommandTag: pgconn.NewCommandTag("COMMIT")},
225+
}

0 commit comments

Comments
 (0)