聚类结果可视化

R
可视化
factoextra
ggplot2
ggforce
Author

Rui

Published

October 15, 2022

使用 factoextra 可视化聚类结果

导入程序包和数据

使用鸢尾花数据集,剔除最后一列。

Code
library(tidyverse)
library(cluster)
library(factoextra)

data <- iris[ , -5]

确定聚类数

Code
fviz_nbclust(
  data, 
  kmeans, 
  k.max = 10,
  method = "wss", # within sum of squares
  diss = get_dist(data, method = "spearman")
)

根据手肘法则,发现曲线在聚类数 k=3 前下降较快而在 k=3 后趋于平缓,故选择聚类数 k=3fviz_nbclust() 函数是基于 ggplot2 语法的,故可以对其进一步美化,比如添加辅助线:

Code


fviz_nbclust(
  data, 
  kmeans, 
  method = "wss"
) +
  geom_vline(xintercept = 3, linetype = 2)

也可以使用其他聚类方法,比如模糊聚类:

Code
fviz_nbclust(
  data, 
  fanny, 
  method = "wss", 
  verbose = FALSE
)  +
  geom_vline(xintercept = 3, linetype = 2)

kmeans 聚类

Code
k1 <- kmeans(data, 3, nstart = 20) # 聚类数为3
# 查看聚类的一些结果
print(k1)
## K-means clustering with 3 clusters of sizes 50, 62, 38
## 
## Cluster means:
##   Sepal.Length Sepal.Width Petal.Length Petal.Width
## 1     5.006000    3.428000     1.462000    0.246000
## 2     5.901613    2.748387     4.393548    1.433871
## 3     6.850000    3.073684     5.742105    2.071053
## 
## Clustering vector:
##   [1] 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
##  [38] 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 3 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
##  [75] 2 2 2 3 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 3 2 3 3 3 3 2 3 3 3 3
## [112] 3 3 2 2 3 3 3 3 2 3 2 3 2 3 3 2 2 3 3 3 3 3 2 3 3 3 3 2 3 3 3 2 3 3 3 2 3
## [149] 3 2
## 
## Within cluster sum of squares by cluster:
## [1] 15.15100 39.82097 23.87947
##  (between_SS / total_SS =  88.4 %)
## 
## Available components:
## 
## [1] "cluster"      "centers"      "totss"        "withinss"     "tot.withinss"
## [6] "betweenss"    "size"         "iter"         "ifault"

整理聚类结果

Code
# 提取类标签并且与原始数据进行合并
result <- cbind(data, cluster = k1$cluster) %>%
  as_tibble %>%
  mutate(cluster = factor(cluster))

result %>% head() %>% knitr::kable()
Sepal.Length Sepal.Width Petal.Length Petal.Width cluster
5.1 3.5 1.4 0.2 1
4.9 3.0 1.4 0.2 1
4.7 3.2 1.3 0.2 1
4.6 3.1 1.5 0.2 1
5.0 3.6 1.4 0.2 1
5.4 3.9 1.7 0.4 1

每一类的样本数量

Code
# 查看每一类的数目
table(result$cluster)
## 
##  1  2  3 
## 50 62 38

可视化聚类结果

fviz_cluster() 自动使用主成分方法将多个变量降维到 2 维以便展示结果。

Code
# 进行可视化展示
fviz_cluster(
  object = k1, 
  data = data,
  palette = c("#2E9FDF", "#E7B800", "#FC4E07"),
  ellipse.type = "euclid",
  star.plot = TRUE, 
  repel = TRUE,
  ggtheme = theme_minimal()
)

一点美化:

Code
fviz_cluster(
  object = k1, 
  data = data, 
  ellipse.type = "euclid", 
  star.plot = TRUE, 
  repel = TRUE, 
  geom = "point", 
  palette = "jco", 
  main = "", 
  ggtheme = theme_minimal()
) +
  theme(axis.title = element_blank())

使用 ggplot2 可视化聚类结果

获取聚类中心:

Code
# 获取聚类中心
data_center <- k1$centers %>% 
  as_tibble() %>% 
  mutate(label = factor(1:nrow(k1$centers)))

data_center %>% knitr::kable()
Sepal.Length Sepal.Width Petal.Length Petal.Width label
5.006000 3.428000 1.462000 0.246000 1
5.901613 2.748387 4.393548 1.433871 2
6.850000 3.073684 5.742105 2.071053 3

作图:

Code
library(ggplot2)

ggplot() +
  geom_point(data = result, aes(x=Sepal.Length, y=Petal.Length, color=cluster)) +
  geom_point(data = data_center, aes(x=Sepal.Length, y=Petal.Length), size=6, shape=8) +
  theme_classic()

使用 ggforce 可视化聚类结果

矩形边界

Code
library(ggplot2)
library(ggforce)

ggplot(data = result, aes(x=Sepal.Length, y=Petal.Length, color=cluster)) +
  geom_point() +
  geom_mark_rect(aes(fill=cluster), alpha = 0.2) +
  theme_bw()

圆形边界

Code
ggplot(data = result, aes(x=Sepal.Length, y=Petal.Length, color=cluster)) +
  geom_point() +
  geom_mark_circle(aes(fill=cluster), alpha = 0.2) +
  theme_bw() +
  theme(
    plot.margin = margin(50, 50, 50, 150),
    legend.background = element_blank()
  )

设置 coord_cartesian(clip = "off") 显示完整的圆形边界:

Code
ggplot(data = result, aes(x=Sepal.Length, y=Petal.Length, color=cluster)) +
  geom_point() +
  geom_mark_circle(aes(fill=cluster), alpha = 0.2) +
  theme_bw() +
  theme(
    plot.margin = margin(50, 50, 50, 150),
    legend.background = element_blank()
  ) +
  coord_cartesian(clip = "off")

椭圆形边界

Code
ggplot(data = result, aes(x=Sepal.Length, y=Petal.Length, color=cluster)) +
  geom_point() +
  geom_mark_ellipse(aes(fill=cluster), alpha = 0.2) +
  theme_bw() +
  theme(
    plot.margin = margin(10, 10, 10, 50),
    legend.background = element_blank()
  )

不规则边界

Code
ggplot(data = result, aes(x=Sepal.Length, y=Petal.Length, color=cluster)) +
  geom_point() +
  geom_mark_hull(aes(fill=cluster), alpha=0.2) +
  theme_bw() +
  theme(
    plot.margin = margin(10, 10, 10, 50),
    legend.background = element_blank()
  )

添加标签

Code
result <- result %>%
  mutate(label = factor(cluster, labels = c("setosa", "versicolor", "virginica")))

ggplot(data = result, aes(x=Sepal.Length, y=Petal.Length, color=label)) +
  geom_point() +
  geom_mark_ellipse(aes(fill=label, label=label), alpha=0.2) +
  theme_bw() +
  theme(
    plot.margin = margin(10, 10, 10, 50),
    legend.background = element_blank()
  )

局部放大

Code
ggplot(result, aes(Petal.Length, Sepal.Length, colour = label)) +
  geom_point() +
  facet_zoom(x = label == "setosa") # 放大setosa的点

使用 base R 中的 plot 函数

Code
plot(
  data, 
  col = k1$cluster, 
  pch = 19, 
  frame = FALSE,
  main = "K-means with k = 3"
)

Code
plot(
  data[, c(1, 3)], 
  col = k1$cluster, 
  pch = 19, 
  frame = FALSE,
  main = "K-means with k = 3"
)

points(
  k1$centers[, c(1, 3)], 
  col = 1:3, 
  pch = 4, 
  cex = 3
)

使用 clusplot()

自动使用主成分方法将数据降维为 2 维:

Code
library(cluster)

clusplot(
  data,
  k1$cluster,
  main = "Cluster Plot",
  color = TRUE,
  labels = 1,
  lines = 0
)

参考

[1] https://mp.weixin.qq.com/s/r_4A3uiG3f89VMSTD2On3g

[2] https://www.jianshu.com/p/50e75cb66651

[3] https://zhuanlan.zhihu.com/p/511283740