added some ANN stuff

master
Peter Babič 9 years ago
parent fd2e4d52cc
commit 4c7da1ffdc
  1. 31
      src/neural/BattlefieldParameterEvaluator.java

@ -31,6 +31,9 @@ public class BattlefieldParameterEvaluator {
final static int NUM_NN_HIDDEN_UNITS = 50;
// Number of epochs for training
final static int NUM_TRAINING_EPOCHS = 100000;
// The requested error in nn training
final static double NN_TRAINING_ERROR = 0.01;
static int NdxBattle;
static double[] FinalScore1;
static double[] FinalScore2;
@ -92,12 +95,30 @@ public class BattlefieldParameterEvaluator {
RawInputs[NdxSample][1] = GunCoolingRate[NdxSample] / MAXGUNCOOLINGRATE;
RawOutputs[NdxSample][0] = FinalScore1[NdxSample] / 250;
}
BasicNeuralDataSet MyDataSet = new BasicNeuralDataSet(RawInputs, RawOutputs);
// Create and train the neural network   
// ... TO DO ...
BasicNetwork network = new BasicNetwork();
network.addLayer(new BasicLayer(null, true, NUM_NN_INPUTS));
network.addLayer(new BasicLayer(new ActivationSigmoid(), true, NUM_NN_HIDDEN_UNITS));
network.addLayer(new BasicLayer(new ActivationSigmoid(), false, 1));
network.getStructure().finalizeStructure();
network.reset();
System.out.println("Training network...");
// ... TO DO ...
final ResilientPropagation train = new ResilientPropagation(network, MyDataSet);
int epoch = 1;
do {
train.iteration();
System.out.println("Epoch #" + epoch + " Error:" + train.getError());
epoch++;
} while(train.getError() > NN_TRAINING_ERROR);
train.finishTraining();
System.out.println("Training completed.");
System.out.println("Testing network...");
// Generate test samples to build an output image
int[] OutputRGBint = new int[NUMBATTLEFIELDSIZES * NUMCOOLINGRATES];
@ -129,7 +150,6 @@ public class BattlefieldParameterEvaluator {
int MyPixelIndex = (int) (Math.round(NUMCOOLINGRATES * ((GunCoolingRate[NdxSample] / MAXGUNCOOLINGRATE) - 0.1) / 0.9) + Math
.round(NUMBATTLEFIELDSIZES * ((BattlefieldSize[NdxSample] / MAXBATTLEFIELDSIZE) - 0.1) / 0.9)
* NUMCOOLINGRATES);
// int MyPixelIndex = 0;
if ((MyPixelIndex >= 0) && (MyPixelIndex < NUMCOOLINGRATES * NUMBATTLEFIELDSIZES)) {
OutputRGBint[MyPixelIndex] = MyColor.getRGB();
}
@ -139,7 +159,8 @@ public class BattlefieldParameterEvaluator {
File f = new File("hello.png");
try {
ImageIO.write(img, "png", f);
} catch (IOException e) {
}
catch (IOException e) {
// TODO Auto‐generated catch block
e.printStackTrace();
}
@ -182,7 +203,7 @@ public class BattlefieldParameterEvaluator {
// Called when the game sends out an information message during the battle
public void onBattleMessage(BattleMessageEvent e) {
// System.out.println("Msg> " + e.getMessage());
// System.out.println("Msg> " + e.getMessage());
}
// Called when the game sends out an error message during the battle

Loading…
Cancel
Save