mirror of
https://github.com/TREX-CoE/trexio.git
synced 2024-12-22 12:23:54 +01:00
adapt the generator to work for an arbitrary number of indices in sparse dset
This commit is contained in:
parent
ed3bde973e
commit
e340c6541d
@ -577,6 +577,7 @@ def split_dset_dict_detailed (datasets: dict) -> tuple:
|
||||
for k,v in datasets.items():
|
||||
# create a temp dictionary
|
||||
tmp_dict = {}
|
||||
rank = len(v[1])
|
||||
is_sparse = False
|
||||
# specify details required to replace templated variables later
|
||||
if v[0] == 'float':
|
||||
@ -629,9 +630,21 @@ def split_dset_dict_detailed (datasets: dict) -> tuple:
|
||||
group_dset_dtype_double = ''
|
||||
group_dset_dtype_single = ''
|
||||
default_prec = ''
|
||||
group_dset_format_printf = '%10" PRId32 " %10" PRId32 " %10" PRId32 " %10" PRId32 " %24.16e'
|
||||
group_dset_format_scanf = '%" SCNd32 " %" SCNd32 " %" SCNd32 " %" SCNd32 " %lf'
|
||||
group_dset_py_dtype = ''
|
||||
group_dset_sparse_value_format_printf = '%24.16e'
|
||||
group_dset_sparse_value_format_scanf = '%lf'
|
||||
# build format string for n-index sparse quantity
|
||||
int32_len_printf = 10
|
||||
item_printf = f'%{int32_len_printf}" PRId32 " '
|
||||
item_scanf = '%" SCNd32 " '
|
||||
group_dset_format_printf = ''
|
||||
group_dset_format_scanf = ''
|
||||
for i in range(rank):
|
||||
group_dset_format_printf += item_printf
|
||||
group_dset_format_scanf += item_scanf
|
||||
|
||||
group_dset_format_printf += group_dset_sparse_value_format_printf
|
||||
group_dset_format_scanf += group_dset_sparse_value_format_scanf
|
||||
|
||||
tmp_dict['group_dset'] = k
|
||||
# add flag to detect index types
|
||||
@ -654,7 +667,7 @@ def split_dset_dict_detailed (datasets: dict) -> tuple:
|
||||
tmp_dict['group_dset_format_scanf'] = group_dset_format_scanf
|
||||
tmp_dict['group_dset_py_dtype'] = group_dset_py_dtype
|
||||
# add the rank
|
||||
tmp_dict['rank'] = len(v[1])
|
||||
tmp_dict['rank'] = rank
|
||||
tmp_dict['group_dset_rank'] = str(tmp_dict['rank'])
|
||||
# add the list of dimensions
|
||||
tmp_dict['dims'] = [dim.replace('.','_') for dim in v[1]]
|
||||
@ -674,9 +687,26 @@ def split_dset_dict_detailed (datasets: dict) -> tuple:
|
||||
tmp_dict['group_dset_f_dims'] = dim_f_list
|
||||
|
||||
if is_sparse:
|
||||
tmp_dict['group_dset_sparse_line_length'] = "69"
|
||||
tmp_dict['group_dset_sparse_indices_printf'] = "*(index_sparse + 4*i), *(index_sparse + 4*i+1), *(index_sparse + 4*i+2), *(index_sparse + 4*i+3)"
|
||||
tmp_dict['group_dset_sparse_indices_scanf'] = "index_sparse + 4*i, index_sparse + 4*i+1, index_sparse + 4*i+2, index_sparse + 4*i+3"
|
||||
# build printf/scanf sequence and compute line length for n-index sparse quantity
|
||||
index_printf = '*(index_sparse + 4*i'
|
||||
index_scanf = 'index_sparse + 4*i'
|
||||
# one index item consumes up to index_length characters (int32_len_printf for int32 + 1 for space)
|
||||
index_len = int32_len_printf + 1
|
||||
group_dset_sparse_indices_printf = index_printf + ')'
|
||||
group_dset_sparse_indices_scanf = index_scanf
|
||||
group_dset_sparse_line_len = index_len
|
||||
# loop from 1 because we already have stored one index
|
||||
for index_count in range(1,rank):
|
||||
group_dset_sparse_indices_printf += f', {index_printf} + {index_count})'
|
||||
group_dset_sparse_indices_scanf += f', {index_scanf} + {index_count}'
|
||||
group_dset_sparse_line_len += index_len
|
||||
|
||||
# add 24 chars occupied by the floating point value of sparse dataset + 1 char for "\n"
|
||||
group_dset_sparse_line_len += 24 + 1
|
||||
|
||||
tmp_dict['group_dset_sparse_line_length'] = str(group_dset_sparse_line_len)
|
||||
tmp_dict['group_dset_sparse_indices_printf'] = group_dset_sparse_indices_printf
|
||||
tmp_dict['group_dset_sparse_indices_scanf'] = group_dset_sparse_indices_scanf
|
||||
|
||||
# add group name as a key-value pair to the dset dict
|
||||
tmp_dict['group'] = v[2]
|
||||
|
Loading…
Reference in New Issue
Block a user