Understanding Multi-class Classification Confusion Matrix in Python

Nicole Sim
3 min readJun 24, 2021

This article describes (1) how to read a confusion matrix output in Python for a multi-class classification problem (2) provides the code on how you can visualize the mundane matrix output and (3) various F1-scores used for multi-class classification problem.

In a Multi-class Classification problem, there are multiple classes (eg, Ant, Bird, Cat) but each sample is only assigned to 1 class. The confusion matrix output in Python helps us understand the performance of the model.

Let’s use a simple example to understand:

import seaborn as sn
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.metrics import classification_report
y_true = [“cat”, “ant”, “cat”, “cat”, “ant”, “bird”]
y_pred = [“ant”, “ant”, “cat”, “cat”, “ant”, “cat”]
labels=[“ant”, “bird”, “cat”]
cm = confusion_matrix(y_true, y_pred, labels=labels)
print(cm)

We can manually tabulate a table for this small dataset but it’s not feasible for large dataset. How can we visualize the confusion matrix better in Python?

df_cm = pd.DataFrame(cm, labels, labels)ax = sn.heatmap(df_cm, annot=True, annot_kws={“size”: 16}, square=True, cbar=False, fmt='g')ax.set_ylim(0, 3) #this manually corrects the cutoff issue in sns.heatmap found in matplotlib ver 3.1.1
plt.xlabel(“Predicted”)
plt.ylabel(“Actual”)
ax.invert_yaxis() #optional
plt.show()

How do we calculate the F1-Score for multi-class classification? There are several metrics for this.

  1. Micro-F1
  • Total TP = 2 + 0 + 2 = 4
  • Total FP = (0+1) + (0+0) + (0+1) = 2
  • Total FN = (0+0) + (0+1) + (1+0) = 2
  • Precision = 4 / (4+2) = 0.67
  • Recall = 4 / (4 + 2) = 0.67
  • Micro-F1 = 2*(0.67 * 0.67) / (0.67+0.67) = 0.67
print(‘Micro Precision: {:.2f}’.format(precision_score(y_true, y_pred, average=’micro’)))
print(‘Micro Recall: {:.2f}’.format(recall_score(y_true, y_pred, average=’micro’)))
print(‘Micro F1-score: {:.2f}\n’.format(f1_score(y_true, y_pred, average=’micro’)))

2. Macro-F1

This calculates the F1 score for each class.

  • Class Ant = 0.80
  • Class Bird = 0
  • Class Cat = 0.67
  • Macro-F1 = (0.80+0+0.67) / 3 = 0.49
print(‘\nClassification Report\n’)
print(classification_report(y_true, y_pred))

3. Weighted-F1

Weighted-F1 takes the weighed mean of each classes’ F1 Score. We have 2 Ants, 1 Bird and 3 Cats.

  • Weighted F1 = (0.8*2) + (0*1) + (0.67*3) / 6= 0.60

Which type of F1-score should we use?

It depends on your problem.

  • If you want to correct an imbalanced class issue, using weighted-F1 or Macro-F1 might be better. For Macro-F1, all classes contribute to the final metric equally while for Weighted-F1, the metric is affected by the support size (the number of true instances for each label).
  • If the assumption that all classes are equally important is not true, using micro-F1 might be better as the metric considers all the samples and is not affected by class.

--

--

Nicole Sim

An avid learner who can’t stop thinking about new ideas. I love tech, automation, healthcare and entrepreneurship.