1use std::ffi::CStr;
2use std::ptr::NonNull;
3use std::cell::RefCell;
4
5use crate::c_api::{mts_status_t, MTS_SUCCESS, mts_last_error};
6
7const RUST_FUNCTION_FAILED_ERROR_CODE: i32 = -4242;
9
10thread_local! {
11 pub static LAST_RUST_ERROR: RefCell<Error> = RefCell::new(Error {code: None, message: String::new()});
13}
14
15pub use metatensor_sys::Error;
16
17pub fn check_status(status: mts_status_t) -> Result<(), Error> {
19 if status == MTS_SUCCESS {
20 return Ok(())
21 } else if status > 0 {
22 let message = unsafe {
23 CStr::from_ptr(mts_last_error())
24 };
25 let message = message.to_str().expect("invalid UTF8");
26
27 return Err(Error { code: Some(status), message: message.to_owned() });
28 } else if status == RUST_FUNCTION_FAILED_ERROR_CODE {
29 return Err(LAST_RUST_ERROR.with(|e| e.borrow().clone()));
30 } else {
31 return Err(Error { code: Some(status), message: "external function call failed".into() });
32 }
33}
34
35pub fn check_ptr<T>(ptr: *mut T) -> Result<NonNull<T>, Error> {
37 if let Some(ptr) = NonNull::new(ptr) {
38 return Ok(ptr);
39 } else {
40 let message = unsafe {
41 CStr::from_ptr(mts_last_error())
42 };
43 let message = message.to_str().expect("invalid UTF8");
44
45 return Err(Error { code: None, message: message.to_owned() });
46 }
47}
48
49
50pub(crate) fn catch_unwind<F>(function: F) -> mts_status_t where F: FnOnce() + std::panic::UnwindSafe {
53 match std::panic::catch_unwind(function) {
54 Ok(()) => MTS_SUCCESS,
55 Err(e) => {
56 LAST_RUST_ERROR.with(|last_error| {
59 let mut last_error = last_error.borrow_mut();
60 *last_error = e.into();
61 });
62
63 RUST_FUNCTION_FAILED_ERROR_CODE
64 }
65 }
66}