捷徑

torch.nn

這些是圖表的基本建構區塊

緩衝區 (Buffer)

一種不應被視為模型參數的 Tensor。

參數 (Parameter)

一種應被視為模組參數的 Tensor。

未初始化參數 (UninitializedParameter)

尚未初始化的參數。

未初始化緩衝區 (UninitializedBuffer)

尚未初始化的緩衝區。

容器 (Containers)

模組 (Module)

所有神經網路模組的基底類別。

序列 (Sequential)

一個序列容器。

模組列表 (ModuleList)

在列表中保存子模組。

模組字典 (ModuleDict)

在字典中保存子模組。

參數列表 (ParameterList)

在列表中保存參數。

參數字典 (ParameterDict)

在字典中保存參數。

模組的全域 Hook

register_module_forward_pre_hook

註冊一個適用於所有模組的前向預處理 Hook。

register_module_forward_hook

為所有模組註冊一個全域前向 Hook。

register_module_backward_hook

註冊一個適用於所有模組的反向 Hook。

register_module_full_backward_pre_hook

註冊一個適用於所有模組的反向預處理 Hook。

register_module_full_backward_hook

註冊一個適用於所有模組的反向 Hook。

register_module_buffer_registration_hook

註冊一個適用於所有模組的緩衝區註冊 Hook。

register_module_module_registration_hook

註冊一個適用於所有模組的模組註冊 Hook。

register_module_parameter_registration_hook

註冊一個適用於所有模組的參數註冊 Hook。

卷積層 (Convolution Layers)

nn.Conv1d

對由多個輸入平面組成的輸入訊號應用一維卷積。

nn.Conv2d

對由多個輸入平面組成的輸入訊號應用二維卷積。

nn.Conv3d

對由多個輸入平面組成的輸入訊號應用三維卷積。

nn.ConvTranspose1d

對由多個輸入平面組成的輸入影像應用一維轉置卷積運算。

nn.ConvTranspose2d

對由多個輸入平面組成的輸入影像應用二維轉置卷積運算。

nn.ConvTranspose3d

對由多個輸入平面組成的輸入影像應用三維轉置卷積運算。

nn.LazyConv1d

一個 torch.nn.Conv1d 模組,具有 in_channels 參數的延遲初始化。

nn.LazyConv2d

一個 torch.nn.Conv2d 模組,具有 in_channels 參數的延遲初始化。

nn.LazyConv3d

一個 torch.nn.Conv3d 模組,具有 in_channels 參數的延遲初始化。

nn.LazyConvTranspose1d

一個 torch.nn.ConvTranspose1d 模組,具有 in_channels 參數的延遲初始化。

nn.LazyConvTranspose2d

一個 torch.nn.ConvTranspose2d 模組,具有 in_channels 參數的延遲初始化。

nn.LazyConvTranspose3d

一個 torch.nn.ConvTranspose3d 模組,具有 in_channels 參數的延遲初始化。

nn.Unfold

從批次輸入 Tensor 中提取滑動的局部區塊。

nn.Fold

將滑動的局部區塊陣列合併成一個大的包含 Tensor。

池化層 (Pooling layers)

nn.MaxPool1d

對由多個輸入平面組成的輸入訊號應用一維最大池化。

nn.MaxPool2d

對由多個輸入平面組成的輸入訊號應用二維最大池化。

nn.MaxPool3d

對由多個輸入平面組成的輸入訊號應用三維最大池化。

nn.MaxUnpool1d

計算 MaxPool1d 的部分反向運算。

nn.MaxUnpool2d

計算 MaxPool2d 的部分反向運算。

nn.MaxUnpool3d

計算 MaxPool3d 的部分反向運算。

nn.AvgPool1d

對由多個輸入平面組成的輸入訊號應用一維平均池化。

nn.AvgPool2d

對由多個輸入平面組成的輸入訊號應用二維平均池化。

nn.AvgPool3d

對由多個輸入平面組成的輸入訊號應用三維平均池化。

nn.FractionalMaxPool2d

對由多個輸入平面組成的輸入訊號應用二維分數最大池化。

nn.FractionalMaxPool3d

對由多個輸入平面組成的輸入訊號應用三維分數最大池化。

nn.LPPool1d

對由多個輸入平面組成的輸入訊號應用一維冪平均池化。

nn.LPPool2d

對由多個輸入平面組成的輸入訊號應用二維冪平均池化。

nn.LPPool3d

對由多個輸入平面組成的輸入訊號應用三維冪平均池化。

nn.AdaptiveMaxPool1d

對由多個輸入平面組成的輸入訊號應用一維自適應最大池化。

nn.AdaptiveMaxPool2d

對由多個輸入平面組成的輸入訊號應用二維自適應最大池化。

nn.AdaptiveMaxPool3d

對由多個輸入平面組成的輸入訊號應用三維自適應最大池化。

nn.AdaptiveAvgPool1d

對由多個輸入平面組成的輸入訊號應用一維自適應平均池化。

nn.AdaptiveAvgPool2d

對由多個輸入平面組成的輸入訊號應用二維自適應平均池化。

nn.AdaptiveAvgPool3d

對由多個輸入平面組成的輸入訊號應用三維自適應平均池化。

填充層 (Padding Layers)

nn.ReflectionPad1d

使用輸入邊界的反射來填充輸入 Tensor。

nn.ReflectionPad2d

使用輸入邊界的反射來填充輸入 Tensor。

nn.ReflectionPad3d

使用輸入邊界的反射來填充輸入 Tensor。

nn.ReplicationPad1d

使用輸入邊界的複製來填充輸入 Tensor。

nn.ReplicationPad2d

使用輸入邊界的複製來填充輸入 Tensor。

nn.ReplicationPad3d

使用輸入邊界的複製來填充輸入 Tensor。

nn.ZeroPad1d

使用零填充輸入 Tensor 邊界。

nn.ZeroPad2d

使用零填充輸入 Tensor 邊界。

nn.ZeroPad3d

使用零填充輸入 Tensor 邊界。

nn.ConstantPad1d

使用常數值填充輸入 Tensor 邊界。

nn.ConstantPad2d

使用常數值填充輸入 Tensor 邊界。

nn.ConstantPad3d

使用常數值填充輸入 Tensor 邊界。

nn.CircularPad1d

使用輸入邊界的循環填充來填充輸入 Tensor。

nn.CircularPad2d

使用輸入邊界的循環填充來填充輸入 Tensor。

nn.CircularPad3d

使用輸入邊界的循環填充來填充輸入 Tensor。

非線性激活 (加權和,非線性) (Non-linear Activations (weighted sum, nonlinearity))

nn.ELU

逐元素地應用指數線性單元 (ELU) 函數。

nn.Hardshrink

逐元素套用 Hard Shrinkage (Hardshrink) 函數。

nn.Hardsigmoid

逐元素套用 Hardsigmoid 函數。

nn.Hardtanh

逐元素套用 HardTanh 函數。

nn.Hardswish

逐元素套用 Hardswish 函數。

nn.LeakyReLU

逐元素套用 LeakyReLU 函數。

nn.LogSigmoid

逐元素套用 Logsigmoid 函數。

nn.MultiheadAttention

允許模型共同關注來自不同表示子空間的資訊。

nn.PReLU

套用逐元素的 PReLU 函數。

nn.ReLU

逐元素套用 rectified linear unit 函數。

nn.ReLU6

逐元素套用 ReLU6 函數。

nn.RReLU

逐元素套用 randomized leaky rectified linear unit 函數。

nn.SELU

逐元素套用 SELU 函數。

nn.CELU

逐元素套用 CELU 函數。

nn.GELU

套用 Gaussian Error Linear Units 函數。

nn.Sigmoid

逐元素套用 Sigmoid 函數。

nn.SiLU

逐元素套用 Sigmoid Linear Unit (SiLU) 函數。

nn.Mish

逐元素套用 Mish 函數。

nn.Softplus

逐元素套用 Softplus 函數。

nn.Softshrink

逐元素套用 soft shrinkage 函數。

nn.Softsign

套用逐元素的 Softsign 函數。

nn.Tanh

逐元素套用 Hyperbolic Tangent (Tanh) 函數。

nn.Tanhshrink

套用逐元素的 Tanhshrink 函數。

nn.Threshold

對輸入 Tensor 的每個元素進行閾值處理。

nn.GLU

套用 gated linear unit 函數。

非線性激活函數 (其他)

nn.Softmin

對 n 維輸入 Tensor 套用 Softmin 函數。

nn.Softmax

對 n 維輸入 Tensor 套用 Softmax 函數。

nn.Softmax2d

將 SoftMax 應用於每個空間位置的特徵。

nn.LogSoftmax

log(Softmax(x))\log(\text{Softmax}(x)) 函數應用於 n 維輸入 Tensor。

nn.AdaptiveLogSoftmaxWithLoss

高效的 softmax 近似。

正規化層

nn.BatchNorm1d

對 2D 或 3D 輸入套用批次正規化。

nn.BatchNorm2d

對 4D 輸入套用批次正規化。

nn.BatchNorm3d

對 5D 輸入套用批次正規化。

nn.LazyBatchNorm1d

一個具有延遲初始化的 torch.nn.BatchNorm1d 模組。

nn.LazyBatchNorm2d

一個具有延遲初始化的 torch.nn.BatchNorm2d 模組。

nn.LazyBatchNorm3d

一個具有延遲初始化的 torch.nn.BatchNorm3d 模組。

nn.GroupNorm

對一小批輸入套用群組正規化。

nn.SyncBatchNorm

對 N 維輸入套用批次正規化。

nn.InstanceNorm1d

套用實例正規化。

nn.InstanceNorm2d

套用實例正規化。

nn.InstanceNorm3d

套用實例正規化。

nn.LazyInstanceNorm1d

一個具有 num_features 參數延遲初始化的 torch.nn.InstanceNorm1d 模組。

nn.LazyInstanceNorm2d

一個具有 num_features 參數延遲初始化的 torch.nn.InstanceNorm2d 模組。

nn.LazyInstanceNorm3d

一個具有 num_features 參數延遲初始化的 torch.nn.InstanceNorm3d 模組。

nn.LayerNorm

對一小批輸入套用層正規化。

nn.LocalResponseNorm

對輸入訊號套用局部響應正規化。

nn.RMSNorm

對一小批輸入套用 Root Mean Square Layer Normalization。

循環層

nn.RNNBase

RNN 模組 (RNN、LSTM、GRU) 的基底類別。

nn.RNN

將具有 tanh\tanhReLU\text{ReLU} 非線性函數的多層 Elman RNN 應用於輸入序列。

nn.LSTM

將多層長短期記憶 (LSTM) RNN 應用於輸入序列。

nn.GRU

將多層 gated recurrent unit (GRU) RNN 應用於輸入序列。

nn.RNNCell

具有 tanh 或 ReLU 非線性函數的 Elman RNN cell。

nn.LSTMCell

一個長短期記憶 (LSTM) cell。

nn.GRUCell

一個 gated recurrent unit (GRU) cell。

Transformer 層

nn.Transformer

一個 transformer 模型。

nn.TransformerEncoder

TransformerEncoder 是一個 N 個編碼器層的堆疊。

nn.TransformerDecoder

TransformerDecoder 是一個 N 個解碼器層的堆疊。

nn.TransformerEncoderLayer

TransformerEncoderLayer 由 self-attn 和 feedforward network 組成。

nn.TransformerDecoderLayer

TransformerDecoderLayer 由 self-attn、multi-head-attn 和 feedforward network 組成。

線性層

nn.Identity

一個不區分參數的佔位符 identity 運算符。

nn.Linear

將仿射線性轉換應用於傳入的資料: y=xAT+by = xA^T + b

nn.Bilinear

將雙線性轉換應用於傳入的資料: y=x1TAx2+by = x_1^T A x_2 + b

nn.LazyLinear

一個 torch.nn.Linear 模組,其中 in_features 會被推斷出來。

Dropout 層

nn.Dropout

在訓練期間,以機率 p 隨機將輸入張量的一些元素歸零。

nn.Dropout1d

隨機歸零整個通道。

nn.Dropout2d

隨機歸零整個通道。

nn.Dropout3d

隨機歸零整個通道。

nn.AlphaDropout

將 Alpha Dropout 應用於輸入。

nn.FeatureAlphaDropout

隨機遮罩整個通道。

稀疏層

nn.Embedding

一個簡單的查詢表,用於儲存固定字典和大小的嵌入。

nn.EmbeddingBag

計算嵌入「包」的總和或平均值,而無需實例化中間嵌入。

距離函數

nn.CosineSimilarity

傳回 x1x_1x2x_2 之間的餘弦相似度,沿著 dim 計算。

nn.PairwiseDistance

計算輸入向量之間或輸入矩陣的列之間的成對距離。

損失函數

nn.L1Loss

建立一個標準,用於測量輸入 xx 和目標 yy 中每個元素的平均絕對誤差 (MAE)。

nn.MSELoss

建立一個標準,用於測量輸入 xx 和目標 yy 中每個元素的均方誤差(平方 L2 範數)。

nn.CrossEntropyLoss

此標準計算輸入 logits 和目標之間的交叉熵損失。

nn.CTCLoss

Connectionist Temporal Classification 損失。

nn.NLLLoss

負對數概似損失。

nn.PoissonNLLLoss

具有目標 Poisson 分佈的負對數概似損失。

nn.GaussianNLLLoss

高斯負對數概似損失。

nn.KLDivLoss

Kullback-Leibler 散度損失。

nn.BCELoss

建立一個標準,用於測量目標和輸入機率之間的二元交叉熵。

nn.BCEWithLogitsLoss

此損失將 Sigmoid 層和 BCELoss 組合在一個單一類別中。

nn.MarginRankingLoss

建立一個標準,用來衡量給定輸入 x1x1x2x2,兩個 1D mini-batch 或 0D Tensors,以及一個 label 1D mini-batch 或 0D Tensor yy (包含 1 或 -1) 的損失。

nn.HingeEmbeddingLoss

衡量給定輸入 tensor xx 和 label tensor yy (包含 1 或 -1) 的損失。

nn.MultiLabelMarginLoss

建立一個標準,用來最佳化輸入 xx (一個 2D mini-batch Tensor) 和輸出 yy (一個目標類別索引的 2D Tensor) 之間的多類別多重分類 hinge loss (基於邊界的損失)。

nn.HuberLoss

建立一個標準,如果絕對的逐元素誤差小於 delta,則使用平方項;否則使用 delta 縮放的 L1 項。

nn.SmoothL1Loss

建立一個標準,如果絕對的逐元素誤差小於 beta,則使用平方項;否則使用 L1 項。

nn.SoftMarginLoss

建立一個標準,用來最佳化輸入 tensor xx 和目標 tensor yy (包含 1 或 -1) 之間的雙類別分類 logistic 損失。

nn.MultiLabelSoftMarginLoss

建立一個標準,用來最佳化基於最大熵的 multi-label one-versus-all 損失,介於輸入 xx 和大小為 (N,C)(N, C) 的目標 yy 之間。

nn.CosineEmbeddingLoss

建立一個標準,用來衡量給定輸入 tensors x1x_1x2x_2 和一個值為 1 或 -1 的 Tensor label yy 的損失。

nn.MultiMarginLoss

建立一個準則,用於優化輸入 xx (一個 2D mini-batch Tensor) 和輸出 yy (一個目標類別索引的 1D tensor,0yx.size(1)10 \leq y \leq \text{x.size}(1)-1) 之間的多類別分類 hinge loss (基於邊界的損失)。

nn.TripletMarginLoss

建立一個準則,用於測量給定輸入 tensors x1x1x2x2x3x3 和邊界值大於 00 的 triplet loss。

nn.TripletMarginWithDistanceLoss

建立一個準則,用於測量給定輸入 tensors aappnn (分別代表 anchor、positive 和 negative 範例) 的 triplet loss,以及用於計算 anchor 和 positive 範例之間關係("positive distance")和 anchor 和 negative 範例之間關係("negative distance")的非負實值函數("distance function")。

視覺層

nn.PixelShuffle

根據放大因子重新排列 tensor 中的元素。

nn.PixelUnshuffle

反轉 PixelShuffle 操作。

nn.Upsample

對給定的多通道 1D (時間)、2D (空間) 或 3D (體積) 數據進行升採樣。

nn.UpsamplingNearest2d

將 2D 最近鄰升採樣應用於由多個輸入通道組成的輸入訊號。

nn.UpsamplingBilinear2d

將 2D 雙線性升採樣應用於由多個輸入通道組成的輸入訊號。

Shuffle 層

nn.ChannelShuffle

分割並重新排列 tensor 中的通道。

DataParallel 層 (多 GPU,分散式)

nn.DataParallel

在模組級別實現數據平行化。

nn.parallel.DistributedDataParallel

在模組級別基於 torch.distributed 實現分散式數據平行化。

實用工具

來自 torch.nn.utils 模組

用於裁剪參數梯度的實用函數。

clip_grad_norm_

裁剪參數 iterable 的梯度範數。

clip_grad_norm

裁剪參數 iterable 的梯度範數。

clip_grad_value_

在指定值裁剪參數 iterable 的梯度。

get_total_norm

計算 tensors iterable 的範數。

clip_grads_with_norm_

給定預先計算的總範數和所需的 max 範數,縮放參數 iterable 的梯度。

用於將 Module 參數扁平化和非扁平化為單個向量的實用函數。

parameters_to_vector

將參數 iterable 扁平化為單個向量。

vector_to_parameters

將向量的切片複製到參數 iterable 中。

用於融合具有 BatchNorm 模組的 Module 的實用函數。

fuse_conv_bn_eval

將卷積模組和 BatchNorm 模組融合到單個新的卷積模組中。

fuse_conv_bn_weights

將卷積模組參數和 BatchNorm 模組參數融合到新的卷積模組參數中。

fuse_linear_bn_eval

將線性模組和 BatchNorm 模組融合到單個新的線性模組中。

fuse_linear_bn_weights

將線性模組參數和 BatchNorm 模組參數融合到新的線性模組參數中。

用於轉換 Module 參數記憶體格式的實用函數。

convert_conv2d_weight_memory_format

nn.Conv2d.weightmemory_format 轉換為 memory_format

convert_conv3d_weight_memory_format

nn.Conv3d.weightmemory_format 轉換為 memory_format。轉換會遞迴地應用於嵌套的 nn.Module,包括 module

用於從 Module 參數應用和移除權重正規化的實用函數。

weight_norm

將權重正規化應用於給定模組中的參數。

remove_weight_norm

從模組中移除權重正規化重新參數化。

spectral_norm

將譜正規化應用於給定模組中的參數。

remove_spectral_norm

從模組中移除譜正規化重新參數化。

用於初始化 Module 參數的實用函數。

skip_init

給定模組類別物件和 args / kwargs,實例化模組而不初始化參數 / 緩衝區。

用於修剪 Module 參數的實用類別和函數。

prune.BasePruningMethod

用於建立新修剪技術的抽象基類。

prune.PruningContainer

用於迭代剪枝的剪枝方法序列的容器。

prune.Identity

實用剪枝方法,不剪除任何單元,但會產生一個由 1 組成的遮罩的剪枝參數化。

prune.RandomUnstructured

隨機剪除張量中(目前未剪除的)單元。

prune.L1Unstructured

透過將 L1 範數最低的單元歸零,來剪除張量中(目前未剪除的)單元。

prune.RandomStructured

隨機剪除張量中整個(目前未剪除的)通道。

prune.LnStructured

根據 Ln 範數剪除張量中整個(目前未剪除的)通道。

prune.CustomFromMask

prune.identity

套用剪枝重新參數化,而不剪除任何單元。

prune.random_unstructured

透過移除隨機(目前未剪除的)單元來剪除張量。

prune.l1_unstructured

透過移除 L1 範數最低的單元來剪除張量。

prune.random_structured

透過沿指定維度移除隨機通道來剪除張量。

prune.ln_structured

透過沿指定維度移除 Ln 範數最低的通道來剪除張量。

prune.global_unstructured

透過套用指定的 pruning_method,全域剪除對應於 parameters 中所有參數的張量。

prune.custom_from_mask

透過套用 mask 中預先計算的遮罩,來剪除 module 中名為 name 的參數對應的張量。

prune.remove

從模組中移除剪枝重新參數化,並從前向鉤子中移除剪枝方法。

prune.is_pruned

透過尋找剪枝前置鉤子來檢查模組是否已剪枝。

使用 torch.nn.utils.parameterize.register_parametrization() 中的新參數化功能實現的參數化。

parametrizations.orthogonal

將正交或 unitary 參數化套用到矩陣或一批矩陣。

parametrizations.weight_norm

將權重正規化應用於給定模組中的參數。

parametrizations.spectral_norm

將譜正規化應用於給定模組中的參數。

用於在現有模組上參數化張量的實用函式。請注意,這些函式可用於參數化給定的 Parameter 或 Buffer,給定一個將輸入空間映射到參數化空間的特定函式。它們不是會將物件轉換為參數的參數化。 有關如何實現自己的參數化的更多信息,請參閱參數化教學課程

parametrize.register_parametrization

將參數化註冊到模組中的張量。

parametrize.remove_parametrizations

移除模組中張量上的參數化。

parametrize.cached

上下文管理器,啟用在 register_parametrization() 註冊的參數化中的快取系統。

parametrize.is_parametrized

確定模組是否具有參數化。

parametrize.ParametrizationList

一個循序容器,用於保存和管理參數化 torch.nn.Module 的原始參數或緩衝區。

以無狀態方式呼叫給定模組的實用函式。

stateless.functional_call

透過將模組參數和緩衝區替換為提供的參數和緩衝區,對模組執行函式呼叫。

其他模組中的實用函式

nn.utils.rnn.PackedSequence

保存壓縮序列的資料和 batch_sizes 清單。

nn.utils.rnn.pack_padded_sequence

壓縮包含可變長度填補序列的張量。

nn.utils.rnn.pad_packed_sequence

填補可變長度序列的壓縮批次。

nn.utils.rnn.pad_sequence

使用 padding_value 填補可變長度張量的清單。

nn.utils.rnn.pack_sequence

壓縮可變長度張量的清單。

nn.utils.rnn.unpack_sequence

將 PackedSequence 解壓縮為可變長度張量的清單。

nn.utils.rnn.unpad_sequence

將填補的張量解填補為可變長度張量的清單。

nn.Flatten

將連續範圍的維度扁平化為一個張量。

nn.Unflatten

解扁平化一個張量維度,將其擴展到所需的形狀。

量化函式

量化是指以低於浮點精度的位寬執行計算和儲存張量的技術。 PyTorch 支援每個張量和每個通道的非對稱線性量化。 若要了解如何在 PyTorch 中使用量化函式的更多資訊,請參閱 量化 文件。

延遲模組初始化

nn.modules.lazy.LazyModuleMixin

用於延遲初始化參數的模組的混合,也稱為「延遲模組」。

別名

以下是 torch.nn 中對應項的別名

nn.modules.normalization.RMSNorm

對一小批輸入套用 Root Mean Square Layer Normalization。

文件

取得 PyTorch 的完整開發人員文件

檢視文件

教學課程

取得針對初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並取得您問題的解答

檢視資源