Source code for case_explainer.explanation

"""
Explanation object for case-based explanations.
"""

import numpy as np
import matplotlib.pyplot as plt
from typing import List, Tuple, Optional, Dict, Any


[docs] class Neighbor: """Represents a single nearest neighbor."""
[docs] def __init__( self, index: int, distance: float, label: int, features: np.ndarray, metadata: Optional[Dict[str, Any]] = None ): self.index = index self.distance = distance self.label = label self.features = features self.metadata = metadata or {}
def __repr__(self) -> str: meta_str = f", {self.metadata}" if self.metadata else "" return f"Neighbor(index={self.index}, distance={self.distance:.4f}, label={self.label}{meta_str})"
[docs] class Explanation: """ Explanation object containing case-based explanation details. """
[docs] def __init__( self, test_sample: np.ndarray, test_index: Optional[int], neighbors: List[Neighbor], predicted_class: int, true_class: Optional[int], correspondence: float, correspondence_interpretation: str, feature_names: Optional[List[str]] = None, class_names: Optional[Dict[int, str]] = None ): """ Initialize explanation. Args: test_sample: The test sample being explained test_index: Index in test set (if applicable) neighbors: List of Neighbor objects predicted_class: Predicted class label true_class: True class label (if available) correspondence: Correspondence score [0, 1] correspondence_interpretation: "high", "medium", or "low" feature_names: Names of features (optional) class_names: Mapping from class labels to names (optional) """ self.test_sample = test_sample self.test_index = test_index self.neighbors = neighbors self.predicted_class = predicted_class self.true_class = true_class self.correspondence = correspondence self.correspondence_interpretation = correspondence_interpretation self.feature_names = feature_names or [f"feature_{i}" for i in range(len(test_sample))] # Convert class_names list to dict if needed if isinstance(class_names, list): self.class_names = {i: name for i, name in enumerate(class_names)} else: self.class_names = class_names or {}
[docs] def get_predicted_class_name(self) -> str: """Get the predicted class name.""" return self.class_names.get(self.predicted_class, str(self.predicted_class))
[docs] def get_true_class_name(self) -> Optional[str]: """Get the true class name.""" if self.true_class is None: return None return self.class_names.get(self.true_class, str(self.true_class))
[docs] def is_correct(self) -> Optional[bool]: """Check if prediction matches true label (if available).""" if self.true_class is None: return None return self.predicted_class == self.true_class
[docs] def summary(self) -> str: """Generate a text summary of the explanation.""" lines = [] lines.append("=" * 60) lines.append("CASE-BASED EXPLANATION") lines.append("=" * 60) if self.test_index is not None: lines.append(f"Test sample index: {self.test_index}") lines.append(f"Predicted class: {self.get_predicted_class_name()}") if self.true_class is not None: correct_str = "[OK]" if self.is_correct() else "[X]" lines.append(f"True class: {self.get_true_class_name()} {correct_str}") lines.append(f"Correspondence: {self.correspondence:.2%} ({self.correspondence_interpretation})") lines.append("") lines.append(f"Nearest {len(self.neighbors)} neighbors:") lines.append("-" * 60) for i, neighbor in enumerate(self.neighbors, 1): neighbor_class = self.class_names.get(neighbor.label, str(neighbor.label)) match_str = "*" if neighbor.label == self.predicted_class else " " lines.append(f"{i}. [{match_str}] Index {neighbor.index}: class {neighbor_class}, " f"distance {neighbor.distance:.4f}") # Add metadata if available if neighbor.metadata: for key, value in neighbor.metadata.items(): lines.append(f" {key}: {value}") lines.append("=" * 60) return "\n".join(lines)
[docs] def to_dict(self) -> Dict[str, Any]: """Export explanation as dictionary (for JSON serialization).""" return { "test_index": self.test_index, "test_sample": self.test_sample.tolist(), "predicted_class": self.predicted_class, "predicted_class_name": self.get_predicted_class_name(), "true_class": self.true_class, "true_class_name": self.get_true_class_name(), "is_correct": self.is_correct(), "correspondence": float(self.correspondence), "correspondence_interpretation": self.correspondence_interpretation, "neighbors": [ { "index": n.index, "distance": float(n.distance), "label": n.label, "label_name": self.class_names.get(n.label, str(n.label)), "features": n.features.tolist(), "metadata": n.metadata } for n in self.neighbors ], "feature_names": self.feature_names }
[docs] def plot( self, plot_type: str = 'radar', highlight_differences: bool = True, show_distances: bool = True, save_path: Optional[str] = None, figsize: Tuple[int, int] = (12, 8) ) -> None: """ Visualize the explanation. Args: plot_type: 'radar', 'bar', or 'parallel' highlight_differences: Whether to highlight feature differences show_distances: Whether to show distance values save_path: Path to save figure (if provided) figsize: Figure size """ if plot_type == 'bar': self._plot_bar(figsize, save_path) elif plot_type == 'radar': self._plot_radar(figsize, save_path) elif plot_type == 'parallel': self._plot_parallel(figsize, save_path) else: raise ValueError(f"Unknown plot type: {plot_type}")
def _plot_bar(self, figsize: Tuple[int, int], save_path: Optional[str]) -> None: """Create bar plot comparing features.""" n_features = len(self.test_sample) n_neighbors = len(self.neighbors) fig, ax = plt.subplots(figsize=figsize) x = np.arange(n_features) width = 0.8 / (n_neighbors + 1) # Plot test sample ax.bar(x, self.test_sample, width, label='Test Sample', color='red', alpha=0.8, edgecolor='black', linewidth=2) # Plot neighbors colors = plt.cm.Blues(np.linspace(0.3, 0.8, n_neighbors)) for i, neighbor in enumerate(self.neighbors): offset = width * (i + 1) match_str = "*" if neighbor.label == self.predicted_class else "" ax.bar(x + offset, neighbor.features, width, label=f'Neighbor {i+1} {match_str}', color=colors[i], alpha=0.6) ax.set_xlabel('Features') ax.set_ylabel('Feature Values') ax.set_title(f'Case-Based Explanation (Correspondence: {self.correspondence:.2%})') ax.set_xticks(x + width * n_neighbors / 2) ax.set_xticklabels(self.feature_names, rotation=45, ha='right') ax.legend() ax.grid(axis='y', alpha=0.3) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=150, bbox_inches='tight') else: plt.show() def _plot_radar(self, figsize: Tuple[int, int], save_path: Optional[str]) -> None: """Create radar plot (not implemented yet - use bar for now).""" print("Radar plot not yet implemented, using bar plot instead.") self._plot_bar(figsize, save_path) def _plot_parallel(self, figsize: Tuple[int, int], save_path: Optional[str]) -> None: """Create parallel coordinates plot (not implemented yet - use bar for now).""" print("Parallel coordinates plot not yet implemented, using bar plot instead.") self._plot_bar(figsize, save_path) def __repr__(self) -> str: return (f"Explanation(test_index={self.test_index}, " f"predicted={self.get_predicted_class_name()}, " f"correspondence={self.correspondence:.2%}, " f"neighbors={len(self.neighbors)})")