1
0
mirror of https://github.com/TREX-CoE/trexio.git synced 2024-08-25 06:31:43 +02:00

refactor to set common dtype substitutions in a function

This commit is contained in:
q-posev 2021-12-07 10:14:27 +01:00
parent 1ad20c1cb9
commit bf5c651220

View File

@ -325,9 +325,9 @@ def special_populate_text_group(fname: str, paths: dict, group_dict: dict, detai
if group != detailed_dset[dset]['group']:
continue
if ('REPEAT GROUP_DSET_STR' in line) and (detailed_dset[dset]['dtype'] != 'char*'):
if ('REPEAT GROUP_DSET_STR' in line) and (detailed_dset[dset]['group_dset_dtype'] != 'char*'):
continue
if ('REPEAT GROUP_DSET_NUM' in line) and (detailed_dset[dset]['dtype'] == 'char*'):
if ('REPEAT GROUP_DSET_NUM' in line) and (detailed_dset[dset]['group_dset_dtype'] == 'char*'):
continue
dset_allocated.append(dset)
@ -458,6 +458,110 @@ def get_group_dict (configuration: dict) -> dict:
return group_dict
def get_dtype_dict (dtype: str, target: str, rank = None, int_len_printf = None) -> dict:
"""
Returns the dictionary of dtype-related templated variables set for a given `dtype`.
Keys are names of templated variables, values are strings to be used by the generator.
Parameters:
dtype (str) : dtype corresponding to the trex.json (i.e. int/dim/float/float sparse/str)
target (str) : `num` or `dset`
rank (int) : [optional] value of n in n-index (sparse) dset; needed to build the printf/scanf format string
int_len_printf (int): [optional] length reserved for one index when printing n-index (sparse) dset (e.g. 10 for int32_t)
Returns:
dtype_dict (dict) : dictionary dtype-related substitutions
"""
if not target in ['num', 'dset']:
raise Exception('Only num or dset target can be set.')
if 'sparse' in dtype:
if rank is None or int_len_printf is None:
raise Exception("Both rank and int_len_printf arguments has to be provided to build the dtype_dict for sparse data.")
if rank is not None and rank <= 1:
raise Exception('Rank of sparse quantity cannot be lower than 2.')
if int_len_printf is not None and int_len_printf <= 0:
raise Exception('Length of an index of sparse quantity has to be positive value.')
dtype_dict = {}
# set up the key-value pairs dependending on the dtype
if dtype == 'float':
dtype_dict.update({
'default_prec' : '64',
f'group_{target}_dtype' : 'double',
f'group_{target}_h5_dtype' : 'native_double',
f'group_{target}_f_dtype_default' : 'real(8)',
f'group_{target}_f_dtype_double' : 'real(8)',
f'group_{target}_f_dtype_single' : 'real(4)',
f'group_{target}_dtype_default' : 'double',
f'group_{target}_dtype_double' : 'double',
f'group_{target}_dtype_single' : 'float',
f'group_{target}_format_printf' : '24.16e',
f'group_{target}_format_scanf' : 'lf',
f'group_{target}_py_dtype' : 'float'
})
elif dtype in ['int', 'dim', 'index']:
dtype_dict.update({
'default_prec' : '32',
f'group_{target}_dtype' : 'int64_t',
f'group_{target}_h5_dtype' : 'native_int64',
f'group_{target}_f_dtype_default' : 'integer(4)',
f'group_{target}_f_dtype_double' : 'integer(8)',
f'group_{target}_f_dtype_single' : 'integer(4)',
f'group_{target}_dtype_default' : 'int32_t',
f'group_{target}_dtype_double' : 'int64_t',
f'group_{target}_dtype_single' : 'int32_t',
f'group_{target}_format_printf' : '" PRId64 "',
f'group_{target}_format_scanf' : '" SCNd64 "',
f'group_{target}_py_dtype' : 'int'
})
elif dtype == 'str':
dtype_dict.update({
'default_prec' : '',
f'group_{target}_dtype' : 'char*',
f'group_{target}_h5_dtype' : '',
f'group_{target}_f_dtype_default': '',
f'group_{target}_f_dtype_double' : '',
f'group_{target}_f_dtype_single' : '',
f'group_{target}_dtype_default' : 'char*',
f'group_{target}_dtype_double' : '',
f'group_{target}_dtype_single' : '',
f'group_{target}_format_printf' : 's',
f'group_{target}_format_scanf' : 's',
f'group_{target}_py_dtype' : 'str'
})
elif 'sparse' in dtype:
# build format string for n-index sparse quantity
item_printf = f'%{int_len_printf}" PRId32 " '
item_scanf = '%" SCNd32 " '
group_dset_format_printf = ''
group_dset_format_scanf = ''
for i in range(rank):
group_dset_format_printf += item_printf
group_dset_format_scanf += item_scanf
# append the format string for float values
group_dset_format_printf += '%24.16e'
group_dset_format_scanf += '%lf'
# set up the dictionary for sparse
dtype_dict.update({
'default_prec' : '',
f'group_{target}_dtype' : 'double',
f'group_{target}_h5_dtype' : '',
f'group_{target}_f_dtype_default': '',
f'group_{target}_f_dtype_double' : '',
f'group_{target}_f_dtype_single' : '',
f'group_{target}_dtype_default' : '',
f'group_{target}_dtype_double' : '',
f'group_{target}_dtype_single' : '',
f'group_{target}_format_printf' : group_dset_format_printf,
f'group_{target}_format_scanf' : group_dset_format_scanf,
f'group_{target}_py_dtype' : ''
})
return dtype_dict
def get_detailed_num_dict (configuration: dict) -> dict:
"""
Returns the dictionary of all `num`-suffixed variables.
@ -480,33 +584,8 @@ def get_detailed_num_dict (configuration: dict) -> dict:
tmp_dict['group_num'] = tmp_num
num_dict[tmp_num] = tmp_dict
# TODO the arguments below are almost the same as for group_dset (except for trex_json_int_type) and can be exported from somewhere
if v2[0] == 'float':
tmp_dict['datatype'] = 'double'
tmp_dict['group_num_h5_dtype'] = 'native_double'
tmp_dict['group_num_f_dtype_default']= 'real(8)'
tmp_dict['group_num_f_dtype_double'] = 'real(8)'
tmp_dict['group_num_f_dtype_single'] = 'real(4)'
tmp_dict['group_num_dtype_default']= 'double'
tmp_dict['group_num_dtype_double'] = 'double'
tmp_dict['group_num_dtype_single'] = 'float'
tmp_dict['default_prec'] = '64'
tmp_dict['group_num_format_printf'] = '24.16e'
tmp_dict['group_num_format_scanf'] = 'lf'
tmp_dict['group_num_py_dtype'] = 'float'
elif v2[0] in ['int', 'dim']:
tmp_dict['datatype'] = 'int64_t'
tmp_dict['group_num_h5_dtype'] = 'native_int64'
tmp_dict['group_num_f_dtype_default']= 'integer(4)'
tmp_dict['group_num_f_dtype_double'] = 'integer(8)'
tmp_dict['group_num_f_dtype_single'] = 'integer(4)'
tmp_dict['group_num_dtype_default']= 'int32_t'
tmp_dict['group_num_dtype_double'] = 'int64_t'
tmp_dict['group_num_dtype_single'] = 'int32_t'
tmp_dict['default_prec'] = '32'
tmp_dict['group_num_format_printf'] = '" PRId64 "'
tmp_dict['group_num_format_scanf'] = '" SCNd64 "'
tmp_dict['group_num_py_dtype'] = 'int'
tmp_dict.update(get_dtype_dict(v2[0], 'num'))
if v2[0] in ['int', 'dim']:
tmp_dict['trex_json_int_type'] = v2[0]
return num_dict
@ -578,109 +657,47 @@ def split_dset_dict_detailed (datasets: dict) -> tuple:
# create a temp dictionary
tmp_dict = {}
rank = len(v[1])
datatype = v[0]
# define whether the dset is sparse
is_sparse = False
# specify details required to replace templated variables later
if v[0] == 'float':
datatype = 'double'
group_dset_h5_dtype = 'native_double'
group_dset_f_dtype_default= 'real(8)'
group_dset_f_dtype_double = 'real(8)'
group_dset_f_dtype_single = 'real(4)'
group_dset_dtype_default= 'double'
group_dset_dtype_double = 'double'
group_dset_dtype_single = 'float'
default_prec = '64'
group_dset_format_printf = '24.16e'
group_dset_format_scanf = 'lf'
group_dset_py_dtype = 'float'
elif v[0] in ['int', 'index']:
datatype = 'int64_t'
group_dset_h5_dtype = 'native_int64'
group_dset_f_dtype_default= 'integer(4)'
group_dset_f_dtype_double = 'integer(8)'
group_dset_f_dtype_single = 'integer(4)'
group_dset_dtype_default= 'int32_t'
group_dset_dtype_double = 'int64_t'
group_dset_dtype_single = 'int32_t'
default_prec = '32'
group_dset_format_printf = '" PRId64 "'
group_dset_format_scanf = '" SCNd64 "'
group_dset_py_dtype = 'int'
elif v[0] == 'str':
datatype = 'char*'
group_dset_h5_dtype = ''
group_dset_f_dtype_default = ''
group_dset_f_dtype_double = ''
group_dset_f_dtype_single = ''
group_dset_dtype_default = 'char*'
group_dset_dtype_double = ''
group_dset_dtype_single = ''
default_prec = ''
group_dset_format_printf = 's'
group_dset_format_scanf = 's'
group_dset_py_dtype = 'str'
elif 'sparse' in v[0]:
if 'sparse' in datatype:
is_sparse = True
datatype = 'double'
group_dset_h5_dtype = ''
group_dset_f_dtype_default= ''
group_dset_f_dtype_double = ''
group_dset_f_dtype_single = ''
group_dset_dtype_default= ''
group_dset_dtype_double = ''
group_dset_dtype_single = ''
default_prec = ''
group_dset_py_dtype = ''
group_dset_sparse_value_format_printf = '%24.16e'
group_dset_sparse_value_format_scanf = '%lf'
# build format string for n-index sparse quantity
int32_len_printf = 10
item_printf = f'%{int32_len_printf}" PRId32 " '
item_scanf = '%" SCNd32 " '
group_dset_format_printf = ''
group_dset_format_scanf = ''
for i in range(rank):
group_dset_format_printf += item_printf
group_dset_format_scanf += item_scanf
# int64_len_printf = ??
# int16_len_printf = ??
group_dset_format_printf += group_dset_sparse_value_format_printf
group_dset_format_scanf += group_dset_sparse_value_format_scanf
# get the dtype-related substitutions required to replace templated variables later
if not is_sparse:
dtype_dict = get_dtype_dict(datatype, 'dset')
else:
dtype_dict = get_dtype_dict(datatype, 'dset', rank, int32_len_printf)
tmp_dict.update(dtype_dict)
# set the group_dset key to the full name of the dset
tmp_dict['group_dset'] = k
# add flag to detect index types
if 'index' == v[0]:
if 'index' in datatype:
tmp_dict['is_index'] = 'file->one_based'
else:
tmp_dict['is_index'] = 'false'
# add the datatypes for templates
tmp_dict['dtype'] = datatype
tmp_dict['group_dset_dtype'] = datatype
tmp_dict['group_dset_h5_dtype'] = group_dset_h5_dtype
tmp_dict['group_dset_f_dtype_default'] = group_dset_f_dtype_default
tmp_dict['group_dset_f_dtype_double'] = group_dset_f_dtype_double
tmp_dict['group_dset_f_dtype_single'] = group_dset_f_dtype_single
tmp_dict['group_dset_dtype_default'] = group_dset_dtype_default
tmp_dict['group_dset_dtype_double'] = group_dset_dtype_double
tmp_dict['group_dset_dtype_single'] = group_dset_dtype_single
tmp_dict['default_prec'] = default_prec
tmp_dict['group_dset_format_printf'] = group_dset_format_printf
tmp_dict['group_dset_format_scanf'] = group_dset_format_scanf
tmp_dict['group_dset_py_dtype'] = group_dset_py_dtype
# add the rank
tmp_dict['rank'] = rank
tmp_dict['group_dset_rank'] = str(tmp_dict['rank'])
tmp_dict['group_dset_rank'] = str(rank)
# add the list of dimensions
tmp_dict['dims'] = [dim.replace('.','_') for dim in v[1]]
# build a list of dimensions to be inserted in the dims array initialization, e.g. {ao_num, ao_num}
dim_list = tmp_dict['dims'][0]
if tmp_dict['rank'] > 1:
for i in range(1, tmp_dict['rank']):
if rank > 1:
for i in range(1, rank):
dim_toadd = tmp_dict['dims'][i]
dim_list += f', {dim_toadd}'
tmp_dict['group_dset_dim_list'] = dim_list
if tmp_dict['rank'] == 0:
if rank == 0:
dim_f_list = ""
else:
dim_f_list = "(*)"
@ -712,7 +729,7 @@ def split_dset_dict_detailed (datasets: dict) -> tuple:
tmp_dict['group'] = v[2]
# split datasets in numeric- and string- based
if datatype == 'char*':
if 'str' in datatype:
dset_string_dict[k] = tmp_dict
elif is_sparse:
dset_sparse_dict[k] = tmp_dict