日韩性视频-久久久蜜桃-www中文字幕-在线中文字幕av-亚洲欧美一区二区三区四区-撸久久-香蕉视频一区-久久无码精品丰满人妻-国产高潮av-激情福利社-日韩av网址大全-国产精品久久999-日本五十路在线-性欧美在线-久久99精品波多结衣一区-男女午夜免费视频-黑人极品ⅴideos精品欧美棵-人人妻人人澡人人爽精品欧美一区-日韩一区在线看-欧美a级在线免费观看

歡迎訪問 生活随笔!

生活随笔

當前位置: 首頁 > 编程资源 > 编程问答 >内容正文

编程问答

PyTorch框架学习三——张量操作

發布時間:2024/7/23 编程问答 40 豆豆
生活随笔 收集整理的這篇文章主要介紹了 PyTorch框架学习三——张量操作 小編覺得挺不錯的,現在分享給大家,幫大家做個參考.

PyTorch框架學習三——張量操作

  • 一、拼接
    • 1.torch.cat()
    • 2.torch.stack()
  • 二、切分
    • 1.torch.chunk()
    • 2.torch.split()
  • 三、索引
    • 1.torch.index_select()
    • 2.torch.masked_select()
  • 四、變換
    • 1.torch.reshape()
    • 2.torch.transpace()
    • 3.torch.t()
    • 4.torch.squeeze()
    • 5.torch.unsqueeze()

一、拼接

1.torch.cat()

功能:將tensor按照維度dim進行拼接,除了需要拼接的維度外,其余維度尺寸得是相同的。

torch.cat(tensors, dim=0, out=None)

看一下所有的參數:

  • tensors:需要被拼接的張量序列。
  • dim:(int,可選)被拼接的維度,默認為0。
  • >>> x = torch.randn(2, 3) >>> x tensor([[ 0.6580, -1.0969, -0.4614],[-0.1034, -0.5790, 0.1497]]) >>> torch.cat((x, x, x), 0) tensor([[ 0.6580, -1.0969, -0.4614],[-0.1034, -0.5790, 0.1497],[ 0.6580, -1.0969, -0.4614],[-0.1034, -0.5790, 0.1497],[ 0.6580, -1.0969, -0.4614],[-0.1034, -0.5790, 0.1497]]) >>> torch.cat((x, x, x), 1) tensor([[ 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614, 0.6580,-1.0969, -0.4614],[-0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497, -0.1034,-0.5790, 0.1497]])

    2.torch.stack()

    功能:在新創建的維度dim上進行拼接,所有的張量必須是相同的維度。

    torch.stack(tensors, dim=0, out=None)


    注意:stack()會創建一個新的維度。

    t = torch.ones((2, 3)) t_stack = torch.stack([t, t, t], dim=2) print("\nt_stack:{} shape:{}".format(t_stack, t_stack.shape))


    原來t的維度是(2, 3),本來是沒有第三維的,但是stack()會構建新的dim=2,就是先構建第三維dim=2,然后在該維度上進行拼接。

    二、切分

    1.torch.chunk()

    功能:將tensor按維度dim進行平均切分。如果不能整除,最后一份tensor在該維度上的長度小于其他tensor。

    torch.chunk(input, chunks, dim=0)

  • input:要切分的張量。
  • chunks:要切分的份數。
  • dim:要切分的維度,默認為0。
  • a = torch.ones((2, 7)) # 7 list_of_tensors = torch.chunk(a, dim=1, chunks=3) # 3for idx, t in enumerate(list_of_tensors):print("第{}個張量:{}, shape is {}".format(idx+1, t, t.shape))

    2.torch.split()

    功能:將tensor按dim進行切分。

    torch.split(tensor, split_size_or_sections, dim=0)

  • tensor:要切分的張量。
  • split_size_or_sections:(int或list(int))為int時,表示每一份的長度,如果不能整除,最后一份的長度要小于其他的張量,為list時,按list元素來切分。
  • dim:同上。
  • >>> a = torch.arange(10).reshape(5,2) >>> a tensor([[0, 1],[2, 3],[4, 5],[6, 7],[8, 9]]) >>> torch.split(a, 2) (tensor([[0, 1],[2, 3]]),tensor([[4, 5],[6, 7]]),tensor([[8, 9]])) >>> torch.split(a, [1,4]) (tensor([[0, 1]]),tensor([[2, 3],[4, 5],[6, 7],[8, 9]]))

    三、索引

    1.torch.index_select()

    功能:在dim上,按照index索引數據,返回一個依據index索引數據拼接的張量。

    torch.index_select(input, dim, index, out=None)

  • input:要索引的張量。
  • dim:被索引的維度。
  • index:一維張量,包括了要索引的數據序號。(long,不能是float)
  • out:輸出張量(可選)。
  • >>> x = torch.randn(3, 4) >>> x tensor([[ 0.1427, 0.0231, -0.5414, -1.0009],[-0.4664, 0.2647, -0.1228, -1.1068],[-1.1734, -0.6571, 0.7230, -0.6004]]) >>> indices = torch.tensor([0, 2]) >>> torch.index_select(x, 0, indices) tensor([[ 0.1427, 0.0231, -0.5414, -1.0009],[-1.1734, -0.6571, 0.7230, -0.6004]]) >>> torch.index_select(x, 1, indices) tensor([[ 0.1427, -0.5414],[-0.4664, -0.1228],[-1.1734, 0.7230]])

    2.torch.masked_select()

    功能:按照mask中的True進行索引,返回一個一維張量。

    torch.masked_select(input, mask, out=None)

    >>> x = torch.randn(3, 4) >>> x tensor([[ 0.3552, -2.3825, -0.8297, 0.3477],[-1.2035, 1.2252, 0.5002, 0.6248],[ 0.1307, -2.0608, 0.1244, 2.0139]]) >>> mask = x.ge(0.5) >>> mask tensor([[False, False, False, False],[False, True, True, True],[False, False, False, True]]) >>> torch.masked_select(x, mask) tensor([ 1.2252, 0.5002, 0.6248, 2.0139])

    四、變換

    1.torch.reshape()

    功能:變換張量的形狀。

    torch.reshape(input, shape)

  • input:輸入張量。
  • shape:新張量的形狀。當某個維度為-1時,表示該維度不用關心,可以從別的維度計算得到。
  • >>> a = torch.arange(4.) >>> torch.reshape(a, (2, 2)) tensor([[ 0., 1.],[ 2., 3.]]) >>> b = torch.tensor([[0, 1], [2, 3]]) >>> torch.reshape(b, (-1,)) tensor([ 0, 1, 2, 3])

    2.torch.transpace()

    功能:交換tensor的兩個維度。

    torch.transpose(input, dim0, dim1)

  • input:輸入張量。
  • dim0和dim1:要交換的兩個維度。
  • >>> x = torch.randn(2, 3) >>> x tensor([[ 1.0028, -0.9893, 0.5809],[-0.1669, 0.7299, 0.4942]]) >>> torch.transpose(x, 0, 1) tensor([[ 1.0028, -0.1669],[-0.9893, 0.7299],[ 0.5809, 0.4942]])

    3.torch.t()

    功能:2維tensor轉置,對矩陣而言。等價于torch.transpose(input, 0, 1)。

    torch.t(input) >>> x = torch.randn(()) >>> x tensor(0.1995) >>> torch.t(x) tensor(0.1995) >>> x = torch.randn(3) >>> x tensor([ 2.4320, -0.4608, 0.7702]) >>> torch.t(x) tensor([ 2.4320, -0.4608, 0.7702]) >>> x = torch.randn(2, 3) >>> x tensor([[ 0.4875, 0.9158, -0.5872],[ 0.3938, -0.6929, 0.6932]]) >>> torch.t(x) tensor([[ 0.4875, 0.3938],[ 0.9158, -0.6929],[-0.5872, 0.6932]])

    注意:只對矩陣會轉置,對標量和向量都不會。

    4.torch.squeeze()

    功能:壓縮長度為1的維度(軸)。

    torch.squeeze(input, dim=None, out=None)

  • input:輸入張量。
  • dim:(可選)若為None,移除所有長度為1的軸,若指定軸,當且僅當該軸長度為1時移除。
  • out:輸出張量。
  • >>> x = torch.zeros(2, 1, 2, 1, 2) >>> x.size() torch.Size([2, 1, 2, 1, 2]) >>> y = torch.squeeze(x) >>> y.size() torch.Size([2, 2, 2]) >>> y = torch.squeeze(x, 0) >>> y.size() torch.Size([2, 1, 2, 1, 2]) >>> y = torch.squeeze(x, 1) >>> y.size() torch.Size([2, 2, 1, 2])

    5.torch.unsqueeze()

    功能:返回一個新的張量,對輸入的指定位置插入維度 1。

    torch.unsqueeze(input, dim) >>> x = torch.tensor([1, 2, 3, 4]) >>> torch.unsqueeze(x, 0) tensor([[ 1, 2, 3, 4]]) >>> torch.unsqueeze(x, 1) tensor([[ 1],[ 2],[ 3],[ 4]])

    總結

    以上是生活随笔為你收集整理的PyTorch框架学习三——张量操作的全部內容,希望文章能夠幫你解決所遇到的問題。

    如果覺得生活随笔網站內容還不錯,歡迎將生活随笔推薦給好友。