"""
Overview
========

This script evaluates the robustness of **ResNet‑50** and **DenseNet‑121**
classifiers against *adversarial* versions of a clean image set. For every
original/adversarial image pair it:

1. **Computes similarity** using the mean RGB Structural Similarity Index
   (SSIM).
2. **Runs each classifier** on the adversarial image.
3. **Flags successful attacks** (prediction == "real" class).
4. **Aggregates** SSIM‑weighted attack indicators into a final score.

At the end it prints a detailed per‑model breakdown and, if requested, saves
a JSON report.

How It Works
------------
* **Configuration** – All runtime options are provided via a YAML file
  (`--config <path>`). Required fields:

  * `models_dir`      – directory with `*.pth` weight files
  * `classifiers`     – list containing **resnet50** and/or **densenet121**
  * `original_root`   – root folder of clean PNG images
  * `adv_root`        – root folder of adversarial PNG counterparts
  * `device`          – either `auto` (CUDA if available) or `cpu`
  * `aggregate`       – `mean` (default) or `sum` for scoring
  * `save_json`       – optional path where a detailed report is written

Dependencies
------------
* Python ≥ 3.9
* PyTorch ≥ 2.1  (`torch`, `torchvision`)
* NumPy, SciPy, scikit‑image, Pillow
* PyYAML, tqdm

GPU acceleration is automatic when `device: auto` and a CUDA GPU is present.

Quick Start
-----------
```bash
# 1. Install requirements (example with pip)
pip install torch torchvision pyyaml numpy pillow scikit-image tqdm scipy

# 2. Prepare the directories declared in the YAML config:
#    models_dir/    →  *.pth weight files (names match entries in 'classifiers')
#    original_root/ →  clean PNG images
#    adv_root/      →  adversarial PNG counterparts (same sub‑folder structure)

# 3. Run evaluation
python evaluate.py --config config.yaml
```
"""

from __future__ import annotations

import argparse
import json
import warnings
from glob import glob
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as T
import yaml  # PyYAML
from PIL import Image
from skimage.metrics import structural_similarity as ssim
from tqdm import tqdm
from torchvision.models import densenet121, resnet50

CLASS_IDX_REAL = 0  # index of the "real" class in the 2‑class output
CLASSES = 2         # number of output classes for every classifier
DEVICE_DEFAULT = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def load_cfg(cfg_path: str) -> dict:
    """Load a YAML configuration file."""
    with open(cfg_path, "r", encoding="utf-8") as f:
        cfg = yaml.safe_load(f)
    print(f"[CONFIG] Loaded configuration from {cfg_path}")
    return cfg


def get_device(choice: str | None) -> torch.device:
    """Return the torch device requested by the user (or auto‑selected)."""
    if choice is None or choice.lower() == "auto":
        print(f"[DEVICE] Using device {DEVICE_DEFAULT}")
        return DEVICE_DEFAULT
    print("[DEVICE] Using CPU")
    return torch.device("cpu")


def pil_to_np_rgb(path: str | Path) -> np.ndarray:
    """Read an image and return an (H, W, 3) RGB uint8 NumPy array."""
    return np.array(Image.open(path).convert("RGB"))


def compute_ssim_rgb(im1: np.ndarray, im2: np.ndarray) -> float:
    """Compute mean SSIM over the three RGB channels."""
    return sum(
        ssim(im1[..., c], im2[..., c], data_range=255) for c in range(3)
    ) / 3.0


def build_transform() -> T.Compose:
    """Return a common preprocessing transform for all classifiers."""
    return T.Compose(
        [
            T.Resize((256, 256)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )


def load_model(name: str, weight_path: Path, device: torch.device) -> nn.Module:
    """Instantiate a classifier backbone, adapt its final layer, and load weights."""

    print(f"[MODEL] Loading '{name}' weights from {weight_path}…")

    if name == "resnet50":
        model = resnet50()
        model.fc = nn.Linear(model.fc.in_features, CLASSES)
    elif name == "densenet121":
        model = densenet121()
        model.classifier = nn.Linear(model.classifier.in_features, CLASSES)
    else:
        raise ValueError(
            f"Unsupported model '{name}'. Only 'resnet50' and 'densenet121' "
            "are allowed."
        )

    state_dict = torch.load(weight_path, map_location=device)
    model.load_state_dict(state_dict)
    model.eval().to(device)
    print(f"[MODEL] '{name}' ready on {device}\n")
    return model


def evaluate(cfg: dict) -> None:
    device = get_device(cfg.get("device", "auto"))

    models_dir = Path(cfg["models_dir"])
    clf_names = cfg["classifiers"]

    if not clf_names:
        raise ValueError("The 'classifiers' list in the config cannot be empty.")

    transform = build_transform()
    classifiers: dict[str, dict] = {}

    for name in clf_names:
        if name not in {"resnet50", "densenet121"}:
            raise ValueError(
                f"Classifier '{name}' is not supported. "
                "Only 'resnet50' and 'densenet121' are allowed." 
            )
        weight_path = models_dir / f"{name}.pth"
        if not weight_path.exists():
            raise FileNotFoundError(
                f"Weight file for '{name}' not found: {weight_path}" )

        classifiers[name] = {
            "model": load_model(name, weight_path, device),
            "transform": transform,
            "indicators": [],
            "ssim_vals": [],
        }

    num_classifiers = len(classifiers)
    print(f"[SETUP] {num_classifiers} classifier(s) loaded\n")

    original_root = Path(cfg["original_root"])
    adv_root = Path(cfg["adv_root"])

    orig_paths = glob(str(original_root / "**/*.png"), recursive=True)
    if not orig_paths:
        raise RuntimeError("No PNG images found under 'original_root'.")
    print(f"[DATA] {len(orig_paths)} original images detected\n")

    running_sum = 0.0
    total_pairs = 0

    for o_path in tqdm(orig_paths, desc="Images", unit="image"):
        rel = Path(o_path).relative_to(original_root)
        a_path = adv_root / rel

        if not a_path.exists():
            warnings.warn(f"Missing adversarial counterpart for {rel}")
            continue

        img_o = pil_to_np_rgb(o_path)
        img_a = pil_to_np_rgb(a_path)

        ssim_val = compute_ssim_rgb(img_o, img_a)

        pair_contribution = 0.0
        for name, pack in classifiers.items():
            tensor = pack["transform"](Image.fromarray(img_a)).unsqueeze(0).to(device)
            with torch.no_grad():
                pred = pack["model"](tensor).argmax(1).item()
            indicator = int(pred == CLASS_IDX_REAL)

            pack["indicators"].append(indicator)
            pack["ssim_vals"].append(ssim_val)

            pair_contribution += ssim_val * indicator

        running_sum += pair_contribution
        total_pairs += 1

    if total_pairs == 0:
        print("[RESULT] No valid image pairs found – score = 0")
        return

    agg_mode = cfg.get("aggregate", "mean").lower()
    if agg_mode not in {"mean", "sum"}:
        warnings.warn("Unknown 'aggregate' mode – falling back to 'mean'.")
        agg_mode = "mean"

    if agg_mode == "mean":
        final_score = running_sum / (total_pairs * num_classifiers)
    else:  # 'sum'
        final_score = running_sum

    print("[RESULT] SUMMARY:")
    print(f"Images evaluated      : {total_pairs}")
    print(f"Classifiers considered: {num_classifiers}")
    print(f"Final score           : {final_score:.6f}")

    for name, pack in classifiers.items():
        acc = float(np.mean(pack["indicators"])) if pack["indicators"] else 0.0
        mean_ssim = float(np.mean(pack["ssim_vals"])) if pack["ssim_vals"] else 0.0
        print(f"[{name:<12}] attack_success={acc:.4f}  mean_ssim={mean_ssim:.4f}")

    out_json = cfg.get("save_json")
    if out_json:
        report = {
            "final_score": final_score,
            "images_evaluated": total_pairs,
            "per_classifier": {
                n: {
                    "attack_success": float(np.mean(c["indicators"]))
                    if c["indicators"]
                    else 0.0,
                    "mean_ssim": float(np.mean(c["ssim_vals"]))
                    if c["ssim_vals"]
                    else 0.0,
                }
                for n, c in classifiers.items()
            },
        }
        with open(out_json, "w", encoding="utf-8") as f:
            json.dump(report, f, indent=2)
        print(f"[RESULT] Detailed report written to {out_json}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Evaluate adversarial images (config-only)."
    )
    parser.add_argument("--config", required=True, help="Path to YAML configuration file.")
    args = parser.parse_args()

    config = load_cfg(args.config)
    evaluate(config)
