线性回归算法
在这一篇博客中,我会表达我对今天学习的线性回归算法的理解。
OK,目前要解决的问题就是,我们得到了一个数据集,也称做训练集。
而机器学习的过程就是:
- 计算机读取数据集
- 计算机通过某种算法分析数据集
- 计算机分析过后得出了一个线性函数h(x)
- 这个h(x)可以很好地拟合数据集地数据,从而可以方便人们对某个特定的数据进行预测。
上面的线性函数h(x)叫做假设函数,“某种算法”这里用的是梯度下降法。如果听不懂也没有关系,接下来我将介绍如何实现该算法。
假设函数:h ( x ) = θ 0 + θ 1 x h(x)=\theta_0+\theta_1xh(x)=θ0+θ1x
θ 0 \theta_0θ0和θ 1 \theta_1θ1被称为模型参数,我们希望计算机得到特定的θ 0 \theta_0θ0和θ 1 \theta_1θ1,使得h(x)这条直线能够很好地拟合数据集中的数据。
接下来得问题就是如何得到特定的θ 0 \theta_0θ0和θ 1 \theta_1θ1呢?我们假设数据集中一共6个数据,然后根据得到的直线h(x),可以计算真实数据(也就是数据集中的数据)与理论数据的偏差。
对于一个数据而言,它的偏差是h(xi ^ii)-yi ^ii,把每个数据的偏差全部加起来,就得到可以描述计算机拟合真实数据好坏的指标。由于h(xi ^ii)-yi ^ii可能为正也可能为负,所以平方之后再相加是合适的。
假设数据集一共有m条数据,最后得到代价函数,也称作损失函数:J ( θ 0 , θ 1 ) = 1 2 m ∑ i = 1 m [ h ( x i ) − y i ] 2 J(\theta_0,\theta_1)=\frac1{2m}\sum_{i=1}^m[h(x^i)-y^i]^2J(θ0,θ1)=2m1i=1∑m[h(xi)−yi]2
至于为什么要乘一个1 2 m \frac1{2m}2m1,具体我也没有深究,只知道这是更合适的做法。
接下来线性回归算法就只剩一个问题了,就是如何求特定的θ 0 \theta_0θ0和θ 1 \theta_1θ1,使得代价函数有最小值。
这里求代价函数最小值的办法叫做:梯度下降法
梯度下降法的过程如下:
- 随便给定θ 0 \theta_0θ0和θ 1 \theta_1θ1的初值,比如θ 0 = 0 \theta_0=0θ0=0,θ 1 = 0 \theta_1=0θ1=0
- 然后一点点改变θ 0 \theta_0θ0和θ 1 \theta_1θ1,这种改变会使得代价函数J ( θ 0 , θ 1 ) J(\theta_0,\theta_1)J(θ0,θ1)的值逐渐变小,直到找到代价函数的最小值。
使得代价函数的值逐渐变小的公式为θ 0 = θ 0 − α ∂ J ( θ 0 , θ 1 ) ∂ θ 0 \theta_0=\theta_0-\alpha\frac{\partial{J(\theta_0,\theta_1)}}{\partial\theta_0}θ0=θ0−α∂θ0∂J(θ0,θ1)
θ 1 = θ 1 − α ∂ J ( θ 0 , θ 1 ) ∂ θ 1 \theta_1=\theta_1-\alpha\frac{\partial{J(\theta_0,\theta_1)}}{\partial\theta_1}θ1=θ1−α∂θ1∂J(θ0,θ1)
这里的“=”是赋值的意思,不是数学中判断相等的意思。α \alphaα被称为学习速率,直观来看就是下降的速率。
其实学过微积分的,看到这个公式应该很快能明白其中的含义,这里我用我的理解来说明一下这个公式。
∂ J ( θ 0 , θ 1 ) ∂ θ 1 \frac{\partial{J(\theta_0,\theta_1)}}{\partial\theta_1}∂θ1∂J(θ0,θ1)
这个是代价函数对θ 1 \theta_1θ1的偏导数,当只有一个参数时,该公式变为d J ( θ 1 ) d θ 1 \frac{dJ(\theta_1)}{d\theta_1}dθ1dJ(θ1)
假设θ 0 \theta_0θ0固定等于0,只有θ 1 \theta1θ1一个参数,此时得到的代价函数类似为二次函数的图像,在二次函数上任意取一个点,假设该点在二次函数最低点的左边,就如上图所示。
此时该公式表示二次函数在θ 1 \theta1θ1这个点处的斜率,此时斜率为负。
减去一个负数,就是加上一个正数,此时左边的点开始向右边移动,而越接近最小值,斜率越接近0,则加上的正数会越小,移动的长度就会越短,直到斜率为0时,加上的数就为0,θ 1 \theta1θ1不再改变,此时得到的θ 1 \theta1θ1值使得代价函数有最小值。
同理若从右边开始,斜率为正,同样会下降到最低点。
有多个θ \thetaθ参数时时,代价函数从二维函数逐渐向高纬度转变,超过三维的图像就很难可视化了,但是用这个公式计算是不会错的。
这篇博客就到这里了,如果有什么不懂得地方或者觉得我写得不对的地方可以向我留言哦,我会一一和你们讨论学习的。