design

tangermeme.design.greedy_substitution(model, X, y=None, motifs=None, loss=MSELoss(), reverse_complement=True, input_mask=None, output_mask=None, tol=0.001, max_iter=-1, args=None, alphabet=['A', 'C', 'G', 'T'], batch_size=32, device='cuda', verbose=False)

Greedily add motifs to achieve a desired goal.

This design function will greedily add motifs to achieve a desired output from the model. Each round, the function will iterate through all possible motifs, substitute each one with the given spacing, and keep the one whose loss function is the smallest. This process will continue until either the maximum number of iterations is reached (at which point, max_iter motifs will have been inserted into the sequence) or the loss falls below tol.

Accordingly, the choice of loss function and desired output from the model is crucial for good design. Usually, the loss function can be Euclidean distance, but for models with more complex outputs or for subtle design tasks one may want to use something else, such as Jensen-Shannon divergence.

Parameters

model: torch.nn.Module

A PyTorch model to use for making predictions. These models can take in any number of inputs and make any number of outputs. The additional inputs must be specified in the args parameter.

X: torch.tensor, shape=(1, len(alphabet), length)

A one-hot encoded sequence to use as the base for design. This must be a single sequence and has the first dimension for broadcasting reasons.

y: torch.Tensor or list of torch.Tensors or None

A tensor or list of Tensors providing the desired output from the model. The type and shape must be compatible with the provided loss function and comparable to the output from model. Each tensor should have a shape of (1, n) where n is the number of outputs from the model. The first dimension is 1 to make broadcasting work correctly. If None, simply choose the edit that yields the strongest response from the model. Default is None.

motifs: list of strings or None

A list of strings where each string is a motif that can be inserted into the sequence. These strings will be one-hot encoded according to the provided alphabet. If None, use the provided alphabet as the motifs to only change one character at a time. Default is None.

loss: function, optional

This function must take in y and y_hat where y is the desired output from the model and y_hat is the current prediction from the model given the substitutions. By default, this is the torch.nn.MSELoss().

reverse_complement: bool, optional

Whether to augment the provided list of motifs with their reverse complements. This will double the runtime. Default is True.

input_mask: torch.Tensor or None, optional

A mask on input positions that can be the start of substitution. Any motif can be substituted in starting at each allowed position even if the contiguous span of the mask is shorter than the motif. True means that a motif can be substituted in starting at that position and False means that it cannot be. Default is None.

output_mask: torch.Tensor or None, optional

A mask on the outputs from the model to consider. True means to include the outputs in the loss, False means to exclude those outputs from the loss. If None, use all outputs. Default is None.

tol: float, optional

A threshold on the amount of improvement necessary according to loss, where the procedure will stop once the improvement is below. Default is 1e-3.

max_iter: int, optional

The maximum number of iterations to run before terminating the procedure. Set to -1 for no limit. Default is -1.

args: tuple or list or None, optional

An optional set of additional arguments to pass into the model. If provided, each element in the tuple or list is one input to the model and the element must be formatted to be the same batch size as X. If None, no additional arguments are passed into the forward function. Default is None.

alphabetset or tuple or list, optional

A pre-defined alphabet where the ordering of the symbols is the same as the index into the returned tensor, i.e., for the alphabet [‘A’, ‘B’] the returned tensor will have a 1 at index 0 if the character was ‘A’. Characters outside the alphabet are ignored and none of the indexes are set to 1. This is not necessary or used if a one-hot encoded tensor is provided for the motif. Default is [‘A’, ‘C’, ‘G’, ‘T’].

batch_size: int, optional

The number of examples to make predictions for at a time. Default is 32.

device: str or torch.device, optional

The device to move the model and batches to when making predictions. If set to ‘cuda’ without a GPU, this function will crash and must be set to ‘cpu’. Default is ‘cuda’.

verbose: bool, optional

Whether to display a progress bar during predictions. Default is False.

Returns

X: torch.Tensor, shape=(-1, len(alphabet), length)

The edited sequence.