12

I used my training dataset to fit cluster using kmenas function

fit <- kmeans(ca.data, 2);

How can I use fit object to predict cluster membership in a new dataset?

Thanks

user333
  • 7,211

4 Answers4

17

One of your options is to use cl_predict from the cluepackage (note: I found this through googling "kmeans R predict").

Nick Sabbe
  • 12,819
  • 2
  • 37
  • 47
6

Check this complete answer. The code you need is:

clusters <- function(x, centers) {
  # compute squared euclidean distance from each sample to each cluster center
  tmp <- sapply(seq_len(nrow(x)),
                function(i) apply(centers, 1,
                                  function(v) sum((x[i, ]-v)^2)))
  max.col(-t(tmp))  # find index of min distance
}

# create a simple data set with two clusters
set.seed(1)
x <- rbind(matrix(rnorm(100, sd = 0.3), ncol = 2),
           matrix(rnorm(100, mean = 1, sd = 0.3), ncol = 2))
colnames(x) <- c("x", "y")
x_new <- rbind(matrix(rnorm(10, sd = 0.3), ncol = 2),
               matrix(rnorm(10, mean = 1, sd = 0.3), ncol = 2))
colnames(x_new) <- c("x", "y")

cl <- kmeans(x, centers=2)

all.equal(cl[["cluster"]], clusters(x, cl[["centers"]]))
# [1] TRUE
clusters(x_new, cl[["centers"]])
# [1] 2 2 2 2 2 1 1 1 1 1
Pablo Casas
  • 578
  • 6
  • 9
  • 2
    It's been a while from my answer; now I recommend to build a predictive model (like the random forest), using the cluster variable as the target. I got better results in practice with this approach. For example, in clustering all variables are equally important, while the predictive model can automatically choose the ones that maximize the prediction of the cluster. This approach is also compatible with the deployment on production (i.e. predicting to which cluster the case belongs). – Pablo Casas Jun 20 '17 at 16:07
4

You could write an S3 method to predict the classes for a new dataset. The following minimises the sum-of-squares. It is used as for other predict functions: newdata should match the structure of your input to kmeans, and the method argument should work as for fitted.kmeans

predict.kmeans <- function(object,
                           newdata,
                           method = c("centers", "classes")) {
  method <- match.arg(method)

centers <- object$centers ss_by_center <- apply(centers, 1, function(x) { colSums((t(newdata) - x) ^ 2) }) best_clusters <- apply(ss_by_center, 1, which.min)

if (method == "centers") { centers[best_clusters, ] } else { best_clusters } }

I wish there was a predict.kmeans in the existing stats namespace.

Russ Hyde
  • 141
3

Another option is to use the predict method from flexclust package after converting your stats::kmeans model to his kcca type.

Augusto
  • 363