Binary neural nets in R, part 2: Helmholtz Machine

github source for this project

In the last post, I looked at coding a restricted Boltzmann machine (RBM) in R. In this one, I will compare the algorithm and its results to the Helmholtz Machine, which uses the wake-sleep algorithm.

The basic idea behind the wake-sleep algorithm is that the model learns in two alternating phases. In the wake phase, random values for the hidden nodes are generated according to the expectations of the recognition network, then the generative network weights are adjusted. In the sleep phase, random values for the hidden nodes and the data are generated according to the generative network, then the recognition weights are adjusted. This process trains both the recognition and generative network, and at a good solution, the two should be nearly the same. The HM with the wake-sleep algorithm differs from the RBM in that the RBM has only one set of weights for both the top-down and bottom-up expectations for each node.

We can generate data for an HM the same way as the RBM, first by setting up a list L with the model structure:

> L
[[1]]
[1] "bias"

[[2]]
 [1] "x1"  "x2"  "x3"  "x4"  "x5"  "x6"  "x7"  "x8"  "x9"  "x10" "x11" "x12" "x13" "x14" "x15" "x16" "x17" "x18" "x19" "x20"

[[3]]
[1] "h1_1" "h1_2" "h1_3" "h1_4" "h1_5" "h1_6" "h1_7" "h1_8"

Then I used a single adjacency matrix for the entire generative model, Wg, so that I can repeatedly the same matrix multiplication, using L to subset the result each time:

rRBM <- function(n, Wg) {
  X<-matrix(0, n, ncol(Wg), dimnames=list(NULL,colnames(Wg)) )
  X[,1]<-1
  for(i in Q:2 ){
    X[, L[[i]] ] <- (X %*% t(Wg))[, L[[i]] ]
    X[, L[[i]] ] <- matrix( rbinom(prod(dim(X[, L[[i]] ])), size=1, c(lgt(X[, L[[i]] ])) ),dim(X[, L[[i]] ])[1],dim(X[, L[[i]] ])[2] )
  }
  return(X)
}

The wake-sleep algorithm code is given here, first with the wake phase, using Wr to generate node values from the bottom up:

for(iter in 1:maxIter){
  dat.chunk<-dat[sample(1:nrow(dat), size=chunkSize, replace=FALSE),]
  X[,L[[2]]]<-dat.chunk

  lr<-lr.s*temp[iter]
  for(i in 3:Q){
    X[, L[[i]] ] <- lgt(X %*% t(Wr))[, L[[i]] ]
    X[, L[[i]] ] <- matrix( rbinom(prod(dim(X[, L[[i]] ])), size=1, c(X[, L[[i]] ]) ),dim(X[, L[[i]] ])[1],dim(X[, L[[i]] ])[2] )
  }
  X.p <- lgt(X %*% t(Wg))
  X.p[,1]<-1
  Wg.gr<- t((t(X)%*%(X-X.p))/nrow(X))
  Wg[Wg.f]<- Wg[Wg.f]  +  lr*Wg.gr[Wg.f]

Then the sleep phase, using Wg to generate node values from the top-down:

  X2<-X
  for(i in Q:2 ){
    X2[, L[[i]] ] <- lgt(X2 %*% t(Wg))[, L[[i]] ]
    X2[, L[[i]] ] <- matrix( rbinom(prod(dim(X2[, L[[i]] ])), size=1, c(X2[, L[[i]] ]) ),dim(X2[, L[[i]] ])[1],dim(X2[, L[[i]] ])[2] )
  }
  X.p <- lgt(X2 %*% t(Wr))
  X.p[,1]<-1
  Wr.gr<- t((t(X2)%*%(X2-X.p))/nrow(X2))
  Wr[Wr.f]<- Wr[Wr.f]  +  lr*Wr.gr[Wr.f]

Note that both use the same local delta rule to obtain the gradient. It is computed as the the outer product of the observed and predicted values.

Here are the some results from an initial run with the same dimensions as the RBM, again looking at a redacted (|r|>.7 only) correlation matrix of the data-generating weights with the estimated weights:

Iteration 61900 

Learning rate: 0.3 

Hidden Layer 1 and 2 weight vector correlations: 0.3 
     bias h1_1 h1_2  h1_3  h1_4  h1_5 h1_6 h1_7 h1_8
bias                                                
h1_1                -0.78                           
h1_2      0.98                                      
h1_3                      -0.99                     
h1_4                                                
h1_5           0.99                                 
h1_6                                            0.99
h1_7                                  0.99          
h1_8                            -0.99               
     h2_1  h2_2 h2_3  h2_4
h2_1            -0.8      
h2_2                 -0.83
h2_3                      
h2_4      -0.76           

Note that I had to use the first matrix to depermute the first hidden layer in order to calculate the correlations for the second hidden layer. It’s missing one node on both accounts, but pretty cool to see both layers recovering the data-generating weights as promised! Since the learning is stochastic, it will (in theory) learn those last couple weight vectors in the course of eternity.

I increased the complexity of the simulated data and the model, then left it running for a much longer time. It looks like it was pretty successful:

Iteration 900000 

Learning rate: 0.1 

Hidden Layer 1 and 2 weight vector correlations: 0.1 
      bias h1_1  h1_2 h1_3 h1_4  h1_5 h1_6 h1_7  h1_8  h1_9 h1_10 h1_11 h1_12 h1_13 h1_14 h1_15 h1_16 h1_17 h1_18 h1_19 h1_20 h1_21 h1_22 h1_23 h1_24 h1_25 h1_26 h1_27 h1_28 h1_29 h1_30
bias                                                                                                                                                                                     
h1_1                                                                                                               0.98                                                                  
h1_2                                                                                                                                                         0.93                        
h1_3                                                                                                                                -0.99                                                
h1_4                                                                                                                                                              -0.99                  
h1_5                                       0.99                                                                                                                                          
h1_6                                                  -0.99                                                                                                                              
h1_7                                                                                                        -0.97                                                                        
h1_8                                                                           0.98                                                                                                      
h1_9                  0.98                                                                                                                                                               
h1_10                                                                                                                                                                                    
h1_11                                                                                                                                      0.98                                          
h1_12      0.95                                                                                                                                                                          
h1_13                                                             -0.83                                                                                                                  
h1_14                                                                                           -0.96                                                                                    
h1_15                                                                                                                                                  0.98                              
h1_16                                                                                                                                                                          0.97      
h1_17           -0.96                                                                                                                                                                    
h1_18                                                                                                                                                                                    
h1_19                                           -0.95                                                                                                                                    
h1_20                                                                               -0.98                                                                                                
h1_21                                                                    0.98                                                                                                            
h1_22                                                       -0.96                                                                                                                        
h1_23                           -0.98                                                                                                                                                    
h1_24                                                                                                                          0.97                                                      
h1_25                                                                                                                                           -0.96                                    
h1_26                                                                                                 -0.97                                                                              
h1_27                                                                                                                                                                                0.93
h1_28                                                                                     -0.93                                                                                          
h1_29                      0.98                                                                                                                                                          
h1_30                                                                                                                                                                    0.97            
     h2_1 h2_2 h2_3 h2_4  h2_5 h2_6  h2_7 h2_8
h2_1                     -0.94                
h2_2      0.87                                
h2_3                           0.89           
h2_4                0.89                      
h2_5                                       0.8
h2_6           0.96                           
h2_7                                -0.93     
h2_8 0.78