Skip to content

Commit 76eb02c

Browse files
Cleanup threading/ownership semantics of PlatformBridgeCertValidator (#2713)
Cleanup threading/ownership semantics of PlatformBridgeCertValidator by only running static methods on the worker thread. The worker thread is owned by the Validator which join()s the thread before deleting it. So by making the worker thread only operate on data copied (or moved) into the the thread, we avoid multi-threaded access. Risk Level: Low - Not in production Testing: Existing unit Docs Changes: N/A Release Notes: N/A Signed-off-by: Ryan Hamilton <rch@google.com>
1 parent 3e6af4e commit 76eb02c

2 files changed

Lines changed: 94 additions & 73 deletions

File tree

library/common/extensions/cert_validator/platform_bridge/platform_bridge_cert_validator.cc

Lines changed: 66 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ PlatformBridgeCertValidator::PlatformBridgeCertValidator(
2626

2727
PlatformBridgeCertValidator::~PlatformBridgeCertValidator() {
2828
// Wait for validation threads to finish.
29-
for (auto& [id, thread] : validation_threads_) {
30-
if (thread.joinable()) {
31-
thread.join();
29+
for (auto& [id, job] : validation_jobs_) {
30+
if (job.validation_thread_.joinable()) {
31+
job.validation_thread_.join();
3232
}
3333
}
3434
}
@@ -87,33 +87,35 @@ ValidationResults PlatformBridgeCertValidator::doVerifyCertChain(
8787
subject_alt_names = {std::string(hostname)};
8888
}
8989

90-
auto validation = std::make_unique<PendingValidation>(
91-
*this, std::move(certs), host, std::move(subject_alt_names), std::move(callback));
92-
PendingValidation* validation_ptr = validation.get();
93-
validations_.insert(std::move(validation));
94-
std::thread verification_thread(&PendingValidation::verifyCertsByPlatform, validation_ptr);
95-
std::thread::id thread_id = verification_thread.get_id();
96-
validation_threads_[thread_id] = std::move(verification_thread);
90+
ValidationJob job;
91+
job.result_callback_ = std::move(callback);
92+
job.validation_thread_ = std::thread(&verifyCertChainByPlatform, platform_validator_,
93+
&(job.result_callback_->dispatcher()), std::move(certs),
94+
std::string(host), std::move(subject_alt_names), this);
95+
std::thread::id thread_id = job.validation_thread_.get_id();
96+
validation_jobs_[thread_id] = std::move(job);
97+
9798
return {ValidationResults::ValidationStatus::Pending, absl::nullopt, absl::nullopt};
9899
}
99100

100101
void PlatformBridgeCertValidator::verifyCertChainByPlatform(
101-
const std::vector<envoy_data>& cert_chain, const std::string& hostname,
102-
const std::vector<std::string>& subject_alt_names, PendingValidation& pending_validation) {
102+
const envoy_cert_validator* platform_validator, Event::Dispatcher* dispatcher,
103+
std::vector<envoy_data> cert_chain, std::string hostname,
104+
std::vector<std::string> subject_alt_names, PlatformBridgeCertValidator* parent) {
103105
ASSERT(!cert_chain.empty());
104106
ENVOY_LOG(trace, "Start verifyCertChainByPlatform for host {}", hostname);
105107
// This is running in a stand alone thread other than the engine thread.
106108
envoy_data leaf_cert_der = cert_chain[0];
107109
bssl::UniquePtr<X509> leaf_cert(d2i_X509(
108110
nullptr, const_cast<const unsigned char**>(&leaf_cert_der.bytes), leaf_cert_der.length));
109111
envoy_cert_validation_result result =
110-
platform_validator_->validate_cert(cert_chain.data(), cert_chain.size(), hostname.c_str());
112+
platform_validator->validate_cert(cert_chain.data(), cert_chain.size(), hostname.c_str());
111113
bool success = result.result == ENVOY_SUCCESS;
112114
if (!success) {
113115
ENVOY_LOG(debug, result.error_details);
114-
pending_validation.postVerifyResultAndCleanUp(/*success=*/allow_untrusted_certificate_,
115-
result.error_details, result.tls_alert,
116-
makeOptRef(stats_.fail_verify_error_));
116+
postVerifyResultAndCleanUp(success, std::move(hostname), result.error_details, result.tls_alert,
117+
ValidationFailureType::FAIL_VERIFY_ERROR, platform_validator,
118+
dispatcher, parent);
117119
return;
118120
}
119121

@@ -123,49 +125,69 @@ void PlatformBridgeCertValidator::verifyCertChainByPlatform(
123125
if (!success) {
124126
error_details = "PlatformBridgeCertValidator_verifySubjectAltName failed: SNI mismatch.";
125127
ENVOY_LOG(debug, error_details);
126-
pending_validation.postVerifyResultAndCleanUp(/*success=*/allow_untrusted_certificate_,
127-
error_details, SSL_AD_BAD_CERTIFICATE,
128-
makeOptRef(stats_.fail_verify_san_));
128+
postVerifyResultAndCleanUp(success, std::move(hostname), error_details, SSL_AD_BAD_CERTIFICATE,
129+
ValidationFailureType::FAIL_VERIFY_SAN, platform_validator,
130+
dispatcher, parent);
129131
return;
130132
}
131-
pending_validation.postVerifyResultAndCleanUp(success, error_details, SSL_AD_CERTIFICATE_UNKNOWN,
132-
{});
133-
}
134-
135-
void PlatformBridgeCertValidator::PendingValidation::verifyCertsByPlatform() {
136-
parent_.verifyCertChainByPlatform(certs_, hostname_, subject_alt_names_, *this);
133+
postVerifyResultAndCleanUp(success, std::move(hostname), error_details,
134+
SSL_AD_CERTIFICATE_UNKNOWN, ValidationFailureType::SUCCESS,
135+
platform_validator, dispatcher, parent);
137136
}
138137

139-
void PlatformBridgeCertValidator::PendingValidation::postVerifyResultAndCleanUp(
140-
bool success, absl::string_view error_details, uint8_t tls_alert,
141-
OptRef<Stats::Counter> error_counter) {
138+
void PlatformBridgeCertValidator::postVerifyResultAndCleanUp(
139+
bool success, std::string hostname, absl::string_view error_details, uint8_t tls_alert,
140+
ValidationFailureType failure_type, const envoy_cert_validator* platform_validator,
141+
Event::Dispatcher* dispatcher, PlatformBridgeCertValidator* parent) {
142142
ENVOY_LOG(trace,
143143
"Finished platform cert validation for {}, post result callback to network thread",
144-
hostname_);
144+
hostname);
145145

146-
if (parent_.platform_validator_->validation_cleanup) {
147-
parent_.platform_validator_->validation_cleanup();
146+
if (platform_validator->validation_cleanup) {
147+
platform_validator->validation_cleanup();
148148
}
149-
std::weak_ptr<size_t> weak_alive_indicator(parent_.alive_indicator_);
149+
std::weak_ptr<size_t> weak_alive_indicator(parent->alive_indicator_);
150150

151-
// Once this task runs, `this` will be deleted so this must be the last statement in the file.
152-
result_callback_->dispatcher().post([this, weak_alive_indicator, success,
153-
error = std::string(error_details), tls_alert, error_counter,
154-
thread_id = std::this_thread::get_id()]() {
151+
dispatcher->post([weak_alive_indicator, success, hostname = std::move(hostname),
152+
error = std::string(error_details), tls_alert, failure_type,
153+
thread_id = std::this_thread::get_id(), parent]() {
155154
if (weak_alive_indicator.expired()) {
156155
return;
157156
}
158-
ENVOY_LOG(trace, "Got validation result for {} from platform", hostname_);
159-
parent_.validation_threads_[thread_id].join();
160-
parent_.validation_threads_.erase(thread_id);
161-
if (error_counter.has_value()) {
162-
const_cast<Stats::Counter&>(error_counter.ref()).inc();
163-
}
164-
result_callback_->onCertValidationResult(success, error, tls_alert);
165-
parent_.validations_.erase(this);
157+
parent->onVerificationComplete(thread_id, hostname, success, error, tls_alert, failure_type);
166158
});
167159
}
168160

161+
void PlatformBridgeCertValidator::onVerificationComplete(std::thread::id thread_id,
162+
std::string hostname, bool success,
163+
std::string error, uint8_t tls_alert,
164+
ValidationFailureType failure_type) {
165+
ENVOY_LOG(trace, "Got validation result for {} from platform", hostname);
166+
167+
auto job_handle = validation_jobs_.extract(thread_id);
168+
if (job_handle.empty()) {
169+
IS_ENVOY_BUG("No job found for thread");
170+
return;
171+
}
172+
ValidationJob& job = job_handle.mapped();
173+
job.validation_thread_.join();
174+
175+
switch (failure_type) {
176+
case ValidationFailureType::SUCCESS:
177+
break;
178+
case ValidationFailureType::FAIL_VERIFY_ERROR:
179+
stats_.fail_verify_error_.inc();
180+
case ValidationFailureType::FAIL_VERIFY_SAN:
181+
stats_.fail_verify_san_.inc();
182+
}
183+
184+
job.result_callback_->onCertValidationResult(allow_untrusted_certificate_ || success, error,
185+
tls_alert);
186+
ENVOY_LOG(trace,
187+
"Finished platform cert validation for {}, post result callback to network thread",
188+
hostname);
189+
}
190+
169191
} // namespace Tls
170192
} // namespace TransportSockets
171193
} // namespace Extensions

library/common/extensions/cert_validator/platform_bridge/platform_bridge_cert_validator.h

Lines changed: 28 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -64,45 +64,44 @@ class PlatformBridgeCertValidator : public CertValidator, Logger::Loggable<Logge
6464
}
6565

6666
private:
67-
class PendingValidation {
68-
public:
69-
PendingValidation(PlatformBridgeCertValidator& parent, std::vector<envoy_data> certs,
70-
absl::string_view hostname, std::vector<std::string> subject_alt_names,
71-
Ssl::ValidateResultCallbackPtr result_callback)
72-
: parent_(parent), certs_(std::move(certs)), hostname_(hostname),
73-
subject_alt_names_(std::move(subject_alt_names)),
74-
result_callback_(std::move(result_callback)) {}
67+
enum class ValidationFailureType {
68+
SUCCESS,
69+
FAIL_VERIFY_ERROR,
70+
FAIL_VERIFY_SAN,
71+
};
7572

76-
// Ensure that this class is never moved or copied to guarantee pointer stability.
77-
PendingValidation(const PendingValidation&) = delete;
78-
PendingValidation(PendingValidation&&) = delete;
73+
// Calls into platform APIs in a stand-alone thread to verify the given certs.
74+
// Once the validation is done, the result will be posted back to the current
75+
// thread to trigger callback and update verify stats.
76+
// Must be called on the validation thread.
77+
static void verifyCertChainByPlatform(const envoy_cert_validator* platform_validator,
78+
Event::Dispatcher* dispatcher,
79+
std::vector<envoy_data> cert_chain, std::string hostname,
80+
std::vector<std::string> subject_alt_names,
81+
PlatformBridgeCertValidator* parent);
7982

80-
// Calls into platform APIs in a stand-alone thread to verify the given certs.
81-
// Once the validation is done, the result will be posted back to the current
82-
// thread to trigger callback and update verify stats.
83-
void verifyCertsByPlatform();
83+
// Must be called on the validation thread.
84+
static void postVerifyResultAndCleanUp(bool success, std::string hostname,
85+
absl::string_view error_details, uint8_t tls_alert,
86+
ValidationFailureType failure_type,
87+
const envoy_cert_validator* platform_validator,
88+
Event::Dispatcher* dispatcher,
89+
PlatformBridgeCertValidator* parent);
8490

85-
void postVerifyResultAndCleanUp(bool success, absl::string_view error_details,
86-
uint8_t tls_alert, OptRef<Stats::Counter> error_counter);
91+
// Called when a pending verification completes. Must be invoked on the main thread.
92+
void onVerificationComplete(std::thread::id thread_id, std::string hostname, bool success,
93+
std::string error_details, uint8_t tls_alert,
94+
ValidationFailureType failure_type);
8795

88-
private:
89-
PlatformBridgeCertValidator& parent_;
90-
const std::vector<envoy_data> certs_;
91-
const std::string hostname_;
92-
const std::vector<std::string> subject_alt_names_;
96+
struct ValidationJob {
9397
Ssl::ValidateResultCallbackPtr result_callback_;
98+
std::thread validation_thread_;
9499
};
95100

96-
void verifyCertChainByPlatform(const std::vector<envoy_data>& cert_chain,
97-
const std::string& hostname,
98-
const std::vector<std::string>& subject_alt_names,
99-
PendingValidation& pending_validation);
100-
101101
const bool allow_untrusted_certificate_;
102102
const envoy_cert_validator* platform_validator_;
103103
SslStats& stats_;
104-
absl::flat_hash_map<std::thread::id, std::thread> validation_threads_;
105-
absl::flat_hash_set<std::unique_ptr<PendingValidation>> validations_;
104+
absl::flat_hash_map<std::thread::id, ValidationJob> validation_jobs_;
106105
std::shared_ptr<size_t> alive_indicator_{new size_t(1)};
107106
};
108107

0 commit comments

Comments
 (0)