Skip to content

Commit b040612

Browse files
authored
Merge pull request #2346 from dolthub/daylon/wire-tests
Added test framework for testing wire data
2 parents 1425697 + a911a7a commit b040612

1 file changed

Lines changed: 364 additions & 0 deletions

File tree

testing/go/wire_test.go

Lines changed: 364 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,364 @@
1+
// Copyright 2026 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package _go
16+
17+
import (
18+
"fmt"
19+
"net"
20+
"os"
21+
"testing"
22+
"time"
23+
24+
"github.com/cockroachdb/errors"
25+
"github.com/dolthub/go-mysql-server/sql"
26+
"github.com/jackc/pgx/v5/pgproto3"
27+
"github.com/stretchr/testify/assert"
28+
"github.com/stretchr/testify/require"
29+
30+
"github.com/dolthub/doltgresql/server"
31+
)
32+
33+
// TestWireImplementation allows us to directly test what is received on the wire, ensuring that the wire protocol is
34+
// correctly implemented.
35+
func TestWireImplementation(t *testing.T) {
36+
RunWireScripts(t, []WireScriptTest{
37+
{
38+
Name: "Smoke Test",
39+
SetUpScript: []string{
40+
"CREATE TABLE test (pk INT4 PRIMARY KEY);",
41+
"INSERT INTO test VALUES (7);",
42+
},
43+
Assertions: []WireScriptTestAssertion{
44+
{
45+
Send: []pgproto3.FrontendMessage{
46+
&pgproto3.Query{String: "SELECT * FROM test;"},
47+
},
48+
Receive: []pgproto3.BackendMessage{
49+
&pgproto3.RowDescription{
50+
Fields: []pgproto3.FieldDescription{
51+
{
52+
Name: []byte("pk"),
53+
TableOID: 0,
54+
TableAttributeNumber: 0,
55+
DataTypeOID: 23,
56+
DataTypeSize: 4,
57+
TypeModifier: -1,
58+
Format: 0,
59+
},
60+
},
61+
},
62+
&pgproto3.DataRow{Values: [][]byte{[]byte("7")}},
63+
&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")},
64+
&pgproto3.ReadyForQuery{TxStatus: 'I'},
65+
},
66+
},
67+
},
68+
},
69+
})
70+
}
71+
72+
// IgnoreMessageParameters is used to ignore certain fields within a backend message, as they may not yet be implemented
73+
// and therefore will return incorrect results (or variable results, such as with non-stable OIDs).
74+
func IgnoreMessageParameters(message pgproto3.BackendMessage) pgproto3.BackendMessage {
75+
switch message := message.(type) {
76+
case *pgproto3.RowDescription:
77+
for i := range message.Fields {
78+
message.Fields[i].TableOID = 0
79+
message.Fields[i].TableAttributeNumber = 0
80+
}
81+
return message
82+
default:
83+
return message
84+
}
85+
}
86+
87+
// WireScriptTest is used to test wire messages, ensuring that our wire protocol behaves as expected.
88+
type WireScriptTest struct {
89+
// Name of the script.
90+
Name string
91+
// The database to create and use. If not provided, then it defaults to "postgres".
92+
Database string
93+
// The SQL statements to execute as setup, in order. Results are not checked, but statements must not error.
94+
SetUpScript []string
95+
// The set of assertions to make after setup, in order
96+
Assertions []WireScriptTestAssertion
97+
// When using RunScripts, setting this on one (or more) tests causes RunWireScripts to ignore all tests that have
98+
// this set to false (which is the default value). This allows a developer to easily "focus" on a specific test
99+
// without having to comment out other tests, pull it into a different function, etc. In addition, CI ensures that
100+
// this is false before passing, meaning this prevents the commented-out situation where the developer forgets to
101+
// uncomment their code.
102+
Focus bool
103+
// Skip is used to completely skip a test
104+
Skip bool
105+
}
106+
107+
// WireScriptTestAssertion are the assertions upon which the script executes its main "testing" logic.
108+
type WireScriptTestAssertion struct {
109+
// These are sent as a batch to the server
110+
Send []pgproto3.FrontendMessage
111+
// These are the expected results that are received from the server, and must match in both contents and order
112+
Receive []pgproto3.BackendMessage
113+
// This functions the same as Focus on WireScriptTest, except that it applies to the assertion
114+
Focus bool
115+
// This is used to skip an assertion
116+
Skip bool
117+
}
118+
119+
// RawWireConnection is a connection that allows us to directly send and receive messages to a server.
120+
type RawWireConnection struct {
121+
frontend *pgproto3.Frontend
122+
connection net.Conn
123+
network string
124+
timeout time.Duration
125+
startup *pgproto3.StartupMessage
126+
errChan chan error
127+
}
128+
129+
// NewRawWireConnection returns a new RawWireConnection.
130+
func NewRawWireConnection(t *testing.T, host string, port int, timeout time.Duration) *RawWireConnection {
131+
network := net.JoinHostPort(host, fmt.Sprintf("%d", port))
132+
connection, err := (&net.Dialer{}).Dial("tcp", network)
133+
require.NoError(t, err)
134+
c := &RawWireConnection{
135+
frontend: pgproto3.NewFrontend(connection, connection),
136+
connection: connection,
137+
network: network,
138+
timeout: timeout,
139+
startup: nil,
140+
errChan: make(chan error),
141+
}
142+
c.init(t)
143+
return c
144+
}
145+
146+
// Close closes the internal connection.
147+
func (c *RawWireConnection) Close() {
148+
_ = c.connection.Close()
149+
close(c.errChan)
150+
}
151+
152+
// EmptyReceiveBuffer empties the buffer used by Receive. Returns an error if the buffer was not empty.
153+
func (c *RawWireConnection) EmptyReceiveBuffer() error {
154+
if c.frontend.ReadBufferLen() > 0 {
155+
for c.frontend.ReadBufferLen() > 0 {
156+
_, _ = c.frontend.Receive()
157+
}
158+
return errors.New("Doltgres sent additional messages after ReadyForQuery")
159+
}
160+
return nil
161+
}
162+
163+
// Receive returns the next message from the backend.
164+
func (c *RawWireConnection) Receive(t *testing.T) (pgproto3.BackendMessage, error) {
165+
var message pgproto3.BackendMessage
166+
go func() {
167+
var err error
168+
message, err = c.frontend.Receive()
169+
c.errChan <- err
170+
}()
171+
return message, c.handleErrorChannel(t, false)
172+
}
173+
174+
// Send sends the given messages over the wire. If an error is returned, then the connection has been closed, and a new
175+
// one should be opened.
176+
func (c *RawWireConnection) Send(t *testing.T, messages ...pgproto3.FrontendMessage) error {
177+
if len(messages) == 0 {
178+
return nil
179+
}
180+
hasMessage := false
181+
for _, message := range messages {
182+
if message == nil {
183+
continue
184+
}
185+
hasMessage = true
186+
if startupMessage, ok := message.(*pgproto3.StartupMessage); ok {
187+
c.startup = startupMessage
188+
}
189+
c.frontend.Send(message)
190+
}
191+
if !hasMessage {
192+
return nil
193+
}
194+
go func() {
195+
c.errChan <- c.frontend.Flush()
196+
}()
197+
return c.handleErrorChannel(t, true)
198+
}
199+
200+
// init handles the startup message and initial messages from the server.
201+
func (c *RawWireConnection) init(t *testing.T) {
202+
err := c.Send(t, &pgproto3.StartupMessage{
203+
ProtocolVersion: 196608,
204+
Parameters: map[string]string{
205+
"timezone": "PST8PDT",
206+
"user": "postgres",
207+
"database": "postgres",
208+
"options": " -c intervalstyle=postgres_verbose",
209+
"application_name": "pg_regress",
210+
"client_encoding": "WIN1252",
211+
"datestyle": "Postgres, MDY",
212+
},
213+
})
214+
require.NoError(t, err)
215+
StartupLoop:
216+
for {
217+
postgresMessage, err := c.Receive(t)
218+
require.NoError(t, err)
219+
switch response := postgresMessage.(type) {
220+
case *pgproto3.AuthenticationOk:
221+
case *pgproto3.BackendKeyData:
222+
case *pgproto3.ErrorResponse:
223+
t.Log(response.Message)
224+
t.FailNow()
225+
case *pgproto3.ParameterStatus:
226+
case *pgproto3.ReadyForQuery:
227+
break StartupLoop
228+
default:
229+
t.Logf("unknown StartupMessage response type: %T", response)
230+
t.FailNow()
231+
}
232+
}
233+
}
234+
235+
// handleErrorChannel handles errors while ensuring that stuck queries do not cause an infinite loop via a timeout.
236+
func (c *RawWireConnection) handleErrorChannel(t *testing.T, isSend bool) error {
237+
var err error
238+
select {
239+
case err = <-c.errChan:
240+
case <-time.After(c.timeout):
241+
if isSend {
242+
err = errors.New("timeout during Send")
243+
} else {
244+
err = errors.New("timeout during Receive")
245+
}
246+
}
247+
// On error, we must create a new connection since we cut the old one
248+
if err != nil {
249+
_ = c.connection.Close()
250+
connection, nErr := (&net.Dialer{}).Dial("tcp", c.network)
251+
if nErr != nil {
252+
panic(fmt.Errorf("Unable to create a new connection:\n%s\n\nOriginal error:\n%s", nErr.Error(), err.Error()))
253+
}
254+
c.connection = connection
255+
c.frontend = pgproto3.NewFrontend(connection, connection)
256+
c.init(t)
257+
}
258+
return err
259+
}
260+
261+
// RunWireScripts runs the given collection of scripts.
262+
func RunWireScripts(t *testing.T, scripts []WireScriptTest) {
263+
// First, we'll run through the scripts to check for the Focus variable. If it's true, then append it to the new slice.
264+
focusScripts := make([]WireScriptTest, 0, len(scripts))
265+
for _, script := range scripts {
266+
if script.Focus {
267+
// If this is running in GitHub Actions, then we'll panic, because someone forgot to disable it before committing
268+
if _, ok := os.LookupEnv("GITHUB_ACTION"); ok {
269+
panic(fmt.Sprintf("The wire script `%s` has Focus set to `true`. GitHub Actions requires that "+
270+
"all tests are run, which Focus circumvents, leading to this error. Please disable Focus on "+
271+
"all tests.", script.Name))
272+
}
273+
focusScripts = append(focusScripts, script)
274+
}
275+
}
276+
// If we have scripts with Focus set, then we replace the normal script slice with the new slice.
277+
if len(focusScripts) > 0 {
278+
scripts = focusScripts
279+
}
280+
// TODO: for now, our wire handler can't authenticate itself, so we disable it for these tests.
281+
// This prevents things such as testing multiple users, so it should be implemented at some point.
282+
server.EnableAuthentication = false
283+
defer func() {
284+
server.EnableAuthentication = true
285+
}()
286+
287+
for _, script := range scripts {
288+
t.Run(script.Name, func(t *testing.T) {
289+
if script.Skip {
290+
t.Skip()
291+
}
292+
293+
scriptDatabase := script.Database
294+
if len(scriptDatabase) == 0 {
295+
scriptDatabase = "postgres"
296+
}
297+
port, err := sql.GetEmptyPort()
298+
require.NoError(t, err)
299+
ctx, conn, controller := CreateServerWithPort(t, scriptDatabase, port)
300+
defer func() {
301+
controller.Stop()
302+
err := controller.WaitForStop()
303+
require.NoError(t, err)
304+
}()
305+
for _, query := range script.SetUpScript {
306+
_, err = conn.Exec(ctx, query)
307+
require.NoError(t, err, "error running setup query: %s", query)
308+
}
309+
conn.Close(ctx)
310+
rawConn := NewRawWireConnection(t, "localhost", port, 10*time.Second)
311+
defer rawConn.Close()
312+
313+
// With everything set up, we now check for Focus on the assertions
314+
assertions := script.Assertions
315+
// First, we'll run through the scripts to check for the Focus variable. If it's true, then append it to the new slice.
316+
focusAssertions := make([]WireScriptTestAssertion, 0, len(script.Assertions))
317+
for _, assertion := range script.Assertions {
318+
if assertion.Focus {
319+
// If this is running in GitHub Actions, then we'll panic, because someone forgot to disable it before committing
320+
if _, ok := os.LookupEnv("GITHUB_ACTION"); ok {
321+
panic("A wire assertion has Focus set to `true`. GitHub Actions requires that " +
322+
"all non-skipped assertions are run, which Focus circumvents, leading to this error. " +
323+
"Please disable Focus on all wire assertions.")
324+
}
325+
focusAssertions = append(focusAssertions, assertion)
326+
}
327+
}
328+
// If we have assertions with Focus set, then we replace the original slice with the new slice.
329+
if len(focusAssertions) > 0 {
330+
assertions = focusAssertions
331+
}
332+
333+
// Run the assertions
334+
for assertionIdx, assertion := range assertions {
335+
t.Run(fmt.Sprintf("%d", assertionIdx), func(t *testing.T) {
336+
if assertion.Skip {
337+
t.Skip("Skip has been set in the assertion")
338+
}
339+
err = rawConn.Send(t, assertion.Send...)
340+
require.NoError(t, err)
341+
for _, expectedMessage := range assertion.Receive {
342+
actualMessage, err := rawConn.Receive(t)
343+
require.NoError(t, err)
344+
if !assert.Equal(t, IgnoreMessageParameters(expectedMessage), IgnoreMessageParameters(actualMessage)) {
345+
// If the assertion fails, then we have to sync to the ReadyForQuery message
346+
if _, ok := actualMessage.(*pgproto3.ReadyForQuery); !ok {
347+
for {
348+
actualMessage, err := rawConn.Receive(t)
349+
require.NoError(t, err)
350+
if _, ok = actualMessage.(*pgproto3.ReadyForQuery); ok {
351+
return
352+
}
353+
}
354+
}
355+
}
356+
}
357+
// We then ensure that there are no other messages that were not accounted for by the assertion
358+
// (which we consider an error)
359+
_ = assert.NoError(t, rawConn.EmptyReceiveBuffer())
360+
})
361+
}
362+
})
363+
}
364+
}

0 commit comments

Comments
 (0)