#!/usr/bin/env bash
set -euo pipefail

MODEL="flux-2-klein-4b-local"
PROMPT=""
OUTPUT=""
WIDTH="1024"
HEIGHT="1024"
PREPARE="0"

while [[ $# -gt 0 ]]; do
  case "$1" in
    --model) MODEL="${2:-}"; shift 2 ;;
    --prompt) PROMPT="${2:-}"; shift 2 ;;
    --output) OUTPUT="${2:-}"; shift 2 ;;
    --width) WIDTH="${2:-1024}"; shift 2 ;;
    --height) HEIGHT="${2:-1024}"; shift 2 ;;
    --prepare) PREPARE="1"; shift ;;
    *) echo "runtime.image: ignoring unknown argument $1" >&2; shift ;;
  esac
done

if [[ "$PREPARE" != "1" && -z "$PROMPT" ]]; then
  echo "prompt is required" >&2
  exit 2
fi
if [[ "$PREPARE" != "1" && -z "$OUTPUT" ]]; then
  echo "output path is required" >&2
  exit 2
fi
if [[ "$MODEL" != "flux-2-klein-4b-local" ]]; then
  echo "unsupported image runtime model: $MODEL" >&2
  exit 2
fi

ROOT="${RYVION_IMAGE_RUNTIME_ROOT:-/var/lib/ryvion/image-runtime}"
VENV="$ROOT/venv"
CACHE="$ROOT/hf-cache"
MARKER="$ROOT/.deps-flux2-klein-v2"
READY_MARKER="$ROOT/.model-flux2-klein-ready-v2"
mkdir -p "$ROOT" "$CACHE"
export HF_HUB_DISABLE_XET="${HF_HUB_DISABLE_XET:-1}"

if command -v python3.12 >/dev/null 2>&1; then
  PYTHON="$(command -v python3.12)"
elif command -v python3 >/dev/null 2>&1; then
  PYTHON="$(command -v python3)"
else
  echo "Python 3.12 or python3 is required for Ryvion image runtime." >&2
  exit 127
fi

if [[ ! -x "$VENV/bin/python" ]]; then
  echo "runtime.image: creating Python environment"
  "$PYTHON" -m venv "$VENV"
fi
PY="$VENV/bin/python"

if [[ ! -f "$MARKER" ]]; then
  echo "runtime.image: installing FLUX.2 klein runtime dependencies"
  "$PY" -m pip install --upgrade pip
  "$PY" -m pip install --upgrade torch torchvision --index-url https://download.pytorch.org/whl/cu124
  "$PY" -m pip install --upgrade git+https://github.com/huggingface/diffusers.git transformers accelerate safetensors pillow protobuf sentencepiece huggingface_hub
  touch "$MARKER"
fi

RUN_SCRIPT="$ROOT/run_flux2_klein.py"
cat > "$RUN_SCRIPT" <<'PY'
	import os
	import shutil
	import sys
	import tarfile
	import torch
	import urllib.request
	from pathlib import Path
	from diffusers import Flux2KleinPipeline
	from huggingface_hub import snapshot_download
	
	def safe_extract_snapshot(url, token, cache_dir):
	    if not url or not token:
	        return None
	    local_dir = Path(cache_dir).parent / "flux2-klein-platform-snapshot"
	    complete = local_dir / ".complete"
	    if complete.exists() and (local_dir / "model_index.json").exists():
	        return str(local_dir)
	    tmp = Path(str(local_dir) + ".tmp")
	    if tmp.exists():
	        shutil.rmtree(tmp)
	    tmp.mkdir(parents=True, exist_ok=True)
	    req = urllib.request.Request(url, headers={
	        "X-Node-Token": token,
	        "User-Agent": "ryvion-node-image-runtime/1.0",
	    })
	    with urllib.request.urlopen(req, timeout=10800) as resp:
	        with tarfile.open(fileobj=resp, mode="r|gz") as archive:
	            root = tmp.resolve()
	            for member in archive:
	                target = (tmp / member.name).resolve()
	                if target != root and not str(target).startswith(str(root) + os.sep):
	                    raise SystemExit("runtime.image: unsafe model snapshot path")
	                archive.extract(member, tmp)
	    if local_dir.exists():
	        shutil.rmtree(local_dir)
	    tmp.rename(local_dir)
	    complete.write_text("ok\n", encoding="utf-8")
	    return str(local_dir)
	
	def ready_marker_dir(cache_dir):
	    marker = Path(cache_dir).parent / ".model-flux2-klein-ready-v2"
	    if not marker.exists():
	        return None
	    first = marker.read_text(encoding="utf-8").splitlines()[0].strip()
	    if first and (Path(first) / "model_index.json").exists():
	        return first
	    return None
	
	def resolve_model_dir(cache_dir):
	    cached = ready_marker_dir(cache_dir)
	    if cached:
	        return cached
	    platform_dir = safe_extract_snapshot(
	        os.environ.get("RYVION_FLUX2_SNAPSHOT_URL"),
	        os.environ.get("RYVION_NODE_TOKEN"),
	        cache_dir,
	    )
	    if platform_dir:
	        return platform_dir
	    return snapshot_download(
	        "black-forest-labs/FLUX.2-klein-4B",
	        cache_dir=cache_dir,
	        token=os.environ.get("HF_TOKEN") or None,
	        resume_download=True,
	    )
	
	model, prompt, output, width, height, cache_dir = sys.argv[1:7]
width = int(width)
height = int(height)
if model != "flux-2-klein-4b-local":
    raise SystemExit(f"unsupported model {model}")
if torch.cuda.is_available():
    device = "cuda"
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"
if device == "cuda":
    dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
elif device == "mps":
    dtype = torch.float16
else:
    dtype = torch.float32
	local_dir = resolve_model_dir(cache_dir)
pipe = Flux2KleinPipeline.from_pretrained(
    local_dir,
    torch_dtype=dtype,
    cache_dir=cache_dir,
)
if device == "cuda":
    pipe = pipe.to("cuda")
elif device == "mps":
    pipe = pipe.to("mps")
else:
    pipe = pipe.to("cpu")
generator = torch.Generator(device=device if device != "mps" else "cpu").manual_seed(0)
image = pipe(
    prompt=prompt,
    height=height,
    width=width,
    guidance_scale=1.0,
    num_inference_steps=4 if device != "cpu" else 2,
    generator=generator,
).images[0]
image.save(output)
print(f"runtime.image: wrote {output}")
PY

PREPARE_SCRIPT="$ROOT/prepare_flux2_klein.py"
cat > "$PREPARE_SCRIPT" <<'PY'
	import os
	import shutil
	import sys
	import tarfile
	import urllib.request
	from pathlib import Path
	import torch
	from diffusers import Flux2KleinPipeline
	from huggingface_hub import snapshot_download
	
	def safe_extract_snapshot(url, token, cache_dir):
	    if not url or not token:
	        return None
	    local_dir = Path(cache_dir).parent / "flux2-klein-platform-snapshot"
	    complete = local_dir / ".complete"
	    if complete.exists() and (local_dir / "model_index.json").exists():
	        return str(local_dir)
	    tmp = Path(str(local_dir) + ".tmp")
	    if tmp.exists():
	        shutil.rmtree(tmp)
	    tmp.mkdir(parents=True, exist_ok=True)
	    req = urllib.request.Request(url, headers={
	        "X-Node-Token": token,
	        "User-Agent": "ryvion-node-image-runtime/1.0",
	    })
	    with urllib.request.urlopen(req, timeout=10800) as resp:
	        with tarfile.open(fileobj=resp, mode="r|gz") as archive:
	            root = tmp.resolve()
	            for member in archive:
	                target = (tmp / member.name).resolve()
	                if target != root and not str(target).startswith(str(root) + os.sep):
	                    raise SystemExit("runtime.image: unsafe model snapshot path")
	                archive.extract(member, tmp)
	    if local_dir.exists():
	        shutil.rmtree(local_dir)
	    tmp.rename(local_dir)
	    complete.write_text("ok\n", encoding="utf-8")
	    return str(local_dir)
	
	model, cache_dir, ready_marker = sys.argv[1:4]
if model != "flux-2-klein-4b-local":
    raise SystemExit(f"unsupported model {model}")
	local_dir = safe_extract_snapshot(
	    os.environ.get("RYVION_FLUX2_SNAPSHOT_URL"),
	    os.environ.get("RYVION_NODE_TOKEN"),
	    cache_dir,
	)
	if not local_dir:
	    local_dir = snapshot_download(
	        "black-forest-labs/FLUX.2-klein-4B",
	        cache_dir=cache_dir,
	        token=os.environ.get("HF_TOKEN") or None,
	        resume_download=True,
	    )
if torch.cuda.is_available():
    device = "cuda"
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"
if device == "cuda":
    dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
elif device == "mps":
    dtype = torch.float16
else:
    dtype = torch.float32
pipe = Flux2KleinPipeline.from_pretrained(local_dir, torch_dtype=dtype, cache_dir=cache_dir)
pipe = pipe.to(device)
probe_path = str(Path(ready_marker).with_name("readiness_probe.png"))
generator = torch.Generator(device=device if device != "mps" else "cpu").manual_seed(1)
image = pipe(
    prompt="ryvion runtime readiness probe",
    height=256,
    width=256,
    guidance_scale=1.0,
    num_inference_steps=1,
    generator=generator,
).images[0]
image.save(probe_path)
Path(ready_marker).write_text(f"{local_dir}\nprobe={probe_path}\n", encoding="utf-8")
print(f"runtime.image: smoke-tested model cache at {local_dir}")
PY

if [[ "$PREPARE" == "1" ]]; then
  "$PY" "$PREPARE_SCRIPT" "$MODEL" "$CACHE" "$READY_MARKER"
  exit 0
fi
if [[ ! -f "$READY_MARKER" ]]; then
  echo "runtime.image: model cache not ready; preparing now"
  "$PY" "$PREPARE_SCRIPT" "$MODEL" "$CACHE" "$READY_MARKER"
fi
"$PY" "$RUN_SCRIPT" "$MODEL" "$PROMPT" "$OUTPUT" "$WIDTH" "$HEIGHT" "$CACHE"
