Skip to content
Snippets Groups Projects
Commit e463b7bb authored by Hugo Roussel's avatar Hugo Roussel
Browse files

Add a simple reinforcement experiment

parent 8d64e669
No related branches found
No related tags found
2 merge requests!3Merge masters,!2Merge dev into develop
package agents.context.localModel;
import java.util.ArrayList;
import java.util.HashMap;
import agents.context.Context;
import agents.context.Experiment;
......@@ -31,6 +32,7 @@ public abstract class LocalModel {
public abstract double getProposition(Context context);
public abstract double getProposition(Experiment experiment);
public abstract double getMaxProposition(Context context);
public abstract HashMap<String, Double> getMax(Context context);
public abstract double getMinProposition(Context context);
......
......@@ -2,6 +2,7 @@ package agents.context.localModel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import agents.context.Context;
......@@ -132,6 +133,35 @@ public class LocalModelMillerRegression extends LocalModel{
return result;
}
public HashMap<String, Double> getMax(Context context){
ArrayList<Percept> percepts = context.getAmas().getPercepts();
HashMap<String, Double> result = new HashMap<String, Double>();
result.put("oracle", coefs[0]);
if (coefs[0] == Double.NaN)
throw new ArithmeticException("First coeficient of model cannot be NaN");
for (int i = 1 ; i < coefs.length ; i++) {
if (Double.isNaN(coefs[i])) coefs[i] = 0.0;
if(coefs[i]>0) {
Percept p = percepts.get(i-1);
double value = coefs[i] * context.getRanges().get(p).getEnd();
result.put("oracle", value);
result.put(p.getName(), context.getRanges().get(p).getEnd());
}
else {
Percept p = percepts.get(i-1);
double value = coefs[i] * context.getRanges().get(p).getStart();
result.put("oracle", value);
result.put(p.getName(), context.getRanges().get(p).getStart());
}
}
return result;
}
public double getMinProposition(Context context) {
ArrayList<Percept> percepts = context.getAmas().getPercepts();
......
......@@ -17,7 +17,7 @@ import agents.percept.Percept;
import kernel.AMOEBA;
import ncs.NCS;
import utils.Pair;
import utils.PickRandom;
import utils.RandomUtils;
import utils.PrintOnce;
import utils.Quadruplet;
import utils.TRACE_LEVEL;
......@@ -347,7 +347,7 @@ public class Head extends AmoebaAgent {
// to limit performance impact, we limit our search on a random sample.
// a better way would be to increase neighborhood.
PrintOnce.print("Play without oracle : no nearest context in neighbors, searching in a random sample. (only shown once)");
List<Context> searchList = PickRandom.pickNRandomElements(getAmas().getContexts(), 100);
List<Context> searchList = RandomUtils.pickNRandomElements(getAmas().getContexts(), 100);
nearestContext = this.getNearestContext(searchList);
if(nearestContext != null) {
getAmas().data.prediction = nearestContext.getActionProposal();
......
package experiments;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Random;
import fr.irit.smac.amak.Configuration;
import fr.irit.smac.amak.ui.drawables.Drawable;
import fr.irit.smac.amak.ui.drawables.DrawableOval;
import gui.AmoebaWindow;
import javafx.scene.paint.Color;
import kernel.AMOEBA;
import utils.Pair;
import utils.RandomUtils;
import utils.XmlConfigGenerator;
public class SimpleReinforcement {
private Random rand = new Random();
private double x = 0;
private double reward = 0;
private Drawable pos;
public static void main(String[] args) {
ArrayList<Pair<String, Boolean>> sensors = new ArrayList<>();
sensors.add(new Pair<String, Boolean>("p1", false));
sensors.add(new Pair<String, Boolean>("a1", true));
File config;
try {
config = File.createTempFile("config", "xml");
XmlConfigGenerator.makeXML(config, sensors);
} catch (IOException e) {
e.printStackTrace();
System.exit(1);
return; // now compilator know config is initialized
}
AMOEBA amoeba = new AMOEBA(config.getAbsolutePath(), null);
SimpleReinforcement env = new SimpleReinforcement();
Random r = new Random();
HashMap<String, Double> state = env.reset();
HashMap<String, Double> state2;
for(int i = 0; i < 10000; i++) {
state.remove("oracle");
HashMap<String, Double> action = amoeba.maximize(state);
if(r.nextDouble() < 0.1 || action.get("oracle").equals(Double.NEGATIVE_INFINITY) ) {
System.out.println("Random action");
action.put("a1", (r.nextBoolean() ? 10.0 : -10.0));
}
state2 = env.step(action.get("a1"));
action.put("oracle", state2.get("oracle"));
action.put("p1", state.get("p1"));
System.out.println(action);
amoeba.learn(action);
state = state2;
}
}
/**
* Must be called AFTER an AMOEBA with GUI
*/
public SimpleReinforcement() {
Configuration.commandLineMode = false;
AmoebaWindow instance = AmoebaWindow.instance();
pos = new DrawableOval(0.5, 0.5, 1, 1);
pos.setColor(new Color(0.5, 0.0, 0.0, 0.5));
instance.mainVUI.add(pos);
instance.mainVUI.createAndAddRectangle(-50, -0.25, 100, 0.5);
instance.mainVUI.createAndAddRectangle(-0.25, -1, 0.5, 2);
}
public HashMap<String, Double> step(double action){
if(action == 0.0) action = rand.nextDouble();
if(action > 0.0) action = Math.ceil(action);
if(action < 0.0 ) action = Math.floor(action);
if(action > 1.0) action = 1.0;
if(action < -1.0) action = -1.0;
double oldX = x;
x = x + action;
if(x < -50.0 || x > 50.0) {
x = RandomUtils.nextDouble(rand, -50.0, Math.nextUp(50.0));
reward = -100.0;
} else if(x == 0.0 || sign(oldX) != sign(x)) {
// win !
reward = 100.0;
x = RandomUtils.nextDouble(rand, -50.0, Math.nextUp(50.0));
} else {
reward = -1.0;
}
HashMap<String, Double> ret = new HashMap<>();
ret.put("p1", x);
ret.put("oracle", reward);
pos.move(x, 0);
return ret;
}
public HashMap<String, Double> reset(){
x = RandomUtils.nextDouble(rand, -50.0, Math.nextUp(50.0));
reward = 0.0;
HashMap<String, Double> ret = new HashMap<>();
ret.put("p1", x);
ret.put("oracle", reward);
return ret;
}
private int sign(double x) {
return x < 0 ? -1 : 1;
}
}
......@@ -83,7 +83,7 @@ public class ContextRendererFX extends RenderStrategy {
*/
public DrawableRectangle getDrawable() {
if (!context.isDying() && drawable == null) {
drawable = new DrawableContext(0, 0, 10, 10, context);
drawable = new DrawableContext(0, 0, 0, 0, context);
AmoebaWindow.instance().mainVUI.add(drawable);
}
return drawable;
......
......@@ -429,78 +429,23 @@ public class AMOEBA extends Amas<World> implements IAMOEBA {
if(good) pac.add(c);
}
ArrayList<Pair<HashMap<String, Double>, Double>> sol = new ArrayList<>();
ArrayList<HashMap<String, Double>> sol = new ArrayList<>();
for(Context c : pac) {
sol.add(maximiseContext(known, percepts, unknown, c));
sol.add(c.getLocalModel().getMax(c));
}
HashMap<String, Double> max = new HashMap<>();
// set default value if no solution
for(Percept p : unknown) {
max.put(p.getName(), 0.0);
}
Double maxValue = Double.NEGATIVE_INFINITY;
max.put("oracle", maxValue);
//find best solution
for(Pair<HashMap<String, Double>, Double> s : sol) {
if(s.getB() > maxValue) {
maxValue = s.getB();
max = s.getA();
for(HashMap<String, Double> s : sol) {
if(s.get("oracle") > maxValue) {
maxValue = s.get("oracle");
max = s;
}
}
max.put("oracle", maxValue);
return max;
}
private Pair<HashMap<String, Double>, Double> maximiseDummy(HashMap<String, Double> known,
ArrayList<Percept> percepts, ArrayList<Percept> unknown, Context c) {
HashMap<String, Double> res = new HashMap<>();
for(Percept p : unknown) {
res.put(p.getName(), c.getRangeByPercept(p).getCenter());
}
HashMap<String, Double> tmpReq = new HashMap<>(res);
HashMap<String, Double> old = perceptions;
perceptions = tmpReq;
Pair<HashMap<String, Double>, Double> ret = new Pair<>(res, c.getActionProposal());
perceptions = old;
return ret;
}
//TODO tests !
private Pair<HashMap<String, Double>, Double> maximiseContext(HashMap<String, Double> known,
ArrayList<Percept> percepts, ArrayList<Percept> unknown, Context c) {
HashMap<String, Double> res = new HashMap<>();
Double[] coefs = c.getLocalModel().getCoef();
double[] vCoefs = new double[coefs.length-1];
for(int i = 1; i < coefs.length; i++) {
vCoefs[i-1] = coefs[i];
}
LinearObjectiveFunction fct = new LinearObjectiveFunction(vCoefs, coefs[0]);
ArrayList<LinearConstraint> constraints = new ArrayList<>();
//TODO : problem : we are not sure that the order of percepts is the same as coefs
int i = 0;
for(String p : known.keySet()) {
double[] cf = new double[percepts.size()];
cf[i++] = 1.0;
constraints.add(new LinearConstraint(cf, Relationship.EQ, known.get(p)));
}
int unknowStart = i;
for(Percept p : unknown) {
double[] cf = new double[percepts.size()];
cf[i++] = 1.0;
constraints.add(new LinearConstraint(cf, Relationship.GEQ, c.getRangeByPercept(p).getStart()));
constraints.add(new LinearConstraint(cf, Relationship.LEQ, c.getRangeByPercept(p).getEnd()));
}
SimplexSolver solver = new SimplexSolver();
LinearConstraintSet set = new LinearConstraintSet(constraints);
PointValuePair sol = solver.optimize(fct, set, GoalType.MAXIMIZE);
for(Percept p : unknown) {
//TODO check if the order match
res.put(p.getName(), sol.getFirst()[unknowStart++]);
}
Pair<HashMap<String, Double>, Double> ret = new Pair<>(res, sol.getSecond());
return ret;
}
public LocalModel buildLocalModel(Context context) {
switch (localModel) {
......
......@@ -5,7 +5,7 @@ import java.util.List;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;
public class PickRandom {
public class RandomUtils {
/**
* Pick N random element from the list. if n is bigger than the list, return the list.
......@@ -40,4 +40,19 @@ public class PickRandom {
public static <E> List<E> pickNRandomElements(List<E> list, int n) {
return pickNRandomElements(list, n, ThreadLocalRandom.current());
}
/**
* Generate a pseudorandom double values, conforming to the given origin (inclusive) and bound(exclusive).
* @param rand
* @param origin the origin (inclusive) of the random value
* @param bound the bound (exclusive) of the random value
* @return
*/
public static double nextDouble(Random rand, double origin, double bound) {
double r = rand.nextDouble();
r = r * (bound - origin) + origin;
if (r >= bound) // correct for rounding
r = Math.nextDown(bound);
return r;
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment