mirror of
https://github.com/TREX-CoE/trexio.git
synced 2024-11-04 05:03:58 +01:00
adapt the generator to work for an arbitrary number of indices in sparse dset
This commit is contained in:
parent
ed3bde973e
commit
e340c6541d
@ -577,6 +577,7 @@ def split_dset_dict_detailed (datasets: dict) -> tuple:
|
|||||||
for k,v in datasets.items():
|
for k,v in datasets.items():
|
||||||
# create a temp dictionary
|
# create a temp dictionary
|
||||||
tmp_dict = {}
|
tmp_dict = {}
|
||||||
|
rank = len(v[1])
|
||||||
is_sparse = False
|
is_sparse = False
|
||||||
# specify details required to replace templated variables later
|
# specify details required to replace templated variables later
|
||||||
if v[0] == 'float':
|
if v[0] == 'float':
|
||||||
@ -629,9 +630,21 @@ def split_dset_dict_detailed (datasets: dict) -> tuple:
|
|||||||
group_dset_dtype_double = ''
|
group_dset_dtype_double = ''
|
||||||
group_dset_dtype_single = ''
|
group_dset_dtype_single = ''
|
||||||
default_prec = ''
|
default_prec = ''
|
||||||
group_dset_format_printf = '%10" PRId32 " %10" PRId32 " %10" PRId32 " %10" PRId32 " %24.16e'
|
|
||||||
group_dset_format_scanf = '%" SCNd32 " %" SCNd32 " %" SCNd32 " %" SCNd32 " %lf'
|
|
||||||
group_dset_py_dtype = ''
|
group_dset_py_dtype = ''
|
||||||
|
group_dset_sparse_value_format_printf = '%24.16e'
|
||||||
|
group_dset_sparse_value_format_scanf = '%lf'
|
||||||
|
# build format string for n-index sparse quantity
|
||||||
|
int32_len_printf = 10
|
||||||
|
item_printf = f'%{int32_len_printf}" PRId32 " '
|
||||||
|
item_scanf = '%" SCNd32 " '
|
||||||
|
group_dset_format_printf = ''
|
||||||
|
group_dset_format_scanf = ''
|
||||||
|
for i in range(rank):
|
||||||
|
group_dset_format_printf += item_printf
|
||||||
|
group_dset_format_scanf += item_scanf
|
||||||
|
|
||||||
|
group_dset_format_printf += group_dset_sparse_value_format_printf
|
||||||
|
group_dset_format_scanf += group_dset_sparse_value_format_scanf
|
||||||
|
|
||||||
tmp_dict['group_dset'] = k
|
tmp_dict['group_dset'] = k
|
||||||
# add flag to detect index types
|
# add flag to detect index types
|
||||||
@ -654,7 +667,7 @@ def split_dset_dict_detailed (datasets: dict) -> tuple:
|
|||||||
tmp_dict['group_dset_format_scanf'] = group_dset_format_scanf
|
tmp_dict['group_dset_format_scanf'] = group_dset_format_scanf
|
||||||
tmp_dict['group_dset_py_dtype'] = group_dset_py_dtype
|
tmp_dict['group_dset_py_dtype'] = group_dset_py_dtype
|
||||||
# add the rank
|
# add the rank
|
||||||
tmp_dict['rank'] = len(v[1])
|
tmp_dict['rank'] = rank
|
||||||
tmp_dict['group_dset_rank'] = str(tmp_dict['rank'])
|
tmp_dict['group_dset_rank'] = str(tmp_dict['rank'])
|
||||||
# add the list of dimensions
|
# add the list of dimensions
|
||||||
tmp_dict['dims'] = [dim.replace('.','_') for dim in v[1]]
|
tmp_dict['dims'] = [dim.replace('.','_') for dim in v[1]]
|
||||||
@ -674,9 +687,26 @@ def split_dset_dict_detailed (datasets: dict) -> tuple:
|
|||||||
tmp_dict['group_dset_f_dims'] = dim_f_list
|
tmp_dict['group_dset_f_dims'] = dim_f_list
|
||||||
|
|
||||||
if is_sparse:
|
if is_sparse:
|
||||||
tmp_dict['group_dset_sparse_line_length'] = "69"
|
# build printf/scanf sequence and compute line length for n-index sparse quantity
|
||||||
tmp_dict['group_dset_sparse_indices_printf'] = "*(index_sparse + 4*i), *(index_sparse + 4*i+1), *(index_sparse + 4*i+2), *(index_sparse + 4*i+3)"
|
index_printf = '*(index_sparse + 4*i'
|
||||||
tmp_dict['group_dset_sparse_indices_scanf'] = "index_sparse + 4*i, index_sparse + 4*i+1, index_sparse + 4*i+2, index_sparse + 4*i+3"
|
index_scanf = 'index_sparse + 4*i'
|
||||||
|
# one index item consumes up to index_length characters (int32_len_printf for int32 + 1 for space)
|
||||||
|
index_len = int32_len_printf + 1
|
||||||
|
group_dset_sparse_indices_printf = index_printf + ')'
|
||||||
|
group_dset_sparse_indices_scanf = index_scanf
|
||||||
|
group_dset_sparse_line_len = index_len
|
||||||
|
# loop from 1 because we already have stored one index
|
||||||
|
for index_count in range(1,rank):
|
||||||
|
group_dset_sparse_indices_printf += f', {index_printf} + {index_count})'
|
||||||
|
group_dset_sparse_indices_scanf += f', {index_scanf} + {index_count}'
|
||||||
|
group_dset_sparse_line_len += index_len
|
||||||
|
|
||||||
|
# add 24 chars occupied by the floating point value of sparse dataset + 1 char for "\n"
|
||||||
|
group_dset_sparse_line_len += 24 + 1
|
||||||
|
|
||||||
|
tmp_dict['group_dset_sparse_line_length'] = str(group_dset_sparse_line_len)
|
||||||
|
tmp_dict['group_dset_sparse_indices_printf'] = group_dset_sparse_indices_printf
|
||||||
|
tmp_dict['group_dset_sparse_indices_scanf'] = group_dset_sparse_indices_scanf
|
||||||
|
|
||||||
# add group name as a key-value pair to the dset dict
|
# add group name as a key-value pair to the dset dict
|
||||||
tmp_dict['group'] = v[2]
|
tmp_dict['group'] = v[2]
|
||||||
|
Loading…
Reference in New Issue
Block a user