mirror of
https://github.com/TREX-CoE/qmckl.git
synced 2024-11-19 12:32:40 +01:00
176 lines
6.0 KiB
Python
176 lines
6.0 KiB
Python
|
|
import os
|
|
|
|
|
|
collect = False
|
|
process = False
|
|
get_name = False
|
|
block = []
|
|
res_str = ''
|
|
func_name = ''
|
|
arrays = {}
|
|
numbers = {}
|
|
qmckl_public_api = []
|
|
|
|
|
|
with open("qmckl.h", 'r') as f_in:
|
|
for line in f_in:
|
|
|
|
if get_name:
|
|
words = line.strip().split()
|
|
if '(' in words[0]:
|
|
func_name = words[0].split('(')[0]
|
|
else:
|
|
func_name = words[0]
|
|
if 'get' in func_name or 'set' in func_name:
|
|
qmckl_public_api.append(func_name)
|
|
|
|
get_name = False
|
|
|
|
if 'qmckl_exit_code' in line:
|
|
words = line.strip().split()
|
|
if len(words) > 1 and 'qmckl_exit_code' in words[0]:
|
|
# this means that the function name is on the same line as `qmckl_exit_code`
|
|
func_name = words[1].split('(')[0]
|
|
if 'get' in func_name or 'set' in func_name:
|
|
qmckl_public_api.append(func_name)
|
|
elif len(words) == 1:
|
|
# this means that the function name is the first element on the next line
|
|
get_name = True
|
|
#continue # do not `continue` here otherwise collect is not True for some functions
|
|
|
|
# process functions - oneliners (for arrays)
|
|
if 'size_max' in line and ';' in line:
|
|
|
|
tmp_list = line.split(',')
|
|
for i,s in enumerate(tmp_list):
|
|
if 'size_max' in s:
|
|
end_str = tmp_list[i].replace(';','').replace('\n','')
|
|
pattern = f"({tmp_list[i-1]} ,{end_str}"
|
|
datatype = tmp_list[i-1].replace('const','').replace('*','').split()[0]
|
|
arrays[func_name] = {
|
|
'datatype' : datatype,
|
|
'pattern' : pattern
|
|
}
|
|
#if 'qmckl_get_jastrow_type_nucl_vector' in func_name:
|
|
# print(line)
|
|
# print(pattern)
|
|
continue
|
|
|
|
# if size_max is not provided then the function should deal with numbers or string
|
|
#elif 'num' in line and 'get' in func_name:
|
|
elif ';' in line and 'get' in func_name:
|
|
# special case
|
|
if 'size_max' in line:
|
|
continue
|
|
|
|
#print(line)
|
|
|
|
tmp_str = line.split(',')[-1].strip()
|
|
|
|
pattern = tmp_str.replace(')','').replace(';','')
|
|
datatype = pattern.replace('const','').replace('*','').split()[0]
|
|
|
|
numbers[func_name] = {
|
|
'datatype' : datatype,
|
|
'pattern' : pattern
|
|
}
|
|
continue
|
|
# for multilne functions - append line by line to the list
|
|
else:
|
|
block.append(line)
|
|
collect = True
|
|
continue
|
|
|
|
# if size_max is encountered within the multiline function
|
|
if 'size_max' in line and collect:
|
|
#if 'qmckl_get_electron_rescale_factor_en' in func_name:
|
|
# print("LOL")
|
|
|
|
# this will not work for 2-line functions where array argument is on the same line as
|
|
# func name and size_max argument is on the next line
|
|
if not 'qmckl_exit_code' in block[-1] and not '*/' in line:
|
|
pattern = '(' + block[-1].strip() + line.strip().replace(';','')
|
|
datatype = pattern.replace('const','').replace('*','').replace('(','').split()[0]
|
|
|
|
collect = False
|
|
block = []
|
|
arrays[func_name] = {
|
|
'datatype' : datatype,
|
|
'pattern' : pattern
|
|
}
|
|
continue
|
|
|
|
#if 'num' in line and 'get' in func_name and not 'qmckl_get' in line and collect:
|
|
if 'get' in func_name and not 'qmckl_get' in line and collect and ';' in line:
|
|
#print(func_name)
|
|
#print(line)
|
|
pattern = line.replace(';','').replace(')','').strip()
|
|
datatype = pattern.replace('const','').replace('*','').split()[0]
|
|
|
|
collect = False
|
|
block = []
|
|
numbers[func_name] = {
|
|
'datatype' : datatype,
|
|
'pattern' : pattern
|
|
}
|
|
continue
|
|
|
|
# stop/continue multiline function analyzer
|
|
if collect and ')' in line:
|
|
collect = False
|
|
block = []
|
|
continue
|
|
else:
|
|
block.append(line)
|
|
continue
|
|
|
|
|
|
# remove buggy qmckl_get_electron_rescale_factor_en key
|
|
#arrays.pop('qmckl_get_electron_rescale_factor_en')
|
|
|
|
processed = list(arrays.keys()) + list(numbers.keys())
|
|
|
|
#for pub_func in qmckl_public_api:
|
|
#if pub_func not in processed and 'set' not in pub_func:
|
|
#print("TODO", pub_func)
|
|
#print(v['datatype'])
|
|
|
|
#for k,v in numbers.items():
|
|
# print(v)
|
|
|
|
|
|
with open("pyqmckl_include.i", 'w') as f_out:
|
|
|
|
swig_type = ''
|
|
for v in numbers.values():
|
|
|
|
if 'int' in v['datatype']:
|
|
swig_type = 'int'
|
|
elif 'float' in v['datatype'] or 'double' in v['datatype']:
|
|
swig_type = 'float'
|
|
elif 'char' in v['datatype'] or 'bool' in v['datatype']:
|
|
#print('SWIG, skipping', v['datatype'], v['pattern'])
|
|
continue
|
|
else:
|
|
raise TypeError(f"Unknown datatype for swig conversion: {v['datatype']}")
|
|
|
|
f_out.write(f"%apply {swig_type} *OUTPUT {{ {v['pattern']} }};\n")
|
|
|
|
for k,v in arrays.items():
|
|
if 'char' in v['datatype']:
|
|
#print("String type", k, v)
|
|
pass
|
|
|
|
if len(v['pattern'].split(',')) != 2:
|
|
print('Problemo', k, v)
|
|
continue
|
|
|
|
if 'get' in k:
|
|
f_out.write(f"%apply ( {v['datatype']}* ARGOUT_ARRAY1 , int64_t DIM1 ) {{ {v['pattern']} }};\n")
|
|
elif 'set' in k:
|
|
f_out.write(f"%apply ( {v['datatype']}* IN_ARRAY1 , int64_t DIM1 ) {{ {v['pattern']} }};\n")
|
|
#else:
|
|
#print("HOW-TO ?", k)
|
|
|