Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
A
AMOEBA3
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
SMAC
Learning Group
AMOEBA3
Commits
4cf7a490
Commit
4cf7a490
authored
5 years ago
by
Hugo Roussel
Browse files
Options
Downloads
Patches
Plain Diff
Added proof of concept
parent
f14bb865
Branches
Branches containing commit
No related tags found
1 merge request
!4
Exp rein
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
AMOEBAonAMAK/src/experiments/SimpleReinforcement.java
+288
-34
288 additions, 34 deletions
AMOEBAonAMAK/src/experiments/SimpleReinforcement.java
with
288 additions
and
34 deletions
AMOEBAonAMAK/src/experiments/SimpleReinforcement.java
+
288
−
34
View file @
4cf7a490
...
@@ -8,15 +8,19 @@ import java.util.Deque;
...
@@ -8,15 +8,19 @@ import java.util.Deque;
import
java.util.HashMap
;
import
java.util.HashMap
;
import
java.util.Random
;
import
java.util.Random
;
import
agents.context.localModel.TypeLocalModel
;
import
fr.irit.smac.amak.Configuration
;
import
fr.irit.smac.amak.Configuration
;
import
fr.irit.smac.amak.tools.Log
;
import
fr.irit.smac.amak.ui.drawables.Drawable
;
import
fr.irit.smac.amak.ui.drawables.Drawable
;
import
fr.irit.smac.amak.ui.drawables.DrawableOval
;
import
fr.irit.smac.amak.ui.drawables.DrawableOval
;
import
gui.AmoebaWindow
;
import
gui.AmoebaWindow
;
import
javafx.scene.paint.Color
;
import
javafx.scene.paint.Color
;
import
kernel.AMOEBA
;
import
kernel.AMOEBA
;
import
kernel.World
;
import
kernel.backup.SaveHelperDummy
;
import
kernel.backup.SaveHelperDummy
;
import
utils.Pair
;
import
utils.Pair
;
import
utils.RandomUtils
;
import
utils.RandomUtils
;
import
utils.TRACE_LEVEL
;
import
utils.XmlConfigGenerator
;
import
utils.XmlConfigGenerator
;
/**
/**
...
@@ -28,6 +32,10 @@ import utils.XmlConfigGenerator;
...
@@ -28,6 +32,10 @@ import utils.XmlConfigGenerator;
*
*
*/
*/
public
class
SimpleReinforcement
{
public
class
SimpleReinforcement
{
public
static
final
int
N_EXPLORE_LINE
=
60
;
public
static
final
double
MIN_EXPLO_RATE
=
0.02
;
public
static
final
double
EXPLO_RATE_DIMINUTION_FACTOR
=
0.01
;
private
static
int
exploreLine
;
private
Random
rand
=
new
Random
();
private
Random
rand
=
new
Random
();
private
double
x
=
0
;
private
double
x
=
0
;
...
@@ -35,6 +43,12 @@ public class SimpleReinforcement {
...
@@ -35,6 +43,12 @@ public class SimpleReinforcement {
private
Drawable
pos
;
private
Drawable
pos
;
public
static
void
main
(
String
[]
args
)
{
public
static
void
main
(
String
[]
args
)
{
poc
(
true
);
//exp1();
}
public
static
void
exp1
()
{
ArrayList
<
Pair
<
String
,
Boolean
>>
sensors
=
new
ArrayList
<>();
ArrayList
<
Pair
<
String
,
Boolean
>>
sensors
=
new
ArrayList
<>();
sensors
.
add
(
new
Pair
<
String
,
Boolean
>(
"p1"
,
false
));
sensors
.
add
(
new
Pair
<
String
,
Boolean
>(
"p1"
,
false
));
sensors
.
add
(
new
Pair
<
String
,
Boolean
>(
"a1"
,
true
));
sensors
.
add
(
new
Pair
<
String
,
Boolean
>(
"a1"
,
true
));
...
@@ -49,6 +63,8 @@ public class SimpleReinforcement {
...
@@ -49,6 +63,8 @@ public class SimpleReinforcement {
}
}
Configuration
.
commandLineMode
=
true
;
Configuration
.
commandLineMode
=
true
;
Log
.
defaultMinLevel
=
Log
.
Level
.
INFORM
;
World
.
minLevel
=
TRACE_LEVEL
.
ERROR
;
AMOEBA
amoeba
=
new
AMOEBA
(
config
.
getAbsolutePath
(),
null
);
AMOEBA
amoeba
=
new
AMOEBA
(
config
.
getAbsolutePath
(),
null
);
amoeba
.
saver
=
new
SaveHelperDummy
();
amoeba
.
saver
=
new
SaveHelperDummy
();
SimpleReinforcement
env
=
new
SimpleReinforcement
();
SimpleReinforcement
env
=
new
SimpleReinforcement
();
...
@@ -56,25 +72,30 @@ public class SimpleReinforcement {
...
@@ -56,25 +72,30 @@ public class SimpleReinforcement {
Random
r
=
new
Random
();
Random
r
=
new
Random
();
HashMap
<
String
,
Double
>
state
=
env
.
reset
();
HashMap
<
String
,
Double
>
state
=
env
.
reset
();
HashMap
<
String
,
Double
>
state2
;
HashMap
<
String
,
Double
>
state2
;
double
explo
=
0.5
;
double
explo
=
0.8
;
for
(
int
i
=
0
;
i
<
100
;
i
++)
{
int
nbGood
=
0
;
boolean
done
=
false
;
int
nbBad
=
0
;
for
(
int
i
=
0
;
i
<
1000
;
i
++)
{
Deque
<
HashMap
<
String
,
Double
>>
actions
=
new
ArrayDeque
<>();
Deque
<
HashMap
<
String
,
Double
>>
actions
=
new
ArrayDeque
<>();
//System.out.println("Explore "+i);
//System.out.println("Explore "+i);
int
nbStep
=
0
;
int
nbStep
=
0
;
state
=
env
.
reset
();
state
=
env
.
reset
();
while
(!
done
)
{
exploreLine
=
N_EXPLORE_LINE
;
HashMap
<
String
,
Double
>
action
=
new
HashMap
<
String
,
Double
>();
// execute simulation cycles
boolean
done
=
false
;
boolean
invalid
=
false
;
while
(!
done
&&
!
invalid
)
{
nbStep
++;
nbStep
++;
if
(
nbStep
>
5
00
)
{
if
(
nbStep
>
2
00
)
{
done
=
true
;
invalid
=
true
;
}
}
state
.
remove
(
"oracle"
);
state
.
remove
(
"oracle"
);
state
.
remove
(
"a1"
);
double
lastAction
=
action
.
getOrDefault
(
"a1"
,
0.0
);
HashMap
<
String
,
Double
>
action
=
amoeba
.
maximize
(
state
);
action
=
amoeba
.
maximize
(
state
);
if
(
r
.
nextDouble
()
<
0.5
||
action
.
get
(
"oracle"
).
equals
(
Double
.
NEGATIVE_INFINITY
)
)
{
explore
(
r
,
explo
,
action
,
lastAction
);
//System.out.println("Random action");
action
.
put
(
"a1"
,
(
r
.
nextBoolean
()
?
10.0
:
-
10.0
));
}
state2
=
env
.
step
(
action
.
get
(
"a1"
));
state2
=
env
.
step
(
action
.
get
(
"a1"
));
if
(
state2
.
get
(
"oracle"
)
!=
-
1.0
)
{
if
(
state2
.
get
(
"oracle"
)
!=
-
1.0
)
{
...
@@ -84,30 +105,263 @@ public class SimpleReinforcement {
...
@@ -84,30 +105,263 @@ public class SimpleReinforcement {
action
.
put
(
"p1"
,
state
.
get
(
"p1"
));
action
.
put
(
"p1"
,
state
.
get
(
"p1"
));
action
.
put
(
"oracle"
,
state2
.
get
(
"oracle"
));
action
.
put
(
"oracle"
,
state2
.
get
(
"oracle"
));
//System.out.println(action);
//System.out.println(action);
actions
.
add
(
action
);
actions
.
push
(
action
);
state
=
state2
;
state
=
state2
;
}
}
//System.out.println("Learn "+i);
if
(!
invalid
)
{
HashMap
<
String
,
Double
>
action
=
actions
.
pop
();
// build learning set
double
reward
=
action
.
get
(
"oracle"
);
HashMap
<
String
,
Double
>
step
=
actions
.
pop
();
amoeba
.
learn
(
action
);
double
reward
=
step
.
get
(
"oracle"
);
Deque
<
HashMap
<
String
,
Double
>>
learnSet
=
new
ArrayDeque
<>();
learnSet
.
add
(
step
);
while
(!
actions
.
isEmpty
())
{
step
=
actions
.
pop
();
reward
+=
step
.
get
(
"oracle"
);
step
.
put
(
"oracle"
,
reward
);
learnSet
.
push
(
step
);
}
// learn
while
(!
learnSet
.
isEmpty
())
{
HashMap
<
String
,
Double
>
a
=
learnSet
.
pop
();
//System.out.println("("+a.get("p1")+"\t, "+a.get("a1")+"\t, "+a.get("oracle")+")");
amoeba
.
learn
(
a
);
}
//System.exit(0);
// update exploration rate
if
(
explo
>
MIN_EXPLO_RATE
)
{
explo
-=
EXPLO_RATE_DIMINUTION_FACTOR
;
if
(
explo
<
MIN_EXPLO_RATE
)
explo
=
MIN_EXPLO_RATE
;
}
String
goobBad
;
if
(
exploreLine
<
N_EXPLORE_LINE
)
{
nbBad
++;
goobBad
=
"BAD "
;
}
else
{
nbGood
++;
goobBad
=
"GOOD "
;
}
System
.
out
.
println
(
goobBad
+
"Episode "
+
i
+
" reward : "
+
reward
+
" explo : "
+
explo
);
}
else
{
nbBad
++;
System
.
out
.
println
(
"BAD Episode "
+
i
+
" invalid."
);
}
}
double
percentGood
=
((
double
)
nbGood
)/(
nbGood
+
nbBad
);
System
.
out
.
println
(
"Good: "
+
nbGood
+
" Bad: "
+
nbBad
+
" Good%: "
+
percentGood
);
// tests
double
tot_reward
=
0.0
;
for
(
int
i
=
0
;
i
<
500
;
i
++)
{
double
reward
=
0.0
;
state
=
env
.
reset
();
HashMap
<
String
,
Double
>
action
=
new
HashMap
<
String
,
Double
>();
// execute simulation cycles
boolean
done
=
false
;
int
nbStep
=
0
;
while
(!
done
)
{
nbStep
++;
if
(
nbStep
>
1000
)
{
done
=
true
;
}
state
.
remove
(
"oracle"
);
action
=
amoeba
.
maximize
(
state
);
// random action if no proposition from amoeba
if
(
action
.
get
(
"oracle"
).
equals
(
Double
.
NEGATIVE_INFINITY
)
)
{
action
.
put
(
"a1"
,
(
r
.
nextBoolean
()
?
1.0
:
-
1.0
));
}
state2
=
env
.
step
(
action
.
get
(
"a1"
));
if
(
state2
.
get
(
"oracle"
)
!=
-
1.0
)
{
done
=
true
;
}
reward
+=
state2
.
get
(
"oracle"
);
state
=
state2
;
}
while
(!
actions
.
isEmpty
())
{
tot_reward
+=
reward
;
action
=
actions
.
pop
();
}
reward
+=
action
.
get
(
"oracle"
);
System
.
out
.
println
(
"Average reward : "
+
tot_reward
/
500.0
);
}
/**
* This is a proof of concept, showing that if amoeba learn the correct model of the reward,
* it can produce a good solution.
* The expected average reward for the optimal solution is 75.
* The main cause of negative reward is infinite loop (usually near the objective). In such case, the reward is -200
*/
public
static
void
poc
(
boolean
learnMalus
)
{
ArrayList
<
Pair
<
String
,
Boolean
>>
sensors
=
new
ArrayList
<>();
sensors
.
add
(
new
Pair
<
String
,
Boolean
>(
"p1"
,
false
));
sensors
.
add
(
new
Pair
<
String
,
Boolean
>(
"a1"
,
true
));
File
config
;
try
{
config
=
File
.
createTempFile
(
"config"
,
"xml"
);
XmlConfigGenerator
.
makeXML
(
config
,
sensors
);
}
catch
(
IOException
e
)
{
e
.
printStackTrace
();
System
.
exit
(
1
);
return
;
// now compilator know config is initialized
}
Log
.
defaultMinLevel
=
Log
.
Level
.
INFORM
;
World
.
minLevel
=
TRACE_LEVEL
.
ERROR
;
AMOEBA
amoeba
=
new
AMOEBA
(
config
.
getAbsolutePath
(),
null
);
amoeba
.
saver
=
new
SaveHelperDummy
();
SimpleReinforcement
env
=
new
SimpleReinforcement
();
// train
for
(
double
n
=
0.0
;
n
<
0.5
;
n
+=
0.1
)
{
double
pos
=
50.0
-
n
;
for
(
int
i
=
0
;
i
<
49
;
i
++)
{
double
reward
=
100
-
Math
.
abs
(
pos
);
HashMap
<
String
,
Double
>
action
=
new
HashMap
<
String
,
Double
>();
action
.
put
(
"p1"
,
pos
);
action
.
put
(
"a1"
,
-
1.0
);
action
.
put
(
"oracle"
,
reward
);
action
.
put
(
"oracle"
,
reward
);
amoeba
.
learn
(
action
);
amoeba
.
learn
(
action
);
if
(
learnMalus
)
{
reward
=
-
150
+
Math
.
abs
(
pos
);
action
.
put
(
"a1"
,
1.0
);
action
.
put
(
"oracle"
,
reward
);
amoeba
.
learn
(
action
);
}
pos
-=
1.0
;
}
}
if
(
explo
>
0.1
)
{
pos
=
-
50.0
-
n
;
explo
-=
0.01
;
for
(
int
i
=
0
;
i
<
49
;
i
++)
{
if
(
explo
<
0.1
)
double
reward
=
100
-
Math
.
abs
(
pos
);
explo
=
0.1
;
HashMap
<
String
,
Double
>
action
=
new
HashMap
<
String
,
Double
>();
action
.
put
(
"p1"
,
pos
);
action
.
put
(
"a1"
,
1.0
);
action
.
put
(
"oracle"
,
reward
);
amoeba
.
learn
(
action
);
if
(
learnMalus
)
{
reward
=
-
150
+
Math
.
abs
(
pos
);
action
.
put
(
"a1"
,
-
1.0
);
action
.
put
(
"oracle"
,
reward
);
amoeba
.
learn
(
action
);
}
pos
+=
1.0
;
}
}
}
// increase precision of model near objective
// right now it make things worst
/*
for(int n = 0; n < 5; n++) {
for(double pos = 2.0; pos > 0.02; pos -= 0.1) {
double reward = 100 - Math.abs(pos);
HashMap<String, Double> action = new HashMap<String, Double>();
action.put("p1", pos);
action.put("a1", -1.0);
action.put("oracle", reward);
amoeba.learn(action);
if(learnMalus) {
reward = -150 + Math.abs(pos);
action.put("p1", pos);
action.put("a1", 1.0);
action.put("oracle", reward);
amoeba.learn(action);
}
action.put("p1", -pos);
action.put("a1", 1.0);
action.put("oracle", reward);
amoeba.learn(action);
if(learnMalus) {
reward = -150 + Math.abs(pos);
action.put("p1", -pos);
action.put("a1", -1.0);
action.put("oracle", reward);
amoeba.learn(action);
}
}
}
*/
// tests
Random
r
=
new
Random
();
HashMap
<
String
,
Double
>
state
=
env
.
reset
();
HashMap
<
String
,
Double
>
state2
;
double
tot_reward
=
0.0
;
int
nbTest
=
100
;
double
nbPositiveReward
=
0
;
for
(
int
i
=
0
;
i
<
nbTest
;
i
++)
{
double
reward
=
0.0
;
state
=
env
.
reset
();
HashMap
<
String
,
Double
>
action
=
new
HashMap
<
String
,
Double
>();
System
.
out
.
println
(
"Episode "
+
i
+
" reward : "
+
reward
+
" explo : "
+
explo
);
// execute simulation cycles
boolean
done
=
false
;
int
nbStep
=
0
;
while
(!
done
)
{
nbStep
++;
if
(
nbStep
>
200
)
{
done
=
true
;
}
state
.
remove
(
"oracle"
);
action
=
amoeba
.
maximize
(
state
);
// random action if no proposition from amoeba
if
(
action
.
get
(
"oracle"
).
equals
(
Double
.
NEGATIVE_INFINITY
)
)
{
action
.
put
(
"a1"
,
(
r
.
nextBoolean
()
?
1.0
:
-
1.0
));
}
//System.out.println("action "+action);
state2
=
env
.
step
(
action
.
get
(
"a1"
));
if
(
state2
.
get
(
"oracle"
)
!=
-
1.0
)
{
done
=
true
;
}
reward
+=
state2
.
get
(
"oracle"
);
//System.out.println("state2 "+state2+" reward "+reward);
state
=
state2
;
}
if
(
reward
>
0
)
{
nbPositiveReward
+=
1.0
;
}
tot_reward
+=
reward
;
//System.out.println("-----------------------------\nTot reward "+tot_reward+"\n-----------------------------");
}
System
.
out
.
println
(
"Average reward : "
+
tot_reward
/
nbTest
+
" Positive reward %: "
+(
nbPositiveReward
/
nbTest
));
AmoebaWindow
.
instance
().
point
.
move
(
100
,
100
);
AmoebaWindow
.
instance
().
mainVUI
.
updateCanvas
();
}
private
static
void
explore
(
Random
r
,
double
explo
,
HashMap
<
String
,
Double
>
action
,
double
lastAction
)
{
// if we were in the process of going in a straight line, continue
if
(
exploreLine
<
N_EXPLORE_LINE
&&
lastAction
!=
0.0
)
{
action
.
put
(
"a1"
,
lastAction
);
exploreLine
++;
}
else
{
// else if we have to explore
if
(
r
.
nextDouble
()
<
explo
||
action
.
get
(
"oracle"
).
equals
(
Double
.
NEGATIVE_INFINITY
)
)
{
// maybe next time go in a straight line
if
(
r
.
nextBoolean
())
{
exploreLine
=
0
;
}
// chose a random action
action
.
put
(
"a1"
,
(
r
.
nextBoolean
()
?
1.0
:
-
1.0
));
}
}
}
}
}
...
@@ -115,14 +369,14 @@ public class SimpleReinforcement {
...
@@ -115,14 +369,14 @@ public class SimpleReinforcement {
* Must be called AFTER an AMOEBA with GUI
* Must be called AFTER an AMOEBA with GUI
*/
*/
public
SimpleReinforcement
()
{
public
SimpleReinforcement
()
{
//
Configuration.commandLineMode
= false;
if
(!
Configuration
.
commandLineMode
)
{
//
AmoebaWindow instance = AmoebaWindow.instance();
AmoebaWindow
instance
=
AmoebaWindow
.
instance
();
//pos = new DrawableOval(0.5, 0.5, 1, 1);
//pos = new DrawableOval(0.5, 0.5, 1, 1);
//pos.setColor(new Color(0.5, 0.0, 0.0, 0.5));
//pos.setColor(new Color(0.5, 0.0, 0.0, 0.5));
//instance.mainVUI.add(pos);
//instance.mainVUI.add(pos);
//
instance.mainVUI.createAndAddRectangle(-50, -0.25, 100, 0.5);
instance
.
mainVUI
.
createAndAddRectangle
(-
50
,
-
0.25
,
100
,
0.5
);
//
instance.mainVUI.createAndAddRectangle(-0.25, -1, 0.5, 2);
instance
.
mainVUI
.
createAndAddRectangle
(-
0.25
,
-
1
,
0.5
,
2
);
}
}
}
...
@@ -140,7 +394,7 @@ public class SimpleReinforcement {
...
@@ -140,7 +394,7 @@ public class SimpleReinforcement {
reward
=
-
100.0
;
reward
=
-
100.0
;
}
else
if
(
x
==
0.0
||
sign
(
oldX
)
!=
sign
(
x
))
{
}
else
if
(
x
==
0.0
||
sign
(
oldX
)
!=
sign
(
x
))
{
// win !
// win !
reward
=
100
0
.0
;
reward
=
100.0
;
x
=
RandomUtils
.
nextDouble
(
rand
,
-
50.0
,
Math
.
nextUp
(
50.0
));
x
=
RandomUtils
.
nextDouble
(
rand
,
-
50.0
,
Math
.
nextUp
(
50.0
));
}
else
{
}
else
{
reward
=
-
1.0
;
reward
=
-
1.0
;
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment