#!/usr/bin/env python3
#--------------------------------------------------------------------------------
#
#   BREDSAC Microcode Generator
#
#   Usage: microgen source
#
#--------------------------------------------------------------------------------

import os, sys

class SourceError(Exception):
    pass

#--------------------------------------------------------------------------------

class Scanner:

    def __init__(self, srcpath):
        self.srcpath = srcpath
        self.src = open(srcpath)
        self.lineno = 0
        self.definitions = {}
        self.next_line()
    
    def next_line(self):
        try:
            self.read_line()
            while self.line == "":
                self.read_line()
        except EOFError:
            self.line = None
    
    def eof(self):
        return self.line is None
    
    def error_message(self, arg):
        return "%s, line %s: %s" % (self.srcpath, self.lineno, arg)

    def read_line(self):
        self.lineno += 1
        line = self.src.readline()
        if line:
            line = line.partition("#")[0]
            #print("<", line)
            raw_tokens = line.split()
            defs = self.definitions
            expanded_tokens = [defs.get(token, token) for token in raw_tokens]
            self.line = "".join(expanded_tokens)
            #print(">", self.line)
        else:
            self.line = None

#--------------------------------------------------------------------------------

class Data:

    def __init__(self, addr_bits, data_bits):
        self.addr_bits = addr_bits
        self.data_bits = data_bits
        self.words = [0] * (1 << addr_bits)
        self.occupied = set()

#--------------------------------------------------------------------------------

def replace_suffix(path, suf):
    return os.path.splitext(path)[0] + suf

def int_from_bits(bits):
    n = 0
    for i, b in enumerate(bits):
        if b == "0":
            pass
        elif b == "1":
            n |= 1 << i
        else:
            raise SourceError("Invalid digit '%s'" % b)
    return n

def combinations(bits):
#     print("combinations(%s)" % bits)
    result = [0]
    for i, b in enumerate(bits[::-1]):
        x = 1 << i
        if b == "0":
            pass
        elif b == "1":
            result = [n | x for n in result]
        elif b == "x":
            result.extend([n | x for n in result])
#     print(result)
    return result
            
def process_header(src):
    addr_bits = None
    data_bits = None
    while not src.eof():
        line = src.line
        if ":=" not in line:
            break
        name, equals, value = line.partition(":=")
        try:
            n = int(value)
        except ValueError:
            src.error("Invalid parameter value")
        if name == "AddressBits":
            addr_bits = n
        elif name == "DataBits":
            data_bits = n
        else:
            src.error("Invalid parameter name '%s'" % name)
        src.next_line()
    if addr_bits is None:
        src.error("AddressBits parameter missing")
    if data_bits is None:
        src.error("DataBits parameter missing")
    return Data(addr_bits, data_bits)

def process_definition(src):
    name, _, value = src.line.partition("=")
    if name in src.definitions:
        raise SourceError("Redefinition of '%s'" % name)
    src.definitions[name] = value

def process_microinstruction(src, data):
    abits, colon, dbits = src.line.partition(":")
    if colon != ":":
        raise SourceError("Missing ':'")
    if len(abits) != data.addr_bits:
        raise SourceError("Wrong number of address bits (found %s, expected %s)" % (len(abits), data.addr_bits))
    if len(dbits) != data.data_bits:
        raise SourceError("Wrong number of data bits (found %s, expected %s)" % (len(dbits), data.data_bits))
    dword = int_from_bits(dbits)
    for addr in combinations(abits):
        if addr in data.occupied:
            raise SourceError("Word at address 0x%x already occupied" % addr)
        data.words[addr] = dword
        data.occupied.add(addr)

def process_lines(src, data):
    while not src.eof():
        if "=" in src.line:
            process_definition(src)
        else:
            process_microinstruction(src, data)            
        src.next_line()

def process(src):
    data = process_header(src)
    print("Address bits:", data.addr_bits)
    print("Data bits:", data.data_bits)
    process_lines(src, data)
    return data

def write_data(data, dstpath):
    print("Output path:", dstpath)
#     print(data.words)
    f = open(dstpath, "w")
    f.write("v2.0 raw\n")
    for a, w in enumerate(data.words):
        if a and not (a & 0xf):
            f.write("\n")
        f.write("%x " % w)
    f.close()

def main():
    srcpath = sys.argv[1]
    dstpath = replace_suffix(srcpath, ".hex")
    src = Scanner(srcpath)
    try:
        data = process(src)
    except SourceError as e:
        sys.stderr.write("%s\n" % src.error_message(e))
        sys.exit(1)
    write_data(data, dstpath)

main()
