1use std::ffi::CStr;
2
3use crate::c_api::{mts_status_t, MTS_SUCCESS, MTS_CALLBACK_ERROR};
4
5pub use metatensor_sys::Error;
6
7fn get_last_error(status: Option<mts_status_t>) -> Error {
8 let mut message = std::ptr::null();
9 let mut origin = std::ptr::null();
10 let mut user_data = std::ptr::null_mut();
11 let last_error_status = unsafe {
12 crate::c_api::mts_last_error(
13 &mut message, &mut origin, &mut user_data
14 )
15 };
16
17 if last_error_status != MTS_SUCCESS {
18 return Error {
19 code: status,
20 message: "INTERNAL ERROR: failed to get the last error".into(),
21 };
22 }
23
24 let message = if message.is_null() {
25 "<no message provided>"
26 } else {
27 unsafe { CStr::from_ptr(message).to_str().unwrap_or("<invalid UTF-8 in error message>") }
28 };
29
30 let origin = if origin.is_null() {
31 "<no origin provided>"
32 } else {
33 unsafe { CStr::from_ptr(origin).to_str().unwrap_or("<invalid UTF-8 in error origin>") }
34 };
35
36 if !user_data.is_null() && origin == "Rust Error" {
37 let rust_error = unsafe {
38 user_data.cast::<Error>().as_ref().expect("should not be null")
39 };
40 return rust_error.clone();
41 }
42
43 return Error {
44 code: status,
45 message: message.to_owned(),
46 };
47}
48
49unsafe extern "C" fn error_deleter(data: *mut std::ffi::c_void) {
50 let _ = unsafe { Box::from_raw(data.cast::<Error>()) };
51}
52
53fn store_last_error(error: Error) -> mts_status_t {
54 let c_message = std::ffi::CString::new(error.message.clone()).expect("found NULL byte in error message");
55 let c_origin = std::ffi::CString::new("Rust Error").expect("found NULL byte in error origin");
56 let status = unsafe {
57 crate::c_api::mts_set_last_error(
58 c_message.as_ptr(),
59 c_origin.as_ptr(),
60 Box::into_raw(Box::new(error)).cast(),
61 Some(error_deleter),
62 )
63 };
64
65 check_status(status).expect("failed to set last error");
66
67 return MTS_CALLBACK_ERROR;
68}
69
70pub fn check_status(status: mts_status_t) -> Result<(), Error> {
72 if status == MTS_SUCCESS {
73 return Ok(())
74 } else {
75 return Err(get_last_error(Some(status)));
76 }
77}
78
79pub fn check_ptr<T>(ptr: *const T) -> Result<(), Error> {
81 if ptr.is_null() {
82 return Err(get_last_error(None));
83 }
84
85 return Ok(())
86}
87
88
89pub(crate) fn catch_unwind<F>(function: F) -> mts_status_t where F: FnOnce() -> Result<(), Error> + std::panic::UnwindSafe {
92 match std::panic::catch_unwind(function) {
93 Ok(Ok(())) => MTS_SUCCESS,
94 Ok(Err(e)) => {
95 return store_last_error(e);
96 },
97 Err(e) => {
98 return store_last_error(e.into());
99 }
100 }
101}