from argparse import ArgumentParser
import struct
from arc4 import ARC4


def extract_mac_address(file):
    file.seek(0x40100)
    return file.read(6)


def factory_defaults(file, mac_addr):
    f.seek(0x460000)
    data = bytearray(f.read(0x10000))
    assert data[16:29].decode() == "OBI_PARAM_PT\0", "missing header"

    (tag, _, checksum, payload_len) = struct.unpack_from("<LLLL", data)
    print(f"tag: 0x{tag:08x}")
    print(f"checksum: 0x{checksum:08x}")
    print(f"length: 0x{payload_len:08x}")
    assert tag == 0xfffffffb, "incorrect tag"

    key = data[:15]
    key[0] = mac_addr[0] & 0xfd
    key[1:4] = mac_addr[1:4]
    key[4] &= mac_addr[4]
    key[5] &= mac_addr[5]
    plain = ARC4(bytes(key)).decrypt(bytes(data[0x100 : 0x100 + payload_len]))

    sum = 0
    for b in memoryview(plain).cast("b"):
        sum += b
    print(f"sum: 0x{sum:08x}")
    assert sum == checksum, "incorrect checksum"

    off = 0
    while off < payload_len:
        (tag, len, flags, type) = struct.unpack_from("<LHBB", plain, offset=off)
        if tag == 0xffffffff:
            break
        print(f"tag={tag:08x} type={type} flags={flags:02x} len={len}")
        off += 8
        if len == 0:
            continue
        # 1: free text
        # 3: integer (unsigned?)
        # 4: bool
        # 5: select text
        if type == 1 or type == 4 or type == 5:
            print("  " + plain[off : off + len].decode())
        elif type == 3:
            (long,) = struct.unpack_from("<L", plain, offset=off)
            print(f"  {long}")
        off += len

    return plain


def config(file, pos, mac_addr):
    f.seek(pos)
    data = bytearray(f.read(0x20000))
    assert data[16:29].decode() == "OBI_PARAM_PT\0", "missing header"

    (tag,) = struct.unpack_from("<L", data)
    print(f"1st tag: 0x{tag:08x}")
    assert tag == 0xfffffffd, "incorrect tag"

    (version, total_len) = struct.unpack_from("<LL", data, offset=0x30)
    print(f"version: 0x{version:08x}")
    print(f"total len: 0x{total_len:08x}")
    assert total_len == data.__len__(), "incorrect total length"

    (tag, _, payload_len, checksum) = struct.unpack_from("<LLLL", data, offset=0x100)
    print(f"2nd tag: 0x{tag:08x}")
    print(f"checksum: 0x{checksum:08x}")
    print(f"length: 0x{payload_len:08x}")
    assert tag == 0xfffffffc, "incorrect tag"

    key = data[0x100:0x10f]
    key[0] = mac_addr[0] & 0xfe
    key[1:4] = mac_addr[1:4]
    key[4] &= mac_addr[4]
    key[5] &= mac_addr[5]
    plain = ARC4(bytes(key)).decrypt(bytes(data[0x110 : 0x110 + payload_len]))

    sum = 0
    for b in memoryview(plain).cast("b"):
        sum += b
    print(f"sum: 0x{sum:08x}")
    assert sum == checksum, "incorrect checksum"

    off = 0
    while off < payload_len:
        (tag, len, flags, type) = struct.unpack_from("<LHBB", plain, offset=off)
        if tag == 0xffffffff:
            break
        print(f"tag={tag:08x} type={type} flags={flags:02x} len={len}")
        off += 8
        if len == 0:
            continue
        # 1: free text
        # 3: integer (unsigned?)
        # 4: bool
        # 5: select text
        if type == 1 or type == 4 or type == 5:
            print("  " + plain[off : off + len].decode())
        elif type == 3:
            (long,) = struct.unpack_from("<L", plain, offset=off)
            print(f"  {long}")
        off += len

    return plain


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("flash", metavar="FLASH", help="input flash blob file")
    args = parser.parse_args()

    with open(args.flash, "rb") as f:
        mac_addr = extract_mac_address(f)
        print(f"mac address: {':'.join([f'{b:02x}' for b in mac_addr])}")
        print("=== FACTORY DEFAULTS")
        factory_defaults(f, mac_addr)
        print("=== CONFIG @ 0x400000")
        config(f, 0x400000, mac_addr)
        print("=== CONFIG @ 0x420000")
        config(f, 0x420000, mac_addr)
