Совсем недавно вышла бета версия библиотеки TorchMultimodal (ссылка), в которой авторы постарались собрать все лучшие техники и фичи обучения SoTA мультизадачных мультимодальных (М2) архитектур:

слои, обработчики для разных модальностей, лосс функции (Contrastive Loss, Codebook слои, Shifted-window Attention, Components for CLIP, Multimodal GPT, Multi Head Attention)

SoTA архитектуры (FLAVA, DETR, …)

скрипты обучения и инференса

примеры использования



Всё это позволит ставить быстрые и удобные эксперименты для обучения М2 моделей.



В довесок авторы сделали интересный пост о распределённом обучении (ссылка), где на примере модели FLAVA (мультимодальный late fusion трансформер) показали, как можно её масштабировать с 350M параметров до 10B. Рассмотрели два ключевых подхода:

1. Distributed Data Parallel - нарезка датасета по воркерам, градиенты синхронизируются ДО обновления весов, по сути вся модель «реплицируется»

2. Fully Sharded Data Parallel - параметры, градиенты и состояния оптимизатора нарезаются (шардируются) по воркерам (а-ля ZeRO-3), перед forward и backward propagation шарды объединяются.



Сравнение производительности (среднее число сэмплов в секунду за исключением первых 100 на warmup) можно оценить на графике.





github

статья про TorchMultimodal

статья про Scaling Multimodal Foundation Models



@complete_ai