Last updated: 2018-02-11

Code version: d53227e

This file is just generate the structure heatmap of rank 3 data.

DataMaker = function(sigma = 2,seed_num = 99){
  set.seed(seed_num)
  N = 150
  P = 240
  K = 3
  L_true = array(0,dim = c(N,K))
  F_true = array(0,dim = c(P,K))
  F_true[1:80,1] = rnorm(80,0,0.5)
  F_true[81:160,2] = rnorm(80,0,1)
  F_true[161:240,3] = rnorm(80,0,2)
  L_true[1:10,1] = rnorm(10,0,2)
  L_true[11:60,2] = rnorm(50,0,1)
  L_true[61:150,3] = rnorm(90,0,0.5)
  G = L_true %*% t(F_true)
  # generate Y
  E = matrix(rnorm(N*P,0,sigma),nrow=N)
  Y = L_true %*% t(F_true) + E
  return(list(Y = Y, L_true = L_true, F_true = F_true, E = E))
}

CVPMD_softImpute=function(Y,c_s,K,fold = 10, method = "PMD"){
  N = dim(Y)[1]
  P = dim(Y)[2]
  colindex = matrix(sample(P,P),ncol = fold)
  rowindex = matrix(sample(N,N),ncol = fold)

  missing= array(0,dim = c(fold,N,P))
  foldindex = array(0,dim = c(fold,fold,2))
  for(i in 1:fold){
    for(j in 1:fold){
      foldindex[i,j,1] = i
      foldindex[i,j,2] = (i+j) %% fold
    }
  }
  foldindex[which(foldindex == 0)] = fold
  for(i in 1:fold){
    missing[i, , ] = Y
    for(j in 1:fold){
      missing[i,rowindex[,foldindex[j,i,1]],colindex[,foldindex[j,i,2]]] = NA
    }
    missing[i,,which(colSums(missing[i,,],na.rm = T) ==0)] = Y[,which(colSums(missing[i,,],na.rm = T) ==0)]
  }
  # c_s is the candicate of shrinkage parameter
  n_s = length(c_s)
  # rmse for each grids
  CVRMSE = rep(0,n_s)
  minrmse = Inf
  opt_s = 0
  # for each candidate, we run it N_sim times
  for(t_s in 1:n_s){
    # for each grid
    # each time we set the rmse to zeros
    rmse = rep(0,fold)
    for(i in 1:fold){
      if(method == "PMD"){
        res_log = capture.output({out = PMD(missing[i,,], sumabs = c_s[t_s], sumabsv = NULL, sumabsu = NULL,K = K)})
      }else{
        out = softImpute(missing[i,,], rank.max = K,lambda = c_s[t_s])
      }
      if(length(out$d)==1){
        misshat = (out$d) * out$u %*% t(out$v)
      }else{
        misshat = out$u %*%  diag(out$d) %*% t(out$v)
      }
      for(j in 1:fold){
        # for each fold j
        rmse[i] = rmse[i] + sum((Y[rowindex[,foldindex[j,i,1]],colindex[,foldindex[j,i,2]]] -
                                   misshat[rowindex[,foldindex[j,i,1]],colindex[,foldindex[j,i,2]]])^2,na.rm = TRUE)
      }
    } #get the result for one run
    CVRMSE[t_s] = CVRMSE[t_s] + sqrt(sum(rmse)/(N*P))
    if(CVRMSE[t_s] < minrmse){
      minrmse = CVRMSE[t_s]
      opt_s = c_s[t_s]
    }
  }
  return(list(opt_s = opt_s, output = CVRMSE))
}

PMA.wrapper = function(Y_data, ngrids = 10, K=3, fold = 10){
  library(PMA)
  Y = Y_data$Y
  L_true = Y_data$L_true
  F_true = Y_data$F_true
  E = Y_data$E
  N = dim(Y)[1]
  P = dim(Y)[2]
  c_s = seq(0.1,0.9,len=ngrids)
  cvout = CVPMD_softImpute(Y,c_s,K ,fold , method = "PMD")
  res_log = capture.output({out = PMD(Y,sumabsu = NULL, sumabsv = NULL, sumabs = cvout$opt_s ,K = K)})
  Y_hat = out$u %*% diag(out$d) %*%  t(out$v)
  return(Y_hat)
}


softImpute.wrapper = function(Y_data, ngrids = 10, K = 3, fold = 10){
  library(softImpute)
  Y = Y_data$Y
  L_true = Y_data$L_true
  F_true = Y_data$F_true
  E = Y_data$E
  N = dim(Y)[1]
  P = dim(Y)[2]
  c_s = seq(0,100,len=ngrids)
  cvout = CVPMD_softImpute(Y,c_s,K ,fold , method = "softImpute")
  out = softImpute(Y, rank.max = K,lambda = cvout$opt_s)
  if(length(out$d)==1){
    Y_hat = (out$d) * out$u %*% t(out$v)
  }else{
    Y_hat = out$u %*%  diag(out$d) %*% t(out$v)
  }
  return(Y_hat)
}

SVD.wrapper = function(Y_data,K = 3){
  Y = Y_data$Y
  L_true = Y_data$L_true
  F_true = Y_data$F_true
  E = Y_data$E
  N = dim(Y)[1]
  P = dim(Y)[2]
  gsvd = svd(Y,nu = 3,nv = 3)
  Y_hat = (gsvd$u[,1:K] %*% diag(gsvd$d[1:K]) %*% t(gsvd$v[,1:K]))
  return(Y_hat)
}

SSVD.wrapper = function(Y_data,K = 3){
  library(ssvd)
  Y = Y_data$Y
  L_true = Y_data$L_true
  F_true = Y_data$F_true
  E = Y_data$E
  N = dim(Y)[1]
  P = dim(Y)[2]
  gssvd = ssvd::ssvd(Y,method = "method",r = K)
  Y_hat = (gssvd$u %*% diag(gssvd$d) %*% t(gssvd$v))
  return(Y_hat)
}

flashBF.wrapper = function(Y_data,K = 3){
  Y = Y_data$Y
  L_true = Y_data$L_true
  F_true = Y_data$F_true
  E = Y_data$E
  N = dim(Y)[1]
  P = dim(Y)[2]
  library(ebnm)
  library(flashr)
  data =  flashr::flash_set_data(Y)
  f_greedy = flashr::flash_add_greedy(data,verbose=F,var_type = "constant", K = K)
  f = flashr::flash_backfit(data,f_greedy,var_type = "constant")
  Y_hat = f$EL %*% t(f$EF)
  return(Y_hat)
}

flash.wrapper = function(Y_data, K = 3){
  # missindex is a matirx with 3 column here: i j x
  # Y has miss value already
  Y = Y_data$Y
  L_true = Y_data$L_true
  F_true = Y_data$F_true
  E = Y_data$E
  N = dim(Y)[1]
  P = dim(Y)[2]
  data =  flashr::flash_set_data(Y)
  g_flash = flashr::flash_add_greedy(data,verbose=F,var_type = "constant", K = K)
  Y_hat = g_flash$EL %*% t(g_flash$EF)
  return(Y_hat)
}

average of 100 simulations

HMatrix = readRDS("../data/simulation/rankthree/HeatMap.rds")
T = 100
AVG_HM = matrix(0,nrow = 36000,ncol = 7)
for(i in 1:T){
  AVG_HM = AVG_HM + abs(HMatrix[i,,])
}
AVG_HM = AVG_HM/100
LF_1 = matrix(AVG_HM[,1],ncol = 240, nrow = 150)
LF_2 = matrix(AVG_HM[,2],ncol = 240, nrow = 150)
LF_3 = matrix(AVG_HM[,3],ncol = 240, nrow = 150)
LF_4 = matrix(AVG_HM[,4],ncol = 240, nrow = 150)
LF_5 = matrix(AVG_HM[,5],ncol = 240, nrow = 150)
LF_6 = matrix(AVG_HM[,6],ncol = 240, nrow = 150)
LF_true = matrix(AVG_HM[,7],ncol = 240, nrow = 150)
library(reshape2)
melted_cormat_1 <- melt(LF_1, na.rm = TRUE)
melted_cormat_2 <- melt(LF_2, na.rm = TRUE)
melted_cormat_3 <- melt(LF_3, na.rm = TRUE)
melted_cormat_4 <- melt(LF_4, na.rm = TRUE)
melted_cormat_5 <- melt(LF_5, na.rm = TRUE)
melted_cormat_6 <- melt(LF_6, na.rm = TRUE)
melted_cormat_true <- melt(LF_true, na.rm = TRUE)
# Heatmap
library(ggplot2)
plot_HP = function(title_name = "truth", melted_cormat){
  p1 = ggplot(data = melted_cormat, aes(Var2, Var1, fill = value))+
   geom_tile(color = "white")+
    #scale_fill_gradient2(low = "blue", high = "red", mid = "white", 
     #midpoint = 0, limit = c(-0.81,0.81), space = "Lab") + labs(title = title_name, y = "samples", x = "variables") +
    scale_fill_gradient2(low = "white", high = "grey6", mid = "grey9", 
     midpoint = 0.41, limit = c(0,0.82), space = "Lab") + labs(title = title_name, y = "samples", x = "variables") +
    theme_minimal() + theme(legend.position="none",plot.title = element_text(size = 12.9, face = "bold")) 
  p1
}

p0 = plot_HP(title_name = "Latent structure (truth)", melted_cormat = melted_cormat_true)
p1 = plot_HP(title_name = "PMD.cv1", melted_cormat = melted_cormat_1)
p2 = plot_HP(title_name = "flash", melted_cormat = melted_cormat_2)
p3 = plot_HP(title_name = "flash_pn", melted_cormat = melted_cormat_3)
p4 = plot_HP(title_name = "SVD", melted_cormat = melted_cormat_4)
p5 = plot_HP(title_name = "SSVD", melted_cormat = melted_cormat_5)
p6 = plot_HP(title_name = "SI.cv", melted_cormat = melted_cormat_6)
gridExtra::grid.arrange(p2,p3,p1,p6,p5,p4,ncol = 3)

Session information

sessionInfo()
R version 3.3.0 (2016-05-03)
Platform: x86_64-apple-darwin13.4.0 (64-bit)
Running under: OS X 10.13.3 (unknown)

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] denoiseR_1.0    scales_0.4.1    MASS_7.3-47     reshape2_1.4.3 
 [5] flashr_0.4-6    workflowr_0.4.0 rmarkdown_1.6   ggplot2_2.2.1  
 [9] R.matlab_3.6.1  softImpute_1.4  Matrix_1.2-11   PMA_1.0.9      
[13] impute_1.48.0   plyr_1.8.4      ssvd_1.0       

loaded via a namespace (and not attached):
 [1] ashr_2.2-3           lattice_0.20-35      colorspace_1.3-2    
 [4] htmltools_0.3.6      yaml_2.1.16          rlang_0.1.6         
 [7] R.oo_1.21.0          withr_2.1.1          R.utils_2.5.0       
[10] foreach_1.4.4        stringr_1.2.0        munsell_0.4.3       
[13] gtable_0.2.0         R.methodsS3_1.7.1    devtools_1.13.3     
[16] codetools_0.2-15     leaps_3.0            evaluate_0.10.1     
[19] memoise_1.1.0        labeling_0.3         knitr_1.18          
[22] pscl_1.5.2           doParallel_1.0.11    irlba_2.2.1         
[25] parallel_3.3.0       curl_2.8.1           Rcpp_0.12.14        
[28] flashClust_1.01-2    backports_1.1.2      scatterplot3d_0.3-40
[31] truncnorm_1.0-7      gridExtra_2.3        digest_0.6.13       
[34] stringi_1.1.6        flashr2_0.4-0        grid_3.3.0          
[37] rprojroot_1.2        tools_3.3.0          magrittr_1.5        
[40] lazyeval_0.2.0       tibble_1.3.4         cluster_2.0.6       
[43] FactoMineR_1.36      SQUAREM_2017.10-1    httr_1.3.0          
[46] rstudioapi_0.6       iterators_1.0.9      R6_2.2.2            
[49] git2r_0.19.0        

This R Markdown site was created with workflowr