mirror of
https://github.com/TREX-CoE/trexio.git
synced 2024-12-22 20:35:44 +01:00
Added bitfields
This commit is contained in:
parent
06be52c0d1
commit
2afc6160f0
@ -128,13 +128,13 @@ pub fn read_{group_l}_{element_l}(&self) -> Result<{type_r}, ExitCode> {
|
||||
let rc = c::trexio_read_{group}_{element}_64(self.ptr, &mut data_c);
|
||||
(rc, data_c.try_into().expect("try_into failed in read_{group_l}_{element_l}"))
|
||||
};
|
||||
rc_return(rc, data)
|
||||
rc_return(data, rc)
|
||||
}
|
||||
|
||||
pub fn write_{group_l}_{element_l}(&self, data: {type_r}) -> Result<(), ExitCode> {
|
||||
let data: {type_c} = data.try_into().expect("try_into failed in write_{group_l}_{element_l}");
|
||||
let rc = unsafe { c::trexio_write_{group}_{element}_64(self.ptr, data) };
|
||||
rc_return(rc, ())
|
||||
rc_return((), rc)
|
||||
}
|
||||
"""
|
||||
.replace("{type_c}",type_c)
|
||||
@ -153,7 +153,7 @@ pub fn read_{group_l}_{element_l}(&self, capacity: usize) -> Result<String, Exit
|
||||
let rc = c::trexio_read_{group}_{element}(self.ptr, data_c, capacity.try_into().expect("try_into failed in read_{group_l}_{element_l}"));
|
||||
(rc, String::from_raw_parts(data_c as *mut u8, capacity, capacity))
|
||||
};
|
||||
rc_return(rc, data)
|
||||
rc_return(data, rc)
|
||||
}
|
||||
|
||||
pub fn write_{group_l}_{element_l}(&self, data: &str) -> Result<(), ExitCode> {
|
||||
@ -161,7 +161,7 @@ pub fn write_{group_l}_{element_l}(&self, data: &str) -> Result<(), ExitCode> {
|
||||
let data = string_to_c(data);
|
||||
let data = data.as_ptr() as *const c_char;
|
||||
let rc = unsafe { c::trexio_write_{group}_{element}(self.ptr, data, size) };
|
||||
rc_return(rc, ())
|
||||
rc_return((), rc)
|
||||
}
|
||||
"""
|
||||
.replace("{type_c}",type_c)
|
||||
@ -180,7 +180,7 @@ pub fn read_{group_l}_{element_l}(&self) -> Result<{type_r}, ExitCode> {
|
||||
let rc = c::trexio_read_{group}_{element}_64(self.ptr, &mut data_c);
|
||||
(rc, data_c.try_into().expect("try_into failed in read_{group_l}_{element_l}"))
|
||||
};
|
||||
rc_return(rc, data)
|
||||
rc_return(data, rc)
|
||||
}
|
||||
"""
|
||||
.replace("{type_r}",type_r)
|
||||
@ -211,7 +211,7 @@ pub fn read_{group_l}_{element_l}(&self) -> Result<{type_r}, ExitCode> {
|
||||
let rc = c::trexio_read_safe_{group}_{element}_64(self.ptr, data_c, size.try_into().expect("try_into failed in read_{group}_{element}"));
|
||||
(rc, data)
|
||||
};
|
||||
rc_return(rc, data)
|
||||
rc_return(data, rc)
|
||||
}
|
||||
""" ]
|
||||
r += [ '\n'.join(t)
|
||||
@ -227,7 +227,7 @@ pub fn write_{group_l}_{element_l}(&self, data: Vec<{type_r}>) -> Result<(), Exi
|
||||
let size: i64 = data.len().try_into().expect("try_into failed in write_{group_l}_{element_l}");
|
||||
let data = data.as_ptr() as *const {type_c};
|
||||
let rc = unsafe { c::trexio_write_safe_{group}_{element}_64(self.ptr, data, size) };
|
||||
rc_return(rc, ())
|
||||
rc_return((), rc)
|
||||
}
|
||||
"""
|
||||
.replace("{type_c}",type_c)
|
||||
@ -256,7 +256,7 @@ pub fn write_{group_l}_{element_l}(&self, data: Vec<{type_r}>) -> Result<(), Exi
|
||||
let rc = c::trexio_read_{group}_{element}(self.ptr, data_c, capacity.try_into().expect("try_into failed in read_{group}_{element}") );
|
||||
(rc, data)
|
||||
};
|
||||
rc_return(rc, data)
|
||||
rc_return(data, rc)
|
||||
}
|
||||
""" ]
|
||||
r += [ '\n'.join(t)
|
||||
@ -280,7 +280,7 @@ pub fn write_{group_l}_{element_l}(&self, data: Vec<&str>) -> Result<(), ExitCod
|
||||
let size : i32 = size.try_into().expect("try_into failed in write_{group}_{element}");
|
||||
let data_c = data_c.as_ptr() as *mut *const c_char;
|
||||
let rc = unsafe { c::trexio_write_{group}_{element}(self.ptr, data_c, size) };
|
||||
rc_return(rc, ())
|
||||
rc_return((), rc)
|
||||
}
|
||||
"""
|
||||
.replace("{type_c}",type_c)
|
||||
|
138
rust/trexio/src/bitfield.rs
Normal file
138
rust/trexio/src/bitfield.rs
Normal file
@ -0,0 +1,138 @@
|
||||
#[derive(Debug)]
|
||||
pub struct Bitfield {
|
||||
data: Vec<i64>,
|
||||
n_int: usize
|
||||
}
|
||||
|
||||
use crate::c;
|
||||
use crate::File;
|
||||
use crate::ExitCode;
|
||||
|
||||
impl Bitfield {
|
||||
|
||||
/// Creates a new bitfield , using a number of i64 elements consistent
|
||||
/// with the number of MOs in the TREXIO file.
|
||||
pub fn from(n_int: usize, orb_list: Vec<usize>) -> Result<(Self, f64), ExitCode> {
|
||||
|
||||
let orb_list: Vec<i32> = orb_list.iter().map(|&x| x as i32).collect();
|
||||
let occ_num = orb_list.len().try_into().expect("try_into failed in Bitfield::from");
|
||||
let orb_list_ptr = orb_list.as_ptr() as *const i32;
|
||||
let n_int32: i32 = n_int.try_into().expect("try_into failed in Bitfield::from");
|
||||
let mut b = vec![0i64 ; n_int];
|
||||
let bit_list = b.as_mut_ptr() as *mut c::bitfield_t;
|
||||
std::mem::forget(b);
|
||||
let rc = unsafe {
|
||||
c::trexio_to_bitfield_list(orb_list_ptr, occ_num, bit_list, n_int32)
|
||||
};
|
||||
|
||||
let data = unsafe { Vec::from_raw_parts(bit_list, n_int, n_int) };
|
||||
|
||||
let result = Bitfield { data, n_int };
|
||||
|
||||
match ExitCode::from(rc) {
|
||||
ExitCode::Success => Ok( (result, 1.0) ),
|
||||
ExitCode::PhaseChange=> Ok( (result,-1.0) ),
|
||||
x => return Err(x),
|
||||
}
|
||||
}
|
||||
|
||||
/// Number of i64 needed to represent a spin sector
|
||||
pub fn n_int(&self) -> usize {
|
||||
self.n_int
|
||||
}
|
||||
|
||||
/// Returns the alpha part
|
||||
pub fn alpha(&self) -> &[i64] {
|
||||
&self.data[0..self.n_int]
|
||||
}
|
||||
|
||||
/// Returns the beta part
|
||||
pub fn beta(&self) -> &[i64] {
|
||||
let n_int = self.n_int;
|
||||
&self.data[n_int..2*n_int]
|
||||
}
|
||||
|
||||
/// Converts to a format usable in the C library
|
||||
pub fn as_ptr(&self) -> *const c::bitfield_t {
|
||||
let len = self.data.len();
|
||||
let result = &self.data[0..len];
|
||||
result.as_ptr() as *const c::bitfield_t
|
||||
}
|
||||
|
||||
/// Converts to a format usable in the C library
|
||||
pub fn as_mut_ptr(&mut self) -> *mut c::bitfield_t {
|
||||
let len = self.data.len();
|
||||
let result = &mut self.data[0..len];
|
||||
result.as_mut_ptr() as *mut c::bitfield_t
|
||||
}
|
||||
|
||||
/// Converts the bitfield into a list of orbital indices (0-based)
|
||||
pub fn to_orbital_list(&self) -> Result< Vec<usize>, ExitCode> {
|
||||
|
||||
let n_int : i32 = self.n_int.try_into().expect("try_into failed in to_orbital_list");
|
||||
let d1 = self.as_ptr();
|
||||
let cap = self.n_int * 64;
|
||||
let mut list = vec![ 0i32 ; cap ];
|
||||
let list_c = list.as_mut_ptr() as *mut i32;
|
||||
std::mem::forget(list);
|
||||
|
||||
let mut occ_num : i32 = 0;
|
||||
|
||||
let rc = unsafe { c::trexio_to_orbital_list(n_int, d1, list_c, &mut occ_num) };
|
||||
match ExitCode::from(rc) {
|
||||
ExitCode::Success => (),
|
||||
x => return Err(x)
|
||||
};
|
||||
|
||||
let occ_num = occ_num as usize;
|
||||
let list = unsafe { Vec::from_raw_parts(list_c, occ_num, cap) };
|
||||
|
||||
let mut result: Vec<usize> = Vec::with_capacity(occ_num);
|
||||
for i in list.iter() {
|
||||
result.push( *i as usize );
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Converts the determinant into a list of orbital indices (0-based)
|
||||
pub fn to_orbital_list_up_dn(&self) -> Result< (Vec<usize>, Vec<usize>), ExitCode> {
|
||||
|
||||
let n_int : i32 = (self.n_int/2).try_into().expect("try_into failed in to_orbital_list");
|
||||
let d1 = self.as_ptr();
|
||||
let cap = self.n_int * 64;
|
||||
let mut b = vec![ 0i32 ; cap ];
|
||||
let list_up_c = b.as_mut_ptr() as *mut i32;
|
||||
std::mem::forget(b);
|
||||
let mut b = vec![ 0i32 ; cap ];
|
||||
let list_dn_c = b.as_mut_ptr() as *mut i32;
|
||||
std::mem::forget(b);
|
||||
|
||||
let mut occ_num_up : i32 = 0;
|
||||
let mut occ_num_dn : i32 = 0;
|
||||
|
||||
let rc = unsafe { c::trexio_to_orbital_list_up_dn(n_int, d1, list_up_c, list_dn_c, &mut occ_num_up, &mut occ_num_dn) };
|
||||
match ExitCode::from(rc) {
|
||||
ExitCode::Success => (),
|
||||
x => return Err(x)
|
||||
};
|
||||
|
||||
let occ_num_up = occ_num_up as usize;
|
||||
let occ_num_dn = occ_num_dn as usize;
|
||||
let list_up = unsafe { Vec::from_raw_parts(list_up_c, occ_num_up, cap) };
|
||||
let list_dn = unsafe { Vec::from_raw_parts(list_dn_c, occ_num_dn, cap) };
|
||||
|
||||
let mut result_up: Vec<usize> = Vec::with_capacity(occ_num_up);
|
||||
for i in list_up.iter() {
|
||||
result_up.push( *i as usize );
|
||||
}
|
||||
|
||||
let mut result_dn: Vec<usize> = Vec::with_capacity(occ_num_dn);
|
||||
for i in list_dn.iter() {
|
||||
result_dn.push( *i as usize );
|
||||
}
|
||||
Ok( (result_up, result_dn) )
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
@ -11,9 +11,13 @@ pub use exit_code::ExitCode;
|
||||
pub mod back_end;
|
||||
pub use back_end::BackEnd;
|
||||
|
||||
/// Bit fields
|
||||
pub mod bitfield;
|
||||
pub use bitfield::Bitfield;
|
||||
|
||||
pub const PACKAGE_VERSION : &str = unsafe { std::str::from_utf8_unchecked(c::TREXIO_PACKAGE_VERSION) };
|
||||
|
||||
fn rc_return<T>(rc : c::trexio_exit_code, result: T) -> Result<T,ExitCode> {
|
||||
fn rc_return<T>(result: T, rc : c::trexio_exit_code) -> Result<T,ExitCode> {
|
||||
let rc = ExitCode::from(rc);
|
||||
match rc {
|
||||
ExitCode::Success => Ok(result),
|
||||
@ -26,6 +30,11 @@ fn string_to_c(s: &str) -> std::ffi::CString {
|
||||
}
|
||||
|
||||
|
||||
pub fn info() -> Result<(),ExitCode> {
|
||||
let rc = unsafe { c::trexio_info() };
|
||||
rc_return((), rc)
|
||||
}
|
||||
|
||||
|
||||
/// Type for a TREXIO file
|
||||
pub struct File {
|
||||
@ -43,12 +52,12 @@ impl File {
|
||||
let rc: *mut c::trexio_exit_code = &mut c::TREXIO_SUCCESS.clone();
|
||||
let result = unsafe { c::trexio_open(file_name_c, mode, back_end, rc) };
|
||||
let rc = unsafe { *rc };
|
||||
rc_return(rc, File { ptr: result })
|
||||
rc_return(File { ptr: result }, rc)
|
||||
}
|
||||
|
||||
pub fn close(self) -> Result<(), ExitCode> {
|
||||
let rc = unsafe { c::trexio_close(self.ptr) };
|
||||
rc_return(rc, ())
|
||||
rc_return((), rc)
|
||||
}
|
||||
|
||||
pub fn inquire(file_name: &str) -> Result<bool, ExitCode> {
|
||||
@ -62,6 +71,44 @@ impl File {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_state(&self) -> Result<usize, ExitCode> {
|
||||
let mut num = 0i32;
|
||||
let rc = unsafe { c::trexio_get_state(self.ptr, &mut num) };
|
||||
let result: usize = num.try_into().expect("try_into failed in get_state");
|
||||
rc_return(result, rc)
|
||||
}
|
||||
|
||||
pub fn set_state(&self, num: usize) -> Result<(), ExitCode> {
|
||||
let num: i32 = num.try_into().expect("try_into failed in set_state");
|
||||
let rc = unsafe { c::trexio_set_state(self.ptr, num) };
|
||||
rc_return((), rc)
|
||||
}
|
||||
|
||||
pub fn set_one_based(&self) -> Result<(), ExitCode> {
|
||||
let rc = unsafe { c::trexio_set_one_based(self.ptr) };
|
||||
rc_return((), rc)
|
||||
}
|
||||
|
||||
pub fn get_int64_num(&self) -> Result<usize, ExitCode> {
|
||||
let mut num = 0i32;
|
||||
let rc = unsafe {
|
||||
c::trexio_get_int64_num(self.ptr, &mut num)
|
||||
};
|
||||
let num:usize = num.try_into().expect("try_into failed in get_int64_num");
|
||||
rc_return(num, rc)
|
||||
}
|
||||
|
||||
/*
|
||||
pub fn read_determinant_list(&self, offset_file: usize, dset: Vec<Bitfield>) -> Result<usize, ExitCode> {
|
||||
let rc = unsafe {
|
||||
let offset_file: i64 = offset_file;
|
||||
let buffer_size: *mut i64 = dset.len().try_into().expect("try_into failed in read_determinant_list");
|
||||
let dset: *mut i64 = dset.to_c().as_mut_ptr();
|
||||
c::trexio_read_determinant_list(self.ptr, offset_file, buffer_size, dset)
|
||||
};
|
||||
}
|
||||
*/
|
||||
|
||||
}
|
||||
include!("generated.rs");
|
||||
|
||||
|
@ -2,6 +2,8 @@ use trexio::back_end::BackEnd;
|
||||
|
||||
pub fn test_write(file_name: &str, back_end: BackEnd) -> Result<(), trexio::ExitCode> {
|
||||
|
||||
let () = trexio::info()?;
|
||||
|
||||
// Prepare data to be written
|
||||
let n_buffers = 5;
|
||||
let buf_size_sparse = 100/n_buffers;
|
||||
@ -10,9 +12,8 @@ pub fn test_write(file_name: &str, back_end: BackEnd) -> Result<(), trexio::Exit
|
||||
let mut index_sparse_ao_2e_int_eri = vec![0i32 ; 400];
|
||||
for i in 0..100 {
|
||||
let i : usize = i;
|
||||
let j : i32 = i.try_into().unwrap();
|
||||
let fj : f64 = j.try_into().unwrap();
|
||||
value_sparse_ao_2e_int_eri[i] = 3.14 + fj;
|
||||
let j : i32 = i as i32;
|
||||
value_sparse_ao_2e_int_eri[i] = 3.14 + (j as f64);
|
||||
index_sparse_ao_2e_int_eri[4*i + 0] = 4*j - 3;
|
||||
index_sparse_ao_2e_int_eri[4*i + 1] = 4*j+1 - 3;
|
||||
index_sparse_ao_2e_int_eri[4*i + 2] = 4*j+2 - 3;
|
||||
@ -91,6 +92,18 @@ pub fn test_write(file_name: &str, back_end: BackEnd) -> Result<(), trexio::Exit
|
||||
}
|
||||
trex_file.write_mo_energy(energy)?;
|
||||
|
||||
let mut spin = vec![0 ; mo_num];
|
||||
for i in mo_num/2..mo_num {
|
||||
spin[i] = 1;
|
||||
}
|
||||
trex_file.write_mo_spin(spin)?;
|
||||
|
||||
let (det, phase) = trexio::Bitfield::from(4, vec![0, 1, 2, 3, 4, 5, 6, 151, 152, 153, 154])?;
|
||||
println!("{} {:?}", phase, det);
|
||||
println!("{:?}", det.to_orbital_list().unwrap());
|
||||
println!("{:?}", det.to_orbital_list().unwrap());
|
||||
println!("{:?}", det.to_orbital_list_up_dn().unwrap());
|
||||
|
||||
trex_file.close()
|
||||
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user