admin 管理员组文章数量: 1184232
曲线拟合
基于pytorch的深度学习曲线拟合–动图版
在jupyter notebook上无法显示动图
网络存在梯度爆炸的可能,可以通过调节学习率,或者多运行几次避免不好的结果(知道曲线拟合为止)
,为了防止结果重复,就不固定随机种子了。
import torch
from torch import nn, optim
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
import os
import torch.nn.functional as Fimport randomos.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"#如果虚拟环境和基础环境下存在相同的dll文件可能会报错,这个函数的功能是允许副本文件的存在# 建立神经网络框架
class CurveRegression(nn.Module):def __init__(self, n_feature=1, n_hidden=2, n_output=1): # 构造函数super(CurveRegression, self).__init__() # 继承nn.Module# 定义每层用什么样的形式self.hidden = nn.Linear(n_feature, n_hidden) # 建立隐藏层self.w = nn.Linear(n_hidden, n_feature) # 第二层神经元,为防止梯度爆炸前向传播的时候没用到self.hidden1 = nn.Linear(n_feature, n_hidden) # 建立隐藏层self.predict = nn.Linear(n_hidden, n_output) # 建立输出层# 实现前向传播功能def forward(self, x):x = self.hidden(x)# x = F.relu(x) # 激活函数# w = self.w(x)#网络层太深容易造成梯度爆炸,如果使用的话可以多运行几次,或者调一下学习率# x = self.hidden1(w)x = F.relu(x) # 激活函数,使网络变得非线性x = self.predict(x)return xdef setup_seed(seed):#设置随机种子使网络可以复现torch.manual_seed(seed)torch.cuda.manual_seed_all(seed)np.random.seed(seed)random.seed(seed)torch.backends.cudnn.deterministic = Truedef main():x_train = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1)y_train = x_train.pow(3) +x_train.pow(2) + 0.2 * torch.rand(x_train.size())device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("using {} device.".format(device))# setup_seed(10)为了防止出现梯度爆炸后结果重复,就不固定随机种子了。model =CurveRegression().to(device)criterion = nn.MSELoss() # 定义损失函数:均方误差optimizer = optim.SGD(model.parameters(), lr=0.2) # 使用随机梯度下降进行优化# 开始训练模型num_epochs = 1500plt.ion()for epoch in range(num_epochs):plt.cla()out = model(x_train) # 得到网络前向传播地结果loss = criterion(out, y_train) # 得到损失函数# backward# 归零梯度,做反向传播和更新参数# 每次做反向传播之前,都要归零梯度# 不然梯度会累加,造成结果不收敛optimizer.zero_grad()loss.backward()optimizer.step()# 在训练过程中,每隔一段时间就将损失函数的值打印出来,确保误差越来越小if (epoch + 1) % 20 == 0:print('Epoch[{}/{}],loss {:.6f}'.format(epoch + 1, num_epochs, loss.data))# 将测试数据放入网络做前向传播plt.scatter(x_train.data.numpy(), y_train.data.numpy(),)plt.plot(x_train.data.numpy(), out.data.numpy(), 'r-', label='Original Data')plt.text(0.5, 0, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 10, 'color': 'red'})plt.text(0.5, 1.1, 'Epoch=%d' % epoch, fontdict={'size': 10, 'color': 'red'})plt.pause(0.001)plt.ioff()plt.show()# 预测结果if __name__ == "__main__":main()
本文标签: 曲线拟合
版权声明:本文标题:曲线拟合 内容由网友自发贡献,该文观点仅代表作者本人, 转载请联系作者并注明出处:http://roclinux.cn/p/1697355130a267516.html, 本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容,一经查实,本站将立刻删除。
发表评论