pytorch中的squeeze函数、cat函数使用
1 squeeze(): 去除size为1的维度,包括行和列。
至于维度大于等于2时,squeeze()不起作用。
行、例:
>>> torch.rand(4, 1, 3) (0 ,.,.) = 0.5391 0.8523 0.9260 (1 ,.,.) = 0.2507 0.9512 0.6578 (2 ,.,.) = 0.7302 0.3531 0.9442 (3 ,.,.) = 0.2689 0.4367 0.6610 [torch.FloatTensor of size 4x1x3]
>>> torch.rand(4, 1, 3).squeeze() 0.0801 0.4600 0.1799 0.0236 0.7137 0.6128 0.0242 0.3847 0.4546 0.9004 0.5018 0.4021 [torch.FloatTensor of size 4x3]
列、例:
>>> torch.rand(4, 3, 1) (0 ,.,.) = 0.7013 0.9818 0.9723 (1 ,.,.) = 0.9902 0.8354 0.3864 (2 ,.,.) = 0.4620 0.0844 0.5707 (3 ,.,.) = 0.5722 0.2494 0.5815 [torch.FloatTensor of size 4x3x1]
>>> torch.rand(4, 3, 1).squeeze() 0.8784 0.6203 0.8213 0.7238 0.5447 0.8253 0.1719 0.7830 0.1046 0.0233 0.9771 0.2278 [torch.FloatTensor of size 4x3]
不变、例:
>>> torch.rand(4, 3, 2) (0 ,.,.) = 0.6618 0.1678 0.3476 0.0329 0.1865 0.4349 (1 ,.,.) = 0.7588 0.8972 0.3339 0.8376 0.6289 0.9456 (2 ,.,.) = 0.1392 0.0320 0.0033 0.0187 0.8229 0.0005 (3 ,.,.) = 0.2327 0.6264 0.4810 0.6642 0.8625 0.6334 [torch.FloatTensor of size 4x3x2]
>>> torch.rand(4, 3, 2).squeeze() (0 ,.,.) = 0.0593 0.8910 0.9779 0.1530 0.9210 0.2248 (1 ,.,.) = 0.7938 0.9362 0.1064 0.6630 0.9321 0.0453 (2 ,.,.) = 0.0189 0.9187 0.4458 0.9925 0.9928 0.7895 (3 ,.,.) = 0.5116 0.7253 0.0132 0.6673 0.9410 0.8159 [torch.FloatTensor of size 4x3x2]
2 cat函数
>>> t1=torch.FloatTensor(torch.randn(2,3)) >>> t1 -1.9405 1.2009 0.0018 0.9463 0.4409 -1.9017 [torch.FloatTensor of size 2x3]
>>> t2=torch.FloatTensor(torch.randn(2,2)) >>> t2 0.0942 0.1581 1.1621 1.2617 [torch.FloatTensor of size 2x2]
>>> torch.cat((t1, t2), 1) -1.9405 1.2009 0.0018 0.0942 0.1581 0.9463 0.4409 -1.9017 1.1621 1.2617 [torch.FloatTensor of size 2x5]
补充:pytorch中 max()、view()、 squeeze()、 unsqueeze()
查了好多博客都似懂非懂,后来写了几个小例子,瞬间一目了然。
一、torch.max()
import torch a=torch.randn(3) print("a:\n",a) print('max(a):',torch.max(a)) b=torch.randn(3,4) print("b:\n",b) print('max(b,0):',torch.max(b,0)) print('max(b,1):',torch.max(b,1))
输出:
a:
tensor([ 0.9558, 1.1242, 1.9503])
max(a): tensor(1.9503)
b:
tensor([[ 0.2765, 0.0726, -0.7753, 1.5334],
[ 0.0201, -0.0005, 0.2616, -1.1912],
[-0.6225, 0.6477, 0.8259, 0.3526]])
max(b,0): (tensor([ 0.2765, 0.6477, 0.8259, 1.5334]), tensor([ 0, 2, 2, 0]))
max(b,1): (tensor([ 1.5334, 0.2616, 0.8259]), tensor([ 3, 2, 2]))
max(a),用于一维数据,求出最大值。
max(a,0),计算出数据中一列的最大值,并输出最大值所在的行号。
max(a,1),计算出数据中一行的最大值,并输出最大值所在的列号。
print('max(b,1):',torch.max(b,1)[1])
输出:只输出行最大值所在的列号
max(b,1): tensor([ 3, 2, 2])
torch.max(b,1)[0], 只返回最大值的每个数
二、view()
a.view(i,j)表示将原矩阵转化为i行j列的形式
i为-1表示不限制行数,输出1列
a=torch.randn(3,4) print(a)
输出:
tensor([[-0.8146, -0.6592, 1.5100, 0.7615],
[ 1.3021, 1.8362, -0.3590, 0.3028],
[ 0.0848, 0.7700, 1.0572, 0.6383]])
b=a.view(-1,1)
print(b)
输出:
tensor([[-0.8146],
[-0.6592],
[ 1.5100],
[ 0.7615],
[ 1.3021],
[ 1.8362],
[-0.3590],
[ 0.3028],
[ 0.0848],
[ 0.7700],
[ 1.0572],
[ 0.6383]])
i为1,j为-1表示不限制列数,输出1行
b=a.view(1,-1) print(b)
输出:
tensor([[-0.8146, -0.6592, 1.5100, 0.7615, 1.3021, 1.8362, -0.3590,
0.3028, 0.0848, 0.7700, 1.0572, 0.6383]])
i为-1,j为2表示不限制行数,输出2列
b=a.view(-1,2) print(b)
输出:
tensor([[-0.8146, -0.6592],
[ 1.5100, 0.7615],
[ 1.3021, 1.8362],
[-0.3590, 0.3028],
[ 0.0848, 0.7700],
[ 1.0572, 0.6383]])
i为-1,j为3表示不限制行数,输出3列
i为4,j为3表示输出4行3列
b=a.view(-1,3) print(b) b=a.view(4,3) print(b)
输出:
tensor([[-0.8146, -0.6592, 1.5100],
[ 0.7615, 1.3021, 1.8362],
[-0.3590, 0.3028, 0.0848],
[ 0.7700, 1.0572, 0.6383]])
tensor([[-0.8146, -0.6592, 1.5100],
[ 0.7615, 1.3021, 1.8362],
[-0.3590, 0.3028, 0.0848],
[ 0.7700, 1.0572, 0.6383]])
三、
1.torch.squeeze()
压缩矩阵,我理解为降维
a.squeeze(i) 压缩第i维,如果这一维维数是1,则这一维可有可无,便可以压缩
import torch a=torch.randn(1,3,4) print(a) b=a.squeeze(0) print(b) c=a.squeeze(1) print(c
输出:
tensor([[[ 0.4627, 1.6447, 0.1320, 2.0946],
[-0.0080, 0.1794, 1.1898, -1.2525],
[ 0.8281, -0.8166, 1.8846, 0.9008]]])
一页三行4列的矩阵
第0维为1,则可以通过squeeze(0)删掉,转化为三行4列的矩阵
tensor([[ 0.4627, 1.6447, 0.1320, 2.0946],
[-0.0080, 0.1794, 1.1898, -1.2525],
[ 0.8281, -0.8166, 1.8846, 0.9008]])
第1维不为1,则不可以压缩
tensor([[[ 0.4627, 1.6447, 0.1320, 2.0946],
[-0.0080, 0.1794, 1.1898, -1.2525],
[ 0.8281, -0.8166, 1.8846, 0.9008]]])
2.torch.unsqueeze()
unsqueeze(i) 表示将第i维设置为1
对压缩为3行4列后的矩阵b进行操作,将第0维设置为1
c=b.unsqueeze(0) print(c)
输出一个一页三行四列的矩阵
tensor([[[ 0.0661, -0.2386, -0.6610, 1.5774],
[ 1.2210, -0.1084, -0.1166, -0.2379],
[-1.0012, -0.4363, 1.0057, -1.5180]]])
将第一维设置为1
c=b.unsqueeze(1) print(c)
输出一个3页,一行,4列的矩阵
tensor([[[-1.0067, -1.1477, -0.3213, -1.0633]],
[[-2.3976, 0.9857, -0.3462, -0.3648]],
[[ 1.1012, -0.4659, -0.0858, 1.6631]]])
另外,squeeze、unsqueeze操作不改变原矩阵
以上为个人经验,希望能给大家一个参考,也希望大家多多支持猪先飞。
相关文章
- 这篇文章主要介绍了Tomcat配置及如何在Eclipse中启动,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下...2021-02-04
- 这篇文章主要介绍了Intellij IDEA连接Navicat数据库的方法,本文通过图文并茂的形式给大家介绍的非常详细,对大家的学习或工作具有一定的参考借价值,需要的朋友可以参考下...2021-03-25
pytorch nn.Conv2d()中的padding以及输出大小方式
今天小编就为大家分享一篇pytorch nn.Conv2d()中的padding以及输出大小方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-04-27- 这篇文章主要介绍了PyTorch一小时掌握之迁移学习篇,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下...2021-09-08
Linux安装Pytorch1.8GPU(CUDA11.1)的实现
这篇文章主要介绍了Linux安装Pytorch1.8GPU(CUDA11.1)的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧...2021-03-25- 这篇文章主要介绍了Pytorch之扩充tensor的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2021-03-05
- 今天小编就为大家分享一篇pytorch 自定义卷积核进行卷积操作方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-05-06
- 这篇文章主要介绍了解决pytorch 交叉熵损失输出为负数的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-07-08
pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率
今天小编就为大家分享一篇pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-05-02- Navicat for MySQL注册码用来激活 Navicat for MySQL 软件,只要拥有 Navicat 注册码就能激活相应的 Navicat 产品。这篇文章主要介绍了Navicat for MySQL 11注册码\激活码汇总,需要的朋友可以参考下...2020-11-23
- 这篇文章主要介绍了pytorch 实现冻结部分参数训练另一部分,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2021-03-27
Jenkins+tomcat自动发布的热部署/重启及遇到的问题解决办法(推荐)
这篇文章主要介绍了Jenkins+tomcat自动发布的热部署/重启及遇到的问题解决办法,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下...2020-07-10- 过滤器Filter是定义于tomcat的servlet-api.jar中的一个接口,接口路径为javax.servlet.Filter。tomcat过滤器采用了典型的过滤器设计模式,过滤器链FilterChain由tomcat维持,链条是可以支持多个过滤器的...2021-06-26
从Pytorch模型pth文件中读取参数成numpy矩阵的操作
这篇文章主要介绍了从Pytorch模型pth文件中读取参数成numpy矩阵的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2021-03-04Pytorch 的损失函数Loss function使用详解
今天小编就为大家分享一篇Pytorch 的损失函数Loss function使用详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-05-02tomcat启动完成执行 某个方法 定时任务(Spring)操作
这篇文章主要介绍了tomcat启动完成执行 某个方法 定时任务(Spring)操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-09-25- 这篇文章主要介绍了Tomcat正常访问localhost报404问题解决,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧...2021-03-31
- 今天小编就为大家分享一篇pytorch中的上采样以及各种反操作,求逆操作详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-04-30
使用Maven 搭建 Spring MVC 本地部署Tomcat的详细教程
这篇文章主要介绍了使用Maven 搭建 Spring MVC 本地部署Tomcat,本文通过图文并茂的形式给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下...2021-08-16- 这篇文章主要介绍了基于Pytorch版yolov5的滑块验证码破解思路详解,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下...2021-02-25