CosineSimilarity¶
- class torch.nn.CosineSimilarity(dim=1, eps=1e-08)[source][source]¶
傳回 和 之間的餘弦相似度 (cosine similarity),沿著 dim 計算。
- 形狀
輸入1: 其中 D 位於 dim 的位置
- Input2: , 與 x1 具有相同的維度數量,且在維度 dim 上與 x1 的大小相符,
並可在其他維度上與 x1 進行廣播。
Output:
- 範例:
>>> input1 = torch.randn(100, 128) >>> input2 = torch.randn(100, 128) >>> cos = nn.CosineSimilarity(dim=1, eps=1e-6) >>> output = cos(input1, input2)