Skip to content

Commit a299e0c

Browse files
committed
Noah + Copilot feedback
1 parent b483a48 commit a299e0c

1 file changed

Lines changed: 13 additions & 51 deletions

File tree

test/test_ocsp_support.py

Lines changed: 13 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import sys
1919
from datetime import datetime, timedelta, timezone
20+
from typing import cast
2021
from unittest.mock import MagicMock, Mock, patch
2122

2223
import pytest
@@ -27,6 +28,9 @@
2728

2829
pytestmark = pytest.mark.ocsp
2930

31+
pytest.importorskip("cryptography")
32+
pytest.importorskip("requests")
33+
3034
from cryptography.exceptions import InvalidSignature
3135
from cryptography.hazmat.primitives.asymmetric.dsa import DSAPublicKey
3236
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicKey
@@ -36,14 +40,14 @@
3640
from cryptography.x509 import (
3741
AuthorityInformationAccess,
3842
ExtensionNotFound,
43+
Name,
3944
TLSFeature,
4045
TLSFeatureType,
4146
)
4247
from cryptography.x509.ocsp import OCSPCertStatus, OCSPResponseStatus
4348
from cryptography.x509.oid import AuthorityInformationAccessOID, ExtendedKeyUsageOID
4449

4550
from pymongo.ocsp_support import (
46-
_build_ocsp_request,
4751
_get_certs_by_key_hash,
4852
_get_certs_by_name,
4953
_get_extension,
@@ -111,40 +115,22 @@ def test_dsa_valid(self):
111115
self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 1)
112116
key.verify.assert_called_once()
113117

114-
def test_dsa_invalid(self):
115-
key = MagicMock(spec=DSAPublicKey)
116-
key.verify.side_effect = InvalidSignature()
117-
self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 0)
118-
119118
def test_ec_valid(self):
120119
key = MagicMock(spec=EllipticCurvePublicKey)
121120
self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 1)
122121
key.verify.assert_called_once()
123122

124-
def test_ec_invalid(self):
125-
key = MagicMock(spec=EllipticCurvePublicKey)
126-
key.verify.side_effect = InvalidSignature()
127-
self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 0)
128-
129123
def test_x25519_skips_verify(self):
130124
key = MagicMock(spec=X25519PublicKey)
131-
# X25519 is for key exchange only; verify is not called, returns 1
132125
self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 1)
133126

134127
def test_x448_skips_verify(self):
135128
key = MagicMock(spec=X448PublicKey)
136-
# X448 is for key exchange only; verify is not called, returns 1
137129
self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 1)
138130

139131
def test_other_key_valid(self):
140132
key = Mock()
141133
self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 1)
142-
key.verify.assert_called_once_with(b"sig", b"data")
143-
144-
def test_other_key_invalid(self):
145-
key = Mock()
146-
key.verify.side_effect = InvalidSignature()
147-
self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 0)
148134

149135

150136
class TestGetExtension(unittest.TestCase):
@@ -167,7 +153,6 @@ def test_rsa(self):
167153
cert = Mock()
168154
cert.public_key.return_value = key
169155
result = _public_key_hash(cert)
170-
self.assertIsInstance(result, bytes)
171156
self.assertEqual(len(result), 20) # SHA-1 digest
172157

173158
def test_ec(self):
@@ -176,23 +161,20 @@ def test_ec(self):
176161
cert = Mock()
177162
cert.public_key.return_value = key
178163
result = _public_key_hash(cert)
179-
self.assertIsInstance(result, bytes)
180164
self.assertEqual(len(result), 20)
181165

182166
def test_other_key_type(self):
183-
# Covers the else branch (Ed25519, Ed448, etc.)
184167
key = Mock()
185168
key.public_bytes.return_value = b"other_key_bytes"
186169
cert = Mock()
187170
cert.public_key.return_value = key
188171
result = _public_key_hash(cert)
189-
self.assertIsInstance(result, bytes)
190172
self.assertEqual(len(result), 20)
191173

192174

193-
class TestGetCertsByKeyHash(unittest.TestCase):
175+
class TestGetCerts(unittest.TestCase):
194176
@patch("pymongo.ocsp_support._public_key_hash")
195-
def test_match(self, mock_hash):
177+
def test_by_key_hash_match(self, mock_hash):
196178
issuer = Mock()
197179
issuer.subject = "issuer_subject"
198180
cert1 = Mock()
@@ -205,7 +187,7 @@ def test_match(self, mock_hash):
205187
self.assertEqual(result, [cert1])
206188

207189
@patch("pymongo.ocsp_support._public_key_hash")
208-
def test_no_match(self, mock_hash):
190+
def test_by_key_hash_no_match(self, mock_hash):
209191
issuer = Mock()
210192
issuer.subject = "issuer_subject"
211193
cert = Mock()
@@ -215,9 +197,7 @@ def test_no_match(self, mock_hash):
215197
result = _get_certs_by_key_hash([cert], issuer, b"expected_hash")
216198
self.assertEqual(result, [])
217199

218-
219-
class TestGetCertsByName(unittest.TestCase):
220-
def test_match(self):
200+
def test_by_name_match(self):
221201
issuer = Mock()
222202
issuer.subject = "issuer"
223203
cert1 = Mock()
@@ -227,36 +207,20 @@ def test_match(self):
227207
cert2.subject = "other"
228208
cert2.issuer = "issuer"
229209

230-
result = _get_certs_by_name([cert1, cert2], issuer, "responder")
210+
result = _get_certs_by_name([cert1, cert2], issuer, cast(Name, "responder"))
231211
self.assertEqual(result, [cert1])
232212

233-
def test_no_match(self):
213+
def test_by_name_no_match(self):
234214
issuer = Mock()
235215
issuer.subject = "issuer"
236216
cert = Mock()
237217
cert.subject = "other"
238218
cert.issuer = "issuer"
239219

240-
result = _get_certs_by_name([cert], issuer, "responder")
220+
result = _get_certs_by_name([cert], issuer, cast(Name, "responder"))
241221
self.assertEqual(result, [])
242222

243223

244-
class TestBuildOcspRequest(unittest.TestCase):
245-
@patch("pymongo.ocsp_support._OCSPRequestBuilder")
246-
def test_builds_request(self, mock_builder_class):
247-
mock_builder = Mock()
248-
mock_builder.add_certificate.return_value = mock_builder
249-
mock_request = Mock()
250-
mock_builder.build.return_value = mock_request
251-
mock_builder_class.return_value = mock_builder
252-
253-
result = _build_ocsp_request(Mock(), Mock())
254-
255-
self.assertEqual(result, mock_request)
256-
mock_builder.add_certificate.assert_called_once()
257-
mock_builder.build.assert_called_once()
258-
259-
260224
class TestVerifyResponseSignature(unittest.TestCase):
261225
@patch("pymongo.ocsp_support._verify_signature")
262226
def test_responder_is_issuer_by_name(self, mock_verify_sig):
@@ -515,7 +479,7 @@ def test_serial_number_mismatch(self, mock_build, mock_post, mock_load, _):
515479
mock_post.return_value = http_resp
516480
ocsp_resp = Mock()
517481
ocsp_resp.response_status = OCSPResponseStatus.SUCCESSFUL
518-
ocsp_resp.serial_number = 99999 # Mismatch
482+
ocsp_resp.serial_number = 99999
519483
mock_load.return_value = ocsp_resp
520484

521485
result = _get_ocsp_response(Mock(), Mock(), "http://ocsp.example.com", cache)
@@ -788,11 +752,9 @@ def test_stapled_good(self, _, mock_issuer, mock_load, __, mock_build):
788752
@patch("pymongo.ocsp_support._get_issuer_cert", return_value=None)
789753
@patch("pymongo.ocsp_support._get_extension", return_value=None)
790754
def test_uses_peer_cert_chain_fallback(self, _, __):
791-
# conn without get_verified_chain triggers the fallback path
792755
conn = self._setup_conn(has_verified_chain=False)
793756
user_data = self._setup_user_data()
794757
user_data.trusted_ca_certs = []
795-
# No AIA (_get_extension returns None) → soft fail → True
796758
self.assertTrue(_ocsp_callback(conn, b"", user_data))
797759

798760

0 commit comments

Comments
 (0)