深度學習之多層感知器(MLP)之經典mnist數字識別

news/2024/7/6 4:39:38 标签: keras, sklearn, 神经网络, 深度学习, mlp

目录

  • 前言
  • 簡介
  • 思考與推導
  • 實戰
  • 總結

前言

在上一篇文章中用mlp解决了一个好壞質檢二分類问题,这次我们依然用多層感知器mlp来解决經典mnist數字識別


回顧一下前文,但是具體理論還是看前文深度學習之多層感知器(MLP)

簡介

多层感知器(MLP,Multilayer Perceptron)是一种前馈人工神经网络模型,其将输入的多个数据集映射到单一的输出的数据集上。

在这里插入图片描述

思考與推導

如果学过生物的话,应该都学过人的神经元结构,其中有两个很重要的部分(树突和轴突末端)

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述
当人看到图片时,神经元间就会进行一些信息传递进而判断出图片类别。所以当我们模仿人体神经元时,可以想象将神经元信号数值化

在这里插入图片描述
即是说给通过输入数据,再经过相关公式,最后生成一个新的数值信号。看到这儿,是不是感觉和逻辑回归有一些相似之处,那就再回顾一下逻辑回归模型框架

在这里插入图片描述

可以看到左边是我们的输入数据,中间是我们的公式,右边就是我们求得的概率了。
虽然有些相似,但人的神经元有一个不同的点就是它并不是一个单一的神经元,也可以理解成它并不是一个单一的逻辑回归结构,而是由很多个神经元组成的神经网络。那么我们能把逻辑回归组成一个网络吗?

在这里插入图片描述

这里我们可以这样理解,从输入神经元开始到第一个隐含神经元这样一个计算过程,是通过逻辑回归计算的,然后第一个隐含神经元和第二个隐含神经元间也是通过逻辑回归计算的(这里第一个隐含神经元的相关数据可以看作是新的输入),同理,第二个隐含神经元和输出神经元间也是通过逻辑回归计算的。

还是一样,来看一个简化的MLP模型结构

在这里插入图片描述

来回忆一下逻辑回归模型的一些数学公式

在这里插入图片描述

这里z可以是x乘以θ

那这个简化的结构的y怎么计算呢?

我们是先计算a的值然后再计算y的值(这里x0是常数项)
在这里插入图片描述
这里就和前面的z=xθ相同了

在这里插入图片描述

話不多説直接進入實戰

實戰

基於mnist數據集,建立mlp模型,實現0-9數字的十分類task:

  1. 實現mnist數據載入,可視化圖形數字
  2. 完成數據預處理:圖像數據維度轉換與歸一化、輸出結果格式轉換
  3. 計算模型在預測數據集的準確率
  4. 模型結構:兩層隱藏層,每層有392個神經元
# load the mnist data
from keras.datasets import mnist
(X_train, y_train), (X_test, y_test) = mnist.load_data()

加載數據,keras為我們提供了這個經典數據集

在这里插入图片描述

# visualize the data
img1 = X_train[0]
from matplotlib import pyplot as plt
fig1 = plt.figure(figsize=(3,3))
plt.imshow(img1)
plt.title(y_train[0])
plt.show()

隨便找張圖可視化看一下是什麽
在这里插入图片描述
可以看的出來是數字5

#fromat the input data
feature_size = img1.shape[0] * img1.shape[1]
X_train_format = X_train.reshape(X_train.shape[0],feature_size)
X_test_format = X_test.reshape(X_test.shape[0],feature_size)

print(X_train_format.shape)

格式化一下
在这里插入图片描述

# normalize the input data
X_train_normal = X_train_format/255
X_test_normal = X_test_format/255

print(X_train_normal[0])

歸一化
在这里插入图片描述

# format the output data(labels)
from keras.utils import to_categorical
y_train_format = to_categorical(y_train)
y_test_format = to_categorical(y_test)

print(y_train[0]) # 5
print(y_train_format[0]) # 在下標為5 的地方為1

轉爲one-hot編碼(獨熱編碼)

独热编码即 One-Hot 编码,又称一位有效编码,其方法是使用N位状态寄存器来对N个状态进行编码,每个状态都有它独立的寄存器位,并且在任意时候,其中只有一位有效。
這裏不詳細説明,看注釋應該也看得懂這個編碼是什麽效果

在这里插入图片描述

在这里插入图片描述
建一個這樣的模型

# set up the model
from keras.models import Sequential
from keras.layers import Dense,Activation

mlp = Sequential()
mlp.add(Dense(units=392,activation='sigmoid',input_dim=feature_size)) # 第一層
mlp.add(Dense(units=392,activation='sigmoid')) # 第二層
mlp.add(Dense(units=10,activation='softmax')) # 輸出層
mlp.summary()

設置模型並summary()看一下模型結構,第一層是需要告訴input_dim的,後面幾層是根據上一層的所以就不用,激活函數用的sigmoid

在这里插入图片描述
可以看出有兩個隱藏層,而且訓練數據量也挺大的,會挺耗時

# comfigure the model
mlp.compile(loss='categorical_crossentropy',optimizer='adam')

配置模型

mlp.fit(X_train_normal,y_train_format,epochs=10)

模型訓練

在这里插入图片描述
迭代10次大概用了一分鐘

评估模型

y_train_predict = np.argmax(mlp.predict(X_train_normal), 1) # np.argmax()獲取對應整數, [5 0 4 ... 5 6 8]
print(y_train_predict)

在这里插入图片描述

from sklearn.metrics import accuracy_score
accuracy_train = accuracy_score(y_train,y_train_predict)
print(accuracy_train)

看一下準確率,發現還不錯

在这里插入图片描述

y_test_predict = np.argmax(mlp.predict(X_test_normal), 1)
accuracy_test = accuracy_score(y_test,y_test_predict)
print(accuracy_test)

看一下測試數據準確率,也是還可以,沒有過擬合問題

在这里插入图片描述

img2 = X_test[123]
fig1 = plt.figure(figsize=(3,3))
plt.imshow(img2)
plt.title(y_test_predict[123])
plt.show()

可視化看一下,標題是預測數字,圖片是本來的圖片

在这里插入图片描述
可以看出準確預測出是6了,也可以用別的數字試試看結果都預測的還可以

總結

圖像數字多分類實戰summary:

  1. 通過mlp模型,實現了基於圖像數據的數字自動識別分類
  2. 完成了圖像的數字化處理與可視化
  3. mlp模型的輸入、輸出數據格式有了更深的認識,完成了數據預處理與格式轉換
  4. 建立了結構更爲複雜的mlp模型
  5. mnist數據集地址:http://yann.lecun.com/exdb/mnist/

這就是本次學習深度學習之多層感知器(MLP)的筆記
附上本次實戰的數據集和源碼:
鏈接:https://github.com/fbozhang/python/tree/master/jupyter


http://www.niftyadmin.cn/n/155111.html

相关文章

多来客课堂:抖音本地生活服务商平台入驻门槛跟流程解析

首先—— 想要入驻平台,成为本地生活服务商是有一定门槛的,从官方给出的条件来看,这个门槛其实是不低的,具体有以下几点: 1、需要你有公司,是法人,并且营业执照的注册资金不能低于50万&#xf…

D. Game on Axis(很有意思的图论题)

Problem - D - Codeforces 有n个点&#xff0c;1,2,n&#xff0c;每个点i上都有一个数字a。你在跟他们玩游戏。一开始&#xff0c;你在点1。当你在第i点时&#xff0c;采取以下步骤:如果 1<i<n&#xff0c;转到 i ai&#xff0c;否则&#xff0c;游戏结束。在游戏开始前&…

【Java进阶篇】—— 常用类和基础API

一、String类 1.1 String的特性 java.lang.String 类代表字符串&#xff0c;由final关键字修饰&#xff0c;在赋值后不能改变&#xff08;常量&#xff09;&#xff0c;不能继承String类String 对象的字符内容是存储在一个字符数组 value[]中的 我们来看一下String在JDK8中的…

PMP第十一章重要知识点

第11章 项目风险管理 项目风险管理的目标在于提高正面风险的概率和&#xff08;或&#xff09;影响&#xff0c;降低负面风险的概率和&#xff08;或&#xff09;影响&#xff0c;从而提高项目成功的可能性。 项目的独特性带来风险。 风险三要素&#xff1a;风险事件、概率、…

一文读懂!跨境支付业务详解

随着全球电子商务的增长&#xff0c;面向进出口贸易的跨境支付应运而生。据华创微课统计&#xff0c;预计2022年&#xff0c;跨境支付交易总额将突破156万亿美元&#xff0c;跨境服务市场有巨大的发展潜力。在跨境服务的链路中&#xff0c;支付是关键的一环。像shopee、亚马逊、…

JS从0到1——653. 钞票

文章目录QuestionIdeasCodeQuestion 在这个问题中&#xff0c;你需要读取一个整数值并将其分解为多张钞票的和&#xff0c;每种面值的钞票可以使用多张&#xff0c;并要求所用的钞票数量尽可能少。 请你输出读取值和钞票清单。 钞票的可能面值有 100,50,20,10,5,2,1 。 经过…

算法小抄7-二分枚举

二分枚举是二分查找的一种应用(这是我自己起的名字hhh,可别在外面说这是二分枚举的题),这类题相对于二分查找趣味性会更强一些,但是同时也需要更理解二分法的本质--枚举 爱吃香蕉的珂珂 题目链接 题目大意 珂珂想用最慢的速度,在警卫回来之前吃掉所有的香蕉,数组中的每一个数…

Eclipse 快捷键(更新中)

Eclipse 快捷键菜单栏快捷键 序号Eclipse里的快捷键快捷键说明1 Alt Shift N 新建项目/类2 Ctrl W / Ctrl Shift W 关闭当前编辑器界面 / 关闭全部打开的编辑器界面3 Ctrl S 保存 4 F5 刷新 编辑器快捷键 序号Eclipse里的快捷键快捷键说明1 Ctrl Alt / 输入关键字的…