• 文件 >
  • torch.nn.functional
快捷方式

torch.nn.functional

卷積函數

conv1d

將 1D 卷積應用於由多個輸入平面組成的輸入訊號。

conv2d

將 2D 卷積應用於由多個輸入平面組成的輸入圖像。

conv3d

將 3D 卷積應用於由多個輸入平面組成的輸入圖像。

conv_transpose1d

將 1D 轉置卷積運算符應用於由多個輸入平面組成的輸入訊號,有時也稱為「反卷積」。

conv_transpose2d

將 2D 轉置卷積運算符應用於由多個輸入平面組成的輸入圖像,有時也稱為「反卷積」。

conv_transpose3d

將 3D 轉置卷積運算符應用於由多個輸入平面組成的輸入圖像,有時也稱為「反卷積」。

unfold

從批次輸入張量中提取滑動局部區塊。

fold

將滑動局部區塊的陣列組合到一個大的包含張量中。

池化函數

avg_pool1d

將 1D 平均池化應用於由多個輸入平面組成的輸入訊號。

avg_pool2d

kH×kWkH \times kW 區域上應用 2D 平均池化操作,步長為 sH×sWsH \times sW

avg_pool3d

kT×kH×kWkT \times kH \times kW 區域上應用 3D 平均池化操作,步長為 sT×sH×sWsT \times sH \times sW

max_pool1d

對由多個輸入平面組成的輸入訊號應用 1D 最大池化。

max_pool2d

對由多個輸入平面組成的輸入訊號應用 2D 最大池化。

max_pool3d

對由多個輸入平面組成的輸入訊號應用 3D 最大池化。

max_unpool1d

計算 MaxPool1d 的部分反向操作。

max_unpool2d

計算 MaxPool2d 的部分反向操作。

max_unpool3d

計算 MaxPool3d 的部分反向操作。

lp_pool1d

對由多個輸入平面組成的輸入訊號應用 1D 幂平均池化。

lp_pool2d

對由多個輸入平面組成的輸入訊號應用 2D 幂平均池化。

lp_pool3d

對由多個輸入平面組成的輸入訊號應用 3D 幂平均池化。

adaptive_max_pool1d

對由多個輸入平面組成的輸入訊號應用 1D 自適應最大池化。

adaptive_max_pool2d

對由多個輸入平面組成的輸入訊號應用 2D 自適應最大池化。

adaptive_max_pool3d

對由多個輸入平面組成的輸入訊號應用 3D 自適應最大池化。

adaptive_avg_pool1d

對由多個輸入平面組成的輸入訊號應用 1D 自適應平均池化。

adaptive_avg_pool2d

對由多個輸入平面組成的輸入訊號應用 2D 自適應平均池化。

adaptive_avg_pool3d

對由多個輸入平面組成的輸入訊號應用 3D 自適應平均池化。

fractional_max_pool2d

對由多個輸入平面組成的輸入訊號應用 2D 分數最大池化。

fractional_max_pool3d

對由多個輸入平面組成的輸入訊號應用 3D 分數最大池化。

注意力機制

torch.nn.attention.bias 模組包含設計用於 scaled_dot_product_attention 的 attention_biases。

scaled_dot_product_attention

scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,

非線性激活函數

threshold

將閾值應用於輸入 Tensor 的每個元素。

threshold_

threshold() 的原地 (in-place) 版本。

relu

逐元素應用修正線性單元函數。

relu_

relu() 的原地 (in-place) 版本。

hardtanh

逐元素應用 HardTanh 函數。

hardtanh_

hardtanh() 的原地 (in-place) 版本。

hardswish

逐元素應用 hardswish 函數。

relu6

逐元素應用函數 ReLU6(x)=min(max(0,x),6)\text{ReLU6}(x) = \min(\max(0,x), 6)

elu

逐元素應用指數線性單元 (ELU) 函數。

elu_

原地 (In-place) 版本的 elu()

selu

逐元素 (element-wise) 應用 SELU(x)=scale(max(0,x)+min(0,α(exp(x)1)))\text{SELU}(x) = scale * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1))),其中 α=1.6732632423543772848170429916717\alpha=1.6732632423543772848170429916717scale=1.0507009873554804934193349852946scale=1.0507009873554804934193349852946

celu

逐元素 (element-wise) 應用 CELU(x)=max(0,x)+min(0,α(exp(x/α)1))\text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))

leaky_relu

逐元素 (element-wise) 應用 LeakyReLU(x)=max(0,x)+negative_slopemin(0,x)\text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)

leaky_relu_

原地 (In-place) 版本的 leaky_relu()

prelu

逐元素套用函數 PReLU(x)=max(0,x)+weightmin(0,x)\text{PReLU}(x) = \max(0,x) + \text{weight} * \min(0,x),其中 weight 是一個可學習的參數。

rrelu

隨機修正線性單元 (Randomized leaky ReLU)。

rrelu_

原地 (In-place) 版本的 rrelu()

glu

門控線性單元 (The gated linear unit)。

gelu

當 approximate 引數為 'none' 時,它會逐元素套用函數 GELU(x)=xΦ(x)\text{GELU}(x) = x * \Phi(x)

logsigmoid

逐元素套用 LogSigmoid(xi)=log(11+exp(xi))\text{LogSigmoid}(x_i) = \log \left(\frac{1}{1 + \exp(-x_i)}\right)

hardshrink

逐元素套用硬收縮函數 (hard shrinkage function)

tanhshrink

逐元素套用,Tanhshrink(x)=xTanh(x)\text{Tanhshrink}(x) = x - \text{Tanh}(x)

softsign

逐元素套用函數 SoftSign(x)=x1+x\text{SoftSign}(x) = \frac{x}{1 + |x|}

softplus

逐元素套用函數 Softplus(x)=1βlog(1+exp(βx))\text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x))

softmin

套用 softmin 函數。

softmax

套用 softmax 函數。

softshrink

逐元素套用 soft shrinkage 函數

gumbel_softmax

從 Gumbel-Softmax 分佈中取樣 (連結 1 連結 2),並且可以選擇離散化。

log_softmax

套用 softmax 後接 logarithm。

tanh

逐元素套用,Tanh(x)=tanh(x)=exp(x)exp(x)exp(x)+exp(x)\text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)}{\exp(x) + \exp(-x)}

sigmoid

對每個元素應用 Sigmoid 函數 Sigmoid(x)=11+exp(x)\text{Sigmoid}(x) = \frac{1}{1 + \exp(-x)}

hardsigmoid

對每個元素應用 Hardsigmoid 函數。

silu

對每個元素應用 Sigmoid Linear Unit (SiLU) 函數。

mish

對每個元素應用 Mish 函數。

batch_norm

對一批資料的每個通道應用批次正規化 (Batch Normalization)。

group_norm

對最後幾個維度應用群組正規化 (Group Normalization)。

instance_norm

在批次中的每個資料樣本中,對每個通道獨立應用實例正規化 (Instance Normalization)。

layer_norm

對最後幾個維度應用層正規化 (Layer Normalization)。

local_response_norm

對輸入訊號應用局部響應正規化 (local response normalization)。

rms_norm

應用均方根層正規化 (Root Mean Square Layer Normalization)。

normalize

在指定維度上執行輸入的 LpL_p 正規化。

線性函數

linear

對傳入的資料應用線性轉換:y=xAT+by = xA^T + b.

bilinear

對傳入的資料應用雙線性轉換:y=x1TAx2+by = x_1^T A x_2 + b

Dropout 函數

dropout

在訓練期間,以機率 p 隨機將輸入張量的一些元素歸零。

alpha_dropout

將 alpha dropout 應用於輸入。

feature_alpha_dropout

隨機遮罩整個通道(一個通道是一個特徵圖)。

dropout1d

隨機將整個通道歸零(一個通道是一個 1D 特徵圖)。

dropout2d

隨機將整個通道歸零(一個通道是一個 2D 特徵圖)。

dropout3d

隨機將整個通道歸零(一個通道是一個 3D 特徵圖)。

稀疏函數

embedding

產生一個簡單的查詢表,用於在固定的字典和大小中查找嵌入向量。

embedding_bag

計算嵌入 bags 的總和、平均值或最大值。

one_hot

接受形狀為 (*) 的 LongTensor 索引值,並返回形狀為 (*, num_classes) 的張量,該張量除了最後一個維度的索引與輸入張量的相應值匹配的位置為 1 之外,其他位置均為零。

距離函數

pairwise_distance

詳情請參閱 torch.nn.PairwiseDistance

cosine_similarity

返回 x1x2 之間的餘弦相似度,沿著 dim 計算。

pdist

計算輸入中每對 row 向量之間的 p-norm 距離。

損失函數

binary_cross_entropy

衡量目標和輸入機率之間的 Binary Cross Entropy(二元交叉熵)。

binary_cross_entropy_with_logits

計算目標和輸入 logits 之間的 Binary Cross Entropy(二元交叉熵)。

poisson_nll_loss

Poisson 負對數概似損失。

cosine_embedding_loss

詳情請參閱CosineEmbeddingLoss

cross_entropy

計算輸入 logits 和目標之間的交叉熵損失。

ctc_loss

應用 Connectionist Temporal Classification(連接主義時間分類)損失。

gaussian_nll_loss

Gaussian 負對數概似損失。

hinge_embedding_loss

詳情請參閱HingeEmbeddingLoss

kl_div

計算 KL Divergence(KL 散度)損失。

l1_loss

計算元素級絕對值差的均值的函數。

mse_loss

衡量元素級均方誤差,可選擇加權。

margin_ranking_loss

詳情請參閱MarginRankingLoss

multilabel_margin_loss

詳情請參閱MultiLabelMarginLoss

multilabel_soft_margin_loss

詳情請參閱MultiLabelSoftMarginLoss

multi_margin_loss

詳情請參閱MultiMarginLoss

nll_loss

計算負對數概似損失。

huber_loss

計算 Huber 損失,可選擇加權。

smooth_l1_loss

計算 Smooth L1 損失。

soft_margin_loss

詳情請參閱SoftMarginLoss

triplet_margin_loss

計算給定輸入張量和一個大於 0 的邊距之間的 triplet 損失。

triplet_margin_with_distance_loss

使用自定義距離函數計算輸入張量的 triplet margin 損失。

視覺函數

pixel_shuffle

將形狀為 (,C×r2,H,W)(*, C \times r^2, H, W) 的張量重新排列為形狀為 (,C,H×r,W×r)(*, C, H \times r, W \times r) 的張量,其中 r 是 upscale_factor

pixel_unshuffle

透過重新排列形狀為 (,C,H×r,W×r)(*, C, H \times r, W \times r) 的張量中的元素,反轉 PixelShuffle 運算,變成形狀為 (,C×r2,H,W)(*, C \times r^2, H, W) 的張量,其中 r 為 downscale_factor

pad

填充張量。

interpolate

對輸入進行降採樣/升採樣。

upsample

對輸入進行升採樣。

upsample_nearest

使用最近鄰的像素值對輸入進行升採樣。

upsample_bilinear

使用雙線性升採樣對輸入進行升採樣。

grid_sample

計算網格採樣。

affine_grid

給定一批仿射矩陣 theta,生成 2D 或 3D 流場(採樣網格)。

DataParallel 函數 (multi-GPU, 分散式)

data_parallel

torch.nn.parallel.data_parallel

跨 device_ids 中給定的 GPU 平行評估 module(input)。

文件

存取 PyTorch 的完整開發人員文件

檢視文件

教學

取得適用於初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得您問題的解答

檢視資源