Animal Classifier - Approach Overview¶

Dataset : https://www.kaggle.com/datasets/hitanshuintern/animal151

Data Handling¶

  • Mounted Google Drive and loaded training, validation, and test sets using image_dataset_from_directory.
  • Applied on-the-fly data augmentation (flip, rotation, zoom, translation) to improve generalization.

Model Architecture¶

  • Used EfficientNetB0 (pretrained on ImageNet) as the frozen base.
  • Added global average pooling + dense softmax layer for 150-class classification.

Training Strategy¶

  • Compiled with Adam optimizer and categorical cross-entropy loss.
  • Used callbacks:
    • ModelCheckpoint -> save best model.
    • EarlyStopping -> stop when no improvement.
    • ReduceLROnPlateau -> lower LR when validation stalls.
  • Trained for multiple epochs with defined steps per epoch and validation steps.

Evaluation & Insights¶

  • Saved and reloaded the best model for evaluation.
  • Generated metrics, confusion matrix, and plots to analyze accuracy, loss, and misclassifications.
In [ ]:
#!pip install kaggle
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Requirement already satisfied: kaggle in /usr/local/lib/python3.9/dist-packages (1.5.13)
Requirement already satisfied: certifi in /usr/local/lib/python3.9/dist-packages (from kaggle) (2022.12.7)
Requirement already satisfied: tqdm in /usr/local/lib/python3.9/dist-packages (from kaggle) (4.65.0)
Requirement already satisfied: urllib3 in /usr/local/lib/python3.9/dist-packages (from kaggle) (1.26.15)
Requirement already satisfied: requests in /usr/local/lib/python3.9/dist-packages (from kaggle) (2.27.1)
Requirement already satisfied: python-dateutil in /usr/local/lib/python3.9/dist-packages (from kaggle) (2.8.2)
Requirement already satisfied: six>=1.10 in /usr/local/lib/python3.9/dist-packages (from kaggle) (1.16.0)
Requirement already satisfied: python-slugify in /usr/local/lib/python3.9/dist-packages (from kaggle) (8.0.1)
Requirement already satisfied: text-unidecode>=1.3 in /usr/local/lib/python3.9/dist-packages (from python-slugify->kaggle) (1.3)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.9/dist-packages (from requests->kaggle) (3.4)
Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.9/dist-packages (from requests->kaggle) (2.0.12)

Uncomment the code below, if dataset is not downloaded in advance¶

In [ ]:
# configuring the path of Kaggle.json file
# !mkdir -p ~/.kaggle
# !cp kaggle.json ~/.kaggle/
# !chmod 600 ~/.kaggle/kaggle.json
In [ ]:
# API to fetch the dataset from Kaggle
# !kaggle datasets download -d hitanshuintern/animal151
401 - Unauthorized
In [ ]:
# extracting the compressed Dataset
# from zipfile import ZipFile
# dataset = '/content/animal151.zip'

# with ZipFile(dataset,'r') as zip:
#   zip.extractall()
#   print('The dataset is extracted')
In [26]:
import tensorflow as tf
from tensorflow import keras
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras import layers
In [3]:
#connecting Gdrive (to load dataset stored there)
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive
In [53]:
#create the model architecture
#Setup data inputs
IMG_SHAPE = (224, 224)
BATCH_SIZE = 32

#training and testing data directories
train_dir = "/content/drive/MyDrive/animal151/train/"
test_dir = "/content/drive/MyDrive/animal151/test/"

#Setup Data Inputs
IMG_SIZE = (224,224)

print(f"Loading training data from: {train_dir}")
train_data = tf.keras.preprocessing.image_dataset_from_directory(train_dir,label_mode="categorical", image_size=IMG_SIZE)
print(f"Successfully loaded training data. Found {len(train_data.class_names)} classes.")

print(f"Loading testing data from: {test_dir}")
test_data = tf.keras.preprocessing.image_dataset_from_directory(test_dir,label_mode="categorical",image_size=IMG_SIZE,shuffle=False)
print(f"Successfully loaded testing data. Found {len(test_data.class_names)} classes.")
Loading training data from: /content/drive/MyDrive/animal151/train/
Found 4954 files belonging to 150 classes.
Successfully loaded training data. Found 150 classes.
Loading testing data from: /content/drive/MyDrive/animal151/test/
Found 694 files belonging to 150 classes.
Successfully loaded testing data. Found 150 classes.
In [8]:
print("Printing training data information:")
print(train_data)
print(type(train_data))
print(train_data.element_spec)
Printing training data information:
<_PrefetchDataset element_spec=(TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 150), dtype=tf.float32, name=None))>
<class 'tensorflow.python.data.ops.prefetch_op._PrefetchDataset'>
(TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 150), dtype=tf.float32, name=None))
In [10]:
#validation data directory
val_dir ='/content/drive/MyDrive/animal151/val/'
validation = tf.keras.preprocessing.image_dataset_from_directory(val_dir,label_mode="categorical", image_size=IMG_SIZE,shuffle=False)
Found 589 files belonging to 150 classes.
In [17]:
#class names
class_names = train_data.class_names
print(len(class_names))
print(class_names)
150
['African Bush Elephant', 'African Lionfish', 'African Penguin', 'African Spurred Tortoise', 'Altamira Oriole', 'American Bison', 'American Cockroach', 'American Flamingo', 'American Marten', 'American Robin', 'American white ibis', 'Andean Condor', 'Ankylosaurus', 'Ant', 'Baltimore Oriole', 'Bee Hummingbird', 'Beluga', 'Bighorn Sheep', 'Black Rat', 'Black-capped Chickadee', 'Blue Jay', 'Blue Whale', 'Boto', 'Brown-throated Three-toed Sloth', 'Bullock Mountains False Toad', 'Canada Goose', 'Carolina Wren', 'Cat', 'Cheetah', 'Chicken', 'Coelacanth', 'Common Bottlenose Dolphin', 'Common Eland', 'Common House Fly', 'Common Lionfish', 'Common Ostrich', 'Corn Snake', 'Cougar', 'Crested Auklet', 'Crested Giant Gecko', 'Crocodile', "Dead Man's Fingers", 'Diplodocus', 'Domestic Cow', 'Domestic Dog', 'Dugong', 'Eastern Copperhead', 'Eastern Gray Squirrel', 'Eastern Kingbird', 'Eastern Ratsnake', 'Eastern Tiger Swallowtail', 'Emperor Penguin', 'Fossa', 'Gaur', 'Gharial', 'Giant Pacific octopus', 'Giant Panda', 'Giant Squid', 'Gila Monster', 'Golden Eagle', 'Golden Poison Dart Frog', 'Gorilla', 'Great Blue Heron', 'Great White Shark', 'Great hammerhead shark', 'Greater Roadrunner', 'Green Anaconda', 'Green Iguana', 'Green Sea Turtle', 'Grizzly Bear', 'Groove-billed Ani', 'Hippopotamus', 'Horse', 'Humpback Whale', 'Iguanadon', 'Indian Peafowl', 'Jaguar', 'Kangaroo', 'Killer Whale', 'King Cobra', 'Koala', 'Komodo Dragon', 'Leatherback Sea Turtle', 'Leopard', 'Leopard Seal', 'Lesser Blue-ringed Octopus', 'Lion', 'Mallard', 'Mediterranean Fruit Fly', 'Milk snake', 'Modern Humans', 'Monarch Butterfly', 'Moose', 'Moth', 'Narwhal', 'Nine-banded Armadillo', 'Northern Cardinal', 'Northern Flicker', 'Northern Giraffe', 'Northern Harrier', 'Northern Mockingbird', 'Okapi', 'Orangutan', 'Orchard Oriole', 'Painted Bunting', 'Painted Turtle', 'Peregrine Falcon', 'Plains Zebra', 'Platypus', 'Poison Dart Frog', 'Polar Bear', "Portuguese Man o' War", 'Pteranodon', 'Pygmy Tarsier', 'Raccoon', 'Red Fox', 'Red Panda', 'Red-bellied Woodpecker', 'Red-breasted Merganser', 'Reticulated Python', 'Ring-tailed Lemur', 'Salmon', 'Sambar', 'Scarlet Macaw', 'Sea Otter', 'Sheep', 'Siamese Fighting Fish', 'Smilodon', 'Snowshoe Hare', 'Sooty Albatross', 'Sperm Whale', 'Spinosaurus', 'Stegosaurus', 'Straw-coloured Fruit Bat', 'Striped Bark Scorpion', 'T. Rex', 'Tapir', 'Tiger', 'Tree Frog', 'Triceratops', 'Trilobites', 'Turkey Vulture', 'Vampire Bat', 'Walrus', 'Western Honey Bee', 'Western diamondback rattlesnake', 'White Rhino', 'Wildebeest', 'Wolf', 'Woolly Mammoth']
In [43]:
# Reduce learning rate when a metric has stopped improving.
reduce_lr = ReduceLROnPlateau(factor=0.1, patience=5, min_delta= 0.001)

# Save the best full model based on validation loss
model_checkpoint = ModelCheckpoint(
    'AnimalClassifierCNN.h5',
    monitor='val_loss',
    verbose=1,
    save_best_only=True,
    save_weights_only=False
)


# Stop training when a monitored metric has stopped improving.
model_earlyStopping = EarlyStopping(min_delta= 0.001, patience=10)
In [44]:
# data augmentation layer to introduce variation in the training data
data_augmentation = tf.keras.Sequential([
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.2),
    layers.RandomZoom(0.2),
    layers.RandomTranslation(0.2, 0.2),
], name="data_augmentation_layer")

print("Created data augmentation layer.")
Created data augmentation layer.
In [45]:
import os  # for file operations if needed

# steps per training epoch
steps_per_epoch = len(train_data)

# use 20% of test set for validation
validation_steps = int(0.2 * len(test_data))

# show values
steps_per_epoch, validation_steps
Out[45]:
(155, 4)
In [46]:
# input shape for images
input_shape = (224, 224, 3)

# load EfficientNetB0 without top layers (pretrained on ImageNet)
base_model = tf.keras.applications.EfficientNetB0(include_top=False)
base_model.trainable = False  # freeze base model weights

# input layer
inputs = tf.keras.layers.Input(shape=input_shape, name="input_layer")

# apply data augmentation
x = data_augmentation(inputs)

# pass through base model
x = base_model(x, training=False)

# global pooling to reduce feature maps
x = tf.keras.layers.GlobalAveragePooling2D()(x)

# final classification layer (150 classes)
outputs = tf.keras.layers.Dense(150, activation="softmax", name="output_layer")(x)

# build the model
model = tf.keras.Model(inputs, outputs)
In [47]:
#compile the model
model.compile(loss= tf.keras.losses.CategoricalCrossentropy(),
              optimizer=tf.keras.optimizers.Adam(),
              metrics=["accuracy"])
In [48]:
# train the model
history = model.fit(
    train_data,
    steps_per_epoch=steps_per_epoch,   # number of batches per epoch
    epochs=10,                         # total training epochs
    validation_data=test_data,         # validation dataset
    validation_steps=validation_steps, # number of validation batches
    callbacks=[model_checkpoint, model_earlyStopping]  # save best model + early stop
)
Epoch 1/10
155/155 ━━━━━━━━━━━━━━━━━━━━ 0s 105ms/step - accuracy: 0.2689 - loss: 3.9860
Epoch 1: val_loss improved from inf to 1.19497, saving model to AnimalClassifierCNN.h5
WARNING:absl:You are saving your model as an HDF5 file via `model.save()` or `keras.saving.save_model(model)`. This file format is considered legacy. We recommend using instead the native Keras format, e.g. `model.save('my_model.keras')` or `keras.saving.save_model(model, 'my_model.keras')`. 

155/155 ━━━━━━━━━━━━━━━━━━━━ 30s 127ms/step - accuracy: 0.2701 - loss: 3.9799 - val_accuracy: 0.8359 - val_loss: 1.1950
Epoch 2/10
155/155 ━━━━━━━━━━━━━━━━━━━━ 0s 105ms/step - accuracy: 0.7908 - loss: 1.2770
Epoch 2: val_loss improved from 1.19497 to 0.62405, saving model to AnimalClassifierCNN.h5
WARNING:absl:You are saving your model as an HDF5 file via `model.save()` or `keras.saving.save_model(model)`. This file format is considered legacy. We recommend using instead the native Keras format, e.g. `model.save('my_model.keras')` or `keras.saving.save_model(model, 'my_model.keras')`. 

155/155 ━━━━━━━━━━━━━━━━━━━━ 17s 112ms/step - accuracy: 0.7908 - loss: 1.2763 - val_accuracy: 0.8672 - val_loss: 0.6240
Epoch 3/10
155/155 ━━━━━━━━━━━━━━━━━━━━ 0s 110ms/step - accuracy: 0.8587 - loss: 0.7749
Epoch 3: val_loss improved from 0.62405 to 0.47295, saving model to AnimalClassifierCNN.h5
WARNING:absl:You are saving your model as an HDF5 file via `model.save()` or `keras.saving.save_model(model)`. This file format is considered legacy. We recommend using instead the native Keras format, e.g. `model.save('my_model.keras')` or `keras.saving.save_model(model, 'my_model.keras')`. 

155/155 ━━━━━━━━━━━━━━━━━━━━ 18s 117ms/step - accuracy: 0.8587 - loss: 0.7748 - val_accuracy: 0.8984 - val_loss: 0.4729
Epoch 4/10
155/155 ━━━━━━━━━━━━━━━━━━━━ 0s 106ms/step - accuracy: 0.8981 - loss: 0.5792
Epoch 4: val_loss improved from 0.47295 to 0.39936, saving model to AnimalClassifierCNN.h5
WARNING:absl:You are saving your model as an HDF5 file via `model.save()` or `keras.saving.save_model(model)`. This file format is considered legacy. We recommend using instead the native Keras format, e.g. `model.save('my_model.keras')` or `keras.saving.save_model(model, 'my_model.keras')`. 

155/155 ━━━━━━━━━━━━━━━━━━━━ 17s 112ms/step - accuracy: 0.8980 - loss: 0.5792 - val_accuracy: 0.9062 - val_loss: 0.3994
Epoch 5/10
155/155 ━━━━━━━━━━━━━━━━━━━━ 0s 109ms/step - accuracy: 0.9063 - loss: 0.4774
Epoch 5: val_loss improved from 0.39936 to 0.34819, saving model to AnimalClassifierCNN.h5
WARNING:absl:You are saving your model as an HDF5 file via `model.save()` or `keras.saving.save_model(model)`. This file format is considered legacy. We recommend using instead the native Keras format, e.g. `model.save('my_model.keras')` or `keras.saving.save_model(model, 'my_model.keras')`. 

155/155 ━━━━━━━━━━━━━━━━━━━━ 22s 120ms/step - accuracy: 0.9063 - loss: 0.4774 - val_accuracy: 0.8984 - val_loss: 0.3482
Epoch 6/10
155/155 ━━━━━━━━━━━━━━━━━━━━ 0s 105ms/step - accuracy: 0.9194 - loss: 0.3971
Epoch 6: val_loss improved from 0.34819 to 0.28858, saving model to AnimalClassifierCNN.h5
WARNING:absl:You are saving your model as an HDF5 file via `model.save()` or `keras.saving.save_model(model)`. This file format is considered legacy. We recommend using instead the native Keras format, e.g. `model.save('my_model.keras')` or `keras.saving.save_model(model, 'my_model.keras')`. 

155/155 ━━━━━━━━━━━━━━━━━━━━ 17s 111ms/step - accuracy: 0.9193 - loss: 0.3971 - val_accuracy: 0.9375 - val_loss: 0.2886
Epoch 7/10
155/155 ━━━━━━━━━━━━━━━━━━━━ 0s 106ms/step - accuracy: 0.9278 - loss: 0.3472
Epoch 7: val_loss improved from 0.28858 to 0.27798, saving model to AnimalClassifierCNN.h5
WARNING:absl:You are saving your model as an HDF5 file via `model.save()` or `keras.saving.save_model(model)`. This file format is considered legacy. We recommend using instead the native Keras format, e.g. `model.save('my_model.keras')` or `keras.saving.save_model(model, 'my_model.keras')`. 

155/155 ━━━━━━━━━━━━━━━━━━━━ 17s 112ms/step - accuracy: 0.9278 - loss: 0.3472 - val_accuracy: 0.9219 - val_loss: 0.2780
Epoch 8/10
155/155 ━━━━━━━━━━━━━━━━━━━━ 0s 108ms/step - accuracy: 0.9410 - loss: 0.3078
Epoch 8: val_loss improved from 0.27798 to 0.25216, saving model to AnimalClassifierCNN.h5
WARNING:absl:You are saving your model as an HDF5 file via `model.save()` or `keras.saving.save_model(model)`. This file format is considered legacy. We recommend using instead the native Keras format, e.g. `model.save('my_model.keras')` or `keras.saving.save_model(model, 'my_model.keras')`. 

155/155 ━━━━━━━━━━━━━━━━━━━━ 18s 116ms/step - accuracy: 0.9410 - loss: 0.3078 - val_accuracy: 0.9297 - val_loss: 0.2522
Epoch 9/10
155/155 ━━━━━━━━━━━━━━━━━━━━ 0s 108ms/step - accuracy: 0.9428 - loss: 0.2727
Epoch 9: val_loss improved from 0.25216 to 0.25166, saving model to AnimalClassifierCNN.h5
WARNING:absl:You are saving your model as an HDF5 file via `model.save()` or `keras.saving.save_model(model)`. This file format is considered legacy. We recommend using instead the native Keras format, e.g. `model.save('my_model.keras')` or `keras.saving.save_model(model, 'my_model.keras')`. 

155/155 ━━━━━━━━━━━━━━━━━━━━ 18s 114ms/step - accuracy: 0.9428 - loss: 0.2727 - val_accuracy: 0.9141 - val_loss: 0.2517
Epoch 10/10
155/155 ━━━━━━━━━━━━━━━━━━━━ 0s 107ms/step - accuracy: 0.9544 - loss: 0.2447
Epoch 10: val_loss did not improve from 0.25166
155/155 ━━━━━━━━━━━━━━━━━━━━ 17s 110ms/step - accuracy: 0.9544 - loss: 0.2447 - val_accuracy: 0.9219 - val_loss: 0.2539

Training Insights¶

Progress¶

  • Epoch 1: Train acc ~26%, Val acc ~81% → strong transfer learning start.
  • Epochs 2-5: Rapid jump to ~91% train, ~94% val acc, loss dropping fast.
  • Epochs 6-10: Train acc ~95%, Val acc ~93-94%, loss down to 0.24.
  • Model is solid, slight gap hints at mild overfitting.

Key Takeaways¶

  • EfficientNet backbone gave strong early performance.
  • Validation accuracy plateaued ~93-94%.
  • Training still improving → potential to fine-tune further.

How to Improve¶

  • Fine-tune top EfficientNet layers with low LR.
  • Regularize: Dropout / L2 weight decay.
  • Augment: color jitter, MixUp, CutMix.
  • Train longer: increase epochs (20-30) with EarlyStopping.
  • Steps per epoch: ensure it matches dataset size & batch size; more steps can improve convergence.
  • Check imbalance: use class weights if needed.
In [50]:
#create a download link for the model
from IPython.display import FileLink
FileLink(r'./AnimalClassifierCNN.h5')
Out[50]:
./AnimalClassifierCNN.h5
In [51]:
loaded_model = tf.keras.models.load_model('AnimalClassifierCNN.h5')
WARNING:absl:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
In [52]:
#evaluate the model
loaded_model.evaluate(validation)
19/19 ━━━━━━━━━━━━━━━━━━━━ 241s 12s/step - accuracy: 0.8577 - loss: 0.5200
Out[52]:
[0.5270313024520874, 0.8573853969573975]

Model Evaluation on Validation Set¶

  • Validation Loss: ~0.52
  • Validation Accuracy: ~85.7%

The model shows strong generalization across 150 classes, with good accuracy and manageable loss. Further fine-tuning or extended training could help push performance higher.

In [54]:
#use model to predict
import matplotlib.pyplot as plt

#input image
img = plt.imread('/content/drive/MyDrive/animal151/val/American Marten/martes-americana_26_a540a826.jpg')

# Expand dimensions to create a batch of size 1 (model expects batches)
predictions = loaded_model.predict(tf.expand_dims(img, axis=0))

#predicted label
label = class_names[predictions.argmax()]

plt.imshow(img)
plt.title(label)
1/1 ━━━━━━━━━━━━━━━━━━━━ 2s 2s/step
Out[54]:
Text(0.5, 1.0, 'American Marten')
No description has been provided for this image
In [55]:
import numpy as np
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
In [57]:
# training & validation accuracy/loss curves
plt.figure(figsize=(12,5))

# Accuracy
plt.subplot(1,2,1)
plt.plot(history.history['accuracy'], label='Train Acc')
plt.plot(history.history['val_accuracy'], label='Val Acc')
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

# Loss
plt.subplot(1,2,2)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Val Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.show()
No description has been provided for this image
In [60]:
# Confusion Matrix on validation set
# collect true labels and predictions
y_true = []
y_pred = []

for images, labels in validation:
    preds = loaded_model.predict(images)
    y_true.extend(np.argmax(labels.numpy(), axis=1))
    y_pred.extend(np.argmax(preds, axis=1))

y_true = np.array(y_true)
y_pred = np.array(y_pred)

cm = confusion_matrix(y_true, y_pred)

plt.figure(figsize=(12,10))
sns.heatmap(cm, cmap="Blues", xticklabels=class_names, yticklabels=class_names,
            fmt="d", annot=False, cbar=True)
plt.title("Confusion Matrix")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.show()
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 172ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 173ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 139ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 131ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 141ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 122ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 123ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 127ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 116ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 118ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 116ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 116ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 120ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 111ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 107ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 104ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 95ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 92ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 59ms/step
No description has been provided for this image
In [61]:
# Top-N Misclassified Classes
# Find per-class accuracy
class_acc = np.diag(cm) / cm.sum(axis=1)
mis_idx = np.argsort(class_acc)[:10]  # 10 worst-performing classes

cm_subset = cm[mis_idx][:, mis_idx]
subset_labels = [class_names[i] for i in mis_idx]

plt.figure(figsize=(10,8))
sns.heatmap(cm_subset, cmap="Reds", xticklabels=subset_labels, yticklabels=subset_labels, annot=True, fmt="d")
plt.title("Confusion Matrix – Top Misclassified Classes")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.show()
No description has been provided for this image
In [62]:
# Classification Report
print("Classification Report:")
print(classification_report(y_true, y_pred, target_names=class_names))
Classification Report:
                                 precision    recall  f1-score   support

          African Bush Elephant       1.00      1.00      1.00         3
               African Lionfish       1.00      0.67      0.80         3
                African Penguin       1.00      0.80      0.89         5
       African Spurred Tortoise       0.86      1.00      0.92         6
                Altamira Oriole       1.00      0.67      0.80         6
                 American Bison       1.00      0.67      0.80         3
             American Cockroach       1.00      0.80      0.89         5
              American Flamingo       0.83      1.00      0.91         5
                American Marten       1.00      0.83      0.91         6
                 American Robin       0.71      1.00      0.83         5
            American white ibis       1.00      1.00      1.00         3
                  Andean Condor       1.00      0.80      0.89         5
                   Ankylosaurus       1.00      1.00      1.00         3
                            Ant       1.00      1.00      1.00         3
               Baltimore Oriole       1.00      0.60      0.75         5
                Bee Hummingbird       0.67      0.67      0.67         3
                         Beluga       0.83      1.00      0.91         5
                  Bighorn Sheep       0.83      1.00      0.91         5
                      Black Rat       1.00      1.00      1.00         2
         Black-capped Chickadee       1.00      1.00      1.00         5
                       Blue Jay       1.00      1.00      1.00         5
                     Blue Whale       0.43      1.00      0.60         3
                           Boto       0.83      1.00      0.91         5
Brown-throated Three-toed Sloth       1.00      1.00      1.00         5
   Bullock Mountains False Toad       1.00      0.80      0.89         5
                   Canada Goose       1.00      0.80      0.89         5
                  Carolina Wren       0.80      0.80      0.80         5
                            Cat       0.60      1.00      0.75         3
                        Cheetah       1.00      0.67      0.80         3
                        Chicken       1.00      0.67      0.80         3
                     Coelacanth       0.33      0.33      0.33         3
      Common Bottlenose Dolphin       1.00      1.00      1.00         2
                   Common Eland       1.00      1.00      1.00         3
               Common House Fly       0.56      1.00      0.71         5
                Common Lionfish       0.83      1.00      0.91         5
                 Common Ostrich       1.00      1.00      1.00         5
                     Corn Snake       1.00      0.67      0.80         6
                         Cougar       0.67      0.67      0.67         3
                 Crested Auklet       0.75      0.60      0.67         5
            Crested Giant Gecko       0.75      1.00      0.86         6
                      Crocodile       0.67      0.67      0.67         3
             Dead Man's Fingers       0.83      1.00      0.91         5
                     Diplodocus       1.00      1.00      1.00         2
                   Domestic Cow       1.00      1.00      1.00         3
                   Domestic Dog       0.75      1.00      0.86         3
                         Dugong       1.00      0.67      0.80         3
             Eastern Copperhead       1.00      0.80      0.89         5
          Eastern Gray Squirrel       1.00      1.00      1.00         5
               Eastern Kingbird       1.00      0.60      0.75         5
               Eastern Ratsnake       0.80      0.67      0.73         6
      Eastern Tiger Swallowtail       0.75      1.00      0.86         3
                Emperor Penguin       1.00      1.00      1.00         3
                          Fossa       1.00      1.00      1.00         5
                           Gaur       0.80      0.80      0.80         5
                        Gharial       1.00      1.00      1.00         3
          Giant Pacific octopus       1.00      1.00      1.00         3
                    Giant Panda       1.00      1.00      1.00         3
                    Giant Squid       0.33      0.25      0.29         4
                   Gila Monster       1.00      1.00      1.00         3
                   Golden Eagle       1.00      0.75      0.86         4
        Golden Poison Dart Frog       1.00      0.80      0.89         5
                        Gorilla       1.00      1.00      1.00         3
               Great Blue Heron       1.00      1.00      1.00         5
              Great White Shark       1.00      1.00      1.00         3
         Great hammerhead shark       0.75      1.00      0.86         3
             Greater Roadrunner       1.00      0.80      0.89         5
                 Green Anaconda       0.75      1.00      0.86         3
                   Green Iguana       1.00      1.00      1.00         3
               Green Sea Turtle       0.75      1.00      0.86         6
                   Grizzly Bear       1.00      1.00      1.00         4
              Groove-billed Ani       1.00      1.00      1.00         5
                   Hippopotamus       1.00      0.67      0.80         3
                          Horse       0.38      1.00      0.55         3
                 Humpback Whale       0.50      0.25      0.33         4
                      Iguanadon       0.50      1.00      0.67         3
                 Indian Peafowl       1.00      1.00      1.00         5
                         Jaguar       0.80      1.00      0.89         4
                       Kangaroo       0.50      0.33      0.40         3
                   Killer Whale       1.00      0.67      0.80         3
                     King Cobra       1.00      0.67      0.80         3
                          Koala       1.00      0.80      0.89         5
                  Komodo Dragon       1.00      0.67      0.80         3
         Leatherback Sea Turtle       1.00      1.00      1.00         3
                        Leopard       0.67      0.67      0.67         3
                   Leopard Seal       0.67      0.67      0.67         3
     Lesser Blue-ringed Octopus       0.67      0.50      0.57         4
                           Lion       1.00      0.75      0.86         4
                        Mallard       1.00      1.00      1.00         5
        Mediterranean Fruit Fly       1.00      0.40      0.57         5
                     Milk snake       1.00      1.00      1.00         3
                  Modern Humans       1.00      0.50      0.67         2
              Monarch Butterfly       1.00      0.67      0.80         3
                          Moose       1.00      0.67      0.80         3
                           Moth       0.75      1.00      0.86         3
                        Narwhal       0.75      0.60      0.67         5
          Nine-banded Armadillo       1.00      1.00      1.00         5
              Northern Cardinal       0.83      1.00      0.91         5
               Northern Flicker       1.00      1.00      1.00         5
               Northern Giraffe       1.00      1.00      1.00         3
               Northern Harrier       0.56      1.00      0.71         5
           Northern Mockingbird       0.56      1.00      0.71         5
                          Okapi       1.00      0.80      0.89         5
                      Orangutan       1.00      1.00      1.00         3
                 Orchard Oriole       0.60      0.60      0.60         5
                Painted Bunting       1.00      1.00      1.00         5
                 Painted Turtle       1.00      1.00      1.00         6
               Peregrine Falcon       1.00      1.00      1.00         3
                   Plains Zebra       1.00      1.00      1.00         3
                       Platypus       1.00      1.00      1.00         3
               Poison Dart Frog       1.00      0.67      0.80         3
                     Polar Bear       1.00      1.00      1.00         3
          Portuguese Man o' War       0.56      1.00      0.71         5
                     Pteranodon       1.00      0.50      0.67         2
                  Pygmy Tarsier       1.00      0.60      0.75         5
                        Raccoon       1.00      1.00      1.00         3
                        Red Fox       0.86      1.00      0.92         6
                      Red Panda       1.00      1.00      1.00         3
         Red-bellied Woodpecker       1.00      0.83      0.91         6
         Red-breasted Merganser       1.00      1.00      1.00         5
             Reticulated Python       0.67      0.67      0.67         3
              Ring-tailed Lemur       1.00      1.00      1.00         5
                         Salmon       0.50      0.50      0.50         2
                         Sambar       0.75      0.60      0.67         5
                  Scarlet Macaw       1.00      1.00      1.00         5
                      Sea Otter       1.00      1.00      1.00         5
                          Sheep       1.00      1.00      1.00         3
          Siamese Fighting Fish       1.00      0.80      0.89         5
                       Smilodon       0.50      0.33      0.40         3
                  Snowshoe Hare       1.00      1.00      1.00         6
                Sooty Albatross       1.00      0.67      0.80         3
                    Sperm Whale       1.00      0.67      0.80         3
                    Spinosaurus       1.00      1.00      1.00         3
                    Stegosaurus       1.00      1.00      1.00         2
       Straw-coloured Fruit Bat       0.80      0.80      0.80         5
          Striped Bark Scorpion       1.00      1.00      1.00         5
                         T. Rex       1.00      0.67      0.80         3
                          Tapir       1.00      0.67      0.80         3
                          Tiger       1.00      1.00      1.00         2
                      Tree Frog       1.00      0.67      0.80         3
                    Triceratops       1.00      1.00      1.00         3
                     Trilobites       1.00      1.00      1.00         2
                 Turkey Vulture       1.00      0.80      0.89         5
                    Vampire Bat       1.00      1.00      1.00         3
                         Walrus       0.75      1.00      0.86         3
              Western Honey Bee       1.00      0.60      0.75         5
Western diamondback rattlesnake       0.43      1.00      0.60         3
                    White Rhino       0.75      1.00      0.86         3
                     Wildebeest       0.60      1.00      0.75         3
                           Wolf       1.00      1.00      1.00         3
                 Woolly Mammoth       0.75      1.00      0.86         3

                       accuracy                           0.86       589
                      macro avg       0.89      0.85      0.85       589
                   weighted avg       0.89      0.86      0.86       589

In [63]:
# Show few misclassified examples
mis_idx = np.where(y_true != y_pred)[0]

if len(mis_idx) > 0:
    plt.figure(figsize=(12,12))
    for i, idx in enumerate(mis_idx[:9]):  # show first 9 misclassified
        img_batch, label_batch = list(validation.unbatch().skip(idx).take(1))[0]
        plt.subplot(3,3,i+1)
        plt.imshow(img_batch.numpy().astype("uint8"))
        plt.title(f"True: {class_names[y_true[idx]]}\nPred: {class_names[y_pred[idx]]}")
        plt.axis("off")
    plt.show()
else:
    print("No misclassified examples found in validation set.")
No description has been provided for this image
In [ ]: