"No one is harder on a talented person than the person themselves" - Linda Wilkinson ; "Trust your guts and don't follow the herd" ; "Validate direction not destination" ;

December 09, 2023

Simplifying Neural Network Training Under Class Imbalance

Simplifying Neural Network Training Under Class Imbalance

  • Small batch size - Class-imbalanced - settings, where small batch sizes shine.
  • Data augmentations have an amplified impact on performance under class imbalance, especially on minority-class accuracy
  • Adding a self-supervised loss during training can improve feature representations
  • Label smoothing, especially on minority class examples, helps prevent overfitting. We adapt label smoothing for the class-imbalanced setting by applying more smoothing to minorityclass examples than to majority-class  examples
  • A small modification of Sharpness-Aware Minimization (SAM) pulls decision boundaries away from minority samples and significantly improves minority-group accuracy
  • Loss reweighting. Reweighting methods assign different weights to majority and minority class loss functions, increasing the influence of minority samples which would otherwise play little role in the loss function

Label smoothing is a technique often used in training deep learning models, particularly for classification tasks. It modifies the target labels, making them a blend of the original hard labels and some uniform or prior distribution. This can lead to better generalization by preventing the model from becoming too confident about its predictions. In a class-imbalanced setting, where some classes have significantly more examples than others, label smoothing can help by reducing the model's bias towards the more frequent classes.

Label smoothing for the class-imbalanced setting python example


In practice, label smoothing does not change the dataset's inherent imbalance but softens the target distributions by moving a portion of the mass from the peak (corresponding to the hard label) to other classes, which can help during the training of a model, preventing it from becoming overly confident on the majority class.



Loss reweighting for the class-imbalanced setting python example


class_weight: Optional dictionary mapping class indices (integers) to a weight (float) value, used for weighting the loss function (during training only).

Let's import the module first

from sklearn.utils import class_weight

In order to calculate the class weight do the following

class_weights = class_weight.compute_class_weight('balanced', np.unique(y_train), y_train)

Thirdly and lastly add it to the model fitting

model.fit(X_train, y_train, class_weight=class_weights)

Keep Exploring!!!

No comments: