1
0
mirror of https://github.com/TREX-CoE/qmckl.git synced 2024-11-19 20:42:50 +01:00
qmckl/python/process.py

176 lines
6.0 KiB
Python
Raw Normal View History

2022-05-02 13:39:22 +02:00
import os
collect = False
process = False
get_name = False
block = []
res_str = ''
func_name = ''
arrays = {}
numbers = {}
qmckl_public_api = []
with open("qmckl.h", 'r') as f_in:
2022-05-02 13:39:22 +02:00
for line in f_in:
if get_name:
words = line.strip().split()
if '(' in words[0]:
func_name = words[0].split('(')[0]
else:
func_name = words[0]
if 'get' in func_name or 'set' in func_name:
qmckl_public_api.append(func_name)
get_name = False
if 'qmckl_exit_code' in line:
words = line.strip().split()
if len(words) > 1 and 'qmckl_exit_code' in words[0]:
# this means that the function name is on the same line as `qmckl_exit_code`
func_name = words[1].split('(')[0]
if 'get' in func_name or 'set' in func_name:
qmckl_public_api.append(func_name)
elif len(words) == 1:
# this means that the function name is the first element on the next line
get_name = True
#continue # do not `continue` here otherwise collect is not True for some functions
# process functions - oneliners (for arrays)
if 'size_max' in line and ';' in line:
tmp_list = line.split(',')
for i,s in enumerate(tmp_list):
if 'size_max' in s:
end_str = tmp_list[i].replace(';','').replace('\n','')
pattern = f"({tmp_list[i-1]} ,{end_str}"
datatype = tmp_list[i-1].replace('const','').replace('*','').split()[0]
arrays[func_name] = {
'datatype' : datatype,
'pattern' : pattern
}
#if 'qmckl_get_jastrow_type_nucl_vector' in func_name:
# print(line)
# print(pattern)
continue
# if size_max is not provided then the function should deal with numbers or string
#elif 'num' in line and 'get' in func_name:
elif ';' in line and 'get' in func_name:
# special case
if 'size_max' in line:
continue
#print(line)
tmp_str = line.split(',')[-1].strip()
pattern = tmp_str.replace(')','').replace(';','')
datatype = pattern.replace('const','').replace('*','').split()[0]
numbers[func_name] = {
'datatype' : datatype,
'pattern' : pattern
}
continue
# for multilne functions - append line by line to the list
else:
block.append(line)
collect = True
continue
# if size_max is encountered within the multiline function
if 'size_max' in line and collect:
#if 'qmckl_get_electron_rescale_factor_en' in func_name:
# print("LOL")
# this will not work for 2-line functions where array argument is on the same line as
# func name and size_max argument is on the next line
if not 'qmckl_exit_code' in block[-1] and not '*/' in line:
pattern = '(' + block[-1].strip() + line.strip().replace(';','')
datatype = pattern.replace('const','').replace('*','').replace('(','').split()[0]
collect = False
block = []
arrays[func_name] = {
'datatype' : datatype,
'pattern' : pattern
}
continue
#if 'num' in line and 'get' in func_name and not 'qmckl_get' in line and collect:
if 'get' in func_name and not 'qmckl_get' in line and collect and ';' in line:
#print(func_name)
#print(line)
2022-05-02 13:39:22 +02:00
pattern = line.replace(';','').replace(')','').strip()
datatype = pattern.replace('const','').replace('*','').split()[0]
collect = False
block = []
numbers[func_name] = {
'datatype' : datatype,
'pattern' : pattern
}
continue
# stop/continue multiline function analyzer
if collect and ')' in line:
collect = False
block = []
continue
else:
block.append(line)
continue
# remove buggy qmckl_get_electron_rescale_factor_en key
#arrays.pop('qmckl_get_electron_rescale_factor_en')
processed = list(arrays.keys()) + list(numbers.keys())
2022-05-02 16:36:22 +02:00
#for pub_func in qmckl_public_api:
#if pub_func not in processed and 'set' not in pub_func:
#print("TODO", pub_func)
2022-05-02 13:39:22 +02:00
#print(v['datatype'])
2022-05-02 16:36:22 +02:00
#for k,v in numbers.items():
# print(v)
2022-05-02 13:39:22 +02:00
with open("pyqmckl_include.i", 'w') as f_out:
2022-05-02 13:39:22 +02:00
swig_type = ''
for v in numbers.values():
if 'int' in v['datatype']:
swig_type = 'int'
elif 'float' in v['datatype'] or 'double' in v['datatype']:
swig_type = 'float'
elif 'char' in v['datatype'] or 'bool' in v['datatype']:
2022-05-02 16:36:22 +02:00
#print('SWIG, skipping', v['datatype'], v['pattern'])
2022-05-02 13:39:22 +02:00
continue
else:
raise TypeError(f"Unknown datatype for swig conversion: {v['datatype']}")
f_out.write(f"%apply {swig_type} *OUTPUT {{ {v['pattern']} }};\n")
for k,v in arrays.items():
if 'char' in v['datatype']:
2022-05-02 16:36:22 +02:00
#print("String type", k, v)
2022-05-02 13:39:22 +02:00
pass
if len(v['pattern'].split(',')) != 2:
print('Problemo', k, v)
continue
if 'get' in k:
f_out.write(f"%apply ( {v['datatype']}* ARGOUT_ARRAY1 , int64_t DIM1 ) {{ {v['pattern']} }};\n")
elif 'set' in k:
f_out.write(f"%apply ( {v['datatype']}* IN_ARRAY1 , int64_t DIM1 ) {{ {v['pattern']} }};\n")
2022-05-02 16:36:22 +02:00
#else:
#print("HOW-TO ?", k)
2022-05-02 13:39:22 +02:00