捷徑

torch

torch 套件包含用於多維張量的資料結構,並定義了對這些張量的數學運算。 此外,它還提供了許多用於有效序列化張量和任意類型的實用程式,以及其他有用的實用程式。

它具有 CUDA 對應項,使您能夠在運算能力 >= 3.0 的 NVIDIA GPU 上執行張量運算。

張量 (Tensors)

is_tensor

如果 obj 是 PyTorch 張量,則返回 True。

is_storage

如果 obj 是 PyTorch 儲存物件,則返回 True。

is_complex

如果 input 的資料類型是複數資料類型,即 torch.complex64torch.complex128 其中之一,則返回 True。

is_conj

如果 input 是一個共軛張量,也就是說,它的共軛位元設定為 True,則返回 True。

is_floating_point

如果 input 的資料類型是浮點資料類型,即 torch.float64torch.float32torch.float16torch.bfloat16 其中之一,則返回 True。

is_nonzero

如果 input 是一個單一元素張量,且在類型轉換後不等於零,則返回 True。

set_default_dtype

將預設浮點 dtype 設定為 d

get_default_dtype

取得目前的預設浮點 torch.dtype

set_default_device

設定預設的 torch.Tensordevice 上分配。

get_default_device

取得預設的 torch.Tensor 將在 device 上分配

set_default_tensor_type

numel

返回 input 張量中的元素總數。

set_printoptions

設定列印選項。

set_flush_denormal

在 CPU 上停用非正規化的浮點數。

建立運算 (Creation Ops)

注意

隨機取樣建立運算列在 隨機取樣 (Random sampling) 下,包括:torch.rand() torch.rand_like() torch.randn() torch.randn_like() torch.randint() torch.randint_like() torch.randperm()。您也可以將 torch.empty()原地隨機取樣 (In-place random sampling) 方法一起使用,以建立從更廣泛分布取樣的 torch.Tensor

tensor

藉由複製 data 來建構一個沒有 autograd 歷史記錄的張量(也稱為「葉張量」,請參閱 Autograd 機制)。

sparse_coo_tensor

使用給定的 indices 建立具有指定值的 COO(rdinate) 格式的稀疏張量

sparse_csr_tensor

使用給定的 crow_indicescol_indices 建立具有指定值的 CSR (Compressed Sparse Row) 格式的稀疏張量

sparse_csc_tensor

使用給定的 ccol_indicesrow_indices 建立具有指定值的 CSC (Compressed Sparse Column) 格式的稀疏張量

sparse_bsr_tensor

使用給定的 crow_indicescol_indices 建立具有指定二維區塊的 BSR (Block Compressed Sparse Row) 格式的稀疏張量

sparse_bsc_tensor

使用給定的 ccol_indicesrow_indices 建立具有指定二維區塊的 BSC (Block Compressed Sparse Column) 格式的稀疏張量

asarray

obj 轉換為張量。

as_tensor

data 轉換為張量,盡可能共享資料並保留 autograd 歷史記錄。

as_strided

使用指定的 sizestridestorage_offset 建立現有 torch.Tensor input 的檢視。

from_file

建立一個 CPU 張量,其儲存由記憶體對應檔案支援。

from_numpy

numpy.ndarray 建立一個 Tensor

from_dlpack

將來自外部函式庫的 tensor 轉換為 torch.Tensor

frombuffer

從實作 Python buffer protocol 的物件建立一個一維的 Tensor

zeros

返回一個填滿純量值 0 的 tensor,其形狀由變數引數 size 定義。

zeros_like

返回一個填滿純量值 0 的 tensor,其大小與 input 相同。

ones

返回一個填滿純量值 1 的 tensor,其形狀由變數引數 size 定義。

ones_like

返回一個填滿純量值 1 的 tensor,其大小與 input 相同。

arange

返回大小為 endstartstep\left\lceil \frac{\text{end} - \text{start}}{\text{step}} \right\rceil 的 1-D tensor,其值來自區間 [start, end),以共同差 stepstart 開始取值。

range

返回大小為 endstartstep+1\left\lfloor \frac{\text{end} - \text{start}}{\text{step}} \right\rfloor + 1 的 1-D tensor,其值從 startend,步長為 step

linspace

建立一個大小為 steps 的一維 tensor,其值從 startend 均勻間隔,包含邊界值。

logspace

建立一個大小為 steps 的一維 tensor,其值從 basestart{{\text{{base}}}}^{{\text{{start}}}}baseend{{\text{{base}}}}^{{\text{{end}}}} 均勻間隔,包含邊界值,以 base 為底的對數刻度。

eye

返回一個 2-D tensor,其對角線上為 1,其他地方為 0。

empty

返回一個填滿未初始化資料的 tensor。

empty_like

返回一個與 input 大小相同的未初始化 tensor。

empty_strided

建立一個具有指定 sizestride 並填滿未定義資料的 tensor。

full

建立一個大小為 size 並填滿 fill_value 的 tensor。

full_like

返回一個與 input 具有相同大小的張量,並以 fill_value 填充。

quantize_per_tensor

將浮點張量轉換為具有給定縮放比例和零點的量化張量。

quantize_per_channel

將浮點張量轉換為具有給定縮放比例和零點(每個通道)的量化張量。

dequantize

通過反量化量化的張量,返回一個 fp32 張量。

complex

建構一個複數張量,其實部等於 real,虛部等於 imag

polar

建構一個複數張量,其元素是笛卡爾座標,對應於具有絕對值 abs 和角度 angle 的極座標。

heaviside

計算 input 中每個元素的海維賽德階躍函數。

索引、切片、連接、變更操作

adjoint

返回張量的共軛視圖,並將最後兩個維度轉置。

argwhere

返回一個包含 input 所有非零元素的索引的張量。

cat

在給定的維度中連接 tensors 中給定的張量序列。

concat

torch.cat() 的別名。

concatenate

torch.cat() 的別名。

conj

返回 input 的視圖,並翻轉共軛位元。

chunk

嘗試將張量分割成指定數量的塊。

dsplit

根據 indices_or_sections 將具有三個或更多維度的張量 input 按深度分割成多個張量。

column_stack

通過水平堆疊 tensors 中的張量來建立一個新的張量。

dstack

按深度(沿第三個軸)依序堆疊張量。

gather

沿著由 dim 指定的軸收集值。

hsplit

根據 indices_or_sections 將具有一個或多個維度的張量 input 水平分割成多個張量。

hstack

依序水平(按列)堆疊張量。

index_add

有關函數描述,請參閱 index_add_()

index_copy

有關函數描述,請參閱 index_add_()

index_reduce

有關函數描述,請參閱 index_reduce_()

index_select

返回一個新張量,該張量使用 LongTensorindex 中的條目沿維度 dim 索引 input 張量。

masked_select

返回一個新的 1 維張量,該張量根據布林遮罩 mask 索引 input 張量,其中 mask 是一個 BoolTensor

movedim

input 的維度從 source 中的位置移動到 destination 中的位置。

moveaxis

torch.movedim() 的別名。

narrow

返回一個新張量,該張量是 input 張量的縮小版本。

narrow_copy

Tensor.narrow() 相同,只是這個返回副本而不是共享儲存。

nonzero

permute

返回原始張量 input 的一個視圖,其維度已置換。

reshape

返回一個與 input 具有相同資料和元素數量的張量,但具有指定的形狀。

row_stack

torch.vstack() 的別名。

select

沿著給定索引處的所選維度對 input 張量進行切片。

scatter

非原地版本的 torch.Tensor.scatter_()

diagonal_scatter

src 張量的值嵌入到 input 中,沿著 input 的對角線元素,相對於 dim1dim2

select_scatter

src 張量的值嵌入到給定索引的 input 中。

slice_scatter

src 張量的值嵌入到給定維度的 input 中。

scatter_add

torch.Tensor.scatter_add_() 的異地 (out-of-place) 版本

scatter_reduce

torch.Tensor.scatter_reduce_() 的異地 (out-of-place) 版本

split

將張量分割成多個區塊。

squeeze

返回一個張量,其中移除了 input 所有大小為 1 的指定維度。

stack

沿著一個新的維度串聯一系列的張量。

swapaxes

torch.transpose() 的別名。

swapdims

torch.transpose() 的別名。

t

期望 input 為 <= 2-D 的張量,並轉置維度 0 和 1。

take

返回一個新的張量,其中包含給定索引處的 input 元素。

take_along_dim

沿著給定的 dim,從 indices 中的一維索引在 input 中選擇值。

tensor_split

將一個張量沿著 dim 維度分割成多個子張量,所有這些子張量都是 input 的視圖,分割的方式取決於 indices_or_sections 指定的索引或段數。

tile

通過重複 input 的元素來構造一個張量。

transpose

返回一個張量,該張量是 input 的轉置版本。

unbind

移除一個張量維度。

unravel_index

將扁平索引的張量轉換為坐標張量的元組,這些坐標張量索引到指定形狀的任意張量中。

unsqueeze

返回一個新的張量,該張量在指定位置插入一個大小為一的維度。

vsplit

根據 indices_or_sections 將具有兩個或多個維度的張量 input 垂直分割成多個張量。

vstack

依序垂直 (以列為主) 堆疊張量。

where

根據 condition,返回一個從 inputother 中選擇元素的張量。

加速器

在 PyTorch 儲存庫中,我們將「加速器」定義為與 CPU 一起使用以加速計算的 torch.device。這些設備使用非同步執行方案,使用 torch.Streamtorch.Event 作為其執行同步的主要方式。我們還假設在給定主機上一次只能有一個這樣的加速器可用。這允許我們將當前加速器用作相關概念(例如釘選記憶體、Stream device_type、FSDP 等)的預設設備。

截至今日,加速器設備依序為 “CUDA”“MTIA”“XPU” 和 PrivateUse1 (許多設備不在 PyTorch 儲存庫本身中)。

Stream

以先進先出 (FIFO) 順序非同步執行各自任務的有序佇列。

Event

查詢並記錄 Stream 狀態,以識別或控制跨 Stream 的相依性並測量計時。

產生器

Generator

建立並返回一個產生器物件,該物件管理產生虛擬亂數的演算法的狀態。

隨機抽樣

seed

將所有設備上產生亂數的種子設定為非決定性亂數。

manual_seed

設定所有設備上產生亂數的種子。

initial_seed

傳回作為 Python long 的產生亂數的初始種子。

get_rng_state

torch.ByteTensor 形式傳回亂數產生器狀態。

set_rng_state

設定亂數產生器狀態。

torch.default_generator 傳回 預設 CPU torch.Generator

bernoulli

從 Bernoulli 分佈中抽取二進位亂數(0 或 1)。

multinomial

傳回一個張量,其中每一列包含從多項式 (更嚴格的定義是多元,請參閱 torch.distributions.multinomial.Multinomial 取得更多詳細資訊) 機率分佈中抽取的 num_samples 索引,該分佈位於張量 input 的對應列中。

normal

傳回一個從獨立常態分佈中抽取的亂數張量,這些分佈的平均值和標準差是給定的。

poisson

傳回一個與 input 大小相同的張量,其中每個元素都從 Poisson 分佈中抽取,其速率參數由 input 中的對應元素給定,即

rand

返回一個張量,其元素為從區間 [0,1)[0, 1) 上均勻分布的隨機數。

rand_like

返回一個與 input 大小相同的張量,其元素為從區間 [0,1)[0, 1) 上均勻分布的隨機數。

randint

返回一個張量,其元素為介於 low (包含) 和 high (不包含) 之間均勻生成的隨機整數。

randint_like

返回一個形狀與 Tensor input 相同的張量,其元素為介於 low (包含) 和 high (不包含) 之間均勻生成的隨機整數。

randn

返回一個張量,其元素為來自平均數為 0 且變異數為 1 的常態分布 (也稱為標準常態分布) 的隨機數。

randn_like

返回一個與 input 大小相同的張量,其元素為來自平均數為 0 且變異數為 1 的常態分布的隨機數。

randperm

返回從 0n - 1 的整數的隨機排列。

原地 (In-place) 隨機抽樣

還有一些定義在張量上的原地隨機抽樣函數。點擊連結以參考它們的文件

準隨機抽樣

quasirandom.SobolEngine

torch.quasirandom.SobolEngine 是一個用於生成 (scrambled) Sobol 序列的引擎。

序列化

save

將一個物件儲存到磁碟檔案。

load

從檔案載入使用 torch.save() 儲存的物件。

平行化

get_num_threads

返回用於平行化 CPU 運算的執行緒數量

set_num_threads

設定用於 CPU 上運算內 (intraop) 平行化的執行緒數量。

get_num_interop_threads

返回用於 CPU 上運算間 (inter-op) 平行化的執行緒數量 (例如

set_num_interop_threads

設定用於運算間平行化的執行緒數量 (例如

局部停用梯度計算

上下文管理器 torch.no_grad()torch.enable_grad()torch.set_grad_enabled() 有助於局部停用和啟用梯度計算。 有關其使用的更多詳細資訊,請參閱局部停用梯度計算。 這些上下文管理器是執行緒局部的,因此如果您使用 threading 模組等將工作發送到另一個執行緒,它們將無法工作。

範例

>>> x = torch.zeros(1, requires_grad=True)
>>> with torch.no_grad():
...     y = x * 2
>>> y.requires_grad
False

>>> is_train = False
>>> with torch.set_grad_enabled(is_train):
...     y = x * 2
>>> y.requires_grad
False

>>> torch.set_grad_enabled(True)  # this can also be used as a function
>>> y = x * 2
>>> y.requires_grad
True

>>> torch.set_grad_enabled(False)
>>> y = x * 2
>>> y.requires_grad
False

no_grad

禁用梯度計算的上下文管理器。

enable_grad

啟用梯度計算的上下文管理器。

autograd.grad_mode.set_grad_enabled

開啟或關閉梯度計算的上下文管理器。

is_grad_enabled

如果目前已啟用 grad 模式,則傳回 True。

autograd.grad_mode.inference_mode

啟用或停用推論模式的上下文管理器。

is_inference_mode_enabled

如果目前已啟用推論模式,則傳回 True。

數學運算

常數

inf

一個浮點正無限大。 是 math.inf 的別名。

nan

一個浮點「非數字」值。 此值不是合法的數字。 是 math.nan 的別名。

逐元素運算

abs

計算 input 中每個元素的絕對值。

absolute

torch.abs() 的別名

acos

計算 input 中每個元素的反餘弦。

arccos

torch.acos() 的別名。

acosh

返回一個新的 tensor,其中包含 input 元素的反雙曲餘弦值。

arccosh

torch.acosh() 的別名。

add

other 乘以 alpha 後,加到 input

addcdiv

執行 tensor1 除以 tensor2 的元素級別除法,將結果乘以純量 value,然後加到 input

addcmul

執行 tensor1 乘以 tensor2 的元素級別乘法,將結果乘以純量 value,然後加到 input

angle

計算給定 input tensor 的元素級別角度(以弧度為單位)。

asin

返回一個新的 tensor,其中包含 input 元素的反正弦值。

arcsin

torch.asin() 的別名。

asinh

返回一個新的 tensor,其中包含 input 元素的反雙曲正弦值。

arcsinh

torch.asinh() 的別名。

atan

返回一個新的 tensor,其中包含 input 元素的反正切值。

arctan

torch.atan() 的別名。

atanh

返回一個新的 tensor,其中包含 input 元素的反雙曲正切值。

arctanh

torch.atanh() 的別名。

atan2

考慮象限,求元素級別的 inputi/otheri\text{input}_{i} / \text{other}_{i} 的反正切值。

arctan2

torch.atan2() 的別名。

bitwise_not

計算給定輸入 tensor 的位元反相。

bitwise_and

計算 inputother 的位元 AND。

bitwise_or

計算 inputother 的位元 OR。

bitwise_xor

計算 inputother 的位元 XOR。

bitwise_left_shift

計算 input 左移 other 個位元的算術左移。

bitwise_right_shift

計算 input 右移 other 個位元的算術右移。

ceil

返回一個新的 tensor,其中包含 input 元素的上限,即大於或等於每個元素的最小整數。

clamp

input 中的所有元素鉗制到範圍 [ min, max ] 內。

clip

torch.clamp() 的別名。

conj_physical

計算給定 input tensor 的元素級別共軛複數。

copysign

創建一個新的浮點 tensor,其中包含 input 的量值和 other 的符號,逐元素。

cos

返回一個新的 tensor,其中包含 input 元素的餘弦值。

cosh

返回一個新的 tensor,其中包含 input 元素的雙曲餘弦值。

deg2rad

返回一個新的 tensor,其中包含從角度(以度為單位)轉換為弧度的 input 的每個元素。

div

將輸入 input 的每個元素除以 other 的對應元素。

divide

torch.div() 的別名。

digamma

torch.special.digamma() 的別名。

erf

torch.special.erf() 的別名。

erfc

torch.special.erfc() 的別名。

erfinv

torch.special.erfinv() 的別名。

exp

返回一個新的 tensor,其中包含輸入 tensor input 中元素的指數。

exp2

torch.special.exp2() 的別名。

expm1

torch.special.expm1() 的別名。

fake_quantize_per_channel_affine

返回一個新的 tensor,其中包含使用 scalezero_pointquant_minquant_max,跨 axis 指定的通道,對 input 中的資料進行 fake 量化的結果。

fake_quantize_per_tensor_affine

返回一個新的 tensor,其中包含使用 scalezero_pointquant_minquant_maxinput 中的資料進行 fake 量化的結果。

fix

torch.trunc() 的別名

float_power

以雙精度方式,逐元素地將 input 提升到 exponent 的冪次方。

floor

返回一個新的 tensor,其中包含 input 中元素的 floor,即小於或等於每個元素的最大整數。

floor_divide

fmod

逐個元素地應用 C++ 的 std::fmod

frac

計算 input 中每個元素的小數部分。

frexp

input 分解為 mantissa 和 exponent tensors,使得 input=mantissa×2exponent\text{input} = \text{mantissa} \times 2^{\text{exponent}}

gradient

使用 二階精確中心差分法,以及邊界上的一階或二階估計,估計函數 g:RnRg : \mathbb{R}^n \rightarrow \mathbb{R} 在一維或多維中的梯度。

imag

返回一個新的 tensor,其中包含 self tensor 的虛數值。

ldexp

input 乘以 2 ** other

lerp

根據純量或 tensor weight 對兩個 tensors start(由 input 給定)和 end 進行線性插值,並返回結果 out tensor。

lgamma

計算 input 上 gamma 函數絕對值的自然對數。

log

返回一個新的 tensor,其中包含 input 中元素的自然對數。

log10

返回一個新的 tensor,其中包含 input 中以 10 為底的元素的對數。

log1p

返回一個新的 tensor,其中包含 (1 + input) 的自然對數。

log2

返回一個新的 tensor,其中包含 input 中以 2 為底的元素的對數。

logaddexp

輸入指數的總和的對數。

logaddexp2

以 2 為底的輸入指數總和的對數。

logical_and

計算給定輸入 tensors 的元素級邏輯 AND。

logical_not

計算給定輸入 tensor 的元素級邏輯 NOT。

logical_or

計算給定輸入 tensors 的元素級邏輯 OR。

logical_xor

計算給定輸入 tensors 的元素級邏輯 XOR。

logit

torch.special.logit() 的別名。

hypot

給定直角三角形的兩條直角邊,返回其斜邊。

i0

torch.special.i0() 的別名。

igamma

torch.special.gammainc() 的別名。

igammac

torch.special.gammaincc() 的別名。

mul

input 乘以 other

multiply

torch.mul() 的別名。

mvlgamma

torch.special.multigammaln() 的別名。

nan_to_num

NaN、正無窮大和負無窮大值在 input 中分別替換為 nanposinfneginf 指定的值。

neg

返回一個新張量,其中包含 input 的元素取負數後的值。

negative

torch.neg() 的別名。

nextafter

返回 input 之後,朝向 other 的下一個浮點數值,逐元素計算。

polygamma

torch.special.polygamma() 的別名。

positive

返回 input

pow

input 中的每個元素取 exponent 次方,並返回包含結果的張量。

quantized_batch_norm

對 4D (NCHW) 量化張量應用批次正規化。

quantized_max_pool1d

對由多個輸入平面組成的輸入量化張量應用 1D 最大池化。

quantized_max_pool2d

對由多個輸入平面組成的輸入量化張量應用 2D 最大池化。

rad2deg

返回一個新張量,其中 input 的每個元素都從弧度制的角度轉換為角度制的角度。

real

返回一個包含 self 張量的實數值的新張量。

reciprocal

返回一個新張量,其中包含 input 元素的倒數。

remainder

逐元素計算 Python 的模數運算

round

input 的元素四捨五入到最接近的整數。

rsqrt

返回一個新張量,其中包含 input 的每個元素的平方根的倒數。

sigmoid

torch.special.expit() 的別名。

sign

返回一個新張量,其中包含 input 的元素的符號。

sgn

此函數是 torch.sign() 對複數張量的擴展。

signbit

測試 input 的每個元素的符號位是否已設置。

sin

返回一個新張量,其中包含 input 的元素的正弦值。

sinc

torch.special.sinc() 的別名。

sinh

返回一個新張量,其中包含 input 的元素的雙曲正弦值。

softmax

torch.nn.functional.softmax() 的別名。

sqrt

返回一個新張量,其中包含 input 的元素的平方根。

square

返回一個新張量,其中包含 input 的元素的平方。

sub

input 中減去 other,並按 alpha 縮放。

subtract

torch.sub() 的別名。

tan

返回一個新張量,其中包含 input 的元素的正切值。

tanh

返回一個新張量,其中包含 input 的元素的雙曲正切值。

true_divide

torch.div() 的別名,且 rounding_mode=None

trunc

返回一個新張量,其中包含 input 的元素的截斷整數值。

xlogy

torch.special.xlogy() 的別名。

Reduction Ops

argmax

返回 input 張量中所有元素的最大值的索引。

argmin

返回展平張量或沿維度的最小值(們)的索引

amax

返回給定維度 diminput 張量的每個切片的最大值。

amin

返回給定維度 diminput 張量的每個切片的最小值。

aminmax

計算 input 張量的最小值和最大值。

all

測試 input 中的所有元素是否都評估為 True

any

測試 input 中是否有任何元素評估為 True

max

返回 input 張量中所有元素的最大值。

min

返回 input 張量中所有元素的最小值。

dist

傳回 (input - other) 的 p 範數。

logsumexp

傳回 input 張量在指定維度 dim 上,每列元素取指數後總和的對數值。

mean

nanmean

計算指定維度上所有非 NaN 元素的平均值。

median

傳回 input 中數值的中位數。

nanmedian

傳回 input 中數值的中位數,忽略 NaN 值。

mode

傳回一個具名元組 (values, indices),其中 valuesinput 張量在指定維度 dim 上,每列元素出現次數最多的值 (眾數),而 indices 則是每個眾數值的索引位置。

norm

傳回給定張量的矩陣範數或向量範數。

nansum

傳回所有元素的總和,並將非數字 (NaN) 視為零。

prod

傳回 input 張量中所有元素的乘積。

quantile

計算 input 張量在維度 dim 上,每列元素的 q 分位數。

nanquantile

這是 torch.quantile() 的變體,它會 "忽略" NaN 值,並計算分位數 q,如同 input 中的 NaN 值不存在一樣。

std

計算由 dim 指定維度上的標準差。

std_mean

計算由 dim 指定維度上的標準差和平均值。

sum

傳回 input 張量中所有元素的總和。

unique

傳回輸入張量的唯一元素。

unique_consecutive

從每個等價元素的連續群組中,消除除第一個元素之外的所有元素。

var

計算由 dim 指定維度上的變異數。

var_mean

計算由 dim 指定維度上的變異數和平均值。

count_nonzero

計算張量 input 在給定 dim 上非零值的數量。

比較運算子

allclose

此函數檢查 inputother 是否滿足以下條件

argsort

傳回沿給定維度按值升序對張量排序的索引。

eq

計算逐元素相等性

equal

如果兩個張量具有相同的大小和元素,則為 True,否則為 False

ge

計算逐元素的 inputother\text{input} \geq \text{other}

greater_equal

torch.ge() 的別名。

gt

計算逐元素的 input>other\text{input} > \text{other}

greater

torch.gt() 的別名。

isclose

傳回一個新的張量,其中包含布林元素,表示 input 的每個元素是否與 other 的對應元素 "接近"。

isfinite

傳回一個新的張量,其中包含布林元素,表示每個元素是否為 有限 值。

isin

測試 elements 的每個元素是否在 test_elements 中。

isinf

測試 input 的每個元素是否為無限大(正或負無限大)。

isposinf

測試 input 的每個元素是否為正無限大。

isneginf

測試 input 的每個元素是否為負無限大。

isnan

傳回一個新的張量,其中包含布林元素,表示 input 的每個元素是否為 NaN。

isreal

傳回一個新的張量,其中包含布林元素,表示 input 的每個元素是否為實數值。

kthvalue

傳回一個具名元組 (values, indices),其中 valuesinput 張量在給定維度 dim 上,每列元素的第 k 個最小值。

le

逐元素計算 inputother\text{input} \leq \text{other}

less_equal

torch.le() 的別名。

lt

逐元素計算 input<other\text{input} < \text{other}

less

torch.lt() 的別名。

maximum

計算 inputother 的逐元素最大值。

minimum

計算 inputother 的逐元素最小值。

fmax

計算 inputother 的逐元素最大值。

fmin

計算 inputother 的逐元素最小值。

ne

逐元素計算 inputother\text{input} \neq \text{other}

not_equal

torch.ne() 的別名。

sort

沿著給定的維度,依值以遞增順序對 input 張量的元素進行排序。

topk

沿著給定的維度,返回給定 input 張量的 k 個最大元素。

msort

沿著其第一個維度,依值以遞增順序對 input 張量的元素進行排序。

頻譜運算

stft

短時傅立葉變換 (STFT)。

istft

反短時傅立葉變換。

bartlett_window

Bartlett 視窗函數。

blackman_window

Blackman 視窗函數。

hamming_window

Hamming 視窗函數。

hann_window

Hann 視窗函數。

kaiser_window

計算 Kaiser 視窗,其視窗長度為 window_length,形狀參數為 beta

其他運算

atleast_1d

返回每個零維輸入張量的 1 維視圖。

atleast_2d

返回每個零維輸入張量的 2 維視圖。

atleast_3d

返回每個零維輸入張量的 3 維視圖。

bincount

計算非負整數陣列中每個值的頻率。

block_diag

從提供的張量建立分塊對角矩陣。

broadcast_tensors

根據廣播語義廣播給定的張量。

broadcast_to

input 廣播到 shape 形狀。

broadcast_shapes

broadcast_tensors() 類似,但用於形狀。

bucketize

返回每個 input 中的值所屬的桶的索引,其中桶的邊界由 boundaries 設定。

cartesian_prod

執行給定張量序列的笛卡爾積。

cdist

計算批次的 p 範數距離,其距離為兩組行向量的每一對。

clone

返回 input 的副本。

combinations

計算給定張量的長度為 rr 的組合。

corrcoef

估計由 input 矩陣給定的變數的 Pearson 積差相關係數矩陣,其中行是變數,列是觀察值。

cov

估計由 input 矩陣給定的變數的共變異數矩陣,其中行是變數,列是觀察值。

cross

返回 inputotherdim 維度中向量的叉積。

cummax

返回一個名為 (values, indices) 的 namedtuple,其中 valuesinputdim 維度中的元素累積分母。

cummin

返回一個名為 (values, indices) 的 namedtuple,其中 valuesinputdim 維度中的元素累積最小值。

cumprod

返回 input 中沿著維度 dim 的元素的累計乘積。

cumsum

返回 input 中沿著維度 dim 的元素的累計總和。

diag

  • 如果 input 是一個向量(1 維張量),則返回一個 2 維正方形張量。

diag_embed

創建一個張量,其某些 2D 平面(由 dim1dim2 指定)的對角線由 input 填充。

diagflat

  • 如果 input 是一個向量(1 維張量),則返回一個 2 維正方形張量。

diagonal

返回 input 的部分視圖,其對角線元素相對於 dim1dim2,並作為一個維度附加在形狀的末尾。

diff

計算沿給定維度的 n 階前向差分。

einsum

沿著使用基於愛因斯坦求和約定的符號指定的維度,對輸入 operands 的元素乘積求和。

flatten

通過將 input 重塑為一維張量來將其扁平化。

flip

沿著 dims 中給定的軸反轉 n 維張量的順序。

fliplr

在左右方向翻轉張量,返回一個新張量。

flipud

在上下方向翻轉張量,返回一個新張量。

kron

計算 inputother 的 Kronecker 積,用 \otimes 表示。

rot90

在由 dims 軸指定的平面中將 n 維張量旋轉 90 度。

gcd

計算 inputother 的元素級最大公因數 (GCD)。

histc

計算張量的直方圖。

histogram

計算張量中值的直方圖。

histogramdd

計算張量中值的多維直方圖。

meshgrid

創建由 attr:tensors 中的 1D 輸入指定的座標網格。

lcm

計算 inputother 的元素級最小公倍數 (LCM)。

logcumsumexp

返回 input 中沿著維度 dim 的元素的指數運算的累計總和的對數。

ravel

返回一個連續的扁平化張量。

renorm

返回一個張量,其中 input 沿著維度 dim 的每個子張量都被正規化,使得子張量的 p-範數低於值 maxnorm

repeat_interleave

重複張量的元素。

roll

沿著給定的維度滾動張量 input

searchsorted

sorted_sequence 的*最內層*維度中找到索引,這樣,如果 values 中的相應值在排序後插入到索引之前,則 sorted_sequence 中相應的*最內層*維度的順序將被保留。

tensordot

返回 a 和 b 在多個維度上的收縮。

trace

返回輸入 2 維矩陣的對角線元素的總和。

tril

返回矩陣(2 維張量)或矩陣批次 input 的下三角部分,結果張量 out 的其他元素設置為 0。

tril_indices

返回 row-by- col 矩陣的下三角部分的索引,以 2-by-N 張量表示,其中第一行包含所有索引的行座標,第二行包含列座標。

triu

返回矩陣(2 維張量)或矩陣批次 input 的上三角部分,結果張量 out 的其他元素設置為 0。

triu_indices

返回 row by col 矩陣的上三角部分的索引,以 2-by-N 張量表示,其中第一行包含所有索引的行座標,第二行包含列座標。

unflatten

在多個維度上展開輸入張量的維度。

vander

生成一個 Vandermonde 矩陣。

view_as_real

返回 input 作為實數張量的視圖。

view_as_complex

返回 input 作為複數張量的視圖。

resolve_conj

如果 input 的共軛位設置為 True,則返回一個具有實體化共軛的新張量,否則返回 input

resolve_neg

如果 input 的負數位設置為 True,則返回一個具有實體化取負的新張量,否則返回 input

BLAS 和 LAPACK 運算

addbmm

執行儲存在 batch1batch2 中的矩陣的批次矩陣乘積,並帶有簡化的加法步驟(所有矩陣乘法沿第一個維度累加)。

addmm

執行矩陣 mat1mat2 的矩陣乘法。

addmv

執行矩陣 mat 與向量 vec 的矩陣向量乘積。

addr

執行向量 vec1vec2 的外積,並將其加到矩陣 input 上。

baddbmm

執行 batch1batch2 中矩陣的批次矩陣乘積。

bmm

執行儲存在 inputmat2 中矩陣的批次矩陣乘積。

chain_matmul

返回 NN 個 2-D 張量的矩陣乘積。

cholesky

計算對稱正定矩陣 AA 的 Cholesky 分解,或者對批次的對稱正定矩陣進行計算。

cholesky_inverse

計算給定其 Cholesky 分解的複 Hermitian 或實對稱正定矩陣的逆矩陣。

cholesky_solve

計算具有複 Hermitian 或實對稱正定左側,且給定其 Cholesky 分解的線性方程組的解。

dot

計算兩個 1D 張量的點積。

geqrf

這是一個底層函數,用於直接調用 LAPACK 的 geqrf。

ger

torch.outer() 的別名。

inner

計算 1D 張量的點積。

inverse

torch.linalg.inv() 的別名

det

torch.linalg.det() 的別名

logdet

計算方陣或批次方陣的對數行列式。

slogdet

torch.linalg.slogdet() 的別名

lu

計算矩陣或批次矩陣 A 的 LU 分解。

lu_solve

使用來自 lu_factor() 的 A 的部分主元 LU 分解,返回線性系統 Ax=bAx = b 的 LU 解。

lu_unpack

lu_factor() 返回的 LU 分解解包為 P, L, U 矩陣。

matmul

兩個張量的矩陣乘積。

matrix_power

torch.linalg.matrix_power() 的別名

matrix_exp

torch.linalg.matrix_exp() 的別名。

mm

執行矩陣 inputmat2 的矩陣乘法。

mv

執行矩陣 input 與向量 vec 的矩陣向量乘積。

orgqr

torch.linalg.householder_product() 的別名。

ormqr

計算 Householder 矩陣的乘積與一般矩陣的矩陣乘法。

outer

inputvec2 的外積。

pinverse

torch.linalg.pinv() 的別名

qr

計算矩陣或一批矩陣 input 的 QR 分解,並返回張量的 namedtuple (Q, R),使得 input=QR\text{input} = Q R,其中 QQ 是正交矩陣或一批正交矩陣,RR 是上三角矩陣或一批上三角矩陣。

svd

計算矩陣或批次矩陣 input 的奇異值分解 (singular value decomposition)。

svd_lowrank

回傳矩陣、批次矩陣或稀疏矩陣 AA 的奇異值分解 (U, S, V),使得 AUdiag(S)VHA \approx U \operatorname{diag}(S) V^{\text{H}}

pca_lowrank

對低秩矩陣、此類矩陣的批次或稀疏矩陣執行線性主成分分析 (PCA)。

lobpcg

使用無矩陣 LOBPCG 方法,找到對稱正定廣義特徵值問題的 k 個最大(或最小)特徵值以及對應的特徵向量。

trapz

torch.trapezoid() 的別名。

trapezoid

沿著 dim 計算梯形法則

cumulative_trapezoid

沿著 dim 累計計算梯形法則

triangular_solve

求解具有方形上三角或下三角可逆矩陣 AA 和多個右手邊 bb 的方程式系統。

vdot

沿著一個維度計算兩個 1D 向量的點積。

Foreach 運算

警告

此 API 處於 Beta 階段,可能會在未來進行變更。不支援前向模式 AD。

_foreach_abs

torch.abs() 應用於輸入列表的每個 Tensor。

_foreach_abs_

torch.abs() 應用於輸入列表的每個 Tensor。

_foreach_acos

torch.acos() 應用於輸入列表的每個 Tensor。

_foreach_acos_

torch.acos() 應用於輸入列表的每個 Tensor。

_foreach_asin

torch.asin() 應用於輸入列表的每個 Tensor。

_foreach_asin_

torch.asin() 應用於輸入列表的每個 Tensor。

_foreach_atan

torch.atan() 應用於輸入列表的每個 Tensor。

_foreach_atan_

torch.atan() 應用於輸入列表的每個 Tensor。

_foreach_ceil

torch.ceil() 應用於輸入列表的每個 Tensor。

_foreach_ceil_

torch.ceil() 應用於輸入列表的每個 Tensor。

_foreach_cos

torch.cos() 應用於輸入列表的每個 Tensor。

_foreach_cos_

torch.cos() 應用於輸入列表的每個 Tensor。

_foreach_cosh

torch.cosh() 應用於輸入列表的每個 Tensor。

_foreach_cosh_

torch.cosh() 應用於輸入列表的每個 Tensor。

_foreach_erf

torch.erf() 應用於輸入列表的每個 Tensor。

_foreach_erf_

torch.erf() 應用於輸入列表的每個 Tensor。

_foreach_erfc

torch.erfc() 應用於輸入列表的每個 Tensor。

_foreach_erfc_

torch.erfc() 應用於輸入列表的每個 Tensor。

_foreach_exp

torch.exp() 應用於輸入列表的每個 Tensor。

_foreach_exp_

torch.exp() 應用於輸入列表的每個 Tensor。

_foreach_expm1

torch.expm1() 應用於輸入列表的每個 Tensor。

_foreach_expm1_

torch.expm1() 應用於輸入列表的每個 Tensor。

_foreach_floor

torch.floor() 應用於輸入列表的每個 Tensor。

_foreach_floor_

torch.floor() 應用於輸入列表的每個 Tensor。

_foreach_log

torch.log() 應用於輸入列表的每個 Tensor。

_foreach_log_

torch.log() 應用於輸入列表的每個 Tensor。

_foreach_log10

torch.log10() 應用於輸入列表的每個 Tensor。

_foreach_log10_

torch.log10() 應用於輸入列表的每個 Tensor。

_foreach_log1p

torch.log1p() 應用於輸入列表的每個 Tensor。

_foreach_log1p_

torch.log1p() 應用於輸入列表的每個 Tensor。

_foreach_log2

torch.log2() 應用於輸入列表的每個 Tensor。

_foreach_log2_

torch.log2() 應用於輸入列表的每個 Tensor。

_foreach_neg

torch.neg() 應用於輸入列表的每個 Tensor。

_foreach_neg_

torch.neg() 應用於輸入列表的每個 Tensor。

_foreach_tan

torch.tan() 應用於輸入列表的每個 Tensor。

_foreach_tan_

torch.tan() 應用於輸入列表的每個 Tensor。

_foreach_sin

torch.sin() 應用於輸入列表的每個 Tensor。

_foreach_sin_

torch.sin() 應用於輸入列表的每個 Tensor。

_foreach_sinh

torch.sinh() 應用於輸入列表的每個 Tensor。

_foreach_sinh_

torch.sinh() 應用於輸入列表的每個 Tensor。

_foreach_round

torch.round() 應用於輸入列表中的每個 Tensor。

_foreach_round_

torch.round() 應用於輸入列表中的每個 Tensor。

_foreach_sqrt

torch.sqrt() 應用於輸入列表中的每個 Tensor。

_foreach_sqrt_

torch.sqrt() 應用於輸入列表中的每個 Tensor。

_foreach_lgamma

torch.lgamma() 應用於輸入列表中的每個 Tensor。

_foreach_lgamma_

torch.lgamma() 應用於輸入列表中的每個 Tensor。

_foreach_frac

torch.frac() 應用於輸入列表中的每個 Tensor。

_foreach_frac_

torch.frac() 應用於輸入列表中的每個 Tensor。

_foreach_reciprocal

torch.reciprocal() 應用於輸入列表中的每個 Tensor。

_foreach_reciprocal_

torch.reciprocal() 應用於輸入列表中的每個 Tensor。

_foreach_sigmoid

torch.sigmoid() 應用於輸入列表中的每個 Tensor。

_foreach_sigmoid_

torch.sigmoid() 應用於輸入列表中的每個 Tensor。

_foreach_trunc

torch.trunc() 應用於輸入列表中的每個 Tensor。

_foreach_trunc_

torch.trunc() 應用於輸入列表中的每個 Tensor。

_foreach_zero_

torch.zero() 應用於輸入列表中的每個 Tensor。

工具程式

compiled_with_cxx11_abi

傳回 PyTorch 是否使用 _GLIBCXX_USE_CXX11_ABI=1 建置

result_type

傳回對提供的輸入張量執行算術運算後產生的 torch.dtype

can_cast

判斷在型別提升 文件中描述的 PyTorch 轉換規則下是否允許型別轉換。

promote_types

傳回具有最小大小和純量種類的 torch.dtype,該型別的大小或種類都不小於 type1type2

use_deterministic_algorithms

設定 PyTorch 運算是否必須使用「確定性」演算法。

are_deterministic_algorithms_enabled

如果全域確定性旗標已開啟,則傳回 True。

is_deterministic_algorithms_warn_only_enabled

如果全域確定性旗標設定為僅警告,則傳回 True。

set_deterministic_debug_mode

設定確定性運算的偵錯模式。

get_deterministic_debug_mode

傳回確定性運算的偵錯模式的目前值。

set_float32_matmul_precision

設定 float32 矩陣乘法的內部精確度。

get_float32_matmul_precision

傳回 float32 矩陣乘法精確度的目前值。

set_warn_always

當此旗標為 False (預設值) 時,某些 PyTorch 警告可能只會在每個程序中顯示一次。

get_device_module

傳回與給定裝置相關聯的模組 (例如,torch.device('cuda')、"mtia:0"、"xpu" 等)。

is_warn_always_enabled

如果已開啟全域 warn_always 旗標,則傳回 True。

vmap

vmap 是向量化的對應;vmap(func) 會傳回一個新的函式,該函式會在輸入的某些維度上對應 func

_assert

Python assert 的包裝函式,可進行符號追蹤。

符號數字

class torch.SymInt(node)[source][source]

類似於 int (包括 magic methods),但會將對包裝節點的所有運算重新導向。 這特別用於符號記錄符號形狀工作流程中的運算。

as_integer_ratio()[source][source]

將此 int 表示為精確的整數比率

傳回類型

Tuple[SymInt, int]

class torch.SymFloat(node)[source][source]

類似於 float (包括 magic methods),但會將對包裝節點的所有運算重新導向。 這特別用於符號記錄符號形狀工作流程中的運算。

as_integer_ratio()[source][source]

將此 float 表示為精確的整數比率

傳回類型

Tuple[int, int]

conjugate()[source][source]

傳回 float 的複共軛。

傳回類型

SymFloat

hex()[source][source]

傳回 float 的十六進位表示法。

傳回類型

str

is_integer()[source][source]

如果 float 是整數,則傳回 True。

class torch.SymBool(node)[原始碼][原始碼]

如同布林值 (包含魔術方法),但將所有操作重新導向至封裝的節點。 這特別用於以符號方式記錄符號形狀工作流程中的操作。

與常規布林值不同,常規布林運算子會強制執行額外的保護,而不是以符號方式評估。 請改用位元運算子來處理此問題。

sym_float

用於浮點數轉換的 SymInt 感知工具。

sym_fresh_size

sym_int

用於整數轉換的 SymInt 感知工具。

sym_max

用於 max 的 SymInt 感知工具,避免在 a < b 時進行分支。

sym_min

用於 min() 的 SymInt 感知工具。

sym_not

用於邏輯否定的 SymInt 感知工具。

sym_ite

sym_sum

N 元加法,對於長列表,其計算速度比迭代二元加法更快。

匯出路徑

警告

此功能是一個原型,未來可能會發生不相容的變更。

export generated/exportdb/index

控制流程

警告

此功能是一個原型,未來可能會發生不相容的變更。

cond

有條件地套用 true_fnfalse_fn

最佳化

compile

使用 TorchDynamo 和指定的後端來最佳化給定的模型/函式。

torch.compile 文件

運算子標籤

class torch.Tag

成員

core

data_dependent_output

dynamic_output_shape

flexible_layout

generated

inplace_view

maybe_aliasing_or_mutating

needs_fixed_stride_order

nondeterministic_bitwise

nondeterministic_seeded

pointwise

pt2_compliant_tag

view_copy

property name

文件

存取 PyTorch 的完整開發人員文件

檢視文件

教學課程

取得針對初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得問題解答

檢視資源