metatensor/
errors.rs

1use std::ffi::CStr;
2use std::ptr::NonNull;
3use std::cell::RefCell;
4
5use crate::c_api::{mts_status_t, MTS_SUCCESS, mts_last_error};
6
7/// Error code used to indicate failure of a Rust function
8const RUST_FUNCTION_FAILED_ERROR_CODE: i32 = -4242;
9
10thread_local! {
11    /// Storage for the last error coming from a Rust function
12    pub static LAST_RUST_ERROR: RefCell<Error> = RefCell::new(Error {code: None, message: String::new()});
13}
14
15pub use metatensor_sys::Error;
16
17/// Check an `mts_status_t`, returning an error if is it not `MTS_SUCCESS`
18pub fn check_status(status: mts_status_t) -> Result<(), Error> {
19    if status == MTS_SUCCESS {
20        return Ok(())
21    } else if status > 0 {
22        let message = unsafe {
23            CStr::from_ptr(mts_last_error())
24        };
25        let message = message.to_str().expect("invalid UTF8");
26
27        return Err(Error { code: Some(status), message: message.to_owned() });
28    } else if status == RUST_FUNCTION_FAILED_ERROR_CODE {
29        return Err(LAST_RUST_ERROR.with(|e| e.borrow().clone()));
30    } else {
31        return Err(Error { code: Some(status), message: "external function call failed".into() });
32    }
33}
34
35/// Check a pointer allocated by metatensor-core, returning an error if is null
36pub fn check_ptr<T>(ptr: *mut T) -> Result<NonNull<T>, Error> {
37    if let Some(ptr) = NonNull::new(ptr) {
38        return Ok(ptr);
39    } else {
40        let message = unsafe {
41            CStr::from_ptr(mts_last_error())
42        };
43        let message = message.to_str().expect("invalid UTF8");
44
45        return Err(Error { code: None, message: message.to_owned() });
46    }
47}
48
49
50/// An alternative to `std::panic::catch_unwind` that automatically transform
51/// the error into `mts_status_t`.
52pub(crate) fn catch_unwind<F>(function: F) -> mts_status_t where F: FnOnce() + std::panic::UnwindSafe {
53    match std::panic::catch_unwind(function) {
54        Ok(()) => MTS_SUCCESS,
55        Err(e) => {
56            // Store the error in LAST_RUST_ERROR, we will extract it later
57            // in `check_status`
58            LAST_RUST_ERROR.with(|last_error| {
59                let mut last_error = last_error.borrow_mut();
60                *last_error = e.into();
61            });
62
63            RUST_FUNCTION_FAILED_ERROR_CODE
64        }
65    }
66}