A

AdamDaitu

V1

2022/12/31阅读:27主题:科技蓝

R语言——支持向量机手写数字分类

R语言——支持向量机手写数字分类

1: 支持向量机模型简介

支持向量机(SVM)分类的基本思想是求解能够正确划分数据集并且几何间隔最大的分离超平面,利用该超平面使得任何一类的数据划分的相当均匀。对于线性可分的训练数据而言,线性可分离超平面有无穷多个,但是几何间隔最大的分离超平面是唯一的。

间隔最大化的直观解释是:对训练数据集找到几何间隔最大的超平面,意味着以充分大的确信度对训练数据进行分类。而最大间隔是由支持向量来决定的,针对二分类问题,支持向量是指距离划分超平面最近的正类的点和负类的点。

使用支持向量机算法时,由于并不是所有的问题都是线性可分的,这就需要使用不同的核函数。正是因为核函数的引入才使支持向量机能够训练出任意形状的超平面。使用核函数的方式又称为核技巧,核技巧可以将需要处理的问题映射到一个更高维度的空间,从而对在低维不好处理的问题转在高纬空间中进行处理,进而得到精度更高的分类器。常用的核函数有线性核函数多项式核函数径向基核函数sigmoid核函数等

支持向量机的实际使用中,很少会有一个超平面将不同类别的数据完全分开,所以对划分边界近似线性的数据使用软间隔的方法,允许数据跨过划分超平面,这样就会使得一些样本分类错误。通过对分类错误的样本施加惩罚,可在最大间隔和确保划分超平面边缘的正确分类之间寻找一个平衡。

R语言 可使用 e1071 包实现支持向量机的分类、回归、异常值的识别,及其可视化分析等,下面将会介绍如何使用SVM算法对手写数字数据进行分类。

2: 手写数字数据准备

使用SVM对数据分类之前,先导入会使用到的包,并读取数据的训练集和测试集,程序如下:

library(e1071);library(readr);library(Metrics);library(Rtsne)
## 导入数据
train_digit <- read_csv("data/chap12/digit_train.csv",col_names = FALSE)
test_digit <- read_csv("data/chap12/digit_test.csv",col_names = FALSE)
nrow(train_digit)
nrow(test_digit)
## [1] 3823
## [1] 1797

从上面的程序输出可知,手写数字数据集,一共有3823个训练样本,1797个测试样本。下面从训练集中随机的挑选300个样本,对手写字体图像进行可视化查看,运行下面的程序程序可获得图像1。

## 可视化训练集中的几个样本
set.seed(123)
index <- sample(1000,300)
par(mfrow = c(15,20),mai=c(0.01,0.01,0.01,0.01))
for(ii in 1:length(index)){
  im <- matrix(unname(unlist(train_digit[index[ii],1:64])),
               nrow=8,ncol = 8,byrow = F)
  image(im,col = gray(seq(01, length = 256)),xaxt= "n", yaxt= "n")
}

图1 手写数字样本可视化
图1 手写数字样本可视化

3: 支持向量机分类

下面使用训练数据集训练一个支持向量机分类器,使用svm()函数,通过参数kernel ="radial"指定使用径向基核函数,并计算获得的模型在训练集和测试集上的预测精度,程序和输出如下:

## 使用原始数据建立支持向量机模型并预测精度
train_digit$X65 <- as.factor(train_digit$X65)
test_digit$X65 <- as.factor(test_digit$X65)
digitsvm <- svm(X65 ~., data = train_digit,kernel ="radial",scale = FALSE)
digitsvm
## Call:
## svm(formula = X65 ~ ., data = train_digit, kernel = "radial", scale = FALSE)
## Parameters:
##    SVM-Type:  C-classification 
##  SVM-Kernel:  radial 
##        cost:  1 
## Number of Support Vectors:  3811
## 对训练集和测试集进行预测,查看模型的精度
train_pre <- predict(digitsvm,train_digit,type = "class")
test_pre <- predict(digitsvm,test_digit,type = "class")
sprintf("支持向量机训练集上预测精度:%4f",accuracy(train_digit$X65,train_pre))
sprintf("支持向量机测试集上预测精度:%4f",accuracy(test_digit$X65,test_pre))
## [1] "支持向量机训练集上预测精度:1.000000"
## [1] "支持向量机测试集上预测精度:0.562048"

从上面的输出结果中,可以发现获得的模型在训练集上的精度为百分之百,但是在测试集上的精度并不高。可能的原因是:没有对数据进行数据标准化或者特征提取等操作。下面使用t-SNE算法获取数据的降维特征,然后再训练新的SVM模型,并查看模型的效果。

4: SVM对t-SNE特征分类

下面的程序是首先利用t-SNE算法,将手写数字数据集降维到2D空间中,然后再使用相同的参数训练一个SVM分类器,并可视化新的分类器再训练集上的分界面,程序和输出如下所示:

## 利用TSNE算法将训练集和测试集降维到二维空间中
digit_all <- rbind(train_digit,test_digit)##合并训练数据和测试数据
system.time(   # 可以获取TSNE算法的消耗时间
digit_tsne <- Rtsne(digit_all[,1:64],dims = 2,pca = FALSE
                    perplexity = 50,theta = 0.0,max_iter = 500)
)
##    user  system elapsed 
## 257.343  82.169 340.113
## 将提取的TSNE特征切分为训练数据和测试数据
train_digit_tsne <- as.data.frame(digit_tsne$Y[1:nrow(train_digit),])
train_digit_tsne$label <- as.factor(train_digit$X65)
test_digit_tsne <- as.data.frame(digit_tsne$Y[-c(1:nrow(train_digit)),])
test_digit_tsne$label <- as.factor(test_digit$X65)
## 训练支持向量机模型
set.seed(123)   # radial核SVM分类器
digitsvm <- svm(label ~., data = train_digit_tsne,kernel ="radial",scale = FALSE)
digitsvm
## Call:
## svm(formula = label ~ ., data = train_digit_tsne, kernel = "radial", 
##     scale = FALSE)
## Parameters:
##    SVM-Type:  C-classification 
##  SVM-Kernel:  radial 
##        cost:  1 
## Number of Support Vectors:  1431
## 使用训练数据可视化获得的SVM分类器对数据的切分情况
par(mfrow = c(1,1))
plot(digitsvm,data = train_digit_tsne,V1~V2,
     symbolPalette = rainbow(10),color.palette = terrain.colors)
 
图2 SVM分界面可视化
图2 SVM分界面可视化

从图2所示的分界面可以看出,针对降维后的数据特征,使用SVM算法能够很好的将不同类的数据进行划分。

下面是计算新的SVM分类器,在训练集和测试集上的预测精度的程序,从输出结果中可以发现,模型在训练集和测试集上的预测精度都很高,预测准确率接近百分之百。

## 对训练集和测试集进行预测,查看模型的精度
train_pre <- predict(digitsvm,train_digit_tsne,type = "class")
test_pre <- predict(digitsvm,test_digit_tsne,type = "class")
sprintf("支持向量机训练集上预测精度:%4f",
        accuracy(train_digit_tsne$label,train_pre))
sprintf("支持向量机测试集上预测精度:%4f",
        accuracy(test_digit_tsne$label,test_pre))
## [1] "支持向量机训练集上预测精度:0.991106"
## [1] "支持向量机测试集上预测精度:0.989983"

欢迎关注我们

欢迎加入我们的QQ交流群获取使用的数据:837977579

欢迎关注我们的微信公众号获取更多内容

今天的分享就到这里了,敬请期待下一篇!

最后欢迎大家分享转发,您的点赞是对我的鼓励和肯定!

往期回顾

拓展阅读

1:R语言——常用的数据可视化包总结

2:R语言——ggTimeSeries包可视化日历热力图

3:R语言——数据可视化方法

4:R语言——Ridge和Lasso回归分析

5:R语言——大数据分析实战系列丛书推荐

分类:

人工智能

标签:

机器学习

作者介绍

A
AdamDaitu
V1