[1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

Tutorial B7: Cartesian Product

Sometimes, when using multi-input models, one wants to run a function on the cartesian product between the two inputs. Put another way, given a sequence input \(X\) and some other input \(Y\), one wants to make predictions for \((X_0, Y_0), (X_0, Y_1)... (X_n, Y_m)\) where \(X\) has \(n\) elements and \(Y\) has \(m\) inputs. Of course, one could simply run each of the functions for a fixed value of \(Y\) across all \(X\) (or vice-versa) and then change the value of \(Y\) each time. However, having to code this yourself is not convenient and can easily be implemented in an inefficient manner, particularly if one is going to encounter settings where they sometimes have a small number of \(X\) and other times have a small number of \(Y\).

Instead of having to implement this yourself, tangermeme provides apply_pairwise and apply_product to make applying functions across products like this easy and memory efficient. apply_pairwise yields examples from the pairwise product between X and a set of arguments whose ordering is paired. For example, if you have \(X\) that has \(5\) elements in it and arguments \(a\) and \(b\) that each have \(3\), you would get \((X_0, a_0, b_0), (X_0, a_1, b_1), ... (X_4, a_1, b_1), (X_4, a_2, b_2)\) In contrast, apply_product applies a function to the cartesian product of the sequences and each of the arguments provided. This means that you would instead get \((X_0, a_0, b_0), (X_0, a_0, b_1), (X_0, a_0, b_2) ... (X_2, a_0, b_1)... (X_4, a_2, b_1), (X_4, a_2, b_2)\). Although apply_product is more general, in the sense that it can be applied across any number of arguments, it is not the right function to run when you have paired inputs like cell information.

In theory, the most conceptually simple way to set up this function is to unravel the entire product into CPU memory and then run the provided function on the entire thing. However, this can take a huge amount of memory, particularly if the product is over several elements. In practice, it’s better to construct each batch iteratively and only run one batch at a time through the model. That way, only the model predictions are stored in CPU memory as opposed to the (usually much larger) inputs.

Let’s see all this in action with a toy model that takes an input, flattens it, and applies an optional linear transformation.

[2]:
import torch

class FlattenDense(torch.nn.Module):
    def __init__(self, length=10):
            super(FlattenDense, self).__init__()
            self.dense = torch.nn.Linear(length*4, 3)

    def forward(self, X, alpha=0, beta=1):
            X = X.reshape(X.shape[0], -1)
            return self.dense(X) * beta + alpha

This model has two optional inputs: alpha, which is an additive constant on the output from the dense layer, and beta, which is a multiplicative factor. Yes, it’s redundant to have these factors after a dense layer which is doing a pretty similar thing, but this is meant just to demonstrate how to use the functions and to confirm that it’s doing the expected thing.

Let’s start off by generating some random one-hot encodings and running the model on them.

[3]:
from tangermeme.utils import random_one_hot
torch.manual_seed(0)

X = random_one_hot((5, 4, 10), random_state=0).float()
model = FlattenDense()

y = model(X)
y
[3]:
tensor([[-0.3154, -0.1625, -0.3183],
        [-0.0866,  0.5461, -0.0244],
        [ 0.3089, -0.2828, -0.1485],
        [ 0.1671, -0.1341, -0.3094],
        [-0.0627,  0.0088,  0.3471]], grad_fn=<AddBackward0>)

Apply Pairwise

apply_pairwise is the correct function to use if you have data that has two axes, where one of the axes is sequences, and the other axis contains multiple tensors of paired information. As an example, if you have a DragoNNFruit model which makes predictions for chromatin accessibility for each cell in a single-cell ATAC-seq experiment, the inputs are sequences, a vector representing the state of the cell, and the read depth of the cell. Because cell state and read depth are paired – both come from the same cell – you want to do the product between X and (cell_state, read_depth) such that you get \((X_0, c_0, r_0), (X_0, c_1, r_1), (X_0, c_2, r_2)...\). Importantly, you do not want to do the full cross product because that will create examples where the read depths and cell states come from different cells.

Predict

We can begin by checking what the predictions would be when using this function with arguments that only have a batch size of 1. Conceptually, this should be identical to just running the predict function, and we can compare our results here to the predictions that we got above.

[4]:
from tangermeme.predict import predict
from tangermeme.product import apply_pairwise

torch.manual_seed(0)
alpha = torch.zeros(1, 1)
beta = torch.ones(1, 1)

y_product = apply_pairwise(predict, model, X, args=(alpha, beta))[:, 0]
y_product
[4]:
tensor([[-0.3154, -0.1625, -0.3183],
        [-0.0866,  0.5461, -0.0244],
        [ 0.3089, -0.2828, -0.1485],
        [ 0.1671, -0.1341, -0.3094],
        [-0.0627,  0.0088,  0.3471]])

Looks like the values are identical, although we do have to index a little bit because the additional index corresponds to the length of the argument tensors.

Next, we can look at what happens when we set alpha and beta to be more than just one example.

[5]:
alpha = torch.zeros(2, 1)
beta = torch.ones(2, 1)

y_product = apply_pairwise(predict, model, X, args=(alpha, beta))
y_product
[5]:
tensor([[[-0.3154, -0.1625, -0.3183],
         [-0.3154, -0.1625, -0.3183]],

        [[-0.0866,  0.5461, -0.0244],
         [-0.0866,  0.5461, -0.0244]],

        [[ 0.3089, -0.2828, -0.1485],
         [ 0.3089, -0.2828, -0.1485]],

        [[ 0.1671, -0.1341, -0.3094],
         [ 0.1671, -0.1341, -0.3094]],

        [[-0.0627,  0.0088,  0.3471],
         [-0.0627,  0.0088,  0.3471]]])

Here, we see that the results are the same for adjacent predictions, which makes sense because alpha is just zeros in both cases and beta is just ones in both cases. Next, we can see that changing the values of alpha and beta will lead to different predictions.

[6]:
alpha = torch.randn(2, 1)
beta = torch.randn(2, 1)

y_product = apply_pairwise(predict, model, X, args=(alpha, beta))
y_product
[6]:
tensor([[[ 2.2283,  1.8950,  2.2344],
         [-0.4727, -0.3858, -0.4743]],

        [[ 1.7297,  0.3512,  1.5941],
         [-0.3427,  0.0170, -0.3073]],

        [[ 0.8680,  2.1571,  1.8646],
         [-0.1178, -0.4542, -0.3779]],

        [[ 1.1769,  1.8331,  2.2151],
         [-0.1984, -0.3696, -0.4693]],

        [[ 1.6775,  1.5218,  0.7847],
         [-0.3290, -0.2884, -0.0961]]])

As mentioned repeatedly, tangermeme tries to be as assumption-free as possible. This means that alpha and beta can be any shape that works with the math provided in the implementation. Because three outputs are generated for each example, we can have our alpha and beta tensors also have three dimensions.

[7]:
alpha = torch.zeros(1, 3)
beta = torch.ones(1, 3)

y_product = apply_pairwise(predict, model, X, args=(alpha, beta))
y_product.shape
[7]:
torch.Size([5, 1, 3])

Attributions

In addition to working with the predict function, these product functions can take in any other tangermeme function and apply them to the respect product of examples. This means that we can apply deep_lift_shap just eas easily as we apply predict.

[8]:
from tangermeme.deep_lift_shap import deep_lift_shap

y_attr = apply_pairwise(deep_lift_shap, model, X, args=(alpha, beta))
y_attr.shape
[8]:
torch.Size([5, 1, 4, 10])

The shape follows from the previous examples: the first dimension is the size of X, the second dimension is the size of alpha and beta, and the remaining dimensions are those from the function being applied.

Marginalize

Next, we can apply marginalize just as easily as we can apply predict. A major difference in the output here will be that there will be two tensors returned: one before making the substitution, and one after. Importantly, when using apply_pairwise and apply_product additional arguments can be passed into the inner function positionally as simply more arguments. Note the “TGA” below.

[9]:
from tangermeme.marginalize import marginalize

y_before, y_after = apply_pairwise(marginalize, model, X, motif="TGA", args=(alpha, beta))
y_before[:, 0], y_after[:, 0]
[9]:
(tensor([[-0.3154, -0.1625, -0.3183],
         [-0.0866,  0.5461, -0.0244],
         [ 0.3089, -0.2828, -0.1485],
         [ 0.1671, -0.1341, -0.3094],
         [-0.0627,  0.0088,  0.3471]]),
 tensor([[-0.0615, -0.2536, -0.1744],
         [-0.1973,  0.6584,  0.2584],
         [ 0.2046,  0.1125, -0.0750],
         [ 0.0317,  0.0328, -0.1166],
         [ 0.0374,  0.1503,  0.4602]]))

If we wanted to also pass in an argument for start we could just keep adding in arguments.

[10]:
y_before, y_after = apply_pairwise(marginalize, model, X, motif="TGA", start=0, args=(alpha, beta))
y_before[:, 0], y_after[:, 0]
[10]:
(tensor([[-0.3154, -0.1625, -0.3183],
         [-0.0866,  0.5461, -0.0244],
         [ 0.3089, -0.2828, -0.1485],
         [ 0.1671, -0.1341, -0.3094],
         [-0.0627,  0.0088,  0.3471]]),
 tensor([[-0.3603, -0.2071, -0.2159],
         [-0.1005,  0.3231, -0.1471],
         [ 0.1569, -0.3900, -0.1329],
         [ 0.1721, -0.2478, -0.2496],
         [-0.1870, -0.0630,  0.1463]]))

Note that the values before making the substitution are the same, but the values after are different.

Naturally, being able to pass in any function, e.g., marginalize, and being able to pass in any arguments to those functions makes it possible to nest functions even further! After all, marginalize itself defaults to predictions but can apply other functions just as easily. Although the signature will be a little bit messy, we can easily use apply_pairwise with the marginalize function that itself is applying deep_lift_shap instead of predict! All we have to do is use the additional_func_kwargs argument, which is a dictionary of arguments that get passed directly into the provided func. This is somewhat redundant with passing in arguments directly, but circumvents issues where you want to pass an argument into func that is the same name as an argument needed by apply_pairwise.

[11]:
y_before, y_after = apply_pairwise(marginalize, model, X, motif="TGA", alphabet=['A', 'C', 'G', 'T',],
                                   additional_func_kwargs={'func': deep_lift_shap}, args=(alpha, beta))
y_before.shape, y_after.shape
[11]:
(torch.Size([5, 1, 4, 10]), torch.Size([5, 1, 4, 10]))

Even though it is a little messy to define the signature, look at how easy it is to do marginalized attributions across a product of examples, and you have the power to change any of the arguments in any of the functions called along the way. You can now do it in a single line instead of having to think of how to efficiently do each of the parts.

Apply Product

In contrast to apply_pairwise, apply_product is a more general function that will construct examples from the product of any number of arguments that have been passed in. If you have a model that takes in many inputs and each input corresponds to an orthogonal sort of value, e.g., a model that takes in DNA sequence, and protein sequence, and some sort of conditions, etc, and predicts something like binding structure, this would be the function for you. The signature is identical to apply_pairwise except the function is applied to more constructed examples.

Let’s start off by seeing this in action with the same prediction as before.

[12]:
from tangermeme.product import apply_product

alpha = torch.zeros(1, 1)
beta = torch.ones(1, 1)

y_product = apply_product(predict, model, X, args=(alpha, beta))[:, 0, 0]
y_product
[12]:
tensor([[-0.3154, -0.1625, -0.3183],
        [-0.0866,  0.5461, -0.0244],
        [ 0.3089, -0.2828, -0.1485],
        [ 0.1671, -0.1341, -0.3094],
        [-0.0627,  0.0088,  0.3471]])

Looks like we are getting the same thing as before, except that there is an additional axis that needs to be indexed into because on of the axes corresponds to alpha and one of them corresponds to beta.

Since all we are doing is adding a value in a broadcasted manner, we can easily check by adding in the appropriate dimensions and doing the addition outside the context of this function.

[13]:
alpha = torch.randn(3, 1)

y_product = apply_product(predict, model, X, args=(alpha,))
y_product
[13]:
tensor([[[-1.4000, -1.2470, -1.4028],
         [-1.7140, -1.5611, -1.7168],
         [ 0.0879,  0.2409,  0.0851]],

        [[-1.1711, -0.5384, -1.1089],
         [-1.4852, -0.8525, -1.4230],
         [ 0.3167,  0.9494,  0.3790]],

        [[-0.7756, -1.3673, -1.2330],
         [-1.0897, -1.6814, -1.5471],
         [ 0.7123,  0.1206,  0.2548]],

        [[-0.9174, -1.2186, -1.3939],
         [-1.2315, -1.5327, -1.7080],
         [ 0.5704,  0.2693,  0.0940]],

        [[-1.1472, -1.0757, -0.7374],
         [-1.4613, -1.3898, -1.0515],
         [ 0.3407,  0.4122,  0.7505]]])
[14]:
y.unsqueeze(1) + alpha.unsqueeze(0)
[14]:
tensor([[[-1.4000, -1.2470, -1.4028],
         [-1.7140, -1.5611, -1.7168],
         [ 0.0879,  0.2409,  0.0851]],

        [[-1.1711, -0.5384, -1.1089],
         [-1.4852, -0.8525, -1.4230],
         [ 0.3167,  0.9494,  0.3790]],

        [[-0.7756, -1.3673, -1.2330],
         [-1.0897, -1.6814, -1.5471],
         [ 0.7123,  0.1206,  0.2548]],

        [[-0.9174, -1.2186, -1.3939],
         [-1.2315, -1.5327, -1.7080],
         [ 0.5704,  0.2693,  0.0940]],

        [[-1.1472, -1.0757, -0.7374],
         [-1.4613, -1.3898, -1.0515],
         [ 0.3407,  0.4122,  0.7505]]], grad_fn=<AddBackward0>)

Same values. If we add in a beta value, we see the same thing.

[15]:
torch.manual_seed(0)
alpha = torch.randn(3, 1)
beta = torch.randn(1, 1)

y_product = apply_product(predict, model, X, args=(alpha, beta))[:, :, 0]
y_product
[15]:
tensor([[[ 1.3617,  1.4486,  1.3601],
         [-0.4727, -0.3858, -0.4743],
         [-2.3581, -2.2711, -2.3597]],

        [[ 1.4918,  1.8514,  1.5271],
         [-0.3427,  0.0170, -0.3073],
         [-2.2280, -1.8684, -2.1926]],

        [[ 1.7166,  1.3803,  1.4566],
         [-0.1178, -0.4542, -0.3779],
         [-2.0032, -2.3395, -2.2632]],

        [[ 1.6360,  1.4648,  1.3651],
         [-0.1984, -0.3696, -0.4693],
         [-2.0838, -2.2550, -2.3546]],

        [[ 1.5054,  1.5460,  1.7383],
         [-0.3290, -0.2884, -0.0961],
         [-2.2144, -2.1738, -1.9815]]])
[16]:
y.unsqueeze(1) * beta.unsqueeze(0) + alpha.unsqueeze(0)
[16]:
tensor([[[ 1.3617,  1.4486,  1.3601],
         [-0.4727, -0.3858, -0.4743],
         [-2.3581, -2.2711, -2.3597]],

        [[ 1.4918,  1.8514,  1.5271],
         [-0.3427,  0.0170, -0.3073],
         [-2.2280, -1.8684, -2.1926]],

        [[ 1.7166,  1.3803,  1.4566],
         [-0.1178, -0.4542, -0.3779],
         [-2.0032, -2.3395, -2.2632]],

        [[ 1.6360,  1.4648,  1.3651],
         [-0.1984, -0.3696, -0.4693],
         [-2.0838, -2.2550, -2.3546]],

        [[ 1.5054,  1.5460,  1.7383],
         [-0.3290, -0.2884, -0.0961],
         [-2.2144, -2.1738, -1.9815]]], grad_fn=<AddBackward0>)