Skip to content
Snippets Groups Projects
Commit bc8341fd authored by BrunoDatoMeneses's avatar BrunoDatoMeneses
Browse files

ENH: 2D reinforcement Amoeba naive

parent 761e73c4
Branches
No related tags found
1 merge request!4Exp rein
......@@ -47,6 +47,7 @@ public class Context extends AmoebaAgent {
private double action;
private Double actionProposition = null;
public Double lastPrediction = null;
//private boolean valid;
......@@ -1485,6 +1486,8 @@ public class Context extends AmoebaAgent {
}
public double getActionProposal() {
return localModel.getProposition();
}
......
......@@ -750,6 +750,7 @@ public class Head extends AmoebaAgent {
newContext = context;
newContextCreated = true;
newContext.lastPrediction = newContext.getActionProposal();
}
getAmas().data.executionTimes[9]=System.currentTimeMillis()- getAmas().data.executionTimes[9];
......@@ -893,6 +894,7 @@ public class Head extends AmoebaAgent {
double minDistanceToOraclePrediction = Double.POSITIVE_INFINITY;
for (Context activatedContext : activatedContexts) {
System.out.println(activatedContext.getName());
currentDistanceToOraclePrediction = activatedContext.getLocalModel()
.distance(activatedContext.getCurrentExperiment());
getAmas().data.distanceToRegression = currentDistanceToOraclePrediction;
......@@ -923,6 +925,7 @@ public class Head extends AmoebaAgent {
activatedContext.criticalities.addCriticality("distanceToRegression", currentDistanceToOraclePrediction);
//getEnvironment().trace(new ArrayList<String>(Arrays.asList("ADD CRITICALITY TO CTXT", ""+activatedContext.getName(), ""+criticalities.getLastValues().get("distanceToRegression").size())));
activatedContext.lastPrediction = activatedContext.getActionProposal();
}
......
......@@ -31,7 +31,7 @@ import utils.XmlConfigGenerator;
public abstract class SimpleReinforcement2D {
/* Learn and Test */
public static final int MAX_STEP_PER_EPISODE = 200;
public static final int N_LEARN = 400;
public static final int N_LEARN = 400;//400
public static final int N_TEST = 100;
/* Exploration */
......@@ -50,11 +50,11 @@ public abstract class SimpleReinforcement2D {
learning(new QLearning());
System.out.println("----- END QLEARNING -----");*/
ArrayList<ArrayList<Double>> results = new ArrayList<>();
for(int i = 0; i < 100; i++) {
for(int i = 0; i < 1; i++) {
//LearningAgent agent = new QLearning();
LearningAgent agent = new AmoebaQL();
//LearningAgent agent = new AmoebaCoop();
Environment env = new TwoDimensionEnv();
Environment env = new TwoDimensionEnv(10);
results.add(learning(agent, env));
System.out.println(i);
}
......@@ -69,7 +69,7 @@ public abstract class SimpleReinforcement2D {
System.out.println(""+i+"\t"+average);
}
System.exit(0);
//System.exit(0);
}
/**
......@@ -110,7 +110,7 @@ public abstract class SimpleReinforcement2D {
public AmoebaQL() {
amoeba = setup();
amoeba.setLocalModel(TypeLocalModel.MILLER_REGRESSION);
amoeba.getEnvironment().setMappingErrorAllowed(0.1);
amoeba.getEnvironment().setMappingErrorAllowed(0.025);
}
@Override
......@@ -158,11 +158,12 @@ public abstract class SimpleReinforcement2D {
}
HashMap<String, Double> learn = new HashMap<>(action);
learn.put("oracle", lr * q);
//learn.put("oracle", lr * q);
learn.put("oracle", reward);
// learn : previous state, current action and current Q learning reward
System.out.println(learn);
amoeba.learn(learn);
}
@Override
......@@ -266,9 +267,13 @@ public abstract class SimpleReinforcement2D {
private double x = 0;
private double y = 0;
private double reward = 0;
private double size;
private Drawable pos;
public TwoDimensionEnv() {
public TwoDimensionEnv(double envSize) {
size = envSize;
if(!Configuration.commandLineMode) {
AmoebaWindow instance = AmoebaWindow.instance();
//pos = new DrawableOval(0.5, 0.5, 1, 1);
......@@ -283,9 +288,9 @@ public abstract class SimpleReinforcement2D {
@Override
public HashMap<String, Double> reset(){
x = RandomUtils.nextDouble(rand, -50.0, Math.nextUp(50.0));
x = RandomUtils.nextDouble(rand, -size, Math.nextUp(size));
x = Math.round(x);
y = RandomUtils.nextDouble(rand, -50.0, Math.nextUp(50.0));
y = RandomUtils.nextDouble(rand, -size, Math.nextUp(size));
y = Math.round(x);
reward = 0.0;
//pos.move(x+0.5, 0.5);
......@@ -306,7 +311,7 @@ public abstract class SimpleReinforcement2D {
if(action > 1.0) action = 1.0;
if(action < -1.0) action = -1.0;
double oldX = x;
x = x + action*10;
x = x + action;
double action2 = actionMap.get("a2");
//if(action2 == 0.0) action2 = rand.nextDouble();
......@@ -315,14 +320,14 @@ public abstract class SimpleReinforcement2D {
if(action2 > 1.0) action2 = 1.0;
if(action2 < -1.0) action2 = -1.0;
double oldY = y;
y = y + action2*10;
y = y + action2;
//System.out.println("ACTIONS " + " a1 " +action + " " + " a2 " + action2);
if(x < -50.0 || x > 50.0 || y < -50.0 || y > 50.0) {
reward = -100.0;
if(x < -size || x > size || y < -size || y > size) {
reward = -1000.0;
} else if((x == 0.0 && y == 0.0) || (sign(oldX) != sign(x) && sign(oldY) != sign(y) )) {
// win !
reward = 100.0;
reward = 1000.0;
} else {
reward = -1.0;
}
......@@ -345,8 +350,8 @@ public abstract class SimpleReinforcement2D {
@Override
public List<String> perceptionSpace() {
ArrayList<String> l = new ArrayList<>();
l.add("p1 enum:false [-50, 50]");
l.add("p2 enum:false [-50, 50]");
l.add("p1 enum:false [-"+size+", "+size+"]");
l.add("p2 enum:false [-"+size+", "+size+"]");
return l;
}
......@@ -454,6 +459,8 @@ public abstract class SimpleReinforcement2D {
state = state2;
}
System.out.println("-----------------------------------------------------------------------");
// update exploration rate
if(explo > MIN_EXPLO_RATE) {
explo -= EXPLO_RATE_DIMINUTION_FACTOR;
......@@ -521,7 +528,7 @@ public abstract class SimpleReinforcement2D {
*/
public static void poc(boolean learnMalus) {
AMOEBA amoeba = setup();
Environment env = new TwoDimensionEnv();
Environment env = new TwoDimensionEnv(50);
// train
for(double n = 0.0; n < 0.5; n+=0.1) {
......
......@@ -40,10 +40,37 @@ public class ContextRendererFX extends RenderStrategy {
}
private void updateColor() {
setColorWithPrediction();
}
private void setColorWithCoefs() {
Double[] c = ContextColor.colorFromCoefs(context.getFunction().getCoef());
drawable.setColor(new Color(c[0], c[1], c[2], 90d / 255d));
}
private void setColorWithPrediction() {
double r = 0.0;
double g = 0.0;
double b = 0.0;
if(context.lastPrediction!=null) {
r = context.lastPrediction < -900 ? 1.0 : 0.0;
g = context.lastPrediction > 900 ? 1.0 : 0.0;
}else {
b = 1.0;
}
if(Math.abs(context.lastPrediction)>900) {
System.out.println("---------------------------------------------" +context.getName() + " " + context.lastPrediction + " r " + r + " g " + g);
}
drawable.setColor(new Color(r, g, b, 90d / 255d));
}
public String getColorForUnity() {
Double[] c = ContextColor.colorFromCoefs(context.getFunction().getCoef());
return c[0].intValue() + "," + c[1].intValue() + "," + c[2].intValue() + ",100";
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment