52matlab技术网站,matlab教程,matlab安装教程,matlab下载

 找回密码
 立即注册
搜索
热搜: 活动 交友 discuz
查看: 379|回复: 1

用t-SNE降维查看分类效果

[复制链接]

92

主题

164

帖子

1870

积分

版主

Rank: 7Rank: 7Rank: 7

积分
1870
发表于 2020-2-26 14:57:01 | 显示全部楼层 |阅读模式
本帖最后由 matlab的旋律 于 2020-3-9 15:42 编辑

#####################################加载库#######################################
from sklearn.datasets import load_digits
from sklearn.manifold import TSNE
from matplotlib import pyplot as plt

##################################################################################

digits = load_digits()#加载数据
embeddings = TSNE().fit_transform(digits.data)#t-SNE降维,默认降为二维
vis_x = embeddings[:, 0]#0维
vis_y = embeddings[:, 1]#1维

############################提取不同label对应的坐标#################################
index0 = [i for i in range(len(digits.target)) if digits.target == 0]
index1 = [i for i in range(len(digits.target)) if digits.target
== 1]
index2 = [i for i in range(len(digits.target)) if digits.target
== 2]
index3 = [i for i in range(len(digits.target)) if digits.target
== 3]
index4 = [i for i in range(len(digits.target)) if digits.target
== 4]
index5 = [i for i in range(len(digits.target)) if digits.target
== 5]
index6 = [i for i in range(len(digits.target)) if digits.target
== 6]
index7 = [i for i in range(len(digits.target)) if digits.target
== 7]
index8 = [i for i in range(len(digits.target)) if digits.target
== 8]
index9 = [i for i in range(len(digits.target)) if digits.target

###################################################################################


#######################################绘图########################################
colors=['b', 'c', 'y', 'm', 'r', 'g', 'k','yellow','yellowgreen','wheat']
plt.scatter(vis_x[index0], vis_y[index0], c=colors[0], cmap='brg', marker='h',label='0')
plt.scatter(vis_x[index1], vis_y[index1], c=colors[1], cmap='brg',marker='<',label='1')
plt.scatter(vis_x[index2], vis_y[index2], c=colors[2], cmap='brg',marker='x',label='2')
plt.scatter(vis_x[index3], vis_y[index3], c=colors[3], cmap='brg',marker='.',label='3')
plt.scatter(vis_x[index4], vis_y[index4], c=colors[4], cmap='brg',marker='p',label='4')
plt.scatter(vis_x[index5], vis_y[index5], c=colors[5], cmap='brg',marker='>',label='5')
plt.scatter(vis_x[index6], vis_y[index6], c=colors[6], cmap='brg',marker='^',label='6')
plt.scatter(vis_x[index7], vis_y[index7], c=colors[7], cmap='brg',marker='d',label='7')
plt.scatter(vis_x[index8], vis_y[index8], c=colors[8], cmap='brg',marker='s',label='8')
plt.scatter(vis_x[index9], vis_y[index9], c=colors[9], cmap='brg',marker='o',label='9')

plt.title(u't-SNE')
plt.legend()
plt.show()



本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有帐号?立即注册

x
回复

使用道具 举报

92

主题

164

帖子

1870

积分

版主

Rank: 7Rank: 7Rank: 7

积分
1870
 楼主| 发表于 2020-2-26 16:30:16 | 显示全部楼层
#####################################加载库#######################################
from sklearn.datasets import load_digits
from sklearn.manifold import TSNE
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
plt.rcParams['font.sans-serif']=['Microsoft YaHei']#设置汉字
plt.rcParams['axes.unicode_minus']=False
##################################################################################

digits = load_digits()#加载数据
embeddings = TSNE(n_components=3).fit_transform(digits.data)#t-SNE降维,这里设置为三维
vis_x = embeddings[:, 0]#0维
vis_y = embeddings[:, 1]#1维
vis_z = embeddings[:, 2]#2维

############################提取不同label对应的坐标#################################
index0 = [i for i in range(len(digits.target)) if digits.target == 0]
index1 = [i for i in range(len(digits.target)) if digits.target == 1]
index2 = [i for i in range(len(digits.target)) if digits.target == 2]
index3 = [i for i in range(len(digits.target)) if digits.target == 3]
index4 = [i for i in range(len(digits.target)) if digits.target == 4]
index5 = [i for i in range(len(digits.target)) if digits.target == 5]
index6 = [i for i in range(len(digits.target)) if digits.target == 6]
index7 = [i for i in range(len(digits.target)) if digits.target == 7]
index8 = [i for i in range(len(digits.target)) if digits.target == 8]
index9 = [i for i in range(len(digits.target)) if digits.target == 9]
###################################################################################


#######################################绘图########################################
colors=['b', 'c', 'y', 'm', 'r', 'g', 'k','yellow','yellowgreen','wheat']
fig = plt.figure()
ax4 = Axes3D(fig)
ax4.scatter(vis_x[index0], vis_y[index0], vis_z[index0], c=colors[0], cmap='brg', marker='h',label='0')
ax4.scatter(vis_x[index1], vis_y[index1], vis_z[index1], c=colors[1], cmap='brg',marker='<',label='1')
ax4.scatter(vis_x[index2], vis_y[index2], vis_z[index2], c=colors[2], cmap='brg',marker='x',label='2')
ax4.scatter(vis_x[index3], vis_y[index3], vis_z[index3], c=colors[3], cmap='brg',marker='.',label='3')
ax4.scatter(vis_x[index4], vis_y[index4], vis_z[index4], c=colors[4], cmap='brg',marker='p',label='4')
ax4.scatter(vis_x[index5], vis_y[index5], vis_z[index5], c=colors[5], cmap='brg',marker='>',label='5')
ax4.scatter(vis_x[index6], vis_y[index6], vis_z[index6], c=colors[6], cmap='brg',marker='^',label='6')
ax4.scatter(vis_x[index7], vis_y[index7], vis_z[index7], c=colors[7], cmap='brg',marker='d',label='7')
ax4.scatter(vis_x[index8], vis_y[index8], vis_z[index8], c=colors[8], cmap='brg',marker='s',label='8')
ax4.scatter(vis_x[index9], vis_y[index9], vis_z[index9], c=colors[9], cmap='brg',marker='o',label='9')

ax4.grid(False)#去掉网格线
ax4.patch.set_facecolor(color="green")

plt.title('t-SNE降维分类效果图')
plt.legend()
plt.show()



本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有帐号?立即注册

x
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

Archiver|手机版|小黑屋|52matlab技术网站 ( 粤ICP备14005920号-5 )

GMT+8, 2020-4-1 11:09 , Processed in 0.065315 second(s), 20 queries .

Powered by Discuz! X3.2 Licensed

© 2001-2013 Comsenz Inc.

快速回复 返回顶部 返回列表