捷徑

torch.cond

torch.cond(pred, true_fn, false_fn, operands=())[來源]

有條件地應用 true_fnfalse_fn

警告

torch.cond 是 PyTorch 中的原型功能。它對輸入和輸出類型的支援有限,目前不支援訓練。請期待未來版本的 PyTorch 中更穩定的實作。在以下網址閱讀有關功能分類的更多資訊:https://pytorch.dev.org.tw/blog/pytorch-feature-classification-changes/#prototype

cond 是一種結構化的控制流程運算子。也就是說,它就像 Python 的 if 語句,但對 true_fnfalse_fnoperands 有限制,這使其能夠使用 torch.compile 和 torch.export 進行捕獲。

假設滿足對 cond 引數的約束,則 cond 等效於以下內容:

def cond(pred, true_branch, false_branch, operands):
    if pred:
        return true_branch(*operands)
    else:
        return false_branch(*operands)
參數
  • pred (Union[bool, torch.Tensor]) – 一個布林運算式或具有一個元素的張量,指示要應用哪個分支函式。

  • true_fn (Callable) – 一個可呼叫的函式 (a -> b),它位於正在追蹤的範圍內。

  • false_fn (Callable) – 一個可呼叫的函式 (a -> b),它位於正在追蹤的範圍內。true 分支和 false 分支必須具有一致的輸入和輸出,這表示輸入必須相同,並且輸出必須具有相同的類型和形狀。

  • operands (Tuple of 可能為巢狀的 dict/list/tuple of torch.Tensor) – true/false 函式的輸入元組。如果 true_fn/false_fn 不需要輸入,則可以為空。預設值為 ()。

回傳類型

Any

範例

def true_fn(x: torch.Tensor):
    return x.cos()
def false_fn(x: torch.Tensor):
    return x.sin()
return cond(x.shape[0] > 4, true_fn, false_fn, (x,))
限制
  • 條件語句(又稱 pred)必須滿足以下約束之一

    • 它是一個只有一個元素的 torch.Tensor,並且 dtype 為 torch.bool

    • 它是一個布林運算式,例如 x.shape[0] > 10x.dim() > 1 and x.shape[1] > 10

  • 分支函式(又稱 true_fn/false_fn)必須滿足以下所有約束

    • 函式簽名必須與 operands 匹配。

    • 函式必須回傳具有相同元資料(例如形狀、dtype 等)的張量。

    • 函式不能對輸入或全域變數進行就地變更。(注意:在分支中允許使用就地張量運算,例如 add_,用於中間結果)

警告

暫時性限制

  • 分支的輸出必須是單個張量。未來將支援張量的 Pytree。

文件

取得 PyTorch 的完整開發人員文件

檢視文件

教學課程

取得適合初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得您問題的解答

檢視資源