From 2afc6160f00d6308b935e2b3a1d579e9019e9b77 Mon Sep 17 00:00:00 2001 From: Anthony Scemama Date: Thu, 12 Oct 2023 19:09:30 +0200 Subject: [PATCH] Added bitfields --- rust/trexio/build.py | 18 ++--- rust/trexio/src/bitfield.rs | 138 ++++++++++++++++++++++++++++++++++++ rust/trexio/src/lib.rs | 53 +++++++++++++- rust/trexio/src/test.rs | 19 ++++- 4 files changed, 213 insertions(+), 15 deletions(-) create mode 100644 rust/trexio/src/bitfield.rs diff --git a/rust/trexio/build.py b/rust/trexio/build.py index 2c0782f..cf6746e 100755 --- a/rust/trexio/build.py +++ b/rust/trexio/build.py @@ -128,13 +128,13 @@ pub fn read_{group_l}_{element_l}(&self) -> Result<{type_r}, ExitCode> { let rc = c::trexio_read_{group}_{element}_64(self.ptr, &mut data_c); (rc, data_c.try_into().expect("try_into failed in read_{group_l}_{element_l}")) }; - rc_return(rc, data) + rc_return(data, rc) } pub fn write_{group_l}_{element_l}(&self, data: {type_r}) -> Result<(), ExitCode> { let data: {type_c} = data.try_into().expect("try_into failed in write_{group_l}_{element_l}"); let rc = unsafe { c::trexio_write_{group}_{element}_64(self.ptr, data) }; - rc_return(rc, ()) + rc_return((), rc) } """ .replace("{type_c}",type_c) @@ -153,7 +153,7 @@ pub fn read_{group_l}_{element_l}(&self, capacity: usize) -> Result Result<(), ExitCode> { @@ -161,7 +161,7 @@ pub fn write_{group_l}_{element_l}(&self, data: &str) -> Result<(), ExitCode> { let data = string_to_c(data); let data = data.as_ptr() as *const c_char; let rc = unsafe { c::trexio_write_{group}_{element}(self.ptr, data, size) }; - rc_return(rc, ()) + rc_return((), rc) } """ .replace("{type_c}",type_c) @@ -180,7 +180,7 @@ pub fn read_{group_l}_{element_l}(&self) -> Result<{type_r}, ExitCode> { let rc = c::trexio_read_{group}_{element}_64(self.ptr, &mut data_c); (rc, data_c.try_into().expect("try_into failed in read_{group_l}_{element_l}")) }; - rc_return(rc, data) + rc_return(data, rc) } """ .replace("{type_r}",type_r) @@ -211,7 +211,7 @@ pub fn read_{group_l}_{element_l}(&self) -> Result<{type_r}, ExitCode> { let rc = c::trexio_read_safe_{group}_{element}_64(self.ptr, data_c, size.try_into().expect("try_into failed in read_{group}_{element}")); (rc, data) }; - rc_return(rc, data) + rc_return(data, rc) } """ ] r += [ '\n'.join(t) @@ -227,7 +227,7 @@ pub fn write_{group_l}_{element_l}(&self, data: Vec<{type_r}>) -> Result<(), Exi let size: i64 = data.len().try_into().expect("try_into failed in write_{group_l}_{element_l}"); let data = data.as_ptr() as *const {type_c}; let rc = unsafe { c::trexio_write_safe_{group}_{element}_64(self.ptr, data, size) }; - rc_return(rc, ()) + rc_return((), rc) } """ .replace("{type_c}",type_c) @@ -256,7 +256,7 @@ pub fn write_{group_l}_{element_l}(&self, data: Vec<{type_r}>) -> Result<(), Exi let rc = c::trexio_read_{group}_{element}(self.ptr, data_c, capacity.try_into().expect("try_into failed in read_{group}_{element}") ); (rc, data) }; - rc_return(rc, data) + rc_return(data, rc) } """ ] r += [ '\n'.join(t) @@ -280,7 +280,7 @@ pub fn write_{group_l}_{element_l}(&self, data: Vec<&str>) -> Result<(), ExitCod let size : i32 = size.try_into().expect("try_into failed in write_{group}_{element}"); let data_c = data_c.as_ptr() as *mut *const c_char; let rc = unsafe { c::trexio_write_{group}_{element}(self.ptr, data_c, size) }; - rc_return(rc, ()) + rc_return((), rc) } """ .replace("{type_c}",type_c) diff --git a/rust/trexio/src/bitfield.rs b/rust/trexio/src/bitfield.rs new file mode 100644 index 0000000..5303c44 --- /dev/null +++ b/rust/trexio/src/bitfield.rs @@ -0,0 +1,138 @@ +#[derive(Debug)] +pub struct Bitfield { + data: Vec, + n_int: usize +} + +use crate::c; +use crate::File; +use crate::ExitCode; + +impl Bitfield { + + /// Creates a new bitfield , using a number of i64 elements consistent + /// with the number of MOs in the TREXIO file. + pub fn from(n_int: usize, orb_list: Vec) -> Result<(Self, f64), ExitCode> { + + let orb_list: Vec = orb_list.iter().map(|&x| x as i32).collect(); + let occ_num = orb_list.len().try_into().expect("try_into failed in Bitfield::from"); + let orb_list_ptr = orb_list.as_ptr() as *const i32; + let n_int32: i32 = n_int.try_into().expect("try_into failed in Bitfield::from"); + let mut b = vec![0i64 ; n_int]; + let bit_list = b.as_mut_ptr() as *mut c::bitfield_t; + std::mem::forget(b); + let rc = unsafe { + c::trexio_to_bitfield_list(orb_list_ptr, occ_num, bit_list, n_int32) + }; + + let data = unsafe { Vec::from_raw_parts(bit_list, n_int, n_int) }; + + let result = Bitfield { data, n_int }; + + match ExitCode::from(rc) { + ExitCode::Success => Ok( (result, 1.0) ), + ExitCode::PhaseChange=> Ok( (result,-1.0) ), + x => return Err(x), + } + } + + /// Number of i64 needed to represent a spin sector + pub fn n_int(&self) -> usize { + self.n_int + } + + /// Returns the alpha part + pub fn alpha(&self) -> &[i64] { + &self.data[0..self.n_int] + } + + /// Returns the beta part + pub fn beta(&self) -> &[i64] { + let n_int = self.n_int; + &self.data[n_int..2*n_int] + } + + /// Converts to a format usable in the C library + pub fn as_ptr(&self) -> *const c::bitfield_t { + let len = self.data.len(); + let result = &self.data[0..len]; + result.as_ptr() as *const c::bitfield_t + } + + /// Converts to a format usable in the C library + pub fn as_mut_ptr(&mut self) -> *mut c::bitfield_t { + let len = self.data.len(); + let result = &mut self.data[0..len]; + result.as_mut_ptr() as *mut c::bitfield_t + } + + /// Converts the bitfield into a list of orbital indices (0-based) + pub fn to_orbital_list(&self) -> Result< Vec, ExitCode> { + + let n_int : i32 = self.n_int.try_into().expect("try_into failed in to_orbital_list"); + let d1 = self.as_ptr(); + let cap = self.n_int * 64; + let mut list = vec![ 0i32 ; cap ]; + let list_c = list.as_mut_ptr() as *mut i32; + std::mem::forget(list); + + let mut occ_num : i32 = 0; + + let rc = unsafe { c::trexio_to_orbital_list(n_int, d1, list_c, &mut occ_num) }; + match ExitCode::from(rc) { + ExitCode::Success => (), + x => return Err(x) + }; + + let occ_num = occ_num as usize; + let list = unsafe { Vec::from_raw_parts(list_c, occ_num, cap) }; + + let mut result: Vec = Vec::with_capacity(occ_num); + for i in list.iter() { + result.push( *i as usize ); + } + Ok(result) + } + + /// Converts the determinant into a list of orbital indices (0-based) + pub fn to_orbital_list_up_dn(&self) -> Result< (Vec, Vec), ExitCode> { + + let n_int : i32 = (self.n_int/2).try_into().expect("try_into failed in to_orbital_list"); + let d1 = self.as_ptr(); + let cap = self.n_int * 64; + let mut b = vec![ 0i32 ; cap ]; + let list_up_c = b.as_mut_ptr() as *mut i32; + std::mem::forget(b); + let mut b = vec![ 0i32 ; cap ]; + let list_dn_c = b.as_mut_ptr() as *mut i32; + std::mem::forget(b); + + let mut occ_num_up : i32 = 0; + let mut occ_num_dn : i32 = 0; + + let rc = unsafe { c::trexio_to_orbital_list_up_dn(n_int, d1, list_up_c, list_dn_c, &mut occ_num_up, &mut occ_num_dn) }; + match ExitCode::from(rc) { + ExitCode::Success => (), + x => return Err(x) + }; + + let occ_num_up = occ_num_up as usize; + let occ_num_dn = occ_num_dn as usize; + let list_up = unsafe { Vec::from_raw_parts(list_up_c, occ_num_up, cap) }; + let list_dn = unsafe { Vec::from_raw_parts(list_dn_c, occ_num_dn, cap) }; + + let mut result_up: Vec = Vec::with_capacity(occ_num_up); + for i in list_up.iter() { + result_up.push( *i as usize ); + } + + let mut result_dn: Vec = Vec::with_capacity(occ_num_dn); + for i in list_dn.iter() { + result_dn.push( *i as usize ); + } + Ok( (result_up, result_dn) ) + } + +} + + diff --git a/rust/trexio/src/lib.rs b/rust/trexio/src/lib.rs index 0cd70d2..9956bb7 100644 --- a/rust/trexio/src/lib.rs +++ b/rust/trexio/src/lib.rs @@ -11,9 +11,13 @@ pub use exit_code::ExitCode; pub mod back_end; pub use back_end::BackEnd; +/// Bit fields +pub mod bitfield; +pub use bitfield::Bitfield; + pub const PACKAGE_VERSION : &str = unsafe { std::str::from_utf8_unchecked(c::TREXIO_PACKAGE_VERSION) }; -fn rc_return(rc : c::trexio_exit_code, result: T) -> Result { +fn rc_return(result: T, rc : c::trexio_exit_code) -> Result { let rc = ExitCode::from(rc); match rc { ExitCode::Success => Ok(result), @@ -26,6 +30,11 @@ fn string_to_c(s: &str) -> std::ffi::CString { } +pub fn info() -> Result<(),ExitCode> { + let rc = unsafe { c::trexio_info() }; + rc_return((), rc) +} + /// Type for a TREXIO file pub struct File { @@ -43,12 +52,12 @@ impl File { let rc: *mut c::trexio_exit_code = &mut c::TREXIO_SUCCESS.clone(); let result = unsafe { c::trexio_open(file_name_c, mode, back_end, rc) }; let rc = unsafe { *rc }; - rc_return(rc, File { ptr: result }) + rc_return(File { ptr: result }, rc) } pub fn close(self) -> Result<(), ExitCode> { let rc = unsafe { c::trexio_close(self.ptr) }; - rc_return(rc, ()) + rc_return((), rc) } pub fn inquire(file_name: &str) -> Result { @@ -62,6 +71,44 @@ impl File { } } + pub fn get_state(&self) -> Result { + let mut num = 0i32; + let rc = unsafe { c::trexio_get_state(self.ptr, &mut num) }; + let result: usize = num.try_into().expect("try_into failed in get_state"); + rc_return(result, rc) + } + + pub fn set_state(&self, num: usize) -> Result<(), ExitCode> { + let num: i32 = num.try_into().expect("try_into failed in set_state"); + let rc = unsafe { c::trexio_set_state(self.ptr, num) }; + rc_return((), rc) + } + + pub fn set_one_based(&self) -> Result<(), ExitCode> { + let rc = unsafe { c::trexio_set_one_based(self.ptr) }; + rc_return((), rc) + } + + pub fn get_int64_num(&self) -> Result { + let mut num = 0i32; + let rc = unsafe { + c::trexio_get_int64_num(self.ptr, &mut num) + }; + let num:usize = num.try_into().expect("try_into failed in get_int64_num"); + rc_return(num, rc) + } + +/* + pub fn read_determinant_list(&self, offset_file: usize, dset: Vec) -> Result { + let rc = unsafe { + let offset_file: i64 = offset_file; + let buffer_size: *mut i64 = dset.len().try_into().expect("try_into failed in read_determinant_list"); + let dset: *mut i64 = dset.to_c().as_mut_ptr(); + c::trexio_read_determinant_list(self.ptr, offset_file, buffer_size, dset) + }; + } + */ + } include!("generated.rs"); diff --git a/rust/trexio/src/test.rs b/rust/trexio/src/test.rs index 8c32ebd..ac0d4e4 100644 --- a/rust/trexio/src/test.rs +++ b/rust/trexio/src/test.rs @@ -2,6 +2,8 @@ use trexio::back_end::BackEnd; pub fn test_write(file_name: &str, back_end: BackEnd) -> Result<(), trexio::ExitCode> { + let () = trexio::info()?; + // Prepare data to be written let n_buffers = 5; let buf_size_sparse = 100/n_buffers; @@ -10,9 +12,8 @@ pub fn test_write(file_name: &str, back_end: BackEnd) -> Result<(), trexio::Exit let mut index_sparse_ao_2e_int_eri = vec![0i32 ; 400]; for i in 0..100 { let i : usize = i; - let j : i32 = i.try_into().unwrap(); - let fj : f64 = j.try_into().unwrap(); - value_sparse_ao_2e_int_eri[i] = 3.14 + fj; + let j : i32 = i as i32; + value_sparse_ao_2e_int_eri[i] = 3.14 + (j as f64); index_sparse_ao_2e_int_eri[4*i + 0] = 4*j - 3; index_sparse_ao_2e_int_eri[4*i + 1] = 4*j+1 - 3; index_sparse_ao_2e_int_eri[4*i + 2] = 4*j+2 - 3; @@ -91,6 +92,18 @@ pub fn test_write(file_name: &str, back_end: BackEnd) -> Result<(), trexio::Exit } trex_file.write_mo_energy(energy)?; + let mut spin = vec![0 ; mo_num]; + for i in mo_num/2..mo_num { + spin[i] = 1; + } + trex_file.write_mo_spin(spin)?; + + let (det, phase) = trexio::Bitfield::from(4, vec![0, 1, 2, 3, 4, 5, 6, 151, 152, 153, 154])?; + println!("{} {:?}", phase, det); + println!("{:?}", det.to_orbital_list().unwrap()); + println!("{:?}", det.to_orbital_list().unwrap()); + println!("{:?}", det.to_orbital_list_up_dn().unwrap()); + trex_file.close() }