|
| 1 | +#include <memory> |
| 2 | +#include <string> |
| 3 | + |
| 4 | +#include "source/common/buffer/buffer_impl.h" |
| 5 | +#include "source/common/crypto/crypto_impl.h" |
| 6 | +#include "source/common/crypto/utility.h" |
| 7 | +#include "source/common/network/transport_socket_options_impl.h" |
| 8 | +#include "source/extensions/transport_sockets/tls/cert_validator/default_validator.h" |
| 9 | +#include "source/extensions/transport_sockets/tls/cert_validator/san_matcher.h" |
| 10 | +#include "source/extensions/transport_sockets/tls/stats.h" |
| 11 | + |
| 12 | +#include "test/common/stats/stat_test_utility.h" |
| 13 | +#include "test/extensions/transport_sockets/tls/cert_validator/test_common.h" |
| 14 | +#include "test/extensions/transport_sockets/tls/ssl_test_utility.h" |
| 15 | +#include "test/extensions/transport_sockets/tls/test_data/san_dns2_cert_info.h" |
| 16 | +#include "test/mocks/event/mocks.h" |
| 17 | +#include "test/mocks/ssl/mocks.h" |
| 18 | +#include "test/test_common/environment.h" |
| 19 | +#include "test/test_common/test_runtime.h" |
| 20 | +#include "test/test_common/utility.h" |
| 21 | + |
| 22 | +#include "gmock/gmock.h" |
| 23 | +#include "gtest/gtest.h" |
| 24 | +#include "library/common/data/utility.h" |
| 25 | +#include "library/common/extensions/cert_validator/platform_bridge/config.h" |
| 26 | +#include "openssl/ssl.h" |
| 27 | +#include "openssl/x509v3.h" |
| 28 | + |
| 29 | +using SSLContextPtr = Envoy::CSmartPtr<SSL_CTX, SSL_CTX_free>; |
| 30 | + |
| 31 | +using envoy::extensions::transport_sockets::tls::v3::CertificateValidationContext; |
| 32 | + |
| 33 | +using testing::_; |
| 34 | +using testing::NiceMock; |
| 35 | +using testing::Return; |
| 36 | +using testing::ReturnRef; |
| 37 | +using testing::StrEq; |
| 38 | + |
| 39 | +namespace Envoy { |
| 40 | +namespace Extensions { |
| 41 | +namespace TransportSockets { |
| 42 | +namespace Tls { |
| 43 | + |
| 44 | +class MockValidateResultCallback : public Ssl::ValidateResultCallback { |
| 45 | +public: |
| 46 | + ~MockValidateResultCallback() override = default; |
| 47 | + |
| 48 | + MOCK_METHOD(Event::Dispatcher&, dispatcher, ()); |
| 49 | + MOCK_METHOD(void, onCertValidationResult, (bool, const std::string&, uint8_t)); |
| 50 | +}; |
| 51 | + |
| 52 | +class MockValidator { |
| 53 | +public: |
| 54 | + MOCK_METHOD(void, cleanup, ()); |
| 55 | + MOCK_METHOD(envoy_cert_validation_result, validate, (const envoy_data*, uint8_t, const char*)); |
| 56 | +}; |
| 57 | + |
| 58 | +class PlatformBridgeCertValidatorTest |
| 59 | + : public testing::TestWithParam<CertificateValidationContext::TrustChainVerification> { |
| 60 | +protected: |
| 61 | + PlatformBridgeCertValidatorTest() |
| 62 | + : api_(Api::createApiForTest()), dispatcher_(api_->allocateDispatcher("test_thread")), |
| 63 | + stats_(generateSslStats(test_store_)), ssl_ctx_(SSL_CTX_new(TLS_method())), |
| 64 | + callback_(std::make_unique<MockValidateResultCallback>()), is_server_(false) { |
| 65 | + mock_validator_ = std::make_unique<MockValidator>(); |
| 66 | + main_thread_id_ = std::this_thread::get_id(); |
| 67 | + |
| 68 | + platform_validator_.validate_cert = (PlatformBridgeCertValidatorTest::validate); |
| 69 | + platform_validator_.validation_cleanup = (PlatformBridgeCertValidatorTest::cleanup); |
| 70 | + } |
| 71 | + |
| 72 | + void initializeConfig() { |
| 73 | + EXPECT_CALL(config_, caCert()).WillOnce(ReturnRef(empty_string_)); |
| 74 | + EXPECT_CALL(config_, certificateRevocationList()).WillOnce(ReturnRef(empty_string_)); |
| 75 | + EXPECT_CALL(config_, trustChainVerification()).WillOnce(Return(GetParam())); |
| 76 | + } |
| 77 | + |
| 78 | + ~PlatformBridgeCertValidatorTest() { |
| 79 | + mock_validator_.reset(); |
| 80 | + main_thread_id_ = std::thread::id(); |
| 81 | + Envoy::Assert::resetEnvoyBugCountersForTest(); |
| 82 | + } |
| 83 | + |
| 84 | + ABSL_MUST_USE_RESULT bool waitForDispatcherToExit() { |
| 85 | + Event::TimerPtr timer(dispatcher_->createTimer([this]() -> void { dispatcher_->exit(); })); |
| 86 | + timer->enableTimer(std::chrono::milliseconds(100)); |
| 87 | + dispatcher_->run(Event::Dispatcher::RunType::RunUntilExit); |
| 88 | + return !timer->enabled(); |
| 89 | + } |
| 90 | + |
| 91 | + bool acceptInvalidCertificates() { |
| 92 | + return GetParam() == CertificateValidationContext::ACCEPT_UNTRUSTED; |
| 93 | + } |
| 94 | + |
| 95 | + static envoy_cert_validation_result validate(const envoy_data* certs, uint8_t size, |
| 96 | + const char* hostname) { |
| 97 | + // Validate must be called on the worker thread, not the main thread. |
| 98 | + EXPECT_NE(main_thread_id_, std::this_thread::get_id()); |
| 99 | + |
| 100 | + // Make sure the cert was converted correctly. |
| 101 | + const Buffer::InstancePtr buffer = Data::Utility::toInternalData(*certs); |
| 102 | + const auto digest = Common::Crypto::UtilitySingleton::get().getSha256Digest(*buffer); |
| 103 | + EXPECT_EQ(TEST_SAN_DNS2_CERT_256_HASH, Hex::encode(digest)); |
| 104 | + return mock_validator_->validate(certs, size, hostname); |
| 105 | + } |
| 106 | + |
| 107 | + static void cleanup() { |
| 108 | + // Validate must be called on the worker thread, not the main thread. |
| 109 | + EXPECT_NE(main_thread_id_, std::this_thread::get_id()); |
| 110 | + mock_validator_->cleanup(); |
| 111 | + } |
| 112 | + |
| 113 | + static std::unique_ptr<MockValidator> mock_validator_; |
| 114 | + static std::thread::id main_thread_id_; |
| 115 | + |
| 116 | + Api::ApiPtr api_; |
| 117 | + Event::DispatcherPtr dispatcher_; |
| 118 | + Stats::TestUtil::TestStore test_store_; |
| 119 | + SslStats stats_; |
| 120 | + Ssl::MockCertificateValidationContextConfig config_; |
| 121 | + std::string empty_string_; |
| 122 | + SSLContextPtr ssl_ctx_; |
| 123 | + TestSslExtendedSocketInfo ssl_extended_info_; |
| 124 | + CertValidator::ExtraValidationContext validation_context_; |
| 125 | + Network::TransportSocketOptionsConstSharedPtr transport_socket_options_; |
| 126 | + std::unique_ptr<MockValidateResultCallback> callback_; |
| 127 | + bool is_server_; |
| 128 | + envoy_cert_validator platform_validator_; |
| 129 | +}; |
| 130 | + |
| 131 | +std::unique_ptr<MockValidator> PlatformBridgeCertValidatorTest::mock_validator_; |
| 132 | +std::thread::id PlatformBridgeCertValidatorTest::main_thread_id_; |
| 133 | + |
| 134 | +INSTANTIATE_TEST_SUITE_P(TrustMode, PlatformBridgeCertValidatorTest, |
| 135 | + testing::ValuesIn({CertificateValidationContext::VERIFY_TRUST_CHAIN, |
| 136 | + CertificateValidationContext::ACCEPT_UNTRUSTED})); |
| 137 | + |
| 138 | +TEST_P(PlatformBridgeCertValidatorTest, NoConfig) { |
| 139 | + EXPECT_ENVOY_BUG( |
| 140 | + { PlatformBridgeCertValidator validator(nullptr, stats_, &platform_validator_); }, |
| 141 | + "Invalid certificate validation context config."); |
| 142 | +} |
| 143 | + |
| 144 | +TEST_P(PlatformBridgeCertValidatorTest, NonEmptyCaCert) { |
| 145 | + std::string ca_cert = "xyz"; |
| 146 | + EXPECT_CALL(config_, caCert()).WillRepeatedly(ReturnRef(ca_cert)); |
| 147 | + EXPECT_CALL(config_, certificateRevocationList()).WillRepeatedly(ReturnRef(empty_string_)); |
| 148 | + EXPECT_CALL(config_, trustChainVerification()).WillRepeatedly(Return(GetParam())); |
| 149 | + |
| 150 | + EXPECT_ENVOY_BUG( |
| 151 | + { PlatformBridgeCertValidator validator(&config_, stats_, &platform_validator_); }, |
| 152 | + "Invalid certificate validation context config."); |
| 153 | +} |
| 154 | + |
| 155 | +TEST_P(PlatformBridgeCertValidatorTest, NonEmptyRevocationList) { |
| 156 | + std::string revocation_list = "xyz"; |
| 157 | + EXPECT_CALL(config_, caCert()).WillRepeatedly(ReturnRef(empty_string_)); |
| 158 | + EXPECT_CALL(config_, certificateRevocationList()).WillRepeatedly(ReturnRef(revocation_list)); |
| 159 | + EXPECT_CALL(config_, trustChainVerification()).WillRepeatedly(Return(GetParam())); |
| 160 | + |
| 161 | + EXPECT_ENVOY_BUG( |
| 162 | + { PlatformBridgeCertValidator validator(&config_, stats_, &platform_validator_); }, |
| 163 | + "Invalid certificate validation context config."); |
| 164 | +} |
| 165 | + |
| 166 | +TEST_P(PlatformBridgeCertValidatorTest, NoCallback) { |
| 167 | + initializeConfig(); |
| 168 | + PlatformBridgeCertValidator validator(&config_, stats_, &platform_validator_); |
| 169 | + |
| 170 | + bssl::UniquePtr<STACK_OF(X509)> cert_chain = readCertChainFromFile(TestEnvironment::substitute( |
| 171 | + "{{ test_rundir }}/test/extensions/transport_sockets/tls/test_data/san_dns2_cert.pem")); |
| 172 | + std::string hostname = "www.example.com"; |
| 173 | + |
| 174 | + EXPECT_ENVOY_BUG( |
| 175 | + { |
| 176 | + validator.doVerifyCertChain(*cert_chain, Ssl::ValidateResultCallbackPtr(), |
| 177 | + &ssl_extended_info_, transport_socket_options_, *ssl_ctx_, |
| 178 | + validation_context_, is_server_, hostname); |
| 179 | + }, |
| 180 | + "No callback specified"); |
| 181 | +} |
| 182 | + |
| 183 | +TEST_P(PlatformBridgeCertValidatorTest, EmptyCertChain) { |
| 184 | + initializeConfig(); |
| 185 | + PlatformBridgeCertValidator validator(&config_, stats_, &platform_validator_); |
| 186 | + |
| 187 | + bssl::UniquePtr<STACK_OF(X509)> cert_chain(sk_X509_new_null()); |
| 188 | + std::string hostname = "www.example.com"; |
| 189 | + |
| 190 | + ValidationResults results = validator.doVerifyCertChain( |
| 191 | + *cert_chain, std::move(callback_), &ssl_extended_info_, transport_socket_options_, *ssl_ctx_, |
| 192 | + validation_context_, is_server_, hostname); |
| 193 | + EXPECT_EQ(ValidationResults::ValidationStatus::Failed, results.status); |
| 194 | + EXPECT_FALSE(results.tls_alert.has_value()); |
| 195 | + ASSERT_TRUE(results.error_details.has_value()); |
| 196 | + EXPECT_EQ("verify cert chain failed: empty cert chain.", results.error_details.value()); |
| 197 | + EXPECT_EQ(Envoy::Ssl::ClientValidationStatus::NotValidated, |
| 198 | + ssl_extended_info_.certificateValidationStatus()); |
| 199 | +} |
| 200 | + |
| 201 | +TEST_P(PlatformBridgeCertValidatorTest, ValidCertificate) { |
| 202 | + initializeConfig(); |
| 203 | + PlatformBridgeCertValidator validator(&config_, stats_, &platform_validator_); |
| 204 | + |
| 205 | + std::string hostname = "server1.example.com"; |
| 206 | + bssl::UniquePtr<STACK_OF(X509)> cert_chain = readCertChainFromFile(TestEnvironment::substitute( |
| 207 | + "{{ test_rundir }}/test/extensions/transport_sockets/tls/test_data/san_dns2_cert.pem")); |
| 208 | + envoy_cert_validation_result result = {ENVOY_SUCCESS, 0, NULL}; |
| 209 | + EXPECT_CALL(*mock_validator_, validate(_, _, _)).WillOnce(Return(result)); |
| 210 | + EXPECT_CALL(*mock_validator_, cleanup()); |
| 211 | + auto& callback_ref = *callback_; |
| 212 | + EXPECT_CALL(callback_ref, dispatcher()).WillRepeatedly(ReturnRef(*dispatcher_)); |
| 213 | + |
| 214 | + ValidationResults results = validator.doVerifyCertChain( |
| 215 | + *cert_chain, std::move(callback_), &ssl_extended_info_, transport_socket_options_, *ssl_ctx_, |
| 216 | + validation_context_, is_server_, hostname); |
| 217 | + EXPECT_EQ(ValidationResults::ValidationStatus::Pending, results.status); |
| 218 | + |
| 219 | + EXPECT_CALL(callback_ref, onCertValidationResult(true, "", 46)).WillOnce(Invoke([this]() { |
| 220 | + EXPECT_EQ(main_thread_id_, std::this_thread::get_id()); |
| 221 | + dispatcher_->exit(); |
| 222 | + })); |
| 223 | + EXPECT_FALSE(waitForDispatcherToExit()); |
| 224 | +} |
| 225 | + |
| 226 | +TEST_P(PlatformBridgeCertValidatorTest, ValidCertificateButInvalidSni) { |
| 227 | + initializeConfig(); |
| 228 | + PlatformBridgeCertValidator validator(&config_, stats_, &platform_validator_); |
| 229 | + |
| 230 | + std::string hostname = "server2.example.com"; |
| 231 | + bssl::UniquePtr<STACK_OF(X509)> cert_chain = readCertChainFromFile(TestEnvironment::substitute( |
| 232 | + "{{ test_rundir }}/test/extensions/transport_sockets/tls/test_data/san_dns2_cert.pem")); |
| 233 | + envoy_cert_validation_result result = {ENVOY_SUCCESS, 0, NULL}; |
| 234 | + EXPECT_CALL(*mock_validator_, validate(_, _, _)).WillOnce(Return(result)); |
| 235 | + EXPECT_CALL(*mock_validator_, cleanup()); |
| 236 | + auto& callback_ref = *callback_; |
| 237 | + EXPECT_CALL(callback_ref, dispatcher()).WillRepeatedly(ReturnRef(*dispatcher_)); |
| 238 | + |
| 239 | + ValidationResults results = validator.doVerifyCertChain( |
| 240 | + *cert_chain, std::move(callback_), &ssl_extended_info_, transport_socket_options_, *ssl_ctx_, |
| 241 | + validation_context_, is_server_, hostname); |
| 242 | + EXPECT_EQ(ValidationResults::ValidationStatus::Pending, results.status); |
| 243 | + |
| 244 | + EXPECT_CALL(callback_ref, |
| 245 | + onCertValidationResult( |
| 246 | + acceptInvalidCertificates(), |
| 247 | + "PlatformBridgeCertValidator_verifySubjectAltName failed: SNI mismatch.", |
| 248 | + SSL_AD_BAD_CERTIFICATE)) |
| 249 | + .WillOnce(Invoke([this]() { dispatcher_->exit(); })); |
| 250 | + EXPECT_FALSE(waitForDispatcherToExit()); |
| 251 | +} |
| 252 | + |
| 253 | +TEST_P(PlatformBridgeCertValidatorTest, ValidCertificateSniOverride) { |
| 254 | + initializeConfig(); |
| 255 | + PlatformBridgeCertValidator validator(&config_, stats_, &platform_validator_); |
| 256 | + |
| 257 | + std::vector<std::string> subject_alt_names = {"server1.example.com"}; |
| 258 | + |
| 259 | + std::string hostname = "server2.example.com"; |
| 260 | + bssl::UniquePtr<STACK_OF(X509)> cert_chain = readCertChainFromFile(TestEnvironment::substitute( |
| 261 | + "{{ test_rundir }}/test/extensions/transport_sockets/tls/test_data/san_dns2_cert.pem")); |
| 262 | + envoy_cert_validation_result result = {ENVOY_SUCCESS, 0, NULL}; |
| 263 | + EXPECT_CALL(*mock_validator_, validate(_, _, StrEq(subject_alt_names[0].c_str()))) |
| 264 | + .WillOnce(Return(result)); |
| 265 | + EXPECT_CALL(*mock_validator_, cleanup()); |
| 266 | + auto& callback_ref = *callback_; |
| 267 | + EXPECT_CALL(callback_ref, dispatcher()).WillRepeatedly(ReturnRef(*dispatcher_)); |
| 268 | + transport_socket_options_ = |
| 269 | + std::make_shared<Network::TransportSocketOptionsImpl>("", std::move(subject_alt_names)); |
| 270 | + |
| 271 | + ValidationResults results = validator.doVerifyCertChain( |
| 272 | + *cert_chain, std::move(callback_), &ssl_extended_info_, transport_socket_options_, *ssl_ctx_, |
| 273 | + validation_context_, is_server_, hostname); |
| 274 | + EXPECT_EQ(ValidationResults::ValidationStatus::Pending, results.status); |
| 275 | + |
| 276 | + // The cert will be validated against the overridden name not the invalid name "server2". |
| 277 | + EXPECT_CALL(callback_ref, onCertValidationResult(true, "", 46)).WillOnce(Invoke([this]() { |
| 278 | + dispatcher_->exit(); |
| 279 | + })); |
| 280 | + EXPECT_FALSE(waitForDispatcherToExit()); |
| 281 | +} |
| 282 | + |
| 283 | +TEST_P(PlatformBridgeCertValidatorTest, DeletedWithValidationPending) { |
| 284 | + initializeConfig(); |
| 285 | + auto validator = |
| 286 | + std::make_unique<PlatformBridgeCertValidator>(&config_, stats_, &platform_validator_); |
| 287 | + |
| 288 | + std::string hostname = "server1.example.com"; |
| 289 | + bssl::UniquePtr<STACK_OF(X509)> cert_chain = readCertChainFromFile(TestEnvironment::substitute( |
| 290 | + "{{ test_rundir }}/test/extensions/transport_sockets/tls/test_data/san_dns2_cert.pem")); |
| 291 | + envoy_cert_validation_result result = {ENVOY_SUCCESS, 0, NULL}; |
| 292 | + EXPECT_CALL(*mock_validator_, validate(_, _, _)).WillOnce(Return(result)); |
| 293 | + EXPECT_CALL(*mock_validator_, cleanup()); |
| 294 | + auto& callback_ref = *callback_; |
| 295 | + EXPECT_CALL(callback_ref, dispatcher()).WillRepeatedly(ReturnRef(*dispatcher_)); |
| 296 | + |
| 297 | + ValidationResults results = validator->doVerifyCertChain( |
| 298 | + *cert_chain, std::move(callback_), &ssl_extended_info_, transport_socket_options_, *ssl_ctx_, |
| 299 | + validation_context_, is_server_, hostname); |
| 300 | + EXPECT_EQ(ValidationResults::ValidationStatus::Pending, results.status); |
| 301 | + |
| 302 | + validator.reset(); |
| 303 | + |
| 304 | + // Since the validator was deleted, the callback should not be invoked and |
| 305 | + // so the dispatcher will not exit until the alarm fires. |
| 306 | + EXPECT_TRUE(waitForDispatcherToExit()); |
| 307 | +} |
| 308 | + |
| 309 | +} // namespace Tls |
| 310 | +} // namespace TransportSockets |
| 311 | +} // namespace Extensions |
| 312 | +} // namespace Envoy |
0 commit comments