• 文件 >
  • 具名張量運算符覆蓋範圍
捷徑

具名張量運算符覆蓋範圍

請先閱讀 具名張量 以取得具名張量的簡介。

此文件是有關名稱推斷的參考,名稱推斷是一個定義具名張量如何

  1. 使用名稱來提供額外的自動執行階段正確性檢查

  2. 將名稱從輸入張量傳播到輸出張量

以下是所有支援具名張量的運算及其相關名稱推斷規則的清單。

如果您在此處未看到列出的操作,但它有助於您的使用案例,請搜尋是否已提交問題,如果沒有,請提交一個問題

警告

具名張量 API 是一個實驗性的 API,可能會發生變更。

支援的操作

API

名稱推論規則

Tensor.abs(), torch.abs()

保留輸入名稱

Tensor.abs_()

保留輸入名稱

Tensor.acos(), torch.acos()

保留輸入名稱

Tensor.acos_()

保留輸入名稱

Tensor.add(), torch.add()

統一來自輸入的名稱

Tensor.add_()

統一來自輸入的名稱

Tensor.addmm(), torch.addmm()

收縮维度

Tensor.addmm_()

收縮维度

Tensor.addmv(), torch.addmv()

收縮维度

Tensor.addmv_()

收縮维度

Tensor.align_as()

請參閱文件

Tensor.align_to()

請參閱文件

Tensor.all(), torch.all()

Tensor.any(), torch.any()

Tensor.asin(), torch.asin()

保留輸入名稱

Tensor.asin_()

保留輸入名稱

Tensor.atan(), torch.atan()

保留輸入名稱

Tensor.atan2(), torch.atan2()

統一來自輸入的名稱

Tensor.atan2_()

統一來自輸入的名稱

Tensor.atan_()

保留輸入名稱

Tensor.bernoulli(), torch.bernoulli()

保留輸入名稱

Tensor.bernoulli_()

Tensor.bfloat16()

保留輸入名稱

Tensor.bitwise_not(), torch.bitwise_not()

保留輸入名稱

Tensor.bitwise_not_()

Tensor.bmm(), torch.bmm()

收縮维度

Tensor.bool()

保留輸入名稱

Tensor.byte()

保留輸入名稱

torch.cat()

統一來自輸入的名稱

Tensor.cauchy_()

Tensor.ceil(), torch.ceil()

保留輸入名稱

Tensor.ceil_()

Tensor.char()

保留輸入名稱

Tensor.chunk(), torch.chunk()

保留輸入名稱

Tensor.clamp(), torch.clamp()

保留輸入名稱

Tensor.clamp_()

Tensor.copy_()

out 函數和 in-place 變體

Tensor.cos(), torch.cos()

保留輸入名稱

Tensor.cos_()

Tensor.cosh(), torch.cosh()

保留輸入名稱

Tensor.cosh_()

Tensor.acosh(), torch.acosh()

保留輸入名稱

Tensor.acosh_()

Tensor.cpu()

保留輸入名稱

Tensor.cuda()

保留輸入名稱

Tensor.cumprod(), torch.cumprod()

保留輸入名稱

Tensor.cumsum(), torch.cumsum()

保留輸入名稱

Tensor.data_ptr()

Tensor.deg2rad(), torch.deg2rad()

保留輸入名稱

Tensor.deg2rad_()

Tensor.detach(), torch.detach()

保留輸入名稱

Tensor.detach_()

Tensor.device, torch.device()

Tensor.digamma(), torch.digamma()

保留輸入名稱

Tensor.digamma_()

Tensor.dim()

Tensor.div(), torch.div()

統一來自輸入的名稱

Tensor.div_()

統一來自輸入的名稱

Tensor.dot(), torch.dot()

Tensor.double()

保留輸入名稱

Tensor.element_size()

torch.empty()

工廠函數 (Factory functions)

torch.empty_like()

工廠函數 (Factory functions)

Tensor.eq(), torch.eq()

統一來自輸入的名稱

Tensor.erf(), torch.erf()

保留輸入名稱

Tensor.erf_()

Tensor.erfc(), torch.erfc()

保留輸入名稱

Tensor.erfc_()

Tensor.erfinv(), torch.erfinv()

保留輸入名稱

Tensor.erfinv_()

Tensor.exp(), torch.exp()

保留輸入名稱

Tensor.exp_()

Tensor.expand()

保留輸入名稱

Tensor.expm1(), torch.expm1()

保留輸入名稱

Tensor.expm1_()

Tensor.exponential_()

Tensor.fill_()

Tensor.flatten(), torch.flatten()

請參閱文件

Tensor.float()

保留輸入名稱

Tensor.floor(), torch.floor()

保留輸入名稱

Tensor.floor_()

Tensor.frac(), torch.frac()

保留輸入名稱

Tensor.frac_()

Tensor.ge(), torch.ge()

統一來自輸入的名稱

Tensor.get_device(), torch.get_device()

Tensor.grad

Tensor.gt(), torch.gt()

統一來自輸入的名稱

Tensor.half()

保留輸入名稱

Tensor.has_names()

請參閱文件

Tensor.index_fill(), torch.index_fill()

保留輸入名稱

Tensor.index_fill_()

Tensor.int()

保留輸入名稱

Tensor.is_contiguous()

Tensor.is_cuda

Tensor.is_floating_point(), torch.is_floating_point()

Tensor.is_leaf

Tensor.is_pinned()

Tensor.is_shared()

Tensor.is_signed(), torch.is_signed()

Tensor.is_sparse

Tensor.is_sparse_csr

torch.is_tensor()

Tensor.item()

Tensor.itemsize

Tensor.kthvalue(), torch.kthvalue()

移除維度 (Removes dimensions)

Tensor.le(), torch.le()

統一來自輸入的名稱

Tensor.log(), torch.log()

保留輸入名稱

Tensor.log10(), torch.log10()

保留輸入名稱

Tensor.log10_()

Tensor.log1p(), torch.log1p()

保留輸入名稱

Tensor.log1p_()

Tensor.log2(), torch.log2()

保留輸入名稱

Tensor.log2_()

Tensor.log_()

Tensor.log_normal_()

Tensor.logical_not(), torch.logical_not()

保留輸入名稱

Tensor.logical_not_()

Tensor.logsumexp(), torch.logsumexp()

移除維度 (Removes dimensions)

Tensor.long()

保留輸入名稱

Tensor.lt(), torch.lt()

統一來自輸入的名稱

torch.manual_seed()

Tensor.masked_fill(), torch.masked_fill()

保留輸入名稱

Tensor.masked_fill_()

Tensor.masked_select(), torch.masked_select()

對齊遮罩以符合輸入,然後從輸入張量統一名稱

Tensor.matmul(), torch.matmul()

收縮维度

Tensor.mean(), torch.mean()

移除維度 (Removes dimensions)

Tensor.median(), torch.median()

移除維度 (Removes dimensions)

Tensor.nanmedian(), torch.nanmedian()

移除維度 (Removes dimensions)

Tensor.mm(), torch.mm()

收縮维度

Tensor.mode(), torch.mode()

移除維度 (Removes dimensions)

Tensor.mul(), torch.mul()

統一來自輸入的名稱

Tensor.mul_()

統一來自輸入的名稱

Tensor.mv(), torch.mv()

收縮维度

Tensor.names

請參閱文件

Tensor.narrow(), torch.narrow()

保留輸入名稱

Tensor.nbytes

Tensor.ndim

Tensor.ndimension()

Tensor.ne(), torch.ne()

統一來自輸入的名稱

Tensor.neg(), torch.neg()

保留輸入名稱

Tensor.neg_()

torch.normal()

保留輸入名稱

Tensor.normal_()

Tensor.numel(), torch.numel()

torch.ones()

工廠函數 (Factory functions)

Tensor.pow(), torch.pow()

統一來自輸入的名稱

Tensor.pow_()

Tensor.prod(), torch.prod()

移除維度 (Removes dimensions)

Tensor.rad2deg(), torch.rad2deg()

保留輸入名稱

Tensor.rad2deg_()

torch.rand()

工廠函數 (Factory functions)

torch.rand()

工廠函數 (Factory functions)

torch.randn()

工廠函數 (Factory functions)

torch.randn()

工廠函數 (Factory functions)

Tensor.random_()

Tensor.reciprocal(), torch.reciprocal()

保留輸入名稱

Tensor.reciprocal_()

Tensor.refine_names()

請參閱文件

Tensor.register_hook()

Tensor.register_post_accumulate_grad_hook()

Tensor.rename()

請參閱文件

Tensor.rename_()

請參閱文件

Tensor.requires_grad

Tensor.requires_grad_()

Tensor.resize_()

僅允許不改變形狀的調整大小

Tensor.resize_as_()

僅允許不改變形狀的調整大小

Tensor.round(), torch.round()

保留輸入名稱

Tensor.round_()

Tensor.rsqrt(), torch.rsqrt()

保留輸入名稱

Tensor.rsqrt_()

Tensor.select(), torch.select()

移除維度 (Removes dimensions)

Tensor.short()

保留輸入名稱

Tensor.sigmoid(), torch.sigmoid()

保留輸入名稱

Tensor.sigmoid_()

Tensor.sign(), torch.sign()

保留輸入名稱

Tensor.sign_()

Tensor.sgn(), torch.sgn()

保留輸入名稱

Tensor.sgn_()

Tensor.sin(), torch.sin()

保留輸入名稱

Tensor.sin_()

Tensor.sinh(), torch.sinh()

保留輸入名稱

Tensor.sinh_()

Tensor.asinh(), torch.asinh()

保留輸入名稱

Tensor.asinh_()

Tensor.size()

Tensor.softmax(), torch.softmax()

保留輸入名稱

Tensor.split(), torch.split()

保留輸入名稱

Tensor.sqrt(), torch.sqrt()

保留輸入名稱

Tensor.sqrt_()

Tensor.squeeze(), torch.squeeze()

移除維度 (Removes dimensions)

Tensor.std(), torch.std()

移除維度 (Removes dimensions)

torch.std_mean()

移除維度 (Removes dimensions)

Tensor.stride()

Tensor.sub(), torch.sub()

統一來自輸入的名稱

Tensor.sub_()

統一來自輸入的名稱

Tensor.sum(), torch.sum()

移除維度 (Removes dimensions)

Tensor.tan(), torch.tan()

保留輸入名稱

Tensor.tan_()

Tensor.tanh(), torch.tanh()

保留輸入名稱

Tensor.tanh_()

Tensor.atanh(), torch.atanh()

保留輸入名稱

Tensor.atanh_()

torch.tensor()

工廠函數 (Factory functions)

Tensor.to()

保留輸入名稱

Tensor.topk(), torch.topk()

移除維度 (Removes dimensions)

Tensor.transpose(), torch.transpose()

置換維度

Tensor.trunc(), torch.trunc()

保留輸入名稱

Tensor.trunc_()

Tensor.type()

Tensor.type_as()

保留輸入名稱

Tensor.unbind(), torch.unbind()

移除維度 (Removes dimensions)

Tensor.unflatten()

請參閱文件

Tensor.uniform_()

Tensor.var(), torch.var()

移除維度 (Removes dimensions)

torch.var_mean()

移除維度 (Removes dimensions)

Tensor.zero_()

torch.zeros()

工廠函數 (Factory functions)

保留輸入名稱

所有逐點一元函數都遵循此規則,以及其他一些一元函數。

  • 檢查名稱:無

  • 傳播名稱:輸入張量的名稱會傳播到輸出。

>>> x = torch.randn(3, 3, names=('N', 'C'))
>>> x.abs().names
('N', 'C')

移除維度

所有縮減運算,例如sum(),都會透過縮減所需的維度來移除維度。其他運算,例如select()squeeze(),也會移除維度。

無論何時可以將整數維度索引傳遞給運算子,也可以傳遞維度名稱。 接受維度索引列表的函數也可以接受維度名稱列表。

  • 檢查名稱:如果將 dimdims 作為名稱列表傳入,請檢查這些名稱是否存在於 self 中。

  • 傳播名稱:如果輸入張量的維度由 dimdims 指定,且不存在於輸出張量中,則這些維度對應的名稱不會出現在 output.names 中。

>>> x = torch.randn(1, 3, 3, 3, names=('N', 'C', 'H', 'W'))
>>> x.squeeze('N').names
('C', 'H', 'W')

>>> x = torch.randn(3, 3, 3, 3, names=('N', 'C', 'H', 'W'))
>>> x.sum(['N', 'C']).names
('H', 'W')

# Reduction ops with keepdim=True don't actually remove dimensions.
>>> x = torch.randn(3, 3, 3, 3, names=('N', 'C', 'H', 'W'))
>>> x.sum(['N', 'C'], keepdim=True).names
('N', 'C', 'H', 'W')

統一來自輸入的名稱

所有二元算術運算都遵循此規則。廣播運算仍然從右側進行位置廣播,以保持與未命名張量的相容性。若要按名稱執行顯式廣播,請使用 Tensor.align_as()

  • 檢查名稱:所有名稱必須從右側進行位置匹配。也就是說,在 tensor + other 中,對於 (-min(tensor.dim(), other.dim()) + 1, -1] 中的所有 imatch(tensor.names[i], other.names[i]) 必須為 true。

  • 檢查名稱:此外,所有命名的維度必須從右側對齊。在匹配期間,如果我們將命名的維度 A 與未命名的維度 None 匹配,則 A 不得在具有未命名維度的張量中出現。

  • 傳播名稱:統一來自兩個張量的右側名稱對,以產生輸出名稱。

例如,

# tensor: Tensor[   N, None]
# other:  Tensor[None,    C]
>>> tensor = torch.randn(3, 3, names=('N', None))
>>> other = torch.randn(3, 3, names=(None, 'C'))
>>> (tensor + other).names
('N', 'C')

檢查名稱

  • match(tensor.names[-1], other.names[-1])True

  • match(tensor.names[-2], tensor.names[-2])True

  • 因為我們在 tensor 中將 None'C' 匹配,請檢查以確保 'C' 不存在於 tensor 中(它不存在)。

  • 檢查以確保 'N' 不存在於 other 中(它不存在)。

最後,輸出名稱使用 [unify('N', None), unify(None, 'C')] = ['N', 'C'] 計算。

更多範例

# Dimensions don't match from the right:
# tensor: Tensor[N, C]
# other:  Tensor[   N]
>>> tensor = torch.randn(3, 3, names=('N', 'C'))
>>> other = torch.randn(3, names=('N',))
>>> (tensor + other).names
RuntimeError: Error when attempting to broadcast dims ['N', 'C'] and dims
['N']: dim 'C' and dim 'N' are at the same position from the right but do
not match.

# Dimensions aren't aligned when matching tensor.names[-1] and other.names[-1]:
# tensor: Tensor[N, None]
# other:  Tensor[      N]
>>> tensor = torch.randn(3, 3, names=('N', None))
>>> other = torch.randn(3, names=('N',))
>>> (tensor + other).names
RuntimeError: Misaligned dims when attempting to broadcast dims ['N'] and
dims ['N', None]: dim 'N' appears in a different position from the right
across both lists.

注意

在最後兩個範例中,都可以通過名稱對齊張量,然後執行加法。使用 Tensor.align_as() 按名稱對齊張量,或使用 Tensor.align_to() 將張量對齊到自訂維度排序。

排列維度

某些運算,例如 Tensor.t(),會排列維度的順序。維度名稱附加到個別維度,因此它們也會被排列。

如果運算符採用位置索引 dim,它也可以將維度名稱作為 dim

  • 檢查名稱:如果 dim 作為名稱傳遞,請檢查它是否存在於張量中。

  • 傳播名稱:以與要排列的維度相同的方式排列維度名稱。

>>> x = torch.randn(3, 3, names=('N', 'C'))
>>> x.transpose('N', 'C').names
('C', 'N')

收縮維度

矩陣乘法函數遵循此規則的某些變體。讓我們首先瀏覽 torch.mm(),然後推廣批量矩陣乘法的規則。

對於 torch.mm(tensor, other)

  • 檢查名稱:無

  • 傳播名稱:結果名稱為 (tensor.names[-2], other.names[-1])

>>> x = torch.randn(3, 3, names=('N', 'D'))
>>> y = torch.randn(3, 3, names=('in', 'out'))
>>> x.mm(y).names
('N', 'out')

矩陣乘法本質上對兩個維度執行點積,將它們摺疊。當兩個張量進行矩陣相乘時,收縮的維度會消失,並且不會出現在輸出張量中。

torch.mv()torch.dot() 的工作方式類似:名稱推斷不檢查輸入名稱,並刪除點積中涉及的維度

>>> x = torch.randn(3, 3, names=('N', 'D'))
>>> y = torch.randn(3, names=('something',))
>>> x.mv(y).names
('N',)

現在,讓我們看看 torch.matmul(tensor, other)。假設 tensor.dim() >= 2other.dim() >= 2

  • 檢查名稱:檢查輸入的批次維度是否對齊且可廣播。有關輸入對齊的含義,請參閱 統一來自輸入的名稱

  • 傳播名稱:結果名稱通過統一批次維度並刪除收縮的維度來獲得:unify(tensor.names[:-2], other.names[:-2]) + (tensor.names[-2], other.names[-1])

範例

# Batch matrix multiply of matrices Tensor['C', 'D'] and Tensor['E', 'F'].
# 'A', 'B' are batch dimensions.
>>> x = torch.randn(3, 3, 3, 3, names=('A', 'B', 'C', 'D'))
>>> y = torch.randn(3, 3, 3, names=('B', 'E', 'F'))
>>> torch.matmul(x, y).names
('A', 'B', 'C', 'F')

最後,有許多 matmul 函數的融合 add 版本。也就是說,addmm()addmv()。這些被視為組合的名稱推斷,例如,mm()add() 的名稱推斷。

工廠函數

工廠函數現在採用一個新的 names 參數,該參數將名稱與每個維度關聯。

>>> torch.zeros(2, 3, names=('N', 'C'))
tensor([[0., 0., 0.],
        [0., 0., 0.]], names=('N', 'C'))

out 函數和 in-place 變體

指定為 out= 張量的張量具有以下行為

  • 如果它沒有命名的維度,則從操作計算出的名稱會傳播到它。

  • 如果它有任何命名的維度,則從操作計算出的名稱必須與現有名稱完全相等。否則,該操作會出錯。

所有原地 (in-place) 方法都會修改輸入,使其名稱等於從名稱推斷計算出的名稱。例如:

>>> x = torch.randn(3, 3)
>>> y = torch.randn(3, 3, names=('N', 'C'))
>>> x.names
(None, None)

>>> x += y
>>> x.names
('N', 'C')

文件

存取 PyTorch 的完整開發者文件

查看文件

教學

取得針對初學者和進階開發者的深度教學

查看教學

資源

尋找開發資源並獲得問題解答

查看資源