快捷方式

torch.masked

簡介

動機

警告

遮罩張量的 PyTorch API 處於原型階段,未來可能會或可能不會更改。

MaskedTensor 作為 torch.Tensor 的擴展,為使用者提供能力

  • 使用任何遮罩語義 (例如,可變長度張量、nan* 運算符等)。

  • 區分 0 和 NaN 梯度。

  • 各種稀疏應用 (請參閱下面的教學)。

「已指定」和「未指定」在 PyTorch 中有很長的歷史,但沒有正式的語義,當然也沒有一致性;事實上,MaskedTensor 的誕生源於 vanilla torch.Tensor 類別無法妥善解決的問題累積。因此,MaskedTensor 的主要目標是成為 PyTorch 中「已指定」和「未指定」值的真實來源,這些值是一等公民,而不是事後才考慮。反過來,這應該進一步釋放稀疏性的潛力、實現更安全和更一致的運算符,並為用戶和開發人員提供更流暢和更直觀的體驗。

什麼是 MaskedTensor?

MaskedTensor 是一個張量子類別,由 1) 輸入 (資料) 和 2) 遮罩組成。遮罩告訴我們應該包含或忽略輸入中的哪些條目。

舉例來說,假設我們要遮罩掉所有等於 0 的值 (以灰色表示) 並取最大值。

_images/tensor_comparison.jpg

上面是 vanilla 張量範例,而下面是 MaskedTensor,其中所有的 0 都被遮罩掉。 這顯然會產生不同的結果,具體取決於我們是否有遮罩,但這種靈活的結構允許使用者有系統地忽略他們在計算過程中想要忽略的任何元素。

我們已經編寫了許多現有的教學,以幫助使用者入門,例如:

支援的運算符

一元運算符

一元運算符是只包含單個輸入的運算符。 將它們應用於 MaskedTensors 相對簡單:如果在給定索引處資料被遮罩掉,我們就應用運算符,否則我們將繼續遮罩掉資料。

可用的一元運算符包括:

abs

計算 input 中每個元素的絕對值。

absolute

別名,請參閱 torch.abs()

acos

計算 input 中每個元素的反餘弦值。

arccos

別名,請參閱 torch.acos()

acosh

傳回一個新張量,其中包含 input 元素的反雙曲餘弦值。

arccosh

別名,請參閱 torch.acosh()

angle

計算給定 input 張量的元素角度 (以弧度表示)。

asin

傳回一個新張量,其中包含 input 元素的反正弦值。

arcsin

別名,請參閱 torch.asin()

asinh

傳回一個新張量,其中包含 input 元素的反雙曲正弦值。

arcsinh

別名,請參閱 torch.asinh()

atan

傳回一個新張量,其中包含 input 元素的反正切值。

arctan

別名,請參閱 torch.atan()

atanh

傳回一個新張量,其中包含 input 元素的反雙曲正切值。

arctanh

別名,請參閱 torch.atanh()

bitwise_not

計算給定輸入張量的位元 NOT。

ceil

傳回一個新張量,其中包含 input 元素的 ceiling,即大於或等於每個元素的最小整數。

clamp

input 中的所有元素鉗制到範圍 [ min, max ] 中。

clip

別名,請參閱 torch.clamp()

conj_physical

計算給定 input 張量的元素共軛。

cos

傳回一個新張量,其中包含 input 元素的餘弦值。

cosh

傳回一個新張量,其中包含 input 元素的雙曲餘弦值。

deg2rad

傳回一個新張量,其中包含從角度(以度為單位)轉換為弧度的 input 的每個元素。

digamma

別名,請參閱 torch.special.digamma()

erf

別名,請參閱 torch.special.erf()

erfc

別名,請參閱 torch.special.erfc()

erfinv

別名,請參閱 torch.special.erfinv()

exp

返回一個新的張量,其元素為輸入張量 input 的指數。

exp2

torch.special.exp2() 的別名。

expm1

torch.special.expm1() 的別名。

fix

torch.trunc() 的別名

floor

返回一個新的張量,其元素為 input 的 floor,即小於或等於每個元素的最大整數。

frac

計算 input 中每個元素的小數部分。

lgamma

計算 input 上伽瑪函數絕對值的自然對數。

log

返回一個新的張量,其元素為 input 的自然對數。

log10

返回一個新的張量,其元素為 input 的以 10 為底的對數。

log1p

返回一個新的張量,其元素為 (1 + input) 的自然對數。

log2

返回一個新的張量,其元素為 input 的以 2 為底的對數。

logit

torch.special.logit() 的別名。

i0

torch.special.i0() 的別名。

isnan

返回一個新的張量,其布林元素表示 input 的每個元素是否為 NaN。

nan_to_num

NaN、正無窮大和負無窮大值在 input 中分別替換為 nanposinfneginf 指定的值。

neg

返回一個新的張量,其元素為 input 的負數。

negative

torch.neg() 的別名

positive

返回 input

pow

input 中的每個元素與 exponent 相乘,並返回一個包含結果的張量。

rad2deg

返回一個新的張量,其元素為 input 中的每個元素從弧度轉換為角度。

reciprocal

返回一個新的張量,其元素為 input 的倒數

round

input 的元素四捨五入到最接近的整數。

rsqrt

返回一個新的張量,其元素為 input 的每個元素的平方根的倒數。

sigmoid

torch.special.expit() 的別名。

sign

返回一個新的張量,其元素為 input 的符號。

sgn

此函數是 torch.sign() 對複數張量的擴展。

signbit

測試 input 的每個元素是否設置了符號位。

sin

返回一個新的張量,其元素為 input 的正弦。

sinc

torch.special.sinc() 的別名。

sinh

返回一個新的張量,其元素為 input 的雙曲正弦。

sqrt

返回一個新的張量,其元素為 input 的平方根。

square

返回一個新的張量,其元素為 input 的平方。

tan

返回一個新的張量,其元素為 input 的正切。

tanh

返回一個新的張量,其元素為 input 的雙曲正切。

trunc

返回一個新的張量,其元素為 input 的截斷整數值。

可用的 inplace 一元運算符是以上所有 除了

angle

計算給定 input 張量的元素角度 (以弧度表示)。

positive

返回 input

signbit

測試 input 的每個元素是否設置了符號位。

isnan

返回一個新的張量,其布林元素表示 input 的每個元素是否為 NaN。

二元運算符

您可能在教學中已經看到,MaskedTensor 也實作了二元運算,但前提是兩個 MaskedTensor 中的遮罩必須匹配,否則會引發錯誤。 如錯誤中所述,如果您需要對特定運算符的支持,或者對它們應該如何表現有建議的語義,請在 GitHub 上提出 issue。 目前,我們決定採用最保守的實作方式,以確保使用者確切地知道發生了什麼,並且有意識地決定使用遮罩語義。

可用的二元運算符有

add

將按 alpha 縮放的 other 加到 input

atan2

逐元素計算 inputi/otheri\text{input}_{i} / \text{other}_{i} 的反正切函數,並考量象限。

arctan2

torch.atan2() 的別名。

bitwise_and

計算 inputother 的位元 AND 運算。

bitwise_or

計算 inputother 的位元 OR 運算。

bitwise_xor

計算 inputother 的位元 XOR 運算。

bitwise_left_shift

計算 input 左移 other 位元的算術位移。

bitwise_right_shift

計算 input 右移 other 位元的算術位移。

div

將輸入 input 的每個元素除以 other 的對應元素。

divide

torch.div() 的別名。

floor_divide

fmod

逐元素應用 C++ 的 std::fmod

logaddexp

輸入指數之和的對數。

logaddexp2

以 2 為底數,輸入指數之和的對數。

mul

input 乘以 other

multiply

torch.mul() 的別名。

nextafter

逐元素返回 input 之後朝向 other 的下一個浮點數值。

remainder

逐元素計算 Python 的模數運算

sub

input 中減去按 alpha 縮放的 other

subtract

torch.sub() 的別名。

true_divide

torch.div()rounding_mode=None 的別名。

eq

逐元素計算相等性

ne

逐元素計算 inputother\text{input} \neq \text{other}

le

逐元素計算 inputother\text{input} \leq \text{other}

ge

逐元素計算 inputother\text{input} \geq \text{other}

greater

torch.gt() 的別名。

greater_equal

torch.ge() 的別名。

gt

逐元素計算 input>other\text{input} > \text{other}

less_equal

torch.le() 的別名。

lt

逐元素計算 input<other\text{input} < \text{other}

less

torch.lt() 的別名。

maximum (最大值)

計算 inputother 的逐元素最大值。

minimum (最小值)

計算 inputother 的逐元素最小值。

fmax

計算 inputother 的逐元素最大值。

fmin

計算 inputother 的逐元素最小值。

not_equal (不相等)

torch.ne() 的別名。

所有可用的原地二元運算符都是上述的,除了

logaddexp

輸入指數之和的對數。

logaddexp2

以 2 為底數,輸入指數之和的對數。

equal (相等)

如果兩個張量具有相同的大小和元素,則為 True,否則為 False

fmin

計算 inputother 的逐元素最小值。

minimum (最小值)

計算 inputother 的逐元素最小值。

fmax

計算 inputother 的逐元素最大值。

縮減 (Reductions)

以下縮減可用(具有自動微分支持)。 更多信息,概述 教學詳細介紹了一些縮減的範例,而 進階語義 教學則更深入地討論了我們如何決定某些縮減語義。

sum (總和)

返回 input 張量中所有元素的總和。

mean (平均值)

amin (最小值)

返回給定維度 diminput 張量每個切片的最小值。

amax (最大值)

返回給定維度 diminput 張量每個切片的最大值。

argmin (最小值的索引)

返回扁平化張量或沿維度的最小值(們)的索引

argmax (最大值的索引)

返回 input 張量中所有元素的最大值的索引。

prod (乘積)

返回 input 張量中所有元素的乘積。

all (全部)

測試 input 中的所有元素是否評估為 True

norm (範數)

返回給定張量的矩陣範數或向量範數。

var (變異數)

計算由 dim 指定的維度上的變異數。

std (標準差)

計算由 dim 指定的維度上的標準差。

視圖和選擇函數 (View and select functions)

我們還包含了一些視圖和選擇函數; 直觀地說,這些運算符將同時應用於數據和遮罩,然後將結果包裝在 MaskedTensor 中。 舉一個簡單的例子,考慮 select()

>>> data = torch.arange(12, dtype=torch.float).reshape(3, 4)
>>> data
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.]])
>>> mask = torch.tensor([[True, False, False, True], [False, True, False, False], [True, True, True, True]])
>>> mt = masked_tensor(data, mask)
>>> data.select(0, 1)
tensor([4., 5., 6., 7.])
>>> mask.select(0, 1)
tensor([False,  True, False, False])
>>> mt.select(0, 1)
MaskedTensor(
  [      --,   5.0000,       --,       --]
)

目前支持以下操作

atleast_1d

返回每個零維輸入張量的 1 維視圖。

broadcast_tensors

根據 廣播語義 (Broadcasting semantics) 廣播給定的張量。

broadcast_to

input 廣播到形狀 shape

cat (串聯)

在給定的維度中串聯 tensors 中給定的張量序列。

chunk (分塊)

嘗試將張量拆分為指定數量的塊。

column_stack (堆疊成列)

通過水平堆疊 tensors 中的張量來創建一個新張量。

dsplit (深度分割)

根據 indices_or_sectionsinput(一個具有三個或更多維度的張量)深度地分割成多個張量。

flatten (扁平化)

通過將 input 重塑為一維張量來將其扁平化。

hsplit (水平分割)

根據 indices_or_sectionsinput(一個具有一個或多個維度的張量)水平地分割成多個張量。

hstack (水平堆疊)

按順序水平(按列)堆疊張量。

kron (克羅內克積)

計算 inputother 的克羅內克積,表示為 \otimes

meshgrid

創建由 attr:tensors 中 1D 輸入指定的坐標網格。

narrow (縮小)

返回一個新張量,它是 input 張量的縮小版本。

nn.functional.unfold

從批次輸入張量中提取滑動局部塊。

ravel (攤平)

返回一個連續的扁平化張量。

select (選擇)

沿給定索引處的選定維度對 input 張量進行切片。

split (分割)

將張量分割成塊。

stack (堆疊)

沿新維度串聯一系列張量。

t (轉置)

期望 input 為 <= 2-D 張量並轉置維度 0 和 1。

transpose (轉置)

返回一個張量,它是 input 的轉置版本。

vsplit (垂直分割)

根據 indices_or_sectionsinput(一個具有兩個或更多維度的張量)垂直地分割成多個張量。

vstack (垂直堆疊)

按順序垂直(按行)堆疊張量。

Tensor.expand (擴展)

返回 self 張量的新視圖,其中單例維度擴展到更大的大小。

Tensor.expand_as (擴展為)

將此張量擴展為與 other 相同的大小。

Tensor.reshape (重塑)

返回一個與 self 具有相同數據和元素數的張量,但具有指定的形狀。

Tensor.reshape_as (重塑為)

將此張量返回為與 other 相同的形狀。

Tensor.unfold

返回原始張量的視圖,該視圖包含來自 self 張量在 dimension 維度中大小為 size 的所有切片。

Tensor.view

返回一個新的張量,其數據與 self 張量相同,但具有不同的 shape

文件

存取 PyTorch 的全面開發者文件

檢視文件

教學

取得針對初學者和進階開發人員的深入教學課程

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源