#!/usr/bin/env python3
# vim:ts=4:sts=4:sw=4:expandtab

import argparse
import logging
import pathlib
import socket
import threading
import traceback

DEFAULT_PORT=4567
DEFAULT_PATH='.'

def client_thread(client_socket, client_address, base_path):
    try:
        request_subpath = b''
        while 0 not in request_subpath:
            request_subpath += client_socket.recv(8192)
        request_subpath = str(request_subpath[:request_subpath.index(0)], 'utf-8')
        logging.info(f'Client {client_address} requests file {request_subpath}')
        
        base_path = pathlib.Path(base_path)
        file_path = base_path / request_subpath
        assert base_path.resolve() in file_path.resolve().parents
        assert file_path.is_file()

        with (pathlib.Path(base_path) / request_subpath).open('rb') as request_file:
            while True:
                data = request_file.read(8192)
                if not data:
                    break
                client_socket.sendall(data)
        logging.info(f'Sending a file {request_subpath} to {client_address} succeeded')
    except:
        logging.error(f'Sending a file to {client_address} failed:\n{traceback.format_exc()}')
    finally:
        client_socket.close()

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='File Transfer TCP Server')
    parser.add_argument('--debug', action='store_true')
    parser.add_argument('--port', default=DEFAULT_PORT)
    parser.add_argument('--path', default=DEFAULT_PATH)
    args = parser.parse_args()
    if args.debug:
        logging.basicConfig(level=logging.DEBUG)
    else:
        logging.basicConfig(level=logging.INFO)

    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_socket:
        server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        server_socket.bind(('', args.port))
        server_socket.listen(64)
        while True:
            (client_socket, client_address) = server_socket.accept()
            logging.info(f'Client {client_address} connected')
            threading.Thread(target=client_thread,
                    args=(client_socket, client_address, args.path)
                ).start()
