From 1e7dfc8ddd068de5a33f22a74a6255e071918a8b Mon Sep 17 00:00:00 2001 From: Anthony Scemama Date: Fri, 13 Oct 2023 11:17:13 +0200 Subject: [PATCH] Tests in bitfields --- rust/trexio/src/bitfield.rs | 66 +++++++++++++++++++++++++++++-------- rust/trexio/src/lib.rs | 14 -------- rust/trexio/src/test.rs | 5 --- 3 files changed, 53 insertions(+), 32 deletions(-) diff --git a/rust/trexio/src/bitfield.rs b/rust/trexio/src/bitfield.rs index 5303c44..8b1b6b1 100644 --- a/rust/trexio/src/bitfield.rs +++ b/rust/trexio/src/bitfield.rs @@ -1,18 +1,16 @@ #[derive(Debug)] pub struct Bitfield { data: Vec, - n_int: usize } use crate::c; -use crate::File; use crate::ExitCode; impl Bitfield { /// Creates a new bitfield , using a number of i64 elements consistent /// with the number of MOs in the TREXIO file. - pub fn from(n_int: usize, orb_list: Vec) -> Result<(Self, f64), ExitCode> { + pub fn from(n_int: usize, orb_list: &[usize]) -> Result<(Self, f64), ExitCode> { let orb_list: Vec = orb_list.iter().map(|&x| x as i32).collect(); let occ_num = orb_list.len().try_into().expect("try_into failed in Bitfield::from"); @@ -27,7 +25,7 @@ impl Bitfield { let data = unsafe { Vec::from_raw_parts(bit_list, n_int, n_int) }; - let result = Bitfield { data, n_int }; + let result = Bitfield { data }; match ExitCode::from(rc) { ExitCode::Success => Ok( (result, 1.0) ), @@ -36,19 +34,24 @@ impl Bitfield { } } - /// Number of i64 needed to represent a spin sector - pub fn n_int(&self) -> usize { - self.n_int + pub fn from_alpha_beta(alpha: Bitfield, beta: Bitfield) -> Result { + if alpha.data.len() != beta.data.len() { + return Err(ExitCode::InvalidArg2) + }; + let mut data = alpha.data.clone(); + data.extend_from_slice(&beta.data); + Ok(Bitfield { data }) } /// Returns the alpha part pub fn alpha(&self) -> &[i64] { - &self.data[0..self.n_int] + let n_int = self.data.len()/2; + &self.data[0..n_int] } /// Returns the beta part pub fn beta(&self) -> &[i64] { - let n_int = self.n_int; + let n_int = self.data.len()/2; &self.data[n_int..2*n_int] } @@ -69,9 +72,9 @@ impl Bitfield { /// Converts the bitfield into a list of orbital indices (0-based) pub fn to_orbital_list(&self) -> Result< Vec, ExitCode> { - let n_int : i32 = self.n_int.try_into().expect("try_into failed in to_orbital_list"); + let n_int : i32 = self.data.len().try_into().expect("try_into failed in to_orbital_list"); let d1 = self.as_ptr(); - let cap = self.n_int * 64; + let cap = self.data.len() * 64; let mut list = vec![ 0i32 ; cap ]; let list_c = list.as_mut_ptr() as *mut i32; std::mem::forget(list); @@ -97,9 +100,9 @@ impl Bitfield { /// Converts the determinant into a list of orbital indices (0-based) pub fn to_orbital_list_up_dn(&self) -> Result< (Vec, Vec), ExitCode> { - let n_int : i32 = (self.n_int/2).try_into().expect("try_into failed in to_orbital_list"); + let n_int : i32 = (self.data.len()/2).try_into().expect("try_into failed in to_orbital_list"); let d1 = self.as_ptr(); - let cap = self.n_int * 64; + let cap = self.data.len()/2 * 64; let mut b = vec![ 0i32 ; cap ]; let list_up_c = b.as_mut_ptr() as *mut i32; std::mem::forget(b); @@ -136,3 +139,40 @@ impl Bitfield { } +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn creation_from_list() { + + let list0 = vec![0, 1, 2, 3, 4]; + let list1 = vec![0, 1, 2, 4, 3]; + let list2 = vec![0, 1, 4, 2, 3]; + + let (alpha0, phase0) = Bitfield::from(2, &list0).unwrap(); + let list = alpha0.to_orbital_list().unwrap(); + assert_eq!(list, list0); + + let (alpha1, phase1) = Bitfield::from(2, &list1).unwrap(); + let list = alpha1.to_orbital_list().unwrap(); + assert_eq!(list, list0); + assert_eq!(phase1, -phase0); + + let (alpha2, phase2) = Bitfield::from(2, &list2).unwrap(); + let list = alpha2.to_orbital_list().unwrap(); + assert_eq!(list, list0); + assert_eq!(phase2, phase0); + + } + + #[test] + fn creation_alpha_beta() { + + let (alpha, _) = Bitfield::from(2, &[0, 1, 2, 3, 4]).unwrap(); + let (beta , _) = Bitfield::from(2, &[0, 1, 2, 4, 5]).unwrap(); + let det = Bitfield::from_alpha_beta(alpha, beta).unwrap(); + let list = det.to_orbital_list().unwrap(); + assert_eq!(list, [0,1,2,3,4,128,129,130,132,133]); + } +} diff --git a/rust/trexio/src/lib.rs b/rust/trexio/src/lib.rs index 9956bb7..89f704f 100644 --- a/rust/trexio/src/lib.rs +++ b/rust/trexio/src/lib.rs @@ -113,17 +113,3 @@ impl File { include!("generated.rs"); -#[cfg(test)] -mod tests { - use super::*; - use std::mem; - use c::*; - - #[test] - fn read_trexio_file() { - println!("============================================"); - println!(" TREXIO MAJOR VERSION : {}", TREXIO_VERSION_MAJOR); - println!("============================================"); - - } -} diff --git a/rust/trexio/src/test.rs b/rust/trexio/src/test.rs index ac0d4e4..766fe45 100644 --- a/rust/trexio/src/test.rs +++ b/rust/trexio/src/test.rs @@ -98,11 +98,6 @@ pub fn test_write(file_name: &str, back_end: BackEnd) -> Result<(), trexio::Exit } trex_file.write_mo_spin(spin)?; - let (det, phase) = trexio::Bitfield::from(4, vec![0, 1, 2, 3, 4, 5, 6, 151, 152, 153, 154])?; - println!("{} {:?}", phase, det); - println!("{:?}", det.to_orbital_list().unwrap()); - println!("{:?}", det.to_orbital_list().unwrap()); - println!("{:?}", det.to_orbital_list_up_dn().unwrap()); trex_file.close()