快捷方式

pad

class tensordict.pad(tensordict: T, pad_size: Sequence[int], value: float = 0.0)

使用常數值沿著批次維度填充 tensordict 中的所有張量,並傳回新的 tensordict。

參數:
  • tensordict (TensorDict) – 要填充的 tensordict

  • pad_size (Sequence[int]) – 填充大小,用於從第一個維度開始並向前移動,填充 tensordict 的一些批次維度。[len(pad_size) / 2] 個批次大小的維度將被填充。 例如,僅填充第一個維度,則 pad 的形式為 (padding_left, padding_right)。 要填充兩個維度,(padding_left, padding_right, padding_top, padding_bottom) 等等。 pad_size 必須是偶數,且小於或等於批次維度數量的兩倍。

  • value (float, optional) – 填充時使用的填充值,預設為 0.0

傳回:

沿著批次維度填充的新 TensorDict

範例

>>> from tensordict import TensorDict, pad
>>> import torch
>>> td = TensorDict({'a': torch.ones(3, 4, 1),
...     'b': torch.ones(3, 4, 1, 1)}, batch_size=[3, 4])
>>> dim0_left, dim0_right, dim1_left, dim1_right = [0, 1, 0, 2]
>>> padded_td = pad(td, [dim0_left, dim0_right, dim1_left, dim1_right], value=0.0)
>>> print(padded_td.batch_size)
torch.Size([4, 6])
>>> print(padded_td.get("a").shape)
torch.Size([4, 6, 1])
>>> print(padded_td.get("b").shape)
torch.Size([4, 6, 1, 1])

文件

取得 PyTorch 的完整開發者文件

檢視文件

教學

取得初學者和進階開發者的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得問題解答

檢視資源