#!/usr/bin/env python3

import argparse
import datetime
import logging
import socket
import struct

class TPPacket: #RFC 868

    EPOCH = datetime.datetime(year=1900, month=1, day=1, tzinfo=datetime.UTC)

    def __init__(self, payload=b'\0'*4):
        self._payload = payload
        assert len(self._payload) == 4

    @property
    def payload(self):
        return self._payload

    @property
    def time(self):
        seconds = struct.unpack('!L', self.payload[0:4])[0]
        return self.EPOCH + datetime.timedelta(seconds=seconds)

    @time.setter
    def time(self, value):
        seconds = int((value.astimezone(datetime.UTC) - self.EPOCH).total_seconds() + 0.5)
        self._payload = struct.pack('!L', seconds)

    def __repr__(self):
        result = f'Time {self.time.astimezone()}'
        return result

    @classmethod
    def TPResponse(cls, timestamp=None):
        if timestamp is None:
            timestamp=datetime.datetime.now(datetime.UTC)
        seconds = int((timestamp - cls.EPOCH).total_seconds() + 0.5)
        return TPPacket(struct.pack('!L', seconds))

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Time Protocol Client UDP')
    parser.add_argument('--debug', action='store_true')
    parser.add_argument('--host', default='time.nist.gov')
    parser.add_argument('--port', default='time')
    parser.add_argument('--timeout', type=float, default=2)
    parser.add_argument('--retries', type=int, default=5)
    args = parser.parse_args()
    if args.debug:
        logging.basicConfig(level=logging.DEBUG)
    else:
        logging.basicConfig(level=logging.INFO)

    for family, type, proto, canonname, sockaddr in socket.getaddrinfo(args.host, args.port, family=socket.AF_INET, type=socket.SOCK_DGRAM):
        with socket.socket(family, type) as sock:
            retry_time = datetime.timedelta(seconds=args.timeout)
            sock.settimeout(retry_time.total_seconds())
            success = False
            tries = 0
            while not success:
                if tries >= args.retries:
                    logging.error(f'too many retries')
                    break
                sock.sendto(b'', sockaddr)
                logging.info(f'sent UDP query to {sockaddr}')
                tries += 1
                now_time = datetime.datetime.now()
                next_retry_time = now_time + retry_time
                while True:
                    timeout = (next_retry_time - datetime.datetime.now()).total_seconds()
                    if timeout <= 0:
                        break
                    sock.settimeout(timeout)
                    try:
                        reply, addr = sock.recvfrom(2**16)
                    except socket.timeout:
                        logging.warning(f'socket timeout')
                        break
                    finally:
                        sock.settimeout(retry_time.total_seconds())
                    if addr != sockaddr:
                        logging.warning(f'unexpected response from {addr}')
                        continue
                    try:
                        reply = TPPacket(reply)
                        logging.info(f'received UDP response from {addr}: {reply}')
                        success = True
                        break
                    except AssertionError:
                        logging.warning('mangled response received')
                        continue
