Sunday, March 30, 2014

Gradient Descent

Earlier, I had a post on predicting Iris Flower types from their flower sizes with the gradient descent algorithm.

The algorithm is actually quite beautiful in its simplicity. I'll try to explain the what the algorithm does and why it works in this post. Hopefully, this will give you a better picture of how machines can "learn" from data.

Note: I'll be splitting my explanation into two parts. This is part 1. Part 2 will be coming out soon.

--------------------------------------

In its most basic implementation, the gradient descent tries to fit a line to a set of points.

So, given a set of an m number points on the xy coordinate system below,

x | y
1 | -1
3 | 1
...
...
x[m] | y[m]
My excellent plotting skills
  The algorithm would want to find a line like the green one I drew.


Because our data set is linear (follows a sort of straight line), we can represent the line as a linear function in the form f(x) = mx + b.

Following the naming convention, the line is called: h(x) = θ[0] + θ[1]x , but essentially their the same.

So, ready? The algorithm is as follows:

Repeat until convergence {
    For θ[j] in every θ {

        update θ[j] to be :

    }
} where a is the learning rate
and m is the number of training points

For our example, the algorithm would essentially be this:

Repeat until the values of θ don't change anymore {
     for θ[j] in the θ we have (θ[0] and θ[1]) {
          update θ[j] to be :
                 itself + the learning rate * the summation with all points i of ( (the actual value - the predicted value) * the x of that point)
     }
}  

 ------------------------------------

You may be wondering where these formulas come from and why they work. To understand that, you'll first have to know what it means to have a "better fitting" line.

First off the cost function is a function that measures a "bad" a line fits a data set. If the cost function is high, the line doesn't fit the data set very well. The goal of gradient descent is to minimize the cost function.

The cost function used here is:




where m is the number of training points. 

If you look closely, you'll find that the cost function is actually a total of the distances between the predicted line and the actual points.

  

------------------------------

This is the part 1. Any questions or comments? Post them below.

No comments:

Post a Comment