torch.nn.functional.cosine_similarity¶
- torch.nn.functional.cosine_similarity(x1, x2, dim=1, eps=1e-8) Tensor ¶
傳回
x1
和x2
之間的 cosine 相似度,沿著維度 dim 計算。x1
和x2
必須可以廣播到一個共同形狀。dim
指的是這個共同形狀中的維度。輸出結果的維度dim
會被壓縮 (參見torch.squeeze()
),導致輸出張量的維度減少 1。支援 型別提升。
- 參數
範例
>>> input1 = torch.randn(100, 128) >>> input2 = torch.randn(100, 128) >>> output = F.cosine_similarity(input1, input2) >>> print(output)