Балансировка классов в torch с помощью WeightedRandomSampler
В машинном обучении постоянно встречается эта проблема - в датасете, на котором ты обучаешь нейросеть (или любую другую ML модель) - разное количество записей для разных классов. Иногда прям сильно разное. Если оставить это как есть, то нейросеть толком не выучит редкие классы и не научится их отличать.
Два основных подхода к решению этой проблемы - умножать loss для редких классов пропорционально их редкости, чтобы высокий loss заставлял сетку учиться на них и чаще читать записи редких классов, чтобы они перестали быть редкими. Эта статья про второй подход.
Итак, подготовка.
Создаем тензор датасет, в котором есть ровно 1000 нулей и 50 единичек. Нули и единички, понятное дело, представляют разные классы. Очевидно, классы не сбалансированы.
import torch
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
ds = TensorDataset(torch.Tensor([ 0 for _ in range(1000)] + [ 1 for _ in range(50)]))
Теперь будем просить батчи из этого несбалансированного датасета разными способами. Батчи - это небольшие куски данных, коллекции определенной длины.
Первый способ, простейший
dl = DataLoader(ds, batch_size=50)
for _ in range(10):
print(next(iter(dl))[0].mean())
# tensor(0.)
# tensor(0.)
# tensor(0.)
# tensor(0.)
# tensor(0.)
# tensor(0.)
# tensor(0.)
# tensor(0.)
# tensor(0.)
# tensor(0.)
Мы попросили из самого дефолтного даталоадера 10 батчей, в каждом батче по 50 цифр. Чтобы понять, из каких цифр состоят батчи - посчитали среднее для каждого из этих 10 батчей.
Очевидно, что сейчас здесь всегда нули - самый дефолтный даталоадер бежит по порядку по всем элементам. В начале у нас тысяча нулей, вот мы и получаем каждый раз в среднем 0. Просто!
Второй способ, чуть лучше
dl = DataLoader(ds, batch_size=50, shuffle=True)
for _ in range(10):
print(next(iter(dl))[0].mean())
# tensor(0.0400)
# tensor(0.1000)
# tensor(0.0800)
# tensor(0.1400)
# tensor(0.1000)
# tensor(0.0200)
# tensor(0.0200)
# tensor(0.0400)
# tensor(0.0600)
# tensor(0.0400)
Сейчас мы получили не нули, так как мы перемешали все элементы с помощью shuffle=True
и теперь в батчи иногда попадаются единички. Тоже просто!
Но средние по батчам далеки от 0.5, которое было бы, если бы мы получали нули и единички с одинаковой частотой, то есть если бы единички перестали быть редкими.
Третий способ и достижение баланса
Но все не так просто, нужна некоторая подготовка. Нам нужно знать - какая частота появления у какого класса в нашем датасете, чтобы правильно выставить веса для каждого значения:
from collections import Counter
# i.item() - вытаскивает питонье int значение из тензора
counter = Counter(i.item() for i, in ds)
weights = [1/counter.get(i.item()) for i, in ds]
Обратите внимание: изящный анпэкинг for i, in ds
, мы уже писали про него ранее.
Теперь counter это Counter({0: 1000, 1: 50})
, он просто посчитал уникальные элементы в датасете, а weights - это список весов для каждой цифры - можно сказать, “вероятностей” вытащить из датасета эту конкретную цифру.
Вероятность в кавычках, потому что сумма весов не обязана быть равна 1 - хитрый объект из следующего блока кода, в честь которого и названа эта статья там сам поделит на сумму весов, чтобы все правильно работало.
В нашем случае len(weights) == 1050
(ну, это понятно) и sum(weights) == 2
(в общем случае, она будет равна количеству классов, у нас класса 2 - нули и единицы)
Все, теперь мы готовы решить задачку
sampler = WeightedRandomSampler(weights, num_samples=len(weights))
dl = DataLoader(ds, batch_size=50, sampler=sampler)
for _ in range(10):
print(next(iter(dl))[0].mean())
# tensor(0.7000)
# tensor(0.5000)
# tensor(0.2000)
# tensor(0.6000)
# tensor(0.4000)
# tensor(0.6000)
# tensor(0.6000)
# tensor(0.4000)
# tensor(0.6000)
# tensor(0.6000)
Теперь у нас в каждом батче, который мы просим, среднее получается всегда примерно 0.5 - а это значит, что нет больших редких классов, все они одинаково хорошо представлены в обучении. Вы великолепны!
Откуда это все?
То самое соревнование про симпсонов
Та же самая статья, только Kaggle ноутбук
Больше про рандом сэмплер тут