|
17 | 17 |
|
18 | 18 | use std::future::Future; |
19 | 19 | use std::ptr::NonNull; |
20 | | -use std::sync::{Arc, OnceLock}; |
| 20 | +use std::sync::{Arc, OnceLock, RwLock}; |
21 | 21 | use std::time::Duration; |
22 | 22 |
|
23 | 23 | use datafusion::datasource::TableProvider; |
@@ -59,11 +59,29 @@ pub fn is_ipython_env(py: Python) -> &'static bool { |
59 | 59 | }) |
60 | 60 | } |
61 | 61 |
|
62 | | -/// Utility to get the Global Datafussion CTX |
| 62 | +fn global_ctx_slot() -> &'static RwLock<Arc<SessionContext>> { |
| 63 | + static CTX: OnceLock<RwLock<Arc<SessionContext>>> = OnceLock::new(); |
| 64 | + CTX.get_or_init(|| RwLock::new(Arc::new(SessionContext::new()))) |
| 65 | +} |
| 66 | + |
| 67 | +/// Utility to get the Global DataFusion CTX. |
| 68 | +/// |
| 69 | +/// Returns an owned `Arc<SessionContext>` snapshot. The underlying slot can be |
| 70 | +/// replaced via [`set_global_ctx`]; existing snapshots are unaffected. |
63 | 71 | #[inline] |
64 | | -pub fn get_global_ctx() -> &'static Arc<SessionContext> { |
65 | | - static CTX: OnceLock<Arc<SessionContext>> = OnceLock::new(); |
66 | | - CTX.get_or_init(|| Arc::new(SessionContext::new())) |
| 72 | +pub fn get_global_ctx() -> Arc<SessionContext> { |
| 73 | + global_ctx_slot() |
| 74 | + .read() |
| 75 | + .expect("global SessionContext lock poisoned") |
| 76 | + .clone() |
| 77 | +} |
| 78 | + |
| 79 | +/// Replace the Global DataFusion CTX. Subsequent calls to [`get_global_ctx`] |
| 80 | +/// will return the new context. Already-cloned `Arc`s are not affected. |
| 81 | +pub fn set_global_ctx(ctx: Arc<SessionContext>) { |
| 82 | + *global_ctx_slot() |
| 83 | + .write() |
| 84 | + .expect("global SessionContext lock poisoned") = ctx; |
67 | 85 | } |
68 | 86 |
|
69 | 87 | /// Utility to collect rust futures with GIL released and respond to |
@@ -224,3 +242,40 @@ pub fn ffi_logical_codec_from_pycapsule(obj: Bound<PyAny>) -> PyResult<FFI_Logic |
224 | 242 |
|
225 | 243 | Ok(codec.clone()) |
226 | 244 | } |
| 245 | + |
| 246 | +#[cfg(test)] |
| 247 | +mod tests { |
| 248 | + use super::*; |
| 249 | + |
| 250 | + /// The global slot must round-trip a custom `SessionContext`. Since the |
| 251 | + /// global is process-wide, this test only asserts identity through a |
| 252 | + /// single set/get cycle and restores the prior value at the end so the |
| 253 | + /// test is independent of ordering with other tests in the binary. |
| 254 | + #[test] |
| 255 | + fn set_global_ctx_replaces_default() { |
| 256 | + let prior = get_global_ctx(); |
| 257 | + let custom = Arc::new(SessionContext::new()); |
| 258 | + let custom_ptr = Arc::as_ptr(&custom); |
| 259 | + |
| 260 | + set_global_ctx(custom.clone()); |
| 261 | + let observed = get_global_ctx(); |
| 262 | + assert_eq!( |
| 263 | + Arc::as_ptr(&observed), |
| 264 | + custom_ptr, |
| 265 | + "get_global_ctx should return the context installed by set_global_ctx", |
| 266 | + ); |
| 267 | + |
| 268 | + // A snapshot taken before the swap should be unaffected after another |
| 269 | + // set_global_ctx call, because get_global_ctx clones the Arc. |
| 270 | + let snapshot = get_global_ctx(); |
| 271 | + let replacement = Arc::new(SessionContext::new()); |
| 272 | + set_global_ctx(replacement); |
| 273 | + assert_eq!( |
| 274 | + Arc::as_ptr(&snapshot), |
| 275 | + custom_ptr, |
| 276 | + "previously cloned snapshots must not be invalidated by set_global_ctx", |
| 277 | + ); |
| 278 | + |
| 279 | + set_global_ctx(prior); |
| 280 | + } |
| 281 | +} |
0 commit comments