Source code for xaicompare.adapters.explainers.explainer_shap_tree
# xaicompare/adapters/explainers/explainer_shap_tree.py
import numpy as np
import shap
from scipy.sparse import spmatrix
from typing import Any, Dict, List, Tuple
from xaicompare.adapters.explainers.explainer_base import ExplainerAdapter
from xaicompare.registry.xai_registry import register_xai
[docs]
@register_xai("shap_tree")
class ShapTreeExplainerAdapter(ExplainerAdapter):
"""
SHAP Tree Explainer with:
- Safe sparse handling
- Forced vectorization
- Guaranteed 2-D input to SHAP
- Batch processing to avoid memory errors
- Multi-class normalization
"""
[docs]
def limitation_text(self):
return "This is some test text about the limitations of the SHAP Tree Explainer."
def __init__(self, model_adapter, config: Dict[str, Any]):
super().__init__(model_adapter, config)
# ----------------------------------------------------------------------
# GLOBAL IMPORTANCE
# ----------------------------------------------------------------------
[docs]
def global_importance(self, X, rows_limit: int = 200) -> Tuple[np.ndarray, List[str]]:
"""
Compute mean|SHAP| across first rows_limit samples safely.
"""
# Extract raw values from numpy/pandas objects
# ----------------------------------------------------------------------
# LOCAL EXPLANATIONS
# ----------------------------------------------------------------------
[docs]
def local_explanations(self, x_row) -> np.ndarray:
"""
Compute SHAP for a single example → return signed vector.
"""
# Force shape (1, raw_item)