diff --git a/AMOEBAonAMAK/src/agents/context/localModel/LocalModel.java b/AMOEBAonAMAK/src/agents/context/localModel/LocalModel.java index a2268693f46e26bb676a1f63ed7c56b8a7ba20db..bb2bec7d64e982e8a3fac65c04a5dd0ed84c56b2 100644 --- a/AMOEBAonAMAK/src/agents/context/localModel/LocalModel.java +++ b/AMOEBAonAMAK/src/agents/context/localModel/LocalModel.java @@ -1,6 +1,7 @@ package agents.context.localModel; import java.util.ArrayList; +import java.util.HashMap; import agents.context.Context; import agents.context.Experiment; @@ -31,6 +32,7 @@ public abstract class LocalModel { public abstract double getProposition(Context context); public abstract double getProposition(Experiment experiment); public abstract double getMaxProposition(Context context); + public abstract HashMap<String, Double> getMax(Context context); public abstract double getMinProposition(Context context); diff --git a/AMOEBAonAMAK/src/agents/context/localModel/LocalModelMillerRegression.java b/AMOEBAonAMAK/src/agents/context/localModel/LocalModelMillerRegression.java index 87998fe51fe7f0fbd06036fa3e26bf74c662f8b9..a27ec497b082cb35ae4cf2563d2e8fca265d509d 100644 --- a/AMOEBAonAMAK/src/agents/context/localModel/LocalModelMillerRegression.java +++ b/AMOEBAonAMAK/src/agents/context/localModel/LocalModelMillerRegression.java @@ -2,6 +2,7 @@ package agents.context.localModel; import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; import java.util.List; import agents.context.Context; @@ -132,6 +133,35 @@ public class LocalModelMillerRegression extends LocalModel{ return result; } + public HashMap<String, Double> getMax(Context context){ + ArrayList<Percept> percepts = context.getAmas().getPercepts(); + + HashMap<String, Double> result = new HashMap<String, Double>(); + result.put("oracle", coefs[0]); + + if (coefs[0] == Double.NaN) + throw new ArithmeticException("First coeficient of model cannot be NaN"); + + for (int i = 1 ; i < coefs.length ; i++) { + + if (Double.isNaN(coefs[i])) coefs[i] = 0.0; + if(coefs[i]>0) { + Percept p = percepts.get(i-1); + double value = coefs[i] * context.getRanges().get(p).getEnd(); + result.put("oracle", value); + result.put(p.getName(), context.getRanges().get(p).getEnd()); + } + else { + Percept p = percepts.get(i-1); + double value = coefs[i] * context.getRanges().get(p).getStart(); + result.put("oracle", value); + result.put(p.getName(), context.getRanges().get(p).getStart()); + } + } + + return result; + } + public double getMinProposition(Context context) { ArrayList<Percept> percepts = context.getAmas().getPercepts(); diff --git a/AMOEBAonAMAK/src/agents/head/Head.java b/AMOEBAonAMAK/src/agents/head/Head.java index f4de3348de5bde8f9ede9591fa28887989137332..c3070c579b432016e1b7f6d258a2b32b0a8236f8 100644 --- a/AMOEBAonAMAK/src/agents/head/Head.java +++ b/AMOEBAonAMAK/src/agents/head/Head.java @@ -17,7 +17,7 @@ import agents.percept.Percept; import kernel.AMOEBA; import ncs.NCS; import utils.Pair; -import utils.PickRandom; +import utils.RandomUtils; import utils.PrintOnce; import utils.Quadruplet; import utils.TRACE_LEVEL; @@ -347,7 +347,7 @@ public class Head extends AmoebaAgent { // to limit performance impact, we limit our search on a random sample. // a better way would be to increase neighborhood. PrintOnce.print("Play without oracle : no nearest context in neighbors, searching in a random sample. (only shown once)"); - List<Context> searchList = PickRandom.pickNRandomElements(getAmas().getContexts(), 100); + List<Context> searchList = RandomUtils.pickNRandomElements(getAmas().getContexts(), 100); nearestContext = this.getNearestContext(searchList); if(nearestContext != null) { getAmas().data.prediction = nearestContext.getActionProposal(); diff --git a/AMOEBAonAMAK/src/experiments/SimpleReinforcement.java b/AMOEBAonAMAK/src/experiments/SimpleReinforcement.java new file mode 100644 index 0000000000000000000000000000000000000000..75c015c32db16918c66f2de5589ce3cc4a12de1a --- /dev/null +++ b/AMOEBAonAMAK/src/experiments/SimpleReinforcement.java @@ -0,0 +1,117 @@ +package experiments; + +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Random; + +import fr.irit.smac.amak.Configuration; +import fr.irit.smac.amak.ui.drawables.Drawable; +import fr.irit.smac.amak.ui.drawables.DrawableOval; +import gui.AmoebaWindow; +import javafx.scene.paint.Color; +import kernel.AMOEBA; +import utils.Pair; +import utils.RandomUtils; +import utils.XmlConfigGenerator; + +public class SimpleReinforcement { + + private Random rand = new Random(); + private double x = 0; + private double reward = 0; + private Drawable pos; + + public static void main(String[] args) { + ArrayList<Pair<String, Boolean>> sensors = new ArrayList<>(); + sensors.add(new Pair<String, Boolean>("p1", false)); + sensors.add(new Pair<String, Boolean>("a1", true)); + File config; + try { + config = File.createTempFile("config", "xml"); + XmlConfigGenerator.makeXML(config, sensors); + } catch (IOException e) { + e.printStackTrace(); + System.exit(1); + return; // now compilator know config is initialized + } + + AMOEBA amoeba = new AMOEBA(config.getAbsolutePath(), null); + SimpleReinforcement env = new SimpleReinforcement(); + + Random r = new Random(); + HashMap<String, Double> state = env.reset(); + HashMap<String, Double> state2; + for(int i = 0; i < 10000; i++) { + state.remove("oracle"); + HashMap<String, Double> action = amoeba.maximize(state); + if(r.nextDouble() < 0.1 || action.get("oracle").equals(Double.NEGATIVE_INFINITY) ) { + System.out.println("Random action"); + action.put("a1", (r.nextBoolean() ? 10.0 : -10.0)); + } + state2 = env.step(action.get("a1")); + action.put("oracle", state2.get("oracle")); + action.put("p1", state.get("p1")); + System.out.println(action); + amoeba.learn(action); + state = state2; + } + } + + /** + * Must be called AFTER an AMOEBA with GUI + */ + public SimpleReinforcement() { + Configuration.commandLineMode = false; + AmoebaWindow instance = AmoebaWindow.instance(); + pos = new DrawableOval(0.5, 0.5, 1, 1); + pos.setColor(new Color(0.5, 0.0, 0.0, 0.5)); + instance.mainVUI.add(pos); + instance.mainVUI.createAndAddRectangle(-50, -0.25, 100, 0.5); + instance.mainVUI.createAndAddRectangle(-0.25, -1, 0.5, 2); + + + + } + + public HashMap<String, Double> step(double action){ + if(action == 0.0) action = rand.nextDouble(); + if(action > 0.0) action = Math.ceil(action); + if(action < 0.0 ) action = Math.floor(action); + if(action > 1.0) action = 1.0; + if(action < -1.0) action = -1.0; + double oldX = x; + x = x + action; + if(x < -50.0 || x > 50.0) { + x = RandomUtils.nextDouble(rand, -50.0, Math.nextUp(50.0)); + reward = -100.0; + } else if(x == 0.0 || sign(oldX) != sign(x)) { + // win ! + reward = 100.0; + x = RandomUtils.nextDouble(rand, -50.0, Math.nextUp(50.0)); + } else { + reward = -1.0; + } + HashMap<String, Double> ret = new HashMap<>(); + ret.put("p1", x); + ret.put("oracle", reward); + pos.move(x, 0); + return ret; + } + + public HashMap<String, Double> reset(){ + x = RandomUtils.nextDouble(rand, -50.0, Math.nextUp(50.0)); + reward = 0.0; + + HashMap<String, Double> ret = new HashMap<>(); + ret.put("p1", x); + ret.put("oracle", reward); + return ret; + } + + private int sign(double x) { + return x < 0 ? -1 : 1; + } + +} diff --git a/AMOEBAonAMAK/src/gui/ContextRendererFX.java b/AMOEBAonAMAK/src/gui/ContextRendererFX.java index cfe62c32fd7d5c9909ad1710e570cf216962df7f..75ddf4125607f21b3ef19ca0dcaaea006105be4d 100644 --- a/AMOEBAonAMAK/src/gui/ContextRendererFX.java +++ b/AMOEBAonAMAK/src/gui/ContextRendererFX.java @@ -83,7 +83,7 @@ public class ContextRendererFX extends RenderStrategy { */ public DrawableRectangle getDrawable() { if (!context.isDying() && drawable == null) { - drawable = new DrawableContext(0, 0, 10, 10, context); + drawable = new DrawableContext(0, 0, 0, 0, context); AmoebaWindow.instance().mainVUI.add(drawable); } return drawable; diff --git a/AMOEBAonAMAK/src/kernel/AMOEBA.java b/AMOEBAonAMAK/src/kernel/AMOEBA.java index d05843c62e72bac2fd2c9abb45de07fb06c2625c..4da9804eb1ce27da404b8c31afe82fe59c63df1f 100644 --- a/AMOEBAonAMAK/src/kernel/AMOEBA.java +++ b/AMOEBAonAMAK/src/kernel/AMOEBA.java @@ -429,78 +429,23 @@ public class AMOEBA extends Amas<World> implements IAMOEBA { if(good) pac.add(c); } - ArrayList<Pair<HashMap<String, Double>, Double>> sol = new ArrayList<>(); + ArrayList<HashMap<String, Double>> sol = new ArrayList<>(); for(Context c : pac) { - sol.add(maximiseContext(known, percepts, unknown, c)); + sol.add(c.getLocalModel().getMax(c)); } HashMap<String, Double> max = new HashMap<>(); - // set default value if no solution - for(Percept p : unknown) { - max.put(p.getName(), 0.0); - } + Double maxValue = Double.NEGATIVE_INFINITY; + max.put("oracle", maxValue); //find best solution - for(Pair<HashMap<String, Double>, Double> s : sol) { - if(s.getB() > maxValue) { - maxValue = s.getB(); - max = s.getA(); + for(HashMap<String, Double> s : sol) { + if(s.get("oracle") > maxValue) { + maxValue = s.get("oracle"); + max = s; } } - max.put("oracle", maxValue); return max; } - - private Pair<HashMap<String, Double>, Double> maximiseDummy(HashMap<String, Double> known, - ArrayList<Percept> percepts, ArrayList<Percept> unknown, Context c) { - HashMap<String, Double> res = new HashMap<>(); - for(Percept p : unknown) { - res.put(p.getName(), c.getRangeByPercept(p).getCenter()); - } - HashMap<String, Double> tmpReq = new HashMap<>(res); - HashMap<String, Double> old = perceptions; - perceptions = tmpReq; - Pair<HashMap<String, Double>, Double> ret = new Pair<>(res, c.getActionProposal()); - perceptions = old; - return ret; - } - - //TODO tests ! - private Pair<HashMap<String, Double>, Double> maximiseContext(HashMap<String, Double> known, - ArrayList<Percept> percepts, ArrayList<Percept> unknown, Context c) { - HashMap<String, Double> res = new HashMap<>(); - - Double[] coefs = c.getLocalModel().getCoef(); - double[] vCoefs = new double[coefs.length-1]; - for(int i = 1; i < coefs.length; i++) { - vCoefs[i-1] = coefs[i]; - } - LinearObjectiveFunction fct = new LinearObjectiveFunction(vCoefs, coefs[0]); - ArrayList<LinearConstraint> constraints = new ArrayList<>(); - //TODO : problem : we are not sure that the order of percepts is the same as coefs - int i = 0; - for(String p : known.keySet()) { - double[] cf = new double[percepts.size()]; - cf[i++] = 1.0; - constraints.add(new LinearConstraint(cf, Relationship.EQ, known.get(p))); - } - int unknowStart = i; - for(Percept p : unknown) { - double[] cf = new double[percepts.size()]; - cf[i++] = 1.0; - constraints.add(new LinearConstraint(cf, Relationship.GEQ, c.getRangeByPercept(p).getStart())); - constraints.add(new LinearConstraint(cf, Relationship.LEQ, c.getRangeByPercept(p).getEnd())); - } - SimplexSolver solver = new SimplexSolver(); - LinearConstraintSet set = new LinearConstraintSet(constraints); - PointValuePair sol = solver.optimize(fct, set, GoalType.MAXIMIZE); - for(Percept p : unknown) { - //TODO check if the order match - res.put(p.getName(), sol.getFirst()[unknowStart++]); - } - - Pair<HashMap<String, Double>, Double> ret = new Pair<>(res, sol.getSecond()); - return ret; - } public LocalModel buildLocalModel(Context context) { switch (localModel) { diff --git a/AMOEBAonAMAK/src/utils/PickRandom.java b/AMOEBAonAMAK/src/utils/RandomUtils.java similarity index 63% rename from AMOEBAonAMAK/src/utils/PickRandom.java rename to AMOEBAonAMAK/src/utils/RandomUtils.java index ae94493c4a00b50e6654ab8722eac7f98de66117..9440f17725430ccad2abcee4edb6470d61072134 100644 --- a/AMOEBAonAMAK/src/utils/PickRandom.java +++ b/AMOEBAonAMAK/src/utils/RandomUtils.java @@ -5,7 +5,7 @@ import java.util.List; import java.util.Random; import java.util.concurrent.ThreadLocalRandom; -public class PickRandom { +public class RandomUtils { /** * Pick N random element from the list. if n is bigger than the list, return the list. @@ -40,4 +40,19 @@ public class PickRandom { public static <E> List<E> pickNRandomElements(List<E> list, int n) { return pickNRandomElements(list, n, ThreadLocalRandom.current()); } + + /** + * Generate a pseudorandom double values, conforming to the given origin (inclusive) and bound(exclusive). + * @param rand + * @param origin the origin (inclusive) of the random value + * @param bound the bound (exclusive) of the random value + * @return + */ + public static double nextDouble(Random rand, double origin, double bound) { + double r = rand.nextDouble(); + r = r * (bound - origin) + origin; + if (r >= bound) // correct for rounding + r = Math.nextDown(bound); + return r; + } }