import os collect = False process = False get_name = False block = [] res_str = '' func_name = '' arrays = {} numbers = {} qmckl_public_api = [] with open("include/qmckl.h", 'r') as f_in: for line in f_in: if get_name: words = line.strip().split() if '(' in words[0]: func_name = words[0].split('(')[0] else: func_name = words[0] if 'get' in func_name or 'set' in func_name: qmckl_public_api.append(func_name) get_name = False if 'qmckl_exit_code' in line: words = line.strip().split() if len(words) > 1 and 'qmckl_exit_code' in words[0]: # this means that the function name is on the same line as `qmckl_exit_code` func_name = words[1].split('(')[0] if 'get' in func_name or 'set' in func_name: qmckl_public_api.append(func_name) elif len(words) == 1: # this means that the function name is the first element on the next line get_name = True #continue # do not `continue` here otherwise collect is not True for some functions # process functions - oneliners (for arrays) if 'size_max' in line and ';' in line: tmp_list = line.split(',') for i,s in enumerate(tmp_list): if 'size_max' in s: end_str = tmp_list[i].replace(';','').replace('\n','') pattern = f"({tmp_list[i-1]} ,{end_str}" datatype = tmp_list[i-1].replace('const','').replace('*','').split()[0] arrays[func_name] = { 'datatype' : datatype, 'pattern' : pattern } #if 'qmckl_get_jastrow_type_nucl_vector' in func_name: # print(line) # print(pattern) continue # if size_max is not provided then the function should deal with numbers or string #elif 'num' in line and 'get' in func_name: elif ';' in line and 'get' in func_name: # special case if 'size_max' in line: continue #print(line) tmp_str = line.split(',')[-1].strip() pattern = tmp_str.replace(')','').replace(';','') datatype = pattern.replace('const','').replace('*','').split()[0] numbers[func_name] = { 'datatype' : datatype, 'pattern' : pattern } continue # for multilne functions - append line by line to the list else: block.append(line) collect = True continue # if size_max is encountered within the multiline function if 'size_max' in line and collect: #if 'qmckl_get_electron_rescale_factor_en' in func_name: # print("LOL") # this will not work for 2-line functions where array argument is on the same line as # func name and size_max argument is on the next line if not 'qmckl_exit_code' in block[-1] and not '*/' in line: pattern = '(' + block[-1].strip() + line.strip().replace(';','') datatype = pattern.replace('const','').replace('*','').replace('(','').split()[0] collect = False block = [] arrays[func_name] = { 'datatype' : datatype, 'pattern' : pattern } continue #if 'num' in line and 'get' in func_name and not 'qmckl_get' in line and collect: if 'get' in func_name and not 'qmckl_get' in line and collect and ';' in line: print(func_name) print(line) pattern = line.replace(';','').replace(')','').strip() datatype = pattern.replace('const','').replace('*','').split()[0] collect = False block = [] numbers[func_name] = { 'datatype' : datatype, 'pattern' : pattern } continue # stop/continue multiline function analyzer if collect and ')' in line: collect = False block = [] continue else: block.append(line) continue # remove buggy qmckl_get_electron_rescale_factor_en key #arrays.pop('qmckl_get_electron_rescale_factor_en') processed = list(arrays.keys()) + list(numbers.keys()) for pub_func in qmckl_public_api: if pub_func not in processed and 'set' not in pub_func: print("TODO", pub_func) #print(v['datatype']) for k,v in numbers.items(): print(v) with open("python/pyqmckl_include.i", 'w') as f_out: swig_type = '' for v in numbers.values(): if 'int' in v['datatype']: swig_type = 'int' elif 'float' in v['datatype'] or 'double' in v['datatype']: swig_type = 'float' elif 'char' in v['datatype'] or 'bool' in v['datatype']: print('SWIG, skipping') continue else: raise TypeError(f"Unknown datatype for swig conversion: {v['datatype']}") f_out.write(f"%apply {swig_type} *OUTPUT {{ {v['pattern']} }};\n") for k,v in arrays.items(): if 'char' in v['datatype']: print("String type") pass if len(v['pattern'].split(',')) != 2: print('Problemo', k, v) continue if 'get' in k: f_out.write(f"%apply ( {v['datatype']}* ARGOUT_ARRAY1 , int64_t DIM1 ) {{ {v['pattern']} }};\n") elif 'set' in k: f_out.write(f"%apply ( {v['datatype']}* IN_ARRAY1 , int64_t DIM1 ) {{ {v['pattern']} }};\n") else: print("HOW-TO ?", k)