1
0
mirror of https://github.com/TREX-CoE/trexio.git synced 2024-12-22 20:35:44 +01:00

Added bitfields

This commit is contained in:
Anthony Scemama 2023-10-13 11:39:50 +02:00
parent 1e7dfc8ddd
commit 601dfbba89
3 changed files with 40 additions and 28 deletions

View File

@ -2,6 +2,7 @@ use crate::c;
/// Possible back ends
#[derive(Debug)]
#[derive(PartialEq)]
pub enum BackEnd {
Text,
Hdf5,

View File

@ -1,4 +1,5 @@
#[derive(Debug)]
#[derive(PartialEq)]
pub struct Bitfield {
data: Vec<i64>,
}
@ -10,7 +11,7 @@ impl Bitfield {
/// Creates a new bitfield , using a number of i64 elements consistent
/// with the number of MOs in the TREXIO file.
pub fn from(n_int: usize, orb_list: &[usize]) -> Result<(Self, f64), ExitCode> {
pub fn from(n_int: usize, orb_list: &[usize]) -> (Self, f64) {
let orb_list: Vec<i32> = orb_list.iter().map(|&x| x as i32).collect();
let occ_num = orb_list.len().try_into().expect("try_into failed in Bitfield::from");
@ -28,31 +29,31 @@ impl Bitfield {
let result = Bitfield { data };
match ExitCode::from(rc) {
ExitCode::Success => Ok( (result, 1.0) ),
ExitCode::PhaseChange=> Ok( (result,-1.0) ),
x => return Err(x),
ExitCode::Success => (result, 1.0),
ExitCode::PhaseChange=> (result,-1.0),
x => panic!("TREXIO Error {}", x)
}
}
pub fn from_alpha_beta(alpha: Bitfield, beta: Bitfield) -> Result<Bitfield, ExitCode> {
pub fn from_alpha_beta(alpha: &Bitfield, beta: &Bitfield) -> Bitfield {
if alpha.data.len() != beta.data.len() {
return Err(ExitCode::InvalidArg2)
panic!("alpha and beta parts have different lengths");
};
let mut data = alpha.data.clone();
data.extend_from_slice(&beta.data);
Ok(Bitfield { data })
Bitfield { data }
}
/// Returns the alpha part
pub fn alpha(&self) -> &[i64] {
pub fn alpha(&self) -> Bitfield {
let n_int = self.data.len()/2;
&self.data[0..n_int]
Bitfield { data: (&self.data[0..n_int]).to_vec() }
}
/// Returns the beta part
pub fn beta(&self) -> &[i64] {
pub fn beta(&self) -> Bitfield {
let n_int = self.data.len()/2;
&self.data[n_int..2*n_int]
Bitfield { data: (&self.data[n_int..2*n_int]).to_vec() }
}
/// Converts to a format usable in the C library
@ -70,7 +71,7 @@ impl Bitfield {
}
/// Converts the bitfield into a list of orbital indices (0-based)
pub fn to_orbital_list(&self) -> Result< Vec<usize>, ExitCode> {
pub fn to_orbital_list(&self) -> Vec<usize> {
let n_int : i32 = self.data.len().try_into().expect("try_into failed in to_orbital_list");
let d1 = self.as_ptr();
@ -84,7 +85,7 @@ impl Bitfield {
let rc = unsafe { c::trexio_to_orbital_list(n_int, d1, list_c, &mut occ_num) };
match ExitCode::from(rc) {
ExitCode::Success => (),
x => return Err(x)
x => panic!("TREXIO Error {}", x)
};
let occ_num = occ_num as usize;
@ -94,11 +95,11 @@ impl Bitfield {
for i in list.iter() {
result.push( *i as usize );
}
Ok(result)
result
}
/// Converts the determinant into a list of orbital indices (0-based)
pub fn to_orbital_list_up_dn(&self) -> Result< (Vec<usize>, Vec<usize>), ExitCode> {
pub fn to_orbital_list_up_dn(&self) -> (Vec<usize>, Vec<usize>) {
let n_int : i32 = (self.data.len()/2).try_into().expect("try_into failed in to_orbital_list");
let d1 = self.as_ptr();
@ -116,7 +117,7 @@ impl Bitfield {
let rc = unsafe { c::trexio_to_orbital_list_up_dn(n_int, d1, list_up_c, list_dn_c, &mut occ_num_up, &mut occ_num_dn) };
match ExitCode::from(rc) {
ExitCode::Success => (),
x => return Err(x)
x => panic!("TREXIO Error {}", x)
};
let occ_num_up = occ_num_up as usize;
@ -133,7 +134,7 @@ impl Bitfield {
for i in list_dn.iter() {
result_dn.push( *i as usize );
}
Ok( (result_up, result_dn) )
(result_up, result_dn)
}
}
@ -150,17 +151,17 @@ mod tests {
let list1 = vec![0, 1, 2, 4, 3];
let list2 = vec![0, 1, 4, 2, 3];
let (alpha0, phase0) = Bitfield::from(2, &list0).unwrap();
let list = alpha0.to_orbital_list().unwrap();
let (alpha0, phase0) = Bitfield::from(2, &list0);
let list = alpha0.to_orbital_list();
assert_eq!(list, list0);
let (alpha1, phase1) = Bitfield::from(2, &list1).unwrap();
let list = alpha1.to_orbital_list().unwrap();
let (alpha1, phase1) = Bitfield::from(2, &list1);
let list = alpha1.to_orbital_list();
assert_eq!(list, list0);
assert_eq!(phase1, -phase0);
let (alpha2, phase2) = Bitfield::from(2, &list2).unwrap();
let list = alpha2.to_orbital_list().unwrap();
let (alpha2, phase2) = Bitfield::from(2, &list2);
let list = alpha2.to_orbital_list();
assert_eq!(list, list0);
assert_eq!(phase2, phase0);
@ -168,11 +169,20 @@ mod tests {
#[test]
fn creation_alpha_beta() {
let (alpha, _) = Bitfield::from(2, &[0, 1, 2, 3, 4]).unwrap();
let (beta , _) = Bitfield::from(2, &[0, 1, 2, 4, 5]).unwrap();
let det = Bitfield::from_alpha_beta(alpha, beta).unwrap();
let list = det.to_orbital_list().unwrap();
let (alpha, _) = Bitfield::from(2, &[0, 1, 2, 3, 4]);
let (beta , _) = Bitfield::from(2, &[0, 1, 2, 4, 5]);
let det = Bitfield::from_alpha_beta(&alpha, &beta);
let list = det.to_orbital_list();
assert_eq!(list, [0,1,2,3,4,128,129,130,132,133]);
assert_eq!(det.alpha(), alpha);
assert_eq!(det.beta(), beta);
}
#[test]
#[should_panic]
fn creation_alpha_beta_with_different_nint() {
let (alpha, _) = Bitfield::from(1, &[0, 1, 2, 3, 4]);
let (beta , _) = Bitfield::from(2, &[0, 1, 2, 4, 5]);
let _ = Bitfield::from_alpha_beta(&alpha, &beta);
}
}

View File

@ -2,6 +2,7 @@ use crate::c;
/// Exit codes
#[derive(Debug)]
#[derive(PartialEq)]
pub enum ExitCode {
Failure,
Success,