Introduction: A library for model interpretability in PyTorch.
Added on: Jan 20, 2025
captum.ai

What is captum.ai

Captum is a multi-modal interpretability library designed to support models across various modalities such as vision and text. Built on PyTorch, it requires minimal modifications to existing neural networks and is extensible for interpretability research.

How to Use captum.ai

  1. Install Captum:
    • via conda (recommended): conda install captum -c pytorch
    • via pip: pip install captum
  2. Create and prepare model:
    import numpy as np
    import torch
    import torch.nn as nn
    from captum.attr import IntegratedGradients
    
    class ToyModel(nn.Module):
        def __init__(self):
            super().__init__()
            self.lin1 = nn.Linear(3, 3)
            self.relu = nn.ReLU()
            self.lin2 = nn.Linear(3, 2)
    
            # initialize weights and biases
            self.lin1.weight = nn.Parameter(torch.arange(-4.0, 5.0).view(3, 3))
            self.lin1.bias = nn.Parameter(torch.zeros(1,3))
            self.lin2.weight = nn.Parameter(torch.arange(-3.0, 3.0).view(2, 3))
            self.lin2.bias = nn.Parameter(torch.ones(1,2))
    
        def forward(self, input):
            return self.lin2(self.relu(self.lin1(input)))
    
    model = ToyModel()
    model.eval()
    
  3. Fix random seeds for deterministic computations:
    torch.manual_seed(123)
    np.random.seed(123)
    
  4. Define input and baseline tensors:
    input = torch.rand(2, 3)
    baseline = torch.zeros(2, 3)
    
  5. Select algorithm to instantiate and apply (Integrated Gradients in this example):
    ig = IntegratedGradients(model)
    attributions, delta = ig.attribute(input, baseline, target=0, return_convergence_delta=True)
    print('IG Attributions:', attributions)
    print('Convergence Delta:', delta)
    
  6. View Output:
    IG Attributions: tensor([[-0.5922, -1.5497, -1.0067],
                             [ 0.0000, -0.2219, -5.1991]])
    Convergence Delta: tensor([2.3842e-07, -4.7684e-07])
    

Features of captum.ai

  • Multi-Modal

    Supports interpretability of models across modalities including vision, text, and more.

  • Built on PyTorch

    Supports most types of PyTorch models and can be used with minimal modification to the original neural network.

  • Extensible

    Open source, generic library for interpretability research. Easily implement and benchmark new algorithms.