top of page
  • Writer's pictureMike

Multi-Class Classification: XGBoost

The Titanic model was a binary classification problem. The answer was either yes or no which in machine speak was a 1 or a 0. The next model we will build is also a classification problem, however, it’s a multi-class classification model. That means the output of the model will have more than one answer.


The Iris flower data set is a multivariate data set introduced by the British statistician and biologist Ronald Fisher in his 1936 paper The use of multiple measurements in taxonomic problems. The dataset consists of 50 samples from each of three species of Iris (Iris Setosa, Iris virginica, and Iris versicolor). Four features were measured from each sample: the length and the width of the sepals and petals, in centimeters.


The dataset is below. The target variable is species.


The target variable has three possible outputs. They are Setosa, virginica, and versicolor.


Recall that XGBoost only accepts numerical inputs so we will need to change those textual values to numbers.

Let’s import the libraries we will use on this problem. This will be the first cell in our Jupyter Notebook. Take note that SciKit-Learn has several built-in datasets we can use and one of those datasets is the iris dataset.


Import Libraries


import xgboost as xgb

from sklearn import datasets

from sklearn.model_selection import train_test_split

from sklearn.preprocessing import LabelEncoder

from xgboost import XGBClassifier

from sklearn.metrics import accuracy_score


Load and Prepare Data


In the next line of code let’s create a variable iris and read our iris dataset into that data variable. The iris dataset is being loaded from datasets from within SciKit-learn.


iris = datasets.load_iris()


In the next line of code, we are specifying the target variable and the attribute we want the model to use.


X = iris.data

y = iris.target


Separate Data


In the next line of code, we are separating our data into training and testing sets. We’ve passed the parameter test_size to create an 80/20 training and testing split of our data. You can specify any size you’d like. Some recommendations are 80/20 or 70/30.

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=5)

In the next line of code, we are converting our target variable into integers using label encoding. Recall that label encoding refers to converting the labels into numeric form so the model understands it.


lc = LabelEncoder()

lc = lc.fit(y)

lc_y = lc.transform(y)


Note: This step was not necessary with the iris dataset in SciKit-Learn. That dataset was cleansed and model ready. I wanted to include this step to show you how encoding the target variable was done.


In the next two lines of code we create a variable called model to hold our classifier and fit our model to the training dataset.


Define and Fit Model


model = XGBClassifier()

model.fit(X_train, y_train)

In the next cell we are creating a variable y_pred and executing our model against the testing data. Additionally, a variable predictions is created to round the results of our model.


y_pred = model.predict(X_test)

predictions = [round(value) for value in y_pred]


In this step we are creating a variable accuracy and placing the predictions made against our test data.


accuracy = accuracy_score(y_test, predictions)

print("Accuracy: %.2f%%" % (accuracy * 100.0))


The complete code is below.


import xgboost as xgb

from sklearn import datasets

from sklearn.model_selection import train_test_split

from sklearn.preprocessing import LabelEncoder

from xgboost import XGBClassifier

from sklearn.metrics import accuracy_score

iris = datasets.load_iris()

X = iris.data

y = iris.target

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=5)

lc = LabelEncoder()

lc = lc.fit(Y)

lc_y = lc.transform(Y)

model = XGBClassifier()

model.fit(X_train, y_train)

y_pred = model.predict(X_test)

predictions = [round(value) for value in y_pred]

accuracy = accuracy_score(y_test, predictions)

print("Accuracy: %.2f%%" % (accuracy * 100.0))

18,392 views0 comments

Recent Posts

See All

Here's the link for the BigQuery Course on Udemy Again, I sincerely apologize for this.

There's a lot of confusion in this space on what a data analyst is and what they do. Let's clear that up. A data analyst enables businesses to maximize the value of their data assets through visualiz

bottom of page