torch.tril¶
- torch.tril(input, diagonal=0, *, out=None) Tensor ¶
傳回矩陣(2 維張量)或矩陣批次
input
的下三角部分,結果張量out
的其他元素設定為 0。矩陣的下三角部分定義為對角線上和對角線下的元素。
參數
diagonal
控制要考慮的對角線。 如果diagonal
= 0,則保留主對角線及其以下的所有元素。 正值包含與主對角線之上相同數量的對角線,類似地,負值排除與主對角線之下相同數量的對角線。 主對角線是索引的集合 對於 其中 是矩陣的維度。範例
>>> a = torch.randn(3, 3) >>> a tensor([[-1.0813, -0.8619, 0.7105], [ 0.0935, 0.1380, 2.2112], [-0.3409, -0.9828, 0.0289]]) >>> torch.tril(a) tensor([[-1.0813, 0.0000, 0.0000], [ 0.0935, 0.1380, 0.0000], [-0.3409, -0.9828, 0.0289]]) >>> b = torch.randn(4, 6) >>> b tensor([[ 1.2219, 0.5653, -0.2521, -0.2345, 1.2544, 0.3461], [ 0.4785, -0.4477, 0.6049, 0.6368, 0.8775, 0.7145], [ 1.1502, 3.2716, -1.1243, -0.5413, 0.3615, 0.6864], [-0.0614, -0.7344, -1.3164, -0.7648, -1.4024, 0.0978]]) >>> torch.tril(b, diagonal=1) tensor([[ 1.2219, 0.5653, 0.0000, 0.0000, 0.0000, 0.0000], [ 0.4785, -0.4477, 0.6049, 0.0000, 0.0000, 0.0000], [ 1.1502, 3.2716, -1.1243, -0.5413, 0.0000, 0.0000], [-0.0614, -0.7344, -1.3164, -0.7648, -1.4024, 0.0000]]) >>> torch.tril(b, diagonal=-1) tensor([[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [ 0.4785, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [ 1.1502, 3.2716, 0.0000, 0.0000, 0.0000, 0.0000], [-0.0614, -0.7344, -1.3164, 0.0000, 0.0000, 0.0000]])