diff --git a/src/neural/BattlefieldParameterEvaluator.java b/src/neural/BattlefieldParameterEvaluator.java index 99fabf6..6e9c073 100644 --- a/src/neural/BattlefieldParameterEvaluator.java +++ b/src/neural/BattlefieldParameterEvaluator.java @@ -31,6 +31,9 @@ public class BattlefieldParameterEvaluator { final static int NUM_NN_HIDDEN_UNITS = 50; // Number of epochs for training final static int NUM_TRAINING_EPOCHS = 100000; + // The requested error in nn training + final static double NN_TRAINING_ERROR = 0.01; + static int NdxBattle; static double[] FinalScore1; static double[] FinalScore2; @@ -92,12 +95,30 @@ public class BattlefieldParameterEvaluator { RawInputs[NdxSample][1] = GunCoolingRate[NdxSample] / MAXGUNCOOLINGRATE; RawOutputs[NdxSample][0] = FinalScore1[NdxSample] / 250; } + BasicNeuralDataSet MyDataSet = new BasicNeuralDataSet(RawInputs, RawOutputs); + // Create and train the neural network    - // ... TO DO ... + BasicNetwork network = new BasicNetwork(); + network.addLayer(new BasicLayer(null, true, NUM_NN_INPUTS)); + network.addLayer(new BasicLayer(new ActivationSigmoid(), true, NUM_NN_HIDDEN_UNITS)); + network.addLayer(new BasicLayer(new ActivationSigmoid(), false, 1)); + network.getStructure().finalizeStructure(); + network.reset(); + System.out.println("Training network..."); - // ... TO DO ... + final ResilientPropagation train = new ResilientPropagation(network, MyDataSet); + + int epoch = 1; + + do { + train.iteration(); + System.out.println("Epoch #" + epoch + " Error:" + train.getError()); + epoch++; + } while(train.getError() > NN_TRAINING_ERROR); + train.finishTraining(); System.out.println("Training completed."); + System.out.println("Testing network..."); // Generate test samples to build an output image int[] OutputRGBint = new int[NUMBATTLEFIELDSIZES * NUMCOOLINGRATES]; @@ -129,7 +150,6 @@ public class BattlefieldParameterEvaluator { int MyPixelIndex = (int) (Math.round(NUMCOOLINGRATES * ((GunCoolingRate[NdxSample] / MAXGUNCOOLINGRATE) - 0.1) / 0.9) + Math .round(NUMBATTLEFIELDSIZES * ((BattlefieldSize[NdxSample] / MAXBATTLEFIELDSIZE) - 0.1) / 0.9) * NUMCOOLINGRATES); - // int MyPixelIndex = 0; if ((MyPixelIndex >= 0) && (MyPixelIndex < NUMCOOLINGRATES * NUMBATTLEFIELDSIZES)) { OutputRGBint[MyPixelIndex] = MyColor.getRGB(); } @@ -139,7 +159,8 @@ public class BattlefieldParameterEvaluator { File f = new File("hello.png"); try { ImageIO.write(img, "png", f); - } catch (IOException e) { + } + catch (IOException e) { // TODO Auto‐generated catch block e.printStackTrace(); } @@ -182,7 +203,7 @@ public class BattlefieldParameterEvaluator { // Called when the game sends out an information message during the battle public void onBattleMessage(BattleMessageEvent e) { -// System.out.println("Msg> " + e.getMessage()); + // System.out.println("Msg> " + e.getMessage()); } // Called when the game sends out an error message during the battle