From 601dfbba89a41682e402f9a1e700432ba1c78e1a Mon Sep 17 00:00:00 2001 From: Anthony Scemama Date: Fri, 13 Oct 2023 11:39:50 +0200 Subject: [PATCH] Added bitfields --- rust/trexio/src/back_end.rs | 1 + rust/trexio/src/bitfield.rs | 66 +++++++++++++++++++++--------------- rust/trexio/src/exit_code.rs | 1 + 3 files changed, 40 insertions(+), 28 deletions(-) diff --git a/rust/trexio/src/back_end.rs b/rust/trexio/src/back_end.rs index 15c981c..b26ce7e 100644 --- a/rust/trexio/src/back_end.rs +++ b/rust/trexio/src/back_end.rs @@ -2,6 +2,7 @@ use crate::c; /// Possible back ends #[derive(Debug)] +#[derive(PartialEq)] pub enum BackEnd { Text, Hdf5, diff --git a/rust/trexio/src/bitfield.rs b/rust/trexio/src/bitfield.rs index 8b1b6b1..fee35da 100644 --- a/rust/trexio/src/bitfield.rs +++ b/rust/trexio/src/bitfield.rs @@ -1,4 +1,5 @@ #[derive(Debug)] +#[derive(PartialEq)] pub struct Bitfield { data: Vec, } @@ -10,7 +11,7 @@ impl Bitfield { /// Creates a new bitfield , using a number of i64 elements consistent /// with the number of MOs in the TREXIO file. - pub fn from(n_int: usize, orb_list: &[usize]) -> Result<(Self, f64), ExitCode> { + pub fn from(n_int: usize, orb_list: &[usize]) -> (Self, f64) { let orb_list: Vec = orb_list.iter().map(|&x| x as i32).collect(); let occ_num = orb_list.len().try_into().expect("try_into failed in Bitfield::from"); @@ -28,31 +29,31 @@ impl Bitfield { let result = Bitfield { data }; match ExitCode::from(rc) { - ExitCode::Success => Ok( (result, 1.0) ), - ExitCode::PhaseChange=> Ok( (result,-1.0) ), - x => return Err(x), + ExitCode::Success => (result, 1.0), + ExitCode::PhaseChange=> (result,-1.0), + x => panic!("TREXIO Error {}", x) } } - pub fn from_alpha_beta(alpha: Bitfield, beta: Bitfield) -> Result { + pub fn from_alpha_beta(alpha: &Bitfield, beta: &Bitfield) -> Bitfield { if alpha.data.len() != beta.data.len() { - return Err(ExitCode::InvalidArg2) + panic!("alpha and beta parts have different lengths"); }; let mut data = alpha.data.clone(); data.extend_from_slice(&beta.data); - Ok(Bitfield { data }) + Bitfield { data } } /// Returns the alpha part - pub fn alpha(&self) -> &[i64] { + pub fn alpha(&self) -> Bitfield { let n_int = self.data.len()/2; - &self.data[0..n_int] + Bitfield { data: (&self.data[0..n_int]).to_vec() } } /// Returns the beta part - pub fn beta(&self) -> &[i64] { + pub fn beta(&self) -> Bitfield { let n_int = self.data.len()/2; - &self.data[n_int..2*n_int] + Bitfield { data: (&self.data[n_int..2*n_int]).to_vec() } } /// Converts to a format usable in the C library @@ -70,7 +71,7 @@ impl Bitfield { } /// Converts the bitfield into a list of orbital indices (0-based) - pub fn to_orbital_list(&self) -> Result< Vec, ExitCode> { + pub fn to_orbital_list(&self) -> Vec { let n_int : i32 = self.data.len().try_into().expect("try_into failed in to_orbital_list"); let d1 = self.as_ptr(); @@ -84,7 +85,7 @@ impl Bitfield { let rc = unsafe { c::trexio_to_orbital_list(n_int, d1, list_c, &mut occ_num) }; match ExitCode::from(rc) { ExitCode::Success => (), - x => return Err(x) + x => panic!("TREXIO Error {}", x) }; let occ_num = occ_num as usize; @@ -94,11 +95,11 @@ impl Bitfield { for i in list.iter() { result.push( *i as usize ); } - Ok(result) + result } /// Converts the determinant into a list of orbital indices (0-based) - pub fn to_orbital_list_up_dn(&self) -> Result< (Vec, Vec), ExitCode> { + pub fn to_orbital_list_up_dn(&self) -> (Vec, Vec) { let n_int : i32 = (self.data.len()/2).try_into().expect("try_into failed in to_orbital_list"); let d1 = self.as_ptr(); @@ -116,7 +117,7 @@ impl Bitfield { let rc = unsafe { c::trexio_to_orbital_list_up_dn(n_int, d1, list_up_c, list_dn_c, &mut occ_num_up, &mut occ_num_dn) }; match ExitCode::from(rc) { ExitCode::Success => (), - x => return Err(x) + x => panic!("TREXIO Error {}", x) }; let occ_num_up = occ_num_up as usize; @@ -133,7 +134,7 @@ impl Bitfield { for i in list_dn.iter() { result_dn.push( *i as usize ); } - Ok( (result_up, result_dn) ) + (result_up, result_dn) } } @@ -150,17 +151,17 @@ mod tests { let list1 = vec![0, 1, 2, 4, 3]; let list2 = vec![0, 1, 4, 2, 3]; - let (alpha0, phase0) = Bitfield::from(2, &list0).unwrap(); - let list = alpha0.to_orbital_list().unwrap(); + let (alpha0, phase0) = Bitfield::from(2, &list0); + let list = alpha0.to_orbital_list(); assert_eq!(list, list0); - let (alpha1, phase1) = Bitfield::from(2, &list1).unwrap(); - let list = alpha1.to_orbital_list().unwrap(); + let (alpha1, phase1) = Bitfield::from(2, &list1); + let list = alpha1.to_orbital_list(); assert_eq!(list, list0); assert_eq!(phase1, -phase0); - let (alpha2, phase2) = Bitfield::from(2, &list2).unwrap(); - let list = alpha2.to_orbital_list().unwrap(); + let (alpha2, phase2) = Bitfield::from(2, &list2); + let list = alpha2.to_orbital_list(); assert_eq!(list, list0); assert_eq!(phase2, phase0); @@ -168,11 +169,20 @@ mod tests { #[test] fn creation_alpha_beta() { - - let (alpha, _) = Bitfield::from(2, &[0, 1, 2, 3, 4]).unwrap(); - let (beta , _) = Bitfield::from(2, &[0, 1, 2, 4, 5]).unwrap(); - let det = Bitfield::from_alpha_beta(alpha, beta).unwrap(); - let list = det.to_orbital_list().unwrap(); + let (alpha, _) = Bitfield::from(2, &[0, 1, 2, 3, 4]); + let (beta , _) = Bitfield::from(2, &[0, 1, 2, 4, 5]); + let det = Bitfield::from_alpha_beta(&alpha, &beta); + let list = det.to_orbital_list(); assert_eq!(list, [0,1,2,3,4,128,129,130,132,133]); + assert_eq!(det.alpha(), alpha); + assert_eq!(det.beta(), beta); + } + + #[test] + #[should_panic] + fn creation_alpha_beta_with_different_nint() { + let (alpha, _) = Bitfield::from(1, &[0, 1, 2, 3, 4]); + let (beta , _) = Bitfield::from(2, &[0, 1, 2, 4, 5]); + let _ = Bitfield::from_alpha_beta(&alpha, &beta); } } diff --git a/rust/trexio/src/exit_code.rs b/rust/trexio/src/exit_code.rs index 74cc08d..2e20958 100644 --- a/rust/trexio/src/exit_code.rs +++ b/rust/trexio/src/exit_code.rs @@ -2,6 +2,7 @@ use crate::c; /// Exit codes #[derive(Debug)] +#[derive(PartialEq)] pub enum ExitCode { Failure, Success,