product
- tangermeme.product.apply_product(func, model, X, args, batch_size=32, device='cuda', additional_func_kwargs={}, verbose=False, **kwargs)
Apply a function on the cartesian product between X and each args.
This function will take the provided function and apply it in a batched manner across the cartesian product of X and each of the arguments provided in args. Because this is a cartesian product, the number of examples that need to be processed will quickly grow with respect to the number of arguments being passed in. Each of the tensors in args must be one input to model, in the order that they are specified by the forward function.
This function can accept in any other function – be it predictions, attributions, or marginalizations. If the provided function itself has parameters that need to be specified, you can provide them directly to this function in the order that they appear in the provided function.
Parameters
- func: function
A function, likely implemented in tangermeme, to apply in a batched manner across the product of examples.
- model: torch.nn.Module
The PyTorch model to use to make predictions.
- X: torch.tensor, shape=(-1, len(alphabet), length)
A one-hot encoded set of sequences to make predictions for.
- args: tuple or list
A set of additional arguments to pass into the model. Each element in args should be one tensor that is input to the model. The elements do not need to be the same size as each other as a product will be constructed over all of them, as well as with X. If you only want to use one value for an argument across all function applications
- batch_size: int, optional
The number of examples to make predictions for at a time. Default is 32.
- device: str or torch.device
The device to move the model and batches to when making predictions. If set to ‘cuda’ without a GPU, this function will crash and must be set to ‘cpu’. Default is ‘cuda’.
- additional_func_kwargs: dict, optional
Additional named arguments to pass into the function when it is called. This is provided as an alternate path to route arguments into the function in case they overlap, name-wise, with those in this function, or if you want to be absolutely sure that the arguments are making their way into the function. Default is {}.
- verbose: bool, optional
Whether to display a progress bar as spacings are evaluated. Default is False.
- kwargs: optional
Additional named arguments that will get passed into the function when it is called. Default is no arguments are passed in.
Returns
- y: torch.Tensor or list/tuple of torch.Tensors
The output from the model for each input example. The precise format is determined by the model. If the model outputs a single tensor, y is a single tensor concatenated across all batches. If the model outputs multiple tensors, y is a list of tensors which are each concatenated across all batches.