Module epiclass.utils.shap.shap_utils

Module containing utility functions for shap files handling and a bit of analysis.

Functions

def extract_shap_values_and_info(shap_logdir: str | Path, verbose: bool = True)

Extract and print basic statistics about SHAP values from an archive.

Args

shap_logdir : str
The directory where the SHAP values archive is located.
verbose : bool
Whether to print basic statistics about the SHAP values.

Returns

shap_matrices (np.ndarray): SHAP matrices. eval_md5s (List[str]): List of evaluation MD5s. classes (List[Tuple[str, str]]): List of classes. Each class is a tuple containing the class index and the class label.

def get_archives(shap_values_dir: str | Path)

Extracts SHAP values and explainer background information from .npz files in a specified directory.

This function searches for files in the provided directory, specifically looking for files that match the patterns "evaluation.npz" and "explainer_background.npz". It loads these .npz files as dictionaries and returns them. The function raises a FileNotFoundError if the required files are not found in the directory.

Args

shap_values_dir (str | Path): The directory path where the .npz files are located.

Returns

Tuple[Dict, Dict]
The first dictionary contains the SHAP values extracted from the "evaluation.npz" file,

and the second contains the explainer background information extracted from the "explainer_background.npz" file.

Raises

FileNotFoundError
If either the SHAP values file or the explainer background file is not found

in the specified directory.

def get_shap_matrix(meta: Metadata, shap_matrices: np.ndarray, eval_md5s: List[str], label_category: str, selected_labels: List[str], class_idx: int, copy_meta: bool = True) ‑> Tuple[numpy.ndarray, List[int]]

Generates a SHAP matrix corresponding to a selected subset of samples.

This function selects a subset of samples based on specified criteria and then generates a SHAP matrix for these selected samples. It filters the metadata if a specific target subsample is provided, and selects a subset of samples that are identified by their md5 hash. It then selects the SHAP values of these samples under the matrix of the given class number.

Args

meta : metadata.Metadata
Metadata object containing information about the samples.
shap_matrices : np.ndarray
Array of SHAP matrices for each class.
eval_md5s : List[str]
List of md5 hashes identifying the evaluation samples.
label_category : str
Name of the category in the metadata that contains the desired labels.
selected_labels : List[str]
Name of the classes for which samples will be considered.
class_idx : int
Index of the class for which the shap values matrix will be used.

Returns

np.ndarray
The selected SHAP matrix for the selected class and for the chosen samples based on the provided criteria.
List[int]
The indices of the chosen samples in the original SHAP matrix.

Raises

IndexError
If the class_idx is out of bounds for the shap_matrices.
def n_most_important_features(sample_shaps: np.ndarray, n: int) ‑> numpy.ndarray

Return indices of features with the highest absolute shap values.

Args

sample_shaps : np.ndarray
Array of SHAP values for a single sample.
n : int
Number of top features to return.

Returns

np.ndarray
Indices of the top n features with the highest absolute SHAP values.
def select_random_shap_samples(shap_dict: Dict[str, List[np.ndarray]], n: int) ‑> Dict[str, List[numpy.ndarray]]

Selects a random subset of SHAP values and their corresponding IDs from a given dictionary.

This function randomly selects 'n' samples from the provided SHAP values. It ensures that the selection is non-repetitive. The function is designed to work with a dictionary containing SHAP values and their corresponding IDs. The resulting subset contains both SHAP values and IDs, maintaining their association.

Args

shap_dict : Dict[str, List[np.ndarray]]
A dictionary with two keys: 'shap' and 'ids'. 'shap' should be a list of numpy arrays containing SHAP values, and 'ids' should be a list of identifiers corresponding to each SHAP value.
n : int
The number of random samples to select. If 'n' is larger than the total number of samples available, all samples are returned without duplication.

Returns

Dict[str, List[np.ndarray]]
A dictionary containing two keys: 'shap' and 'ids'. 'shap' is a list of numpy arrays representing the randomly selected SHAP values, and 'ids' is a list of the corresponding identifiers. The length of the lists equals 'n', or the total number of samples if 'n' is larger than the available samples.

Raises

ValueError
If 'n' is negative.
IndexError
If the provided 'shap_dict' does not contain the required keys ('shap' and 'ids').
def subsample_md5s(md5s: List[str], metadata: Metadata, category_label: str, labels: List[str], copy_metadata: bool = True) ‑> List[int]

Subsample md5s index based on metadata filtering provided, for a given category and filtering labels.

Args

md5s : list
A list of MD5 hashes.
metadata : Metadata
A metadata object containing the data to be filtered.
category_label : str
The category label to be used for filtering the metadata.
labels : list
A list of labels to be used for selecting category subsets in the metadata.

Returns

list
A list of indices corresponding to the selected md5s.