Skip to content

Commit 797357d

Browse files
authored
Merge pull request #2231 from dolthub/fulghum/psycopg2
Add client tests for psycopg2 python client
2 parents e585dec + e77c0a3 commit 797357d

3 files changed

Lines changed: 121 additions & 1 deletion

File tree

testing/PostgresDockerfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ RUN apt update -y && \
1515
apt install -y \
1616
python3.8 \
1717
python3-pip \
18+
python3-psycopg2 \
1819
curl \
1920
wget \
2021
pkg-config \

testing/postgres-client-tests/postgres-client-tests.bats

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ teardown() {
4545
node $BATS_TEST_DIRNAME/node/workbench.js $USER $PORT $DOLTGRES_VERSION $BATS_TEST_DIRNAME/node/testdata
4646
}
4747

48-
4948
@test "perl DBI:Pg client" {
5049
perl $BATS_TEST_DIRNAME/perl/postgres-test.pl $USER $PORT
5150
}
@@ -68,3 +67,8 @@ teardown() {
6867
(cd $BATS_TEST_DIRNAME/c; make clean; make)
6968
$BATS_TEST_DIRNAME/c/postgres-c-connector-test $USER $PORT
7069
}
70+
71+
@test "python postgres: psycopg2 client" {
72+
cd $BATS_TEST_DIRNAME/python
73+
python3 psycopg2_test.py $USER $PORT
74+
}
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
#!/usr/bin/env python3
2+
import os
3+
import sys
4+
import traceback
5+
import psycopg2
6+
7+
# ---------------------------------------------------------------------------
8+
# Query list (kept at top for consistency with other tests)
9+
# ---------------------------------------------------------------------------
10+
11+
TEST_QUERIES = [
12+
"DROP TABLE IF EXISTS test",
13+
"create table test (pk int, value int, d1 decimal(9, 3), f1 float, c1 char(10), t1 text, primary key(pk))",
14+
"select * from test",
15+
"insert into test (pk, value, d1, f1, c1, t1) values (0,0,0.0,0.0,'abc','a1')",
16+
"select * from test",
17+
"select dolt_add('-A');",
18+
"select dolt_commit('-m', 'my commit')",
19+
"select COUNT(*) FROM dolt.log",
20+
"select dolt_checkout('-b', 'mybranch')",
21+
"insert into test (pk, value, d1, f1, c1, t1) values (10,10, 123456.789, 420.42,'example','some text')",
22+
"select dolt_commit('-a', '-m', 'my commit2')",
23+
"select dolt_checkout('main')",
24+
"select dolt_merge('mybranch')",
25+
"select COUNT(*) FROM dolt.log",
26+
]
27+
28+
# ---------------------------------------------------------------------------
29+
30+
def env(name, default=None):
31+
return os.getenv(name, default)
32+
33+
34+
def connect(user: str, port: int):
35+
conn = psycopg2.connect(
36+
host=env("PGHOST", "localhost"),
37+
port=port,
38+
dbname="postgres",
39+
user=user,
40+
password=env("PGPASSWORD", "password"),
41+
connect_timeout=int(env("PGCONNECT_TIMEOUT", "10")),
42+
sslmode=env("PGSSLMODE"),
43+
)
44+
conn.autocommit = True
45+
return conn
46+
47+
48+
def run(cur, q):
49+
print(f"SQL> {q}", flush=True)
50+
cur.execute(q)
51+
if cur.description is not None:
52+
cur.fetchall() # drain result set
53+
54+
# load_test creates a table with |n_rows| and asserts that all rows are correctly returned.
55+
def load_test(cur, n_rows=1000):
56+
print("\n=== Part 1: Load test ===", flush=True)
57+
58+
rows = max(1000, int(n_rows))
59+
60+
run(cur, "DROP TABLE IF EXISTS load_test")
61+
run(cur, "CREATE TABLE load_test (id INT PRIMARY KEY, val INT NOT NULL)")
62+
63+
data = [(i, i * 10) for i in range(rows)]
64+
cur.executemany(
65+
"INSERT INTO load_test (id, val) VALUES (%s, %s)",
66+
data,
67+
)
68+
69+
cur.execute("SELECT COUNT(*) FROM load_test")
70+
cnt = cur.fetchone()[0]
71+
if cnt != rows:
72+
raise AssertionError(f"COUNT(*) mismatch: expected {rows}, got {cnt}")
73+
74+
cur.execute("SELECT id FROM load_test ORDER BY id")
75+
got = cur.fetchall()
76+
if len(got) != rows:
77+
raise AssertionError(f"fetchall mismatch: expected {rows}, got {len(got)}")
78+
79+
print(f"Inserted and selected {rows} rows OK.", flush=True)
80+
81+
82+
def compliance_test(cur):
83+
print("\n=== Part 2: Test Queries ===", flush=True)
84+
for q in TEST_QUERIES:
85+
run(cur, q)
86+
print("Compliance queries executed OK.", flush=True)
87+
88+
89+
def main():
90+
if len(sys.argv) != 3:
91+
print("Usage: python3 psycopg2_test.py <user> <port>")
92+
return 2
93+
94+
user = sys.argv[1]
95+
port = int(sys.argv[2])
96+
load_rows = int(env("LOAD_ROWS", "1000"))
97+
98+
try:
99+
with connect(user, port) as conn:
100+
with conn.cursor() as cur:
101+
load_test(cur, load_rows)
102+
compliance_test(cur)
103+
104+
print("\n✅ All tests passed.", flush=True)
105+
return 0
106+
107+
except Exception as e:
108+
print("\n❌ Test failed.", flush=True)
109+
print(f"Error: {e}", flush=True)
110+
traceback.print_exc()
111+
return 1
112+
113+
114+
if __name__ == "__main__":
115+
sys.exit(main())

0 commit comments

Comments
 (0)