torch_argmax_mask¶
-
torch_argmax_mask
(q: torch.Tensor, dim: int)¶ Returns a random tie-breaking argmax mask
Example
>>> import torch >>> torch.manual_seed(1337) >>> q = torch.ones(3, 2) >>> torch_argmax_mask(q, 1) # tensor([[False, True], # [ True, False], # [ True, False]]) >>> torch_argmax_mask(q, 1) # tensor([[False, True], # [False, True], # [ True, False]])