#!/usr/bin/env python3

import argparse
import datetime
import logging
import socket
import struct

class TPPacket: #RFC 868

    EPOCH = datetime.datetime(year=1900, month=1, day=1, tzinfo=datetime.UTC)

    def __init__(self, payload=b'\0'*4):
        self._payload = payload
        assert len(self._payload) == 4

    @property
    def payload(self):
        return self._payload

    @property
    def time(self):
        seconds = struct.unpack('!L', self.payload[0:4])[0]
        return self.EPOCH + datetime.timedelta(seconds=seconds)

    @time.setter
    def time(self, value):
        seconds = int((value.astimezone(datetime.UTC) - self.EPOCH).total_seconds() + 0.5)
        self._payload = struct.pack('!L', seconds)

    def __repr__(self):
        result = f'Time {self.time.astimezone()}'
        return result

    @classmethod
    def TPResponse(cls, timestamp=None):
        if timestamp is None:
            timestamp=datetime.datetime.now(datetime.UTC)
        seconds = int((timestamp - cls.EPOCH).total_seconds() + 0.5)
        return TPPacket(struct.pack('!L', seconds))

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Time Protocol Client')
    parser.add_argument('--debug', action='store_true')
    parser.add_argument('--host', default='time.nist.gov') #20.215.232.215 #localhost
    parser.add_argument('--port', default='time')
    args = parser.parse_args()
    if args.debug:
        logging.basicConfig(level=logging.DEBUG)
    else:
        logging.basicConfig(level=logging.INFO)

    for family, type, proto, canonname, sockaddr in socket.getaddrinfo(args.host, args.port, family=socket.AF_INET, type=socket.SOCK_DGRAM):
        with socket.socket(family, type) as sock:
            sock.sendto(b'', sockaddr)
            logging.info(f'sent query to {sockaddr}')
            reply, addr = sock.recvfrom(2**16)
            if addr != sockaddr:
                logging.warning(f'unexpected response from {addr}')
            try:
                reply = TPPacket(reply)
                logging.info(f'received response from {addr}: {reply}')
            except AssertionError:
                logging.warning('mangled packet received')
