优草派 > 问答 > Python

pytorch 实现张量tensor,图片,CPU,GPU,数组等的转换

作者:xujincan     

PyTorch是一种基于Python的科学计算库,它是一个用于构建深度神经网络的开源机器学习库。PyTorch具有高度可扩展性,可用于处理各种类型的数据,包括张量(tensor)、图片、CPU、GPU、数组等。在本文中,我们将深入探讨如何使用PyTorch实现这些数据类型的转换。一、张量tensor的转换

张量是PyTorch中最基本的数据类型。它类似于Numpy中的数组,但可以在GPU上运行,这使得PyTorch比Numpy更快。下面是如何在PyTorch中创建张量的代码:

```python

import torch

# 创建一个大小为5x3的未初始化张量

x = torch.empty(5, 3)

print(x)

# 创建一个大小为5x3的随机张量

x = torch.rand(5, 3)

print(x)

# 创建一个大小为5x3的全0张量,数据类型为long

x = torch.zeros(5, 3, dtype=torch.long)

print(x)

```

可以使用`size()`方法来查看张量的大小:

```python

print(x.size())

```

张量可以在CPU和GPU之间转换。下面是将张量从CPU转移到GPU的代码:

```python

# 在GPU上创建一个大小为5x3的随机张量

x = torch.rand(5, 3).cuda()

# 将张量从GPU转移到CPU

x = x.cpu()

# 将张量从CPU转移到GPU

x = x.cuda()

```

二、图片的转换

在PyTorch中,可以使用`torchvision`模块处理图像。下面是如何将图像转换为张量的代码:

```python

import torch

import torchvision

import torchvision.transforms as transforms

# 定义变换

transform = transforms.Compose(

[transforms.ToTensor()])

# 加载数据集

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,

download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,

shuffle=True, num_workers=2)

# 显示图像

import matplotlib.pyplot as plt

import numpy as np

# 定义类别

classes = ('plane', 'car', 'bird', 'cat',

'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# 获取随机数据

dataiter = iter(trainloader)

images, labels = dataiter.next()

# 显示图像

def imshow(img):

img = img / 2 + 0.5 # 非标准化

npimg = img.numpy()

plt.imshow(np.transpose(npimg, (1, 2, 0)))

plt.show()

imshow(torchvision.utils.make_grid(images))

print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

```

三、CPU和GPU的转换

在PyTorch中,可以在CPU和GPU之间转换数据。下面是如何将数据从CPU转移到GPU的代码:

```python

import torch

# 在GPU上创建一个大小为5x3的随机张量

x = torch.rand(5, 3).cuda()

# 将张量从GPU转移到CPU

x = x.cpu()

# 将张量从CPU转移到GPU

x = x.cuda()

```

四、数组的转换

在PyTorch中,可以使用`numpy()`方法将张量转换为Numpy数组。下面是如何将张量转换为Numpy数组的代码:

```python

import torch

# 创建一个大小为5x3的随机张量

x = torch.rand(5, 3)

# 将张量转换为Numpy数组

y = x.numpy()

# 将Numpy数组转换为张量

z = torch.from_numpy(y)

```

五、

5天短视频训练营
新手入门剪辑课程,零基础也能学
分享变现渠道,助你兼职赚钱
限时特惠:0元
立即抢
新手剪辑课程 (精心挑选,简单易学)
第一课
新手如何学剪辑视频? 开始学习
第二课
短视频剪辑培训班速成是真的吗? 开始学习
第三课
不需要付费的视频剪辑软件有哪些? 开始学习
第四课
手机剪辑app哪个好? 开始学习
第五课
如何做短视频剪辑赚钱? 开始学习
第六课
视频剪辑接单网站APP有哪些? 开始学习
第七课
哪里可以学短视频运营? 开始学习
第八课
做短视频运营需要会什么? 开始学习
相关问题
sql判断字段是否存在
python键值对
for循环可以遍历字典吗
怎么使用vscode
查看更多

客服热线:0731-85127885

湘ICP备19005950号-1  

工商营业执照信息

违法和不良信息举报

举报电话:0731-85127885 举报邮箱:tousu@csai.cn

优草派  版权所有 © 2024