学习笔记—Optimization algorithms

news/2024/7/10 4:55:26 标签: 深度学习, 优化, coursera机器学习笔记

这是这门课程第二周的内容。
深度学习遇到大数据(样本量在十万级以上)时,我们之前的常规操作在这时可能会变得很笨重。

1. Mini-batch gradient descent

之前在模型训练过程时,每一轮迭代都需要遍历整个训练集样本,当样本集非常大时,这样的每一轮都将经历漫长的时间。为了应对这一难题,有人提出了Mini-batch gradient descent,与之对应的是batch gradient descent。
batch gradient descent即我们之前常用到的梯度下降算法,它在每一轮计算梯度时考虑所有的训练样本集;
mini-batch gradient descent的特殊之处在于,每一轮计算梯度时,只考虑一部分(m)训练样本集;一种极端情况,每轮只考虑一个训练样本,这样的优化算法又称之为Stochastic Gradient Descent。

如:将整个训练集划分为64个样本为单个mini_batch。
这里写图片描述

由于mini-batch每次更新梯度时只考虑一部分训练样本,因此更新所得参数并不能保证使训练集总体的损失函数值一直处于下降的趋势,所得到的结果一般是训练集的损失函数值在震荡中逐渐降低。
这里写图片描述
mini-batch的优点是训练速度快,对资源的要求低;缺点是往往很难收敛。
这里写图片描述
常用的mini-batch大小:64(2^6), 128(2^7),256(2^8), 1024(2^10)

2. Gradient descent with momentum

上一节提到,mini-batch gradient descent每次更新参数并不能保证训练集的损失值一直变小,表现出来的结果会比较震荡。可以这样理解:mini-batch在一些我们比较不关心的方向上过度学习,并且拖缓了我们希望它前进的方法的速度。
为了应对这一挑战,我们采用 exponentially weighted average的思路来计算梯度。即我们希望新的梯度由当前batch样本计算的梯度和以往的梯度组合而成,组合方式就是上面提到的指数权重平均法。
下图可以帮我们很好地理解exponentially weighted average。
这里写图片描述
其中 vt 为本轮的梯度, vt1 为上一轮的梯度, θt 为本轮样本计算的梯度。从图上可以发现,当将公式分解后,第 t 轮的梯度vt由从第一轮开始到当前轮所有的梯度分别乘以不同的权重组合而成。

这里写图片描述

Exponentially weighted average

指数权重平均法可以有效地降低数据的波动性,并且代码实现简单,运行时不需要缓存过多数据,在深度学习算法中很受欢迎。
下图为采用指数权重平均法描述一年当中温度的变化情况,其中红线 β=0.9 ,绿线 β=0.98
这里写图片描述

Momentum

使用指数权重平均法更新梯度的梯度下降算法便可称之为 Gradient descent with momentum,那么它的梯度更新公式如下:

{vdW[l]=βvdW[l]+(1β)dW[l]W[l]=W[l]αvdW[l]

{vdb[l]=βvdb[l]+(1β)db[l]b[l]=b[l]αvdb[l]

如下:Momentum的梯度用红色箭头表示,蓝线代表当前batch计算的梯度,它只是影响红色箭头的走向,红色箭头并不完全按它的方向前进。
这里写图片描述

β 的值一般取0.9,当然也可以选择其它的值; β 值越大梯度的更新越平滑,因为它要考虑更多的过去的梯度值。

Momentum在更新梯度时考虑了以前的梯度值,这样它的梯度更新变得更平滑,Momentum可以应用在batch gradient descent, mini-batch gradient descent 或 stochastic gradient descent。

3. Gradient descent with RMSprop

RMSprop是另一种可以加速梯度下降的方法。可以认为RMSprop和Momentum两种方法的目的都是一样的——加速梯度下降,只是具体实现策略不同。
这里写图片描述

4. Adam

Adam是用于训练神经网络的最有效方法之一,它结合了Momentum和RMSprop两者的思想。我们来看一下Adam在更新梯度时的策略:
这里写图片描述

Adam是怎样工作的?在每一轮进行梯度更新时:
首先它会计算出前面梯度值的指数权重平均值,并保存为变量 v vcorrect;(Momentum)
然后它会计算出前面梯度值的指数权重平均值的平方值,并保存为变量 s scorrect;(RMSprop)
最后基于前两步的结果更新梯度。(Combination)

其中更新 W 的公式如下:

vdW[l]=β1vdW[l]+(1β1)JW[l]vcorrecteddW[l]=vdW[l]1(β1)tsdW[l]=β2sdW[l]+(1β2)(JW[l])2scorrecteddW[l]=sdW[l]1(β1)tW[l]=W[l]αvcorrecteddW[l]scorrecteddW[l]+ε

更新参数 b 的公式与上同理。

Adam的优点:
- 相对较小的内存需求(当然比gradient descent和gradient descent with momentum还是要高一些);
- 收敛非常迅速。

5. Learning-rate-decay

有时候我们在训练模型时会遇到模型无法在全局最优点收敛的问题。如下蓝线所示:
这里写图片描述
然而如果设置α的值随着每次迭代变小就能极大改善上述现象,如上图绿线所示。
通常用于缩小 α 的方法如下:
这里写图片描述

注:如无特殊说明,以上所有图片均截选自吴恩达在Coursera开设的神经网络系列课程的讲义。


http://www.niftyadmin.cn/n/1693394.html

相关文章

生成对抗网络学习笔记3----论文unsupervised representation learning with deep convolutional generative adversarial

论文原文:地址 论文译文:地址 1、阅读论文 Radford A, Metz L, Chintala S. Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks[J]. Computer Science, 2015. 2、翻译论文 摘要 近年来,使…

学习笔记-batch normalization

这是Deep learning 第二门课的第三周课程的学习笔记。 1. Hyperparameter tuning 针对深度学习,不推荐使用grid search来寻找最优的参数值。因为深度学习的计算量实在太大了,grid search方法太耗资源也太慢了。 对于深度学习的调参,吴老师…

生成对抗网络学习笔记4----GAN(Generative Adversarial Nets)的实现

首先是各种参考博客、链接等,表示感谢。 1、参考博客1:地址 2、参考博客2:地址 ——以下,开始正文。 1、GAN的简单总结 见上一篇博客。 2、利用GAN生成1维正态分布 首先,我们创建“真实”数据分布,一…

NTT

1 问题描述FFT问题解决的是复数域上的卷积。如果现在的问题是这样:给出两个整数数列$Ai,Bj,0\leq i\leq n-1,0\leq j\leq m-1$,以及素数$P$,计算新数列$Ci(\sum_{k}A_{i-k}B_{k})\%P$。不在$A,B$定义域内的值均为0.NTT就是解决这样在模意义下的卷积问题。2 预备知识…

生成对抗网络学习笔记5----DCGAN(unsupervised representation learning with deep convolutional generative adv)的实现

首先是各种参考博客、链接等,表示感谢。 1、参考博客1:地址 ——以下,开始正文。 2017/12/12 更新 解决训练不收敛的问题。 更新在最后面部分。 1、DCGAN的简单总结 稳定的深度卷积GAN 架构指南: 所有的pooling层使用步幅卷积(判别网络…

Algorithms, Part I

前言 目前我已暂停学习深度学习课程,虽然之前的学习过程很愉快,但我现在想要挑战一下自己。机器学习在很多大学都被分在计算机科学底下,而我又了解到要学习compute science几乎不可能绕过数据结构和算法这一部分内容,这一部分的知…

TensorFlow学习笔记13----TensorFlow Serving

原文教程:tensorflow官方教程 记录关键内容与学习感受。未完待续。。 TensorFlow Serving ——这一部分最后再来看。先放着。 1、介绍 tensorflow服务器对于机器学习模型来说,是一个灵活的、高效能的服务系统,用来设计生产环境。tensorfl…

TensorFlow学习笔记14----Convolutional Neural Networks

原文教程:tensorflow官方教程 记录关键内容与学习感受。未完待续。。 Convolutional Neural Networks 1、Overview 1.1 Goals 1.2 Highlights of the Tutorial 1.3 Model Architecture 2、Code Organization 3、CIFAR-10 Model 3.1 Model Inputs 3.2 Model…