Skip to content

Commit fa8d60f

Browse files
committed
error handling poc
1 parent 39915de commit fa8d60f

1 file changed

Lines changed: 71 additions & 30 deletions

File tree

Modules/_base64/src/lib.rs

Lines changed: 71 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ use cpython_sys::PyBuffer_Release;
1212
use cpython_sys::PyBytes_AsString;
1313
use cpython_sys::PyBytes_FromStringAndSize;
1414
use cpython_sys::PyErr_NoMemory;
15+
use cpython_sys::PyErr_SetNone;
16+
use cpython_sys::PyErr_SetObject;
1517
use cpython_sys::PyErr_SetString;
1618
use cpython_sys::PyExc_TypeError;
1719
use cpython_sys::PyMethodDef;
@@ -22,6 +24,59 @@ use cpython_sys::PyModuleDef_Init;
2224
use cpython_sys::PyObject;
2325
use cpython_sys::PyObject_GetBuffer;
2426

27+
// Error Handling Abstraction
28+
29+
/// Zero-sized type indicating that a Python exception has been set.
30+
/// Using this type will ensure `Result<&PyObject, ExecutedErr>` and `Result<PyRc, ExecutedErr>`
31+
/// to be same size as `*mut PyObject`.
32+
#[derive(Debug, Clone, Copy)]
33+
pub struct ExecutedErr;
34+
35+
/// Enum representing different ways to set a Python exception.
36+
///
37+
/// This type is NOT stored in Result - it's immediately converted to
38+
/// `ExecutedErr` via `.into()`, which triggers the actual C API call.
39+
pub enum MakeErr {
40+
SetString(*mut PyObject, *const c_char),
41+
SetObject(*mut PyObject, *mut PyObject),
42+
SetNone(*mut PyObject),
43+
NoMemory,
44+
}
45+
46+
impl MakeErr {
47+
fn execute(self) -> ExecutedErr {
48+
match self {
49+
MakeErr::SetString(exc_type, msg) => {
50+
unsafe { PyErr_SetString(exc_type, msg) };
51+
}
52+
MakeErr::SetObject(exc_type, value) => {
53+
unsafe { PyErr_SetObject(exc_type, value) };
54+
}
55+
MakeErr::SetNone(exc_type) => {
56+
unsafe { PyErr_SetNone(exc_type) };
57+
}
58+
MakeErr::NoMemory => {
59+
unsafe { PyErr_NoMemory() };
60+
}
61+
}
62+
ExecutedErr
63+
}
64+
65+
#[inline]
66+
pub fn type_error(msg: *const c_char) -> Self {
67+
Self::SetString(unsafe { PyExc_TypeError }, msg)
68+
}
69+
}
70+
71+
impl From<MakeErr> for ExecutedErr {
72+
#[inline]
73+
fn from(exc: MakeErr) -> Self {
74+
exc.execute()
75+
}
76+
}
77+
78+
pub type PyResult<T> = Result<T, ExecutedErr>;
79+
2580
const PYBUF_SIMPLE: c_int = 0;
2681
const PAD_BYTE: u8 = b'=';
2782
const ENCODE_TABLE: [u8; 64] = *b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
@@ -82,11 +137,12 @@ struct BorrowedBuffer {
82137
}
83138

84139
impl BorrowedBuffer {
85-
fn from_object(obj: &PyObject) -> Result<Self, ()> {
140+
fn from_object(obj: &PyObject) -> PyResult<Self> {
86141
let mut view = MaybeUninit::<Py_buffer>::uninit();
87142
let buffer = unsafe {
88143
if PyObject_GetBuffer(obj.as_raw(), view.as_mut_ptr(), PYBUF_SIMPLE) != 0 {
89-
return Err(());
144+
// PyObject_GetBuffer already set the exception
145+
return Err(ExecutedErr);
90146
}
91147
Self {
92148
view: view.assume_init(),
@@ -122,12 +178,7 @@ pub unsafe extern "C" fn standard_b64encode(
122178
nargs: Py_ssize_t,
123179
) -> *mut PyObject {
124180
if nargs != 1 {
125-
unsafe {
126-
PyErr_SetString(
127-
PyExc_TypeError,
128-
c"standard_b64encode() takes exactly one argument".as_ptr(),
129-
);
130-
}
181+
MakeErr::type_error(c"standard_b64encode() takes exactly one argument".as_ptr()).execute();
131182
return ptr::null_mut();
132183
}
133184

@@ -140,51 +191,41 @@ pub unsafe extern "C" fn standard_b64encode(
140191
}
141192
}
142193

143-
fn standard_b64encode_impl(source: &PyObject) -> Result<*mut PyObject, ()> {
144-
let buffer = match BorrowedBuffer::from_object(source) {
145-
Ok(buf) => buf,
146-
Err(_) => return Err(()),
147-
};
194+
fn standard_b64encode_impl(source: &PyObject) -> PyResult<*mut PyObject> {
195+
let buffer = BorrowedBuffer::from_object(source)?;
148196

149197
let view_len = buffer.len();
150198
if view_len < 0 {
151-
unsafe {
152-
PyErr_SetString(
153-
PyExc_TypeError,
154-
c"standard_b64encode() argument has negative length".as_ptr(),
155-
);
156-
}
157-
return Err(());
199+
return Err(MakeErr::type_error(
200+
c"standard_b64encode() argument has negative length".as_ptr(),
201+
)
202+
.into());
158203
}
159204

160205
let input_len = view_len as usize;
161206
let input = unsafe { slice::from_raw_parts(buffer.as_ptr(), input_len) };
162207

163208
let Some(output_len) = encoded_output_len(input_len) else {
164-
unsafe {
165-
PyErr_NoMemory();
166-
}
167-
return Err(());
209+
return Err(MakeErr::NoMemory.into());
168210
};
169211

170212
if output_len > isize::MAX as usize {
171-
unsafe {
172-
PyErr_NoMemory();
173-
}
174-
return Err(());
213+
return Err(MakeErr::NoMemory.into());
175214
}
176215

177216
let result = unsafe { PyBytes_FromStringAndSize(ptr::null(), output_len as Py_ssize_t) };
178217
if result.is_null() {
179-
return Err(());
218+
// PyBytes_FromStringAndSize already set the exception
219+
return Err(ExecutedErr);
180220
}
181221

182222
let dest_ptr = unsafe { PyBytes_AsString(result) };
183223
if dest_ptr.is_null() {
184224
unsafe {
185225
Py_DecRef(result);
186226
}
187-
return Err(());
227+
// PyBytes_AsString already set the exception
228+
return Err(ExecutedErr);
188229
}
189230
let dest = unsafe { slice::from_raw_parts_mut(dest_ptr.cast::<u8>(), output_len) };
190231

0 commit comments

Comments
 (0)