EXPONENTIAL MOVING AVERAGE



Weight EMA - smoothing техника, которая призвана улучшить generalisation performance модели. Эта техника пришла к нам из times series.



Основная схема: держим копию весов модели, при апдейте делаем апдейт как weighted average между текущими весами и весами до этого степа (post-optimization step).



Что за проблема есть: при каждом training step мы делаем апдейт весов относительно одного батча, но даже при хорошей подготовке данных и хорошем BatchSampler, один батч может быть шумовым и lead to getting stuck in poor local minima or oscillating between different parameter values



Как эта проблема решается?

Формула: decay_factor * new_weights + (1-decay_factor) * old_weights

decay_factor следит за тем, чтобы апдейт был как раз с небольшим влиянием прошлых весов. Из формулы следует:

decay close to 0 - больше влияния на старые веса

decay close to 1 - больше влияния на новые веса



реализация в tensorflow



Реализация Pytorch:

реализация объяснение для лайтнинг подхода через callback:

(там необходимо только поменять функцию on_save_checkpoint и сохранять ema чекпоинты как дикт, как вот тут)



У нас есть веса модели, которые мы изначально инициализируем любым образом, это мб рандом, а может быть распределение. парочка видов инициализации весов модели. Далее в колбэке мы храним веса ema, и апдейтим на каждом степе, в конце сохраняем как чекпойнт key. Заметьте, что когда происходит степ, веса изначальной модели также должны меняться. Подгрузить можно сначала сделав torch.load(checkpoint_path), а потом проиндексировавшись по ema весам. При обучении сильно не расстраивайтесь, вначале метрики будут ниже чем при обычном обучении и обучение при этом будет медленнее, но это плата за stability.



Если вы делаете на чистом pytorch, то можете создать такой же класс и просто делать step и сохранить эти веса в конце тренировки.



первое видео про ema

обзорно про time series, reinforcement

объяснение от умного Lei Mao с матешей



#grokaem_dl