1
0
mirror of https://github.com/TREX-CoE/qmckl.git synced 2024-12-23 04:44:03 +01:00
qmckl/python/src/process_header.py

189 lines
6.5 KiB
Python

#!/usr/bin/env python3
import os
import sys
qmckl_h = sys.argv[1]
collect = False
process = False
get_name = False
block = []
res_str = ''
func_name = ''
arrays = {}
numbers = {}
qmckl_public_api = []
qmckl_errors = []
with open(qmckl_h, 'r') as f_in:
for line in f_in:
# get the errors but without the type cast because SWIG does not recognize it
if '#define' in line and 'qmckl_exit_code' in line:
qmckl_errors.append(line.strip().replace('(qmckl_exit_code)',''))
continue
if get_name:
words = line.strip().split()
if '(' in words[0]:
func_name = words[0].split('(')[0]
else:
func_name = words[0]
if 'get' in func_name or 'set' in func_name:
qmckl_public_api.append(func_name)
get_name = False
if 'qmckl_exit_code' in line:
words = line.strip().split()
if len(words) > 1 and 'qmckl_exit_code' in words[0]:
# this means that the function name is on the same line as `qmckl_exit_code`
func_name = words[1].split('(')[0]
if 'get' in func_name or 'set' in func_name:
qmckl_public_api.append(func_name)
elif len(words) == 1:
# this means that the function name is the first element on the next line
get_name = True
#continue # do not `continue` here otherwise collect is not True for some functions
# process functions - oneliners (for arrays)
if 'size_max' in line and ';' in line:
tmp_list = line.split(',')
for i,s in enumerate(tmp_list):
if 'size_max' in s:
end_str = tmp_list[i].replace(';','').replace('\n','')
pattern = f"({tmp_list[i-1]} ,{end_str}"
datatype = tmp_list[i-1].replace('const','').replace('*','').split()[0]
arrays[func_name] = {
'datatype' : datatype,
'pattern' : pattern
}
#if 'qmckl_get_jastrow_type_nucl_vector' in func_name:
# print(line)
# print(pattern)
continue
# if size_max is not provided then the function should deal with numbers or string
#elif 'num' in line and 'get' in func_name:
elif ';' in line and 'get' in func_name:
# special case
if 'size_max' in line:
continue
#print(line)
tmp_str = line.split(',')[-1].strip()
pattern = tmp_str.replace(')','').replace(';','')
datatype = pattern.replace('const','').replace('*','').split()[0]
numbers[func_name] = {
'datatype' : datatype,
'pattern' : pattern
}
continue
# for multilne functions - append line by line to the list
else:
block.append(line)
collect = True
continue
# if size_max is encountered within the multiline function
if 'size_max' in line and collect:
#if 'qmckl_get_electron_rescale_factor_en' in func_name:
# print("LOL")
# this will not work for 2-line functions where array argument is on the same line as
# func name and size_max argument is on the next line
if not 'qmckl_exit_code' in block[-1] and not '*/' in line:
pattern = '(' + block[-1].strip() + line.strip().replace(';','')
datatype = pattern.replace('const','').replace('*','').replace('(','').split()[0]
collect = False
block = []
arrays[func_name] = {
'datatype' : datatype,
'pattern' : pattern
}
continue
#if 'num' in line and 'get' in func_name and not 'qmckl_get' in line and collect:
if 'get' in func_name and not 'qmckl_get' in line and collect and ';' in line:
#print(func_name)
#print(line)
pattern = line.replace(';','').replace(')','').strip()
datatype = pattern.replace('const','').replace('*','').split()[0]
collect = False
block = []
numbers[func_name] = {
'datatype' : datatype,
'pattern' : pattern
}
continue
# stop/continue multiline function analyzer
if collect and ')' in line:
collect = False
block = []
continue
else:
block.append(line)
continue
# remove buggy qmckl_get_electron_rescale_factor_en key
#arrays.pop('qmckl_get_electron_rescale_factor_en')
processed = list(arrays.keys()) + list(numbers.keys())
#for pub_func in qmckl_public_api:
#if pub_func not in processed and 'set' not in pub_func:
#print("TODO", pub_func)
#print(v['datatype'])
#for k,v in numbers.items():
# print(v)
with open("qmckl_include.i", 'w') as f_out:
# write the list of errors as constants without the type cast
for e in qmckl_errors:
line = e.replace('#define', '%constant qmckl_exit_code').replace('(','=').replace(')',';')
f_out.write(line + '\n')
swig_type = ''
for v in numbers.values():
if 'int' in v['datatype']:
swig_type = 'int'
elif 'float' in v['datatype'] or 'double' in v['datatype']:
swig_type = 'float'
elif 'char' in v['datatype'] or 'bool' in v['datatype']:
#print('SWIG, skipping', v['datatype'], v['pattern'])
continue
else:
raise TypeError(f"Unknown datatype for swig conversion: {v['datatype']}")
f_out.write(f"%apply {swig_type} *OUTPUT {{ {v['pattern']} }};\n")
for k,v in arrays.items():
if 'char' in v['datatype']:
#print("String type", k, v)
pass
if len(v['pattern'].split(',')) != 2:
print('Problemo', k, v)
continue
if 'get' in k:
f_out.write(f"%apply ( {v['datatype']}* ARGOUT_ARRAY1 , int64_t DIM1 ) {{ {v['pattern']} }};\n")
elif 'set' in k:
f_out.write(f"%apply ( {v['datatype']}* IN_ARRAY1 , int64_t DIM1 ) {{ {v['pattern']} }};\n")
#else:
#print("HOW-TO ?", k)