Introduction:
A library for model interpretability in PyTorch.
Added on:
Jan 20, 2025

What is captum.ai
Captum is a multi-modal interpretability library designed to support models across various modalities such as vision and text. Built on PyTorch, it requires minimal modifications to existing neural networks and is extensible for interpretability research.
How to Use captum.ai
-
Install Captum:
- via conda (recommended):
conda install captum -c pytorch
- via pip:
pip install captum
- via conda (recommended):
-
Create and prepare model:
import numpy as np import torch import torch.nn as nn from captum.attr import IntegratedGradients class ToyModel(nn.Module): def __init__(self): super().__init__() self.lin1 = nn.Linear(3, 3) self.relu = nn.ReLU() self.lin2 = nn.Linear(3, 2) # initialize weights and biases self.lin1.weight = nn.Parameter(torch.arange(-4.0, 5.0).view(3, 3)) self.lin1.bias = nn.Parameter(torch.zeros(1,3)) self.lin2.weight = nn.Parameter(torch.arange(-3.0, 3.0).view(2, 3)) self.lin2.bias = nn.Parameter(torch.ones(1,2)) def forward(self, input): return self.lin2(self.relu(self.lin1(input))) model = ToyModel() model.eval()
-
Fix random seeds for deterministic computations:
torch.manual_seed(123) np.random.seed(123)
-
Define input and baseline tensors:
input = torch.rand(2, 3) baseline = torch.zeros(2, 3)
-
Select algorithm to instantiate and apply (Integrated Gradients in this example):
ig = IntegratedGradients(model) attributions, delta = ig.attribute(input, baseline, target=0, return_convergence_delta=True) print('IG Attributions:', attributions) print('Convergence Delta:', delta)
-
View Output:
IG Attributions: tensor([[-0.5922, -1.5497, -1.0067], [ 0.0000, -0.2219, -5.1991]]) Convergence Delta: tensor([2.3842e-07, -4.7684e-07])
Features of captum.ai
-
Multi-Modal
Supports interpretability of models across modalities including vision, text, and more.
-
Built on PyTorch
Supports most types of PyTorch models and can be used with minimal modification to the original neural network.
-
Extensible
Open source, generic library for interpretability research. Easily implement and benchmark new algorithms.