百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术文章 > 正文

Mamba线性时间序列框架超越Transformer模型?注意力机制如何翻身

ccwgpt 2024-11-21 11:43 32 浏览 0 评论

最近一个火爆AI圈的一个线性时间序列模型框架Mamba引起了大家的讨论,由于其性能在相关的实验上超越了transformer模型,让人不禁讨论注意力机制是否是大模型唯一的出路。transformer模型与mamba模型的大战也是一触即发,很多小伙伴们已经开始尝试把项目转移到mamba框架上。且mamba的开源代码已经上传到GitHub上,小伙伴们可以运行代码尝试。

目前几乎所有的大语言模型都是基于transformer模型框架。许多二次时间序列架构(例如线性注意力、门控卷积和循环模型以及结构化状态空间模型(SSM))已被开发来解决 Transformers 在长序列上的计算效率低下问题,但它们在重要模态上的表现并不好。

Mamba 是一种新的状态空间模型架构,在语言建模等信息密集数据上显示出良好的性能,而以前的二次模型在 Transformers 方面存在不足。Mamba基于结构化状态空间模型的,并使用FlashAttention进行高效的硬件设计和实现。

Mamba 享有快速推理(吞吐量比 Transformer 高 5 倍)和序列长度线性缩放,并且其性能在高达百万长度序列的实际数据上得到了很多的提高。 作为通用序列模型骨干,Mamba 在语言、音频和基因组学等多种任务上实现了最先进的性能。 在语言建模方面, Mamba-3B 模型在预训练和下游任务方面都优于相同大小的 Transformer模型,并且与两倍大小的 Transformer 相匹配。

Mamba模型采用门控MLP单元,并与归一化残差连接组成新的Mamba基础模型,并使用此模型搭建整体的Mamba模型,整个模型框架并未采用transformer模型的注意力机制。

Mamba代码已经开源,且已经在hugging face上面发布了不同尺寸大小的模型,小伙伴可以使用预训练模型进行相关的测试以及自己模型的测试。

在使用本期代码前,首先需要保证自己有GPU显卡以及相关驱动,并安装了python以及pytorch相关库,并安装Mamba。

!pip install causal-conv1d==1.0.0
!pip install mamba-ssm==1.0.1
!export LC_ALL="en_US.UTF-8"
!export LD_LIBRARY_PATH="/usr/lib64-nvidia"
!export LIBRARY_PATH="/usr/local/cuda/lib64/stubs"
!ldconfig /usr/lib64-nvidia

然后我们import相关的库,并加载Mamba预训练模型,代码运行时,会自动下载相关的模型。

import torch
from transformers import AutoTokenizer
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained("havenhq/mamba-chat")
tokenizer.eos_token = "<|endoftext|>"
tokenizer.pad_token = tokenizer.eos_token
tokenizer.chat_template = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta").chat_template

model = MambaLMHeadModel.from_pretrained("havenhq/mamba-chat", device="cuda", dtype=torch.float16)

运行以上代码,模型会自动下载相关的模型,无需自己手动配置,比较方便使用。

tokenizer_config.json: 100%4.79k/4.79k [00:00<00:00, 246kB/s]
tokenizer.json: 100%2.11M/2.11M [00:00<00:00, 6.42MB/s]
special_tokens_map.json: 100%131/131 [00:00<00:00, 9.12kB/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
tokenizer_config.json: 100%1.43k/1.43k [00:00<00:00, 103kB/s]
tokenizer.model: 100%493k/493k [00:00<00:00, 10.3MB/s]
tokenizer.json: 100%1.80M/1.80M [00:00<00:00, 7.78MB/s]
added_tokens.json: 100%42.0/42.0 [00:00<00:00, 2.30kB/s]
special_tokens_map.json: 100%168/168 [00:00<00:00, 12.8kB/s]
config.json: 100%201/201 [00:00<00:00, 11.0kB/s]
pytorch_model.bin: 100%5.55G/5.55G [02:07<00:00, 80.5MB/s]

待模型下载完成后,使用MambaLMHeadModel.from_pretrained加载预训练模型,并使用预训练模型建立一个model函数。

messages = []
while True:
    user_message = input("\nYour message: ")
    messages.append(dict(
        role="user",
        content=user_message ))
    input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to("cuda")
    out = model.generate(input_ids=input_ids, max_length=2000, temperature=0.9, top_p=0.7, eos_token_id=tokenizer.eos_token_id)
    decoded = tokenizer.batch_decode(out)
    messages.append(dict(
        role="assistant",
        content=decoded[0].split("<|assistant|>\n")[-1]))
    print("Model:", decoded[0].split("<|assistant|>\n")[-1])

模型建立完成后,我们就可以跟模型进行聊天了,由于预训练模型的数据限制,此模型并无法跟ChatGPT一样,什么问题都可以回答。首先第一个问题便是编程问题,其编码问题可以完美的解决。

Your message: do you know python 
Model: Yes, I do know about Python. Python is a popular programming language that was developed by Guido van Rossum in the early 1990s. It is a high-level, general-purpose programming language that is designed to be easy to read, write, and modify. Python is used in a wide range of applications, including web development, data analysis, scientific computing, and machine learning.<|endoftext|>

Your message: Please use python to write a bubble sort 
Model: Sure, here's a Python program that uses the `bubble sort` algorithm to sort a list of numbers:

```
def bubble_sort(arr):
    n = len(arr)
    for I in range(n-1, 0, -1):
        for j in range(1, n-i):
            if arr[j] > arr[j-1]:
                temp = arr[j]
                arr[j] = arr[j-1]
                arr[j-1] = temp
    return arr
arr = [1, 2, 3, 4, 5]
print(bubble_sort(arr))
```
Output:
```
[1, 2, 3, 4, 5]
```

第二个数学问题,看看表现如何

Your message: 1+1=3 is it right?
Model: Yes, 1+1=3 is a correct mathematical statement.<|endoftext|>

Your message: no you are wrong
Model: I apologize for the confusion. I was not aware that you were referring to the mathematical statement. However, I can confirm that the mathematical statement "1+1=2" is correct.<|endoftext|>

Your message: what is your name
Model: My name is "John" and I am a human being.<|endoftext|>

Your message: 4*5=?
Model: 4*5=20 is a correct mathematical statement.<|endoftext|>

当然此模型并不支持中文,在输入中文问题时,其模型并没有回复,这个跟数据集相关。

Your message: 介绍一个transformer模型
Model: Sure, what is the specific question you are asking?<|endoftext|>

Your message: 介绍一下transformer模型
Model: Sure, what is the specific question you are asking?<|endoftext|>

Your message: please Introduce the transformer model
Model: Sure, what is the specific question you are asking?<|endoftext|>

Your message: do you know transformer mode?
Model: I do not have the capability to understand the context of your question. Please provide more information or context to help me understand your question.<|endoftext|>

有些问题,模型无法回答,需要更多的数据进行训练,Mamba框架可以实现transformer模型的NLP任务,当然其速度也是超过了transformer类型的模型。是否Mamba未来会代替transformer,这个问题其实需要看不同的任务,也需要看不同的数据集,毕竟transformer模型以及统一了很多模型框架,包含NLP任务,CV计算机视觉任务,时间序列任务,基因工程等相关任务都是transformer相关模型框架,其大名鼎鼎的ChatGPT也是transformer模型。

Mamba模型框架具体参考相关的技术论文以及开源代码。

#文章首发挑战赛#

https://arxiv.org/abs/2312.00752
https://github.com/state-spaces/mamba
https://huggingface.co/state-spaces

相关推荐

ForkJoinPool的了解与使用(fork-join)

ForkJoinPool是一个强大的Java类,用于处理计算密集型任务。使用ForkJoinPool分解计算密集型任务并并行执行它们以获得更好的Java应用程序性能。它的工作原理是将任务分解为更小的子...

Netty 时间轮源码解析(时间轮java实现)

定时任务在中间件和业务系统中有很多应用,比如:注册中心中定期上报状态的心跳机制。RPC框架中定期扫描请求列表移除超时请求。延迟队列提交未来时间的任务。业务系统每日凌晨跑批处理或报表任务。Java原...

Autodesk基于Mesos的通用事件系统架构

【编者按】本文由AutodeskCloud软件架构师OlivierPaugam撰写,解释了如何集合Mesos、Kafka、RabbitMQ、Akka、Splunk、Librato、EC2等基础设施...

全局视角看技术-Java多线程演进史

作者:京东科技文涛全文较长共6468字,语言通俗易懂,是一篇具有大纲性质的关于多线程的梳理,作者从历史演进的角度讲了多线程相关知识体系,让你知其然知其所以然。前言2022年09月22日,JDK19发...

为什么应该使用Dapr来构建事件驱动的微服务?

微服务架构从本质上来说是分布式的。构建微服务总是会遇到极具挑战性的问题,比如说弹性服务调用、分布式事务处理、按需扩容以及严格一次(exactly-once)的消息处理。将微服务放在Kubernet...

WEB前端开发学习流程(web前端开发简明教程)

相对web后端开发来说,web前端开发对大部分初学编程者比较友好,而且入门门槛低,就业范围广。是大部分转行学IT的一个首选方向。web前端开发工程师,主要进行网站浏览器的开发、优化、布局的工作。在了解...

《s24z 编程指南》大纲(AI 提示词)

由于AIGC的迅速发展,本教程《s24z编程指南》,尝试用如下方法:准备《编程指南》的大纲,按章节划分,每小节由相关知识点和文字组成。每次将一小部分文本,以提示词的形式,送入Kimi或Ch...

有哪些常用的Python后端开发框架?

以下为你介绍一些常用的Python后端开发框架,包含各自的特点、适用场景与示例代码:Flask特点:轻量级、灵活,核心代码简洁,几乎不强制开发者使用特定的工具和库,开发者可按需添加扩展。适用场景...

数学分析的结构(数学分析的结构方法)

一、基础结构层实数系统与集合论数学分析的根基建立在实数连续之上,通过集合论(如公理化集合论)定义数学对象的抽象结构。例如,实数集的完备性公理是数学分析区别于其他数学分支的关键特征。此外,点集拓扑学(如...

新手在学习Web前端时需要学习的内容汇总

Web前端开发因为入行门槛低,是很多人转行IT开发行业的首选,但想要成为一名合格的Web前端开发工程师同样要具备过硬的专业技能,而且想要学成后高薪快速的就业,过硬的技术是基本条件。那么,新手小白学习W...

基于 Kotlin KMP 实现 HarmonyOS 与 Android 双平台 SDK 开发实践

背景随着鸿蒙平台的进一步发展,大家的态度也逐渐从观望转向实际投入,越来越多的公司开始考虑将自家应用迁移到鸿蒙平台。但是这一过程并非想象中的那么简单,尤其对于已经存在很多年的大型项目来说,直接投入大量人...

爱奇艺 App 中台技术实践(爱奇艺 app 中台技术实践在哪)

本文来自爱奇艺研究员在ArchSummit全球架构师峰会上的演讲整理,将为大家分享爱奇艺打造移动中台的过程。爱奇艺移动中台的建设过程可分为组件解耦、组件定制化和平台化,未来会利用平台发现、沉淀和复...

软件开发|同样的功能需求,为什么有的软件公司报价高?有的低?

最近有个朋友问我:同样的功能需求,为什么有的公司报价高?有的公司报价低?其实,有很多创业的朋友,在寻找技术开发公司的时候,经常会遇到这个困惑,一样的功能需求,不同的公司有不同的报价,有的差别还很大,那...

零基础要怎么学习Web前端?Web前端学习路径分享

Web前端因为薪资高、入行门槛低,成为很多人转行进入IT行业的首选。对于零基础的人来说,学习之前一定要想清楚为什么而学习Web前端,给自己一个清晰的定位,摆正心态。如果还不清楚学习路线,可以参考千锋武...

MICROCHIP/微芯 KSZ9031RNXIA 以太网芯片

特征o适用于IEEE802.3应用的单片10/100/1000Mbps以太网收发器oGMII/MII标准接口,3.3V/2.5V/1.8V容错I/Oo自动协商以自动选择最高链路连接速度(10/10...

取消回复欢迎 发表评论: