From 91585a6811ee363ed082d16564c4d58699896ae4 Mon Sep 17 00:00:00 2001 From: q-posev Date: Thu, 2 Jan 2025 21:29:36 +0100 Subject: [PATCH] Fix Python tests: align with the C ones --- python/test/test_api.py | 104 ++++++++++++++++++++++++++++++---------- 1 file changed, 80 insertions(+), 24 deletions(-) diff --git a/python/test/test_api.py b/python/test/test_api.py index 3ce3d9a..1660b22 100644 --- a/python/test/test_api.py +++ b/python/test/test_api.py @@ -6,6 +6,62 @@ import trexio from benzene_data import * +# this function is copied from the trexio-tools github repository (BSD-3 license): +# https://github.com/TREX-CoE/trexio_tools/blob/master/src/trexio_tools/group_tools/determinant.py +def to_determinant_list(orbital_list: list, int64_num: int) -> list: + """ + Convert a list of occupied orbitals from the `orbital_list` + into a list of Slater determinants (in their bit string representation). + + Orbitals in the `orbital_list` should be 0-based, namely the lowest orbital has index 0, not 1. + + int64_num is the number of 64-bit integers needed to represent a Slater determinant bit string. + It depends on the number of molecular orbitals as follows: int64_num = int((mo_num-1)/64) + 1 + """ + + if not isinstance(orbital_list, list): + raise TypeError(f"orbital_list should be a list, not {type(orbital_list)}") + + det_list = [] + bitfield = 0 + shift = 0 + + # since orbital indices are 0-based but the code below works for 1-based --> increment the input indices by +1 + orb_list_upshifted = [ orb+1 for orb in orbital_list] + + # orbital list has to be sorted in increasing order for the bitfields to be set correctly + orb_list_sorted = sorted(orb_list_upshifted) + + for orb in orb_list_sorted: + + if orb-shift > 64: + # this removes the 0 bit from the beginning of the bitfield + bitfield = bitfield >> 1 + # append a bitfield to the list + det_list.append(bitfield) + bitfield = 0 + + modulo = int((orb-1)/64) + shift = modulo*64 + bitfield |= (1 << (orb-shift)) + + # this removes the 0 bit from the beginning of the bitfield + bitfield = bitfield >> 1 + det_list.append(bitfield) + #print('Popcounts: ', [bin(d).count('1') for d in det_list) + #print('Bitfields: ', [bin(d) for d in det_list]) + + bitfield_num = len(det_list) + if bitfield_num > int64_num: + raise Exception(f'Number of bitfields {bitfield_num} cannot be more than the int64_num {int64_num}.') + if bitfield_num < int64_num: + for _ in range(int64_num - bitfield_num): + print("Appending an empty bitfield.") + det_list.append(0) + + return det_list + + def clean(back_end, filename): """Remove test files.""" if back_end == trexio.TREXIO_HDF5: @@ -206,16 +262,36 @@ class TestIO: """Write CI determinants and coefficients.""" self.open() # write mo_num (needed later to write determinants) - trexio.write_mo_num(self.test_file, mo_num) + MO_NUM_TEST = 100 + trexio.write_mo_num(self.test_file, MO_NUM_TEST) # get the number of bit-strings per spin component + INT64_NUM_TEST = int((MO_NUM_TEST-1)/64) + 1 int_num = trexio.get_int64_num(self.test_file) - assert int_num == int64_num + assert int_num == INT64_NUM_TEST + # write the number of up and down electrons + trexio.write_electron_up_num(self.test_file, 4) + trexio.write_electron_dn_num(self.test_file, 3) + # orbital lists + orb_list_up = [0,1,2,3] + orb_list_dn = [0,1,2] + + # data to write + DET_NUM_TEST = 100 + det_up = to_determinant_list(orb_list_up, INT64_NUM_TEST) + det_dn = to_determinant_list(orb_list_dn, INT64_NUM_TEST) + + det_list = [] + coeff_list = [] + for i in range(DET_NUM_TEST): + det_list.append(det_up + det_dn) + coeff_list.append(3.14 + float(i)) + # write the data for the ground state offset = 0 trexio.write_state_id(self.test_file, 0) - trexio.write_determinant_list(self.test_file, offset, det_num, dets) + trexio.write_determinant_list(self.test_file, offset, DET_NUM_TEST, det_list) assert trexio.has_determinant_list(self.test_file) - trexio.write_determinant_coefficient(self.test_file, offset, det_num, coeffs) + trexio.write_determinant_coefficient(self.test_file, offset, DET_NUM_TEST, coeff_list) assert trexio.has_determinant_coefficient(self.test_file) # manually check the consistency between coefficient_size and number of determinants assert trexio.read_determinant_coefficient_size(self.test_file) == trexio.read_determinant_num(self.test_file) @@ -350,26 +426,6 @@ class TestIO: if self.test_file.isOpen: self.test_file.close() - def test_determinant_read(self): - """Read the CI determinants.""" - self.open(mode='r') - # read determinants (list of ints and float coefficients) - buf_size = 100 - offset_file = 0 - # read full buf_size (i.e. the one that does not reach EOF) - dets_np, read_buf_size, eof = trexio.read_determinant_list(self.test_file, offset_file, buf_size) - #print(f'First complete read of determinant list: {read_buf_size}') - assert not eof - assert read_buf_size == buf_size - assert dets_np[0][0] == 0 - assert dets_np[read_buf_size-1][int64_num*2-1] == read_buf_size * int64_num * 2- 1 - - coefficients_np, read_buf_size, eof = trexio.read_determinant_coefficient(self.test_file, offset_file, buf_size) - #print(f'First complete read of determinant coefficients: {read_buf_size}') - assert not eof - assert read_buf_size == buf_size - if self.test_file.isOpen: - self.test_file.close() def test_array_str_read(self): """Read an array of strings."""