torch.set_default_dtype¶
- torch.set_default_dtype(d, /)[來源][來源]¶
將預設浮點數 dtype 設定為
d
。支援浮點數 dtype 作為輸入。其他 dtype 會導致 torch 引發例外。當 PyTorch 初始化時,其預設的浮點數 dtype 為 torch.float32,而 set_default_dtype(torch.float64) 的目的是為了方便進行類似 NumPy 的型別推斷。預設的浮點數 dtype 用於
隱式地決定預設的複數 dtype。當預設的浮點數型別為 float16 時,預設的複數 dtype 為 complex32。對於 float32,預設的複數 dtype 為 complex64。對於 float64,則為 complex128。對於 bfloat16,將會引發例外,因為 bfloat16 沒有對應的複數型別。
推斷使用 Python 浮點數或複數建立的張量的 dtype。請參閱下面的範例。
決定布林值和整數張量與 Python 浮點數和複數之間的型別提升結果。
- 參數
d (
torch.dtype
) – 作為預設值的浮點數 dtype。
範例
>>> # initial default for floating point is torch.float32 >>> # Python floats are interpreted as float32 >>> torch.tensor([1.2, 3]).dtype torch.float32 >>> # initial default for floating point is torch.complex64 >>> # Complex Python numbers are interpreted as complex64 >>> torch.tensor([1.2, 3j]).dtype torch.complex64
>>> torch.set_default_dtype(torch.float64) >>> # Python floats are now interpreted as float64 >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor torch.float64 >>> # Complex Python numbers are now interpreted as complex128 >>> torch.tensor([1.2, 3j]).dtype # a new complex tensor torch.complex128
>>> torch.set_default_dtype(torch.float16) >>> # Python floats are now interpreted as float16 >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor torch.float16 >>> # Complex Python numbers are now interpreted as complex128 >>> torch.tensor([1.2, 3j]).dtype # a new complex tensor torch.complex32