#!/usr/bin/env python3

import argparse
import datetime
import logging
import random
import socket
import struct
import time

class DNSPacket:
    def __init__(self, payload):
        self._payload = payload
        self._response_position = None
        self._response_authority_position = None
        self._response_additional_position = None
        assert len(self._payload) >= 12

    @property
    def payload(self):
        return self._payload

    @property
    def transaction_id(self):
        return int.from_bytes(self.payload[0:2], 'big')

    @property
    def flags(self):
        flags = int.from_bytes(self.payload[2:4], 'big')
        return [ flags & (1 << (15-i)) != 0 for i in range(16) ]

    @property
    def is_query(self):
        return self.flags[0] == False
    @property
    def is_response(self):
        return self.flags[0] == True
    @property
    def opcode(self):
        return sum([ 2**i for i in range(4) if self.flags[1+3-i] ])
    @property
    def is_authoritative(self):
        return self.flags[5]
    @property
    def is_truncated(self):
        return self.flags[6]
    @property
    def is_recursion_desired(self):
        return self.flags[7]
    @property
    def is_recursion_available(self):
        return self.flags[8]
    @property
    def is_authenticated(self):
        return self.flags[10]
    @property
    def is_authentication_disabled(self):
        return self.flags[11]
    @property
    def reply_code(self):
        return sum([ 2**i for i in range(4) if self.flags[12+3-i] ])

    @property
    def query_count(self):
        return int.from_bytes(self.payload[4:6], 'big')
    @property
    def response_count(self):
        return int.from_bytes(self.payload[6:8], 'big')
    @property
    def response_authority_count(self):
        return int.from_bytes(self.payload[8:10], 'big')
    @property
    def response_additional_count(self):
        return int.from_bytes(self.payload[10:12], 'big')

    @staticmethod
    def parse_name(payload, position):
        name = []
        real_position = None
        jumps = 0
        while True:
            assert len(payload) >= position+1
            size = payload[position]
            position += 1
            if size >= 128+64:
                assert len(payload) >= position+1
                if real_position is None:
                    real_position = position + 1
                jumps += 1
                assert jumps < 256
                position = (size & 63) * 256 + payload[position]
            else:
                assert size < 64
                if size == 0:
                    return (name, real_position or position)
                assert len(payload) >= position+size
                name.append(payload[position:position+size].decode())
                position+=size

    class Query:
        def __init__(self, payload, position):
            original_position = position
            self._name, position = DNSPacket.parse_name(payload, position)
            assert len(payload) >= position + 4
            self._type = int.from_bytes(payload[position:position+2], 'big')
            self._class = int.from_bytes(payload[position+2:position+4], 'big')
            self._payload_length = position + 4 - original_position
        @property
        def name(self):
            return '.'.join(self._name)
        @property
        def type(self):
            return self._type
        @property
        def clas(self):
            return self._class
        @property
        def payload_length(self):
            return self._payload_length
        def __repr__(self):
            return f'{self.name} ({self.type}, {self.clas})'

    class Response:
        def __init__(self, payload, position):
            self._payload = payload
            self._position = position
            self._name, position = DNSPacket.parse_name(payload, position)
            assert len(payload) >= position + 10
            self._type = int.from_bytes(payload[position:position+2], 'big')
            self._class = int.from_bytes(payload[position+2:position+4], 'big')
            self._ttl = int.from_bytes(payload[position+4:position+8], 'big')
            self._length = int.from_bytes(payload[position+8:position+10], 'big')
            position += 10
            assert len(payload) >= position + self._length
            self._data_position = position
            self._data = payload[position:position + self._length]
            self._payload_length = position + self._length - self._position
        @property
        def name(self):
            return '.'.join(self._name)
        @property
        def type(self):
            return self._type
        @property
        def clas(self):
            return self._class
        @property
        def ttl(self):
            return self._ttl
        @property
        def data(self):
            return self._data
        @property
        def ipv4_address(self):
            assert self.type in [ 1 ]
            assert self.clas in [ 1 ]
            assert len(self.data) == 4
            return '.'.join([ f'{int(self.data[i])}' for i in range(4) ])
        @property
        def ipv6_address(self):
            assert self.type in [ 28 ]
            assert self.clas in [ 1 ]
            assert len(self.data) == 16
            blocks = [ f'{self.data[2*i]:02x}{self.data[2*i+1]:02x}' for i in range(8) ]
            return ':'.join(blocks)
        @property
        def ns_address(self):
            assert self.type in [ 2 ]
            name, p = DNSPacket.parse_name(self._payload, self._data_position)
            return '.'.join(name)
        @property
        def cname_address(self):
            assert self.type in [ 5 ]
            name, p = DNSPacket.parse_name(self._payload, self._data_position)
            return '.'.join(name)
        @property
        def text(self):
            assert self.type in [ 16 ]
            return self.data

        @property
        def address(self):
            if self.type == 1:
                if self.clas == 1:
                    return self.ipv4_address
            if self.type == 28:
                if self.clas == 1:
                    return self.ipv6_address
            if self.type == 2:
                return self.ns_address
            if self.type == 5:
                return self.cname_address
            if self.type == 16:
                return self.text
            return self.data
        @property
        def payload_length(self):
            return self._payload_length
        def __repr__(self):
            return f'{self.name} ({self.type}, {self.clas}) = {self.address} TTL {self.ttl}'


    @property
    def query_position(self):
        return 12
    @property
    def queries(self):
        queries = []
        position = self.query_position
        for i in range(self.query_count):
            query = self.Query(self.payload, position)
            position += query.payload_length
            queries.append(query)
        self._response_position = position
        return queries

    @property
    def response_position(self):
        if self._response_position is None:
            self.queries
        return self._response_position
    @property
    def responses(self):
        responses = []
        position = self.response_position
        for i in range(self.response_count):
            response = self.Response(self.payload, position)
            position += response.payload_length
            responses.append(response)
        self._response_authority_position = position
        return responses

    @property
    def response_authority_position(self):
        if self._response_authority_position is None:
            self.responses
        return self._response_authority_position
    @property
    def responses_authority(self):
        responses_authority = []
        position = self.response_authority_position
        for i in range(self.response_authority_count):
            response = self.Response(self.payload, position)
            position += response.payload_length
            responses_authority.append(response)
        self._response_additional_position = position
        return responses_authority

    @property
    def response_additional_position(self):
        if self._response_additional_position is None:
            self.responses_authority
        return self._response_additional_position
    @property
    def responses_additional(self):
        responses_additional = []
        position = self.response_additional_position
        for i in range(self.response_additional_count):
            response = self.Response(self.payload, position)
            position += response.payload_length
            responses_additional.append(response)
        return responses_additional

    def __repr__(self):
        result = f'Transaction #{self.transaction_id}:'
        result += f' Operation {self.opcode}'
        result += f', Reply {self.reply_code}'
        flags = []
        if self.is_query:
            flags.append('query')
        else:
            flags.append('response')
        if self.is_authoritative:
            flags.append('authoritative')
        if self.is_truncated:
            flags.append('truncated')
        if self.is_recursion_desired:
            flags.append('recursion desired')
        if self.is_recursion_available:
            flags.append('recursion available')
        if self.is_authenticated:
            flags.append('authenticated')
        if self.is_authentication_disabled:
            flags.append('authentication disabled')
        result += f', ({", ".join(flags)})'
        result += f', {self.query_count} Queries: {self.queries}'
        result += f', {self.response_count} Responses: {self.responses}'
        result += f', {self.response_authority_count} Authority: {self.responses_authority}'
        result += f', {self.response_additional_count} Additional: {self.responses_additional}'
        return result

    @classmethod
    def DNSQuery(cls, name_type_classes, transaction_id=None, authoritative=False, recursion_desired=True, authentication_disabled=False):
        payload = (transaction_id or random.randrange(0, 2**16)).to_bytes(2, 'big') #TransactionID
        flags = [
                    False,                       #Query
                    False, False, False, False,  #OpCode
                    authoritative,               #Authoritative
                    False,                       #Truncated
                    recursion_desired,           #Recursion Desired
                    False,                       #Recursion Available
                    False,                       #
                    False,                       #Authenticated
                    authentication_disabled,     #Authentication Disabled
                    False, False, False, False,  #ReplyCode
                ]
        payload += sum([ 2**i for i in range(16) if flags[15-i] ]).to_bytes(2, 'big') #Flags
        payload += len(name_type_classes).to_bytes(2, 'big') #QueryCount
        payload += (0).to_bytes(2, 'big')        #ResponseCount
        payload += (0).to_bytes(2, 'big')        #ResponseAuthorityCount
        payload += (0).to_bytes(2, 'big')        #ResponseAdditionalCount
        for name, typ, clas in name_type_classes:
            for part in name.split('.'):
                part = part.encode()
                assert len(part) < 64
                payload += len(part).to_bytes(1, 'big')
                payload += part
            payload += (0).to_bytes(1, 'big')    #Name

            payload += (typ).to_bytes(2, 'big')    #Type
            payload += (clas).to_bytes(2, 'big')    #Class

        return DNSPacket(payload)

    @classmethod
    def DNSQueryA(cls, name, **kwargs):
        name_type_classes = [ (name, 1, 1), ]
        return cls.DNSQuery(name_type_classes, **kwargs)

    @classmethod
    def DNSQueryAAAA(cls, names, **kwargs):
        name_type_classes = [ (name, 28, 1), ]
        return cls.DNSQuery(name_type_classes, **kwargs)

    @classmethod
    def DNSQueryTXT(cls, names, **kwargs):
        name_type_classes = [ (name, 16, 1), ]
        return cls.DNSQuery(name_type_classes, **kwargs)


DNS_ROOTS = [ socket.gethostbyname(f'{l}.root-servers.net') for l in 'abcdefghijklm' ]

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Domain Name System Client UDP')
    parser.add_argument('--debug', action='store_true')
    parser.add_argument('--recursive', action='store_true')
    parser.add_argument('--host', default=random.choice(DNS_ROOTS))
    parser.add_argument('--port', default='domain')
    parser.add_argument('--timeout', type=float, default=2)
    parser.add_argument('--retries', type=int, default=5)
    parser.add_argument('hostname')
    args = parser.parse_args()
    if args.debug:
        logging.basicConfig(level=logging.DEBUG)
    else:
        logging.basicConfig(level=logging.INFO)

    for family, type, proto, canonname, sockaddr in socket.getaddrinfo(args.host, args.port, family=socket.AF_INET, type=socket.SOCK_DGRAM):
        with socket.socket(family, type) as sock:
            success = False
            tries = 0
            while not success:
                if tries >= args.retries:
                    logging.error(f'too many retries')
                    break
                query = DNSPacket.DNSQuery([(args.hostname,1,1),], recursion_desired=args.recursive)
                sock.sendto(query.payload, sockaddr)
                logging.info(f'sent UDP query to {sockaddr}: {query}')
                tries += 1
                next_retry_time = time.monotonic() + args.timeout
                while True:
                    timeout = next_retry_time - time.monotonic()
                    if timeout <= 0:
                        break
                    sock.settimeout(timeout)
                    try:
                        reply, addr = sock.recvfrom(2**16)
                    except socket.timeout:
                        logging.warning(f'socket timeout')
                        break
                    finally:
                        sock.settimeout(None)
                    if addr != sockaddr:
                        logging.warning(f'unexpected response from {addr}')
                        continue
                    try:
                        reply = DNSPacket(reply)
                        logging.info(f'received UDP response from {addr}: {reply}')
                        success = True
                        break
                    except AssertionError:
                        logging.warning('mangled response received')
                        continue
