ChengYue's Cat

torch.eye()函数

用来生层对角矩阵,即对角线上的元素都为1,其余都为0.

import torch as t

t.eye(3,3)
# tensor([[1., 0., 0.],
#         [0., 1., 0.],
#         [0., 0., 1.]])

t.eye(3,4)
# tensor([[1., 0., 0., 0.],
#         [0., 1., 0., 0.],
#         [0., 0., 1., 0.]])

t.eye(3)
# tensor([[1., 0., 0.],
#         [0., 1., 0.],
#         [0., 0., 1.]])

squeeze()unsqueeze()函数

import torch as t

t = t.rand(1,2,10)
# tensor([[[0.9775, 0.5134, 0.7421, 0.0982, 0.3329, 0.6887, 0.7929, 0.9276, 0.4891, 0.5787],
#          [0.1926, 0.7890, 0.4820, 0.0320, 0.1211, 0.7013, 0.8026, 0.7744, 0.5264, 0.4967]]])

t.squeeze(dim=0)
# tensor([[0.9775, 0.5134, 0.7421, 0.0982, 0.3329, 0.6887, 0.7929, 0.9276, 0.4891, 0.5787],
#         [0.1926, 0.7890, 0.4820, 0.0320, 0.1211, 0.7013, 0.8026, 0.7744, 0.5264, 0.4967]])

t.unsqueeze(dim=0)
# tensor([[[[0.9775, 0.5134, 0.7421, 0.0982, 0.3329, 0.6887, 0.7929, 0.9276, 0.4891, 0.5787],
#           [0.1926, 0.7890, 0.4820, 0.0320, 0.1211, 0.7013, 0.8026, 0.7744, 0.5264, 0.4967]]]])


t.unsqueeze(dim=2)
print(t)
# tensor([[[0.9775, 0.5134, 0.7421, 0.0982, 0.3329, 0.6887, 0.7929, 0.9276, 0.4891, 0.5787],
#          [0.1926, 0.7890, 0.4820, 0.0320, 0.1211, 0.7013, 0.8026, 0.7744, 0.5264, 0.4967]]])
# print(t.size())