CaseExplainer
The main class for creating case-based explanations.
- class case_explainer.CaseExplainer(X_train, y_train, k=5, feature_names=None, class_names=None, metric='euclidean', algorithm='auto', scale_data=True, class_weights=None, metadata=None, n_jobs=-1)[source]
Bases:
objectGeneral-purpose case-based explainability module.
Provides model-agnostic explanations through training set precedent and nearest neighbor correspondence. Builds k-NN index during initialization for fast lookups during explanation.
Based on refined Method 2 from hardware trojan detection pipeline: - Pre-builds NearestNeighbors index on training data - Uses distance-weighted correspondence: weight = 1 / (distance + 1)^3 - Supports class weights for imbalanced datasets - Compatible with any classifier (sklearn, XGBoost, etc.)
Example
>>> from case_explainer import CaseExplainer >>> explainer = CaseExplainer(X_train, y_train, k=5) >>> explanation = explainer.explain_instance(test_sample, model=clf) >>> print(f"Correspondence: {explanation.correspondence:.2%}") >>> explanation.plot()
- __init__(X_train, y_train, k=5, feature_names=None, class_names=None, metric='euclidean', algorithm='auto', scale_data=True, class_weights=None, metadata=None, n_jobs=-1)[source]
Initialize CaseExplainer with training data and build k-NN index.
- Parameters:
X_train (
Union[ndarray,DataFrame]) – Training features (n_samples, n_features)y_train (
Union[ndarray,Series,List]) – Training labels (n_samples,)k (
int) – Number of nearest neighbors for explanations (default: 5)feature_names (
Optional[List[str]]) – Names of features (optional)class_names (
Optional[Dict[int,str]]) – Mapping from class labels to names (optional)metric (
str) – Distance metric (default: ‘euclidean’)algorithm (
str) – k-NN algorithm - ‘auto’, ‘ball_tree’, ‘kd_tree’, ‘brute’ (default: ‘auto’)scale_data (
bool) – Whether to standardize features (recommended: True)class_weights (
Optional[Dict[int,float]]) – Optional weights for each class in correspondence computation e.g., {0: 1.0, 1: 2.0} to weight class 1 twice as muchmetadata (
Optional[Dict[str,List]]) – Optional dict with metadata for each training sample e.g., {‘sample_id’: […], ‘source’: […], …}n_jobs (
int) – Number of parallel jobs for k-NN search (-1 = all CPUs)
- explain_instance(test_sample, test_index=None, true_class=None, predicted_class=None, model=None, k=None, return_provenance=True, distance_weighted=True)[source]
Explain a prediction using case-based reasoning with k-NN precedent.
This method: 1. Finds k nearest neighbors in the pre-built index 2. Computes weighted correspondence based on neighbor labels 3. Returns explanation with neighbor details and correspondence score
- Parameters:
test_sample (
Union[ndarray,Series,List]) – Sample to explain (n_features,)test_index (
Optional[int]) – Index in test set (optional, for tracking)true_class (
Optional[int]) – True class label (optional, for validation)predicted_class (
Optional[int]) – Predicted class (optional, will use model if not provided)model (
Optional[Any]) – Trained model with predict() method (optional)k (
Optional[int]) – Number of neighbors (optional, uses default from init if not provided)return_provenance (
bool) – Include metadata in explanationdistance_weighted (
bool) – Use distance weighting for correspondence
- Return type:
- Returns:
Explanation object with neighbors and correspondence
- explain_batch(X_test, y_test=None, predictions=None, model=None, k=None, return_provenance=True, distance_weighted=True)[source]
Explain multiple predictions efficiently.
- Parameters:
X_test (
Union[ndarray,DataFrame]) – Test samples (n_samples, n_features)y_test (
Union[ndarray,Series,List,None]) – True labels (optional)predictions (
Union[ndarray,List,None]) – Predicted labels (optional, will use model if not provided)k (
Optional[int]) – Number of neighbors (optional, uses default from init)return_provenance (
bool) – Include metadatadistance_weighted (
bool) – Use distance weighting
- Return type:
- Returns:
List of Explanation objects
Core Methods
Building the Explainer
- CaseExplainer.__init__(X_train, y_train, k=5, feature_names=None, class_names=None, metric='euclidean', algorithm='auto', scale_data=True, class_weights=None, metadata=None, n_jobs=-1)[source]
Initialize CaseExplainer with training data and build k-NN index.
- Parameters:
X_train (
Union[ndarray,DataFrame]) – Training features (n_samples, n_features)y_train (
Union[ndarray,Series,List]) – Training labels (n_samples,)k (
int) – Number of nearest neighbors for explanations (default: 5)feature_names (
Optional[List[str]]) – Names of features (optional)class_names (
Optional[Dict[int,str]]) – Mapping from class labels to names (optional)metric (
str) – Distance metric (default: ‘euclidean’)algorithm (
str) – k-NN algorithm - ‘auto’, ‘ball_tree’, ‘kd_tree’, ‘brute’ (default: ‘auto’)scale_data (
bool) – Whether to standardize features (recommended: True)class_weights (
Optional[Dict[int,float]]) – Optional weights for each class in correspondence computation e.g., {0: 1.0, 1: 2.0} to weight class 1 twice as muchmetadata (
Optional[Dict[str,List]]) – Optional dict with metadata for each training sample e.g., {‘sample_id’: […], ‘source’: […], …}n_jobs (
int) – Number of parallel jobs for k-NN search (-1 = all CPUs)
Explaining Predictions
- CaseExplainer.explain_instance(test_sample, test_index=None, true_class=None, predicted_class=None, model=None, k=None, return_provenance=True, distance_weighted=True)[source]
Explain a prediction using case-based reasoning with k-NN precedent.
This method: 1. Finds k nearest neighbors in the pre-built index 2. Computes weighted correspondence based on neighbor labels 3. Returns explanation with neighbor details and correspondence score
- Parameters:
test_sample (
Union[ndarray,Series,List]) – Sample to explain (n_features,)test_index (
Optional[int]) – Index in test set (optional, for tracking)true_class (
Optional[int]) – True class label (optional, for validation)predicted_class (
Optional[int]) – Predicted class (optional, will use model if not provided)model (
Optional[Any]) – Trained model with predict() method (optional)k (
Optional[int]) – Number of neighbors (optional, uses default from init if not provided)return_provenance (
bool) – Include metadata in explanationdistance_weighted (
bool) – Use distance weighting for correspondence
- Return type:
- Returns:
Explanation object with neighbors and correspondence
- CaseExplainer.explain_batch(X_test, y_test=None, predictions=None, model=None, k=None, return_provenance=True, distance_weighted=True)[source]
Explain multiple predictions efficiently.
- Parameters:
X_test (
Union[ndarray,DataFrame]) – Test samples (n_samples, n_features)y_test (
Union[ndarray,Series,List,None]) – True labels (optional)predictions (
Union[ndarray,List,None]) – Predicted labels (optional, will use model if not provided)k (
Optional[int]) – Number of neighbors (optional, uses default from init)return_provenance (
bool) – Include metadatadistance_weighted (
bool) – Use distance weighting
- Return type:
- Returns:
List of Explanation objects
Example Usage
Basic Example
from case_explainer import CaseExplainer
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
# Load and split data
data = load_breast_cancer()
X_train, X_test, y_train, y_test = train_test_split(
data.data, data.target, test_size=0.3, random_state=42
)
# Train classifier
clf = RandomForestClassifier(n_estimators=100, random_state=42)
clf.fit(X_train, y_train)
# Create explainer
explainer = CaseExplainer(
X_train=X_train,
y_train=y_train,
feature_names=data.feature_names,
algorithm='ball_tree',
scale_data=True
)
# Explain a prediction
explanation = explainer.explain_instance(
test_sample=X_test[0],
k=5,
model=clf,
true_class=y_test[0]
)
print(f"Correspondence: {explanation.correspondence:.2%}")
print(f"Predicted class: {explanation.predicted_class}")
print(f"Correct: {explanation.is_correct()}")
Batch Explanations
# Explain multiple predictions at once
explanations = explainer.explain_batch(
X_test[:100],
k=5,
y_test=y_test[:100],
model=clf
)
# Analyze correspondence distribution
correspondences = [exp.correspondence for exp in explanations]
correct_corr = [exp.correspondence for exp in explanations if exp.is_correct()]
incorrect_corr = [exp.correspondence for exp in explanations if not exp.is_correct()]
print(f"Mean correspondence: {sum(correspondences)/len(correspondences):.2%}")
print(f"Correct predictions: {sum(correct_corr)/len(correct_corr):.2%}")
print(f"Incorrect predictions: {sum(incorrect_corr)/len(incorrect_corr):.2%}")
Working with Metadata
# Attach metadata to training samples
metadata = {
'sample_id': [f"patient_{i}" for i in range(len(X_train))],
'date': ['2024-01-01'] * len(X_train),
'source': ['hospital_A'] * len(X_train)
}
explainer = CaseExplainer(
X_train=X_train,
y_train=y_train,
metadata=metadata,
algorithm='ball_tree'
)
# Access metadata in explanations
explanation = explainer.explain_instance(X_test[0], k=5, model=clf)
for neighbor in explanation.neighbors:
print(f"Neighbor {neighbor.index}: {neighbor.metadata}")
Configuration Options
Algorithm Selection
Choose the indexing algorithm based on your data characteristics:
kd_tree: Best for low-dimensional data (<20 features), fastest for small to medium datasets
ball_tree: Better for high-dimensional data (>20 features), good all-around choice
brute: Exact search, only recommended for small datasets (<5k samples)
auto: Let scikit-learn choose based on data characteristics (default)
# For low-dimensional data
explainer = CaseExplainer(X_train, y_train, algorithm='kd_tree')
# For high-dimensional data
explainer = CaseExplainer(X_train, y_train, algorithm='ball_tree')
# For very small datasets
explainer = CaseExplainer(X_train, y_train, algorithm='brute')
Feature Scaling
Feature scaling is recommended to prevent features with large ranges from dominating distance calculations:
# With scaling (recommended)
explainer = CaseExplainer(X_train, y_train, scale_data=True)
# Without scaling (if features are already normalized)
explainer = CaseExplainer(X_train, y_train, scale_data=False)
Class Weights
For imbalanced datasets, you can weight classes differently in correspondence computation:
# Weight minority class more heavily
explainer = CaseExplainer(
X_train, y_train,
class_weights={0: 1.0, 1: 5.0} # Weight class 1 five times more
)
Notes
Performance Considerations
Index building time is O(n log n) for tree-based methods
Query time is O(log n) for tree-based methods, O(n) for brute force
Memory usage scales with dataset size and dimensionality
Use
n_jobs=-1to parallelize nearest neighbor search
Correspondence Interpretation
High (≥85%): Strong agreement with training precedent, high confidence
Medium (70-85%): Moderate agreement, reasonable confidence
Low (<70%): Weak agreement, prediction may be uncertain or unusual
See Also
Explanation: The explanation object returned byexplain_instancecase_explainer.metrics: Correspondence and distance metrics