Skip to content

Commit cad6907

Browse files
committed
Make the functions FapiContext::set_callbacks() and FapiContext::clear_callbacks() return the previous callbacks instance, if any was set.
1 parent 5b849fb commit cad6907

21 files changed

Lines changed: 134 additions & 165 deletions

src/callback.rs

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
use super::{HashAlgorithm, fapi_sys::TPM2_ALG_ID, memory::CStringHolder};
88
use log::trace;
9-
use std::{borrow::Cow, ffi::CStr, fmt::Debug};
9+
use std::{any::Any, borrow::Cow, ffi::CStr, fmt::Debug};
1010

1111
// ==========================================================================
1212
// Callback parameters
@@ -132,46 +132,62 @@ impl<'a> PolicyActionCbParam<'a> {
132132
/// }
133133
/// }
134134
/// ```
135-
pub trait FapiCallbacks: Debug + Send {
135+
pub trait FapiCallbacks: Any + Send + Debug {
136136
/// A callback function that allows the FAPI to request authorization values.
137+
///
138+
/// The default implementation of this function returns `None`. Please override as needed!
139+
///
140+
/// *See also:* [`Fapi_SetAuthCB()`](https://tpm2-tss.readthedocs.io/en/stable/group___fapi___set_auth_c_b.html)
137141
fn auth_cb(&self, _param: AuthCbParam) -> Option<Cow<'static, str>> {
138142
None
139143
}
140144

141145
/// A callback function that allows the FAPI to request signatures.
142146
///
143147
/// Signatures are requested for authorizing TPM objects.
148+
///
149+
/// The default implementation of this function returns `None`. Please override as needed!
150+
///
151+
/// *See also:* [`Fapi_SetSignCB()`](https://tpm2-tss.readthedocs.io/en/stable/group___fapi___set_sign_c_b.html)
144152
fn sign_cb(&self, _param: SignCbParam) -> Option<Vec<u8>> {
145153
None
146154
}
147155

148156
/// A callback function that allows the FAPI to request branch choices.
149157
///
150158
/// It is usually called during policy evaluation.
159+
///
160+
/// The default implementation of this function returns `None`. Please override as needed!
161+
///
162+
/// *See also:* [`Fapi_SetBranchCB()`](https://tpm2-tss.readthedocs.io/en/stable/group___fapi___set_sign_c_b.html)
151163
fn branch_cb(&self, _param: BranchCbParam) -> Option<usize> {
152164
None
153165
}
154166

155167
/// A callback function that allows the FAPI to notify the application.
156168
///
157169
/// It is usually called to announce policy actions.
170+
///
171+
/// The default implementation of this function returns `false`. Please override as needed!
172+
///
173+
/// *See also:* [`Fapi_SetPolicyActionCB()`](https://tpm2-tss.readthedocs.io/en/stable/group___fapi___set_sign_c_b.html)
158174
fn policy_action_cb(&self, _param: PolicyActionCbParam) -> bool {
159175
false
160176
}
161177
}
162178

163179
// ==========================================================================
164-
// Callbacks wrapper
180+
// Callbacks manager
165181
// ==========================================================================
166182

167183
#[derive(Debug)]
168-
pub struct Callbacks {
184+
pub struct CallbackManager {
169185
inner: Box<dyn FapiCallbacks>,
170186
auth_value: Option<CStringHolder>,
171187
sign_data: Option<Vec<u8>>,
172188
}
173189

174-
impl Callbacks {
190+
impl CallbackManager {
175191
pub fn new(callbacks: impl FapiCallbacks + 'static) -> Self {
176192
Self { inner: Box::new(callbacks), auth_value: None, sign_data: None }
177193
}
@@ -185,6 +201,10 @@ impl Callbacks {
185201
}
186202
}
187203

204+
pub fn into_inner(self) -> Box<dyn FapiCallbacks> {
205+
self.inner
206+
}
207+
188208
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
189209
// Callback functions
190210
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -301,7 +321,7 @@ pub mod entry_point {
301321
},
302322
memory::{ptr_to_cstr_vec, ptr_to_opt_cstr},
303323
},
304-
Callbacks,
324+
CallbackManager,
305325
};
306326
use std::{
307327
ffi::{CStr, c_char, c_void},
@@ -321,7 +341,7 @@ pub mod entry_point {
321341
return mk_fapi_rc!(TSS2_BASE_RC_BAD_VALUE);
322342
}
323343
unsafe {
324-
match (*(user_data as *mut Callbacks)).auth_cb(CStr::from_ptr(object_path), ptr_to_opt_cstr(description)) {
344+
match (*(user_data as *mut CallbackManager)).auth_cb(CStr::from_ptr(object_path), ptr_to_opt_cstr(description)) {
325345
Some(auth_value) => {
326346
*auth = auth_value.as_ptr();
327347
TSS2_RC_SUCCESS
@@ -356,7 +376,7 @@ pub mod entry_point {
356376
return mk_fapi_rc!(TSS2_BASE_RC_BAD_VALUE);
357377
}
358378
unsafe {
359-
match (*(user_data as *mut Callbacks)).sign_cb(
379+
match (*(user_data as *mut CallbackManager)).sign_cb(
360380
CStr::from_ptr(object_path),
361381
ptr_to_opt_cstr(description),
362382
CStr::from_ptr(public_key),
@@ -387,7 +407,7 @@ pub mod entry_point {
387407
return mk_fapi_rc!(TSS2_BASE_RC_BAD_VALUE);
388408
}
389409
unsafe {
390-
match (*(user_data as *mut Callbacks)).branch_cb(
410+
match (*(user_data as *mut CallbackManager)).branch_cb(
391411
CStr::from_ptr(object_path),
392412
ptr_to_opt_cstr(description),
393413
&ptr_to_cstr_vec(branch_names, num_branches)[..],
@@ -407,7 +427,7 @@ pub mod entry_point {
407427
return mk_fapi_rc!(TSS2_BASE_RC_BAD_VALUE);
408428
}
409429
unsafe {
410-
match (*(user_data as *mut Callbacks)).policy_action_cb(CStr::from_ptr(object_path), ptr_to_opt_cstr(action)) {
430+
match (*(user_data as *mut CallbackManager)).policy_action_cb(CStr::from_ptr(object_path), ptr_to_opt_cstr(action)) {
411431
true => TSS2_RC_SUCCESS,
412432
_ => mk_fapi_rc!(TSS2_BASE_RC_GENERAL_FAILURE),
413433
}

src/context.rs

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use std::{ffi::c_char, fmt::Display, num::NonZeroUsize, os::raw::c_void, ptr, sy
99
use crate::{
1010
BaseErrorCode, BlobType, ErrorCode, FapiCallbacks, ImportData, InternalError, KeyFlags, NvFlags, PaddingFlags, QuoteFlags, QuoteResult, SealFlags,
1111
SignResult, TpmBlobs,
12-
callback::{Callbacks, entry_point},
12+
callback::{CallbackManager, entry_point},
1313
fapi_sys::{self, FAPI_CONTEXT, TPM2_RC, TSS2_RC, constants::TSS2_RC_SUCCESS},
1414
flags::Flags,
1515
json::{self, JsonValue},
@@ -80,7 +80,7 @@ type TctiOpaqueContextBlob = *mut [u8; 0];
8080
#[derive(Debug)]
8181
pub struct FapiContext {
8282
native_holder: NativeContextHolder,
83-
callbacks: Option<Callbacks>,
83+
callbacks: Option<CallbackManager>,
8484
}
8585

8686
/// A struct that wraps the native C pointer to the underlying FAPI_CONTEXT instance
@@ -148,37 +148,41 @@ impl FapiContext {
148148
// Callback setters
149149
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
150150

151-
/// This function registers application-defined callback functions with the FAPI context.
151+
/// This function registers application-defined callback functions with the FAPI context, replacing any previously registered callbacks.
152152
///
153153
/// The callback functions are implemented via the [`FapiCallbacks`](crate::FapiCallbacks) trait.
154154
///
155+
/// If successful, the function retruns the previously registered callback functions, which may be `None`.
156+
///
155157
/// *See also:*
156158
/// - [`Fapi_SetAuthCB()`](https://tpm2-tss.readthedocs.io/en/stable/group___fapi___set_auth_c_b.html)
157159
/// - [`Fapi_SetSignCB()`](https://tpm2-tss.readthedocs.io/en/stable/group___fapi___set_sign_c_b.html)
158160
/// - [`Fapi_SetBranchCB()`](https://tpm2-tss.readthedocs.io/en/stable/group___fapi___set_sign_c_b.html)
159161
/// - [`Fapi_SetPolicyActionCB()`](https://tpm2-tss.readthedocs.io/en/stable/group___fapi___set_sign_c_b.html)
160-
pub fn set_callbacks(&mut self, callbacks: impl FapiCallbacks + 'static) -> Result<(), ErrorCode> {
161-
let _previous = self.callbacks.replace(Callbacks::new(callbacks));
162-
let callbacks_ptr = self.callbacks.as_mut().unwrap() as *mut Callbacks as *mut c_void;
162+
pub fn set_callbacks(&mut self, callbacks: impl FapiCallbacks + 'static) -> Result<Option<Box<dyn FapiCallbacks>>, ErrorCode> {
163+
let previous = self.callbacks.replace(CallbackManager::new(callbacks));
164+
let callbacks_ptr = self.callbacks.as_mut().unwrap() as *mut CallbackManager as *mut c_void;
163165
let results = [
164166
self.fapi_call(false, |context| unsafe { fapi_sys::Fapi_SetAuthCB(context, Some(entry_point::auth_cb), callbacks_ptr) }),
165167
self.fapi_call(false, |context| unsafe { fapi_sys::Fapi_SetSignCB(context, Some(entry_point::sign_cb), callbacks_ptr) }),
166168
self.fapi_call(false, |context| unsafe { fapi_sys::Fapi_SetBranchCB(context, Some(entry_point::branch_cb), callbacks_ptr) }),
167169
self.fapi_call(false, |context| unsafe { fapi_sys::Fapi_SetPolicyActionCB(context, Some(entry_point::policy_action_cb), callbacks_ptr) }),
168170
];
169-
results.iter().try_fold((), |_acc, &result| result)
171+
results.iter().try_fold(previous.map(|cb| cb.into_inner()), |acc, ret| ret.map(|_| acc))
170172
}
171173

172-
/// This function un-registers any application-defined callback functions that have been registers via the [`set_callbacks()`](FapiContext::set_callbacks) function.
173-
pub fn clear_callbacks(&mut self) -> Result<(), ErrorCode> {
174-
let _previous = self.callbacks.take();
174+
/// This function un-registers the application-defined callback functions that have been registers via [`set_callbacks()`](FapiContext::set_callbacks).
175+
///
176+
/// If successful, the function retruns the previously registered callback functions, which may be `None`.
177+
pub fn clear_callbacks(&mut self) -> Result<Option<Box<dyn FapiCallbacks>>, ErrorCode> {
178+
let previous = self.callbacks.take();
175179
let results = [
176180
self.fapi_call(false, |context| unsafe { fapi_sys::Fapi_SetAuthCB(context, None, ptr::null_mut()) }),
177181
self.fapi_call(false, |context| unsafe { fapi_sys::Fapi_SetSignCB(context, None, ptr::null_mut()) }),
178182
self.fapi_call(false, |context| unsafe { fapi_sys::Fapi_SetBranchCB(context, None, ptr::null_mut()) }),
179183
self.fapi_call(false, |context| unsafe { fapi_sys::Fapi_SetPolicyActionCB(context, None, ptr::null_mut()) }),
180184
];
181-
results.iter().try_fold((), |_acc, &result| result)
185+
results.iter().try_fold(previous.map(|cb| cb.into_inner()), |acc, ret| ret.map(|_| acc))
182186
}
183187

184188
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

tests/03_provision_test.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,7 @@ fn test_provision() {
3333
};
3434

3535
// Set up auth callback
36-
let (callbacks, _logger) = MyCallbacks::new(PASSWORD, None);
37-
if let Err(error) = context.set_callbacks(callbacks) {
36+
if let Err(error) = context.set_callbacks(MyCallbacks::new(PASSWORD, None)) {
3837
panic!("Setting up the callback has failed: {:?}", error)
3938
}
4039

@@ -72,7 +71,7 @@ fn test_to_destruction() {
7271
Ok(fpai_ctx) => fpai_ctx,
7372
Err(error) => panic!("Failed to create context: {:?}", error),
7473
};
75-
let (callbacks, _logger) = MyCallbacks::new(PASSWORD, None);
74+
let callbacks = MyCallbacks::new(PASSWORD, None);
7675
tpm_initialize!(context, PASSWORD, callbacks);
7776
}
7877

tests/04_get_random_test.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ fn test_get_random() {
3434
};
3535

3636
// Initialize TPM, if not already initialized
37-
let (callbacks, _logger) = MyCallbacks::new(PASSWORD, None);
38-
tpm_initialize!(context, PASSWORD, callbacks);
37+
tpm_initialize!(context, PASSWORD, MyCallbacks::new(PASSWORD, None));
3938

4039
// Fetch random data
4140
let random_data = match context.get_random(NonZeroUsize::new(128usize).unwrap()) {

tests/05_key_test.rs

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,7 @@ fn test_create_key() {
4242
};
4343

4444
// Initialize TPM, if not already initialized
45-
let (callbacks, _logger) = MyCallbacks::new(PASSWORD, None);
46-
tpm_initialize!(context, PASSWORD, callbacks);
45+
tpm_initialize!(context, PASSWORD, MyCallbacks::new(PASSWORD, None));
4746

4847
// Create new key, if not already created
4948
match context.create_key(key_path, Some(KEY_FLAGS), None, Some(PASSWORD)) {
@@ -70,8 +69,7 @@ fn test_list_keys() {
7069
};
7170

7271
// Initialize TPM, if not already initialized
73-
let (callbacks, _logger) = MyCallbacks::new(PASSWORD, None);
74-
tpm_initialize!(context, PASSWORD, callbacks);
72+
tpm_initialize!(context, PASSWORD, MyCallbacks::new(PASSWORD, None));
7573

7674
// Create new key, if not already created
7775
for key_path in key_paths {
@@ -116,8 +114,7 @@ fn test_export_key() {
116114
};
117115

118116
// Initialize TPM, if not already initialized
119-
let (callbacks, _logger) = MyCallbacks::new(PASSWORD, None);
120-
tpm_initialize!(context, PASSWORD, callbacks);
117+
tpm_initialize!(context, PASSWORD, MyCallbacks::new(PASSWORD, None));
121118

122119
// Create new key, if not already created
123120
match context.create_key(key_path, Some(KEY_FLAGS), None, Some(PASSWORD)) {
@@ -164,8 +161,7 @@ fn test_import_key() {
164161
};
165162

166163
// Initialize TPM, if not already initialized
167-
let (callbacks, _logger) = MyCallbacks::new(PASSWORD, None);
168-
tpm_initialize!(context, PASSWORD, callbacks);
164+
tpm_initialize!(context, PASSWORD, MyCallbacks::new(PASSWORD, None));
169165

170166
// Import the existing public key
171167
match context.import(key_path, ImportData::from_pem(PUBLIC_KEY_DATA).unwrap()) {
@@ -199,8 +195,7 @@ fn test_delete() {
199195
};
200196

201197
// Initialize TPM, if not already initialized
202-
let (callbacks, _logger) = MyCallbacks::new(PASSWORD, None);
203-
tpm_initialize!(context, PASSWORD, callbacks);
198+
tpm_initialize!(context, PASSWORD, MyCallbacks::new(PASSWORD, None));
204199

205200
// Create new key, if not already created
206201
match context.create_key(key_path, Some(KEY_FLAGS), None, Some(PASSWORD)) {
@@ -240,8 +235,7 @@ fn test_get_tpm_blobs() {
240235
};
241236

242237
// Initialize TPM, if not already initialized
243-
let (callbacks, _logger) = MyCallbacks::new(PASSWORD, None);
244-
tpm_initialize!(context, PASSWORD, callbacks);
238+
tpm_initialize!(context, PASSWORD, MyCallbacks::new(PASSWORD, None));
245239

246240
// Create new key, if not already created
247241
match context.create_key(key_path, Some(KEY_FLAGS), None, Some(PASSWORD)) {
@@ -283,8 +277,7 @@ fn test_get_tpm_blobs_with_private() {
283277
};
284278

285279
// Initialize TPM, if not already initialized
286-
let (callbacks, _logger) = MyCallbacks::new(PASSWORD, None);
287-
tpm_initialize!(context, PASSWORD, callbacks);
280+
tpm_initialize!(context, PASSWORD, MyCallbacks::new(PASSWORD, None));
288281

289282
// Create new key, if not already created
290283
match context.create_key(key_path, Some(KEY_FLAGS), None, Some(PASSWORD)) {
@@ -331,8 +324,7 @@ fn test_get_esys_blob() {
331324
};
332325

333326
// Initialize TPM, if not already initialized
334-
let (callbacks, _logger) = MyCallbacks::new(PASSWORD, None);
335-
tpm_initialize!(context, PASSWORD, callbacks);
327+
tpm_initialize!(context, PASSWORD, MyCallbacks::new(PASSWORD, None));
336328

337329
// Create new key, if not already created
338330
match context.create_key(key_path, Some(KEY_FLAGS), None, Some(PASSWORD)) {

tests/06_encrypt_decrypt_test.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@ fn test_encrypt() {
4747
};
4848

4949
// Initialize TPM, if not already initialized
50-
let (callbacks, _logger) = MyCallbacks::new(PASSWORD, None);
51-
tpm_initialize!(context, PASSWORD, callbacks);
50+
tpm_initialize!(context, PASSWORD, MyCallbacks::new(PASSWORD, None));
5251

5352
// Create new key, if not already created
5453
match context.create_key(key_path, Some(KEY_FLAGS_ENCR), None, Some(PASSWORD)) {
@@ -92,8 +91,7 @@ fn test_decrypt() {
9291
};
9392

9493
// Initialize TPM, if not already initialized
95-
let (callbacks, _logger) = MyCallbacks::new(PASSWORD, None);
96-
tpm_initialize!(context, PASSWORD, callbacks);
94+
tpm_initialize!(context, PASSWORD, MyCallbacks::new(PASSWORD, None));
9795

9896
// Create new key, if not already created
9997
match context.create_key(key_path, Some(KEY_FLAGS_ENCR), None, Some(PASSWORD)) {

tests/07_signature_test.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@ fn test_sign() {
4949
};
5050

5151
// Initialize TPM, if not already initialized
52-
let (callbacks, _logger) = MyCallbacks::new(PASSWORD, None);
53-
tpm_initialize!(context, PASSWORD, callbacks);
52+
tpm_initialize!(context, PASSWORD, MyCallbacks::new(PASSWORD, None));
5453

5554
// Create new key, if not already created
5655
match context.create_key(key_path, Some(KEY_FLAGS_SIGN), None, Some(PASSWORD)) {
@@ -103,8 +102,7 @@ fn test_sign_with_pubkey() {
103102
};
104103

105104
// Initialize TPM, if not already initialized
106-
let (callbacks, _logger) = MyCallbacks::new(PASSWORD, None);
107-
tpm_initialize!(context, PASSWORD, callbacks);
105+
tpm_initialize!(context, PASSWORD, MyCallbacks::new(PASSWORD, None));
108106

109107
// Create new key, if not already created
110108
match context.create_key(key_path, Some(KEY_FLAGS_SIGN), None, Some(PASSWORD)) {
@@ -160,8 +158,7 @@ fn test_verify_signature() {
160158
};
161159

162160
// Initialize TPM, if not already initialized
163-
let (callbacks, _logger) = MyCallbacks::new(PASSWORD, None);
164-
tpm_initialize!(context, PASSWORD, callbacks);
161+
tpm_initialize!(context, PASSWORD, MyCallbacks::new(PASSWORD, None));
165162

166163
// Create new key, if not already created
167164
match context.create_key(key_path, Some(KEY_FLAGS_SIGN), None, Some(PASSWORD)) {

0 commit comments

Comments
 (0)