-rwxr-xr-x 11926 lib25519-20220726/scripts-build/dispatch raw
#!/usr/bin/env python3
import os
import sys
import re
def cstring(x):
return '"%s"' % x.replace('\\','\\\\').replace('"','\\"').replace('\n','\\n')
def sanitize(x):
return ''.join(c if c in '0123456789abcdefghijklmnopqrstuvwxyz' else '_' for c in x)
operations = []
primitives = {}
sizes = {}
exports = {}
prototypes = {}
with open('api') as f:
for line in f:
line = line.strip()
if line.startswith('crypto_'):
line = line.split('/')
assert len(line) == 2
o = line[0].split('_')[1]
if o not in operations: operations += [o]
p = line[1]
primitives[o] = p
continue
if line.startswith('#define '):
x = line.split(' ')
x = x[1].split('_')
assert len(x) == 3
assert x[0] == 'crypto'
o = x[1]
if o not in sizes: sizes[o] = ''
sizes[o] += line+'\n'
continue
if line.endswith(');'):
fun,args = line[:-2].split('(')
rettype,fun = fun.split()
fun = fun.split('_')
o = fun[1]
assert fun[0] == 'crypto'
if o not in exports: exports[o] = []
exports[o] += ['_'.join(fun[1:])]
if o not in prototypes: prototypes[o] = []
prototypes[o] += [(rettype,fun,args)]
goal = sys.argv[1]
assert goal in ('auto','manual')
o = sys.argv[2]
host = sys.argv[3]
impls = []
for line in sys.stdin:
line = line.strip().split('/')
if line[0] != o: continue
impls += [line[1:]]
icarch = {}
iccompiler = {}
for i,c in impls:
with open('compilerarch/%s' % c) as f:
icarch[i,c] = f.read().strip()
with open('compilerversion/%s' % c) as f:
iccompiler[i,c] = f.read().strip()
def archkey(a):
if a == 'default': return 1,a # put default last
return -a.count('+'),a
allimpls = sorted(set(i for i,c in impls))
allarches = sorted(set(icarch[i,c] for i,c in impls),key=archkey)
if goal == 'auto':
prioritydata = []
for i in allimpls:
priorityfn = 'priority/%s-%s' % (o,i)
if not os.path.exists(priorityfn): continue
with open(priorityfn) as f:
for line in f:
line = line.split()
if len(line) < 7: continue
prio,score,priohost,cpuid,version,machine = line[:6]
c = ' '.join(line[6:])
prio = float(prio)
prioritydata += [(i,prio,score,priohost,cpuid,machine,c)]
def asupportsic(a,i,c):
a = a.split('+')[1:]
ica = icarch[i,c]
ica = ica.split('+')[1:]
return all(icapart in a for icapart in ica)
def cpuidsupports(cpuid,a):
a = a.split('+')
cpuid = [int('0x'+cpuid[8*j:8*j+8],16) for j in range(32)]
mmx = cpuid[18] & (1<<23)
sse = cpuid[18] & (1<<25)
sse2 = cpuid[18] & (1<<26)
sse3 = cpuid[17] & (1<<0)
ssse3 = cpuid[17] & (1<<9)
sse41 = cpuid[17] & (1<<19)
sse42 = cpuid[17] & (1<<20)
osxsave = cpuid[17] & (1<<27)
avx = cpuid[17] & (1<<28)
bmi1 = cpuid[20] & (1<<3)
avx2 = cpuid[20] & (1<<5)
bmi2 = cpuid[20] & (1<<8)
avx512f = cpuid[20] & (1<<16)
adx = cpuid[20] & (1<<19)
avx512ifma = cpuid[20] & (1<<21)
avx512vl = cpuid[20] & (1<<31)
xmmsaved = cpuid[27] & (1<<1)
ymmsaved = cpuid[27] & (1<<2)
for apart in a[1:]:
if apart not in ('adx','avx','bmi2','avx2','avx512f','avx512vl','avx512ifma'):
raise ValueError('cpuidsupports does not understand %s' % apart)
if apart == 'avx512f':
if not avx512f: return False
if apart == 'avx512vl':
if not avx512vl: return False
if apart == 'avx512ifma':
if not avx512ifma: return False
if apart == 'bmi2':
if not bmi1: return False
if not bmi2: return False
if apart == 'adx':
if not adx: return False
if apart == 'avx2':
if not avx2: return False
if apart in ('avx','avx2'):
if not avx: return False
if not mmx: return False
if not sse: return False
if not sse2: return False
if not sse3: return False
if not ssse3: return False
if not sse41: return False
if not sse42: return False
if not osxsave: return False
if not xmmsaved: return False
if not ymmsaved: return False
return True
def selectic(a,aexclude):
if len(aexclude) > 0:
print('/* considering other machines supporting %s */' % a)
else:
print('/* considering machines supporting %s */' % a)
# requirement: icarch[i,c] is a subset of a
compatibleimpls = [(i,c) for i,c in impls if asupportsic(a,i,c)]
assert len(compatibleimpls) > 0
# desideratum: good performance based on prioritydata
directmatches = any(
priohost == host
and cpuidsupports(cpuid,a)
and all(not cpuidsupports(cpuid,b) for b in aexclude)
for i,prio,score,priohost,cpuid,machine,c in prioritydata
)
if not directmatches:
print('/* no direct matches, so extrapolating from all machines */')
totalprio = {(i,c):0 for i,c in compatibleimpls}
totalweight = {(i,c):0 for i,c in compatibleimpls}
for prioi,prio,score,priohost,cpuid,machine,prioc in prioritydata:
if directmatches:
if priohost != host: continue
if any(cpuidsupports(cpuid,b) for b in aexclude): continue
if not cpuidsupports(cpuid,a): continue
for i,c in compatibleimpls:
if i != prioi: continue
# XXX: use more serious machine learning here
weight = 1.0
if priohost == host: weight *= 10
if cpuidsupports(cpuid,a): weight *= 10
if all(not cpuidsupports(cpuid,b) for b in aexclude): weight *= 10
weight *= 1+len(os.path.commonprefix([iccompiler[i,c],prioc]))
if iccompiler[i,c] == prioc: weight *= 10
# print('/* weight %s from %s %s %s %s for %s %s */' % (weight,prio,machine,prioi,prioc,i,c))
totalprio[i,c] += prio*weight
totalweight[i,c] += weight
# note that implementations without priority data are excluded from ranking
ranking = [(totalprio[i,c]/totalweight[i,c],i,c) for i,c in compatibleimpls if totalweight[i,c] > 0]
ranking.sort()
for prio,i,c in ranking:
print('/* priority %s for %s %s */' % (prio,i,c))
if len(ranking) == 0:
return compatibleimpls[0]
return ranking[0][1:]
todo = []
usedimpls = set()
handledarches = set()
for a in allarches:
i,c = selectic(a,handledarches)
usedimpls.add((i,c))
todo += [(a,i,c)]
handledarches.add(a)
for a,i,c in todo:
print('/* decision: for %s use %s %s */' % (a,i,c))
print('')
if goal == 'auto':
print('extern const char *lib25519_%s_implementation(void) __attribute__((visibility("default")));' % o)
print('extern const char *lib25519_%s_compiler(void) __attribute__((visibility("default")));' % o)
else:
print('extern const char *lib25519_%s_implementation(void);' % o)
print('extern const char *lib25519_%s_compiler(void);' % o)
print('extern const char *lib25519_dispatch_%s_implementation(long long) __attribute__((visibility("default")));' % o)
print('extern const char *lib25519_dispatch_%s_compiler(long long) __attribute__((visibility("default")));' % o)
print('extern long long lib25519_numimpl_%s(void) __attribute__((visibility("default")));' % o)
for a in allarches:
if a == 'default': continue
a_csymbol = sanitize(a)
print('extern int lib25519_supports_%s(void);' % a_csymbol)
if len(allarches) > 1: print('')
def printfun_auto(which,fun=None):
if which == 'resolver':
shortfun = '_'.join(fun[1:])
print('void *lib25519_auto_%s(void)' % shortfun)
elif which == 'implementation':
print('const char *lib25519_%s_implementation(void)' % o)
elif which == 'compiler':
print('const char *lib25519_%s_compiler(void)' % o)
else:
raise ValueError('unknown printfun %s' % which)
print('{')
for a,i,c in todo:
cond = ''
if a != 'default':
cond = 'if (lib25519_supports_%s()) ' % sanitize(a)
if which == 'resolver':
print(' %sreturn lib25519_%s_%s_%s_%s;' % (cond,o,sanitize(i),c,shortfun))
if which == 'implementation':
print(' %sreturn %s;' % (cond,cstring(i)))
if which == 'compiler':
print(' %sreturn %s;' % (cond,cstring(iccompiler[i,c])))
if a == 'default': break
print('}')
if which == 'resolver':
print('')
print('%s lib25519_%s(%s) __attribute__((visibility("default"))) __attribute__((ifunc("lib25519_auto_%s")));' % (rettype,shortfun,args,shortfun))
for rettype,fun,args in prototypes[o]:
shortfun = '_'.join(fun[1:])
if goal == 'auto':
print('extern %s lib25519_%s(%s) __attribute__((visibility("default")));' % (rettype,shortfun,args))
else:
print('extern %s lib25519_%s(%s);' % (rettype,shortfun,args))
print('extern %s (*lib25519_dispatch_%s(long long))(%s) __attribute__((visibility("default")));' % (rettype,shortfun,args))
print('')
for i,c in impls:
if goal == 'auto':
if (i,c) not in usedimpls:
continue
print('extern %s lib25519_%s_%s_%s_%s(%s) __attribute__((visibility("default")));' % (rettype,o,sanitize(i),c,shortfun,args))
print('')
if goal == 'auto':
printfun_auto('resolver',fun)
if goal == 'manual':
namedparams = args.split(',')
for i in range(len(namedparams)):
if namedparams[i][-1] != '*':
namedparams[i] += ' '
namedparams[i] += 'arg%d'%i
namedparams = ','.join(namedparams)
print('%s (*lib25519_dispatch_%s(long long impl))(%s)' % (rettype,shortfun,args))
print('{')
for a in allarches:
if a == 'default': continue
a_csymbol = sanitize(a)
print(' int supports_%s = lib25519_supports_%s();' % (a_csymbol,a_csymbol))
print(' if (impl >= 0) {')
for i,c in impls:
a = icarch[i,c]
a_csymbol = sanitize(a)
if a == 'default':
print(' if (!impl--) return lib25519_%s_%s_%s_%s;' % (o,sanitize(i),c,shortfun))
else:
print(' if (supports_%s) if (!impl--) return lib25519_%s_%s_%s_%s;' % (a_csymbol,o,sanitize(i),c,shortfun))
print(' }')
print(' return lib25519_%s;' % shortfun)
print('}')
print('')
if goal == 'auto':
printfun_auto('implementation')
print('')
printfun_auto('compiler')
else:
print('const char *lib25519_dispatch_%s_implementation(long long impl)' % o)
print('{')
for a in allarches:
if a == 'default': continue
a_csymbol = sanitize(a)
print(' int supports_%s = lib25519_supports_%s();' % (a_csymbol,a_csymbol))
print(' if (impl >= 0) {')
for i,c in impls:
a = icarch[i,c]
a_csymbol = sanitize(a)
if a == 'default':
print(' if (!impl--) return %s;' % (cstring(i)))
else:
print(' if (supports_%s) if (!impl--) return %s;' % (a_csymbol,cstring(i)))
print(' }')
print(' return lib25519_%s_implementation();' % o)
print('}')
print('')
print('const char *lib25519_dispatch_%s_compiler(long long impl)' % o)
print('{')
for a in allarches:
if a == 'default': continue
a_csymbol = sanitize(a)
print(' int supports_%s = lib25519_supports_%s();' % (a_csymbol,a_csymbol))
print(' if (impl >= 0) {')
for i,c in impls:
a = icarch[i,c]
a_csymbol = sanitize(a)
if a == 'default':
print(' if (!impl--) return %s;' % (cstring(iccompiler[i,c])))
else:
print(' if (supports_%s) if (!impl--) return %s;' % (a_csymbol,cstring(iccompiler[i,c])))
print(' }')
print(' return lib25519_%s_compiler();' % o)
print('}')
print('')
print('long long lib25519_numimpl_%s(void)' % o)
print('{')
numimpla = sum(1 for (i,c) in impls if icarch[i,c] == 'default')
numimpl = ['%d' % numimpla]
for a in allarches:
if a == 'default': continue
a_csymbol = sanitize(a)
print(' long long supports_%s = lib25519_supports_%s();' % (a_csymbol,a_csymbol))
numimpla = sum(1 for (i,c) in impls if icarch[i,c] == a)
numimpl += ['supports_%s*%d' % (a_csymbol,numimpla)]
print(' return %s;' % '+'.join(numimpl))
print('}')