#!/usr/bin/env python3
# vim:ts=4:sts=4:sw=4:expandtab

import argparse
import logging
import pathlib
import socket
import struct
import threading
import traceback

DEFAULT_PORT=4567
DEFAULT_PATH='.'
DEFAULT_PART_SIZE=1024

def client_thread(client_socket, client_address, client_message, base_path, part_size):
    try:
        #TODO: Reliably get a request path from client (decide which part of which file is requested)
        assert len(client_message) >= 18
        trans_id, part_num, path_length = struct.unpack('!QQH', client_message[0:18])
        assert len(client_message) >= 18 + path_length
        request_subpath = str(client_message[18:18+path_length], 'utf-8')
        base_path = pathlib.Path(base_path)
        file_path = base_path / request_subpath
        print(base_path, file_path)
        assert base_path.resolve() in file_path.resolve().parents
        assert file_path.is_file()

        if part_num == 0:
            #TODO: Send info on the requested path (number of parts)
            file_size = file_path.stat().st_size
            part_count = (file_size + part_size - 1) // part_size
            client_socket.sendto(struct.pack('!QQQ', trans_id, part_num, part_count), client_address)
        else:
            #TODO: Send requested part of the file
            with file_path.open('rb') as request_file:
                request_file.seek(part_size*(part_num-1))
                client_socket.sendto(struct.pack('!QQ', trans_id, part_num)+request_file.read(part_size), client_address)
    except:
        logging.error(f'Responding to {client_address} failed:\n{traceback.format_exc()}')

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='File Transfer UDP Server')
    parser.add_argument('--debug', action='store_true')
    parser.add_argument('--port', default=DEFAULT_PORT)
    parser.add_argument('--path', default=DEFAULT_PATH)
    parser.add_argument('--part-size', default=DEFAULT_PART_SIZE)
    args = parser.parse_args()
    if args.debug:
        logging.basicConfig(level=logging.DEBUG)
    else:
        logging.basicConfig(level=logging.INFO)

    with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as server_socket:
        server_socket.bind(('', args.port))
        while True:
            (client_message, client_address) = server_socket.recvfrom(2**16)
            logging.debug(f'Client {client_address} sent a message')
            threading.Thread(target=client_thread,
                    args=(server_socket, client_address, client_message, args.path, args.part_size)
                ).start()
