From 38946c8ab7019752a8928dab5aea2ec4aaf24e8c Mon Sep 17 00:00:00 2001 From: Anthony Scemama Date: Mon, 16 Oct 2023 12:52:04 +0200 Subject: [PATCH] Added sparse write --- rust/trexio/build.py | 37 ++++++++- rust/trexio/tests/{write.rs => read_write.rs} | 80 +++++++++---------- rust/trexio/tmp/.gitignore | 1 + 3 files changed, 74 insertions(+), 44 deletions(-) rename rust/trexio/tests/{write.rs => read_write.rs} (72%) create mode 100644 rust/trexio/tmp/.gitignore diff --git a/rust/trexio/build.py b/rust/trexio/build.py index cf6746e..02ffc15 100755 --- a/rust/trexio/build.py +++ b/rust/trexio/build.py @@ -260,8 +260,6 @@ pub fn write_{group_l}_{element_l}(&self, data: Vec<{type_r}>) -> Result<(), Exi } """ ] r += [ '\n'.join(t) -.replace("{type_c}",type_c) -.replace("{type_r}",type_r) .replace("{group}",group) .replace("{group_l}",group_l) .replace("{element}",element) @@ -270,6 +268,7 @@ pub fn write_{group_l}_{element_l}(&self, data: Vec<{type_r}>) -> Result<(), Exi r += [ """ pub fn write_{group_l}_{element_l}(&self, data: Vec<&str>) -> Result<(), ExitCode> { let mut size = 0; + // Find longest string for s in data.iter() { let l = s.len(); size = if l>size {l} else {size}; @@ -283,6 +282,38 @@ pub fn write_{group_l}_{element_l}(&self, data: Vec<&str>) -> Result<(), ExitCod rc_return((), rc) } """ +.replace("{group}",group) +.replace("{group_l}",group_l) +.replace("{element}",element) +.replace("{element_l}",element_l) ] + + elif data[group][element][0] in [ "float sparse" ]: + size = len(data[group][element][1]) + typ = "&[(" + ",".join( [ "usize" for _ in range(size) ]) + ", f64)]" + r += [ (""" +pub fn write_{group_l}_{element_l}(&self, offset: usize, data: {typ}) -> Result<(), ExitCode> { + let mut idx = Vec::::with_capacity({size}*data.len()); + let mut val = Vec::::with_capacity(data.len()); + // Array of indices + for d in data.iter() { +""" + +'\n'.join([ f" idx.push(d.{i}.try_into().unwrap());" for i in range(size) ]) + +f"\n val.push(d.{size});" + +""" + } + + let size_max: i64 = data.len().try_into().expect("try_into failed in write_{group}_{element}"); + let buffer_size = size_max; + let idx_ptr = idx.as_ptr() as *const i32; + let val_ptr = val.as_ptr() as *const f64; + let offset: i64 = offset.try_into().expect("try_into failed in write_{group}_{element}"); + let rc = unsafe { c::trexio_write_safe_{group}_{element}(self.ptr, + offset, buffer_size, idx_ptr, size_max, val_ptr, size_max) }; + rc_return((), rc) +} +""") +.replace("{size}",str(size)) +.replace("{typ}",typ) .replace("{type_c}",type_c) .replace("{type_r}",type_r) .replace("{group}",group) @@ -290,8 +321,6 @@ pub fn write_{group_l}_{element_l}(&self, data: Vec<&str>) -> Result<(), ExitCod .replace("{element}",element) .replace("{element_l}",element_l) ] - elif data[group][element][0] in [ "float sparse" ]: - pass diff --git a/rust/trexio/tests/write.rs b/rust/trexio/tests/read_write.rs similarity index 72% rename from rust/trexio/tests/write.rs rename to rust/trexio/tests/read_write.rs index abcaf81..13b072f 100644 --- a/rust/trexio/tests/write.rs +++ b/rust/trexio/tests/read_write.rs @@ -6,20 +6,6 @@ fn write(file_name: &str, back_end: BackEnd) -> Result<(), trexio::ExitCode> { // Prepare data to be written - let n_buffers = 5; - let buf_size_sparse = 100/n_buffers; - let mut value_sparse_ao_2e_int_eri = vec![0.0f64 ; 100]; - let mut index_sparse_ao_2e_int_eri = vec![0i32 ; 400]; - for i in 0..100 { - let i : usize = i; - let j : i32 = i as i32; - value_sparse_ao_2e_int_eri[i] = 3.14 + (j as f64); - index_sparse_ao_2e_int_eri[4*i + 0] = 4*j - 3; - index_sparse_ao_2e_int_eri[4*i + 1] = 4*j+1 - 3; - index_sparse_ao_2e_int_eri[4*i + 2] = 4*j+2 - 3; - index_sparse_ao_2e_int_eri[4*i + 3] = 4*j+3 - 3; - } - let nucleus_num = 12; let state_id = 2; let charge = vec![6., 6., 6., 6., 6., 6., 1., 1., 1., 1., 1., 1.0f64]; @@ -38,21 +24,11 @@ fn write(file_name: &str, back_end: BackEnd) -> Result<(), trexio::ExitCode> { let mo_num = 150; let ao_num = 1000; let basis_shell_num = 24; - let basis_nucleus_index = vec![ 0usize, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23 ]; + let basis_nucleus_index: Vec = (0..24).collect(); + + let label = vec![ "C", "Na", "C", "C 66", "C", + "C", "H 99", "Ru", "H", "H", "H", "H" ]; - let label = vec![ - "C", - "Na", - "C", - "C 66", - "C", - "C", - "H 99", - "Ru", - "H", - "H", - "H", - "H" ]; let sym_str = "B3U with some comments"; @@ -98,8 +74,27 @@ fn write(file_name: &str, back_end: BackEnd) -> Result<(), trexio::ExitCode> { } trex_file.write_mo_spin(spin)?; + // Integrals + let nmax = 100; + let mut ao_2e_int_eri = Vec::<(usize,usize,usize,usize,f64)>::with_capacity(nmax); + + let n_buffers = 5; + let bufsize = nmax/n_buffers; + + for i in 0..100 { + // Quadruplet of indices + value + let data = (4*i, 4*i+1, 4*i+2, 4*i+3, 3.14 + (i as f64)); + ao_2e_int_eri.push(data); + } + + let mut offset = 0; + for i in 0..n_buffers { + trex_file.write_ao_2e_int_eri(offset, &ao_2e_int_eri[offset..offset+bufsize])?; + offset += bufsize; + } + + // Determinants - // let det_num = 50; let mut det_list = Vec::with_capacity(det_num); for i in 0..det_num { @@ -111,11 +106,11 @@ fn write(file_name: &str, back_end: BackEnd) -> Result<(), trexio::ExitCode> { } let n_buffers = 5; - let buf_size_det = 50/n_buffers; + let bufsize = 50/n_buffers; let mut offset = 0; for i in 0..n_buffers { - trex_file.write_determinant_list(offset, &det_list[offset..offset+buf_size_det])?; - offset += buf_size_det; + trex_file.write_determinant_list(offset, &det_list[offset..offset+bufsize])?; + offset += bufsize; } @@ -125,16 +120,21 @@ fn write(file_name: &str, back_end: BackEnd) -> Result<(), trexio::ExitCode> { #[test] pub fn info() { - trexio::info(); + let _ = trexio::info(); +} + + +use std::fs; + +#[test] +pub fn text_backend() { + let _ = write("tmp/test_write.dir", trexio::BackEnd::Text); + fs::remove_dir_all("tmp/test_write.dir").unwrap() } #[test] -pub fn wite_text() { - write("test_write.dir", trexio::BackEnd::Text); -} - -#[test] -pub fn wite_hdf5() { - write("test_write.hdf5", trexio::BackEnd::Hdf5); +pub fn hdf5_backend() { + let _ = write("tmp/test_write.hdf5", trexio::BackEnd::Hdf5); + fs::remove_file("tmp/test_write.hdf5").unwrap() } diff --git a/rust/trexio/tmp/.gitignore b/rust/trexio/tmp/.gitignore new file mode 100644 index 0000000..72e8ffc --- /dev/null +++ b/rust/trexio/tmp/.gitignore @@ -0,0 +1 @@ +*