快捷方式

torch.testing

torch.testing.assert_close(actual, expected, *, allow_subclasses=True, rtol=None, atol=None, equal_nan=False, check_device=True, check_dtype=True, check_layout=True, check_stride=False, msg=None)[source][source]

斷言 actualexpected 是否接近。

如果 actualexpected 是跨步的 (strided)、非量化的、實數值的且有限的,則如果滿足以下條件,它們被認為是接近的:

actualexpectedatol+rtolexpected\lvert \text{actual} - \text{expected} \rvert \le \texttt{atol} + \texttt{rtol} \cdot \lvert \text{expected} \rvert

非有限值 (-infinf) 只有在它們相等時才被認為是接近的。 只有當 equal_nanTrue 時,NaN 才被認為彼此相等。

此外,只有當它們具有相同的以下屬性時,它們才被認為是接近的:

  • device (如果 check_deviceTrue),

  • dtype (如果 check_dtypeTrue),

  • layout (如果 check_layoutTrue), 和

  • stride (如果 check_strideTrue)。

如果 actualexpected 是一個 meta tensor,則只會執行屬性檢查。

如果 actualexpected 是稀疏的 (sparse) (具有 COO、CSR、CSC、BSR 或 BSC 布局),則它們的跨步成員 (strided members) 會被個別檢查。 索引 (indices),即 COO 的 indices、CSR 和 BSR 的 crow_indicescol_indices,或 CSC 和 BSC 布局的 ccol_indicesrow_indices,總是檢查是否相等,而根據上述定義,數值會檢查是否接近。

如果 actualexpected 是量化的 (quantized),如果它們具有相同的 qscheme(),並且根據上述定義,dequantize() 的結果是接近的,則認為它們是接近的。

actualexpected 可以是 Tensor,或是任何 tensor-or-scalar-likes,可以透過 torch.as_tensor() 從中建構 torch.Tensor。 除了 Python 純量之外,輸入類型必須直接相關。 此外,actualexpected 可以是 SequenceMapping,在這種情況下,如果它們的結構匹配且它們的所有元素根據上述定義被認為是接近的,則認為它們是接近的。

注意

Python 數值型別 (scalars) 是型別關係要求的例外,因為它們的 type(),也就是 intfloatcomplex,等同於類張量 (tensor-like) 的 dtype。 因此,可以檢查不同型別的 Python 數值,但需要設定 check_dtype=False

參數
  • actual (Any) – 實際輸入。

  • expected (Any) – 預期輸入。

  • allow_subclasses (bool) – 如果為 True (預設值),且除了 Python 數值型別外,允許直接相關型別的輸入。 否則,需要型別相等。

  • rtol (Optional[float]) – 相對容忍度。 如果指定,則也必須指定 atol。 如果省略,則會根據 dtype 使用下表中的預設值。

  • atol (Optional[float]) – 絕對容忍度。 如果指定,則也必須指定 rtol。 如果省略,則會根據 dtype 使用下表中的預設值。

  • equal_nan (Union[bool, str]) – 如果為 True,則兩個 NaN 值將被視為相等。

  • check_device (bool) – 如果為 True (預設值),斷言對應的張量位於相同的 device 上。 如果停用此檢查,則會先將不同 device 上的張量移動到 CPU,然後再進行比較。

  • check_dtype (bool) – 如果為 True (預設值),斷言對應的張量具有相同的 dtype。 如果停用此檢查,則具有不同 dtype 的張量會先升級到一個共同的 dtype (根據 torch.promote_types()),然後再進行比較。

  • check_layout (bool) – 如果為 True (預設值),斷言對應的張量具有相同的 layout。 如果停用此檢查,則具有不同 layout 的張量會先轉換為跨步 (strided) 張量,然後再進行比較。

  • check_stride (bool) – 如果為 True 且對應的張量是跨步 (strided) 張量,則斷言它們具有相同的步幅 (stride)。

  • msg (Optional[Union[str, Callable[[str], str]]]) – 如果在比較過程中發生錯誤,則可以使用的可選錯誤訊息。 也可以作為可呼叫物件傳遞,在這種情況下,它將使用產生的訊息進行呼叫,並且應該返回新訊息。

引發
  • ValueError – 如果無法從輸入建構 torch.Tensor

  • ValueError – 如果僅指定 rtolatol

  • AssertionError – 如果對應的輸入不是 Python 數值型別,並且它們沒有直接關係。

  • AssertionError – 如果 allow_subclassesFalse,但對應的輸入不是 Python 數值型別,並且具有不同的型別。

  • AssertionError – 如果輸入是 Sequence,但它們的長度不匹配。

  • AssertionError – 如果輸入是 Mapping,但它們的鍵集合不匹配。

  • AssertionError – 如果對應的張量沒有相同的 shape

  • AssertionError – 如果 check_layoutTrue,但對應的張量沒有相同的 layout

  • AssertionError – 如果只有一個對應的張量被量化。

  • AssertionError – 如果對應的張量被量化,但具有不同的 qscheme()

  • AssertionError – 如果 check_deviceTrue,但對應的張量不在相同的 device 上。

  • AssertionError – 如果 check_dtypeTrue,但對應的張量沒有相同的 dtype

  • AssertionError – 如果 check_strideTrue,但對應的 stride 張量沒有相同的 stride。

  • AssertionError – 如果根據上述定義,對應張量的值不接近。

下表顯示了不同 dtype 的預設 rtolatol。如果 dtype 不匹配,則使用兩個容差中的最大值。

dtype

rtol

atol

float16

1e-3

1e-5

bfloat16

1.6e-2

1e-5

float32

1.3e-6

1e-5

float64

1e-7

1e-7

complex32

1e-3

1e-5

complex64

1.3e-6

1e-5

complex128

1e-7

1e-7

quint8

1.3e-6

1e-5

quint2x4

1.3e-6

1e-5

quint4x2

1.3e-6

1e-5

qint8

1.3e-6

1e-5

qint32

1.3e-6

1e-5

other

0.0

0.0

注意

assert_close() 具有高度可配置性,並具有嚴格的預設設定。 鼓勵使用者 partial() 以適應他們的使用案例。 例如,如果需要相等性檢查,則可以定義一個 assert_equal,該 assert_equal 預設為每個 dtype 使用零容差

>>> import functools
>>> assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0)
>>> assert_equal(1e-9, 1e-10)
Traceback (most recent call last):
...
AssertionError: Scalars are not equal!

Expected 1e-10 but got 1e-09.
Absolute difference: 9.000000000000001e-10
Relative difference: 9.0

範例

>>> # tensor to tensor comparison
>>> expected = torch.tensor([1e0, 1e-1, 1e-2])
>>> actual = torch.acos(torch.cos(expected))
>>> torch.testing.assert_close(actual, expected)
>>> # scalar to scalar comparison
>>> import math
>>> expected = math.sqrt(2.0)
>>> actual = 2.0 / math.sqrt(2.0)
>>> torch.testing.assert_close(actual, expected)
>>> # numpy array to numpy array comparison
>>> import numpy as np
>>> expected = np.array([1e0, 1e-1, 1e-2])
>>> actual = np.arccos(np.cos(expected))
>>> torch.testing.assert_close(actual, expected)
>>> # sequence to sequence comparison
>>> import numpy as np
>>> # The types of the sequences do not have to match. They only have to have the same
>>> # length and their elements have to match.
>>> expected = [torch.tensor([1.0]), 2.0, np.array(3.0)]
>>> actual = tuple(expected)
>>> torch.testing.assert_close(actual, expected)
>>> # mapping to mapping comparison
>>> from collections import OrderedDict
>>> import numpy as np
>>> foo = torch.tensor(1.0)
>>> bar = 2.0
>>> baz = np.array(3.0)
>>> # The types and a possible ordering of mappings do not have to match. They only
>>> # have to have the same set of keys and their elements have to match.
>>> expected = OrderedDict([("foo", foo), ("bar", bar), ("baz", baz)])
>>> actual = {"baz": baz, "bar": bar, "foo": foo}
>>> torch.testing.assert_close(actual, expected)
>>> expected = torch.tensor([1.0, 2.0, 3.0])
>>> actual = expected.clone()
>>> # By default, directly related instances can be compared
>>> torch.testing.assert_close(torch.nn.Parameter(actual), expected)
>>> # This check can be made more strict with allow_subclasses=False
>>> torch.testing.assert_close(
...     torch.nn.Parameter(actual), expected, allow_subclasses=False
... )
Traceback (most recent call last):
...
TypeError: No comparison pair was able to handle inputs of type
<class 'torch.nn.parameter.Parameter'> and <class 'torch.Tensor'>.
>>> # If the inputs are not directly related, they are never considered close
>>> torch.testing.assert_close(actual.numpy(), expected)
Traceback (most recent call last):
...
TypeError: No comparison pair was able to handle inputs of type <class 'numpy.ndarray'>
and <class 'torch.Tensor'>.
>>> # Exceptions to these rules are Python scalars. They can be checked regardless of
>>> # their type if check_dtype=False.
>>> torch.testing.assert_close(1.0, 1, check_dtype=False)
>>> # NaN != NaN by default.
>>> expected = torch.tensor(float("Nan"))
>>> actual = expected.clone()
>>> torch.testing.assert_close(actual, expected)
Traceback (most recent call last):
...
AssertionError: Scalars are not close!

Expected nan but got nan.
Absolute difference: nan (up to 1e-05 allowed)
Relative difference: nan (up to 1.3e-06 allowed)
>>> torch.testing.assert_close(actual, expected, equal_nan=True)
>>> expected = torch.tensor([1.0, 2.0, 3.0])
>>> actual = torch.tensor([1.0, 4.0, 5.0])
>>> # The default error message can be overwritten.
>>> torch.testing.assert_close(actual, expected, msg="Argh, the tensors are not close!")
Traceback (most recent call last):
...
AssertionError: Argh, the tensors are not close!
>>> # If msg is a callable, it can be used to augment the generated message with
>>> # extra information
>>> torch.testing.assert_close(
...     actual, expected, msg=lambda msg: f"Header\n\n{msg}\n\nFooter"
... )
Traceback (most recent call last):
...
AssertionError: Header

Tensor-likes are not close!

Mismatched elements: 2 / 3 (66.7%)
Greatest absolute difference: 2.0 at index (1,) (up to 1e-05 allowed)
Greatest relative difference: 1.0 at index (1,) (up to 1.3e-06 allowed)

Footer
torch.testing.make_tensor(*shape, dtype, device, low=None, high=None, requires_grad=False, noncontiguous=False, exclude_zero=False, memory_format=None)[source][source]

建立具有給定 shapedevicedtype 的張量,並使用從 [low, high) 均勻抽樣的值填充。

如果指定了 lowhigh,並且它們超出了 dtype 的可表示有限值的範圍,則它們將分別被鉗制到最低或最高的可表示有限值。 如果 None,則下表描述了 lowhigh 的預設值,這取決於 dtype

dtype

low

high

boolean type

0

2

unsigned integral type

0

10

signed integral types

-9

10

floating types

-9

9

complex types

-9

9

參數
  • shape (Tuple[int, ...]) – 定義輸出張量形狀的單個整數或整數序列。

  • dtype (torch.dtype) – 返回張量的資料類型。

  • device (Union[str, torch.device]) – 返回張量的裝置。

  • low (Optional[Number]) – 設定給定範圍的下限(包含)。如果提供一個數字,它將被鉗制為給定 dtype 的最小可表示有限值。當 None (預設) 時,此值基於 dtype 決定(請參閱上表)。預設值:None

  • high (Optional[Number]) –

    設定給定範圍的上限(不包含)。如果提供一個數字,它將被鉗制為給定 dtype 的最大可表示有限值。當 None (預設) 時,此值基於 dtype 決定(請參閱上表)。預設值:None

    自 2.1 版本起已棄用: 傳遞 low==highmake_tensor() 以用於浮點數或複數類型已自 2.1 版本起棄用,並將在 2.3 版本中移除。請改用 torch.full()

  • requires_grad (Optional[bool]) – 如果 autograd 應該記錄返回的 tensor 上的操作。預設值:False

  • noncontiguous (Optional[bool]) – 如果為 True,則返回的 tensor 將會是不連續的。 如果建構的 tensor 少於兩個元素,則忽略此參數。與 memory_format 互斥。

  • exclude_zero (Optional[bool]) – 如果 True,則根據 dtype,零會被 dtype 的小正值替換。 對於布林和整數類型,零會被一替換。對於浮點數類型,它被 dtype 的最小正正規數(dtypefinfo() 物件的 “tiny” 值)替換,對於複數類型,它被一個實部和虛部都是複數類型可表示的最小正正規數的複數替換。預設值 False

  • memory_format (Optional[torch.memory_format]) – 返回的 tensor 的記憶體格式。與 noncontiguous 互斥。

引發
  • ValueError – 如果為整數 dtype 傳遞了 requires_grad=True

  • ValueError – 如果 low >= high

  • ValueError – 如果 lowhighnan

  • ValueError – 如果同時傳遞了 noncontiguousmemory_format

  • TypeError – 如果此函數不支援 dtype

回傳類型

Tensor

範例

>>> from torch.testing import make_tensor
>>> # Creates a float tensor with values in [-1, 1)
>>> make_tensor((3,), device='cpu', dtype=torch.float32, low=-1, high=1)
tensor([ 0.1205, 0.2282, -0.6380])
>>> # Creates a bool tensor on CUDA
>>> make_tensor((2, 2), device='cuda', dtype=torch.bool)
tensor([[False, False],
        [False, True]], device='cuda:0')
torch.testing.assert_allclose(actual, expected, rtol=None, atol=None, equal_nan=True, msg='')[source][source]

警告

torch.testing.assert_allclose() 已自 1.12 版本起棄用,並將在未來的版本中移除。請改用 torch.testing.assert_close()。 您可以在 這裡找到詳細的升級說明。

文件

存取 PyTorch 的完整開發者文件

檢視文件

教學課程

取得初學者和高級開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得解答

檢視資源