1
0
mirror of https://github.com/TREX-CoE/trexio.git synced 2024-12-22 12:23:54 +01:00

adapt the generator to work for an arbitrary number of indices in sparse dset

This commit is contained in:
q-posev 2021-12-03 16:56:45 +01:00
parent ed3bde973e
commit e340c6541d

View File

@ -577,6 +577,7 @@ def split_dset_dict_detailed (datasets: dict) -> tuple:
for k,v in datasets.items():
# create a temp dictionary
tmp_dict = {}
rank = len(v[1])
is_sparse = False
# specify details required to replace templated variables later
if v[0] == 'float':
@ -629,9 +630,21 @@ def split_dset_dict_detailed (datasets: dict) -> tuple:
group_dset_dtype_double = ''
group_dset_dtype_single = ''
default_prec = ''
group_dset_format_printf = '%10" PRId32 " %10" PRId32 " %10" PRId32 " %10" PRId32 " %24.16e'
group_dset_format_scanf = '%" SCNd32 " %" SCNd32 " %" SCNd32 " %" SCNd32 " %lf'
group_dset_py_dtype = ''
group_dset_sparse_value_format_printf = '%24.16e'
group_dset_sparse_value_format_scanf = '%lf'
# build format string for n-index sparse quantity
int32_len_printf = 10
item_printf = f'%{int32_len_printf}" PRId32 " '
item_scanf = '%" SCNd32 " '
group_dset_format_printf = ''
group_dset_format_scanf = ''
for i in range(rank):
group_dset_format_printf += item_printf
group_dset_format_scanf += item_scanf
group_dset_format_printf += group_dset_sparse_value_format_printf
group_dset_format_scanf += group_dset_sparse_value_format_scanf
tmp_dict['group_dset'] = k
# add flag to detect index types
@ -654,7 +667,7 @@ def split_dset_dict_detailed (datasets: dict) -> tuple:
tmp_dict['group_dset_format_scanf'] = group_dset_format_scanf
tmp_dict['group_dset_py_dtype'] = group_dset_py_dtype
# add the rank
tmp_dict['rank'] = len(v[1])
tmp_dict['rank'] = rank
tmp_dict['group_dset_rank'] = str(tmp_dict['rank'])
# add the list of dimensions
tmp_dict['dims'] = [dim.replace('.','_') for dim in v[1]]
@ -674,9 +687,26 @@ def split_dset_dict_detailed (datasets: dict) -> tuple:
tmp_dict['group_dset_f_dims'] = dim_f_list
if is_sparse:
tmp_dict['group_dset_sparse_line_length'] = "69"
tmp_dict['group_dset_sparse_indices_printf'] = "*(index_sparse + 4*i), *(index_sparse + 4*i+1), *(index_sparse + 4*i+2), *(index_sparse + 4*i+3)"
tmp_dict['group_dset_sparse_indices_scanf'] = "index_sparse + 4*i, index_sparse + 4*i+1, index_sparse + 4*i+2, index_sparse + 4*i+3"
# build printf/scanf sequence and compute line length for n-index sparse quantity
index_printf = '*(index_sparse + 4*i'
index_scanf = 'index_sparse + 4*i'
# one index item consumes up to index_length characters (int32_len_printf for int32 + 1 for space)
index_len = int32_len_printf + 1
group_dset_sparse_indices_printf = index_printf + ')'
group_dset_sparse_indices_scanf = index_scanf
group_dset_sparse_line_len = index_len
# loop from 1 because we already have stored one index
for index_count in range(1,rank):
group_dset_sparse_indices_printf += f', {index_printf} + {index_count})'
group_dset_sparse_indices_scanf += f', {index_scanf} + {index_count}'
group_dset_sparse_line_len += index_len
# add 24 chars occupied by the floating point value of sparse dataset + 1 char for "\n"
group_dset_sparse_line_len += 24 + 1
tmp_dict['group_dset_sparse_line_length'] = str(group_dset_sparse_line_len)
tmp_dict['group_dset_sparse_indices_printf'] = group_dset_sparse_indices_printf
tmp_dict['group_dset_sparse_indices_scanf'] = group_dset_sparse_indices_scanf
# add group name as a key-value pair to the dset dict
tmp_dict['group'] = v[2]