From 5f5267e34d920ce52e951512cd301cd584ddd87c Mon Sep 17 00:00:00 2001 From: Anthony Scemama Date: Wed, 18 Oct 2023 11:04:39 +0200 Subject: [PATCH] Tests pass --- rust/trexio/build.py | 6 +++++- rust/trexio/src/lib.rs | 16 ++++++++++------ rust/trexio/tests/read_write.rs | 17 ++++++++--------- 3 files changed, 23 insertions(+), 16 deletions(-) diff --git a/rust/trexio/build.py b/rust/trexio/build.py index ecd83f5..3675b11 100755 --- a/rust/trexio/build.py +++ b/rust/trexio/build.py @@ -333,7 +333,11 @@ pub fn read_{group_l}_{element_l}(&self, offset: usize, buffer_size:usize) -> Re let rc = unsafe { c::trexio_read_safe_{group}_{element}(self.ptr, offset, &mut buffer_size_read, idx_ptr, buffer_size_read, val_ptr, buffer_size_read) }; - let buffer_size_read: usize = buffer_size_read.try_into().expect("try_into failed in read_{group}_{element} (buffer_size)"); + let rc = match ExitCode::from(rc) { + ExitCode::End => ExitCode::to_c(&ExitCode::Success), + _ => rc + }; + let buffer_size_read: usize = buffer_size_read.try_into().expect("try_into failed in read_{group}_{element} (buffer_size)"); unsafe { idx.set_len({size}*buffer_size_read) }; unsafe { val.set_len(buffer_size_read) }; let idx: Vec::<&[i32]> = idx.chunks({size}).collect(); diff --git a/rust/trexio/src/lib.rs b/rust/trexio/src/lib.rs index 8a4bbc4..486346a 100644 --- a/rust/trexio/src/lib.rs +++ b/rust/trexio/src/lib.rs @@ -123,17 +123,21 @@ impl File { pub fn read_determinant_list(&self, offset_file: usize, buffer_size: usize) -> Result, ExitCode> { let n_int = self.get_int64_num()?; - let mut one_d_array: Vec = Vec::with_capacity(buffer_size * n_int); + let mut one_d_array: Vec = Vec::with_capacity(buffer_size * 2* n_int); let one_d_array_ptr = one_d_array.as_ptr() as *mut i64; let rc = unsafe { let offset_file: i64 = offset_file.try_into().expect("try_into failed in read_determinant_list (offset_file)"); let mut buffer_size_read: i64 = buffer_size.try_into().expect("try_into failed in read_determinant_list (buffer_size)"); let rc = c::trexio_read_determinant_list(self.ptr, offset_file, &mut buffer_size_read, one_d_array_ptr); - one_d_array.set_len(buffer_size_read.try_into() - .expect("try_into failed in read_determinant_list (buffer_size_read)")); - rc - }; - let result: Vec:: = one_d_array.chunks(n_int) + let buffer_size_read: usize = buffer_size_read.try_into().expect("try_into failed in read_determinant_list (buffer_size)"); + one_d_array.set_len(n_int*2usize*buffer_size_read); + match ExitCode::from(rc) { + ExitCode::End => ExitCode::to_c(&ExitCode::Success), + ExitCode::Success => { assert_eq!(buffer_size_read, buffer_size); rc } + _ => rc + } + }; + let result: Vec:: = one_d_array.chunks(2*n_int) .collect::>() .iter() .map(|x| (Bitfield::from_vec(&x))) diff --git a/rust/trexio/tests/read_write.rs b/rust/trexio/tests/read_write.rs index 8b834ad..fead812 100644 --- a/rust/trexio/tests/read_write.rs +++ b/rust/trexio/tests/read_write.rs @@ -198,10 +198,9 @@ fn read(file_name: &str, back_end: BackEnd) -> Result<(), trexio::ExitCode> { let n_buffers = 8; let bufsize = nmax/n_buffers+10; -/* TODO: check from here */ for i in 0..100 { // Quadruplet of indices + value - let data = (4*i, 4*i+1, 4*i+2, 4*i+3, 3.13 + (i as f64)); + let data = (4*i, 4*i+1, 4*i+2, 4*i+3, 3.14 + (i as f64)); ao_2e_int_eri_ref.push(data); } @@ -209,7 +208,7 @@ fn read(file_name: &str, back_end: BackEnd) -> Result<(), trexio::ExitCode> { let mut ao_2e_int_eri = Vec::<(usize,usize,usize,usize,f64)>::with_capacity(nmax); for _ in 0..n_buffers { let buffer = trex_file.read_ao_2e_int_eri(offset, bufsize)?; - offset += bufsize; + offset += buffer.len(); ao_2e_int_eri.extend(buffer); } assert_eq!(ao_2e_int_eri_ref, ao_2e_int_eri); @@ -229,13 +228,13 @@ fn read(file_name: &str, back_end: BackEnd) -> Result<(), trexio::ExitCode> { } let n_buffers = 8; - let bufsize = det_num/n_buffers + 10; + let bufsize = det_num/n_buffers + 20; let mut offset = 0; let mut det_list: Vec = Vec::with_capacity(det_num); for _ in 0..n_buffers { let buffer = trex_file.read_determinant_list(offset, bufsize)?; + offset += buffer.len(); det_list.extend(buffer); - offset += bufsize; } assert_eq!(det_list_ref, det_list); @@ -253,15 +252,15 @@ use std::fs; #[test] pub fn text_backend() { - let _ = write("tmp/test_write.dir", trexio::BackEnd::Text); - let _ = read("tmp/test_write.dir", trexio::BackEnd::Text); + let _ = write("tmp/test_write.dir", trexio::BackEnd::Text).unwrap(); + let _ = read("tmp/test_write.dir", trexio::BackEnd::Text).unwrap(); fs::remove_dir_all("tmp/test_write.dir").unwrap() } #[test] pub fn hdf5_backend() { - let _ = write("tmp/test_write.hdf5", trexio::BackEnd::Hdf5); - let _ = read("tmp/test_write.hdf5", trexio::BackEnd::Hdf5); + let _ = write("tmp/test_write.hdf5", trexio::BackEnd::Hdf5).unwrap(); + let _ = read("tmp/test_write.hdf5", trexio::BackEnd::Hdf5).unwrap(); fs::remove_file("tmp/test_write.hdf5").unwrap() }