快捷鍵

torch.linalg.householder_product

torch.linalg.householder_product(A, tau, *, out=None) Tensor

計算 Householder 矩陣乘積的前 n 行。

Let K\mathbb{K} be R\mathbb{R} or C\mathbb{C}, and let AKm×nA \in \mathbb{K}^{m \times n} be a matrix with columns aiKma_i \in \mathbb{K}^m for i=1,,mi=1,\ldots,m with mnm \geq n. Denote by bib_i the vector resulting from zeroing out the first i1i-1 components of aia_i and setting to 1 the ii-th. For a vector τKk\tau \in \mathbb{K}^k with knk \leq n, this function computes the first nn columns of the matrix

H1H2...HkwithHi=ImτibibiHH_1H_2 ... H_k \qquad\text{with}\qquad H_i = \mathrm{I}_m - \tau_i b_i b_i^{\text{H}}

where Im\mathrm{I}_m is the m-dimensional identity matrix and bHb^{\text{H}} is the conjugate transpose when bb is complex, and the transpose when bb is real-valued. The output matrix is the same size as the input matrix A.

See Representation of Orthogonal or Unitary Matrices for further details.

支援 float、double、cfloat 和 cdouble dtypes 的輸入。也支援矩陣批次,如果輸入是矩陣批次,則輸出也具有相同的批次維度。

另請參閱

torch.geqrf() can be used together with this function to form the Q from the qr() decomposition.

torch.ormqr() is a related function that computes the matrix multiplication of a product of Householder matrices with another matrix. However, that function is not supported by autograd.

警告

只有在 τi1ai2\tau_i \neq \frac{1}{||a_i||^2}. If this condition is not met, no error will be thrown, but the gradient produced may contain NaN.

參數
  • A (Tensor) – 形狀為 (*, m, n) 的張量,其中 * 為零或多個批次維度。

  • tau (Tensor) – 形狀為 (*, k) 的張量,其中 * 為零或多個批次維度。

關鍵字參數

out (Tensor, optional) – 輸出張量。如果為 None 則忽略。預設值:None。

引發

RuntimeError – 如果 A 不滿足 m >= n 的要求,或 tau 不滿足 n >= k 的要求。

範例

>>> A = torch.randn(2, 2)
>>> h, tau = torch.geqrf(A)
>>> Q = torch.linalg.householder_product(h, tau)
>>> torch.dist(Q, torch.linalg.qr(A).Q)
tensor(0.)

>>> h = torch.randn(3, 2, 2, dtype=torch.complex128)
>>> tau = torch.randn(3, 1, dtype=torch.complex128)
>>> Q = torch.linalg.householder_product(h, tau)
>>> Q
tensor([[[ 1.8034+0.4184j,  0.2588-1.0174j],
        [-0.6853+0.7953j,  2.0790+0.5620j]],

        [[ 1.4581+1.6989j, -1.5360+0.1193j],
        [ 1.3877-0.6691j,  1.3512+1.3024j]],

        [[ 1.4766+0.5783j,  0.0361+0.6587j],
        [ 0.6396+0.1612j,  1.3693+0.4481j]]], dtype=torch.complex128)

文件

存取 PyTorch 的完整開發者文件

查看文件

教學

為初學者和進階開發者提供深入教學

查看教學

資源

尋找開發資源並獲得問題解答

查看資源