Module epiclass.core.shap_values
Module containing shap values related code (e.g. handling computation, analysing results).
Classes
class LGBM_SHAP_Handler (model_analyzer: EstimatorAnalyzer, logdir: Path | str)
-
Handle shap computations and data saving/loading.
Expand source code
class LGBM_SHAP_Handler: """Handle shap computations and data saving/loading.""" def __init__(self, model_analyzer: EstimatorAnalyzer, logdir: Path | str): self.logdir = logdir self.saver = SHAP_Saver(logdir=logdir) self.model_classes = list(model_analyzer.mapping.items()) self.model: LGBMClassifier = LGBM_SHAP_Handler._check_model_is_lgbm( model_analyzer ) @staticmethod def _check_model_is_lgbm(model_analyzer: EstimatorAnalyzer) -> LGBMClassifier: """Return lightgbm classifier if found, else raise ValueError.""" model = model_analyzer.classifier if isinstance(model, Pipeline): model = model.steps[-1][1] if not isinstance(model, LGBMClassifier): raise ValueError( f"Expected model to be a lightgbm classifier, but got {model} instead." ) return model def compute_shaps( self, background_dset: SomeData, evaluation_dset: SomeData, save=True, name="", num_workers: int = 4, ) -> Tuple[List[np.ndarray], shap.TreeExplainer]: """Compute shap values of lgbm model on evaluation dataset. Returns shap values and explainer """ explainer = shap.TreeExplainer( model=self.model, data=background_dset.signals, model_output="raw", feature_perturbation="interventional", ) if save: self.saver.save_to_npz( name=name + "_explainer_background", background_md5s=background_dset.ids, background_expectation=explainer.expected_value, # type: ignore classes=self.model_classes, ) shap_values = LGBM_SHAP_Handler._compute_shap_values_parallel( explainer=explainer, signals=evaluation_dset.signals, num_workers=num_workers, ) if save: self.saver.save_to_npz( name=name + "_evaluation", evaluation_md5s=evaluation_dset.ids, shap_values=shap_values, expected_value=explainer.expected_value, classes=self.model_classes, ) return shap_values, explainer @staticmethod def _compute_shap_values_parallel( explainer: shap.TreeExplainer, signals: ArrayLike, num_workers: int, ) -> List[np.ndarray]: # Split the signals into chunks for parallel processing signal_chunks = np.array_split(signals, num_workers) # Worker function def worker(chunk): explainer_copy = copy.deepcopy(explainer) return explainer_copy.shap_values(X=chunk) # Use ThreadPoolExecutor to compute shap_values in parallel with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: shap_values_chunks = list(executor.map(worker, signal_chunks)) # Concatenate the chunks if isinstance(shap_values_chunks[0], np.ndarray): # binary case shap_values = list(np.concatenate(shap_values_chunks, axis=0)) else: # multiclass case shap_values = [ np.concatenate([chunk[i] for chunk in shap_values_chunks], axis=0) for i in range(len(shap_values_chunks[0])) ] return shap_values
Methods
def compute_shaps(self, background_dset: SomeData, evaluation_dset: SomeData, save=True, name='', num_workers: int = 4) ‑> Tuple[List[numpy.ndarray], shap.explainers._tree.Tree]
-
Compute shap values of lgbm model on evaluation dataset.
Returns shap values and explainer
class NN_SHAP_Handler (model: LightningDenseClassifier, logdir: Path | str)
-
Handle shap computations and data saving/loading.
Expand source code
class NN_SHAP_Handler: """Handle shap computations and data saving/loading.""" def __init__(self, model: LightningDenseClassifier, logdir: Path | str): self.model = model self.model.eval() self.model_classes = list(self.model.mapping.items()) self.logdir = logdir self.saver = SHAP_Saver(logdir=logdir) def compute_shaps( self, background_dset: SomeData, evaluation_dset: SomeData, save=True, name="", num_workers: int = 4, ) -> Tuple[shap.DeepExplainer, List[np.ndarray]]: """Compute shap values of deep learning model on evaluation dataset by creating an explainer with background dataset. Returns explainer and shap values (as a list of matrix per class) """ explainer = shap.DeepExplainer( model=self.model, data=torch.from_numpy(background_dset.signals).float() ) if save: self.saver.save_to_npz( name=name + "_explainer_background", background_md5s=background_dset.ids, background_expectation=explainer.expected_value, # type: ignore classes=self.model_classes, ) signals = torch.from_numpy(evaluation_dset.signals).float() shap_values = NN_SHAP_Handler._compute_shap_values_parallel( explainer, signals, num_workers ) if save: self.saver.save_to_npz( name=name + "_evaluation", evaluation_md5s=evaluation_dset.ids, shap_values=shap_values, classes=self.model_classes, ) return explainer, shap_values # type: ignore @staticmethod def _compute_shap_values_parallel( explainer: shap.DeepExplainer, signals: torch.Tensor, num_workers: int, ) -> List[np.ndarray]: """Compute SHAP values in parallel using a ThreadPoolExecutor. Args: explainer (shap.DeepExplainer): The SHAP explainer object used for computing SHAP values. signals (torch.Tensor): The evaluation dataset samples as a torch Tensor of shape (#samples, #features). num_workers (int): The number of parallel threads to use for computation. Returns: List[np.ndarray]: A list of SHAP values matrices (one per output class) of shape (#samples, #features). """ signal_chunks = torch.tensor_split(signals, num_workers) def worker(chunk: torch.Tensor) -> np.ndarray: explainer_copy = copy.deepcopy(explainer) return explainer_copy.shap_values(chunk) # type: ignore with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: shap_values_chunks = list(executor.map(worker, signal_chunks)) shap_values = [ np.concatenate([chunk[i] for chunk in shap_values_chunks], axis=0) for i in range(len(shap_values_chunks[0])) ] return shap_values
Methods
def compute_shaps(self, background_dset: SomeData, evaluation_dset: SomeData, save=True, name='', num_workers: int = 4) ‑> Tuple[shap.explainers._deep.Deep, List[numpy.ndarray]]
-
Compute shap values of deep learning model on evaluation dataset by creating an explainer with background dataset.
Returns explainer and shap values (as a list of matrix per class)
class SHAP_Analyzer (model: LightningDenseClassifier, explainer: shap.DeepExplainer)
-
SHAP_Analyzer class for analyzing SHAP values of a model.
Attributes
model
- The trained model.
explainer
- The SHAP explainer object used to compute SHAP values.
Expand source code
class SHAP_Analyzer: """SHAP_Analyzer class for analyzing SHAP values of a model. Attributes: model: The trained model. explainer: The SHAP explainer object used to compute SHAP values. """ def __init__(self, model: LightningDenseClassifier, explainer: shap.DeepExplainer): self.model = model self.explainer = explainer def verify_shap_values_coherence( self, shap_values: List, dset: SomeData, tolerance=1e-6 # type: ignore ): """Verify the coherence of SHAP values with the model's output probabilities. Checks if the sum of SHAP values for each sample (across all classes) and the base values is approximately equal to the model's output probabilities. Args: shap_values: List of SHAP values for each class. dset: The dataset used to compute the SHAP values. The samples need to be in the same order as in list of shap values. tolerance: The allowed tolerance for the difference between the sum of SHAP values and the model's output probabilities (default is 1e-6). Returns: bool: True if the SHAP values are coherent, False otherwise. """ num_classes = len(shap_values) num_samples = shap_values[0].shape[0] # Calculate the sum of SHAP values for each sample (across all classes) and add base values shap_sum = np.zeros((num_samples, num_classes)) for i, shap_values_class in enumerate(shap_values): # shap_values_class.sum(axis=1).shape = (n_samples,) shap_sum[:, i] = ( shap_values_class.sum(axis=1) + self.explainer.expected_value[i] # type: ignore ) # Compute the model's output probabilities for the samples signals = torch.from_numpy(dset.signals).float() model_output_logits = self.model(signals).detach() probs = F.softmax(model_output_logits, dim=1).detach().numpy() shap_to_prob = ( F.softmax(torch.from_numpy(shap_sum).float(), dim=1) .sum(dim=1) .detach() .numpy() ) print( f"Verifying model output: Sum close to 1? {np.all(1 - probs.sum(axis=1) <= tolerance)}" ) print( f"Verifying SHAP output: Sum close to 1 (w expected value)? {np.all(1 - shap_to_prob <= tolerance)}" ) print(f"Shap sum shape, detailling all classes: {shap_sum.shape}") total = shap_sum.sum(axis=1) print( f"Sum of all shap values across classes, for {total.shape} samples: {total}\n" ) # Compare the sum of SHAP values with the model's output probabilities diff = np.abs(shap_sum - model_output_logits.numpy()) coherent = np.all(diff <= tolerance) print( f"Detailled values for shap sum, model preds and diff:\n {shap_sum}\n{model_output_logits}\n{diff}" ) if not coherent: problematic_samples = np.argwhere(diff > tolerance) print(f"SHAP values are not coherent for samples:\n {problematic_samples}") return coherent
Methods
def verify_shap_values_coherence(self, shap_values: List, dset: SomeData, tolerance=1e-06)
-
Verify the coherence of SHAP values with the model's output probabilities.
Checks if the sum of SHAP values for each sample (across all classes) and the base values is approximately equal to the model's output probabilities.
Args
shap_values
- List of SHAP values for each class.
dset
- The dataset used to compute the SHAP values. The samples need to be in the same order as in list of shap values.
tolerance
- The allowed tolerance for the difference between the sum of SHAP values and the model's output probabilities (default is 1e-6).
Returns
bool
- True if the SHAP values are coherent, False otherwise.
class SHAP_Saver (logdir: Path | str)
-
Handle shap data saving/loading.
Expand source code
class SHAP_Saver: """Handle shap data saving/loading.""" def __init__(self, logdir: Path | str): self.logdir = logdir self.filename_template = "shap_{name}_{time}.{ext}" def _create_filename(self, ext: str, name="") -> Path: """Create a filename with the given extension and name, and a timestamp.""" filename = self.filename_template.format(name=name, ext=ext, time=time_now_str()) filename = Path(self.logdir) / filename return filename def save_to_csv( self, shap_values_matrix: np.ndarray, ids: List[str], name: str ) -> Path: """Save a single shap value matrix (shape (n_samples, #features)) to csv. Giving a name is mandatory. Returns path of saved file. """ if isinstance(shap_values_matrix, list): raise ValueError( f"Expected 'shap_values_matrix' to be a numpy array of shape (n_samples, #features), but got a list instead: {shap_values_matrix}" # pylint: disable=line-too-long ) filename = self._create_filename(name=name, ext="csv") n_dims = shap_values_matrix.shape[1] df = pd.DataFrame(data=shap_values_matrix, index=ids, columns=range(n_dims)) print(f"Saving SHAP values to: {filename}") df.to_csv(filename) return filename def save_to_npz(self, name: str, verbose=True, **kwargs): """Save kwargs to numpy compressed npz file. Transforms everything into numpy arrays.""" filename = self._create_filename(name=name, ext="npz") if verbose: print(f"Saving SHAP values to: {filename}") np.savez_compressed( file=filename, **kwargs, # type: ignore ) @staticmethod def load_from_csv(path: Path | str) -> pd.DataFrame: """Return pandas dataframe of shap values for loaded file.""" return pd.read_csv(path, index_col=0)
Static methods
def load_from_csv(path: Path | str)
-
Return pandas dataframe of shap values for loaded file.
Methods
def save_to_csv(self, shap_values_matrix: np.ndarray, ids: List[str], name: str) ‑> pathlib.Path
-
Save a single shap value matrix (shape (n_samples, #features)) to csv. Giving a name is mandatory.
Returns path of saved file.
def save_to_npz(self, name: str, verbose=True, **kwargs)
-
Save kwargs to numpy compressed npz file. Transforms everything into numpy arrays.