#!/usr/bin/env python3

import os
import sys
import subprocess

host = sys.argv[1]

def warn(w):
  sys.stderr.write('warning: %s\n' % w)
  fresult.write('%s\n' % w)

carch = {}
for c in os.listdir('compilerarch'):
  with open('compilerarch/%s' % c) as f:
    carch[c] = f.read().strip()
    if carch[c] == 'default': carch[c] = host
    assert carch[c].split('+')[0] == host

insns = {}

with open('scripts/intel-instructions') as f:
  for line in f:
    mnem,args,ext = line.split()
    if mnem not in insns: insns[mnem] = []
    insns[mnem] += [(args,ext)]

aliases = {
  'movabs':'mov', # XXX
  'movs':'movsb', # XXX
  'stos':'stosb', # XXX
  'cmovnae':'cmovb',
  'cmovc':'cmovb',
  'cmovna':'cmovbe',
  'cmovnge':'cmovl',
  'cmovng':'cmovle',
  'cmovae':'cmovnb',
  'cmovnc':'cmovnb',
  'cmova':'cmovnbe',
  'cmovge':'cmovnl',
  'cmovg':'cmovnle',
  'cmovpo':'cmovnp',
  'cmovne':'cmovnz',
  'cmovpe':'cmovp',
  'cmove':'cmovz',
  'setnae':'setb',
  'setc':'setb',
  'setna':'setbe',
  'setnge':'setl',
  'setng':'setle',
  'setae':'setnb',
  'setnc':'setnb',
  'seta':'setnbe',
  'setge':'setnl',
  'setg':'setnle',
  'setpo':'setnp',
  'setne':'setnz',
  'setpe':'setp',
  'sete':'setz',
  'jnae':'jb',
  'jc':'jb',
  'jna':'jbe',
  'jnge':'jl',
  'jng':'jle',
  'jae':'jnb',
  'jnc':'jnb',
  'ja':'jnbe',
  'jge':'jnl',
  'jg':'jnle',
  'jpo':'jnp',
  'jne':'jnz',
  'jpe':'jp',
  'je':'jz',
}

def argsmatch(args,insnargs,ext):
  if args == '-':
    if insnargs == '':
      return True

  xmmlimit = 16
  ymmlimit = 16
  if ext.startswith('avx512'):
    # XXX: this is really just for avx512vl
    xmmlimit = 32
    ymmlimit = 32

  args = args.split(',')
  if args[-2:] == ['base','index']: args = args[:-2]

  insnargs = insnargs.replace('{k0}',',k0')
  insnargs = insnargs.replace('{k1}',',k1')
  insnargs = insnargs.replace('{k2}',',k2')
  insnargs = insnargs.replace('{k3}',',k3')
  insnargs = insnargs.replace('{k4}',',k4')
  insnargs = insnargs.replace('{k5}',',k5')
  insnargs = insnargs.replace('{k6}',',k6')
  insnargs = insnargs.replace('{k7}',',k7')

  insnargs = insnargs.split(',')
  if len(insnargs) != len(args): return False
  for a,i in zip(args,insnargs):
    if a in ('r','vr'):
      if i not in ('rax','rbx','rcx','rdx','rsi','rdi','rbp','rsp','r8','r9','r10','r11','r12','r13','r14','r15'):
        if i not in ('eax','ebx','ecx','edx','esi','edi','ebp','esp','r8d','r9d','r10d','r11d','r12d','r13d','r14d','r15d'):
          if i not in ('ax','bx','cx','dx','si','di','sp','bp','r8w','r9w','r10w','r11w','r12w','r13w','r14w','r15w'):
            if i not in ('al','bl','cl','dl','ah','bh','ch','dh','sil','dil','bpl','spl','r8b','r9b','r10b','r11b','r12b','r13b','r14b','r15b'):
              return False
    elif a in ('r8','vr8'):
      if i not in ('al','bl','cl','dl','ah','bh','ch','dh','sil','dil','bpl','spl','r8b','r9b','r10b','r11b','r12b','r13b','r14b','r15b'):
        return False
    elif a in ('r16','vr16'):
      if i not in ('ax','bx','cx','dx','si','di','sp','bp','r8w','r9w','r10w','r11w','r12w','r13w','r14w','r15w'):
        return False
    elif a in ('r32','vr32'):
      if i not in ('eax','ebx','ecx','edx','esi','edi','esp','ebp','r8d','r9d','r10d','r11d','r12d','r13d','r14d','r15d'):
        return False
    elif a in ('r64','vr64'):
      if i not in ('rax','rbx','rcx','rdx','rsi','rdi','rsp','rbp','r8','r9','r10','r11','r12','r13','r14','r15'):
        return False
    elif a == 'xmm':
      if i not in ['xmm%d'%j for j in range(xmmlimit)]:
        return False
    elif a == 'ymm':
      if i not in ['ymm%d'%j for j in range(ymmlimit)]:
        return False
    elif a == 'zmm':
      if i not in ['zmm%d'%j for j in range(32)]:
        return False
    elif a == 'k':
      if i not in ('k0','k1','k2','k3','k4','k5','k6','k7'):
        return False
    elif a in ('m','agen'):
      if '[' not in i:
        if 'PTR' not in i:
          return False
    elif a == 'relbr':
      if not all(c in '0123456789abcdef' for c in i):
        return False
    elif a == 'imm':
      if not i.startswith('0x'):
        if not i.isnumeric():
          return False
    else:
      return False
  return True

def extok(arch,ext):
  arch = arch.split('+')[1:]
  if ext in ('base','i86','i186','i286','i386','i486real','longmode','x87','sse','sse2','cmov','fat_nop'):
    return True # amd64 guarantees support for these
  if ext == 'popcnt': return 'popcnt' in arch
  if ext == 'bmi1': return 'bmi1' in arch
  if ext == 'bmi2': return 'bmi2' in arch
  if ext == 'sse3': return 'sse3' in arch
  if ext == 'ssse3': return 'ssse3' in arch
  if ext == 'sse4.1': return 'sse41' in arch
  if ext == 'sse4.2': return 'sse42' in arch
  if ext == 'sse4a': return 'sse4a' in arch
  if ext == 'avx': return 'avx' in arch
  if ext == 'adx': return 'adx' in arch
  if ext == 'avx2': return 'avx2' in arch
  if ext == 'avx512f': return 'avx512f' in arch
  if ext == 'avx512vl': return 'avx512vl' in arch
  if ext == 'avx512ifma': return 'avx512ifma' in arch
  # XXX: could test for vaes, but not actually interested in using it
  # XXX: could test for waitpkg, but not actually interested in using it
  return False # XXX

def checkinsn(obj,arch,insn):
  global extset

  insn = insn.split()
  if host == 'amd64' and insn == ['endbr64']:
    return # XXX
  if host == 'amd64' and insn == ['xgetbv'] and (obj.startswith('compilers/') or obj.startswith('cpuid/')):
    return # assume dispatch code is using xgetbv carefully
  while len(insn) >= 1 and insn[0] in ('cs','data16','rep','repe','repz','notrack'):
    insn = insn[1:]
  if len(insn) == 0: return
  mnem = insn[0]
  if mnem in aliases:
    mnem = aliases[mnem]
  if mnem not in insns:
    warn('unknown mnemonic %s' % mnem)
    return
  for args,ext in insns[mnem]:
    if argsmatch(args,' '.join(insn[1:]),ext):
      extset.add(ext)
      if not extok(arch,ext):
        warn('%s: %s instruction set does not allow %s: %s' % (obj,arch,ext,' '.join(insn)))
      return
  warn('%s: no args match for %s' % (obj,' '.join(insn)))

def dir2objs(d):
  objs = [d+'/'+fn for fn in os.listdir(d) if fn.endswith('.o')]
  if os.path.exists(d+'/dependencies'):
    with open(d+'/dependencies') as f:
      for line in f:
        objs += dir2objs(line.strip())
  return objs

def doit(d):
  if host != 'amd64':
    warn('this script does not know how to check instruction-set extensions for %s' % host)
    return

  objs = dir2objs(d)

  if len(d.split('/')) < 4:
    arch = host # no instruction-set extensions allowed
  else:
    arch = carch[d.split('/')[3]]

  try:
    # XXX: in principle should inspect raw insns for, e.g., EVEX vs. VEX
    p = subprocess.Popen(['objdump','-d','-M','intel-mnemonic','-M','intel','--no-show-raw-insn']+objs,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,universal_newlines=True)
    out,err = p.communicate()
    if err:
      warn('objdump error: %s' % err)
    elif p.returncode:
      warn('objdump failure: %s' % p.returncode)
    else:
      obj = 'unknown-object-file'
      for line in out.splitlines():
        if len(objs) > 0 and line.startswith('%s:'%objs[0]):
          obj,objs = objs[0],objs[1:]
        if '\t' in line:
          line = line[line.index('\t')+1:]
          if '#' in line:
            line = line[:line.index('#')]
          if '<' in line:
            line = line[:line.index('<')]
          checkinsn(obj,arch,line)
  except Exception as e:
    warn('objdump failure: %s' % e)

for d in sys.argv[2:]:
  d = d.strip()
  extset = set()
  with open('%s/result-insns'%d,'w') as fresult:
    doit(d)
  with open('%s/info-insns'%d,'w') as finfo:
    for ext in sorted(extset):
      finfo.write(ext+'\n')