Skip to content
Snippets Groups Projects
Commit 442a78db authored by Hugo Roussel's avatar Hugo Roussel
Browse files

LocalModel refactor and factories

Now use factories, each TypeLocalModel is linked to a factory.
parent c20e7953
Branches
No related tags found
1 merge request!4Exp rein
Showing
with 285 additions and 140 deletions
......@@ -8,7 +8,6 @@ import java.util.HashMap;
import agents.AmoebaAgent;
import agents.context.localModel.LocalModel;
import agents.context.localModel.LocalModelCoop;
import agents.context.localModel.LocalModelMillerRegression;
import agents.context.localModel.TypeLocalModel;
import agents.head.Criticalities;
......@@ -158,15 +157,13 @@ public class Context extends AmoebaAgent {
// expand();
this.confidence = fatherContext.confidence;
if (fatherContext.getLocalModel().getType() == TypeLocalModel.MILLER_REGRESSION) {
this.localModel = getAmas().buildLocalModel(this);
// this.formulaLocalModel = ((LocalModelMillerRegression)
// bestNearestContext.localModel).getFormula(bestNearestContext);
Double[] coef = fatherContext.localModel.getCoef();
this.localModel.setCoef(coef);
this.actionProposition = fatherContext.localModel.getProposition();
}
this.localModel = getAmas().buildLocalModel(this);
// this.formulaLocalModel = ((LocalModelMillerRegression)
// bestNearestContext.localModel).getFormula(bestNearestContext);
Double[] coef = fatherContext.localModel.getCoef();
this.localModel.setCoef(coef);
this.actionProposition = fatherContext.localModel.getProposition();
getAmas().addAlteredContext(this);
this.setName(String.valueOf(this.hashCode()));
......@@ -199,15 +196,12 @@ public class Context extends AmoebaAgent {
//expand();
this.confidence = bestNearestContext.confidence;
if (bestNearestContext.getLocalModel().getType() == TypeLocalModel.MILLER_REGRESSION) {
this.localModel = getAmas().buildLocalModel(this);
// this.formulaLocalModel = ((LocalModelMillerRegression)
// bestNearestContext.localModel).getFormula(bestNearestContext);
Double[] coef = bestNearestContext.localModel.getCoef();
this.localModel.setCoef(coef);
this.actionProposition = bestNearestContext.localModel.getProposition();
}
this.localModel = getAmas().buildLocalModel(this);
// this.formulaLocalModel = ((LocalModelMillerRegression)
// bestNearestContext.localModel).getFormula(bestNearestContext);
Double[] coef = bestNearestContext.localModel.getCoef();
this.localModel.setCoef(coef);
this.actionProposition = bestNearestContext.localModel.getProposition();
localModel.setFirstExperiments(new ArrayList<Experiment>(bestNearestContext.getLocalModel().getFirstExperiments()));
......@@ -1400,10 +1394,8 @@ public class Context extends AmoebaAgent {
s += "creation tick : " + tickCreation +"\n";
s += "\n";
s += "Model : ";
s += "Model "+this.localModel.getType()+" :";
s += this.localModel.getCoefsFormula() + "\n";
s += "Cooperative Spatial Model : ";
s += ((LocalModelCoop)this.localModel).getCoefsFormulaCoop() + "\n";
s += "\n";
s += "Ranges :\n";
......
......@@ -9,102 +9,176 @@ import agents.context.Experiment;
/**
* A LocalModel is used by a Context to store information and generate prediction.
*/
public interface LocalModel {
public abstract class LocalModel {
protected LocalModel modifier = null;
protected LocalModel modified = null; // Be careful ! One letter and it's totally a different thing !
/**
* Sets the context that use the LocalModel
* @param context
*/
public void setContext(Context context);
public abstract void setContext(Context context);
/**
* gets the context that use the LocalModel
* @return
*/
public Context getContext();
public abstract Context getContext();
/**
* Gets the proposition.
*
* @return the proposition
*/
public double getProposition();
public abstract double getProposition();
/**
* Gets the proposition with the highest value possible
* @return
*/
public double getMaxProposition();
public abstract double getMaxProposition();
/**
* Return the point (percept value) that produce the max proposition, considering some percepts are fixed.
* @return a HashMap with percept names as key, and their corresponding value. The oracle is the max proposition
* @see LocalModel#getMaxProposition(Context)
*/
public HashMap<String, Double> getMaxWithConstraint(HashMap<String, Double> fixedPercepts);;
public abstract HashMap<String, Double> getMaxWithConstraint(HashMap<String, Double> fixedPercepts);;
/**
* Gets the proposition with the lowest value possible
* @return
*/
public double getMinProposition();
public abstract double getMinProposition();
/**
* Gets the formula of the model
* @return
*/
public String getCoefsFormula();
public String getCoefsFormula() {
Double[] coefs = getCoef();
String result = "" +coefs[0];
if (coefs[0] == Double.NaN) System.exit(0);
for (int i = 1 ; i < coefs.length ; i++) {
if (Double.isNaN(coefs[i])) coefs[i] = 0.0;
result += "\t" + coefs[i] + " (" + getContext().getAmas().getPercepts().get(i-1) +")";
}
return result;
}
/**
* Update the model with a new experiment.
* @param newExperiment
* @param weight the weight of the new experiment in the compute of the model
*/
public void updateModel(Experiment newExperiment, double weight);
public abstract void updateModel(Experiment newExperiment, double weight);
public String coefsToString();
public String coefsToString() {
String coefsString = "";
Double[] coefs = getCoef();
if(coefs != null) {
for(int i=0; i<coefs.length; i ++) {
coefsString += coefs[i] + "\t";
}
}
return coefsString;
}
/**
* The distance between an experiment and the model.
* @param experiment
* @return
*/
public double distance(Experiment experiment);
public abstract double distance(Experiment experiment);
/**
* Gets the experiments used to properly initialize the model.
* @return
*/
public ArrayList<Experiment> getFirstExperiments();
public abstract ArrayList<Experiment> getFirstExperiments();
/**
* Sets the experiments used to properly initialize the model.
* This may not trigger an update of the model.
* @param frstExp
*/
public void setFirstExperiments( ArrayList<Experiment> frstExp);
public abstract void setFirstExperiments( ArrayList<Experiment> frstExp);
/**
* Tells if the model has enough experiments to produce a good prediction.
* For example, a regression need a number of experiments equals or superior to the number of dimension.
* @return
*/
public boolean finishedFirstExperiments();
public abstract boolean finishedFirstExperiments();
/**
* Gets coefficients of the model
* @return
*/
public Double[] getCoef();
public abstract Double[] getCoef();
/**
* Sets coefficients of the model
* @return
*/
public void setCoef(Double[] coef);
public abstract void setCoef(Double[] coef);
/**
* Gets the {@link TypeLocalModel} corresponding to this LocalModel
*/
public TypeLocalModel getType();
public abstract TypeLocalModel getType();
/**
* Sets the type of the LocalModel, if it ever has to change.
*/
public abstract void setType(TypeLocalModel type);
/**
* Set an LocalModel that modify the behavior of the current LocalModel.<br/>
* The modifier can then be used by calling {@link LocalModel#getModifier()} on the modified LocalModel.
* @param Modifier a LocalModel
* @see LocalModel#getModifier()
* @see LocalModel#hasModifier()
*/
public void setModifier(LocalModel modifier) {
this.modifier = modifier;
modifier.modified = this;
}
/**
* @return true if the LocalModel has an usable modifier.
*/
public boolean hasModifier() {
return modifier != null;
}
/**
* Get the LocalModel that modify the behavior of the current LocalModel.
* @return a LocalModel or null
*/
public LocalModel getModifier() {
return modifier;
}
/**
*
* @return true if the LocalModel is a modifier, it means that {@link LocalModel#getModified()} is not null
*/
public boolean hasModified() {
return modified != null;
}
/**
* If the LocalModel is a modifier, return the modified LocalModel
* @return a LocalModel or null
*/
public LocalModel getModified() {
return modified;
}
}
......@@ -11,22 +11,25 @@ import agents.context.Experiment;
import agents.percept.Percept;
import utils.Pair;
public class LocalModelCoop implements LocalModel {
public class LocalModelCoopModifier extends LocalModel {
private LocalModel localModel;
private TypeLocalModel type;
public LocalModelCoop(LocalModel localModel) {
public LocalModelCoopModifier(LocalModel localModel, TypeLocalModel type) {
this.localModel = localModel;
localModel.setModifier(this);
setType(type);
}
@Override
public void setContext(Context context) {
localModel.setContext(context);
}
@Override
public Context getContext() {
return localModel.getContext();
}
@Override
public void setContext(Context context) {
localModel.setContext(context);
}
@Override
public double getProposition() {
......@@ -45,10 +48,6 @@ public class LocalModelCoop implements LocalModel {
@Override
public HashMap<String, Double> getMaxWithConstraint(HashMap<String, Double> fixedPercepts) {
return localModel.getMaxWithConstraint(fixedPercepts);
}
public HashMap<String, Double> getMaxWithConstraintCoop(HashMap<String, Double> fixedPercepts) {
ArrayList<Percept> percepts = getContext().getAmas().getPercepts();
HashMap<String, Double> result = new HashMap<String, Double>();
......@@ -86,28 +85,6 @@ public class LocalModelCoop implements LocalModel {
return localModel.getMinProposition();
}
@Override
public String getCoefsFormula() {
return localModel.getCoefsFormula();
}
public String getCoefsFormulaCoop() {
Double[] coefs = getCoefCoop();
String result = "" +coefs[0];
// //System.out.println("Result 0" + " : " + result);
if (coefs[0] == Double.NaN) System.exit(0);
for (int i = 1 ; i < coefs.length ; i++) {
if (Double.isNaN(coefs[i])) coefs[i] = 0.0;
result += "\t" + coefs[i] + " (" + getContext().getAmas().getPercepts().get(i-1) +")";
}
return result;
}
@Override
public void updateModel(Experiment newExperiment, double weight) {
localModel.updateModel(newExperiment, weight);
......@@ -137,19 +114,40 @@ public class LocalModelCoop implements LocalModel {
public boolean finishedFirstExperiments() {
return localModel.finishedFirstExperiments();
}
@Override
public Double[] getCoef() {
return localModel.getCoef();
}
@Override
public String getCoefsFormula() {
Double[] coefs = getCoefCoop();
String result = "" +coefs[0];
if (coefs[0] == Double.NaN) System.exit(0);
for (int i = 1 ; i < coefs.length ; i++) {
if (Double.isNaN(coefs[i])) coefs[i] = 0.0;
result += "\t" + coefs[i] + " (" + getContext().getAmas().getPercepts().get(i-1) +")";
}
result += "\nFrom " +localModel.getType()+" : "+localModel.getCoefsFormula();
return result;
}
public Double[] getCoefCoop() {
Set<Context> neighbors = getNeighbors();
Double[] coef = localModel.getCoef().clone();
int i = 0;
for(Percept p : getContext().getRanges().keySet()) {
for(Context c : neighbors) {
Double[] coef2 = c.getLocalModel().getCoef();
LocalModel model = c.getLocalModel();
while(model.hasModified()) {
model = model.getModified();
}
Double[] coef2 = model.getCoef();
coef[i] += coef2[i]/neighbors.size()*getCommonFrontierCoef(p, getContext(), c);
}
i++;
......@@ -195,7 +193,12 @@ public class LocalModelCoop implements LocalModel {
@Override
public TypeLocalModel getType() {
return localModel.getType();
return type;
}
@Override
public void setType(TypeLocalModel type) {
this.type = type;
}
}
......@@ -14,9 +14,7 @@ import utils.TRACE_LEVEL;
/**
* The Class LocalModelMillerRegression.
*/
public class LocalModelMillerRegression implements LocalModel{
private Context context;
public class LocalModelMillerRegression extends LocalModel{
/** The n parameters. */
private int nParameters;
......@@ -24,6 +22,7 @@ public class LocalModelMillerRegression implements LocalModel{
/** The regression. */
transient private Regression regression;
private Context context;
/** The coef. */
private Double[] coefs;
......@@ -36,7 +35,7 @@ public class LocalModelMillerRegression implements LocalModel{
* @param world the world
*/
public LocalModelMillerRegression(Context associatedContext) {
context = associatedContext;
this.context = associatedContext;
ArrayList<Percept> var = associatedContext.getAmas().getPercepts();
this.nParameters = var.size();
regression = new Regression(nParameters,true);
......@@ -44,7 +43,7 @@ public class LocalModelMillerRegression implements LocalModel{
}
public LocalModelMillerRegression(Context associatedContext, Double[] coefsCopy, List<Experiment> fstExperiments) {
context = associatedContext;
this.context = associatedContext;
ArrayList<Percept> var = associatedContext.getAmas().getPercepts();
this.nParameters = var.size();
regression = new Regression(nParameters,true);
......@@ -53,30 +52,20 @@ public class LocalModelMillerRegression implements LocalModel{
}
@Override
public void setContext(Context context) {
this.context = context;
public Context getContext() {
return context;
}
@Override
public Context getContext() {
return context;
public void setContext(Context context) {
this.context = context;
}
/**
* Sets the coef.
*
* @param coef the new coef
*/
@Override
public void setCoef(Double[] coef) {
this.coefs = coef.clone();
}
/**
* Gets the coef.
*
* @return the coef
*/
@Override
public Double[] getCoef() {
return coefs;
......@@ -215,22 +204,6 @@ public class LocalModelMillerRegression implements LocalModel{
}
@Override
public String getCoefsFormula() {
String result = "" +coefs[0];
// //System.out.println("Result 0" + " : " + result);
if (coefs[0] == Double.NaN) System.exit(0);
for (int i = 1 ; i < coefs.length ; i++) {
if (Double.isNaN(coefs[i])) coefs[i] = 0.0;
result += "\t" + coefs[i] + " (" + context.getAmas().getPercepts().get(i-1) +")";
}
return result;
}
@Override
public void updateModel(Experiment newExperiment, double weight) {
context.getAmas().getEnvironment().trace(TRACE_LEVEL.INFORM, new ArrayList<String>(Arrays.asList(context.getName(),"NEW POINT REGRESSION", "FIRST POINTS :", ""+firstExperiments.size(), "OLD MODEL :", coefsToString())));
......@@ -500,4 +473,8 @@ public class LocalModelMillerRegression implements LocalModel{
public TypeLocalModel getType() {
return TypeLocalModel.MILLER_REGRESSION;
}
@Override
public void setType(TypeLocalModel type) {
}
}
package agents.context.localModel;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
import agents.context.localModel.factories.LocalModelCoopFactory;
import agents.context.localModel.factories.LocalModelFactory;
import agents.context.localModel.factories.LocalModelMillerRegressionFactory;
/**
* Defines the different implemented local model. Each local model is associated
......@@ -9,5 +15,24 @@ import java.io.Serializable;
*/
public enum TypeLocalModel implements Serializable {
/** The miller regression. */
MILLER_REGRESSION
MILLER_REGRESSION(new LocalModelMillerRegressionFactory()),
COOP_MILLER_REGRESSION(new LocalModelCoopFactory(MILLER_REGRESSION.factory));
public final LocalModelFactory factory;
private static final Map<LocalModelFactory, TypeLocalModel> BY_FACTORY = new HashMap<>();
static {
for (TypeLocalModel t : values()) {
BY_FACTORY.put(t.factory, t);
}
}
private TypeLocalModel(LocalModelFactory factory) {
this.factory = factory;
}
public static TypeLocalModel valueOf(LocalModelCoopFactory factory) {
return BY_FACTORY.get(factory);
}
}
package agents.context.localModel.factories;
import agents.context.localModel.LocalModel;
import agents.context.localModel.LocalModelCoopModifier;
import agents.context.localModel.TypeLocalModel;
/**
* A factory for creating {@link LocalModelCoopModifier}. Take a {@link LocalModel} as param,
* or a {@link LocalModelFactory} with all param to build a LocalModel.
* @author Hugo
*
*/
public class LocalModelCoopFactory implements LocalModelFactory {
private LocalModelFactory factory;
public LocalModelCoopFactory(LocalModelFactory factory) {
this.factory = factory;
}
public LocalModelCoopFactory() {
this.factory = null;
}
@Override
public LocalModel buildLocalModel(Object... params) {
if(factory != null) {
return new LocalModelCoopModifier(factory.buildLocalModel(params), TypeLocalModel.valueOf(this));
} else {
if(params.length != 1) {
throw new IllegalArgumentException("Expected one "+LocalModel.class+", got "+params.length+" arguments");
}
if(!(params[0] instanceof LocalModel)) {
throw new IllegalArgumentException("Expected "+LocalModel.class+", got "+params[0].getClass());
}
LocalModel lm = (LocalModel) params[0];
return new LocalModelCoopModifier(lm, TypeLocalModel.valueOf(this));
}
}
}
package agents.context.localModel.factories;
import agents.context.localModel.LocalModel;
public interface LocalModelFactory {
public LocalModel buildLocalModel(Object ...params);
}
package agents.context.localModel.factories;
import agents.context.Context;
import agents.context.localModel.LocalModel;
import agents.context.localModel.LocalModelMillerRegression;
/**
* A factory for creating {@link LocalModelMillerRegression}. Take a {@link Context} as param.
* @author Hugo
*
*/
public class LocalModelMillerRegressionFactory implements LocalModelFactory {
@Override
public LocalModel buildLocalModel(Object... params) {
if(params.length != 1) {
throw new IllegalArgumentException("Expected one "+Context.class+", got "+params.length+" arguments");
}
if(!(params[0] instanceof Context)) {
throw new IllegalArgumentException("Expected "+Context.class+", got "+params[0].getClass());
}
Context c = (Context) params[0];
return new LocalModelMillerRegression(c);
}
}
......@@ -5,7 +5,9 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Random;
import java.util.Scanner;
import agents.context.localModel.TypeLocalModel;
import fr.irit.smac.amak.Configuration;
import fr.irit.smac.amak.tools.Log;
import fr.irit.smac.amak.ui.drawables.Drawable;
......@@ -30,14 +32,14 @@ public class SimpleReinforcement {
/* Learn and Test */
public static final int MAX_STEP_PER_EPISODE = 200;
public static final int N_LEARN = 200;
public static final int N_TEST = 200;
public static final int N_TEST = 10;
/* Exploration */
public static final int N_EXPLORE_LINE = 100;
public static final int N_EXPLORE_LINE = 0;
public static final double MIN_EXPLO_RATE = 0.02;
public static final double EXPLO_RATE_DIMINUTION_FACTOR = 0.01;
public static final double EXPLO_RATE_DIMINUTION_FACTOR = 0.0;
public static final double EXPLO_RATE_BASE = 1;
public static final String EXPLORATION_STRATEGY = "line"; // can be "random" or "line"
public static final String EXPLORATION_STRATEGY = "random"; // can be "random" or "line"
private static int exploreLine;
private Random rand = new Random();
......@@ -66,6 +68,8 @@ public class SimpleReinforcement {
public Amoeba() {
amoeba = setup();
amoeba.setLocalModel(TypeLocalModel.COOP_MILLER_REGRESSION);
amoeba.getEnvironment().setMappingErrorAllowed(0.009);
}
@Override
......@@ -95,11 +99,17 @@ public class SimpleReinforcement {
public double[][] Q = new double[102][2];
public double lr = 0.8;
public double gamma = 0.9;
private Random rand = new Random();
@Override
public double choose(HashMap<String, Double> state) {
int p = state.get("p1").intValue()+50;
double a = Q[p][0] > Q[p][1] ? -1 : 1;
double a;
if(Q[p][0] == Q[p][1]) {
a = rand.nextBoolean() ? -1 : 1;
} else {
a = Q[p][0] > Q[p][1] ? -1 : 1;
}
return a;
}
......@@ -127,16 +137,17 @@ public class SimpleReinforcement {
public static void main(String[] args) {
//poc(true);
Configuration.commandLineMode = true;
//Configuration.commandLineMode = true;
System.out.println("----- AMOEBA -----");
learning(new Amoeba());
System.out.println("----- END AMOEBA -----");
System.out.println("\n\n----- QLEARNING -----");
/*System.out.println("\n\n----- QLEARNING -----");
learning(new QLearning());
System.out.println("----- END QLEARNING -----");
System.out.println("----- END QLEARNING -----");*/
/*ArrayList<ArrayList<Double>> results = new ArrayList<>();
for(int i = 0; i < 1; i++) {
results.add(exp1());
LearningAgent agent = new QLearning();
for(int i = 0; i < 100; i++) {
results.add(learning(agent));
System.out.println(i);
}
......@@ -150,7 +161,7 @@ public class SimpleReinforcement {
System.out.println(""+i+"\t"+average);
}
*/
System.exit(0);
//System.exit(0);
}
/**
......@@ -244,6 +255,9 @@ public class SimpleReinforcement {
System.out.println("Episode "+i+" reward : "+totReward+" explo : "+explo);
double testAR = test(agent, env, r, N_TEST);
averageRewards.add(testAR);
//Scanner scan = new Scanner(System.in);
//scan.nextLine();
}
return averageRewards;
......@@ -287,12 +301,6 @@ public class SimpleReinforcement {
double averageReward = tot_reward/nbTest;
System.out.println("Test average reward : "+averageReward+" Positive reward %: "+(nbPositiveReward/nbTest));
if(!Configuration.commandLineMode) {
AmoebaWindow.instance().point.hide();
AmoebaWindow.instance().rectangle.hide();
AmoebaWindow.instance().mainVUI.updateCanvas();
}
return averageReward;
}
......@@ -469,7 +477,7 @@ public class SimpleReinforcement {
instance.mainVUI.createAndAddRectangle(-50, -0.25, 100, 0.5);
instance.mainVUI.createAndAddRectangle(-0.25, -1, 0.5, 2);
instance.point.hide();
instance.rectangle.hide();
//instance.rectangle.hide();
}
......
......@@ -12,8 +12,6 @@ import java.util.stream.Stream;
import agents.AmoebaAgent;
import agents.context.Context;
import agents.context.localModel.LocalModel;
import agents.context.localModel.LocalModelCoop;
import agents.context.localModel.LocalModelMillerRegression;
import agents.context.localModel.TypeLocalModel;
import agents.head.Head;
import agents.percept.Percept;
......@@ -419,7 +417,7 @@ public class AMOEBA extends Amas<World> implements IAMOEBA {
ArrayList<HashMap<String, Double>> sol = new ArrayList<>();
for(Context c : pac) {
sol.add(((LocalModelCoop)c.getLocalModel()).getMaxWithConstraintCoop(known));
sol.add(c.getLocalModel().getMaxWithConstraint(known));
}
HashMap<String, Double> max = new HashMap<>();
......@@ -436,13 +434,7 @@ public class AMOEBA extends Amas<World> implements IAMOEBA {
}
public LocalModel buildLocalModel(Context context, TypeLocalModel type) {
switch (type) {
case MILLER_REGRESSION:
return new LocalModelCoop(new LocalModelMillerRegression(context));
default:
throw new IllegalArgumentException("Unknown model " + localModel + ".");
}
return type.factory.buildLocalModel(context);
}
public LocalModel buildLocalModel(Context context) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment