{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "b5b34936-2c18-4cd2-ac9e-5cfbcd2972df", "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ['CUDA_VISIBLE_DEVICES'] = '0'" ] }, { "cell_type": "markdown", "id": "41bdd281-0cd7-4bb0-9fe1-e7f079c2fea5", "metadata": {}, "source": [ "### Wrappers are Producitvity Hacks\n", "\n", "A common issue with machine learning models in regulatory genomics is that their outputs are not.. exactly.. what you need. Many downstream analyses work best when models predict a single number, but many of the most well known models predict more than just a single number per example. For example, BPNet predictions a bp resolution profile and also a count, and Enformer predicts a binned profile. How can you use these models with existing downstream functions that are not built for those outputs?\n", "\n", "One potential solution is to modify all of your analysis functions to slice and dice and aggregate the outputs from models until they are the right shape. Maybe you write a `bpnet_deep_lift_shap` function that is a copy of the `deep_lift_shap` function in tangermeme but slice out the profile head and just operate on the counts. Or.. maybe you write an `enformer_deep_lift_shap` function that sums the track across the length dimension before calculating attributions. Although these might technically work, they also seem like a lot of brittle code laying around that makes things messy.\n", "\n", "An alternate solution is to use wrappers! PyTorch conveniently allows you to put a model within another model, colloquially called a wrapper, that can be extremely flexible. \n", "\n", "#### Slicing and Dicing Inputs and Outputs\n", "\n", "Let's take a look using the built-in wrappers for bpnet-lite, a light-weight library for loading and using BPNet and ChromBPNet models.\n", "\n", "First, we can load up a BPNet model." ] }, { "cell_type": "code", "execution_count": 2, "id": "f6c899d9-0a00-4354-a69a-38cfeb0ab5b8", "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "model = torch.load(\"../../../../models/bpnet/GATA2.torch\", weights_only=False)" ] }, { "cell_type": "markdown", "id": "f1ebbfee-0714-407e-b54f-65703c024fc5", "metadata": {}, "source": [ "For those unfamiliar with the model, let's take a look at the output of a random sequence. BPNet models additionally need a control track which are control experiment counts on each strand which is usually set to all zeroes after training." ] }, { "cell_type": "code", "execution_count": 3, "id": "443c758f-0757-4bce-888d-ce9d1458b101", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[tensor([[[ 0.4295, 0.5723, 0.4151, ..., -0.1043, -0.0767, -0.1003],\n", " [-0.2152, 0.0224, 0.1121, ..., -0.3272, -0.2603, -0.2479]]]),\n", " tensor([[0.3379]])]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from tangermeme.utils import random_one_hot\n", "from tangermeme.predict import predict\n", "\n", "X = random_one_hot((1, 4, 2114)).float()\n", "X_ctl = torch.zeros_like(X)[:, :2]\n", "\n", "predict(model, X, args=(X_ctl,))" ] }, { "cell_type": "markdown", "id": "8d7f2185-7e77-45ac-91e6-f10172b9a321", "metadata": {}, "source": [ "Okay, looks like the output is a pair of tensors where the first tensor contains logits for the profile predictions and the second tensor contains count predictions across both strands for that locus. Since we passed in only a singl example, both tensors have size 1 for the first dimension.\n", "\n", "But... if I wanted attributions, what happens?" ] }, { "cell_type": "code", "execution_count": 4, "id": "9d5c1151-6e7c-4581-a5de-52905ecaebcf", "metadata": {}, "outputs": [ { "ename": "TypeError", "evalue": "tuple indices must be integers or slices, not tuple", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[4], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtangermeme\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdeep_lift_shap\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m deep_lift_shap\n\u001b[0;32m----> 3\u001b[0m deep_lift_shap(model, X, args\u001b[38;5;241m=\u001b[39m(X_ctl,))\n", "File \u001b[0;32m~/github/tangermeme/tangermeme/deep_lift_shap.py:444\u001b[0m, in \u001b[0;36mdeep_lift_shap\u001b[0;34m(model, X, args, target, batch_size, references, n_shuffles, return_references, hypothetical, warning_threshold, additional_nonlinear_ops, print_convergence_deltas, raw_outputs, device, random_state, verbose)\u001b[0m\n\u001b[1;32m 442\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 443\u001b[0m \tmodel\u001b[38;5;241m.\u001b[39mapply(_clear_hooks)\n\u001b[0;32m--> 444\u001b[0m \t\u001b[38;5;28;01mraise\u001b[39;00m(e)\n\u001b[1;32m 446\u001b[0m \u001b[38;5;66;03m# If not returning the raw multipliers then apply the correction for\u001b[39;00m\n\u001b[1;32m 447\u001b[0m \u001b[38;5;66;03m# character encodings\u001b[39;00m\n\u001b[1;32m 448\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m raw_outputs \u001b[38;5;241m==\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m:\n", "File \u001b[0;32m~/github/tangermeme/tangermeme/deep_lift_shap.py:422\u001b[0m, in \u001b[0;36mdeep_lift_shap\u001b[0;34m(model, X, args, target, batch_size, references, n_shuffles, return_references, hypothetical, warning_threshold, additional_nonlinear_ops, print_convergence_deltas, raw_outputs, device, random_state, verbose)\u001b[0m\n\u001b[1;32m 420\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _args \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 421\u001b[0m \t_args \u001b[38;5;241m=\u001b[39m (torch\u001b[38;5;241m.\u001b[39mcat([arg, arg]) \u001b[38;5;28;01mfor\u001b[39;00m arg \u001b[38;5;129;01min\u001b[39;00m _args)\n\u001b[0;32m--> 422\u001b[0m \ty \u001b[38;5;241m=\u001b[39m model(X_, \u001b[38;5;241m*\u001b[39m_args)[:, target]\n\u001b[1;32m 423\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 424\u001b[0m \ty \u001b[38;5;241m=\u001b[39m model(X_)[:, target]\n", "\u001b[0;31mTypeError\u001b[0m: tuple indices must be integers or slices, not tuple" ] } ], "source": [ "from tangermeme.deep_lift_shap import deep_lift_shap\n", "\n", "deep_lift_shap(model, X, args=(X_ctl,))" ] }, { "cell_type": "markdown", "id": "2ed86a73-29f8-455c-9f68-935ae033792e", "metadata": {}, "source": [ "We get an error. This is because `deep_lift_shap`, like many other downstream functions, cannot handle outputs of arbitrary shape. In this case, it assumes a single number per example. \n", "\n", "Time for a wrapper! Let's do something simple and just slice out the count predictions." ] }, { "cell_type": "code", "execution_count": 5, "id": "bce4d418-9571-47bf-a992-a496621cde80", "metadata": {}, "outputs": [], "source": [ "class CountWrapper(torch.nn.Module):\n", " def __init__(self, model):\n", " super(CountWrapper, self).__init__()\n", " self.model = model\\\n", " \n", " def forward(self, X, *args):\n", " return self.model(X, *args)[1]" ] }, { "cell_type": "markdown", "id": "3fe14d99-6620-4e1f-9aba-724c1bf44351", "metadata": {}, "source": [ "All we are doing here is running the underlying model but only returning the second output, which for BPNet models is the count head." ] }, { "cell_type": "code", "execution_count": 6, "id": "117cf40c-1b32-4a5d-8c12-6bd651c253b1", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0.3379]])" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "count_model = CountWrapper(model)\n", "\n", "predict(count_model, X, args=(X_ctl,))" ] }, { "cell_type": "markdown", "id": "3be82ed4-2f33-4d62-b365-160f289be2c1", "metadata": {}, "source": [ "Simple! Now we can pass this into anything downstream without having to modify that function." ] }, { "cell_type": "code", "execution_count": 7, "id": "ac860b08-cb29-41cb-b628-98fae5292172", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[[-0.0000e+00, -0.0000e+00, -0.0000e+00, ..., -0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [ 0.0000e+00, -1.0636e-08, -8.1422e-09, ..., -0.0000e+00,\n", " -0.0000e+00, -0.0000e+00],\n", " [ 0.0000e+00, 0.0000e+00, -0.0000e+00, ..., -0.0000e+00,\n", " 0.0000e+00, 0.0000e+00],\n", " [ 0.0000e+00, -0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]]])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "deep_lift_shap(count_model, X, args=(X_ctl,))" ] }, { "cell_type": "markdown", "id": "f3f6a1d4-c0f8-444e-9da5-f455a50f382a", "metadata": {}, "source": [ "Hooray, it ran.\n", "\n", "But it's been pretty annoying to have to keep passing in this empty control track, right? Good news. Wrappers can do more than just modify the outputs from models -- they can really do anything.\n", "\n", "Let's make a wrapper that automatically creates an empty control track that matches the input in size. Naturally, if we wanted to pass in informative control tracks this wrapper wouldn't be helpful, but if all we are doing in these downstream steps is using an all-zeroes track it should be fine." ] }, { "cell_type": "code", "execution_count": 8, "id": "5aa120aa-9fda-4242-b520-5c04f8eed6c6", "metadata": {}, "outputs": [], "source": [ "class ControlWrapper(torch.nn.Module):\n", " def __init__(self, model):\n", " super(ControlWrapper, self).__init__()\n", " self.model = model\n", " \n", " def forward(self, X):\n", " X_ctl = torch.zeros_like(X)[:, :2]\n", " return self.model(X, X_ctl)" ] }, { "cell_type": "markdown", "id": "b4767e31-9fdc-44d7-8014-60020b926afc", "metadata": {}, "source": [ "Using this wrapper means that we no longer have to pass in the `X_ctl` input in." ] }, { "cell_type": "code", "execution_count": 9, "id": "4207b875-ba97-44c4-839d-7da8d3c3d20f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0.3379]])" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "control_model = ControlWrapper(count_model)\n", "\n", "predict(control_model, X)" ] }, { "cell_type": "markdown", "id": "6a6b88ce-0772-4d54-94cf-98561ee197da", "metadata": {}, "source": [ "Same output as before, without having to write out `args=(X_ctl,)`! \n", "\n", "These wrappers are already implemented in bpnet-lite and make using those models really easy. Loading just becomes the following." ] }, { "cell_type": "code", "execution_count": 10, "id": "2bf1b81f-0d3d-4a84-b0fb-ee75f1e8c5af", "metadata": {}, "outputs": [], "source": [ "from bpnetlite.bpnet import CountWrapper\n", "from bpnetlite.bpnet import ControlWrapper\n", "\n", "model = torch.load(\"../../../../models/bpnet/GATA2.torch\", weights_only=False)\n", "model = CountWrapper(ControlWrapper(model))" ] }, { "cell_type": "markdown", "id": "6f7af1b2-1a32-406f-a70b-f6ffeb0d5d57", "metadata": {}, "source": [ "Saving some characters like this is always nice but this sort of wrapper, which allows you to control or modify the inputs, is extremely valuable when working with code that may not be as flexible as you would like -- an unfortunately common occurrance in research settings. Imagine that, for example, you want to use a function from another library but that function *does not allow you to pass in additional arguments past the sequence*. Here, you would not be able to use a BPNet model *even though the additional argument is an uninformative all-zeroes*, because the control track argument is required.\n", "\n", "Let's pretend we are in a setting where we have an actually informative control track but the function simply does not allow us to pass in more than a single tensor that is intended to be the one-hot encoded sequence. Using the above wrapper does not work because we do not want all zeroes -- we want actual values!\n", "\n", "A potential solution (assuming the code does not do dimension checking) is that we can concatenate the two inputs together and have a wrapper separate them out internally. Basically, because in this case both tensors are the same length, we can concatenate the four dimensions of the one-hot encoding with the two dimensions of the strands for the control track into a `(n, 6, 2114)` shaped vector." ] }, { "cell_type": "code", "execution_count": 11, "id": "950486c1-5a66-42bd-90c4-a0eddacf5e9c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([[6.5609]]), tensor([[6.5609]]))" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class ControlSplitter(torch.nn.Module):\n", " def __init__(self, model):\n", " super(ControlSplitter, self).__init__()\n", " self.model = model\n", " \n", " def forward(self, X):\n", " return self.model(X[:, :4], X[:, 4:])\n", " \n", "splitter_model = ControlSplitter(count_model)\n", "\n", "X_ctl2 = torch.abs(torch.randn(1, 2, 2114))\n", "Xp_ctl = torch.cat([X, X_ctl2], axis=1)\n", "\n", "predict(count_model, X, args=(X_ctl2,)), predict(splitter_model, Xp_ctl)" ] }, { "cell_type": "markdown", "id": "d74f4735-3a5d-4081-bd1a-bdea38242e7f", "metadata": {}, "source": [ "Looks like we get the same answer in either case." ] }, { "cell_type": "markdown", "id": "c889550e-ad10-4e34-8f94-ff8dc56b485b", "metadata": {}, "source": [ "#### Advanced Slicing+Dicing with Enformer\n", "\n", "Enformer poses potentially even more challenges to work with that BPNet models. It's output comes in the form of a dictionary, because predictions are made for both human and mouse tracks. For each species, predictions are made for thousands of different experiments, and for each track multiple bins are reported. Almost no function is going to have built-in functionality to go from the raw predictions of this model to the single number one might be interested in using.\n", "\n", "Oh, also the `enformer_pytorch` implementation assumes that the length dimension is the first dimension instead of the second dimension, even though PyTorch has a standard that the length dimension should go last. This fact alone can break many functions.\n", "\n", "What to do? A wrapper!" ] }, { "cell_type": "code", "execution_count": 12, "id": "7feff185-5dd1-4761-84bc-3f4622c5f4e5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'human': tensor([[[0.0840, 0.1147, 0.2047, ..., 0.0486, 0.8827, 0.8969],\n", " [0.0832, 0.0926, 0.1280, ..., 0.0093, 0.0721, 0.0559],\n", " [0.0995, 0.1134, 0.1707, ..., 0.0017, 0.0161, 0.0112],\n", " ...,\n", " [0.0958, 0.1037, 0.1234, ..., 0.0039, 0.0474, 0.0564],\n", " [0.0682, 0.0980, 0.1097, ..., 0.0054, 0.0606, 0.1010],\n", " [0.0879, 0.1223, 0.1424, ..., 0.0026, 0.0240, 0.0295]]],\n", " grad_fn=),\n", " 'mouse': tensor([[[0.0811, 0.0877, 0.0587, ..., 0.6569, 5.2159, 1.7070],\n", " [0.0698, 0.1061, 0.0628, ..., 0.2593, 0.6219, 0.4996],\n", " [0.0818, 0.0929, 0.0783, ..., 0.2032, 0.2541, 0.3682],\n", " ...,\n", " [0.0584, 0.0597, 0.0430, ..., 0.1296, 0.1003, 0.2206],\n", " [0.0301, 0.0577, 0.0282, ..., 0.1081, 0.0935, 0.1894],\n", " [0.0332, 0.0570, 0.0394, ..., 0.1293, 0.1594, 0.2478]]],\n", " grad_fn=)}" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import os\n", "os.environ['POLARS_ALLOW_FORKING_THREAD'] = '1' # Needed for Enformer for whatever reason\n", "\n", "from enformer_pytorch import from_pretrained\n", "\n", "class EnformerInputSwapper(torch.nn.Module):\n", " def __init__(self, model):\n", " super(EnformerInputSwapper, self).__init__()\n", " self.model = model\n", " \n", " def forward(self, X):\n", " return self.model(X.permute(0, 2, 1))\n", " \n", "\n", "enformer_base = from_pretrained('EleutherAI/enformer-official-rough', target_length=16, use_tf_gamma=False)\n", "enformer = EnformerInputSwapper(enformer_base)\n", "enformer(X)" ] }, { "cell_type": "markdown", "id": "f9f49a6c-5c21-4e5d-ba5f-4dd0f2c21d7c", "metadata": {}, "source": [ "Now that we have resolved the dimension issue, we can see the dictionary that gets provided by the implementation. Even if functions can handle slicing out indexes from tuples (almost never, anyway), they are even less likely to be able to handle indexing into dictionaries. So, let's put that into our wrapper." ] }, { "cell_type": "code", "execution_count": 13, "id": "8374475c-927f-4998-aab4-210026b82afd", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[[0.0840, 0.1147, 0.2047, ..., 0.0486, 0.8827, 0.8969],\n", " [0.0832, 0.0926, 0.1280, ..., 0.0093, 0.0721, 0.0559],\n", " [0.0995, 0.1134, 0.1707, ..., 0.0017, 0.0161, 0.0112],\n", " ...,\n", " [0.0958, 0.1037, 0.1234, ..., 0.0039, 0.0474, 0.0564],\n", " [0.0682, 0.0979, 0.1096, ..., 0.0054, 0.0606, 0.1010],\n", " [0.0879, 0.1223, 0.1423, ..., 0.0026, 0.0240, 0.0295]]])" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class EnformerWrapper(torch.nn.Module):\n", " def __init__(self, model):\n", " super(EnformerWrapper, self).__init__()\n", " self.model = model\n", " \n", " def forward(self, X):\n", " return self.model(X.permute(0, 2, 1))['human']\n", " \n", "enformer = EnformerWrapper(enformer_base)\n", "\n", "predict(enformer, X)" ] }, { "cell_type": "markdown", "id": "e3801153-83a6-462c-bbe5-ce155467bc57", "metadata": {}, "source": [ "Great, now the input and output are in formats that can be readily used by tangermeme. But what do we do next? Well, we can collapse the predictions across the length dimension. Here, predictions are made for each 128bp bin in the sequence and we could just sum the values across those bins." ] }, { "cell_type": "code", "execution_count": 14, "id": "a1f92249-bbfa-43e7-a018-aa70fe78f3d3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[1.8133, 2.1466, 3.1470, ..., 0.0994, 1.3929, 1.4304]])" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class EnformerWrapper2(torch.nn.Module):\n", " def __init__(self, model):\n", " super(EnformerWrapper2, self).__init__()\n", " self.model = model\n", " \n", " def forward(self, X):\n", " return self.model(X.permute(0, 2, 1))['human'].sum(dim=-2)\n", " \n", "enformer = EnformerWrapper2(enformer_base)\n", "\n", "predict(enformer, X)" ] }, { "cell_type": "markdown", "id": "1cb6c290-02f5-48a4-8310-45d70156434c", "metadata": {}, "source": [ "At this point we have solved several issues in both the input and output of the Enformer model using only a few lines of code and this wrapper should make it easily usable by most downstream functions without needing to modify them. As a final addition, we could slice out a specific target from the 5313 outputs for humans, yielding a single prediction per example from Enformer." ] }, { "cell_type": "code", "execution_count": 15, "id": "9caad630-ca8c-4840-bfcd-f0f53559d293", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([2.0108])" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class EnformerWrapper3(torch.nn.Module):\n", " def __init__(self, model, target):\n", " super(EnformerWrapper3, self).__init__()\n", " self.model = model\n", " self.target = target\n", " \n", " def forward(self, X):\n", " return self.model(X.permute(0, 2, 1))['human'].sum(dim=-2)[:, self.target]\n", " \n", "enformer = EnformerWrapper3(enformer_base, 15)\n", "\n", "predict(enformer, X)" ] }, { "cell_type": "markdown", "id": "e890a873-d11b-4eb4-9686-a1bc4e6015e7", "metadata": {}, "source": [ "#### Correcting Mismatched Shapes\n", "\n", "Sometimes you want to compare the predictions from two models but the models do not operate on the same sequence length. Naturally, no wrapper can magically expand the model to operate faithfully outside the hard constraints it was trained on, but a wrapper can resize the inputs to the expected shape and that caveat can be noted.\n", "\n", "For example, the above Enformer model was able to be directly applied to a sequence of the same length as the BPNet model. Let's load a larger version of it that requires a ~3kbp input." ] }, { "cell_type": "code", "execution_count": 16, "id": "104ad33c-d824-4c3b-ba1e-29f956bee912", "metadata": {}, "outputs": [ { "ename": "ValueError", "evalue": "sequence length 17 is less than target length 24", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[16], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m enformer_base \u001b[38;5;241m=\u001b[39m from_pretrained(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mEleutherAI/enformer-official-rough\u001b[39m\u001b[38;5;124m'\u001b[39m, target_length\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m24\u001b[39m, use_tf_gamma\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m 2\u001b[0m enformer \u001b[38;5;241m=\u001b[39m EnformerWrapper3(enformer_base, \u001b[38;5;241m15\u001b[39m)\n\u001b[0;32m----> 4\u001b[0m predict(enformer, X)\n", "File \u001b[0;32m~/github/tangermeme/tangermeme/predict.py:107\u001b[0m, in \u001b[0;36mpredict\u001b[0;34m(model, X, args, batch_size, dtype, device, verbose)\u001b[0m\n\u001b[1;32m 105\u001b[0m \t\ty_ \u001b[38;5;241m=\u001b[39m model(X_, \u001b[38;5;241m*\u001b[39margs_)\n\u001b[1;32m 106\u001b[0m \t\u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 107\u001b[0m \t\ty_ \u001b[38;5;241m=\u001b[39m model(X_)\n\u001b[1;32m 109\u001b[0m \u001b[38;5;66;03m# Move to the CPU\u001b[39;00m\n\u001b[1;32m 110\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(y_, torch\u001b[38;5;241m.\u001b[39mTensor):\n", "File \u001b[0;32m~/anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n", "File \u001b[0;32m~/anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1745\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1746\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n", "Cell \u001b[0;32mIn[15], line 8\u001b[0m, in \u001b[0;36mEnformerWrapper3.forward\u001b[0;34m(self, X)\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, X):\n\u001b[0;32m----> 8\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel(X\u001b[38;5;241m.\u001b[39mpermute(\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m1\u001b[39m))[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mhuman\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;241m.\u001b[39msum(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m)[:, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtarget]\n", "File \u001b[0;32m~/anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n", "File \u001b[0;32m~/anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1745\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1746\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n", "File \u001b[0;32m~/anaconda3/lib/python3.12/site-packages/enformer_pytorch/modeling_enformer.py:462\u001b[0m, in \u001b[0;36mEnformer.forward\u001b[0;34m(self, x, target, return_corr_coef, return_embeddings, return_only_embeddings, head, target_length)\u001b[0m\n\u001b[1;32m 459\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mset_target_length(target_length)\n\u001b[1;32m 461\u001b[0m trunk_fn \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrunk_checkpointed \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_checkpointing \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_trunk\n\u001b[0;32m--> 462\u001b[0m x \u001b[38;5;241m=\u001b[39m trunk_fn(x)\n\u001b[1;32m 464\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m no_batch:\n\u001b[1;32m 465\u001b[0m x \u001b[38;5;241m=\u001b[39m rearrange(x, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m() ... -> ...\u001b[39m\u001b[38;5;124m'\u001b[39m)\n", "File \u001b[0;32m~/anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n", "File \u001b[0;32m~/anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1745\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1746\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n", "File \u001b[0;32m~/anaconda3/lib/python3.12/site-packages/torch/nn/modules/container.py:250\u001b[0m, in \u001b[0;36mSequential.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 248\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m):\n\u001b[1;32m 249\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[0;32m--> 250\u001b[0m \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m module(\u001b[38;5;28minput\u001b[39m)\n\u001b[1;32m 251\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28minput\u001b[39m\n", "File \u001b[0;32m~/anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n", "File \u001b[0;32m~/anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1745\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1746\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n", "File \u001b[0;32m~/anaconda3/lib/python3.12/site-packages/enformer_pytorch/modeling_enformer.py:205\u001b[0m, in \u001b[0;36mTargetLengthCrop.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 202\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m seq_len \u001b[38;5;241m<\u001b[39m target_len:\n\u001b[0;32m--> 205\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124msequence length \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mseq_len\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m is less than target length \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtarget_len\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 207\u001b[0m trim \u001b[38;5;241m=\u001b[39m (target_len \u001b[38;5;241m-\u001b[39m seq_len) \u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m \u001b[38;5;241m2\u001b[39m\n\u001b[1;32m 209\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m trim \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n", "\u001b[0;31mValueError\u001b[0m: sequence length 17 is less than target length 24" ] } ], "source": [ "enformer_base = from_pretrained('EleutherAI/enformer-official-rough', target_length=24, use_tf_gamma=False)\n", "enformer = EnformerWrapper3(enformer_base, 15)\n", "\n", "predict(enformer, X)" ] }, { "cell_type": "markdown", "id": "aa4ae523-d02d-46ac-9537-7179cf8d7acf", "metadata": {}, "source": [ "Oh no, an error message. \n", "\n", "Well, there are two things we can do. The first thing is that we can make a wrapper that plops the 2114bp sequence into the middle of an otherwise-zeroes tensor. This would basically be like padding the sequence with Ns on both sides." ] }, { "cell_type": "code", "execution_count": 17, "id": "b1fa3384-1d34-44e0-b64b-dc3b5b449e4b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([1.9488])" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class PaddingWrapper(torch.nn.Module):\n", " def __init__(self, model, n_padding=500):\n", " super(PaddingWrapper, self).__init__()\n", " self.model = model\n", " self.n_padding = n_padding\n", " \n", " def forward(self, X):\n", " X_ = torch.zeros(X.shape[0], X.shape[1], X.shape[2]+self.n_padding*2, dtype=X.dtype, device=X.device)\n", " X_[:, :, self.n_padding:self.n_padding+X.shape[2]] = X\n", " return self.model(X_)\n", " \n", "enformer_pad = PaddingWrapper(enformer, 500)\n", "\n", "predict(enformer_pad, X)" ] }, { "cell_type": "markdown", "id": "95cd0770-e22f-4bf3-a84b-908710176908", "metadata": {}, "source": [ "Ta-da, fixed! This wrapper will allow us to pass the same input tensor to models even when they require different shapes. An important note here is that if these models were not trained to know how to handle Ns that, although they can technically make predictions on the sequences, these sequences will be out of distirubtion and the predictions may not be robust.\n", "\n", "A second way that we can handle this issue is that we can expand the sequence we are making predictions for `X` and then trim it for the BPNet models." ] }, { "cell_type": "code", "execution_count": 18, "id": "f6f3d4a1-3dd7-45ee-9de8-51001cfeb9bc", "metadata": {}, "outputs": [], "source": [ "class TrimmingWrapper(torch.nn.Module):\n", " def __init__(self, model, n_trim):\n", " super(TrimmingWrapper, self).__init__()\n", " self.model = model\n", " self.n_trim = n_trim\n", " \n", " def forward(self, X):\n", " return self.model(X[:, :, self.n_trim:-self.n_trim])" ] }, { "cell_type": "markdown", "id": "c5b99243-2de4-4251-8dd4-5b22d63aa628", "metadata": {}, "source": [ "This wrapper will trim the edges off the sequence and then make predictions using the stored model on the trimmed sequence. This approach may be more robust than adding the Ns because the entire length of the sequence is real. However, a weakness of this approach is that if there is critical information in the flanks that gets trimmed out, the model has no chance of responding to it correctly." ] }, { "cell_type": "code", "execution_count": 19, "id": "8f4b03b3-c853-4a56-854b-df4a23dcb89a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([2.1116]), tensor([[0.9279]]))" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X2 = random_one_hot((1, 4, 3000)).float()\n", "trim_bpnet = TrimmingWrapper(control_model, (3000-2114)//2)\n", "\n", "predict(enformer, X2), predict(trim_bpnet, X2)" ] }, { "cell_type": "markdown", "id": "685e48ed-adbc-4f8b-bfe8-c7b4d78f3c52", "metadata": {}, "source": [ "A technical detail here is that the BPNet models can actually be run on sequences of any length but because they were only trained on sequences of length 2114 the predictions may be unreliable on other lengths." ] }, { "cell_type": "markdown", "id": "2c1abd4b-d7d6-4cd2-86ac-d417021b8734", "metadata": {}, "source": [ "#### Squishing Models Together\n", "\n", "So far, we have shown how one can modify the inputs to a model and the outputs from a model using wrappers. But, we can also squish models together into the same object so that they act like a single model! Although forward and backward passes will still take the same amount of time as running both separately (wrapping models together is not compressing them), having a single object can sometimes be more managable." ] }, { "cell_type": "code", "execution_count": 20, "id": "dc4ebecd-4daf-4232-9392-38f852b30161", "metadata": {}, "outputs": [], "source": [ "class SquishWrapper(torch.nn.Module):\n", " def __init__(self, models):\n", " super(SquishWrapper, self).__init__()\n", " self.models = models\n", " \n", " def forward(self, X):\n", " return torch.cat([model(X) for model in self.models], axis=-1)" ] }, { "cell_type": "markdown", "id": "0c052043-57ab-470b-ba5a-964b229dba4e", "metadata": {}, "source": [ "This wrapper gives us an object that takes in an input, applies a series of models to it, and returns the concatenated predictions from all of the models. Let's test it out with three BPNet models and one ChromBPNet model." ] }, { "cell_type": "code", "execution_count": 21, "id": "af2ae7a5-1a9e-4e20-a409-9094c36c4e47", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([1, 4]), torch.Size([1, 1]))" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from bpnetlite import BPNet\n", "\n", "model0 = torch.load(\"../../../../models/bpnet/GATA2.torch\", weights_only=False)\n", "model0 = CountWrapper(ControlWrapper(model0))\n", "\n", "model1 = torch.load(\"../../../../models/bpnet/SOX6.torch\", weights_only=False)\n", "model1 = CountWrapper(ControlWrapper(model1))\n", "\n", "model2 = torch.load(\"../../../../models/bpnet/CTCF.torch\", weights_only=False)\n", "model2 = CountWrapper(ControlWrapper(model2))\n", "\n", "model3 = BPNet.from_chrombpnet(\"../../../../models/chrombpnet/fold_0/model.chrombpnet_nobias.fold_0.ENCSR868FGK.h5\")\n", "model3 = CountWrapper(model3).cuda()\n", "\n", "wrapper = SquishWrapper([model0, model1, model2, model3])\n", "\n", "predict(wrapper, X).shape, predict(model0, X).shape" ] }, { "cell_type": "markdown", "id": "3bc86208-8ea6-4064-8528-3402d8074e1e", "metadata": {}, "source": [ "Looks like it just simply works out of the box even with models that have different sizes and inputs. Each individual model returns only a single value -- the count prediction -- but the wrapper here returns the four numbers together. Because the wrapper is a model, it can be passed into any downstream function like `saturated_mutagenesis` or `marginalize`, etc.\n", "\n", "This example also demonstrates how one can stack many wrappers on top of each other to get the desired output without needing to modify the underlying model. Remember, the BPNet models here take in a control track and output a profile output and a count output, but the ChromBPNet model does not need a control track. Using these stacks, we've removed the need to specify the control track input for only those models that previously needed it, sliced out all the profile outputs, and concatenated together the count outputs. No big deal when you're using wrappers." ] }, { "cell_type": "markdown", "id": "ba9178c6-3bcb-4d10-8285-a6db768b7995", "metadata": {}, "source": [ "#### Adding in Processing\n", "\n", "Naturally, wrappers are not limited in their abilities to modifying the inputs and outputs of models. They can also do processing of the data within the wrapper itself, as we have seen with the Enformer example converting a profile into a single number. \n", "\n", "As an example of this, let's consider reverse complementing. Most regulatory genomics models are trained and evaluated on sequences that are from only one strand. Sometimes, these examples are derived from reverse complementing another example, but the model still only sees one directionality at a time.\n", "\n", "An alternate strategy is to make predictions on an example and its reverse complement and then to average those predictions together. This can be either during training or only in evaluation once a model has been trained. Handling making predictions on an example and its reverse complement is, of course, a huge hassle, particularly if you want to do anything downstream with the model like calculating attributions or marginalizations.\n", "\n", "But it becomes trivial when you have a wrapper!" ] }, { "cell_type": "code", "execution_count": 22, "id": "5d5e9870-9c4c-4403-bc00-61e2ae4406ee", "metadata": {}, "outputs": [], "source": [ "class RCWrapper(torch.nn.Module):\n", " def __init__(self, model):\n", " super(RCWrapper, self).__init__()\n", " self.model = model\n", " \n", " def forward(self, X):\n", " return (self.model(X) + self.model(torch.flip(X, dims=(-1, -2)))) / 2.0" ] }, { "cell_type": "markdown", "id": "6659163b-20c8-4c30-bdf5-d527e70bb2c7", "metadata": {}, "source": [ "In the above wrapper, predictions from the model are averaged between the forward and reverse versions of the sequence. *Importantly*, this does not correctly handle stranded outputs, which would have to also be flipped, because the toy models here only make count predictions." ] }, { "cell_type": "code", "execution_count": 23, "id": "e5bf8348-6118-4a47-b827-af4de5b1db1b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([[0.3525]]), tensor([[0.3379]]), tensor([[0.3525]]))" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "wrapper = RCWrapper(model0)\n", "\n", "predict(wrapper, X), predict(model0, X), (predict(model0, X) + predict(model0, torch.flip(X, dims=(-1, -2)))) / 2" ] }, { "cell_type": "markdown", "id": "515802dc-8600-43e1-ab0c-2543710cb014", "metadata": {}, "source": [ "Looks like the predictions from the wrapper are the same as running the forward and reverse versions of the sequence through. This is probably not particularly surprising, but having this single object that does all that processing is a whole lot easier for downstream applications." ] }, { "cell_type": "code", "execution_count": 24, "id": "e0a5370e-22a1-4ae8-ba07-554383ca8193", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([[0.3525]]), tensor([[0.6588]]))" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from tangermeme.marginalize import marginalize\n", "\n", "marginalize(wrapper, X, \"GATAAC\")" ] }, { "cell_type": "markdown", "id": "a2f186cc-acd5-439e-9037-7f50beb7e07f", "metadata": {}, "source": [ "Here, we get the predictions before and after substituting a GATA-like motif into the sequence and see that the GATA model makes a higher prediction after substitution. But... remember that it is not just making a single prediction, but averaging the prediction in the forward and reverse direction also before and after substituting in the motif. In this one line, a total of four forward passes are happening." ] }, { "cell_type": "markdown", "id": "0da5f881-ab52-4910-b720-86fbcadad393", "metadata": {}, "source": [ "### Conclusions\n", "\n", "Wrappers are producitivty hacks. Rather than spending (potentially significant amounts of) time modifying or even rewriting code that someone else has written to accomodate the oddities of your model, you can write a few lines of code that make your model work within the assumptions of the function you want to use. Because wrappers are so light-weight and do not modify the original model itself, you can write them on-the-fly for any particular function or analysis you encounter and even include multiple of them in the same notebook, as we have done here!\n", "\n", "This flexibility has saved me a significant amount of time. In the past, I would spend a lot of time looking for model implementations that matched my assumptions exactly or even spent time retraining models if I found out that the alphabet was flipped (looking at you, DeepSEA/Beluga with your AGCT alphabet). Wrappers are flexible enough to correct *any* issue in the input and output, as well as add in processing that would be very challenging to account for outside the context of a wrapper." ] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:root]", "language": "python", "name": "conda-root-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.2" } }, "nbformat": 4, "nbformat_minor": 5 }