Базированная модель, часть 2
Первая часть: ссылка. В ней мы определили основную проблему не-трансформерных моделей: запоминание повторяющихся N-грамм. Остался второй вопрос...
Что делать?
Simple linear attention language models balance the recall-throughput tradeoff, статья, пост 1, пост 2, твиттер, репо.
В этой статье авторы представляют новую архитектуру под названием Based, которая (как обычно) должна заменить трансформеры.
Для начала, напомню: основная цель большинства не-трансформеров сделать быстрее при том же качестве. В этой и прошлой статье у трансформеров лучше перплексия. Единственное место, где это не так — первый пост про Based. Кажется, они замерили там некорректные числа, но к выпуску основной статьи исправились💀
Ещё раз: за исключением первого старого поста со старыми числами, здесь нигде не заявляется превосходство по перплексии над трансформерами. Зато заявляется превосходство по пропускной способности и скорости генерации, даже против FlashAttention-2😳
У Based есть 3 типа слоёв, смешивающих эмеддинги токенов:
- BaseConv из предыдущего поста
- Линейное внимание
- Полное внимание на маленьком скользящем окне
В одном слое только один конкретный смеситель, например в первом слое линейное внимание, во втором скользящее окно, в третьем BaseConv, и так далее.
Кроме смесителей по токенам есть и стандартные MLP для смешивания каналов, одинаковые для каждого типа слоёв.
Что за линейное внимание? Сама концепция не нова: убираем софтмакс и получаем рекуррентную форму. Софтмакс убираем аппроксимацией exp(qk) через ряд Тейлора до второго порядка малости: 1 + qk + (qk)^2 / 2. Подробнее об этом написано в другой статье, но тут всё равно показывают, что другие варианты аппроксимаций хуже.
Псевдокод вычисления:
В статье есть фундаментальная теорема: решение MQAR (а значит и языкового моделирования на реальных текстах) требует размер скрытого состояния, линейный от длины последовательности. Для внимания таким скрытым состоянием является KV-кэш, который имеет размер O(Nd). Для Based это KV-состояние для числителя и K-сотояние для знаменателя в симуляции софтмакса, с общей размерностью O(d^3), но в статье эту размерность несколько уменьшают проекциями. Это всё означает, что чем длиннее последовательность мы хотим корректно обрабатывать, тем жирнее нам нужно делать размерность модели😭
Для того, чтобы конкурировать с эффективными реализациями внимания, ребята написали кастомные ядра, которые стараются выполнять операции в SRAM, быстрой памяти GPU, по аналогии с FlashAttention. Получилось и правда быстро😘
В итоге имеем модель, которая по качеству на уровне трансформеров, но пропускная способность которой в 25 раз больше.
Что мне кажется сомнительным во всей этой истории:
- А что произошло с перплексией в первом посте? Что авторы изначально сделали не так?
- Нафига было разрабатывать BaseConv, чтобы потом вернуться к линейному вниманию? Зачем нам теперь вообще BaseConv нужен? Нет, в статье конечно есть секция C, в которой добавление BaseConv обосновывается выигрышем в перплексии, но это ничего не объясняет.
- Почему только 1b параметров? Спонсоров вроде много, люди богатые.
P.S. Немного покринжевал со слова "смеситель" в этом контексте, но решил оставить. Тоже смешивает же.
Первая часть: ссылка. В ней мы определили основную проблему не-трансформерных моделей: запоминание повторяющихся N-грамм. Остался второй вопрос...
Что делать?
Simple linear attention language models balance the recall-throughput tradeoff, статья, пост 1, пост 2, твиттер, репо.
В этой статье авторы представляют новую архитектуру под названием Based, которая (как обычно) должна заменить трансформеры.
Для начала, напомню: основная цель большинства не-трансформеров сделать быстрее при том же качестве. В этой и прошлой статье у трансформеров лучше перплексия. Единственное место, где это не так — первый пост про Based. Кажется, они замерили там некорректные числа, но к выпуску основной статьи исправились
Ещё раз: за исключением первого старого поста со старыми числами, здесь нигде не заявляется превосходство по перплексии над трансформерами. Зато заявляется превосходство по пропускной способности и скорости генерации, даже против FlashAttention-2
У Based есть 3 типа слоёв, смешивающих эмеддинги токенов:
- BaseConv из предыдущего поста
- Линейное внимание
- Полное внимание на маленьком скользящем окне
В одном слое только один конкретный смеситель, например в первом слое линейное внимание, во втором скользящее окно, в третьем BaseConv, и так далее.
Кроме смесителей по токенам есть и стандартные MLP для смешивания каналов, одинаковые для каждого типа слоёв.
Что за линейное внимание? Сама концепция не нова: убираем софтмакс и получаем рекуррентную форму. Софтмакс убираем аппроксимацией exp(qk) через ряд Тейлора до второго порядка малости: 1 + qk + (qk)^2 / 2. Подробнее об этом написано в другой статье, но тут всё равно показывают, что другие варианты аппроксимаций хуже.
Псевдокод вычисления:
qk0 = [1]Мотивация добавления чего-либо к BaseConv всё та же: улучшить вспоминание информации, добавив входозависимые смесители. При этом, в отличие от трансформеров, у которых KV-кэш растёт линейно от длины последовательности, мы можем варьировать объём доступной нам памяти через размер скрытого состояния в линейном внимании. И таким образом можем разменивать точность вспоминания на скорость.
# Считаем слагаемые от q
q_first = [q1, ..., qd] # q
q_second = [q1 * q1, ..., q1 * qd, ..., qd * q1, ... qd * qd] # qq^T
q_new = cat(qk0, q_first, q_second / sqrt(2))
# Считаем слагаемые от k
k_first = [k1, ..., kd] # k
k_second = [k1 * k1, ..., k1 * kd, ..., kd * k1, ... kd * kd] # kk^T
k_new = cat(qk0, k_first, k_second / sqrt(2))
# Разложение экспоненты до второго порядка малости: 1 + qk + (qk)^2 / 2
y = (q_new * k_new).sum()
В статье есть фундаментальная теорема: решение MQAR (а значит и языкового моделирования на реальных текстах) требует размер скрытого состояния, линейный от длины последовательности. Для внимания таким скрытым состоянием является KV-кэш, который имеет размер O(Nd). Для Based это KV-состояние для числителя и K-сотояние для знаменателя в симуляции софтмакса, с общей размерностью O(d^3), но в статье эту размерность несколько уменьшают проекциями. Это всё означает, что чем длиннее последовательность мы хотим корректно обрабатывать, тем жирнее нам нужно делать размерность модели
Для того, чтобы конкурировать с эффективными реализациями внимания, ребята написали кастомные ядра, которые стараются выполнять операции в SRAM, быстрой памяти GPU, по аналогии с FlashAttention. Получилось и правда быстро
В итоге имеем модель, которая по качеству на уровне трансформеров, но пропускная способность которой в 25 раз больше.
Что мне кажется сомнительным во всей этой истории:
- А что произошло с перплексией в первом посте? Что авторы изначально сделали не так?
- Нафига было разрабатывать BaseConv, чтобы потом вернуться к линейному вниманию? Зачем нам теперь вообще BaseConv нужен? Нет, в статье конечно есть секция C, в которой добавление BaseConv обосновывается выигрышем в перплексии, но это ничего не объясняет.
- Почему только 1b параметров? Спонсоров вроде много, люди богатые.
P.S. Немного покринжевал со слова "смеситель" в этом контексте, но решил оставить. Тоже смешивает же.