Skip to content
Snippets Groups Projects
Commit d2b27c41 authored by Hugo Roussel's avatar Hugo Roussel
Browse files

Add AmoebaQL

parent 587145c6
No related branches found
No related tags found
1 merge request!4Exp rein
......@@ -31,13 +31,13 @@ import utils.XmlConfigGenerator;
public class SimpleReinforcement {
/* Learn and Test */
public static final int MAX_STEP_PER_EPISODE = 200;
public static final int N_LEARN = 200;
public static final int N_TEST = 10;
public static final int N_LEARN = 100;
public static final int N_TEST = 100;
/* Exploration */
public static final int N_EXPLORE_LINE = 0;
public static final int N_EXPLORE_LINE = 60;
public static final double MIN_EXPLO_RATE = 0.02;
public static final double EXPLO_RATE_DIMINUTION_FACTOR = 0.0;
public static final double EXPLO_RATE_DIMINUTION_FACTOR = 0.01;
public static final double EXPLO_RATE_BASE = 1;
public static final String EXPLORATION_STRATEGY = "random"; // can be "random" or "line"
private static int exploreLine;
......@@ -57,16 +57,59 @@ public class SimpleReinforcement {
public void learn(HashMap<String, Double> state, HashMap<String, Double> state2, HashMap<String, Double> action, boolean done);
}
public static class AmoebaQL implements LearningAgent {
public AMOEBA amoeba;
public double lr = 0.8;
public double gamma = 0.9;
private Random rand = new Random();
public AmoebaQL() {
amoeba = setup();
amoeba.setLocalModel(TypeLocalModel.MILLER_REGRESSION);
amoeba.getEnvironment().setMappingErrorAllowed(0.04);
}
@Override
public double choose(HashMap<String, Double> state) {
double a = amoeba.maximize(state).getOrDefault("a1", 0.0);
if(a == 0.0) {
a = rand.nextBoolean() ? -1 : 1;
}
return a;
}
@Override
public void learn(HashMap<String, Double> state, HashMap<String, Double> state2,
HashMap<String, Double> action, boolean done) {
HashMap<String, Double> state2Copy = new HashMap<>(state2);
state2Copy.remove("oracle");
double reward = state2.get("oracle");
double q;
if(!done) {
double expectedReward = amoeba.request(action);
double futureAction = this.choose(state2Copy);
q = reward + gamma * futureAction - expectedReward;
} else {
q = reward;
}
HashMap<String, Double> learn = new HashMap<>(action);
learn.put("oracle", lr * q);
amoeba.learn(learn);
}
}
/**
* Wrapper for AMOEBA
* @author Hugo
*
*/
public static class Amoeba implements LearningAgent {
public static class AmoebaCoop implements LearningAgent {
public AMOEBA amoeba;
private Random rand = new Random();
public Amoeba() {
public AmoebaCoop() {
amoeba = setup();
amoeba.setLocalModel(TypeLocalModel.COOP_MILLER_REGRESSION);
amoeba.getEnvironment().setMappingErrorAllowed(0.009);
......@@ -139,7 +182,7 @@ public class SimpleReinforcement {
//poc(true);
//Configuration.commandLineMode = true;
System.out.println("----- AMOEBA -----");
learning(new Amoeba());
learning(new AmoebaQL());
System.out.println("----- END AMOEBA -----");
/*System.out.println("\n\n----- QLEARNING -----");
learning(new QLearning());
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment