#!/usr/bin/env python3
"""Proxy: Anthropic Messages API → Ollama Cloud (OpenAI-compatible)
Serves HTTPS on port 443 with self-signed cert.
Claude Code connects to api.anthropic.com (redirected via /etc/hosts).
No auth check — Claude Code OAuth is opaque to us.
"""
import sys
sys.path.insert(0, "/DATA/.local/lib/python3.12/site-packages")
import os, json, time, httpx
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse

OLLAMA_BASE = os.environ.get("OLLAMA_BASE_URL", "https://ollama.com/v1")
OLLAMA_KEY = os.environ.get("OLLAMA_API_KEY", "")
PROXY_KEY = os.environ.get("PROXY_API_KEY", "")
PORT = int(os.environ.get("PORT", "443"))
CERT_FILE = os.environ.get("CERT_FILE", "/etc/kiwi-proxy/certs/cert.pem")
KEY_FILE = os.environ.get("KEY_FILE", "/etc/kiwi-proxy/certs/key.pem")
INSECURE = os.environ.get("INSECURE", "").lower() in ("1", "true", "yes")

app = FastAPI(title="Kiwi Proxy", version="2.0.0")
client = httpx.AsyncClient(timeout=600.0)


def convert_messages(an_msgs: list, system: str = "") -> list:
    """Anthropic → OpenAI messages."""
    oai = []
    if system:
        oai.append({"role": "system", "content": system})
    for m in an_msgs:
        role = m.get("role", "user")
        content = m.get("content", "")
        if isinstance(content, list):
            parts = []
            for b in content:
                t = b.get("type", "")
                if t == "text":
                    parts.append(b.get("text", ""))
                elif t == "image":
                    parts.append(f"[image: {b.get('source', {}).get('media_type', 'unknown')}]")
                elif t in ("tool_use", "tool_result"):
                    parts.append(json.dumps(b))
            content = "\n".join(parts)
        oai.append({"role": "assistant" if role == "assistant" else "user", "content": content})
    return oai


def make_an_response(choice: dict, an_model: str) -> dict:
    """OpenAI choice → Anthropic message response."""
    msg = choice.get("message", {})
    content = msg.get("content", "")
    stop = choice.get("finish_reason", "end_turn")
    stop_map = {"stop": "end_turn", "length": "max_tokens", "tool_calls": "tool_use"}
    usage = choice.get("usage", {})
    return {
        "id": f"msg_{int(time.time()*1000)}",
        "type": "message",
        "role": "assistant",
        "content": [{"type": "text", "text": content}],
        "model": an_model,
        "stop_reason": stop_map.get(stop, stop),
        "stop_sequence": None,
        "usage": {
            "input_tokens": usage.get("prompt_tokens", 0) or 1,
            "output_tokens": usage.get("completion_tokens", 0) or max(len(content.split()), 1),
        },
    }


async def stream_response(upstream_resp, an_model: str):
    """Stream Ollama SSE → Anthropic SSE."""
    buf = ""
    msg_id = f"msg_{int(time.time()*1000)}"
    started = False
    content_idx = 0
    async for chunk in upstream_resp.aiter_bytes():
        buf += chunk.decode("utf-8", errors="replace")
        while "\n" in buf:
            line, buf = buf.split("\n", 1)
            line = line.strip()
            if not line.startswith("data:"):
                continue
            data = line[5:].strip()
            if data == "[DONE]":
                if started:
                    yield f"event: content_block_stop\ndata: {{\"type\":\"content_block_stop\",\"index\":0}}\n\n"
                    yield f"event: message_delta\ndata: {{\"type\":\"message_delta\",\"delta\":{{\"stop_reason\":\"end_turn\",\"stop_sequence\":null}},\"usage\":{{\"output_tokens\":1}}}}\n\n"
                yield f"event: message_stop\ndata: {{\"type\":\"message_stop\"}}\n\n"
                return
            try:
                j = json.loads(data)
                c = j.get("choices", [{}])[0]
                delta = c.get("delta", {})
                text = delta.get("content", "")
                if text:
                    if not started:
                        started = True
                        yield f"event: message_start\ndata: {{\"type\":\"message_start\",\"message\":{{\"id\":\"{msg_id}\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"{an_model}\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{{\"input_tokens\":1,\"output_tokens\":1}}}}}}\n\n"
                        yield f"event: content_block_start\ndata: {{\"type\":\"content_block_start\",\"index\":0,\"content_block\":{{\"type\":\"text\",\"text\":\"\"}}}}\n\n"
                        yield "event: ping\ndata: {}\n\n"
                    escaped = json.dumps(text)
                    yield f"event: content_block_delta\ndata: {{\"type\":\"content_block_delta\",\"index\":0,\"delta\":{{\"type\":\"text_delta\",\"text\":{escaped}}}}}\n\n"
            except (json.JSONDecodeError, KeyError, IndexError):
                continue
    if started:
        yield f"event: content_block_stop\ndata: {{\"type\":\"content_block_stop\",\"index\":0}}\n\n"
        yield f"event: message_delta\ndata: {{\"type\":\"message_delta\",\"delta\":{{\"stop_reason\":\"end_turn\",\"stop_sequence\":null}},\"usage\":{{\"output_tokens\":1}}}}\n\n"
        yield f"event: message_stop\ndata: {{\"type\":\"message_stop\"}}\n\n"


@app.get("/health")
async def health():
    return {"status": "ok", "https": True, "upstream": OLLAMA_BASE}


@app.get("/v1/models")
async def list_models():
    """Fake model list — Claude Code checks this."""
    return JSONResponse({
        "data": [
            {"id": "claude-sonnet-4-20250514", "type": "model", "display_name": "Claude Sonnet 4"},
        ],
        "has_more": False,
        "first_id": "claude-sonnet-4-20250514",
        "last_id": "claude-sonnet-4-20250514",
    })


@app.api_route("/v1/messages", methods=["POST"])
async def messages(request: Request):
    """Handle /v1/messages — read body ONCE."""
    body = await request.json()

    headers = dict(request.headers)
    safe = {k: v[:50] for k, v in headers.items()}
    print(f"[PROXY] POST /v1/messages stream={body.get('stream')} model={body.get('model')} max_tokens={body.get('max_tokens')}", flush=True)

    model = body.get("model", "claude-sonnet-4-20250514")
    an_model = model
    mapped = "deepseek-v4-pro"
    messages_list = body.get("messages", [])
    system = body.get("system", "")
    max_tokens = max(body.get("max_tokens", 4096), 100)
    stream = body.get("stream", False)

    oai_body = {
        "model": mapped,
        "messages": convert_messages(messages_list, system),
        "max_tokens": max_tokens,
        "temperature": body.get("temperature", 0.7),
        "top_p": body.get("top_p", 1.0),
        "stream": stream,
    }

    upstream_headers = {
        "Authorization": f"Bearer {OLLAMA_KEY}",
        "Content-Type": "application/json",
    }

    if stream:
        req = client.build_request("POST", f"{OLLAMA_BASE}/chat/completions", json=oai_body, headers=upstream_headers)
        resp = await client.send(req, stream=True)
        if resp.status_code != 200:
            err = (await resp.aread()).decode()[:500]
            print(f"[PROXY] ERR: upstream {resp.status_code}: {err}", flush=True)
            raise HTTPException(502, detail=err)
        return StreamingResponse(
            stream_response(resp, an_model),
            media_type="text/event-stream",
            headers={
                "x-request-id": f"msg_{int(time.time()*1000)}",
                "anthropic-version": "2023-06-01",
            }
        )
    else:
        resp = await client.post(f"{OLLAMA_BASE}/chat/completions", json=oai_body, headers=upstream_headers)
        if resp.status_code != 200:
            print(f"[PROXY] ERR: upstream {resp.status_code}: {resp.text[:300]}", flush=True)
            raise HTTPException(502, detail=resp.text[:500])
        data = resp.json()
        choices = data.get("choices", [])
        if not choices:
            raise HTTPException(502, "No choices in upstream response")
        result = make_an_response(choices[0], an_model)
        print(f"[PROXY] OK: model={mapped} tokens={result['usage']['output_tokens']}", flush=True)
        return JSONResponse(result)


@app.api_route("/v1/messages/count_tokens", methods=["POST"])
async def count_tokens(request: Request):
    return JSONResponse({"input_tokens": 100})


@app.api_route("/{full_path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"])
async def catch_all(request: Request, full_path: str = ""):
    print(f"[PROXY] UNHANDLED: {request.method} /{full_path}", flush=True)
    return JSONResponse({"error": "not_found", "path": full_path}, status_code=404)


if __name__ == "__main__":
    import uvicorn

    if not os.path.exists(CERT_FILE):
        print(f"ERROR: cert not found at {CERT_FILE}", file=sys.stderr)
        sys.exit(1)

    log_level = "info" if INSECURE else "warning"
    uvicorn.run(
        app,
        host="0.0.0.0",
        port=PORT,
        ssl_keyfile=KEY_FILE,
        ssl_certfile=CERT_FILE,
        log_level=log_level,
    )
