diff --git a/AMOEBAonAMAK/src/experiments/SimpleReinforcement.java b/AMOEBAonAMAK/src/experiments/SimpleReinforcement.java index b63df76532b63de09980521021efc64b13d007eb..85be42842f749fcb26b09c7ae377c9c3a386bf30 100644 --- a/AMOEBAonAMAK/src/experiments/SimpleReinforcement.java +++ b/AMOEBAonAMAK/src/experiments/SimpleReinforcement.java @@ -2,7 +2,9 @@ package experiments; import java.io.File; import java.io.IOException; +import java.util.ArrayDeque; import java.util.ArrayList; +import java.util.Deque; import java.util.HashMap; import java.util.Random; @@ -12,6 +14,7 @@ import fr.irit.smac.amak.ui.drawables.DrawableOval; import gui.AmoebaWindow; import javafx.scene.paint.Color; import kernel.AMOEBA; +import kernel.backup.SaveHelperDummy; import utils.Pair; import utils.RandomUtils; import utils.XmlConfigGenerator; @@ -45,25 +48,66 @@ public class SimpleReinforcement { return; // now compilator know config is initialized } + Configuration.commandLineMode = true; AMOEBA amoeba = new AMOEBA(config.getAbsolutePath(), null); + amoeba.saver = new SaveHelperDummy(); SimpleReinforcement env = new SimpleReinforcement(); Random r = new Random(); HashMap<String, Double> state = env.reset(); HashMap<String, Double> state2; - for(int i = 0; i < 10000; i++) { - state.remove("oracle"); - HashMap<String, Double> action = amoeba.maximize(state); - if(r.nextDouble() < 0.2 || action.get("oracle").equals(Double.NEGATIVE_INFINITY) ) { - System.out.println("Random action"); - action.put("a1", (r.nextBoolean() ? 10.0 : -10.0)); + double explo = 0.5; + for(int i = 0; i < 100; i++) { + boolean done = false; + Deque<HashMap<String, Double>> actions = new ArrayDeque<>(); + //System.out.println("Explore "+i); + int nbStep = 0; + state = env.reset(); + while(!done) { + nbStep++; + if(nbStep > 500) { + done = true; + } + state.remove("oracle"); + state.remove("a1"); + HashMap<String, Double> action = amoeba.maximize(state); + if(r.nextDouble() < 0.5 || action.get("oracle").equals(Double.NEGATIVE_INFINITY) ) { + //System.out.println("Random action"); + action.put("a1", (r.nextBoolean() ? 10.0 : -10.0)); + } + state2 = env.step(action.get("a1")); + + if(state2.get("oracle") != -1.0) { + done = true; + } + + action.put("p1", state.get("p1")); + action.put("oracle", state2.get("oracle")); + //System.out.println(action); + actions.add(action); + + state = state2; } - state2 = env.step(action.get("a1")); - action.put("oracle", state2.get("oracle")); - action.put("p1", state.get("p1")); - System.out.println(action); + + //System.out.println("Learn "+i); + HashMap<String, Double> action = actions.pop(); + double reward = action.get("oracle"); amoeba.learn(action); - state = state2; + + while(!actions.isEmpty()) { + action = actions.pop(); + reward += action.get("oracle"); + action.put("oracle", reward); + amoeba.learn(action); + } + + if(explo > 0.1) { + explo -= 0.01; + if(explo < 0.1) + explo = 0.1; + } + + System.out.println("Episode "+i+" reward : "+reward+" explo : "+explo); } } @@ -71,13 +115,13 @@ public class SimpleReinforcement { * Must be called AFTER an AMOEBA with GUI */ public SimpleReinforcement() { - Configuration.commandLineMode = false; - AmoebaWindow instance = AmoebaWindow.instance(); - pos = new DrawableOval(0.5, 0.5, 1, 1); - pos.setColor(new Color(0.5, 0.0, 0.0, 0.5)); - instance.mainVUI.add(pos); - instance.mainVUI.createAndAddRectangle(-50, -0.25, 100, 0.5); - instance.mainVUI.createAndAddRectangle(-0.25, -1, 0.5, 2); + //Configuration.commandLineMode = false; + //AmoebaWindow instance = AmoebaWindow.instance(); + //pos = new DrawableOval(0.5, 0.5, 1, 1); + //pos.setColor(new Color(0.5, 0.0, 0.0, 0.5)); + //instance.mainVUI.add(pos); + //instance.mainVUI.createAndAddRectangle(-50, -0.25, 100, 0.5); + //instance.mainVUI.createAndAddRectangle(-0.25, -1, 0.5, 2); @@ -96,7 +140,7 @@ public class SimpleReinforcement { reward = -100.0; } else if(x == 0.0 || sign(oldX) != sign(x)) { // win ! - reward = 100.0; + reward = 1000.0; x = RandomUtils.nextDouble(rand, -50.0, Math.nextUp(50.0)); } else { reward = -1.0; @@ -104,7 +148,7 @@ public class SimpleReinforcement { HashMap<String, Double> ret = new HashMap<>(); ret.put("p1", x); ret.put("oracle", reward); - pos.move(x+0.5, 0.5); + //pos.move(x+0.5, 0.5); return ret; }