昇思25天学习打卡营第9天|MindSpore-Vision Transformer图像分类

Vision Transformer图像分类

Vision Transformer(ViT)简介

近些年,随着基于自注意(Self-Attention)结构的模型的发展,特别是Transformer模型的提出,极大地促进了自然语言处理模型的发展。由于Transformers的计算效率和可扩展性,它已经能够训练具有超过100B参数的空前规模的模型。

ViT则是自然语言处理和计算机视觉两个领域的融合结晶。在不依赖卷积操作的情况下,依然可以在图像分类任务上达到很好的效果。

模型结构

ViT模型的主体结构是基于Transformer模型的Encoder部分(部分结构顺序有调整,如:Normalization的位置与标准Transformer不同),其结构图[1]如下:

模型特点

ViT模型主要应用于图像分类领域。因此,其模型结构相较于传统的Transformer有以下几个特点:

  1. 数据集的原图像被划分为多个patch(图像块)后,将二维patch(不考虑channel)转换为一维向量,再加上类别向量与位置向量作为模型输入。
  2. 模型主体的Block结构是基于Transformer的Encoder结构,但是调整了Normalization的位置,其中,最主要的结构依然是Multi-head Attention结构。
  3. 模型在Blocks堆叠后接全连接层,接受类别向量的输出作为输入并用于分类。通常情况下,我们将最后的全连接层称为Head,Transformer Encoder部分为backbone。

下面将通过代码实例来详细解释基于ViT实现ImageNet分类任务。

注意,本教程在CPU上运行时间过长,不建议使用CPU运行。

环境准备与数据读取¶

开始实验之前,请确保本地已经安装了Python环境并安装了MindSpore。

首先我们需要下载本案例的数据集,可通过http://image-net.org下载完整的ImageNet数据集,本案例应用的数据集是从ImageNet中筛选出来的子集。

运行第一段代码时会自动下载并解压,请确保你的数据集路径如以下结构。

.dataset/
    ├── ILSVRC2012_devkit_t12.tar.gz
    ├── train/
    ├── infer/
    └── val/

实践环境安装

Python环境为 3.9.19

pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14

# or 
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple mindspore==2.2.14

数据准备

from download import download

dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vit_imagenet_dataset.zip"
path = "./"

path = download(dataset_url, path, kind="zip", replace=True)


import os

import mindspore as ms
from mindspore.dataset import ImageFolderDataset
import mindspore.dataset.vision as transforms


data_path = './dataset/'
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]

dataset_train = ImageFolderDataset(os.path.join(data_path, "train"), shuffle=True)

trans_train = [
    transforms.RandomCropDecodeResize(size=224,
                                      scale=(0.08, 1.0),
                                      ratio=(0.75, 1.333)),
    transforms.RandomHorizontalFlip(prob=0.5),
    transforms.Normalize(mean=mean, std=std),
    transforms.HWC2CHW()
]

dataset_train = dataset_train.map(operations=trans_train, input_columns=["image"])
dataset_train = dataset_train.batch(batch_size=16, drop_remainder=True)

模型解析

下面将通过代码来细致剖析ViT模型的内部结构。

Transformer基本原理

Transformer模型源于2017年的一篇文章[2]。在这篇文章中提出的基于Attention机制的编码器-解码器型结构在自然语言处理领域获得了巨大的成功。模型结构如下图所示:

其主要结构为多个Encoder和Decoder模块所组成,其中Encoder和Decoder的详细结构如下图[2]所示:

Encoder与Decoder由许多结构组成,如:多头注意力(Multi-Head Attention)层,Feed Forward层,Normaliztion层,甚至残差连接(Residual Connection,图中的“Add”)。不过,其中最重要的结构是多头注意力(Multi-Head Attention)结构,该结构基于自注意力(Self-Attention)机制,是多个Self-Attention的并行组成。

所以,理解了Self-Attention就抓住了Transformer的核心。

Attention模块

以下是Self-Attention的解释,其核心内容是为输入向量的每个单词学习一个权重。通过给定一个任务相关的查询向量Query向量,计算Query和各个Key的相似性或者相关性得到注意力分布,即得到每个Key对应Value的权重系数,然后对Value进行加权求和得到最终的Attention数值。

在Self-Attention中:

1. 最初的输入向量首先会经过Embedding层映射成Q(Query),K(Key),V(Value)三个向量,由于是并行操作,所以代码中是映射成为dim x 3的向量然后进行分割,换言之,如果你的输入向量为一个向量序列(𝑥1,𝑥2,𝑥3),其中的𝑥1,𝑥2,𝑥3都是一维向量,那么每一个一维向量都会经过Embedding层映射出Q,K,V三个向量,只是Embedding矩阵不同,矩阵参数也是通过学习得到的。这里大家可以认为,Q,K,V三个矩阵是发现向量之间关联信息的一种手段,需要经过学习得到,至于为什么是Q,K,V三个,主要是因为需要两个向量点乘以获得权重,又需要另一个向量来承载权重向加的结果,所以,最少需要3个矩阵。

2. 自注意力机制的自注意主要体现在它的Q,K,V都来源于其自身,也就是该过程是在提取输入的不同顺序的向量的联系与特征,最终通过不同顺序向量之间的联系紧密性(Q与K乘积经过Softmax的结果)来表现出来。Q,K,V得到后就需要获取向量间权重,需要对Q和K进行点乘并除以维度的平方根,对所有向量的结果进行Softmax处理,通过公式(2)的操作,我们获得了向量之间的关系权重。

3. 其最终输出则是通过V这个映射后的向量与Q,K经过Softmax结果进行weight sum获得,这个过程可以理解为在全局上进行自注意表示。每一组Q,K,V最后都有一个V输出,这是Self-Attention得到的最终结果,是当前向量在结合了它与其他向量关联权重后得到的结果。

Self-Attention的全部过程

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/767050.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

传输线在阻抗匹配时串联端接电阻为什么要靠近发送端

传输线在阻抗匹配时串联端接电阻为什么要靠近发送端 在进行阻抗匹配的时候我们可以在电阻源端放置一个串联端接电阻,但是有时候受到空间的限制可能会把电阻摆的稍微远一点,那么这个时候大家可能会有疑问,电阻离发送端远一点或者电阻放置在接…

java+mysql教师管理系统

完整源码地址 教师信息管理系统使用命令行交互的方式及数据库连接实现教师信息管理系统,该系统旨在实现教师信息的管理,并根据需要进行教师信息展示。该软件的功能有如下功能 (1)基本信息管理(教师号、姓名、性别、出生年月、职称、学历、学位、教师类型…

【Git 学习笔记】1.3 Git 的三个阶段

1.3 Git 的三个阶段 由于远程代码库后续存在新的提交,因此实操过程中的结果与书中并不完全一致。根据书中 HEAD 指向的 SHA-1:34acc370b4d6ae53f051255680feaefaf7f7850d,可通过以下命令切换到对应版本,并新建一个 newdemo 分支来…

【STM32 RTC实时时钟如何配置!超详细的解析和超简单的配置,附上寄存器操作】

STM32 里面RTC模块和时钟配置系统(RCC_BDCR寄存器)处于后备区域,即在系统复位或从待机模式唤醒后,RTC的设置和时间维持不变。因为系统对后备寄存器和RTC相关寄存器有写保护,所以如果想要对后备寄存器和RTC进行访问,则需要通过操作…

社交媒体优化的智能顾问:Kompas.ai如何提升品牌社交表现

在社交媒体盛行的数字时代,品牌必须在社交平台上保持活跃和互动,以增强品牌社交互动和提升在线可见性。社交媒体优化不仅能够扩大品牌的影响力,还能够加深与消费者的联系。Kompas.ai,作为一款智能社交媒体顾问工具,能够…

【前端项目笔记】7 商品管理

商品管理 效果展示: 在功能开发之前,创建商品列表的子分支 git branch 查看所有分支 git checkout -b goods_list 创建并切换到新分支goods_list git push -u origin goods_list 将新分支goods_list推送到云端仓库origin并命名为goods_list保存 通过…

LLM学习记录

概述 语言模型的发展 语言模型经历过四个阶段的发展,依次从统计语言模型到神经网络语言模型(NLM),到出现以 BERT 和 Transformer 架构为代表的预训练语言模型(PLM),最终到大型语言模型阶段&am…

竞赛选题 交通目标检测-行人车辆检测流量计数 - 竞赛选题

文章目录 0 前言1\. 目标检测概况1.1 什么是目标检测?1.2 发展阶段 2\. 行人检测2.1 行人检测简介2.2 行人检测技术难点2.3 行人检测实现效果2.4 关键代码-训练过程 最后 0 前言 🔥 优质竞赛项目系列,今天要分享的是 🚩 毕业设计…

【Java环境配置过程详解(包括IDEA配置Java)】

目录 一、JDK下载安装 1. 官网下载JDK 2. 本地安装JDK 3. 配置环境变量 4. 验证是否安装成功 ​编辑二、IDEA进行安装下载 1. 官网下载 IDEA 2、IDEA进行Java开发 1. 创建Java项目 2. 程序测试 一、JDK下载安装 1. 官网下载JDK 1)官网链接: https://www.o…

PTrade如何获取技术值班?如get_RSI - 相对强弱指标;PTrade量化软件如何获取?

get_RSI - 相对强弱指标 get_RSI(close, n6) 使用场景 该函数仅在回测、交易模块可用 接口说明 获取相对强弱指标RSI指标的计算结果 PTrade是恒生公司开发的一款专业量化软件,部分合作券商可提供,↑↑↑! 参数 close:价格…

C语言的数据结构:图的基本概念

前言 之前学过了其它的数据结构,如: 集合 \color{#5ecffd}集合 集合 —— 数据元素属于一个集合。 线型结构 \color{#5ecffd}线型结构 线型结构 —— 一个对一个,如线性表、栈、队列,每一个节点和其它节点之间的关系 一个对一个…

rpm包下载

内网无法下载、选择外网的一台机器下载rpm包 下载后上传rpm包 1、创建下载目录 mkdir /data/asap/test 2、下载能留存包的工具 sudo yum install yum-utils -y 报错就是环境问题没下载成功,我换了个环境正常的机器就可以了 3、下载rpm包到指定目录/data/asa…

一文彻底搞懂Transformer - Input(输入)

一、输入嵌入(Input Embedding) 词嵌入(Word Embedding):词嵌入是最基本的嵌入形式,它将词汇表中的每个单词映射到一个固定大小的向量上。这个向量通常是通过训练得到的,能够捕捉单词之间的语义…

GAMES104:04游戏引擎中的渲染系统1:游戏渲染基础-学习笔记

文章目录 概览:游戏引擎中的渲染系统四个课时概览 一,渲染管线流程二,了解GPUSIMD 和 SIMTGPU 架构CPU到GPU的数据传输GPU性能限制 三,可见性Renderable可渲染对象提高渲染效率Visibility Culling 可见性裁剪 四,纹理压…

如何在《中小学电教》期刊上发表论文?

如何在《中小学电教》期刊上发表论文? 《中小学电教》知网 学术期刊 教育厅25年下半年 3版 ①其他学科 不收甘肃和幼儿园 ②数学、英语、历史、政治(道德与法治)、音体美、科学学科的稿件 全bao 全bao不带课题 文章需要和信息…

【TS】TypeScript 原始数据类型深度解析

🌈个人主页: 鑫宝Code 🔥热门专栏: 闲话杂谈| 炫酷HTML | JavaScript基础 ​💫个人格言: "如无必要,勿增实体" 文章目录 TypeScript 原始数据类型深度解析一、引言二、基础原始数据类型2.1 boolean2.2 …

数据治理体系建设方案

数据治理体系建设方案 在当前的大数据时代,数据已经成为企业核心资产之一,其管理与治理的重要性愈加凸显。有效的数据治理体系不仅能提升数据质量和数据使用的效率,还能为企业创造更多的商业价值。本文将详细阐述数据治理的重要性、核心要素…

SpringBoot 如何处理跨域请求?你说的出几种方法?

引言:在现代的Web开发中,跨域请求(Cross-Origin Resource Sharing,CORS)是一个常见的挑战。随着前后端分离架构的流行,前端应用通常运行在一个与后端 API 不同的域名或端口上,这就导致了浏览器的…

AI生成电商模特图应用定制

🌟 广州AI生成电商模特图应用定制案例剖析— 触站AI,绘制智能图像的未来 🚀 🎨 触站AI,让创意与智能共绘辉煌 🎨在这座充满创新活力的广州城,触站AI以其尖端AI技术,开启了企业AI图像…

动态代理--通俗易懂

程序为什么需要代理?代理长什么样? 例子 梳理 代理对象(接口):要包含被代理的对象的方法 ---Star 被代理对象:要实现代理对象(接口) ---SuperStar 代理工具类:创建一个代理,返回值用代理对象&#xff0c…