/// (C) 2017, Andrew Polar under GPL ver. 3. // LICENSE // // This program is free software; you can redistribute it and/or // modify it under the terms of the GNU General Public License as // published by the Free Software Foundation; either version 3 of // the License, or (at your option) any later version. // // This program is distributed in the hope that it will be useful, but // WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU // General Public License for more details at // Visit . // // Last modification 12.01.2017 // package com.org; import java.util.*; public class EntryPoint { static double[][] trainingSet = {{1,1,0,0,0,0,1,1}, {0,0,1,1,1,1,0,0}}; static double[][] W = {{1,1,-1,-1,-1,-1,1,1}, {-1,-1,1,1,1,1,-1,-1}}; static double learningRate = 0.5; static int numberOfEpochs = 5000; private static double sigmoid(double x) { return 1.0/(1.0 + Math.exp(-x)); } private static double[] computeHidden(double[] v, double[][] W) { int rowsW = W.length; int colsW = W[0].length; int colsv = v.length; if (colsv != colsW) { System.out.println("Sizes mismatch"); System.exit(1); } double[] h = new double[rowsW]; for (int i = 0; i < rowsW; ++i) { h[i] = 0.0; for (int j = 0; j < colsW; ++j) { h[i] += W[i][j] * v[j]; } h[i] = sigmoid(h[i]); } return h; } private static double[] reconstructVisible(double[] h, double[][] W) { int rowsW = W.length; int colsW = W[0].length; int rowsh = h.length; if (rowsW != rowsh) { System.out.println("Sizes not match"); System.exit(1); } double[] v = new double[colsW]; for (int j = 0; j < colsW; ++j) { v[j] = 0.0; for (int i = 0; i < rowsW; ++i) { v[j] += h[i] * W[i][j]; } v[j] = sigmoid(v[j]); } return v; } private static double[] roundVector(double[] x) { int size = x.length; for (int i = 0; i < size; ++i) { if (x[i] >= 0.5 && x[i] <= 1.0) x[i] = 1.0; else if (x[i] >= 0.0 && x[i] < 0.5) x[i] = 0.0; else { System.out.println("Wrong data " + x[i]); System.exit(1); } } return x; } private static double[] edgeVector(double[] x) { int pos = 0; double max = x[pos]; for (int i = 0; i < x.length; ++i) { if (x[i] > max) { pos = i; max = x[pos]; } } for (int i = 0; i < x.length; ++i) { x[i] = 0.0; } x[pos] = 1.0; return x; } private static void contrastiveDivergence() { //make random weight matrix Random rnd = new Random(); for (int i = 0; i < W.length; ++i) { for (int j = 0; j < W[0].length; ++j) { W[i][j] = (double)(50 - rnd.nextInt(100))/50.0; } } //contrastive divergence for (int epoch = 0; epoch < numberOfEpochs; ++epoch) { //contrastive divergence iteration for every training set for (int i = 0; i < trainingSet.length; ++i) { double[] hiddenPositive = computeHidden(trainingSet[i], W); hiddenPositive = edgeVector(hiddenPositive); double[] visibleReturned = reconstructVisible(hiddenPositive, W); visibleReturned = roundVector(visibleReturned); double[] hiddenNegative = computeHidden(visibleReturned, W); hiddenNegative = edgeVector(hiddenNegative); //update weight matrix for (int m = 0; m < W.length; ++m) { for (int n = 0; n < W[0].length; ++n) { W[m][n] += learningRate * (trainingSet[m][n] * hiddenPositive[m] - visibleReturned[n] * hiddenNegative[m]); } } } } } private static void showVector(double[] x, String title) { System.out.println(title); for (int i = 0; i < x.length; ++i) { System.out.print(String.format("%.2f ",x[i])); } System.out.println(); } private static void showMatrix(double[][] M, String title) { System.out.println(title); for (int i = 0; i < M.length; ++i) { for (int j = 0; j < M[i].length; ++j) { System.out.print(String.format("%.2f ",M[i][j])); } System.out.println(); } } public static void main(String[] args) { contrastiveDivergence(); //when commented it uses original W shown at the top showMatrix(W, "Weight matrix"); double[] h1 = computeHidden(trainingSet[0], W); double[] h2 = computeHidden(trainingSet[1], W); showVector(h1, "Probabilities for spread set"); showVector(h2, "Probabilities for centered set"); h1 = roundVector(h1); h2 = roundVector(h2); double[] v1 = reconstructVisible(h1, W); double[] v2 = reconstructVisible(h2, W); v1 = roundVector(v1); v2 = roundVector(v2); showVector(v1, "Reconstructed training set 1"); showVector(v2, "Reconstructed training set 2"); } }