Skip to content

Commit 9d92c19

Browse files
committed
update environments to use RLCore
1 parent 89cce76 commit 9d92c19

18 files changed

Lines changed: 258 additions & 330 deletions

Manifest.toml

Lines changed: 93 additions & 152 deletions
Large diffs are not rendered by default.

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ version = "0.2.0"
77
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
88
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
99
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
10-
ReinforcementLearning = "158674fc-8238-5cab-b5ba-03dfc80d1318"
11-
ReinforcementLearningEnvironments = "25e41dd2-4622-11e9-1641-f1adca772921"
10+
ReinforcementLearningBase = "e575027e-6cd6-5018-9292-cdc6200d2b44"
11+
ReinforcementLearningCore = "de1b191a-4ae0-4afa-a27b-92d07f46b2d6"
1212
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1313
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1414
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"

notebooks/Chapter01_Tic_Tac_Toe.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@
101101
{
102102
"data": {
103103
"text/plain": [
104-
"Observation{Float64,Bool,Int64,NamedTuple{(:legal_actions,),Tuple{Array{Bool,1}}}}(0.0, false, 4186, (legal_actions = Bool[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],))"
104+
"{Float64,Bool,Int64,NamedTuple{(:legal_actions,),Tuple{Array{Bool,1}}}}(0.0, false, 4186, (legal_actions = Bool[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],))"
105105
]
106106
},
107107
"execution_count": 4,

notebooks/Chapter09_Random_Walk.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@
115115
{
116116
"data": {
117117
"text/plain": [
118-
"Observation{Float64,Bool,Int64,NamedTuple{(),Tuple{}}}(0.0, false, 501, NamedTuple())"
118+
"{Float64,Bool,Int64,NamedTuple{(),Tuple{}}}(0.0, false, 501, NamedTuple())"
119119
]
120120
},
121121
"execution_count": 6,

src/environments/AccessControl.jl

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
module AccessControl
22

3-
export AccessControlEnv, reset!, observe, interact!
3+
export AccessControlEnv
4+
5+
using ReinforcementLearningCore
46

5-
using ReinforcementLearningEnvironments
6-
import ReinforcementLearningEnvironments: reset!, observe, interact!
77

88
using Distributions
99

@@ -16,25 +16,17 @@ const CUSTOMERS = 1:length(PRIORITIES)
1616

1717
const TRANSFORMER = LinearIndices((0:N_SERVERS, CUSTOMERS))
1818

19-
mutable struct AccessControlEnv <: AbstractEnv
20-
n_servers::Int
21-
n_free_servers::Int
22-
customer::Int
23-
reward::Float64
24-
observation_space::DiscreteSpace
25-
action_space::DiscreteSpace
26-
AccessControlEnv() =
27-
new(
28-
10,
29-
0,
30-
rand(CUSTOMERS),
31-
0.0,
32-
DiscreteSpace(length(TRANSFORMER)),
33-
DiscreteSpace(2),
34-
)
19+
Base.@kwdef mutable struct AccessControlEnv <: AbstractEnv
20+
n_servers::Int = 10
21+
n_free_servers::Int = 0
22+
customer::Int = rand(CUSTOMERS)
23+
reward::Float64 = 0.0
3524
end
3625

37-
function interact!(env::AccessControlEnv, a)
26+
RLBase.get_observation_space(env::AccessControlEnv) = DiscreteSpace(length(TRANSFORMER))
27+
RLBase.get_action_space(env::AccessControlEnv) = DiscreteSpace(2)
28+
29+
function (env::AccessControlEnv)(a)
3830
action, reward = ACTIONS[a], 0.0
3931
if env.n_free_servers > 0 && action == :accept
4032
env.n_free_servers -= 1
@@ -48,14 +40,14 @@ function interact!(env::AccessControlEnv, a)
4840
nothing
4941
end
5042

51-
observe(env::AccessControlEnv) =
52-
Observation(
43+
RLBase.observe(env::AccessControlEnv) =
44+
(
5345
reward = env.reward,
5446
terminal = false,
5547
state = TRANSFORMER[CartesianIndex(env.n_free_servers + 1, env.customer)],
5648
)
5749

58-
function reset!(env::AccessControlEnv)
50+
function RLBase.reset!(env::AccessControlEnv)
5951
env.n_free_servers = env.n_servers
6052
env.customer = rand(CUSTOMERS)
6153
env.reward = 0.0

src/environments/BairdCounter.jl

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,19 @@
11
module BairdCounter
22

3-
export BairdCounterEnv, reset!, observe, interact!
3+
export BairdCounterEnv
44

5-
using ReinforcementLearningEnvironments
6-
import ReinforcementLearningEnvironments: reset!, observe, interact!
5+
using ReinforcementLearningCore
76

87
const ACTIONS = (:dashed, :solid)
98

10-
mutable struct BairdCounterEnv <: AbstractEnv
11-
current::Int
12-
observation_space::DiscreteSpace
13-
action_space::DiscreteSpace
14-
BairdCounterEnv() = new(rand(1:7), DiscreteSpace(7), DiscreteSpace(length(ACTIONS)))
9+
Base.@kwdef mutable struct BairdCounterEnv <: AbstractEnv
10+
current::Int = rand(1:7)
1511
end
1612

17-
function interact!(env::BairdCounterEnv, a)
13+
RLBase.get_observation_space(env::BairdCounterEnv) = DiscreteSpace(7)
14+
RLBase.get_action_space(env::BairdCounterEnv) = DiscreteSpace(length(ACTIONS))
15+
16+
function (env::BairdCounterEnv)(a)
1817
if ACTIONS[a] == :dashed
1918
env.current = rand(1:6)
2019
else
@@ -23,10 +22,10 @@ function interact!(env::BairdCounterEnv, a)
2322
nothing
2423
end
2524

26-
observe(env::BairdCounterEnv) =
27-
Observation(reward = 0.0, terminal = false, state = env.current)
25+
RLBase.observe(env::BairdCounterEnv) =
26+
(reward = 0.0, terminal = false, state = env.current)
2827

29-
function reset!(env::BairdCounterEnv)
28+
function RLBase.reset!(env::BairdCounterEnv)
3029
env.current = rand(1:6)
3130
nothing
3231
end

src/environments/BlackJack.jl

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
module BlackJack
22

3-
export BlackJackEnv, reset!, observe, interact!
4-
5-
using ReinforcementLearningEnvironments
6-
import ReinforcementLearningEnvironments: reset!, observe, interact!
3+
export BlackJackEnv
74

5+
using ReinforcementLearningCore
86
using Random
97

108
const ACTIONS = [:hit, :stick]
@@ -48,10 +46,11 @@ mutable struct BlackJackEnv <: AbstractEnv
4846
reward::Float64
4947
is_exploring_start::Bool
5048
init::Union{Nothing,Tuple{Hands,Hands}}
51-
observation_space::DiscreteSpace
52-
action_space::DiscreteSpace
5349
end
5450

51+
RLBase.get_observation_space(env::BlackJackEnv) = DiscreteSpace(length(INDS))
52+
RLBase.get_action_space(env::BlackJackEnv) = DiscreteSpace(2)
53+
5554
function BlackJackEnv(; is_exploring_start = false, init = nothing)
5655
env = BlackJackEnv(
5756
Hands(),
@@ -60,8 +59,6 @@ function BlackJackEnv(; is_exploring_start = false, init = nothing)
6059
0.0,
6160
is_exploring_start,
6261
init,
63-
DiscreteSpace(length(INDS)),
64-
DiscreteSpace(2),
6562
)
6663
init_hands!(env)
6764
env
@@ -87,7 +84,7 @@ function init_hands!(env::BlackJackEnv)
8784
env.player_hands, env.dealer_hands = player_hands, dealer_hands
8885
end
8986

90-
function interact!(env::BlackJackEnv, a::Int)
87+
function (env::BlackJackEnv)(a::Int)
9188
if ACTIONS[a] == :hit
9289
push!(env.player_hands, deal_card())
9390
if is_busted(env.player_hands)
@@ -117,7 +114,7 @@ function interact!(env::BlackJackEnv, a::Int)
117114
nothing
118115
end
119116

120-
function reset!(env::BlackJackEnv)
117+
function RLBase.reset!(env::BlackJackEnv)
121118
env.is_end = false
122119
env.reward = 0.0
123120

@@ -133,7 +130,7 @@ encode(env) =
133130
2 <= env.dealer_hands.sum <= 10 ? env.dealer_hands.sum : 1,
134131
]
135132

136-
observe(env::BlackJackEnv) =
137-
Observation(reward = env.reward, terminal = env.is_end, state = encode(env))
133+
RLBase.observe(env::BlackJackEnv) =
134+
(reward = env.reward, terminal = env.is_end, state = encode(env))
138135

139136
end

src/environments/BranchMDP.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
module BranchMDP
22

3-
export BranchMDPEnv, reset!, observe, interact!
3+
export BranchMDPEnv
44

5-
using ReinforcementLearningEnvironments
6-
import ReinforcementLearningEnvironments: reset!, observe, interact!
5+
using ReinforcementLearningCore
76

87
mutable struct BranchMDPEnv <: AbstractEnv
98
transition::Array{Int,3}
@@ -26,7 +25,10 @@ mutable struct BranchMDPEnv <: AbstractEnv
2625
)
2726
end
2827

29-
function interact!(env::BranchMDPEnv, a::Int)
28+
RLBase.get_observation_space(env::BranchMDPEnv) = env.observation_space
29+
RLBase.get_action_space(env::BranchMDPEnv) = env.action_space
30+
31+
function (env::BranchMDPEnv)(a::Int)
3032
if rand() < env.termination_prob
3133
env.reward = 0.0
3234
env.current = size(env.transition, 1) + 1
@@ -40,14 +42,14 @@ function interact!(env::BranchMDPEnv, a::Int)
4042
nothing
4143
end
4244

43-
observe(env::BranchMDPEnv) =
44-
Observation(
45+
RLBase.observe(env::BranchMDPEnv) =
46+
(
4547
reward = env.reward,
4648
terminal = env.current == size(env.transition, 1) + 1,
4749
state = env.current,
4850
)
4951

50-
function reset!(env::BranchMDPEnv, s::Int = 1)
52+
function RLBase.reset!(env::BranchMDPEnv, s::Int = 1)
5153
env.current = s
5254
nothing
5355
end

src/environments/CliffWalking.jl

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
module CliffWalking
22

3-
export CliffWalkingEnv, reset!, observe, interact!
3+
export CliffWalkingEnv
4+
5+
using ReinforcementLearningCore
46

5-
using ReinforcementLearningEnvironments
6-
import ReinforcementLearningEnvironments: reset!, observe, interact!
77

88
const NX = 4
99
const NY = 12
@@ -23,28 +23,27 @@ function iscliff(p::CartesianIndex{2})
2323
x == 4 && y > 1 && y < NY
2424
end
2525

26-
mutable struct CliffWalkingEnv <: AbstractEnv
27-
position::CartesianIndex{2}
28-
observation_space::DiscreteSpace
29-
action_space::DiscreteSpace
30-
CliffWalkingEnv() =
31-
new(Start, DiscreteSpace(length(LinearInds)), DiscreteSpace(length(Actions)))
26+
Base.@kwdef mutable struct CliffWalkingEnv <: AbstractEnv
27+
position::CartesianIndex{2} = Start
3228
end
3329

34-
function interact!(env::CliffWalkingEnv, a::Int)
30+
RLBase.get_observation_space(env::CliffWalkingEnv) = DiscreteSpace(length(LinearInds))
31+
RLBase.get_action_space(env::CliffWalkingEnv) = DiscreteSpace(length(Actions))
32+
33+
function (env::CliffWalkingEnv)(a::Int)
3534
x, y = Tuple(env.position + Actions[a])
3635
env.position = CartesianIndex(min(max(x, 1), NX), min(max(y, 1), NY))
3736
nothing
3837
end
3938

40-
observe(env::CliffWalkingEnv) =
41-
Observation(
39+
RLBase.observe(env::CliffWalkingEnv) =
40+
(
4241
reward = env.position == Goal ? 0.0 : (iscliff(env.position) ? -100.0 : -1.0),
4342
terminal = env.position == Goal || iscliff(env.position),
4443
state = LinearInds[env.position],
4544
)
4645

47-
function reset!(env::CliffWalkingEnv)
46+
function RLBase.reset!(env::CliffWalkingEnv)
4847
env.position = Start
4948
nothing
5049
end

src/environments/LeftRight.jl

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,42 @@
11
module LeftRight
22

3-
export LeftRightEnv, reset!, observe, interact!
3+
export LeftRightEnv
4+
5+
using ReinforcementLearningCore
46

5-
using ReinforcementLearningEnvironments
6-
import ReinforcementLearningEnvironments: reset!, observe, interact!
77

88
using StatsBase
99

1010
mutable struct LeftRightEnv <: AbstractEnv
1111
transitions::Array{Float64,3}
1212
current_state::Int
13-
observation_space::DiscreteSpace
14-
action_space::DiscreteSpace
15-
LeftRightEnv(transitions, current_state) =
16-
new(transitions, current_state, DiscreteSpace(2), DiscreteSpace(2))
1713
end
1814

15+
RLBase.get_observation_space(env::LeftRightEnv) = DiscreteSpace(2)
16+
RLBase.get_action_space(env::LeftRightEnv) = DiscreteSpace(2)
17+
1918
function LeftRightEnv()
2019
t = zeros(2, 2, 2)
2120
t[1, :, :] = [0.9 0.1; 0.0 1.0]
2221
t[2, :, :] = [0.0 1.0; 0.0 1.0]
2322
LeftRightEnv(t, rand(1:2))
2423
end
2524

26-
function interact!(env::LeftRightEnv, a::Int)
25+
function (env::LeftRightEnv)(a::Int)
2726
env.current_state = sample(Weights(
2827
@view(env.transitions[env.current_state, a, :]),
2928
1.0,
3029
))
3130
nothing
3231
end
3332

34-
function reset!(env::LeftRightEnv)
33+
function RLBase.reset!(env::LeftRightEnv)
3534
env.current_state = 1
3635
nothing
3736
end
3837

39-
observe(env::LeftRightEnv) =
40-
Observation(
38+
RLBase.observe(env::LeftRightEnv) =
39+
(
4140
reward = Float64(env.current_state == 2),
4241
terminal = env.current_state == 2,
4342
state = env.current_state,

0 commit comments

Comments
 (0)