Linear¶
- class torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)[source][source]¶
對傳入的資料套用仿射線性轉換: 。
此模組支援 TensorFloat32。
在某些 ROCm 裝置上,當使用 float16 輸入時,此模組將對 backward 使用不同的精度。
- 參數
- 形狀
輸入:,其中 表示包含空的任意維度,且 。
輸出:,其中除了最後一個維度外,所有維度的形狀都與輸入相同,且 。
- 變數
weight ( torch.Tensor ) – 模組中可學習的權重,形狀為 。數值初始化自 的均勻分佈,其中
bias – 模組的可學習偏差,形狀為 。如果
bias
為True
,則數值會從 初始化,其中
範例
>>> m = nn.Linear(20, 30) >>> input = torch.randn(128, 20) >>> output = m(input) >>> print(output.size()) torch.Size([128, 30])