from argparse import ArgumentParser
from hashlib import md5
import re
import struct
import tempfile
from pathlib import Path
from arc4 import ARC4


def decrypt(flash_file: str):
    with open(flash_file, "rb") as f:
        f.seek(0x40000)
        data = bytearray(f.read(0x120))
        assert data[:16].decode() == "OB100 UNIT INFO\0", "missing header"

        (length,) = struct.unpack_from(">L", data, offset=0x28)
        (rk_loc,) = struct.unpack_from(">L", data, offset=0x34)
        print(f"length=0x{length:08x} rk_loc=0x{rk_loc:08x}")

        data.extend(f.read(length - len(data)))

    checksum = data[0x14:0x24]
    data[0x14:0x24] = [0xFF] * 16
    key = md5(data[:rk_loc] + b"thisisthesecretofobihaimfd").digest()
    calc = md5(b"thisisanothersecretofobihaimfd" + key[1:] + data).digest()
    assert calc == checksum, "incorrect checksum"

    data[rk_loc:] = ARC4(key[:-1]).decrypt(bytes(data[rk_loc:]))
    return data


def extract(data: bytearray, dir: Path):
    (major, minor) = struct.unpack_from("BB", data, offset=0x11)
    print(f"hardware version: {major}.{minor}")
    (build_date,) = struct.unpack_from("<L", data, offset=0x24)
    print(f"build date: {build_date:08X} (?)")

    oem_pn = data[0x64:0x74].decode(encoding="ascii", errors="ignore")
    oem_pn = re.match(r"^[a-zA-Z0-9 -]+", oem_pn)
    print(f"oem pn: {oem_pn}")

    if oem_pn == "GFPB100":
        sn_len = 15
    else:
        sn_len = 12
    (sn_off,) = struct.unpack_from(">L", data, offset=0x30)
    print(f"serial no offset=0x{sn_off:08x} length=0x{sn_len:08x}")
    sn = data[sn_off : sn_off + sn_len].decode()
    print(f"serial no: {sn}")

    (rk_off, rk_len) = struct.unpack_from(">LL", data, offset=0x34)
    print(f"rk offset=0x{rk_off:08x} length=0x{rk_len:08x} (?)")

    (privkey_off, privkey_len) = struct.unpack_from(">LL", data, offset=0x40)
    print(f"private key offset=0x{privkey_off:08x} length=0x{privkey_len:08x}")
    privkey = data[privkey_off : privkey_off + privkey_len]
    (dir / "privkey.der").write_bytes(privkey)

    (cert_off, cert_len) = struct.unpack_from(">LL", data, offset=0x4C)
    print(f"certificate offset=0x{cert_off:08x} length=0x{cert_len:08x}")
    cert = data[cert_off : cert_off + cert_len]
    (dir / "cert.der").write_bytes(cert)

    (obi_no_off,) = struct.unpack_from("<L", data, offset=0x58)
    print(f"obi no offset=0x{obi_no_off:08x} length=0x{9:08x}")
    sn = data[obi_no_off : obi_no_off + 9].decode()
    print(f"obi no: {sn}")

    mac_addr = data[0x100:0x106]
    print(f"mac address: {':'.join([f'{b:02x}' for b in mac_addr])}")

    (pn,) = struct.unpack_from(">L", data, offset=0x114)
    print(f"part no: OBi{pn}")


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("flash", metavar="FLASH", help="input flash blob file")
    parser.add_argument("--plain", "-p", help="write decrypted blob to file")
    parser.add_argument("--dir", "-d", help="write blobs to this directory")
    args = parser.parse_args()

    plain = decrypt(args.flash)
    if args.plain:
        Path(args.plain).write_bytes(plain)

    if args.dir:
        dir = args.dir
    else:
        dir = tempfile.mkdtemp()
    print(f"writing blobs to {dir}")

    extract(plain, Path(dir))
