mirror of
https://github.com/TREX-CoE/trexio.git
synced 2025-01-03 01:56:13 +01:00
refactor to set common dtype substitutions in a function
This commit is contained in:
parent
1ad20c1cb9
commit
bf5c651220
@ -325,9 +325,9 @@ def special_populate_text_group(fname: str, paths: dict, group_dict: dict, detai
|
||||
if group != detailed_dset[dset]['group']:
|
||||
continue
|
||||
|
||||
if ('REPEAT GROUP_DSET_STR' in line) and (detailed_dset[dset]['dtype'] != 'char*'):
|
||||
if ('REPEAT GROUP_DSET_STR' in line) and (detailed_dset[dset]['group_dset_dtype'] != 'char*'):
|
||||
continue
|
||||
if ('REPEAT GROUP_DSET_NUM' in line) and (detailed_dset[dset]['dtype'] == 'char*'):
|
||||
if ('REPEAT GROUP_DSET_NUM' in line) and (detailed_dset[dset]['group_dset_dtype'] == 'char*'):
|
||||
continue
|
||||
|
||||
dset_allocated.append(dset)
|
||||
@ -458,6 +458,110 @@ def get_group_dict (configuration: dict) -> dict:
|
||||
return group_dict
|
||||
|
||||
|
||||
def get_dtype_dict (dtype: str, target: str, rank = None, int_len_printf = None) -> dict:
|
||||
"""
|
||||
Returns the dictionary of dtype-related templated variables set for a given `dtype`.
|
||||
Keys are names of templated variables, values are strings to be used by the generator.
|
||||
|
||||
Parameters:
|
||||
dtype (str) : dtype corresponding to the trex.json (i.e. int/dim/float/float sparse/str)
|
||||
target (str) : `num` or `dset`
|
||||
rank (int) : [optional] value of n in n-index (sparse) dset; needed to build the printf/scanf format string
|
||||
int_len_printf (int): [optional] length reserved for one index when printing n-index (sparse) dset (e.g. 10 for int32_t)
|
||||
|
||||
Returns:
|
||||
dtype_dict (dict) : dictionary dtype-related substitutions
|
||||
"""
|
||||
if not target in ['num', 'dset']:
|
||||
raise Exception('Only num or dset target can be set.')
|
||||
if 'sparse' in dtype:
|
||||
if rank is None or int_len_printf is None:
|
||||
raise Exception("Both rank and int_len_printf arguments has to be provided to build the dtype_dict for sparse data.")
|
||||
if rank is not None and rank <= 1:
|
||||
raise Exception('Rank of sparse quantity cannot be lower than 2.')
|
||||
if int_len_printf is not None and int_len_printf <= 0:
|
||||
raise Exception('Length of an index of sparse quantity has to be positive value.')
|
||||
|
||||
dtype_dict = {}
|
||||
# set up the key-value pairs dependending on the dtype
|
||||
if dtype == 'float':
|
||||
dtype_dict.update({
|
||||
'default_prec' : '64',
|
||||
f'group_{target}_dtype' : 'double',
|
||||
f'group_{target}_h5_dtype' : 'native_double',
|
||||
f'group_{target}_f_dtype_default' : 'real(8)',
|
||||
f'group_{target}_f_dtype_double' : 'real(8)',
|
||||
f'group_{target}_f_dtype_single' : 'real(4)',
|
||||
f'group_{target}_dtype_default' : 'double',
|
||||
f'group_{target}_dtype_double' : 'double',
|
||||
f'group_{target}_dtype_single' : 'float',
|
||||
f'group_{target}_format_printf' : '24.16e',
|
||||
f'group_{target}_format_scanf' : 'lf',
|
||||
f'group_{target}_py_dtype' : 'float'
|
||||
})
|
||||
elif dtype in ['int', 'dim', 'index']:
|
||||
dtype_dict.update({
|
||||
'default_prec' : '32',
|
||||
f'group_{target}_dtype' : 'int64_t',
|
||||
f'group_{target}_h5_dtype' : 'native_int64',
|
||||
f'group_{target}_f_dtype_default' : 'integer(4)',
|
||||
f'group_{target}_f_dtype_double' : 'integer(8)',
|
||||
f'group_{target}_f_dtype_single' : 'integer(4)',
|
||||
f'group_{target}_dtype_default' : 'int32_t',
|
||||
f'group_{target}_dtype_double' : 'int64_t',
|
||||
f'group_{target}_dtype_single' : 'int32_t',
|
||||
f'group_{target}_format_printf' : '" PRId64 "',
|
||||
f'group_{target}_format_scanf' : '" SCNd64 "',
|
||||
f'group_{target}_py_dtype' : 'int'
|
||||
})
|
||||
elif dtype == 'str':
|
||||
dtype_dict.update({
|
||||
'default_prec' : '',
|
||||
f'group_{target}_dtype' : 'char*',
|
||||
f'group_{target}_h5_dtype' : '',
|
||||
f'group_{target}_f_dtype_default': '',
|
||||
f'group_{target}_f_dtype_double' : '',
|
||||
f'group_{target}_f_dtype_single' : '',
|
||||
f'group_{target}_dtype_default' : 'char*',
|
||||
f'group_{target}_dtype_double' : '',
|
||||
f'group_{target}_dtype_single' : '',
|
||||
f'group_{target}_format_printf' : 's',
|
||||
f'group_{target}_format_scanf' : 's',
|
||||
f'group_{target}_py_dtype' : 'str'
|
||||
})
|
||||
elif 'sparse' in dtype:
|
||||
# build format string for n-index sparse quantity
|
||||
item_printf = f'%{int_len_printf}" PRId32 " '
|
||||
item_scanf = '%" SCNd32 " '
|
||||
group_dset_format_printf = ''
|
||||
group_dset_format_scanf = ''
|
||||
for i in range(rank):
|
||||
group_dset_format_printf += item_printf
|
||||
group_dset_format_scanf += item_scanf
|
||||
# append the format string for float values
|
||||
group_dset_format_printf += '%24.16e'
|
||||
group_dset_format_scanf += '%lf'
|
||||
|
||||
# set up the dictionary for sparse
|
||||
dtype_dict.update({
|
||||
'default_prec' : '',
|
||||
f'group_{target}_dtype' : 'double',
|
||||
f'group_{target}_h5_dtype' : '',
|
||||
f'group_{target}_f_dtype_default': '',
|
||||
f'group_{target}_f_dtype_double' : '',
|
||||
f'group_{target}_f_dtype_single' : '',
|
||||
f'group_{target}_dtype_default' : '',
|
||||
f'group_{target}_dtype_double' : '',
|
||||
f'group_{target}_dtype_single' : '',
|
||||
f'group_{target}_format_printf' : group_dset_format_printf,
|
||||
f'group_{target}_format_scanf' : group_dset_format_scanf,
|
||||
f'group_{target}_py_dtype' : ''
|
||||
})
|
||||
|
||||
return dtype_dict
|
||||
|
||||
|
||||
|
||||
def get_detailed_num_dict (configuration: dict) -> dict:
|
||||
"""
|
||||
Returns the dictionary of all `num`-suffixed variables.
|
||||
@ -480,33 +584,8 @@ def get_detailed_num_dict (configuration: dict) -> dict:
|
||||
tmp_dict['group_num'] = tmp_num
|
||||
num_dict[tmp_num] = tmp_dict
|
||||
|
||||
# TODO the arguments below are almost the same as for group_dset (except for trex_json_int_type) and can be exported from somewhere
|
||||
if v2[0] == 'float':
|
||||
tmp_dict['datatype'] = 'double'
|
||||
tmp_dict['group_num_h5_dtype'] = 'native_double'
|
||||
tmp_dict['group_num_f_dtype_default']= 'real(8)'
|
||||
tmp_dict['group_num_f_dtype_double'] = 'real(8)'
|
||||
tmp_dict['group_num_f_dtype_single'] = 'real(4)'
|
||||
tmp_dict['group_num_dtype_default']= 'double'
|
||||
tmp_dict['group_num_dtype_double'] = 'double'
|
||||
tmp_dict['group_num_dtype_single'] = 'float'
|
||||
tmp_dict['default_prec'] = '64'
|
||||
tmp_dict['group_num_format_printf'] = '24.16e'
|
||||
tmp_dict['group_num_format_scanf'] = 'lf'
|
||||
tmp_dict['group_num_py_dtype'] = 'float'
|
||||
elif v2[0] in ['int', 'dim']:
|
||||
tmp_dict['datatype'] = 'int64_t'
|
||||
tmp_dict['group_num_h5_dtype'] = 'native_int64'
|
||||
tmp_dict['group_num_f_dtype_default']= 'integer(4)'
|
||||
tmp_dict['group_num_f_dtype_double'] = 'integer(8)'
|
||||
tmp_dict['group_num_f_dtype_single'] = 'integer(4)'
|
||||
tmp_dict['group_num_dtype_default']= 'int32_t'
|
||||
tmp_dict['group_num_dtype_double'] = 'int64_t'
|
||||
tmp_dict['group_num_dtype_single'] = 'int32_t'
|
||||
tmp_dict['default_prec'] = '32'
|
||||
tmp_dict['group_num_format_printf'] = '" PRId64 "'
|
||||
tmp_dict['group_num_format_scanf'] = '" SCNd64 "'
|
||||
tmp_dict['group_num_py_dtype'] = 'int'
|
||||
tmp_dict.update(get_dtype_dict(v2[0], 'num'))
|
||||
if v2[0] in ['int', 'dim']:
|
||||
tmp_dict['trex_json_int_type'] = v2[0]
|
||||
|
||||
return num_dict
|
||||
@ -578,109 +657,47 @@ def split_dset_dict_detailed (datasets: dict) -> tuple:
|
||||
# create a temp dictionary
|
||||
tmp_dict = {}
|
||||
rank = len(v[1])
|
||||
datatype = v[0]
|
||||
|
||||
# define whether the dset is sparse
|
||||
is_sparse = False
|
||||
# specify details required to replace templated variables later
|
||||
if v[0] == 'float':
|
||||
datatype = 'double'
|
||||
group_dset_h5_dtype = 'native_double'
|
||||
group_dset_f_dtype_default= 'real(8)'
|
||||
group_dset_f_dtype_double = 'real(8)'
|
||||
group_dset_f_dtype_single = 'real(4)'
|
||||
group_dset_dtype_default= 'double'
|
||||
group_dset_dtype_double = 'double'
|
||||
group_dset_dtype_single = 'float'
|
||||
default_prec = '64'
|
||||
group_dset_format_printf = '24.16e'
|
||||
group_dset_format_scanf = 'lf'
|
||||
group_dset_py_dtype = 'float'
|
||||
elif v[0] in ['int', 'index']:
|
||||
datatype = 'int64_t'
|
||||
group_dset_h5_dtype = 'native_int64'
|
||||
group_dset_f_dtype_default= 'integer(4)'
|
||||
group_dset_f_dtype_double = 'integer(8)'
|
||||
group_dset_f_dtype_single = 'integer(4)'
|
||||
group_dset_dtype_default= 'int32_t'
|
||||
group_dset_dtype_double = 'int64_t'
|
||||
group_dset_dtype_single = 'int32_t'
|
||||
default_prec = '32'
|
||||
group_dset_format_printf = '" PRId64 "'
|
||||
group_dset_format_scanf = '" SCNd64 "'
|
||||
group_dset_py_dtype = 'int'
|
||||
elif v[0] == 'str':
|
||||
datatype = 'char*'
|
||||
group_dset_h5_dtype = ''
|
||||
group_dset_f_dtype_default = ''
|
||||
group_dset_f_dtype_double = ''
|
||||
group_dset_f_dtype_single = ''
|
||||
group_dset_dtype_default = 'char*'
|
||||
group_dset_dtype_double = ''
|
||||
group_dset_dtype_single = ''
|
||||
default_prec = ''
|
||||
group_dset_format_printf = 's'
|
||||
group_dset_format_scanf = 's'
|
||||
group_dset_py_dtype = 'str'
|
||||
elif 'sparse' in v[0]:
|
||||
if 'sparse' in datatype:
|
||||
is_sparse = True
|
||||
datatype = 'double'
|
||||
group_dset_h5_dtype = ''
|
||||
group_dset_f_dtype_default= ''
|
||||
group_dset_f_dtype_double = ''
|
||||
group_dset_f_dtype_single = ''
|
||||
group_dset_dtype_default= ''
|
||||
group_dset_dtype_double = ''
|
||||
group_dset_dtype_single = ''
|
||||
default_prec = ''
|
||||
group_dset_py_dtype = ''
|
||||
group_dset_sparse_value_format_printf = '%24.16e'
|
||||
group_dset_sparse_value_format_scanf = '%lf'
|
||||
# build format string for n-index sparse quantity
|
||||
int32_len_printf = 10
|
||||
item_printf = f'%{int32_len_printf}" PRId32 " '
|
||||
item_scanf = '%" SCNd32 " '
|
||||
group_dset_format_printf = ''
|
||||
group_dset_format_scanf = ''
|
||||
for i in range(rank):
|
||||
group_dset_format_printf += item_printf
|
||||
group_dset_format_scanf += item_scanf
|
||||
# int64_len_printf = ??
|
||||
# int16_len_printf = ??
|
||||
|
||||
group_dset_format_printf += group_dset_sparse_value_format_printf
|
||||
group_dset_format_scanf += group_dset_sparse_value_format_scanf
|
||||
# get the dtype-related substitutions required to replace templated variables later
|
||||
if not is_sparse:
|
||||
dtype_dict = get_dtype_dict(datatype, 'dset')
|
||||
else:
|
||||
dtype_dict = get_dtype_dict(datatype, 'dset', rank, int32_len_printf)
|
||||
|
||||
tmp_dict.update(dtype_dict)
|
||||
|
||||
# set the group_dset key to the full name of the dset
|
||||
tmp_dict['group_dset'] = k
|
||||
# add flag to detect index types
|
||||
if 'index' == v[0]:
|
||||
if 'index' in datatype:
|
||||
tmp_dict['is_index'] = 'file->one_based'
|
||||
else:
|
||||
tmp_dict['is_index'] = 'false'
|
||||
# add the datatypes for templates
|
||||
tmp_dict['dtype'] = datatype
|
||||
tmp_dict['group_dset_dtype'] = datatype
|
||||
tmp_dict['group_dset_h5_dtype'] = group_dset_h5_dtype
|
||||
tmp_dict['group_dset_f_dtype_default'] = group_dset_f_dtype_default
|
||||
tmp_dict['group_dset_f_dtype_double'] = group_dset_f_dtype_double
|
||||
tmp_dict['group_dset_f_dtype_single'] = group_dset_f_dtype_single
|
||||
tmp_dict['group_dset_dtype_default'] = group_dset_dtype_default
|
||||
tmp_dict['group_dset_dtype_double'] = group_dset_dtype_double
|
||||
tmp_dict['group_dset_dtype_single'] = group_dset_dtype_single
|
||||
tmp_dict['default_prec'] = default_prec
|
||||
tmp_dict['group_dset_format_printf'] = group_dset_format_printf
|
||||
tmp_dict['group_dset_format_scanf'] = group_dset_format_scanf
|
||||
tmp_dict['group_dset_py_dtype'] = group_dset_py_dtype
|
||||
|
||||
# add the rank
|
||||
tmp_dict['rank'] = rank
|
||||
tmp_dict['group_dset_rank'] = str(tmp_dict['rank'])
|
||||
tmp_dict['group_dset_rank'] = str(rank)
|
||||
# add the list of dimensions
|
||||
tmp_dict['dims'] = [dim.replace('.','_') for dim in v[1]]
|
||||
# build a list of dimensions to be inserted in the dims array initialization, e.g. {ao_num, ao_num}
|
||||
dim_list = tmp_dict['dims'][0]
|
||||
if tmp_dict['rank'] > 1:
|
||||
for i in range(1, tmp_dict['rank']):
|
||||
if rank > 1:
|
||||
for i in range(1, rank):
|
||||
dim_toadd = tmp_dict['dims'][i]
|
||||
dim_list += f', {dim_toadd}'
|
||||
|
||||
tmp_dict['group_dset_dim_list'] = dim_list
|
||||
|
||||
if tmp_dict['rank'] == 0:
|
||||
if rank == 0:
|
||||
dim_f_list = ""
|
||||
else:
|
||||
dim_f_list = "(*)"
|
||||
@ -712,7 +729,7 @@ def split_dset_dict_detailed (datasets: dict) -> tuple:
|
||||
tmp_dict['group'] = v[2]
|
||||
|
||||
# split datasets in numeric- and string- based
|
||||
if datatype == 'char*':
|
||||
if 'str' in datatype:
|
||||
dset_string_dict[k] = tmp_dict
|
||||
elif is_sparse:
|
||||
dset_sparse_dict[k] = tmp_dict
|
||||
|
Loading…
Reference in New Issue
Block a user