注意
點擊這裡下載完整的範例程式碼
Nested Tensors 入門¶
建立於: 2022 年 8 月 2 日 | 最後更新: 2024 年 5 月 3 日 | 最後驗證: 2024 年 11 月 5 日
Nested tensors 推廣了常規密集 tensors 的形狀,允許表示參差不齊大小的資料。
對於常規 tensor,每個維度都是常規的並且具有大小
對於 nested tensor,並非所有維度都具有常規大小;其中一些是不規則的
Nested tensors 是在各種領域內表示循序資料的自然解決方案
在 NLP 中,句子可以有不同的長度,因此一批句子會形成一個 nested tensor
在 CV 中,影像可以有不同的形狀,因此一批影像會形成一個 nested tensor
在本教學中,我們將示範 nested tensors 的基本用法,並透過真實世界的範例激發它們在處理不同長度的循序資料方面的用處。特別是,它們對於構建可以有效地處理不規則循序輸入的 transformers 非常寶貴。以下,我們提供了一個使用 nested tensors 的多頭注意力機制實作,結合使用 torch.compile
,其效能優於天真地在具有 padding 的 tensors 上進行操作。
Nested tensors 目前是原型功能,可能會有所變更。
import numpy as np
import timeit
import torch
import torch.nn.functional as F
from torch import nn
torch.manual_seed(1)
np.random.seed(1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Nested tensor 初始化¶
從 Python 前端,可以從 tensors 的清單建立 nested tensor。我們將 nt[i] 表示為 nestedtensor 的第 i 個 tensor 成分。
nt = torch.nested.nested_tensor([torch.arange(12).reshape(
2, 6), torch.arange(18).reshape(3, 6)], dtype=torch.float, device=device)
print(f"{nt=}")
透過將每個底層 tensor 填充到相同的形狀,nestedtensor 可以轉換為常規 tensor。
padded_out_tensor = torch.nested.to_padded_tensor(nt, padding=0.0)
print(f"{padded_out_tensor=}")
所有 tensors 都具有一個屬性,用於確定它們是否為 nested;
print(f"nt is nested: {nt.is_nested}")
print(f"padded_out_tensor is nested: {padded_out_tensor.is_nested}")
從不規則形狀的 tensors 批次建構 nestedtensors 是很常見的。即,維度 0 假定為批次維度。索引維度 0 會傳回第一個底層 tensor 成分。
print("First underlying tensor component:", nt[0], sep='\n')
print("last column of 2nd underlying tensor component:", nt[1, :, -1], sep='\n')
# When indexing a nestedtensor's 0th dimension, the result is a regular tensor.
print(f"First underlying tensor component is nested: {nt[0].is_nested}")
一個重要的注意事項是,尚未支援在維度 0 中進行切片。這表示目前無法建構組合底層 tensor 成分的視圖。
Nested Tensor 操作¶
由於必須為 nestedtensors 顯式實作每個操作,因此 nestedtensors 的操作涵蓋範圍目前比常規 tensors 的範圍窄。目前,僅涵蓋基本操作,例如 index、dropout、softmax、transpose、reshape、linear、bmm。但是,涵蓋範圍正在擴大。如果您需要某些操作,請提出一個 issue 以協助我們確定涵蓋範圍的優先順序。
reshape
reshape 操作用於變更 tensor 的形狀。其常規 tensors 的完整語意可以在這裡找到。對於常規 tensors,當指定新的形狀時,單個維度可能是 -1,在這種情況下,它會從剩餘的維度和元素數量推斷出來。
nestedtensors 的語意類似,除了 -1 不再推斷。相反,它繼承了舊的大小(此處 nt[0]
為 2,nt[1]
為 3)。-1 是為不規則維度指定唯一合法的尺寸。
nt_reshaped = nt.reshape(2, -1, 2, 3)
print(f"{nt_reshaped=}")
transpose
transpose 操作用於交換 tensor 的兩個維度。其完整語意可以在這裡找到。請注意,對於 nestedtensors,維度 0 是特殊的;它被假定為批次維度,因此不支援涉及 nestedtensor 維度 0 的轉置。
nt_transposed = nt_reshaped.transpose(1, 2)
print(f"{nt_transposed=}")
其他
其他操作與常規 tensors 具有相同的語意。在 nestedtensor 上應用該操作等效於將該操作應用於底層 tensor 成分,結果也是 nestedtensor。
nt_mm = torch.nested.nested_tensor([torch.randn((2, 3, 4)), torch.randn((2, 3, 5))], device=device)
nt3 = torch.matmul(nt_transposed, nt_mm)
print(f"Result of Matmul:\n {nt3}")
nt4 = F.dropout(nt3, 0.1)
print(f"Result of Dropout:\n {nt4}")
nt5 = F.softmax(nt4, -1)
print(f"Result of Softmax:\n {nt5}")
為什麼要使用 Nested Tensor¶
當資料是循序時,通常每個樣本具有不同的長度。例如,在一批句子中,每個句子具有不同數量的單字。處理不同序列的常見技術是手動將每個資料 tensor 填充到相同的形狀,以便形成批次。例如,我們有 2 個長度不同的句子和一個詞彙表。為了將其表示為單個 tensor,我們使用 0 填充到批次中的最大長度。
sentences = [["goodbye", "padding"],
["embrace", "nested", "tensor"]]
vocabulary = {"goodbye": 1.0, "padding": 2.0,
"embrace": 3.0, "nested": 4.0, "tensor": 5.0}
padded_sentences = torch.tensor([[1.0, 2.0, 0.0],
[3.0, 4.0, 5.0]])
nested_sentences = torch.nested.nested_tensor([torch.tensor([1.0, 2.0]),
torch.tensor([3.0, 4.0, 5.0])])
print(f"{padded_sentences=}")
print(f"{nested_sentences=}")
這種將一批資料填充到其最大長度的技術並非最佳。填充的資料不是計算所必需的,並且透過分配比必要的更大的 tensors 來浪費記憶體。此外,並非所有操作在應用於填充資料時都具有相同的語意。對於矩陣乘法,為了忽略填充的條目,需要使用 0 填充,而對於 softmax,則必須使用 -inf 填充以忽略特定條目。nested tensor 的主要目標是使用標準 PyTorch tensor UX 促進對不規則資料的操作,從而消除對效率低下且複雜的填充和遮罩的需求。
padded_sentences_for_softmax = torch.tensor([[1.0, 2.0, float("-inf")],
[3.0, 4.0, 5.0]])
print(F.softmax(padded_sentences_for_softmax, -1))
print(F.softmax(nested_sentences, -1))
讓我們看看一個實際的範例: Transformers 中使用的多頭注意力機制組件。 我們可以透過某種方式實作此功能,使其可以在填充或 nested tensors 上運作。
class MultiHeadAttention(nn.Module):
"""
Computes multi-head attention. Supports nested or padded tensors.
Args:
E_q (int): Size of embedding dim for query
E_k (int): Size of embedding dim for key
E_v (int): Size of embedding dim for value
E_total (int): Total embedding dim of combined heads post input projection. Each head
has dim E_total // nheads
nheads (int): Number of heads
dropout_p (float, optional): Dropout probability. Default: 0.0
"""
def __init__(self, E_q: int, E_k: int, E_v: int, E_total: int,
nheads: int, dropout_p: float = 0.0):
super().__init__()
self.nheads = nheads
self.dropout_p = dropout_p
self.query_proj = nn.Linear(E_q, E_total)
self.key_proj = nn.Linear(E_k, E_total)
self.value_proj = nn.Linear(E_v, E_total)
E_out = E_q
self.out_proj = nn.Linear(E_total, E_out)
assert E_total % nheads == 0, "Embedding dim is not divisible by nheads"
self.E_head = E_total // nheads
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
"""
Forward pass; runs the following process:
1. Apply input projection
2. Split heads and prepare for SDPA
3. Run SDPA
4. Apply output projection
Args:
query (torch.Tensor): query of shape (N, L_t, E_q)
key (torch.Tensor): key of shape (N, L_s, E_k)
value (torch.Tensor): value of shape (N, L_s, E_v)
Returns:
attn_output (torch.Tensor): output of shape (N, L_t, E_q)
"""
# Step 1. Apply input projection
# TODO: demonstrate packed projection
query = self.query_proj(query)
key = self.key_proj(key)
value = self.value_proj(value)
# Step 2. Split heads and prepare for SDPA
# reshape query, key, value to separate by head
# (N, L_t, E_total) -> (N, L_t, nheads, E_head) -> (N, nheads, L_t, E_head)
query = query.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)
# (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)
key = key.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)
# (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)
value = value.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)
# Step 3. Run SDPA
# (N, nheads, L_t, E_head)
attn_output = F.scaled_dot_product_attention(
query, key, value, dropout_p=dropout_p, is_causal=True)
# (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total)
attn_output = attn_output.transpose(1, 2).flatten(-2)
# Step 4. Apply output projection
# (N, L_t, E_total) -> (N, L_t, E_out)
attn_output = self.out_proj(attn_output)
return attn_output
設定遵循 Transformer 論文 的超參數
N = 512
E_q, E_k, E_v, E_total = 512, 512, 512, 512
E_out = E_q
nheads = 8
除了 dropout 機率:設定為 0 以進行正確性檢查
dropout_p = 0.0
讓我們從 Zipf 定律生成一些真實的假資料。
def zipf_sentence_lengths(alpha: float, batch_size: int) -> torch.Tensor:
# generate fake corpus by unigram Zipf distribution
# from wikitext-2 corpus, we get rank "." = 3, "!" = 386, "?" = 858
sentence_lengths = np.empty(batch_size, dtype=int)
for ibatch in range(batch_size):
sentence_lengths[ibatch] = 1
word = np.random.zipf(alpha)
while word != 3 and word != 386 and word != 858:
sentence_lengths[ibatch] += 1
word = np.random.zipf(alpha)
return torch.tensor(sentence_lengths)
建立巢狀張量批次輸入
def gen_batch(N, E_q, E_k, E_v, device):
# generate semi-realistic data using Zipf distribution for sentence lengths
sentence_lengths = zipf_sentence_lengths(alpha=1.2, batch_size=N)
# Note: the torch.jagged layout is a nested tensor layout that supports a single ragged
# dimension and works with torch.compile. The batch items each have shape (B, S*, D)
# where B = batch size, S* = ragged sequence length, and D = embedding dimension.
query = torch.nested.nested_tensor([
torch.randn(l.item(), E_q, device=device)
for l in sentence_lengths
], layout=torch.jagged)
key = torch.nested.nested_tensor([
torch.randn(s.item(), E_k, device=device)
for s in sentence_lengths
], layout=torch.jagged)
value = torch.nested.nested_tensor([
torch.randn(s.item(), E_v, device=device)
for s in sentence_lengths
], layout=torch.jagged)
return query, key, value, sentence_lengths
query, key, value, sentence_lengths = gen_batch(N, E_q, E_k, E_v, device)
產生 query、key、value 的填充形式以進行比較
def jagged_to_padded(jt, padding_val):
# TODO: do jagged -> padded directly when this is supported
return torch.nested.to_padded_tensor(
torch.nested.nested_tensor(list(jt.unbind())),
padding_val)
padded_query, padded_key, padded_value = (
jagged_to_padded(t, 0.0) for t in (query, key, value)
)
構建模型
mha = MultiHeadAttention(E_q, E_k, E_v, E_total, nheads, dropout_p).to(device=device)
檢查正確性和效能
def benchmark(func, *args, **kwargs):
torch.cuda.synchronize()
begin = timeit.default_timer()
output = func(*args, **kwargs)
torch.cuda.synchronize()
end = timeit.default_timer()
return output, (end - begin)
output_nested, time_nested = benchmark(mha, query, key, value)
output_padded, time_padded = benchmark(mha, padded_query, padded_key, padded_value)
# padding-specific step: remove output projection bias from padded entries for fair comparison
for i, entry_length in enumerate(sentence_lengths):
output_padded[i, entry_length:] = 0.0
print("=== without torch.compile ===")
print("nested and padded calculations differ by", (jagged_to_padded(output_nested, 0.0) - output_padded).abs().max().item())
print("nested tensor multi-head attention takes", time_nested, "seconds")
print("padded tensor multi-head attention takes", time_padded, "seconds")
# warm up compile first...
compiled_mha = torch.compile(mha)
compiled_mha(query, key, value)
# ...now benchmark
compiled_output_nested, compiled_time_nested = benchmark(
compiled_mha, query, key, value)
# warm up compile first...
compiled_mha(padded_query, padded_key, padded_value)
# ...now benchmark
compiled_output_padded, compiled_time_padded = benchmark(
compiled_mha, padded_query, padded_key, padded_value)
# padding-specific step: remove output projection bias from padded entries for fair comparison
for i, entry_length in enumerate(sentence_lengths):
compiled_output_padded[i, entry_length:] = 0.0
print("=== with torch.compile ===")
print("nested and padded calculations differ by", (jagged_to_padded(compiled_output_nested, 0.0) - compiled_output_padded).abs().max().item())
print("nested tensor multi-head attention takes", compiled_time_nested, "seconds")
print("padded tensor multi-head attention takes", compiled_time_padded, "seconds")
請注意,如果沒有 torch.compile
,python 子類別巢狀張量的 overhead 可能使其比在填充張量上進行的等效計算更慢。但是,一旦啟用 torch.compile
,對巢狀張量進行操作會帶來數倍的加速。隨著批次中填充百分比的增加,避免在填充上浪費計算變得更有價值。
print(f"Nested speedup: {compiled_time_padded / compiled_time_nested:.3f}")
結論¶
在本教程中,我們學習了如何使用巢狀張量執行基本操作,以及如何以避免在填充上進行計算的方式實現 transformer 的多頭注意力機制。 如需更多資訊,請查看 torch.nested 命名空間的文件。
腳本總運行時間: ( 0 分鐘 0.000 秒)