assembler: refactor and start tests

This commit is contained in:
Peter McGoron 2023-02-07 17:25:52 +00:00
parent 5084bb7771
commit 96e8eb95b0
2 changed files with 93 additions and 36 deletions

17
asm_test.py Normal file
View File

@ -0,0 +1,17 @@
from creole_asm import *
import unittest
class ProgramTest(unittest.TestCase):
def test_oneline(self):
p = Program()
p.parse_asm_line("PUSH r0")
b = p()
self.assertEqual(b, b'\x01\xC2\x80\x00')
def test_large_reg(self):
p = Program(regnum=0x8000000)
p.parse_asm_line("PUSH r134217727")
b = p()
self.assertEqual(b, b'\x01\xFC\x87\xbf\xbf\xbf\xbf\x00')
if __name__ == "__main__":
unittest.main()

View File

@ -2,56 +2,66 @@
from enum import Enum from enum import Enum
class MalformedArgument(Exception):
pass
class ArgType(Enum): class ArgType(Enum):
TYPE_IMM = 1 IMM = 1
TYPE_REG = 2 REG = 2
TYPE_VAL = 3 VAL = 3
TYPE_LAB = 4 LAB = 4
def gettype(s): def gettype(s):
if s.isnumeric(): if s.isnumeric():
return (TYPE_IMM, int(s)) return (ArgType.IMM, int(s))
elif s[0] == 'r' and s[1:].isnumeric(): elif s[0] == 'r' and s[1:].isnumeric():
return (TYPE_REG, int(s[1:])) return (ArgType.REG, int(s[1:]))
elif s[0] == 'l' and s[1:].isnumeric(): elif s[0] == 'l' and s[1:].isnumeric():
return (TYPE_LAB, int(s[1:])) return (ArgType.LAB, int(s[1:]))
else: else:
return None raise MalformedArgument(s)
def typecheck(self, s): def typecheck(self, s):
t = ArgType.gettype(s) t = ArgType.gettype(s)
if t is None: if self == ArgType.VAL:
return None return t[0] == ArgType.REG or t[0] == ArgType.IMM
if self == TYPE_VAL:
return t[0] == TYPE_REG or t[0] == TYPE_IMM
else: else:
return t[0] == self return t[0] == self
class OpcodeException(Exception):
pass
class TypecheckLenException(Exception):
pass
class TypecheckException(Exception):
pass
class Instruction: class Instruction:
def __init__(self, opcode, argtypes): def __init__(self, opcode, argtypes):
if opcode > 0x7F or opcode < 0:
raise OpcodeException(opcode)
self.opcode = opcode self.opcode = opcode
assert self.opcode < 0x80 and self.opcode >= 0
self.argtypes = argtypes self.argtypes = argtypes
def typecheck(self, sargs): def typecheck(self, sargs):
rargs = [] rargs = []
if len(sargs) != len(self.argtypes): if len(sargs) != len(self.argtypes):
return None raise TypecheckLenException(sargs, self.argtypes)
for i in range(0, len(sargs)): for i in range(0, len(sargs)):
if not self.argtypes[i].typecheck(sargs[i]): if not self.argtypes[i].typecheck(sargs[i]):
return None raise TypecheckException(self.argtypes[i],
sargs[i])
rargs.append(ArgType.gettype(sargs[i])) rargs.append(ArgType.gettype(sargs[i]))
return rargs return rargs
instructions = { instructions = {
"NOP" : Instruction(0, []), "nop" : Instruction(0, []),
"PUSH" : Instruction(1, [ArgType.TYPE_REG]), "push" : Instruction(1, [ArgType.REG]),
"POP" : Instruction(2, [ArgType.TYPE_REG]), "pop" : Instruction(2, [ArgType.REG]),
"ADD" : Instruction(3, [ArgType.TYPE_VAL, ArgType.TYPE_VAL, ArgType.TYPE_VAL]), "add" : Instruction(3, [ArgType.VAL, ArgType.VAL, ArgType.VAL]),
"MUL" : Instruction(4, [ArgType.TYPE_VAL, ArgType.TYPE_VAL, ArgType.TYPE_VAL]), "mul" : Instruction(4, [ArgType.VAL, ArgType.VAL, ArgType.VAL]),
"DIV" : Instruction(5, [ArgType.TYPE_VAL, ArgType.TYPE_VAL, ArgType.TYPE_VAL]), "div" : Instruction(5, [ArgType.VAL, ArgType.VAL, ArgType.VAL]),
"JL" : Instruction(6, [ArgType.TYPE_LAB, ArgType.TYPE_VAL, ArgType.TYPE_VAL]), "jl" : Instruction(6, [ArgType.LAB, ArgType.VAL, ArgType.VAL]),
"CLB" : Instruction(7, [ArgType.TYPE_LAB]), "clb" : Instruction(7, [ArgType.LAB]),
"SYS" : Instruction(8, [ArgType.TYPE_VAL]) "sys" : Instruction(8, [ArgType.VAL])
} }
encoding_types = { encoding_types = {
@ -66,18 +76,31 @@ encoding_types = {
# B : Total number of bits excluding high bits # B : Total number of bits excluding high bits
} }
class InvalidNumberException(Exception):
pass
class InvalidLengthException(Exception):
pass
def encode_pseudo_utf8(n, high_bits, to): def encode_pseudo_utf8(n, high_bits, to):
if n < 0:
raise InvalidNumberException(n)
if to is None or to < 0:
for k in sorted(encoding_types):
if n <= encoding_types[k][0]:
to = k
break
if to is None:
raise InvalidNumberException(n)
if to > 8 or to < 0: if to > 8 or to < 0:
return None raise InvalidLengthException(to)
elif to == 1: elif to == 1:
if n < 0x80: if n < 0x80:
return bytes([n]) return bytes([n])
else: else:
return None raise InvalidNumberException(n,to)
(maxval, start_byte, n_tot) = encoding_types[to] (maxval, start_byte, n_tot) = encoding_types[to]
if n > maxval or high_bits > 15: if n > maxval or high_bits > 15:
return None raise InvalidNumberException(n, high_bits)
n = n | (high_bits << n_tot) n = n | (high_bits << n_tot)
all_bytes = [] all_bytes = []
for i in range(0, to - 1): for i in range(0, to - 1):
@ -86,32 +109,47 @@ def encode_pseudo_utf8(n, high_bits, to):
all_bytes.append(start_byte | n) all_bytes.append(start_byte | n)
return bytes(reversed(all_bytes)) return bytes(reversed(all_bytes))
class RangeCheckException(Exception):
pass
class Line: class Line:
def __init__(self, opcode, args): def __init__(self, opcode, args, labnum, regnum):
self.opcode = opcode self.opcode = opcode
self.args = args self.args = args
for a in args:
if a[0] == ArgType.REG:
if a[1] < 0 or a[1] >= regnum:
raise RangeCheckException(a[0],
a[1],
regnum)
elif a[0] == ArgType.LAB:
if a[1] < 0 or a[1] >= labnum:
raise RangeCheckException(a[0],
a[1],
regnum)
def __call__(self): def __call__(self):
b = bytes([self.opcode]) b = bytes([self.opcode])
for a in args: for a in self.args:
if a[0] == TYPE_REG: if a[0] == ArgType.REG:
b = b + encode_pseudo_utf8(a[1],1,None) b = b + encode_pseudo_utf8(a[1],1,None)
else: else:
b = b + encode_pseudo_utf8(a[1],0,None) b = b + encode_pseudo_utf8(a[1],0,None)
return b + bytes([0]) return b + bytes([0])
class InstructionNotFoundException(Exception):
pass
class Program: class Program:
def asm_push_line(self, ins, args): def asm_push_line(self, ins, args):
self.asm.append(Line(ins, args)) self.asm.append(Line(ins, args, self.labnum, self.regnum))
def parse_asm_line(self, line): def parse_asm_line(self, line):
line = line.split() line = line.split()
line[0] = line[0].casefold()
if line[0] not in instructions: if line[0] not in instructions:
raise Exception raise InstructionNotFoundException(line[0])
else: else:
ins = instructions[line[0]] ins = instructions[line[0]]
args_w_type = ins.typecheck(line[1:]) args_w_type = ins.typecheck(line[1:])
if r is None:
raise Exception
self.asm_push_line(ins.opcode, args_w_type) self.asm_push_line(ins.opcode, args_w_type)
def __call__(self): def __call__(self):
@ -120,5 +158,7 @@ class Program:
b = b + line() b = b + line()
return b return b
def __init__(self): def __init__(self, labnum=16, regnum=16):
self.asm = [] self.asm = []
self.labnum = labnum
self.regnum = regnum