优草派 > 问答 > Python

PyTorch学习:动态图和静态图的例子

作者:ws64176     

PyTorch是一种深度学习框架,它的特点是动态图和静态图都支持。这篇文章将从多个角度分析动态图和静态图的例子。

动态图和静态图的区别

在深度学习中,动态图和静态图是两种不同的计算图模式。静态图在定义模型时会预先定义好计算图结构,然后将数据输入到图中进行计算。而动态图则是在运行时动态地构建计算图。

静态图的优点是在运行时计算速度快,因为它已经预先定义好了计算图结构。但是静态图的缺点是在模型的调试和修改时非常麻烦,因为需要重新定义计算图结构。

动态图的优点是在调试和修改模型时非常方便,因为可以动态地修改计算图结构。但是动态图的缺点是在运行时计算速度比静态图慢。

动态图和静态图的例子

下面我们将从多个角度分析动态图和静态图的例子。

1. 定义模型

静态图的例子:

```

import tensorflow as tf

x = tf.placeholder(tf.float32, [None, 784])

y = tf.placeholder(tf.float32, [None, 10])

W = tf.Variable(tf.zeros([784, 10]))

b = tf.Variable(tf.zeros([10]))

y_pred = tf.nn.softmax(tf.matmul(x, W) + b)

cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(y_pred), reduction_indices=[1]))

train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

```

动态图的例子:

```

import torch.nn as nn

import torch.nn.functional as F

class Net(nn.Module):

def __init__(self):

super(Net, self).__init__()

self.fc1 = nn.Linear(784, 10)

def forward(self, x):

x = F.softmax(self.fc1(x), dim=1)

return x

net = Net()

```

可以看到,动态图的定义方式更加简单清晰。

2. 训练模型

静态图的例子:

```

with tf.Session() as sess:

sess.run(tf.global_variables_initializer())

for i in range(1000):

batch_xs, batch_ys = mnist.train.next_batch(100)

sess.run(train_step, feed_dict={x: batch_xs, y: batch_ys})

correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.argmax(y, 1))

accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

print(sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels}))

```

动态图的例子:

```

import torch.optim as optim

optimizer = optim.SGD(net.parameters(), lr=0.5)

for epoch in range(10):

for i, data in enumerate(trainloader, 0):

inputs, labels = data

optimizer.zero_grad()

outputs = net(inputs)

loss = F.cross_entropy(outputs, labels)

loss.backward()

optimizer.step()

correct = 0

total = 0

with torch.no_grad():

for data in testloader:

images, labels = data

outputs = net(images)

_, predicted = torch.max(outputs.data, 1)

total += labels.size(0)

correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

```

动态图的训练方式更加简单直观。

3. 模型转换

静态图的例子:

```

import tensorflow as tf

with tf.Session() as sess:

saver = tf.train.Saver()

saver.restore(sess, "model.ckpt")

tf.train.write_graph(sess.graph_def, '.', 'model.pbtxt')

converter = tf.contrib.lite.TFLiteConverter.from_session(sess, [x], [y_pred])

tflite_model = converter.convert()

open("model.tflite", "wb").write(tflite_model)

```

动态图的例子:

```

import torch

dummy_input = torch.randn(1, 784)

torch.onnx.export(net, dummy_input, "model.onnx", verbose=True)

```

可以看到,动态图的模型转换方式更加简单。

5天短视频训练营
新手入门剪辑课程,零基础也能学
分享变现渠道,助你兼职赚钱
限时特惠:0元
立即抢
新手剪辑课程 (精心挑选,简单易学)
第一课
新手如何学剪辑视频? 开始学习
第二课
短视频剪辑培训班速成是真的吗? 开始学习
第三课
不需要付费的视频剪辑软件有哪些? 开始学习
第四课
手机剪辑app哪个好? 开始学习
第五课
如何做短视频剪辑赚钱? 开始学习
第六课
视频剪辑接单网站APP有哪些? 开始学习
第七课
哪里可以学短视频运营? 开始学习
第八课
做短视频运营需要会什么? 开始学习
相关问题
sql判断字段是否存在
python键值对
for循环可以遍历字典吗
怎么使用vscode
查看更多

客服热线:0731-85127885

湘ICP备19005950号-1  

工商营业执照信息

违法和不良信息举报

举报电话:0731-85127885 举报邮箱:tousu@csai.cn

优草派  版权所有 © 2024