1111from http .client import HTTPConnection
1212from urllib .parse import urlparse
1313
14+ from cryptography .hazmat ._oid import ExtensionOID
15+ from cryptography .hazmat .bindings ._rust import ObjectIdentifier
16+ from cryptography .hazmat .primitives .serialization import Encoding , PrivateFormat , NoEncryption , load_pem_private_key
17+
1418from cryptography import x509
19+ from cryptography .x509 import DNSName , ExtensionNotFound
1520
1621SEC_PER_DAY = 24 * 60 * 60
1722
2126
2227class MDCertUtil (object ):
2328 # Utility class for inspecting certificates in test cases
24- # Uses PyOpenSSL: https://pyopenssl.org/en/stable/index.html
2529
2630 @classmethod
2731 def load_server_cert (cls , host_ip , host_port , host_name , tls = None , ciphers = None ):
@@ -42,12 +46,12 @@ def load_server_cert(cls, host_ip, host_port, host_name, tls=None, ciphers=None)
4246 connection .setblocking (1 )
4347 connection .set_tlsext_host_name (host_name .encode ('utf-8' ))
4448 connection .do_handshake ()
45- peer_cert = connection .get_peer_certificate ()
46- return MDCertUtil (None , cert = peer_cert )
49+ ossl_cert = connection .get_peer_certificate ()
50+ return MDCertUtil (None , cert = ossl_cert . to_cryptography () )
4751
4852 @classmethod
4953 def parse_pem_cert (cls , text ):
50- cert = OpenSSL . crypto . load_certificate ( OpenSSL . crypto . FILETYPE_PEM , text .encode ('utf-8' ))
54+ cert = x509 . load_pem_x509_certificate ( text .encode ('utf-8' ))
5155 return MDCertUtil (None , cert = cert )
5256
5357 @classmethod
@@ -72,46 +76,47 @@ def get_plain(cls, url, timeout):
7276 return None
7377
7478 def __init__ (self , cert_path , cert = None ):
79+ self .cert = cert
80+ self .privkey = None
7581 if cert_path is not None :
7682 self .cert_path = cert_path
7783 # load certificate and private key
7884 if cert_path .startswith ("http" ):
79- cert_data = self . get_plain ( cert_path , 1 )
80- else :
81- cert_data = MDCertUtil . _load_binary_file (cert_path )
82-
83- for file_type in ( OpenSSL . crypto . FILETYPE_PEM , OpenSSL . crypto . FILETYPE_ASN1 ) :
84- try :
85- self . cert = OpenSSL . crypto . load_certificate ( file_type , cert_data )
86- except Exception as error :
87- self .error = error
88- if cert is not None :
89- self . cert = cert
90-
91- if self . cert is None :
92- raise self .error
85+ assert False
86+ try :
87+ with open (cert_path ) as fd :
88+ cert = x509 . load_pem_x509_certificate ( "" . join ( fd . readlines ()). encode ())
89+ except Exception as error :
90+ self . error = error
91+ if cert is not None :
92+ self . cert = cert
93+ if self .cert is None :
94+ raise self . error
95+
96+ def add_privkey ( self , path , password = None ):
97+ with open ( path ) as fd :
98+ self .privkey = load_pem_private_key ( "" . join ( fd . readlines ()). encode (), password = password )
9399
94100 def get_issuer (self ):
95101 return self .cert .get_issuer ()
96102
97103 def get_serial (self ):
98104 # the string representation of a serial number is not unique. Some
99105 # add leading 0s to align with word boundaries.
100- return ("%lx" % (self .cert .get_serial_number () )).upper ()
106+ return ("%lx" % (self .cert .serial_number )).upper ()
101107
102108 @staticmethod
103109 def _get_serial (cert ) -> int :
104110 if isinstance (cert , x509 .Certificate ):
105111 return cert .serial_number
106112 if isinstance (cert , MDCertUtil ):
107- return cert .get_serial_number ()
108- elif isinstance (cert , OpenSSL .crypto .X509 ):
109- return cert .get_serial_number ()
113+ return cert .cert .serial_number
110114 elif isinstance (cert , str ):
111115 # assume a hex number
112116 return int (cert , 16 )
113117 elif isinstance (cert , int ):
114118 return cert
119+ assert False , f'{ cert } '
115120 return 0
116121
117122 def get_serial_number (self ):
@@ -121,89 +126,33 @@ def same_serial_as(self, other):
121126 return self ._get_serial (self .cert ) == self ._get_serial (other )
122127
123128 def get_not_before (self ):
124- tsp = self .cert .get_notBefore ()
125- return self ._parse_tsp (tsp )
129+ try :
130+ return self .cert .not_valid_before_utc
131+ except AttributeError :
132+ return self .cert .not_valid_before
126133
127134 def get_not_after (self ):
128- tsp = self .cert .get_notAfter ()
129- return self ._parse_tsp (tsp )
130-
131- def get_cn (self ):
132- return self .cert .get_subject ().CN
135+ try :
136+ return self .cert .not_valid_after_utc
137+ except AttributeError :
138+ return self .cert .not_valid_after
133139
134140 def get_key_length (self ):
135- return self .cert .get_pubkey ().bits ()
141+ return self .cert .public_key ().key_size
136142
137143 def get_san_list (self ):
138- text = OpenSSL .crypto .dump_certificate (OpenSSL .crypto .FILETYPE_TEXT , self .cert ).decode ("utf-8" )
139- m = re .search (r"X509v3 Subject Alternative Name:(\s+critical)?\s*(.*)" , text )
140- sans_list = []
141- if m :
142- sans_list = m .group (2 ).split ("," )
143-
144- def _strip_prefix (s ):
145- return s .split (":" )[1 ] if s .strip ().startswith ("DNS:" ) else s .strip ()
146- return list (map (_strip_prefix , sans_list ))
144+ sans = self .cert .extensions .get_extension_for_class (x509 .SubjectAlternativeName )
145+ return sans .value .get_values_for_type (DNSName )
147146
148147 def get_must_staple (self ):
149- text = OpenSSL .crypto .dump_certificate (OpenSSL .crypto .FILETYPE_TEXT , self .cert ).decode ("utf-8" )
150- m = re .search (r"1.3.6.1.5.5.7.1.24:\s*\n\s*0...." , text )
151- if not m :
152- # Newer openssl versions print this differently
153- m = re .search (r"TLS Feature:\s*\n\s*status_request\s*\n" , text )
154- return m is not None
148+ try :
149+ self .cert .extensions .get_extension_for_oid (ExtensionOID .TLS_FEATURE )
150+ return True
151+ except ExtensionNotFound :
152+ return False
155153
156154 @classmethod
157155 def validate_privkey (cls , privkey_path , passphrase = None ):
158- privkey_data = cls ._load_binary_file (privkey_path )
159- if passphrase :
160- privkey = OpenSSL .crypto .load_privatekey (OpenSSL .crypto .FILETYPE_PEM , privkey_data , passphrase )
161- else :
162- privkey = OpenSSL .crypto .load_privatekey (OpenSSL .crypto .FILETYPE_PEM , privkey_data )
163- return privkey .check ()
164-
165- def validate_cert_matches_priv_key (self , privkey_path ):
166- # Verifies that the private key and cert match.
167- privkey_data = MDCertUtil ._load_binary_file (privkey_path )
168- privkey = OpenSSL .crypto .load_privatekey (OpenSSL .crypto .FILETYPE_PEM , privkey_data )
169- context = OpenSSL .SSL .Context (OpenSSL .SSL .SSLv23_METHOD )
170- context .use_privatekey (privkey )
171- context .use_certificate (self .cert )
172- context .check_privatekey ()
173-
174- # --------- _utils_ ---------
175-
176- def astr (self , s ):
177- return s .decode ('utf-8' )
178-
179- def _parse_tsp (self , tsp ):
180- # timestampss returned by PyOpenSSL are bytes
181- # parse date and time part
182- s = ("%s-%s-%s %s:%s:%s" % (self .astr (tsp [0 :4 ]), self .astr (tsp [4 :6 ]), self .astr (tsp [6 :8 ]),
183- self .astr (tsp [8 :10 ]), self .astr (tsp [10 :12 ]), self .astr (tsp [12 :14 ])))
184- timestamp = datetime .strptime (s , '%Y-%m-%d %H:%M:%S' )
185- # adjust timezone
186- tz_h , tz_m = 0 , 0
187- m = re .match (r"([+\-]\d{2})(\d{2})" , self .astr (tsp [14 :]))
188- if m :
189- tz_h , tz_m = int (m .group (1 )), int (m .group (2 )) if tz_h > 0 else - 1 * int (m .group (2 ))
190- return timestamp .replace (tzinfo = self .FixedOffset (60 * tz_h + tz_m ))
191-
192- @classmethod
193- def _load_binary_file (cls , path ):
194- with open (path , mode = "rb" ) as file :
195- return file .read ()
196-
197- class FixedOffset (tzinfo ):
198-
199- def __init__ (self , offset ):
200- self .__offset = timedelta (minutes = offset )
201-
202- def utcoffset (self , dt ):
203- return self .__offset
204-
205- def tzname (self , dt ):
206- return None
207-
208- def dst (self , dt ):
209- return timedelta (0 )
156+ with open (privkey_path ) as fd :
157+ privkey = load_pem_private_key ("" .join (fd .readlines ()).encode (), password = passphrase )
158+ return privkey is not None
0 commit comments