1
0
mirror of https://github.com/TREX-CoE/trexio.git synced 2024-12-23 04:43:57 +01:00

Added bitfields

This commit is contained in:
Anthony Scemama 2023-10-12 19:09:30 +02:00
parent 06be52c0d1
commit 2afc6160f0
4 changed files with 213 additions and 15 deletions

View File

@ -128,13 +128,13 @@ pub fn read_{group_l}_{element_l}(&self) -> Result<{type_r}, ExitCode> {
let rc = c::trexio_read_{group}_{element}_64(self.ptr, &mut data_c);
(rc, data_c.try_into().expect("try_into failed in read_{group_l}_{element_l}"))
};
rc_return(rc, data)
rc_return(data, rc)
}
pub fn write_{group_l}_{element_l}(&self, data: {type_r}) -> Result<(), ExitCode> {
let data: {type_c} = data.try_into().expect("try_into failed in write_{group_l}_{element_l}");
let rc = unsafe { c::trexio_write_{group}_{element}_64(self.ptr, data) };
rc_return(rc, ())
rc_return((), rc)
}
"""
.replace("{type_c}",type_c)
@ -153,7 +153,7 @@ pub fn read_{group_l}_{element_l}(&self, capacity: usize) -> Result<String, Exit
let rc = c::trexio_read_{group}_{element}(self.ptr, data_c, capacity.try_into().expect("try_into failed in read_{group_l}_{element_l}"));
(rc, String::from_raw_parts(data_c as *mut u8, capacity, capacity))
};
rc_return(rc, data)
rc_return(data, rc)
}
pub fn write_{group_l}_{element_l}(&self, data: &str) -> Result<(), ExitCode> {
@ -161,7 +161,7 @@ pub fn write_{group_l}_{element_l}(&self, data: &str) -> Result<(), ExitCode> {
let data = string_to_c(data);
let data = data.as_ptr() as *const c_char;
let rc = unsafe { c::trexio_write_{group}_{element}(self.ptr, data, size) };
rc_return(rc, ())
rc_return((), rc)
}
"""
.replace("{type_c}",type_c)
@ -180,7 +180,7 @@ pub fn read_{group_l}_{element_l}(&self) -> Result<{type_r}, ExitCode> {
let rc = c::trexio_read_{group}_{element}_64(self.ptr, &mut data_c);
(rc, data_c.try_into().expect("try_into failed in read_{group_l}_{element_l}"))
};
rc_return(rc, data)
rc_return(data, rc)
}
"""
.replace("{type_r}",type_r)
@ -211,7 +211,7 @@ pub fn read_{group_l}_{element_l}(&self) -> Result<{type_r}, ExitCode> {
let rc = c::trexio_read_safe_{group}_{element}_64(self.ptr, data_c, size.try_into().expect("try_into failed in read_{group}_{element}"));
(rc, data)
};
rc_return(rc, data)
rc_return(data, rc)
}
""" ]
r += [ '\n'.join(t)
@ -227,7 +227,7 @@ pub fn write_{group_l}_{element_l}(&self, data: Vec<{type_r}>) -> Result<(), Exi
let size: i64 = data.len().try_into().expect("try_into failed in write_{group_l}_{element_l}");
let data = data.as_ptr() as *const {type_c};
let rc = unsafe { c::trexio_write_safe_{group}_{element}_64(self.ptr, data, size) };
rc_return(rc, ())
rc_return((), rc)
}
"""
.replace("{type_c}",type_c)
@ -256,7 +256,7 @@ pub fn write_{group_l}_{element_l}(&self, data: Vec<{type_r}>) -> Result<(), Exi
let rc = c::trexio_read_{group}_{element}(self.ptr, data_c, capacity.try_into().expect("try_into failed in read_{group}_{element}") );
(rc, data)
};
rc_return(rc, data)
rc_return(data, rc)
}
""" ]
r += [ '\n'.join(t)
@ -280,7 +280,7 @@ pub fn write_{group_l}_{element_l}(&self, data: Vec<&str>) -> Result<(), ExitCod
let size : i32 = size.try_into().expect("try_into failed in write_{group}_{element}");
let data_c = data_c.as_ptr() as *mut *const c_char;
let rc = unsafe { c::trexio_write_{group}_{element}(self.ptr, data_c, size) };
rc_return(rc, ())
rc_return((), rc)
}
"""
.replace("{type_c}",type_c)

138
rust/trexio/src/bitfield.rs Normal file
View File

@ -0,0 +1,138 @@
#[derive(Debug)]
pub struct Bitfield {
data: Vec<i64>,
n_int: usize
}
use crate::c;
use crate::File;
use crate::ExitCode;
impl Bitfield {
/// Creates a new bitfield , using a number of i64 elements consistent
/// with the number of MOs in the TREXIO file.
pub fn from(n_int: usize, orb_list: Vec<usize>) -> Result<(Self, f64), ExitCode> {
let orb_list: Vec<i32> = orb_list.iter().map(|&x| x as i32).collect();
let occ_num = orb_list.len().try_into().expect("try_into failed in Bitfield::from");
let orb_list_ptr = orb_list.as_ptr() as *const i32;
let n_int32: i32 = n_int.try_into().expect("try_into failed in Bitfield::from");
let mut b = vec![0i64 ; n_int];
let bit_list = b.as_mut_ptr() as *mut c::bitfield_t;
std::mem::forget(b);
let rc = unsafe {
c::trexio_to_bitfield_list(orb_list_ptr, occ_num, bit_list, n_int32)
};
let data = unsafe { Vec::from_raw_parts(bit_list, n_int, n_int) };
let result = Bitfield { data, n_int };
match ExitCode::from(rc) {
ExitCode::Success => Ok( (result, 1.0) ),
ExitCode::PhaseChange=> Ok( (result,-1.0) ),
x => return Err(x),
}
}
/// Number of i64 needed to represent a spin sector
pub fn n_int(&self) -> usize {
self.n_int
}
/// Returns the alpha part
pub fn alpha(&self) -> &[i64] {
&self.data[0..self.n_int]
}
/// Returns the beta part
pub fn beta(&self) -> &[i64] {
let n_int = self.n_int;
&self.data[n_int..2*n_int]
}
/// Converts to a format usable in the C library
pub fn as_ptr(&self) -> *const c::bitfield_t {
let len = self.data.len();
let result = &self.data[0..len];
result.as_ptr() as *const c::bitfield_t
}
/// Converts to a format usable in the C library
pub fn as_mut_ptr(&mut self) -> *mut c::bitfield_t {
let len = self.data.len();
let result = &mut self.data[0..len];
result.as_mut_ptr() as *mut c::bitfield_t
}
/// Converts the bitfield into a list of orbital indices (0-based)
pub fn to_orbital_list(&self) -> Result< Vec<usize>, ExitCode> {
let n_int : i32 = self.n_int.try_into().expect("try_into failed in to_orbital_list");
let d1 = self.as_ptr();
let cap = self.n_int * 64;
let mut list = vec![ 0i32 ; cap ];
let list_c = list.as_mut_ptr() as *mut i32;
std::mem::forget(list);
let mut occ_num : i32 = 0;
let rc = unsafe { c::trexio_to_orbital_list(n_int, d1, list_c, &mut occ_num) };
match ExitCode::from(rc) {
ExitCode::Success => (),
x => return Err(x)
};
let occ_num = occ_num as usize;
let list = unsafe { Vec::from_raw_parts(list_c, occ_num, cap) };
let mut result: Vec<usize> = Vec::with_capacity(occ_num);
for i in list.iter() {
result.push( *i as usize );
}
Ok(result)
}
/// Converts the determinant into a list of orbital indices (0-based)
pub fn to_orbital_list_up_dn(&self) -> Result< (Vec<usize>, Vec<usize>), ExitCode> {
let n_int : i32 = (self.n_int/2).try_into().expect("try_into failed in to_orbital_list");
let d1 = self.as_ptr();
let cap = self.n_int * 64;
let mut b = vec![ 0i32 ; cap ];
let list_up_c = b.as_mut_ptr() as *mut i32;
std::mem::forget(b);
let mut b = vec![ 0i32 ; cap ];
let list_dn_c = b.as_mut_ptr() as *mut i32;
std::mem::forget(b);
let mut occ_num_up : i32 = 0;
let mut occ_num_dn : i32 = 0;
let rc = unsafe { c::trexio_to_orbital_list_up_dn(n_int, d1, list_up_c, list_dn_c, &mut occ_num_up, &mut occ_num_dn) };
match ExitCode::from(rc) {
ExitCode::Success => (),
x => return Err(x)
};
let occ_num_up = occ_num_up as usize;
let occ_num_dn = occ_num_dn as usize;
let list_up = unsafe { Vec::from_raw_parts(list_up_c, occ_num_up, cap) };
let list_dn = unsafe { Vec::from_raw_parts(list_dn_c, occ_num_dn, cap) };
let mut result_up: Vec<usize> = Vec::with_capacity(occ_num_up);
for i in list_up.iter() {
result_up.push( *i as usize );
}
let mut result_dn: Vec<usize> = Vec::with_capacity(occ_num_dn);
for i in list_dn.iter() {
result_dn.push( *i as usize );
}
Ok( (result_up, result_dn) )
}
}

View File

@ -11,9 +11,13 @@ pub use exit_code::ExitCode;
pub mod back_end;
pub use back_end::BackEnd;
/// Bit fields
pub mod bitfield;
pub use bitfield::Bitfield;
pub const PACKAGE_VERSION : &str = unsafe { std::str::from_utf8_unchecked(c::TREXIO_PACKAGE_VERSION) };
fn rc_return<T>(rc : c::trexio_exit_code, result: T) -> Result<T,ExitCode> {
fn rc_return<T>(result: T, rc : c::trexio_exit_code) -> Result<T,ExitCode> {
let rc = ExitCode::from(rc);
match rc {
ExitCode::Success => Ok(result),
@ -26,6 +30,11 @@ fn string_to_c(s: &str) -> std::ffi::CString {
}
pub fn info() -> Result<(),ExitCode> {
let rc = unsafe { c::trexio_info() };
rc_return((), rc)
}
/// Type for a TREXIO file
pub struct File {
@ -43,12 +52,12 @@ impl File {
let rc: *mut c::trexio_exit_code = &mut c::TREXIO_SUCCESS.clone();
let result = unsafe { c::trexio_open(file_name_c, mode, back_end, rc) };
let rc = unsafe { *rc };
rc_return(rc, File { ptr: result })
rc_return(File { ptr: result }, rc)
}
pub fn close(self) -> Result<(), ExitCode> {
let rc = unsafe { c::trexio_close(self.ptr) };
rc_return(rc, ())
rc_return((), rc)
}
pub fn inquire(file_name: &str) -> Result<bool, ExitCode> {
@ -62,6 +71,44 @@ impl File {
}
}
pub fn get_state(&self) -> Result<usize, ExitCode> {
let mut num = 0i32;
let rc = unsafe { c::trexio_get_state(self.ptr, &mut num) };
let result: usize = num.try_into().expect("try_into failed in get_state");
rc_return(result, rc)
}
pub fn set_state(&self, num: usize) -> Result<(), ExitCode> {
let num: i32 = num.try_into().expect("try_into failed in set_state");
let rc = unsafe { c::trexio_set_state(self.ptr, num) };
rc_return((), rc)
}
pub fn set_one_based(&self) -> Result<(), ExitCode> {
let rc = unsafe { c::trexio_set_one_based(self.ptr) };
rc_return((), rc)
}
pub fn get_int64_num(&self) -> Result<usize, ExitCode> {
let mut num = 0i32;
let rc = unsafe {
c::trexio_get_int64_num(self.ptr, &mut num)
};
let num:usize = num.try_into().expect("try_into failed in get_int64_num");
rc_return(num, rc)
}
/*
pub fn read_determinant_list(&self, offset_file: usize, dset: Vec<Bitfield>) -> Result<usize, ExitCode> {
let rc = unsafe {
let offset_file: i64 = offset_file;
let buffer_size: *mut i64 = dset.len().try_into().expect("try_into failed in read_determinant_list");
let dset: *mut i64 = dset.to_c().as_mut_ptr();
c::trexio_read_determinant_list(self.ptr, offset_file, buffer_size, dset)
};
}
*/
}
include!("generated.rs");

View File

@ -2,6 +2,8 @@ use trexio::back_end::BackEnd;
pub fn test_write(file_name: &str, back_end: BackEnd) -> Result<(), trexio::ExitCode> {
let () = trexio::info()?;
// Prepare data to be written
let n_buffers = 5;
let buf_size_sparse = 100/n_buffers;
@ -10,9 +12,8 @@ pub fn test_write(file_name: &str, back_end: BackEnd) -> Result<(), trexio::Exit
let mut index_sparse_ao_2e_int_eri = vec![0i32 ; 400];
for i in 0..100 {
let i : usize = i;
let j : i32 = i.try_into().unwrap();
let fj : f64 = j.try_into().unwrap();
value_sparse_ao_2e_int_eri[i] = 3.14 + fj;
let j : i32 = i as i32;
value_sparse_ao_2e_int_eri[i] = 3.14 + (j as f64);
index_sparse_ao_2e_int_eri[4*i + 0] = 4*j - 3;
index_sparse_ao_2e_int_eri[4*i + 1] = 4*j+1 - 3;
index_sparse_ao_2e_int_eri[4*i + 2] = 4*j+2 - 3;
@ -91,6 +92,18 @@ pub fn test_write(file_name: &str, back_end: BackEnd) -> Result<(), trexio::Exit
}
trex_file.write_mo_energy(energy)?;
let mut spin = vec![0 ; mo_num];
for i in mo_num/2..mo_num {
spin[i] = 1;
}
trex_file.write_mo_spin(spin)?;
let (det, phase) = trexio::Bitfield::from(4, vec![0, 1, 2, 3, 4, 5, 6, 151, 152, 153, 154])?;
println!("{} {:?}", phase, det);
println!("{:?}", det.to_orbital_list().unwrap());
println!("{:?}", det.to_orbital_list().unwrap());
println!("{:?}", det.to_orbital_list_up_dn().unwrap());
trex_file.close()
}