2024 年 8 月 27 日,斯坦福大学、加州大学圣地亚哥分校、加州大学伯克利分校和 Meta 的研究人员联合提出了一种名为 TTT(测试时间训练层)的新架构,旨在取代传统的 Transformer 和 Mamba 模型。TTT 模型通过对输入 token 进行梯度下降来压缩上下文,直接替代了注意力机制,解锁了具有表现力记忆的线性复杂度架构。研究表明,TTT-Linear 和 TTT-MLP 在性能上超越了现有的最强模型。
TTT 模型的核心创新在于其测试时间训练层(Test-Time-Training layers),通过对输入 token 进行梯度下降来压缩上下文,直接替代了传统的注意力机制。这种方法不仅降低了计算复杂度,还提高了模型的表现力和记忆能力。TTT-Linear 和 TTT-MLP 是该架构的两个主要变体,分别在不同的任务和数据集上展示了性能。
研究人员在论文中详细描述了 TTT 模型的架构和算法,并提供了大量实验数据来验证其有效性。实验结果显示,TTT-Linear 和 TTT-MLP 在 125M 到 1.3B 参数规模上,与 Transformer 和现代 RNN Mamba 进行比较,结果显示 TTT-Linear 和 TTT-MLP 在性能上匹敌或超越了基准模型。
具体来说,TTT-Linear 在处理短文本任务时表现尤为出色,而 TTT-MLP 则在长文本和复杂任务中展示了更大的潜力。研究人员指出,TTT 模型在 8k 上下文中已经比 Transformer 更快,并且在墙钟时间上与 Mamba 匹敌。TTT-MLP 在长上下文中显示出更大的潜力,能够在更大的隐藏状态中压缩更多信息。
GitHub 上也迅速出现了多个 TTT 模型的实现版本,方便开发者进行实验和应用。例如,test-time-training 团队在 GitHub 上发布了 TTT-Linear 和 TTT-MLP 的 PyTorch 实现,基于 Huggingface Transformers 库,可以用于加载模型和生成文本。该实现版本支持在 GPU 和 Cloud TPU VMs 上运行,适用于 Python 3.11。开发者可以通过以下链接访问这些实现版本:
- TTT-Linear 和 TTT-MLP 的 PyTorch 实现
- TTT-Linear 和 TTT-MLP 的 JAX 实现
- TTT-Linear 和 TTT-MLP 的快速实现
这些实现版本不仅提供了模型的训练和推理代码,还包含了重现论文中吞吐量结果的脚本,方便开发者进行性能测试和优化。
此外,一些开发者还在 Medium 等平台上撰写了关于 TTT 模型的技术文章,详细介绍了 TTT 模型的架构、算法和应用场景。例如,Medium 上的一篇文章详细介绍了 TTT 层如何利用其更大的隐藏状态在长上下文中压缩更多信息,并指出 TTT-MLP 在长上下文中表现优于 TTT-Linear。文章链接如下: