Module epiclass.utils.shap.shap_utils
Module containing utility functions for shap files handling and a bit of analysis.
Functions
def extract_shap_values_and_info(shap_logdir: str | Path, verbose: bool = True)
-
Extract and print basic statistics about SHAP values from an archive.
Args
shap_logdir
:str
- The directory where the SHAP values archive is located.
verbose
:bool
- Whether to print basic statistics about the SHAP values.
Returns
shap_matrices (np.ndarray): SHAP matrices. eval_md5s (List[str]): List of evaluation MD5s. classes (List[Tuple[str, str]]): List of classes. Each class is a tuple containing the class index and the class label.
def get_archives(shap_values_dir: str | Path)
-
Extracts SHAP values and explainer background information from .npz files in a specified directory.
This function searches for files in the provided directory, specifically looking for files that match the patterns "evaluation.npz" and "explainer_background.npz". It loads these .npz files as dictionaries and returns them. The function raises a FileNotFoundError if the required files are not found in the directory.
Args
shap_values_dir (str | Path): The directory path where the .npz files are located.
Returns
Tuple[Dict, Dict]
- The first dictionary contains the SHAP values extracted from the "evaluation.npz" file,
and the second contains the explainer background information extracted from the "explainer_background.npz" file.
Raises
FileNotFoundError
- If either the SHAP values file or the explainer background file is not found
in the specified directory.
def get_shap_matrix(meta: Metadata, shap_matrices: np.ndarray, eval_md5s: List[str], label_category: str, selected_labels: List[str], class_idx: int, copy_meta: bool = True) ‑> Tuple[numpy.ndarray, List[int]]
-
Generates a SHAP matrix corresponding to a selected subset of samples.
This function selects a subset of samples based on specified criteria and then generates a SHAP matrix for these selected samples. It filters the metadata if a specific target subsample is provided, and selects a subset of samples that are identified by their md5 hash. It then selects the SHAP values of these samples under the matrix of the given class number.
Args
meta
:metadata.Metadata
- Metadata object containing information about the samples.
shap_matrices
:np.ndarray
- Array of SHAP matrices for each class.
eval_md5s
:List[str]
- List of md5 hashes identifying the evaluation samples.
label_category
:str
- Name of the category in the metadata that contains the desired labels.
selected_labels
:List[str]
- Name of the classes for which samples will be considered.
class_idx
:int
- Index of the class for which the shap values matrix will be used.
Returns
np.ndarray
- The selected SHAP matrix for the selected class and for the chosen samples based on the provided criteria.
List[int]
- The indices of the chosen samples in the original SHAP matrix.
Raises
IndexError
- If the
class_idx
is out of bounds for theshap_matrices
.
def n_most_important_features(sample_shaps: np.ndarray, n: int) ‑> numpy.ndarray
-
Return indices of features with the highest absolute shap values.
Args
sample_shaps
:np.ndarray
- Array of SHAP values for a single sample.
n
:int
- Number of top features to return.
Returns
np.ndarray
- Indices of the top
n
features with the highest absolute SHAP values.
def select_random_shap_samples(shap_dict: Dict[str, List[np.ndarray]], n: int) ‑> Dict[str, List[numpy.ndarray]]
-
Selects a random subset of SHAP values and their corresponding IDs from a given dictionary.
This function randomly selects 'n' samples from the provided SHAP values. It ensures that the selection is non-repetitive. The function is designed to work with a dictionary containing SHAP values and their corresponding IDs. The resulting subset contains both SHAP values and IDs, maintaining their association.
Args
shap_dict
:Dict[str, List[np.ndarray]]
- A dictionary with two keys: 'shap' and 'ids'. 'shap' should be a list of numpy arrays containing SHAP values, and 'ids' should be a list of identifiers corresponding to each SHAP value.
n
:int
- The number of random samples to select. If 'n' is larger than the total number of samples available, all samples are returned without duplication.
Returns
Dict[str, List[np.ndarray]]
- A dictionary containing two keys: 'shap' and 'ids'. 'shap' is a list of numpy arrays representing the randomly selected SHAP values, and 'ids' is a list of the corresponding identifiers. The length of the lists equals 'n', or the total number of samples if 'n' is larger than the available samples.
Raises
ValueError
- If 'n' is negative.
IndexError
- If the provided 'shap_dict' does not contain the required keys ('shap' and 'ids').
def subsample_md5s(md5s: List[str], metadata: Metadata, category_label: str, labels: List[str], copy_metadata: bool = True) ‑> List[int]
-
Subsample md5s index based on metadata filtering provided, for a given category and filtering labels.
Args
md5s
:list
- A list of MD5 hashes.
metadata
:Metadata
- A metadata object containing the data to be filtered.
category_label
:str
- The category label to be used for filtering the metadata.
labels
:list
- A list of labels to be used for selecting category subsets in the metadata.
Returns
list
- A list of indices corresponding to the selected md5s.