捷徑

torch.nn.functional.gumbel_softmax

torch.nn.functional.gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1)[來源][來源]

從 Gumbel-Softmax 分佈中取樣 (連結 1 連結 2),並可選擇進行離散化。

參數
  • logits (Tensor) – […, num_features] 未正規化的 log 機率

  • tau (float) – 非負純量溫度 (non-negative scalar temperature)

  • hard (bool) – 如果 True,回傳的樣本會被離散化為 one-hot 向量,但在 autograd 中會被微分,如同它是 soft 樣本一樣

  • dim (int) – 計算 softmax 的維度。預設值:-1。

回傳值

從 Gumbel-Softmax 分佈取樣的 Tensor,形狀與 logits 相同。 如果 hard=True,則回傳的樣本將為 one-hot 向量,否則它們將為機率分佈,其總和在 dim 維度上為 1。

回傳類型

Tensor

注意

此函式在此僅為相容舊版程式碼,未來可能會從 nn.Functional 中移除。

注意

hard 的主要技巧是執行 y_hard - y_soft.detach() + y_soft

它實現了兩件事:- 使輸出值完全 one-hot(因為我們先加上然後減去 y_soft 值)- 使梯度等於 y_soft 梯度(因為我們剝離了所有其他梯度)

範例:
>>> logits = torch.randn(20, 32)
>>> # Sample soft categorical using reparametrization trick:
>>> F.gumbel_softmax(logits, tau=1, hard=False)
>>> # Sample hard categorical using "Straight-through" trick:
>>> F.gumbel_softmax(logits, tau=1, hard=True)

文件

取得 PyTorch 的完整開發人員文件

檢視文件

教學

取得初學者和進階開發人員的深入教學

檢視教學

資源

尋找開發資源並獲得問題解答

檢視資源