[1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

Tutorial B2: Ablation

predict + ersatz (shuffle/dinucleotide shuffle) + predict

In the context of evaluating sequence features, the conceptual opposite of a marginalization experiment is an ablation experiment. A marginalization experiment involves holding fixed a short sequence of interest, such as a protein binding site, but having variable background sequence everywhere else. In contrast, an ablation experiment involves shuffling given coordinates while keeping everything else fixed. Put another way: a marginalization experiment usually involves adding a potentially informative feature to uninformative sequence, whereas an ablation experiment involves removes a potentially informative feature from informative sequence.

A common use-case of an ablation experiment is to systematically remove regions of interest to evaluate their marginal and joint effects. For example, in the genomics setting, one may systematically remove all individual protein binding sites from an enhancer or promoter of interest and evaluate how the predictions change. As a more sophisticated example, one might try deleting each of the parts of a composite fusion motif, where multiple motifs are merged together in non-canonical ways, to better understand each component.

Ablate

Ablations can be done in tangermeme using the ablate function. Much like the other functions, ablate takes in a predictive model, one or more sequences, and coordinates to shuffle. Since there are multiple ways that one could shuffle a set of given coordinates, e.g. uniformly random shuffling, replacing with randomly generated sequence, dinucleotide shuffling, etc., ablate allows the user to optionally pass in a shuffling function.

Likewise similar to other functions, ablate returns the predictions on the original sequences and on all n shuffles of the given coordinates. This means that if y_before.shape == (5, 1) because the model yields one value for each of five examples, then y_after.shape == (5, n, 1) where n is the number of shuffles to do on the specified regions. It is up to the user to determine the best way to take these predictions and yield a single score.

[2]:
import torch
from model import Beluga

model = Beluga()
model.load_state_dict(torch.load("deepsea.beluga.torch"))
[2]:
<All keys matched successfully>

Let’s construct a region that the model will respond to by generating some random sequence and then inserting an AP-1 motif into it.

[3]:
from tangermeme.utils import random_one_hot
from tangermeme.ersatz import substitute

X = random_one_hot((100, 4, 2000)).type(torch.float32)  # Generate 100 random sequences to marginalize over
X = substitute(X, "GTGACTCATC")

Now, we can look at the predictions from the model before and after shuffling the middle of the examples – where the motif has been inserted.

[4]:
from tangermeme.ablate import ablate

y_before, y_after = ablate(model, X, 995, 1005, device='cpu')
y_before.shape, y_after.shape
[4]:
(torch.Size([100, 2002]), torch.Size([100, 20, 2002]))

Note that y_after has an additional dimension when compared with y_before. As mentioned before, this is because the shuffling is done n times (default 20) and predictions are returned for each shuffle. For simplicity, we can just take the average across the 20 runs.

[5]:
from matplotlib import pyplot as plt
import seaborn; seaborn.set_style('whitegrid')


plt.scatter(y_before.mean(axis=0).numpy(force=True), y_after.mean(axis=(0, 1)).numpy(force=True), s=1)
plt.xlabel("Before Ablation")
plt.ylabel("After Ablation")
plt.tight_layout()
plt.show()
../_images/tutorials_Tutorial_B2_Ablation_8_0.png

As one might expect, it looks like the majority of tasks are unaffected by the ablation of the AP-1 motif from the sequence. But we can rank tasks by those that are most effected to see which ones they are.

[6]:
import numpy

names = numpy.array([name[:40].replace("_None", "") for name in numpy.loadtxt("beluga_target_names.txt", delimiter=',', dtype=str)])

y_delta = y_after.mean(axis=1) - y_before
idxs = torch.argsort(y_delta.mean(axis=0))
ymax = max(abs(y_delta.min()), abs(y_delta.max()))


plt.imshow(y_delta[:, idxs].T, aspect='auto', cmap='RdBu_r', vmin=-ymax, vmax=ymax, interpolation='nearest')
plt.colorbar(label="Difference in Logits")
plt.ylim(0, 30)
plt.yticks(numpy.arange(0.5, 30.5, 1), names[idxs[:30]], fontsize=8)
plt.xlabel("Sequence")
plt.grid(False)
plt.show()
../_images/tutorials_Tutorial_B2_Ablation_10_0.png

It seems like those most effected by the ablation of the AP-1 motif are tasks related to the AP-1 factors (c-/Jun, c-/Fos, JunD, etc..).

To continue with this trend of results being the opposite of what one might expect from marginalizations, let’s load up a BPNet GATA2 model and see what happens to model predictions before and after removing the GATA motif from a genomic sequence.

[7]:
bpnet = torch.load("gata2.bpnet.torch")

X = random_one_hot((100, 4, 2114)).type(torch.float32)  # Generate 100 random sequences with the right size for BPNet to marginalize over
X = substitute(X, "CTTATCT")

y_before, y_after = ablate(bpnet, X, 1052, 1062, args=(torch.zeros(100, 2, 2114),), device='cpu')

plt.figure(figsize=(6, 4))
plt.scatter(y_before[1], y_after[1].mean(axis=1))
plt.plot([1.1, 2.1], [1.1, 2.1], color='0.7')
plt.xlabel("Without Motif")
plt.ylabel("With Motif")
plt.show()
../_images/tutorials_Tutorial_B2_Ablation_12_0.png

As expected, the model predicts a lower value after ablating out the motif.

[8]:
yb = torch.nn.functional.softmax(y_before[0].mean(axis=0), dim=-1)
ya = torch.nn.functional.softmax(y_after[0].mean(axis=(0, 1)), dim=-1)

plt.figure(figsize=(6, 4))
plt.plot(yb[0], color='0.7', linewidth=1, label="+ strand before")
plt.plot(yb[1], color='0.3', linewidth=1, label="- strand before")

plt.plot(ya[0], color='r', linewidth=1, label="+ strand after")
plt.plot(ya[1], color='b', linewidth=1, label="- strand after")
plt.xlabel("Genomic Position")
plt.ylabel("Predicted Probability")
plt.legend(fontsize=8)
plt.show()
../_images/tutorials_Tutorial_B2_Ablation_14_0.png

Likewise, we can see the inverse pattern here from the marginalization experiment. Before ablating, the profile predictions show clear bimodal peaks in the profile predictions. After ablating, the model is predicting largely flat signal.

A note for all of these ablation experiments (the DeepSEA and the BPNet ones) is that, although they do seem to show exactly the opposite results as the marginalization experiments, this is because we are not using the ablation function in the setting it is most likely to be used in practice. Basically, here, we are constructing synthetic background sequences, inserting a motif in, and then ablating it out. In practice, one is going to take a real region of the genome and ablate out one or more motifs instead of focusing on the individual effect of one motif isolated from the others.

Ablate Annotations

Above, we saw the core functionality for how one can perform ablations. In that functionality, one passes usually one sequence at a time with a fixed coordinate to ablate out. However, sometimes one would like to ablate out many locations on the same example individually and observe the effect that removing each span has. The ablate_annotations function allows you to do that by passing in a set of annotations in the form of (example_idx, start, end) coordinates. This format is similar to the variant effect prediction format and allows for any number of annotations to be ablated for each example. Each coordinate will then be ablated independently of the other annotations, and the result will be the model predictions before and after the ablation is performed.

Let’s start by getting some annotations for a sequence. An easy way to do this (and usually the way you will do this in practice) is to call seqlets on attributions.

[9]:
from tangermeme.ersatz import multisubstitute
from tangermeme.deep_lift_shap import deep_lift_shap
from tangermeme.plot import plot_logo
from tangermeme.seqlet import recursive_seqlets

X = random_one_hot((1, 4, 2000), random_state=0).type(torch.float32)
X = multisubstitute(X, ["GTGACTCATC", "GTGACTCATC", "ACTTCCGCG"], [25, 25])

X_attr = deep_lift_shap(model, X, device='cpu', n_shuffles=20, target=57, random_state=0)
seqlets = recursive_seqlets(X_attr.sum(dim=1), additional_flanks=5)

plt.figure(figsize=(12, 2))
ax = plt.subplot(111)
plot_logo(X_attr[0], ax=ax, start=925, end=1100, annotations=seqlets, score_key='attribution')
plt.ylim(-0.025, 0.25)

plt.xlabel("Genomic Coordinate")
plt.ylabel("Attribution")
plt.title("DeepLIFT/SHAP")
plt.show()
../_images/tutorials_Tutorial_B2_Ablation_17_0.png

Without needing to even know the identity of these binding sites (so no need to run TOMTOM to map the sequence to a motif database), we can calculate the ablation effect by simply iterating over each seqlet, scrambling the sequence within, and returning the model predictions before and after making the scramble.

[10]:
from tangermeme.ablate import ablate_annotations

annotations = seqlets[['example_idx', 'start', 'end']].values

y_before, y_after = ablate_annotations(model, X, annotations)
y_before.shape, y_after.shape
[10]:
(torch.Size([5, 1, 2002]), torch.Size([5, 1, 20, 2002]))

Because the function is only returning the predictions before and after the ablation, we can do whatever we would like to aggregate these differences into a score function. For simplicity, we will just look at the difference in predictions averaged over the 20 shuffles that ablate does by default.

[11]:
deltas = (y_after[:, :, :, 57] - y_before[:, :, None, 57]).mean(axis=-1)[:, 0]
deltas
[11]:
tensor([ 0.0501, -0.6054,  0.1467, -0.1593, -0.2766])
[12]:
seqlets
[12]:
example_idx start end attribution p-value
1 0 1573 1588 -0.032544 4.662937e-15
0 0 1028 1044 0.727307 4.798464e-03
2 0 927 942 -0.001295 5.285412e-03
3 0 957 972 0.643915 6.717850e-03
4 0 993 1007 0.548725 8.579600e-03

Remember (from the seqlet calling tutorial) that the first and third seqlets are mostly artifacts of the single example not having enough data to properly fit a background distribution for negative seqlets. Accordingly, we see that the ablation score for the first and third seqlet are slightly positive (matching with the seqlets having negative attributions) but also very small magnitude. In contrast, the other seqlets have a much higher magnitude ablation score because the predictions are significantly reduced when the seqlet is ablated.

The ablate_annotations function is just a wrapper around the ablate function and so can accept any of the underlying arguments. For example, if you want to increase the number of ablations done…

[13]:
y_before, y_after = ablate_annotations(model, X, annotations, n=50)
y_before.shape, y_after.shape
[13]:
(torch.Size([5, 1, 2002]), torch.Size([5, 1, 50, 2002]))

… or if you’d like to ablate but calculate attributions instead of make predictions after doing the ablation.

[14]:
y_before, y_after = ablate_annotations(model, X, annotations, func=deep_lift_shap)
y_before.shape, y_after.shape
/users/jmschr/github/tangermeme/tangermeme/deep_lift_shap.py:436: RuntimeWarning: Convergence deltas too high: tensor([7.1526e-07, 2.7227e-04, 9.7036e-05, 3.0994e-05, 3.8481e-04, 2.2590e-04,
        1.3590e-03, 5.8365e-04, 8.9169e-05, 7.1526e-05, 1.2445e-04, 9.6798e-05,
        3.5620e-04, 1.6785e-04, 1.7011e-04, 1.8167e-04, 1.9288e-04, 1.2684e-04,
        2.3651e-04, 7.0572e-05, 3.1853e-04, 1.2922e-04, 5.6982e-05, 6.8665e-05,
        2.0671e-04, 1.3554e-04, 2.9588e-04, 1.4210e-04, 9.9659e-05, 1.3471e-05,
        3.9196e-04, 3.5334e-04], device='cuda:0', grad_fn=<AbsBackward0>)
  warnings.warn("Convergence deltas too high: " +
[14]:
(torch.Size([5, 1, 4, 2000]), torch.Size([5, 1, 20, 4, 2000]))

We can visualize the results of these ablations + attributions and confirm that, yes, each motif is being knocked out.

[15]:
plt.figure(figsize=(12, 5))
ax = plt.subplot(311)
plot_logo(y_after[1, 0].mean(axis=0), ax=ax, start=925, end=1100, annotations=seqlets, score_key='attribution')
plt.ylim(-0.025, 0.25)
plt.xlabel("Genomic Coordinate")
plt.ylabel("Attribution")
plt.title("DeepLIFT/SHAP of Ablations")

ax = plt.subplot(312)
plot_logo(y_after[3, 0].mean(axis=0), ax=ax, start=925, end=1100, annotations=seqlets, score_key='attribution')
plt.ylim(-0.025, 0.25)
plt.xlabel("Genomic Coordinate")
plt.ylabel("Attribution")

ax = plt.subplot(313)
plot_logo(y_after[4, 0].mean(axis=0), ax=ax, start=925, end=1100, annotations=seqlets, score_key='attribution')
plt.ylim(-0.025, 0.25)
plt.xlabel("Genomic Coordinate")
plt.ylabel("Attribution")

plt.tight_layout()
plt.show()
../_images/tutorials_Tutorial_B2_Ablation_28_0.png
[ ]: