#!/usr/bin/env python3

import socket
import subprocess

HOST = "0.0.0.0"
PORT = 4242

SEP1 = b"\n----\n"
SEP2 = b"\n---\n"

def get_fortune():
    """Return a fortune as bytes, ending with newline"""
    result = subprocess.run(
        ["fortune"],
        stdout=subprocess.PIPE,
        stderr=subprocess.DEVNULL,
        text=True
    )
    return result.stdout.rstrip().encode() + b"\n"

def recv_exact(conn, n):
    """Receive exactly n bytes from conn, blocking if necessary"""
    data = b""
    while len(data) < n:
        chunk = conn.recv(n - len(data))
        if not chunk:
            if len(data) == 0:
                return None  # connection closed cleanly
            raise ConnectionError("Connection closed before full data received")
        data += chunk
    return data

def handle_client(conn, addr):
    print(f"[+] Connected: {addr}")
    try:
        while True:
            # 1) Read 3-byte length prefix
            raw_len = recv_exact(conn, 3)
            if not raw_len:
                break
            try:
                declared_len = int(raw_len.decode())
            except ValueError:
                print(f"[{addr}] Invalid length prefix")
                break

            # 2) Read exactly declared_len bytes from the stream
            body = recv_exact(conn, declared_len)
            if body is None:
                break

            message = body.decode(errors="ignore")
            actual_len = len(body)

            # 3) Print message and lengths
            print(f"[{addr}] Received length: {declared_len}")
            print(f"[{addr}] Message: {message}")

            # 4) Check length
            if actual_len != declared_len:
                error_msg = f"Bad length. You sent {actual_len:03d} but should have sent {declared_len:03d}"
                conn.sendall(error_msg.encode())
                break

            # 5) Handle 'bye'
            if message.strip().lower() == "bye":
                print(f"[{addr}] Client said bye. Closing.")
                break

            # 6) Build reply: echo + separators + fortune
            payload = SEP1 + body + SEP2 + get_fortune()
            reply_len = len(payload)
            reply_prefix = f"{reply_len:03d}".encode()

            # 7) Send reply
            conn.sendall(reply_prefix + payload)

    except ConnectionError as e:
        print(f"[{addr}] Connection error: {e}")
    finally:
        conn.close()
        print(f"[-] Disconnected: {addr}")

def main():
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server:
        server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        server.bind((HOST, PORT))
        server.listen()
        print(f"[*] Listening on {HOST}:{PORT}")

        while True:
            conn, addr = server.accept()  # BLOCKS until a client connects
            handle_client(conn, addr)

if __name__ == "__main__":
    main()
