#!/usr/bin/env python3
"""Kiwi Proxy: Anthropic Messages API → Ollama Cloud (OpenAI-compatible)
Drop-in CCR replacement for Ollama Cloud Pro users.
Set: ANTHROPIC_BASE_URL=http://localhost:3456 ANTHROPIC_API_KEY=kiwi-ccr-local

Requirements: pip install fastapi uvicorn httpx
"""

import os, json, time, httpx
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse

OLLAMA_BASE = os.environ.get("OLLAMA_BASE_URL", "https://ollama.com/v1")
OLLAMA_KEY = os.environ.get("OLLAMA_API_KEY", "")
PROXY_KEY = os.environ.get("PROXY_API_KEY", "kiwi-ccr-local")
MODEL_MAP = {
    "claude-sonnet-4-20250514": "deepseek-v4-pro",
    "claude-3.5-sonnet": "deepseek-v4-pro",
    "claude": "deepseek-v4-pro",
    "default": "deepseek-v4-pro",
}
PORT = int(os.environ.get("PORT", "3456"))

app = FastAPI(title="Kiwi Proxy", version="1.0.0")
client = httpx.AsyncClient(timeout=600.0, follow_redirects=True)


def an_to_oai(messages: list) -> list:
    """Convert Anthropic messages → OpenAI messages format."""
    oai_msgs = []
    for m in messages:
        role = m.get("role", "user")
        content = m.get("content", "")
        if isinstance(content, list):
            text_parts = []
            for block in content:
                if block.get("type") == "text":
                    text_parts.append(block.get("text", ""))
                elif block.get("type") == "image":
                    text_parts.append(f"[image: {block.get('source', {}).get('media_type', 'unknown')}]")
                elif block.get("type") == "tool_use":
                    text_parts.append(json.dumps({"tool_use": block}))
                elif block.get("type") == "tool_result":
                    text_parts.append(json.dumps({"tool_result": block}))
            content = "\n".join(text_parts)
        if role == "system":
            oai_msgs.insert(0, {"role": "system", "content": content})
        elif role == "assistant":
            oai_msgs.append({"role": "assistant", "content": content})
        else:
            oai_msgs.append({"role": "user", "content": content})
    return oai_msgs


def oai_to_an(choice: dict) -> dict:
    """Convert OpenAI response → Anthropic format."""
    msg = choice.get("message", {})
    content = msg.get("content", "")
    stop_reason = choice.get("finish_reason", "end_turn")
    reason_map = {"stop": "end_turn", "length": "max_tokens", "tool_calls": "tool_use"}
    stop = reason_map.get(stop_reason, stop_reason)
    return {
        "id": f"msg_{int(time.time()*1000)}",
        "type": "message",
        "role": "assistant",
        "content": [{"type": "text", "text": content}],
        "model": "claude-sonnet-4-20250514",
        "stop_reason": stop,
        "stop_sequence": None,
        "usage": {
            "input_tokens": choice.get("usage", {}).get("prompt_tokens", 0) or 0,
            "output_tokens": choice.get("usage", {}).get("completion_tokens", 0) or 0,
        },
    }


async def stream_oai_to_an(response, model: str):
    """Stream SSE from OpenAI → Anthropic SSE."""
    buffer = ""
    async for chunk in response.aiter_bytes():
        buffer += chunk.decode("utf-8", errors="replace")
        while "\n" in buffer:
            line, buffer = buffer.split("\n", 1)
            line = line.strip()
            if not line or not line.startswith("data:"):
                continue
            data_str = line[5:].strip()
            if data_str == "[DONE]":
                yield "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n"
                continue
            try:
                data = json.loads(data_str)
                choices = data.get("choices", [])
                if not choices:
                    continue
                delta = choices[0].get("delta", {})
                content = delta.get("content", "")
                if content:
                    event = {
                        "type": "content_block_delta",
                        "index": 0,
                        "delta": {"type": "text_delta", "text": content},
                    }
                    yield f"event: content_block_delta\ndata: {json.dumps(event)}\n\n"
            except json.JSONDecodeError:
                continue


async def handle_messages(request: Request) -> dict:
    """Handle Anthropic /v1/messages → Ollama /v1/chat/completions."""
    auth = request.headers.get("x-api-key") or request.headers.get("authorization", "")
    auth = auth.replace("Bearer ", "").strip()
    if PROXY_KEY and auth != PROXY_KEY:
        raise HTTPException(401, "Invalid API key")

    body = await request.json()
    model = body.get("model", "claude-sonnet-4-20250514")
    mapped_model = MODEL_MAP.get(model, MODEL_MAP.get("default", "deepseek-v4-pro"))
    messages = body.get("messages", [])
    system = body.get("system", "")
    max_tokens = max(body.get("max_tokens", 4096), 50)
    temperature = body.get("temperature", 0.7)
    top_p = body.get("top_p", 1.0)
    stream = body.get("stream", False)

    oai_msgs = an_to_oai(messages)
    if system:
        oai_msgs.insert(0, {"role": "system", "content": system})

    oai_body = {
        "model": mapped_model,
        "messages": oai_msgs,
        "max_tokens": max_tokens,
        "temperature": temperature,
        "top_p": top_p,
        "stream": stream,
    }

    headers = {
        "Authorization": f"Bearer {OLLAMA_KEY}",
        "Content-Type": "application/json",
    }

    if stream:
        req = client.build_request(
            "POST", f"{OLLAMA_BASE}/chat/completions", json=oai_body, headers=headers
        )
        resp = await client.send(req, stream=True)
        if resp.status_code != 200:
            body_err = await resp.aread()
            raise HTTPException(resp.status_code, detail=body_err.decode()[:1000])
        return StreamingResponse(
            stream_oai_to_an(resp, mapped_model),
            media_type="text/event-stream",
        )
    else:
        resp = await client.post(f"{OLLAMA_BASE}/chat/completions", json=oai_body, headers=headers)
        if resp.status_code != 200:
            raise HTTPException(resp.status_code, detail=resp.text[:1000])
        data = resp.json()
        choices = data.get("choices", [])
        if not choices:
            raise HTTPException(502, "No choices in upstream response")
        return oai_to_an(choices[0])


@app.get("/health")
async def health():
    return {"status": "ok", "proxy": "kiwi-ccr", "upstream": OLLAMA_BASE}


@app.api_route("/v1/messages", methods=["POST"])
@app.api_route("/v1/messages/", methods=["POST"])
async def proxy_messages(request: Request):
    """Catch-all: log + proxy /v1/messages."""
    return JSONResponse(await handle_messages(request))


if __name__ == "__main__":
    import uvicorn
    print(f"[PROXY] Starting on port {PORT}, upstream: {OLLAMA_BASE}")
    uvicorn.run(app, host="0.0.0.0", port=PORT, log_level="warning")
