torch_argmax_mask

torch_argmax_mask(q: torch.Tensor, dim: int)

Returns a random tie-breaking argmax mask

Example

>>> import torch
>>> torch.manual_seed(1337)
>>> q = torch.ones(3, 2)
>>> torch_argmax_mask(q, 1)
# tensor([[False,  True],
#         [ True, False],
#         [ True, False]])
>>> torch_argmax_mask(q, 1)
# tensor([[False,  True],
#         [False,  True],
#         [ True, False]])