捷徑

torch.set_default_dtype

torch.set_default_dtype(d, /)[來源][來源]

將預設浮點數 dtype 設定為 d。支援浮點數 dtype 作為輸入。其他 dtype 會導致 torch 引發例外。

當 PyTorch 初始化時,其預設的浮點數 dtype 為 torch.float32,而 set_default_dtype(torch.float64) 的目的是為了方便進行類似 NumPy 的型別推斷。預設的浮點數 dtype 用於

  1. 隱式地決定預設的複數 dtype。當預設的浮點數型別為 float16 時,預設的複數 dtype 為 complex32。對於 float32,預設的複數 dtype 為 complex64。對於 float64,則為 complex128。對於 bfloat16,將會引發例外,因為 bfloat16 沒有對應的複數型別。

  2. 推斷使用 Python 浮點數或複數建立的張量的 dtype。請參閱下面的範例。

  3. 決定布林值和整數張量與 Python 浮點數和複數之間的型別提升結果。

參數

d (torch.dtype) – 作為預設值的浮點數 dtype。

範例

>>> # initial default for floating point is torch.float32
>>> # Python floats are interpreted as float32
>>> torch.tensor([1.2, 3]).dtype
torch.float32
>>> # initial default for floating point is torch.complex64
>>> # Complex Python numbers are interpreted as complex64
>>> torch.tensor([1.2, 3j]).dtype
torch.complex64
>>> torch.set_default_dtype(torch.float64)
>>> # Python floats are now interpreted as float64
>>> torch.tensor([1.2, 3]).dtype  # a new floating point tensor
torch.float64
>>> # Complex Python numbers are now interpreted as complex128
>>> torch.tensor([1.2, 3j]).dtype  # a new complex tensor
torch.complex128
>>> torch.set_default_dtype(torch.float16)
>>> # Python floats are now interpreted as float16
>>> torch.tensor([1.2, 3]).dtype  # a new floating point tensor
torch.float16
>>> # Complex Python numbers are now interpreted as complex128
>>> torch.tensor([1.2, 3j]).dtype  # a new complex tensor
torch.complex32

文件

獲取 PyTorch 的全面開發者文檔

查看文檔

教學

取得適合初學者和高級開發人員的深入教學

查看教學

資源

尋找開發資源並獲得您的問題解答

查看資源