#!/usr/bin/env python3
"""
stl_render.py — Render ASCII-STL to SVG via Python stdlib only.
Isometric projection + face shading based on surface normal vs light direction.
No pip, no numpy, no matplotlib.

Usage:
    python3 stl_render.py input.stl output.svg [width]

Or import:
    from stl_render import render_stl_to_svg
    svg_text = render_stl_to_svg("model.stl", width=800)
"""

import math
import re
import sys

# ── Math helpers ──

def dot(a, b):
    return a[0]*b[0] + a[1]*b[1] + a[2]*b[2]

def cross(a, b):
    return (
        a[1]*b[2] - a[2]*b[1],
        a[2]*b[0] - a[0]*b[2],
        a[0]*b[1] - a[1]*b[0]
    )

def sub(a, b):
    return (a[0]-b[0], a[1]-b[1], a[2]-b[2])

def length(v):
    return math.sqrt(v[0]**2 + v[1]**2 + v[2]**2)

def normalize(v):
    l = length(v)
    if l == 0:
        return (0, 0, 1)
    return (v[0]/l, v[1]/l, v[2]/l)

# Isometric projection matrix (view from (-1,-1,1) direction)
ISO_X = (-math.sqrt(3)/2,  math.sqrt(3)/2,  0)
ISO_Y = (-1/2,            -1/2,            1)
# Simplified screen projection

def project(v, scale=1, offset=(0, 0)):
    """Isometric-ish projection: x right, y up-left."""
    x = v[0]
    y = v[1]
    z = v[2]
    # 2D isometric
    px = (x - y) * scale + offset[0]
    py = -(x + y) * 0.5 * scale + z * scale + offset[1]
    return (px, py)

def parse_stl(path):
    """Parse ASCII STL into list of (normal, v1, v2, v3) tuples."""
    facets = []
    with open(path, 'r') as f:
        content = f.read()

    # Simple regex parser for ASCII STL
    facet_re = re.compile(
        r'facet normal\s+([-\d.eE]+)\s+([-\d.eE]+)\s+([-\d.eE]+)\s+'
        r'outer loop\s+'
        r'vertex\s+([-\d.eE]+)\s+([-\d.eE]+)\s+([-\d.eE]+)\s+'
        r'vertex\s+([-\d.eE]+)\s+([-\d.eE]+)\s+([-\d.eE]+)\s+'
        r'vertex\s+([-\d.eE]+)\s+([-\d.eE]+)\s+([-\d.eE]+)\s+'
        r'endloop\s+endfacet',
        re.IGNORECASE
    )

    for m in facet_re.finditer(content):
        nums = [float(m.group(i)) for i in range(1, 13)]
        normal = tuple(nums[0:3])
        v1 = tuple(nums[3:6])
        v2 = tuple(nums[6:9])
        v3 = tuple(nums[9:12])
        facets.append((normal, v1, v2, v3))

    return facets

def shade(normal, light=(0.5, -0.5, 1.0)):
    """Return grayscale brightness 0-255 based on how much the face points toward light."""
    light = normalize(light)
    n = normalize(normal)
    intensity = dot(n, light)
    # Clamp 0..1, then map to 100..240 (avoid pure black/white)
    intensity = max(0.0, min(1.0, (intensity + 1.0) / 2.0))
    val = int(100 + intensity * 140)
    return val

def render_stl_to_svg(stl_path, width=800, margin=40, bg="#1a1a2e"):
    facets = parse_stl(stl_path)
    if not facets:
        return ('<svg xmlns="http://www.w3.org/2000/svg" width="400" height="200">'
                '<text x="20" y="30" fill="red">No facets found in STL</text></svg>')

    # Compute bounds in 3D
    all_v = []
    for _, v1, v2, v3 in facets:
        all_v.extend([v1, v2, v3])

    xs = [v[0] for v in all_v]
    ys = [v[1] for v in all_v]
    zs = [v[2] for v in all_v]

    # Project all to find 2D bounds
    proj_all = []
    for v in all_v:
        proj_all.append(project(v, scale=1.0))

    pxs = [p[0] for p in proj_all]
    pys = [p[1] for p in proj_all]

    min_x, max_x = min(pxs), max(pxs)
    min_y, max_y = min(pys), max(pys)

    # Scale to fit width
    content_w = max_x - min_x
    content_h = max_y - min_y
    scale = (width - 2 * margin) / max(content_w, content_h)
    height = int(content_h * scale + 2 * margin)

    offset = (margin - min_x * scale, margin - min_y * scale + (height - 2*margin - content_h*scale)/2)

    # Sort facets by average Z (painter's algorithm — back to front)
    def avg_z(facet):
        _, v1, v2, v3 = facet
        # For isometric, "depth" is x+y (back-left is deeper)
        return -(v1[0] + v1[1] + v2[0] + v2[1] + v3[0] + v3[1])

    facets.sort(key=avg_z, reverse=True)

    # SVG header
    lines = []
    lines.append(f'<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}" '
                   f'viewBox="0 0 {width} {height}" style="background:{bg}">')
    lines.append(f'<rect width="{width}" height="{height}" fill="{bg}"/>')

    # Light direction
    light_dir = normalize((0.5, -0.5, 1.0))

    for normal, v1, v2, v3 in facets:
        # Recalculate normal from vertices (more reliable than STL normal)
        a = sub(v2, v1)
        b = sub(v3, v1)
        n = cross(a, b)
        real_normal = normalize(n)

        # Back-face culling: skip if facing away from viewer
        # Viewer is from (-1,-1,1), so dot with that should be > 0 to show
        viewer = normalize((-1, -1, 1))
        if dot(real_normal, viewer) < -0.05:
            continue

        # Project vertices
        p1 = project(v1, scale, offset)
        p2 = project(v2, scale, offset)
        p3 = project(v3, scale, offset)

        # Shade
        bright = shade(real_normal, light_dir)
        # Bambu Studio-like warm grey/white
        val = bright / 255.0
        # Ambient 0.2, diffuse 0.8
        intensity = 0.2 + 0.8 * max(0.0, dot(real_normal, light_dir))
        intensity = max(0.0, min(1.0, intensity))
        
        # Warm grey color base
        base_r, base_g, base_b = 220, 218, 215
        r = int(base_r * intensity)
        g = int(base_g * intensity)
        b_col = int(base_b * intensity)
        fill_color = f"#{r:02x}{g:02x}{b_col:02x}"
        
        # Edge stroke darker
        edge = f"#{max(0,r-60):02x}{max(0,g-60):02x}{max(0,b_col-60):02x}"

        # Draw triangle
        points = f"{p1[0]:.2f},{p1[1]:.2f} {p2[0]:.2f},{p2[1]:.2f} {p3[0]:.2f},{p3[1]:.2f}"
        lines.append(f'<polygon points="{points}" fill="{fill_color}" stroke="{edge}" stroke-width="0.8" stroke-linejoin="round"/>')

    lines.append('</svg>')
    return '\n'.join(lines)


def main():
    if len(sys.argv) < 3:
        print("Usage: python3 stl_render.py input.stl output.svg [width]")
        sys.exit(1)

    stl_path = sys.argv[1]
    svg_path = sys.argv[2]
    width = int(sys.argv[3]) if len(sys.argv) > 3 else 800

    svg = render_stl_to_svg(stl_path, width)
    with open(svg_path, 'w') as f:
        f.write(svg)

    print(f"Rendered {stl_path} → {svg_path} ({width}px wide)")


if __name__ == "__main__":
    main()
