Базированная модель, часть 2



Первая часть: ссылка. В ней мы определили основную проблему не-трансформерных моделей: запоминание повторяющихся N-грамм. Остался второй вопрос...



Что делать?

Simple linear attention language models balance the recall-throughput tradeoff, статья, пост 1, пост 2, твиттер, репо.



В этой статье авторы представляют новую архитектуру под названием Based, которая (как обычно) должна заменить трансформеры.



Для начала, напомню: основная цель большинства не-трансформеров сделать быстрее при том же качестве. В этой и прошлой статье у трансформеров лучше перплексия. Единственное место, где это не так — первый пост про Based. Кажется, они замерили там некорректные числа, но к выпуску основной статьи исправились 💀

Ещё раз: за исключением первого старого поста со старыми числами, здесь нигде не заявляется превосходство по перплексии над трансформерами. Зато заявляется превосходство по пропускной способности и скорости генерации, даже против FlashAttention-2 😳



У Based есть 3 типа слоёв, смешивающих эмеддинги токенов:

- BaseConv из предыдущего поста

- Линейное внимание

- Полное внимание на маленьком скользящем окне

В одном слое только один конкретный смеситель, например в первом слое линейное внимание, во втором скользящее окно, в третьем BaseConv, и так далее.

Кроме смесителей по токенам есть и стандартные MLP для смешивания каналов, одинаковые для каждого типа слоёв.



Что за линейное внимание? Сама концепция не нова: убираем софтмакс и получаем рекуррентную форму. Софтмакс убираем аппроксимацией exp(qk) через ряд Тейлора до второго порядка малости: 1 + qk + (qk)^2 / 2. Подробнее об этом написано в другой статье, но тут всё равно показывают, что другие варианты аппроксимаций хуже.



Псевдокод вычисления:

qk0 = [1]



# Считаем слагаемые от q

q_first = [q1, ..., qd] # q

q_second = [q1 * q1, ..., q1 * qd, ..., qd * q1, ... qd * qd] # qq^T

q_new = cat(qk0, q_first, q_second / sqrt(2))



# Считаем слагаемые от k

k_first = [k1, ..., kd] # k

k_second = [k1 * k1, ..., k1 * kd, ..., kd * k1, ... kd * kd] # kk^T

k_new = cat(qk0, k_first, k_second / sqrt(2))



# Разложение экспоненты до второго порядка малости: 1 + qk + (qk)^2 / 2

y = (q_new * k_new).sum()

Мотивация добавления чего-либо к BaseConv всё та же: улучшить вспоминание информации, добавив входозависимые смесители. При этом, в отличие от трансформеров, у которых KV-кэш растёт линейно от длины последовательности, мы можем варьировать объём доступной нам памяти через размер скрытого состояния в линейном внимании. И таким образом можем разменивать точность вспоминания на скорость.



В статье есть фундаментальная теорема: решение MQAR (а значит и языкового моделирования на реальных текстах) требует размер скрытого состояния, линейный от длины последовательности. Для внимания таким скрытым состоянием является KV-кэш, который имеет размер O(Nd). Для Based это KV-состояние для числителя и K-сотояние для знаменателя в симуляции софтмакса, с общей размерностью O(d^3), но в статье эту размерность несколько уменьшают проекциями. Это всё означает, что чем длиннее последовательность мы хотим корректно обрабатывать, тем жирнее нам нужно делать размерность модели 😭



Для того, чтобы конкурировать с эффективными реализациями внимания, ребята написали кастомные ядра, которые стараются выполнять операции в SRAM, быстрой памяти GPU, по аналогии с FlashAttention. Получилось и правда быстро 😘



В итоге имеем модель, которая по качеству на уровне трансформеров, но пропускная способность которой в 25 раз больше.



Что мне кажется сомнительным во всей этой истории:

- А что произошло с перплексией в первом посте? Что авторы изначально сделали не так?

- Нафига было разрабатывать BaseConv, чтобы потом вернуться к линейному вниманию? Зачем нам теперь вообще BaseConv нужен? Нет, в статье конечно есть секция C, в которой добавление BaseConv обосновывается выигрышем в перплексии, но это ничего не объясняет.

- Почему только 1b параметров? Спонсоров вроде много, люди богатые.



P.S. Немного покринжевал со слова "смеситель" в этом контексте, но решил оставить. Тоже смешивает же.