现在的位置: 首页 > 综合 > 正文

【deep learning学习笔记】注释yusugomori的RBM代码 — cpp文件 — 模型测试

2014年11月02日 ⁄ 综合 ⁄ 共 835字 ⁄ 字号 评论关闭

产生数据,调用上文的函数,训练RBM模型,并re-construct测试数据,用来验证训练的RBM模型。

void test_rbm() 
{
  srand(0);

  double learning_rate = 0.1;
  int training_epochs = 1000;
  int k = 1;
  
  int train_N = 6;
  int test_N = 2;
  int n_visible = 6;
  int n_hidden = 3;

  // training data
  int train_X[6][6] = {
    {1, 1, 1, 0, 0, 0},
    {1, 0, 1, 0, 0, 0},
    {1, 1, 1, 0, 0, 0},
    {0, 0, 1, 1, 1, 0},
    {0, 0, 1, 0, 1, 0},
    {0, 0, 1, 1, 1, 0}
  };


  // construct RBM
  RBM rbm(train_N, n_visible, n_hidden, NULL, NULL, NULL);

  // train
  for(int epoch=0; epoch<training_epochs; epoch++) 
  {
	// iterator all the training samples, and apply the CD-k algoirthm
    for(int i=0; i<train_N; i++) 
	{
      rbm.contrastive_divergence(train_X[i], learning_rate, k);
    }
  }

  // test data
  int test_X[2][6] = {
    {1, 1, 0, 0, 0, 0},
    {0, 0, 0, 1, 1, 0}
  };
  double reconstructed_X[2][6];


  // test
  for(int i=0; i<test_N; i++) 
  {
    rbm.reconstruct(test_X[i], reconstructed_X[i]);

    for(int j=0; j<n_visible; j++) 
	{
      printf("%.5f ", reconstructed_X[i][j]);
    }
    cout << endl;
  }

}

int main() {
  test_rbm();
  return 0;
}

运行结果如图:

抱歉!评论已关闭.