#!/usr/bin/python # -*- coding: utf-8 -*- """Mini 386 assembler. This (very crudely) implements a subset of gas syntax; it supports the following: #comments .equiv .globl labels mov $32bitconst, %32bitreg mov %32bitreg, %32bitreg int const test %32bitreg, %32bitreg jz const jmp const -nostdlib only That’s enough to compile cat.s. This is incomplete; although it now successfully computes the bytes of the executable code and the symbol table, it doesn’t yet know how to encode them as ELF. """ import re import sys def main(argv): output = [] symbols = {} relocations = {} for lineno, line in enumerate(open(argv[1])): line = comment.sub('', line) instructions = line.split(';') for instruction in instructions: label_mo = label.search(instruction) while label_mo: symbols[label_mo.group(1)] = len(output) instruction = label.sub('', instruction, 1) label_mo = label.search(instruction) op_mo = op.match(instruction) if not op_mo: if not wsp.match(instruction): warn("%s:%s: can’t parse %r" % (argv[1], lineno+1, instruction)) continue directive, arg1, arg2 = op_mo.groups() compiler = directives.get(directive.lower()) if not compiler: warn("%s:%s: don’t know %r" % (argv[1], lineno+1, directive)) continue compiler(arg1, arg2, output, symbols, relocations) resolve(output, symbols, relocations) hexdump(output) print symbols print relocations comment = re.compile('#.*') label = re.compile(r'\s*(\w+):\s*') op = re.compile(r'\s*([\w.]+)\s+([^,]*\S)(?:\s*,\s*([^,]*\S))?\s*$') wsp = re.compile(r'\s*$') def compile_mov(arg1, arg2, output, symbols, relocations): if arg1.startswith('$'): dest = regs_32_patterns[arg2] output.append(0xb8 | dest) compile_32bit_const(arg1[1:], output, symbols, relocations) elif arg1.startswith('%'): output.append(0x89) # XXX %ecx output.append(0xe1) # XXX %esp else: warn("don't know what to make of mov %s, %s" % (arg1, arg2)) regs_32_patterns = {'%eax': 0, '%ebx': 3, '%ecx': 1, '%edx': 2} def compile_32bit_const(n, output, symbols, relocations): val = parse_int(n) if val is None: val = 0 add_relocation(n, len(output), const_32bit_relocation, relocations) for ii in range(4): output.append(val & 0xff) val >>= 8 def parse_int(n): if integer_format.match(n): return int(n) elif hex_format.match(n): return int(n, 16) return None integer_format = re.compile(r'\d+$') hex_format = re.compile('0x[\da-fA-F]+$') def add_relocation(symbol, where, what_kind, relocations): if symbol not in relocations: relocations[symbol] = [] relocations[symbol].append((where, what_kind)) def const_32bit_relocation(output, where, value): for ii in range(4): output[where + ii] = value & 0xff value >>= 8 def compile_int(arg1, arg2, output, symbols, relocations): output.append(0xcd) if arg1.startswith('$'): output.append(int(arg1[1:], 16 if arg1.startswith("$0x") else 10)) else: warn("int what? %r" % arg1) def compile_test(arg1, arg2, output, symbols, relocations): output.append(0x85); output.append(0xc0) # XXX %eax, %eax def compile_jz(arg1, arg2, output, symbols, relocations): output.append(0x74) add_relocation(arg1, len(output), relative_jump_8bit, relocations) output.append(0x00) def relative_jump_8bit(output, where, value): diff = value - (where + 1) if diff < -128 or diff > 127: warn("relative jump at %d too far, %d" % (where, diff)) output[where] = 0xff & diff def compile_jmp(arg1, arg2, output, symbols, relocations): output.append(0xe9) # long offset jump add_relocation(arg1, len(output), relative_jump_32bit, relocations) for ii in range(4): output.append(0) def relative_jump_32bit(output, where, value): const_32bit_relocation(output, where, value - (where + 4)) def compile_equiv(arg1, arg2, output, symbols, relocations): if arg1 in symbols: warn("double definition of %s as %r and %r" % (arg1, symbols[arg1], arg2)) else: symbols[arg1] = parse_int(arg2) def compile_globl(arg1, arg2, output, symbols, relocations): pass directives = {'mov': compile_mov, 'int': compile_int, 'test': compile_test, 'jz': compile_jz, 'jmp': compile_jmp, '.equiv': compile_equiv, '.globl': compile_globl, } # XXX this probably needs to somehow get the filename and line number def warn(msg): sys.stderr.write(msg + '\n') def resolve(output, symbols, relocations): for symbol, value in symbols.items(): for where, what_kind in relocations.get(symbol, []): what_kind(output, where, value) def hexdump(output): for n, byte in enumerate(output): print "%02x" % byte, if n % 16 == 15: print print if __name__ == '__main__': main(sys.argv)