#!/usr/bin/env python3
#--------------------------------------------------------------------------------
#
#   LOGSAC RAM Image Generator
#
#   Usage: ramgen source
#
#--------------------------------------------------------------------------------

import os, sys

class SourceError(Exception):
    pass

#--------------------------------------------------------------------------------

class Scanner:

    def __init__(self, srcpath):
        self.srcpath = srcpath
        self.src = open(srcpath)
        self.lineno = 0
        self.next_line()
    
    def next_line(self):
        try:
            self.read_line()
            while not self.eof() and self.line == "":
                self.read_line()
        except EOFError:
            self.line = None
    
    def eof(self):
        return self.line is None
    
    def error_message(self, arg):
        return "%s, line %s: %s" % (self.srcpath, self.lineno, arg)

    def read_line(self):
        self.lineno += 1
        line = self.src.readline()
        if line:
            line, _, _ = line.strip().partition("#")
            self.line = line.replace(" ", "")
        else:
            self.line = None

#--------------------------------------------------------------------------------

def replace_suffix(path, suf):
    return os.path.splitext(path)[0] + suf

letters = 'PQWERTYUIOJpSZKebFtDoHNMnLXHABCV'
figures = '0123456789_f"+(ab$r; &,.n)/#-?:='
#characters = (letters, figures)
#opcodes = "PQWERTYUIOJ.SZK..F.D.HNM.LXGABCV"
lengths = "FD"

assert len(letters) == len(figures) == 32

def opcode(c):
    try:
        return letters.index(c)
    except ValueError:
        raise SourceError("Invalid opcode character %r" % c)

def character(c):
    for charset in (letters, figures):
        try:
            return charset.index(c)
        except ValueError:
            pass
    raise SourceError("Invalid character literal %r" % c)

def address(n):
    try:
        x = int(n) if n else 0
        if x < 0 or x > 1023:
            raise ValueError
    except ValueError:
        raise SourceError("Invalid address %r" % n)
    return x

def length(l):
    try:
        return lengths.index(l)
    except ValueError:
        raise SourceError("Invalid length character %r" % l)
    
def number(d, base):
    try:
        return int(d, base)
    except ValueError:
        raise SourceError("Invalid base %s number %r" % (base, d))

def fraction(d, l):
    f = float(d)
    b = 34 if l == "D" else 16
    x = int(abs(f) * (2 ** b))
    if f < 0:
        x = -x
    return x

def process(src):
    data = [0] * 1024
    imax = 0
    while not src.eof():
        a, _, line = src.line.partition(":")
        i = address(a)
        c, d = line[:1], line[1:]
        if "A" <= c.upper() <= "Z":
            n, l = d[:-1], d[-1:]
            word = (opcode(c) << 12) + (address(n) << 1) + length(l)
            data[i] = word
        else:
            if line.endswith(("F", "D")):
                d, l = line[:-1], line[-1:]
            else:
                d, l = line, "F"
            if d.startswith("'") and d.endswith("'"):
                c = "'"
                d = d[1:-1]
            elif d.startswith(("$", "%")):
                c, d = d[:1], d[1:]
            else:
                c = ""
            print("c =", c, "d =", d)
            if c == "'":
                word = character(d) << 12
                print("word =", hex(word))
            elif c == "$":
                word = number(d, 16)
            elif c == "%":
                word = number(d, 2)
            elif "." in d:
                word = fraction(d, l)
            else:
                word = number(d, 10)
            data[i] = word & 0x3ffff
            if l == "D":
                i += 1
                data[i] = (word >> 18) & 0x3ffff
        if i >= imax:
            imax = i + 1
        src.next_line()
    return data[:imax]

def write_data(data, dstpath):
    #print("Writing RAM image:", dstpath)
    r32 = range(32)
    f = open(dstpath, "w")
    f.write("v2.0 raw\n")
    for w in data:
        for i in r32:
            f.write("%d " % ((w >> i) & 1))
            if i == 17:
                f.write(" ")
        f.write("\n")
    f.close()

def parse_args():
    p = optparse.OptionParser()
    p.add_option("-o", "--output", dest = "hexpath", help = "Name of hex output file")
    options, args = p.parse_args()
    if len(args) != 1:
        p.error("Wrong number of source files, expected 1")
    return options, args

def main():
    options, args = parse_args()
    srcpath = args[0]
    dstpath = options.hexpath
    if not dstpath:
        dstpath = replace_suffix(srcpath, ".hex")
    src = Scanner(srcpath)
    try:
        data = process(src)
    except SourceError as e:
        sys.stderr.write("%s\n" % src.error_message(e))
        sys.exit(1)
    write_data(data, dstpath)

main()
