mirror of
https://github.com/RRZE-HPC/asmbench.git
synced 2025-09-05 00:20:06 +02:00
336 lines
11 KiB
Python
Executable File
336 lines
11 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
import copy
|
|
import itertools
|
|
|
|
|
|
# TODO use abc to force implementation of interface requirements
|
|
|
|
class Operand:
|
|
def __init__(self, llvm_type):
|
|
self.llvm_type = llvm_type
|
|
|
|
def get_ir_repr(self):
|
|
raise NotImplementedError()
|
|
|
|
def get_constraint_char(self):
|
|
raise NotImplementedError()
|
|
|
|
|
|
class Immediate(Operand):
|
|
def __init__(self, llvm_type, value):
|
|
Operand.__init__(self, llvm_type)
|
|
self.value = value
|
|
|
|
def get_ir_repr(self):
|
|
return self.value
|
|
|
|
def get_constraint_char(self):
|
|
return 'i'
|
|
|
|
class MemoryReference(Operand):
|
|
'''
|
|
offset + base + index*width
|
|
|
|
OFFSET(BASE, INDEX, WIDTH) in AT&T assembly
|
|
|
|
Possible operand values:
|
|
offset: immediate integer (+/-)
|
|
base: register
|
|
index: register
|
|
width: immediate 1,2,4 or 8
|
|
'''
|
|
def __init__(self, llvm_type, offset=None, base=None, index=None, width=None):
|
|
self.offset = offset
|
|
self.base = base
|
|
self.index = index
|
|
self.width = width
|
|
self.destination = destination
|
|
self.parallel = parallel
|
|
|
|
# Sanity checks:
|
|
if bool(index) ^ bool(width):
|
|
raise ValueError("Index and width both need to be set, or None.")
|
|
elif index and width:
|
|
if not (isinstance(width, Immediate) and int(width.value) in [1,2,4,8]):
|
|
raise ValueError("Width may only be immediate 1,2,4 or 8.")
|
|
if not isinstance(index, Register):
|
|
raise ValueError("Index must be a register.")
|
|
|
|
if offset and not isinstance(offset, Immediate):
|
|
raise ValueError("Offset must be an immediate.")
|
|
if base and not isinstance(base, Register):
|
|
raise ValueError("Offset must be a register.")
|
|
|
|
if not index and not width and not offset and not base:
|
|
raise ValueError("Must provide at least an offset or base.")
|
|
|
|
def get_ir_repr(self):
|
|
pass # TODO
|
|
|
|
def get_constraint_char(self):
|
|
return 'm'
|
|
|
|
|
|
class Register(Operand):
|
|
# Persistent storage of register names
|
|
_REGISTER_NAMES_IN_USE = []
|
|
|
|
@staticmethod
|
|
def match(source_registers, destination_registers):
|
|
matched = 0
|
|
for src in source_registers:
|
|
for dst in destination_registers:
|
|
if src.llvm_type == dst.llvm_type:
|
|
destination_registers.remove(dst)
|
|
src.join(dst)
|
|
matched += 1
|
|
return matched
|
|
|
|
def __init__(self, llvm_type, constraint_char='r', name='reg'):
|
|
self.llvm_type = llvm_type
|
|
self.constraint_char = constraint_char
|
|
assert len(name) > 0, "name needs to be at least of length 1."
|
|
self._name = name
|
|
self._named = False
|
|
|
|
def get_ir_repr(self):
|
|
if self._named:
|
|
return '%"{}"'.format(self._name)
|
|
# Check if name is already in use and append integer
|
|
name = self._name
|
|
if name in self._REGISTER_NAMES_IN_USE:
|
|
i = 0
|
|
name_test = name
|
|
while name_test in self._REGISTER_NAMES_IN_USE:
|
|
name_test = '{}.{}'.format(name, i)
|
|
i += 1
|
|
name = name_test
|
|
self._set_name(name)
|
|
return '%"{}"'.format(name)
|
|
|
|
def _set_name(self, name):
|
|
if self._named:
|
|
raise RuntimeError("Already named.")
|
|
else:
|
|
self._name = name
|
|
self._REGISTER_NAMES_IN_USE.append(name)
|
|
self._named = True
|
|
|
|
def get_constraint_char(self):
|
|
return self.constraint_char
|
|
|
|
def join(self, other):
|
|
if self._named and other._named and self._name == other._name:
|
|
# nothing to do, already joined or equal
|
|
pass
|
|
elif self._named and not other._named:
|
|
other._set_name(self._name)
|
|
elif other._named and not self._named:
|
|
self._set_name(other._name)
|
|
else:
|
|
other._set_name(self.get_ir_repr())
|
|
|
|
def __repr__(self):
|
|
return '{}({})'.format(
|
|
self.__class__.__name__,
|
|
', '.join(['{}={!r}'.format(k,v) for k,v in self.__dict__.items()
|
|
if not k.startswith('_')]))
|
|
|
|
def __eq__(self, other):
|
|
return (self.llvm_type == other.llvm_type and
|
|
self.constraint_char == other.constraint_char and
|
|
self._name == other._name and
|
|
self._named == other._named)
|
|
|
|
def __hash__(self):
|
|
return hash((self.llvm_type, self.constraint_char, self._name, self._named))
|
|
|
|
class Synthable:
|
|
def __init__(self):
|
|
pass
|
|
|
|
def build_ir(self):
|
|
raise NotImplementedError()
|
|
|
|
def get_source_registers(self):
|
|
raise NotImplementedError()
|
|
|
|
def get_destination_registers(self):
|
|
raise NotImplementeError()
|
|
|
|
def __repr__(self):
|
|
return '{}({})'.format(
|
|
self.__class__.__name__,
|
|
', '.join(['{}={!r}'.format(k,v) for k,v in self.__dict__.items()
|
|
if not k.startswith('_')]))
|
|
|
|
|
|
class Operation(Synthable):
|
|
'''Base class for operations.'''
|
|
def __repr__(self):
|
|
return '{}({})'.format(
|
|
self.__class__.__name__,
|
|
', '.join(['{}={!r}'.format(k,v) for k,v in self.__dict__.items()
|
|
if not k.startswith('_')]))
|
|
|
|
|
|
class Instruction(Operation):
|
|
def __init__(self, instruction, destination_operand, source_operands):
|
|
self.instruction = instruction
|
|
self.destination_operand = destination_operand
|
|
assert isinstance(destination_operand, Register), "Destination needs to be a register."
|
|
self.source_operands = source_operands
|
|
|
|
def get_source_registers(self):
|
|
return [sop for sop in self.source_operands if isinstance(sop, Register)]
|
|
|
|
def get_destination_registers(self):
|
|
if isinstance(self.destination_operand, Register):
|
|
return [self.destination_operand]
|
|
else:
|
|
return []
|
|
|
|
def build_ir(self):
|
|
'''
|
|
Build IR string based on in and out operand names and types.
|
|
'''
|
|
# Build constraint string from operands
|
|
constraints = ','.join(
|
|
['='+self.destination_operand.get_constraint_char()] +
|
|
[sop.get_constraint_char() for sop in self.source_operands])
|
|
|
|
# Build argument string from operands and register names
|
|
operands = []
|
|
for sop in self.source_operands:
|
|
if isinstance(sop, Immediate) or isinstance(sop, Register):
|
|
operands.append('{type} {repr}'.format(type=sop.llvm_type, repr=sop.get_ir_repr()))
|
|
else:
|
|
raise NotImplemente("Only register and immediate operands are supported.")
|
|
args = ', '.join(operands)
|
|
|
|
# Build instruction from instruction and operands
|
|
return ('{dst_reg} = call {dst_type} asm sideeffect'
|
|
' "{instruction}", "{constraints}" ({args})').format(
|
|
dst_reg=self.destination_operand.get_ir_repr(),
|
|
dst_type=self.destination_operand.llvm_type,
|
|
instruction=self.instruction,
|
|
constraints=constraints,
|
|
args=args)
|
|
|
|
|
|
class Load(Operation):
|
|
def __init__(self, chain_length, structure='linear'):
|
|
'''
|
|
*chain_length* is the number of pointers to place in memory.
|
|
*structure* may be 'linear' (1-offsets) or 'random'.
|
|
'''
|
|
self.chain_length = chain_length
|
|
self.structure = structure
|
|
# TODO
|
|
|
|
|
|
class AddressGeneration(Operation):
|
|
def __init__(self, offset, base, index, width, destination='base'):
|
|
self.offset = offset
|
|
self.base = base
|
|
self.index = index
|
|
self.width = width
|
|
self.destination = destination
|
|
# TODO
|
|
|
|
|
|
class Serialized(Synthable):
|
|
def __init__(self, synths):
|
|
self.synths = synths
|
|
assert all([isinstance(s, Synthable) for s in synths]), "All elements need to be Sythable"
|
|
|
|
def get_source_registers(self):
|
|
sources = []
|
|
last_destinations = []
|
|
for s in self.synths:
|
|
for src in s.get_source_registers():
|
|
for dst in last_destinations:
|
|
if dst.llvm_type == src.llvm_type:
|
|
last_destinations.remove(dst)
|
|
sources.append(src)
|
|
last_destinations = s.get_destination_registers()
|
|
return sources
|
|
|
|
def get_destination_registers(self):
|
|
if self.synths:
|
|
return self.synths[-1].get_destination_registers()
|
|
else:
|
|
return []
|
|
|
|
def build_ir(self):
|
|
code = []
|
|
last = None
|
|
for s in self.synths:
|
|
last_dests = last.get_source_registers() if last else []
|
|
matched = Register.match(s.get_source_registers(), last_dests)
|
|
if matched == 0 and last is not None:
|
|
raise ValueError("Could not find a type match to serialize {} to {}.".format(
|
|
last, self))
|
|
code.append(s.build_ir())
|
|
last = s
|
|
return '\n'.join(code)
|
|
|
|
|
|
class Parallelized(Synthable):
|
|
def __init__(self, synths):
|
|
self.synths = synths
|
|
assert all([isinstance(s, Synthable) for s in synths]), "All elements need to be Sythable"
|
|
|
|
def get_source_registers(self):
|
|
sources = []
|
|
for s in self.synths:
|
|
sources += s.get_source_registers()
|
|
return sources
|
|
|
|
def get_destination_registers(self):
|
|
destinations = []
|
|
for s in self.synths:
|
|
destinations += s.get_destination_registers()
|
|
return destinations
|
|
|
|
def build_ir(self):
|
|
code = []
|
|
for s in self.synths:
|
|
code.append(s.build_ir())
|
|
return '\n'.join(code)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
i1 = Instruction(
|
|
instruction='add $2, $0',
|
|
destination_operand=Register('i64', 'r'),
|
|
source_operands=[Register('i64', 'r'), Immediate('i64', '1')])
|
|
i2 = Instruction(
|
|
instruction='sub $2, $0',
|
|
destination_operand=Register('i64', 'r'),
|
|
source_operands=[Register('i64', 'r'), Immediate('i64', '1')])
|
|
s = Serialized([i1, i2])
|
|
i3 = Instruction(
|
|
instruction='mul $1, $0',
|
|
destination_operand=Register('i64', 'r'),
|
|
source_operands=[Register('i64', 'r'), Register('i64', 'r')])
|
|
i4 = Instruction(
|
|
instruction='div $2, $0',
|
|
destination_operand=Register('i64', 'r'),
|
|
source_operands=[Register('i64', 'r'), Immediate('i64', '23')])
|
|
i5 = Instruction(
|
|
instruction='mul $2, $0',
|
|
destination_operand=Register('i64', 'r'),
|
|
source_operands=[Register('i64', 'r'), Immediate('i64', '23')])
|
|
i6 = Instruction(
|
|
instruction='add $2, $0',
|
|
destination_operand=Register('i64', 'r'),
|
|
source_operands=[Register('i64', 'r'), Register('i64', 'r')])
|
|
s1 = Serialized([i1, i2])
|
|
s2 = Serialized([s1, i3])
|
|
s2.build_ir()
|
|
s3 = Serialized([i4, i5])
|
|
p1 = Parallelized([i6, s2, s3])
|
|
print(p1.build_ir())
|
|
print('srcs', [r.get_ir_repr() for r in p1.get_source_registers()])
|
|
print('dsts', [r.get_ir_repr() for r in p1.get_destination_registers()]) |