• Mike

Gradient Boosting - Part 1

Updated: Jan 18

Gradient boosting is a power technique for building predictive models. Gradient Boosting is about taking a model that by itself is a weak predictive model and combining that model with other models of the same type to produce a more accurate model. The idea is to compute a sequence of simple decisions trees, where each successive tree is built for the prediction residuals of the preceding tree.

In gradient boosting the weak model is called a weak learner. The term used when combining various machine learning models is an ensemble. The weak learner in XGBoost is a decision tree.

Therefore, we need to understand how a decision tree works before we can understand boosting.

Decision Trees Defined

A decision tree is a graphical representation of all the possible solutions to a decision based on certain conditions. Decision trees can be viewed as a set of rules that contain simple clauses put together into a large equation. They are a series of if/then statements that segment your data into similar groups. The rules can be easily understood and can provide observations which are not always obvious by a simple inspection of data.

It's called a decision tree because it starts with a single box (or root), which then branches off into several solutions, just like a tree. While this definition is a little obtuse several examples will help make the concept more lucid. Let’s define a few core terms first. The start of the tree is called the root node. This node will split on the most important attribute in your dataset. After the initial split are decision nodes. They are referred to as decision nodes because a split in the data has been made that caused the tree to branch in two separate directions, hence an initial decision has been made on how best to split the data. The leaf node will often lead to the answer or to the predicted value. Please keep in mind these examples are contrived an only for instructional purposes only.

Let’s look at one of the most famous datasets used in machine learning, the Titanic dataset. The RMS Titanic was a British passenger liner that sank in the North Atlantic Ocean in the early morning hours of 15 April 1912, after it collided with an iceberg during its maiden voyage from Southampton to New York City. There were an estimated 2,224 passengers and crew aboard the ship, and more than 1,500 died, making it one of the deadliest commercial peacetime maritime disasters in modern history. One of the top reasons that the sinking resulted in such loss of life was that there were not enough lifeboats for the passengers and crew. Although there was some element of luck involved in surviving the sinking, some groups of people were more likely to survive than others. The picture below shows the attributes in the Titanic dataset. The target variable, the attribute we want to predict is Survived. Because the target variable is a 1 or a 0 this is a binary classification problem.

Let’s attempt to answer some questions based on our data. How would a decision tree be applied to this dataset?

The very first thing the decision tree algorithm does is to look for the attribute that is the most important in the dataset.

The phrase most important is relative and it is based on some mathematics behind the scene. The dataset has an attribute for gender titled sex. The algorithm decided that the single most important attribute was the sex attribute. Specifically, is the passenger male or not?

The root node is split into two groups, male and female. After the first split has been completed, the algorithm continues to analyze the other attributes in the dataset. The algorithm decided that the age of the passenger was the second most import attribute in the dataset. Specifically, is the passenger older than nine and half years of age? That answer led to another split based on another attribute in the dataset, sibsp. The attribute sibsp simply means, did the passenger have siblings aboard at the time of the sinking? With each successive division, the members of the resulting sets become more and more like each other. Creating a tree involves deciding which features to choose and what conditions to use for splitting, along with knowing when to stop.

Decision trees are helpful, not only because they are a visual representation that help you visualize what you are thinking, but also because design of a decision tree requires a documented thought process. Decision trees help formalize the brainstorming process so we can identify more potential solutions.

Recursive Binary Splitting

Recursive binary splitting is a numerical procedure where all the values are lined up and different split points are tried and tested using a cost function. A cost function is something we want to minimize. Let’s try to define a cost function is simplest terms. Whenever you train a model on your data, you are producing some new values (predicted) for a specific attribute. However, that specific attribute already has some values which real values in the dataset are.

The closer the predicted values are to the corresponding real values, the better the model.

We are using a cost function to measure how close the predicted values are to their corresponding real values. The split with the best cost (lowest cost because we want minimize cost in this situation) is selected. All input variables and all possible split points are evaluated and chosen in a greedy manner.

Recall the earlier example of tree learned from titanic dataset? The very first split that occurred was the split at the root node. At this juncture, all the attributes in the dataset are considered and the training data is divided into groups based on that split. The split that costs the least amount is chosen, which in our example is sex of the passenger. The algorithm is recursive in nature as the groups formed can be sub-divided using same strategy. The decision tree is a greedy algorithm. A greedy algorithm is an algorithm that is concerned with making the best decision at each step of process.

One of the issues with this approach is overfitting. Overfitting is like learning through memorization. Rather than understanding the concepts and making cognizant decisions you simply recall what you’ve seen before and find the closest thing to what you’ve memorized previously. This means that while your model will perform very well on the training set by memorizing what it should do with each input, when it is faced with an input it has never seen before it won’t have any general concepts to fall back on.

The complexity of a decision tree is defined as the number of splits in a decision tree. Simpler trees are preferred. They are easy to understand, and they are less likely to overfit your data. The question now becomes; How do you stop a tree from growing? The more attributes you have, the more splits, the larger the tree. We need a mechanism to stop this growth. One way of doing this is to set a minimum number of training inputs to use on each leaf. For example, we can use a minimum of 7 passengers to reach a decision, either died or survived and ignore any leaf that takes less than 7 passengers. Another way is to set maximum depth of your model. Maximum depth refers to the length of the longest path from a root to a leaf.

The performance of a decision tree can be further enhanced through pruning. Pruning involves removing the branches that make use of features having low importance. Using this approach, we reduce the complexity of tree, thus increasing its predictive power by reducing overfitting.

By pruning we mean that the lower ends, the leaves of the tree are snipped until the tree is much smaller.

Leaf nodes are removed only if it results in a drop in the overall cost function on the entire dataset. Please note the diagram is for instructional purposes only.

Part 2


Recent Posts

See All

Array Basics

Machine learning data is represented by arrays. The core data object in machine learning is the array. Machine learning data is represented by arrays. The core data object in machine learning is the a

What is Machine Learning

In order to understand machine learning you'll need to understand the machine learning hierarchy. The umbrella for all things machine learning is artificial intelligence. Before defining AI take a lo

Applied Machine Learning is Programming

The two primary languages used in applied machine learning today are SQL and Python. There's no way around it. If you want to work in the real world as a machine learning engineer then you are going t