#!/usr/bin/env python3
# vim:ts=4:sts=4:sw=4:expandtab

import argparse
import binascii
import random
import struct
import sys
import wave
import numpy as np

parser = argparse.ArgumentParser(prog='Modem Encoder')
parser.add_argument('--debug', action='store_true')
parser.add_argument('--signal-time', type=float, default=0.1)
parser.add_argument('--frequency-zero', type=int, default=440)
parser.add_argument('--frequency-one', type=int, default=880)
parser.add_argument('--destination', type=int, default=1)
parser.add_argument('--source', type=int, default=2)
parser.add_argument('--message', default='hello world!')
parser.add_argument('--framerate', type=int, default=44000)
parser.add_argument('--amplitude', type=float, default=0.1)
parser.add_argument('--delay-time', type=float, default=3.0)
parser.add_argument('--noise', type=float, default=0)
parser.add_argument('--harder', action='store_true')
parser.add_argument('--output', default='modem.wav')
args = parser.parse_args()

def strtobits(msg):
    return [ l == '1' for l in msg ]
def bitstostr(bits):
    return ''.join([ '1' if bit else '0' for bit in bits ])
def octetstobits(byts):
    result = []
    for byte in byts:
        result += [ bool(byte & (128>>i)) for i in range(8) ]
    return result

def preamble():
    return strtobits('10101010'*7 + '10101011')
def destination(dst = args.destination):
    return struct.pack('!LH', dst//(2**16), dst%(2**16))
def source(src = args.source):
    return struct.pack('!LH', src//(2**16), src%(2**16))
def length(m = args.message):
    return struct.pack('!H', len(bytes(m, 'utf8')))
def message(m = args.message):
    return bytes(m, 'utf8')
def crc(octets):
    return struct.pack('!L', binascii.crc32(octets))


def nrzi(bitstring):
    result = [ not bitstring[0] ]
    for bit in bitstring[1:]:
        if bit:
            result.append(not result[-1])
        else:
            result.append(result[-1])
    return result

bitstring_map = { '0000' : '11110', '0001' : '01001', '0010' : '10100', '0011' : '10101', '0100' : '01010', '0101' : '01011', '0110' : '01110', '0111' : '01111', '1000' : '10010', '1001' : '10011', '1010' : '10110', '1011' : '10111', '1100' : '11010', '1101' : '11011', '1110' : '11100', '1111' : '11101', }
def encode4b5b(bitstring):
    result = []
    for i in range(0, len(bitstring), 4):
        result += strtobits(bitstring_map[bitstostr(bitstring[i:i+4])])
    return result

octets = destination() + source() + length() + message()
octets += crc(octets)
if args.debug:
    print(f'Encoding octets {octets}')

bitstring = preamble() + nrzi(encode4b5b(octetstobits(octets)))
if args.debug:
    print(f'Sending bits {bitstring}')

angle = 0.0
def tone(frequency, time = args.signal_time, amplitude = args.amplitude, noise = args.noise):

    global angle
    if args.harder:
        time = random.uniform(1.0, 1.1) * time
    points = int(np.round(args.framerate * time))

    tone = np.sin( np.arange(points)/points * 2*np.pi * frequency*time + angle * 2*np.pi )
    angle = np.modf(angle + frequency*time)[0]
    noisy = tone + np.random.normal(0, noise*amplitude, points)
    byte_tone = b''.join([struct.pack('h', int(e * (2**15) * amplitude)) for e in np.array(noisy)])
    return byte_tone

wave_file = wave.open(args.output, 'wb')
wave_file.setnchannels(1)
wave_file.setsampwidth(2)
wave_file.setframerate(args.framerate)

wave_file.writeframes(tone(args.frequency_zero, args.delay_time, 0))
for bit in bitstring:
    if bit:
        wave_file.writeframes(tone(args.frequency_one))
    else:
        wave_file.writeframes(tone(args.frequency_zero))

wave_file.close()
