import http.client
import sys
from http.server import SimpleHTTPRequestHandler, ThreadingHTTPServer
from urllib.parse import urlsplit

VERACITY_HOST = "https://dev-veracity.groundedai.company"
# VERACITY_HOST = "http://localhost:8000"
PROXY_PREFIX = "/veracity"
VERACITY_MODE = "law_word"


class Handler(SimpleHTTPRequestHandler):
    def end_headers(self):
        self.send_header("Cache-Control", "no-store")
        self.send_header("Pragma", "no-cache")
        super().end_headers()

    def do_OPTIONS(self):
        if self.path.startswith(PROXY_PREFIX + "/"):
            self.send_response(204)
            self.send_header("Access-Control-Allow-Origin", "*")
            self.send_header("Access-Control-Allow-Methods", "GET,POST,OPTIONS")
            self.send_header(
                "Access-Control-Allow-Headers",
                "Content-Type,X-Api-Key,Accept,X-Veracity-Mode",
            )
            self.end_headers()
            return

        self.send_response(204)
        self.end_headers()

    def do_GET(self):
        if self.path.startswith(PROXY_PREFIX + "/"):
            self._proxy()
            return
        super().do_GET()

    def do_POST(self):
        if self.path.startswith(PROXY_PREFIX + "/"):
            self._proxy()
            return
        self.send_error(404)

    def _proxy(self):
        parts = urlsplit(self.path)
        target_path = parts.path[len(PROXY_PREFIX) :]
        if parts.query:
            target_path = target_path + "?" + parts.query

        api_key = self.headers.get("X-Api-Key")
        if not api_key:
            self.send_response(401)
            self.send_header("Content-Type", "application/json")
            self.send_header("Access-Control-Allow-Origin", "*")
            self.end_headers()
            self.wfile.write(b'{"error":"Missing X-Api-Key"}')
            return

        body = b""
        content_length = self.headers.get("Content-Length")
        if content_length:
            body = self.rfile.read(int(content_length))

        outgoing_headers = {
            "X-Api-Key": api_key,
            "Accept": self.headers.get("Accept") or "application/json",
        }

        if VERACITY_MODE:
            outgoing_headers["X-Veracity-Mode"] = VERACITY_MODE

        content_type = self.headers.get("Content-Type")
        if content_type:
            outgoing_headers["Content-Type"] = content_type

        if VERACITY_HOST.startswith("https://"):
            conn = http.client.HTTPSConnection(VERACITY_HOST[len("https://") :])
        elif VERACITY_HOST.startswith("http://"):
            conn = http.client.HTTPConnection(VERACITY_HOST[len("http://") :])
        else:
            raise ValueError("VERACITY_HOST must start with http:// or https://")
        conn.request(self.command, target_path, body=body, headers=outgoing_headers)
        resp = conn.getresponse()
        resp_body = resp.read()

        self.send_response(resp.status)

        resp_content_type = resp.getheader("Content-Type")
        if resp_content_type:
            self.send_header("Content-Type", resp_content_type)

        self.send_header("Access-Control-Allow-Origin", "*")
        self.send_header("Cache-Control", "no-store")
        self.end_headers()

        self.wfile.write(resp_body)


def main():
    port = int(sys.argv[1]) if len(sys.argv) > 1 else 3111
    server = ThreadingHTTPServer(("0.0.0.0", port), Handler)
    print(f"Serving UI on http://0.0.0.0:{port}")
    print(f"Proxying {PROXY_PREFIX}/* -> {VERACITY_HOST}/*")
    server.serve_forever()


if __name__ == "__main__":
    main()
