torch.cond¶
- torch.cond(pred, true_fn, false_fn, operands=())[來源]¶
有條件地應用 true_fn 或 false_fn。
警告
torch.cond 是 PyTorch 中的原型功能。它對輸入和輸出類型的支援有限,目前不支援訓練。請期待未來版本的 PyTorch 中更穩定的實作。在以下網址閱讀有關功能分類的更多資訊:https://pytorch.dev.org.tw/blog/pytorch-feature-classification-changes/#prototype
cond 是一種結構化的控制流程運算子。也就是說,它就像 Python 的 if 語句,但對 true_fn、false_fn 和 operands 有限制,這使其能夠使用 torch.compile 和 torch.export 進行捕獲。
假設滿足對 cond 引數的約束,則 cond 等效於以下內容:
def cond(pred, true_branch, false_branch, operands): if pred: return true_branch(*operands) else: return false_branch(*operands)
- 參數
pred (Union[bool, torch.Tensor]) – 一個布林運算式或具有一個元素的張量,指示要應用哪個分支函式。
true_fn (Callable) – 一個可呼叫的函式 (a -> b),它位於正在追蹤的範圍內。
false_fn (Callable) – 一個可呼叫的函式 (a -> b),它位於正在追蹤的範圍內。true 分支和 false 分支必須具有一致的輸入和輸出,這表示輸入必須相同,並且輸出必須具有相同的類型和形狀。
operands (Tuple of 可能為巢狀的 dict/list/tuple of torch.Tensor) – true/false 函式的輸入元組。如果 true_fn/false_fn 不需要輸入,則可以為空。預設值為 ()。
- 回傳類型
範例
def true_fn(x: torch.Tensor): return x.cos() def false_fn(x: torch.Tensor): return x.sin() return cond(x.shape[0] > 4, true_fn, false_fn, (x,))
- 限制
條件語句(又稱 pred)必須滿足以下約束之一
它是一個只有一個元素的 torch.Tensor,並且 dtype 為 torch.bool
它是一個布林運算式,例如 x.shape[0] > 10 或 x.dim() > 1 and x.shape[1] > 10
分支函式(又稱 true_fn/false_fn)必須滿足以下所有約束
函式簽名必須與 operands 匹配。
函式必須回傳具有相同元資料(例如形狀、dtype 等)的張量。
函式不能對輸入或全域變數進行就地變更。(注意:在分支中允許使用就地張量運算,例如 add_,用於中間結果)
警告
暫時性限制
分支的輸出必須是單個張量。未來將支援張量的 Pytree。