torch.testing¶
- torch.testing.assert_close(actual, expected, *, allow_subclasses=True, rtol=None, atol=None, equal_nan=False, check_device=True, check_dtype=True, check_layout=True, check_stride=False, msg=None)[source][source]¶
斷言
actual
和expected
是否接近。如果
actual
和expected
是跨步的 (strided)、非量化的、實數值的且有限的,則如果滿足以下條件,它們被認為是接近的:非有限值 (
-inf
和inf
) 只有在它們相等時才被認為是接近的。 只有當equal_nan
為True
時,NaN
才被認為彼此相等。此外,只有當它們具有相同的以下屬性時,它們才被認為是接近的:
device
(如果check_device
為True
),dtype
(如果check_dtype
為True
),layout
(如果check_layout
為True
), 和stride (如果
check_stride
為True
)。
如果
actual
或expected
是一個 meta tensor,則只會執行屬性檢查。如果
actual
和expected
是稀疏的 (sparse) (具有 COO、CSR、CSC、BSR 或 BSC 布局),則它們的跨步成員 (strided members) 會被個別檢查。 索引 (indices),即 COO 的indices
、CSR 和 BSR 的crow_indices
和col_indices
,或 CSC 和 BSC 布局的ccol_indices
和row_indices
,總是檢查是否相等,而根據上述定義,數值會檢查是否接近。如果
actual
和expected
是量化的 (quantized),如果它們具有相同的qscheme()
,並且根據上述定義,dequantize()
的結果是接近的,則認為它們是接近的。actual
和expected
可以是Tensor
,或是任何 tensor-or-scalar-likes,可以透過torch.as_tensor()
從中建構torch.Tensor
。 除了 Python 純量之外,輸入類型必須直接相關。 此外,actual
和expected
可以是Sequence
或Mapping
,在這種情況下,如果它們的結構匹配且它們的所有元素根據上述定義被認為是接近的,則認為它們是接近的。注意
Python 數值型別 (scalars) 是型別關係要求的例外,因為它們的
type()
,也就是int
、float
和complex
,等同於類張量 (tensor-like) 的dtype
。 因此,可以檢查不同型別的 Python 數值,但需要設定check_dtype=False
。- 參數
actual (Any) – 實際輸入。
expected (Any) – 預期輸入。
allow_subclasses (bool) – 如果為
True
(預設值),且除了 Python 數值型別外,允許直接相關型別的輸入。 否則,需要型別相等。rtol (Optional[float]) – 相對容忍度。 如果指定,則也必須指定
atol
。 如果省略,則會根據dtype
使用下表中的預設值。atol (Optional[float]) – 絕對容忍度。 如果指定,則也必須指定
rtol
。 如果省略,則會根據dtype
使用下表中的預設值。check_device (bool) – 如果為
True
(預設值),斷言對應的張量位於相同的device
上。 如果停用此檢查,則會先將不同device
上的張量移動到 CPU,然後再進行比較。check_dtype (bool) – 如果為
True
(預設值),斷言對應的張量具有相同的dtype
。 如果停用此檢查,則具有不同dtype
的張量會先升級到一個共同的dtype
(根據torch.promote_types()
),然後再進行比較。check_layout (bool) – 如果為
True
(預設值),斷言對應的張量具有相同的layout
。 如果停用此檢查,則具有不同layout
的張量會先轉換為跨步 (strided) 張量,然後再進行比較。check_stride (bool) – 如果為
True
且對應的張量是跨步 (strided) 張量,則斷言它們具有相同的步幅 (stride)。msg (Optional[Union[str, Callable[[str], str]]]) – 如果在比較過程中發生錯誤,則可以使用的可選錯誤訊息。 也可以作為可呼叫物件傳遞,在這種情況下,它將使用產生的訊息進行呼叫,並且應該返回新訊息。
- 引發
ValueError – 如果無法從輸入建構
torch.Tensor
。ValueError – 如果僅指定
rtol
或atol
。AssertionError – 如果對應的輸入不是 Python 數值型別,並且它們沒有直接關係。
AssertionError – 如果
allow_subclasses
為False
,但對應的輸入不是 Python 數值型別,並且具有不同的型別。AssertionError – 如果輸入是
Sequence
,但它們的長度不匹配。AssertionError – 如果輸入是
Mapping
,但它們的鍵集合不匹配。AssertionError – 如果對應的張量沒有相同的
shape
。AssertionError – 如果
check_layout
為True
,但對應的張量沒有相同的layout
。AssertionError – 如果只有一個對應的張量被量化。
AssertionError – 如果對應的張量被量化,但具有不同的
qscheme()
。AssertionError – 如果
check_device
為True
,但對應的張量不在相同的device
上。AssertionError – 如果
check_dtype
為True
,但對應的張量沒有相同的dtype
。AssertionError – 如果
check_stride
為True
,但對應的 stride 張量沒有相同的 stride。AssertionError – 如果根據上述定義,對應張量的值不接近。
下表顯示了不同
dtype
的預設rtol
和atol
。如果dtype
不匹配,則使用兩個容差中的最大值。dtype
rtol
atol
float16
1e-3
1e-5
bfloat16
1.6e-2
1e-5
float32
1.3e-6
1e-5
float64
1e-7
1e-7
complex32
1e-3
1e-5
complex64
1.3e-6
1e-5
complex128
1e-7
1e-7
quint8
1.3e-6
1e-5
quint2x4
1.3e-6
1e-5
quint4x2
1.3e-6
1e-5
qint8
1.3e-6
1e-5
qint32
1.3e-6
1e-5
other
0.0
0.0
注意
assert_close()
具有高度可配置性,並具有嚴格的預設設定。 鼓勵使用者partial()
以適應他們的使用案例。 例如,如果需要相等性檢查,則可以定義一個assert_equal
,該assert_equal
預設為每個dtype
使用零容差>>> import functools >>> assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0) >>> assert_equal(1e-9, 1e-10) Traceback (most recent call last): ... AssertionError: Scalars are not equal! Expected 1e-10 but got 1e-09. Absolute difference: 9.000000000000001e-10 Relative difference: 9.0
範例
>>> # tensor to tensor comparison >>> expected = torch.tensor([1e0, 1e-1, 1e-2]) >>> actual = torch.acos(torch.cos(expected)) >>> torch.testing.assert_close(actual, expected)
>>> # scalar to scalar comparison >>> import math >>> expected = math.sqrt(2.0) >>> actual = 2.0 / math.sqrt(2.0) >>> torch.testing.assert_close(actual, expected)
>>> # numpy array to numpy array comparison >>> import numpy as np >>> expected = np.array([1e0, 1e-1, 1e-2]) >>> actual = np.arccos(np.cos(expected)) >>> torch.testing.assert_close(actual, expected)
>>> # sequence to sequence comparison >>> import numpy as np >>> # The types of the sequences do not have to match. They only have to have the same >>> # length and their elements have to match. >>> expected = [torch.tensor([1.0]), 2.0, np.array(3.0)] >>> actual = tuple(expected) >>> torch.testing.assert_close(actual, expected)
>>> # mapping to mapping comparison >>> from collections import OrderedDict >>> import numpy as np >>> foo = torch.tensor(1.0) >>> bar = 2.0 >>> baz = np.array(3.0) >>> # The types and a possible ordering of mappings do not have to match. They only >>> # have to have the same set of keys and their elements have to match. >>> expected = OrderedDict([("foo", foo), ("bar", bar), ("baz", baz)]) >>> actual = {"baz": baz, "bar": bar, "foo": foo} >>> torch.testing.assert_close(actual, expected)
>>> expected = torch.tensor([1.0, 2.0, 3.0]) >>> actual = expected.clone() >>> # By default, directly related instances can be compared >>> torch.testing.assert_close(torch.nn.Parameter(actual), expected) >>> # This check can be made more strict with allow_subclasses=False >>> torch.testing.assert_close( ... torch.nn.Parameter(actual), expected, allow_subclasses=False ... ) Traceback (most recent call last): ... TypeError: No comparison pair was able to handle inputs of type <class 'torch.nn.parameter.Parameter'> and <class 'torch.Tensor'>. >>> # If the inputs are not directly related, they are never considered close >>> torch.testing.assert_close(actual.numpy(), expected) Traceback (most recent call last): ... TypeError: No comparison pair was able to handle inputs of type <class 'numpy.ndarray'> and <class 'torch.Tensor'>. >>> # Exceptions to these rules are Python scalars. They can be checked regardless of >>> # their type if check_dtype=False. >>> torch.testing.assert_close(1.0, 1, check_dtype=False)
>>> # NaN != NaN by default. >>> expected = torch.tensor(float("Nan")) >>> actual = expected.clone() >>> torch.testing.assert_close(actual, expected) Traceback (most recent call last): ... AssertionError: Scalars are not close! Expected nan but got nan. Absolute difference: nan (up to 1e-05 allowed) Relative difference: nan (up to 1.3e-06 allowed) >>> torch.testing.assert_close(actual, expected, equal_nan=True)
>>> expected = torch.tensor([1.0, 2.0, 3.0]) >>> actual = torch.tensor([1.0, 4.0, 5.0]) >>> # The default error message can be overwritten. >>> torch.testing.assert_close(actual, expected, msg="Argh, the tensors are not close!") Traceback (most recent call last): ... AssertionError: Argh, the tensors are not close! >>> # If msg is a callable, it can be used to augment the generated message with >>> # extra information >>> torch.testing.assert_close( ... actual, expected, msg=lambda msg: f"Header\n\n{msg}\n\nFooter" ... ) Traceback (most recent call last): ... AssertionError: Header Tensor-likes are not close! Mismatched elements: 2 / 3 (66.7%) Greatest absolute difference: 2.0 at index (1,) (up to 1e-05 allowed) Greatest relative difference: 1.0 at index (1,) (up to 1.3e-06 allowed) Footer
- torch.testing.make_tensor(*shape, dtype, device, low=None, high=None, requires_grad=False, noncontiguous=False, exclude_zero=False, memory_format=None)[source][source]¶
建立具有給定
shape
、device
和dtype
的張量,並使用從[low, high)
均勻抽樣的值填充。如果指定了
low
或high
,並且它們超出了dtype
的可表示有限值的範圍,則它們將分別被鉗制到最低或最高的可表示有限值。 如果None
,則下表描述了low
和high
的預設值,這取決於dtype
。dtype
low
high
boolean type
0
2
unsigned integral type
0
10
signed integral types
-9
10
floating types
-9
9
complex types
-9
9
- 參數
shape (Tuple[int, ...]) – 定義輸出張量形狀的單個整數或整數序列。
dtype (
torch.dtype
) – 返回張量的資料類型。device (Union[str, torch.device]) – 返回張量的裝置。
low (Optional[Number]) – 設定給定範圍的下限(包含)。如果提供一個數字,它將被鉗制為給定 dtype 的最小可表示有限值。當
None
(預設) 時,此值基於dtype
決定(請參閱上表)。預設值:None
。high (Optional[Number]) –
設定給定範圍的上限(不包含)。如果提供一個數字,它將被鉗制為給定 dtype 的最大可表示有限值。當
None
(預設) 時,此值基於dtype
決定(請參閱上表)。預設值:None
。自 2.1 版本起已棄用: 傳遞
low==high
給make_tensor()
以用於浮點數或複數類型已自 2.1 版本起棄用,並將在 2.3 版本中移除。請改用torch.full()
。requires_grad (Optional[bool]) – 如果 autograd 應該記錄返回的 tensor 上的操作。預設值:
False
。noncontiguous (Optional[bool]) – 如果為 True,則返回的 tensor 將會是不連續的。 如果建構的 tensor 少於兩個元素,則忽略此參數。與
memory_format
互斥。exclude_zero (Optional[bool]) – 如果
True
,則根據dtype
,零會被 dtype 的小正值替換。 對於布林和整數類型,零會被一替換。對於浮點數類型,它被 dtype 的最小正正規數(dtype
的finfo()
物件的 “tiny” 值)替換,對於複數類型,它被一個實部和虛部都是複數類型可表示的最小正正規數的複數替換。預設值False
。memory_format (Optional[torch.memory_format]) – 返回的 tensor 的記憶體格式。與
noncontiguous
互斥。
- 引發
ValueError – 如果為整數 dtype 傳遞了
requires_grad=True
ValueError – 如果
low >= high
。ValueError – 如果
low
或high
為nan
。ValueError – 如果同時傳遞了
noncontiguous
和memory_format
。TypeError – 如果此函數不支援
dtype
。
- 回傳類型
範例
>>> from torch.testing import make_tensor >>> # Creates a float tensor with values in [-1, 1) >>> make_tensor((3,), device='cpu', dtype=torch.float32, low=-1, high=1) tensor([ 0.1205, 0.2282, -0.6380]) >>> # Creates a bool tensor on CUDA >>> make_tensor((2, 2), device='cuda', dtype=torch.bool) tensor([[False, False], [False, True]], device='cuda:0')
- torch.testing.assert_allclose(actual, expected, rtol=None, atol=None, equal_nan=True, msg='')[source][source]¶
警告
torch.testing.assert_allclose()
已自1.12
版本起棄用,並將在未來的版本中移除。請改用torch.testing.assert_close()
。 您可以在 這裡找到詳細的升級說明。