Skip to content

Commit 593b962

Browse files
committed
finishe chapter03
1 parent 3170351 commit 593b962

16 files changed

Lines changed: 554 additions & 151 deletions

Manifest.toml

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ version = "0.3.1"
1414

1515
[[Adapt]]
1616
deps = ["LinearAlgebra"]
17-
git-tree-sha1 = "82dab828020b872fa9efd3abec1152b075bc7cbf"
17+
git-tree-sha1 = "c88cfc7f9c1f9f8633cddf0b56e86302b70f64c5"
1818
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
19-
version = "1.0.0"
19+
version = "1.0.1"
2020

2121
[[Arpack]]
2222
deps = ["Arpack_jll", "Libdl", "LinearAlgebra"]
@@ -58,21 +58,21 @@ version = "0.2.0"
5858

5959
[[CUDAapi]]
6060
deps = ["Libdl", "Logging"]
61-
git-tree-sha1 = "56a813440ac98a1aa64672ab460a1512552211a7"
61+
git-tree-sha1 = "d9614968b9a13df433870115acff20f41e7b400a"
6262
uuid = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
63-
version = "2.1.0"
63+
version = "3.0.0"
6464

6565
[[CUDAdrv]]
6666
deps = ["CEnum", "CUDAapi", "Printf"]
67-
git-tree-sha1 = "5660775f2a3214420add960e1ff2baf46d5297cd"
67+
git-tree-sha1 = "01e90fa34e25776bc7c8661183d4519149ebfe59"
6868
uuid = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
69-
version = "5.1.0"
69+
version = "6.0.0"
7070

7171
[[CUDAnative]]
7272
deps = ["Adapt", "CEnum", "CUDAapi", "CUDAdrv", "DataStructures", "InteractiveUtils", "LLVM", "Libdl", "Printf", "TimerOutputs"]
73-
git-tree-sha1 = "e0c2805c9a7d338823c0d8f574242e284410fa61"
73+
git-tree-sha1 = "59d6c3e313b874abc718f7d6ad02ea604f96db14"
7474
uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
75-
version = "2.9.1"
75+
version = "2.10.0"
7676

7777
[[Clustering]]
7878
deps = ["Distances", "LinearAlgebra", "NearestNeighbors", "Printf", "SparseArrays", "Statistics", "StatsBase"]
@@ -118,9 +118,9 @@ version = "0.5.1"
118118

119119
[[CuArrays]]
120120
deps = ["AbstractFFTs", "Adapt", "CEnum", "CUDAapi", "CUDAdrv", "CUDAnative", "DataStructures", "GPUArrays", "Libdl", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"]
121-
git-tree-sha1 = "4e536542c5c898b1bf43011b6187f3c97ebcc91e"
121+
git-tree-sha1 = "9aac17f7e09017107c84ed2657f462e86b1d56b3"
122122
uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
123-
version = "1.7.0"
123+
version = "1.7.1"
124124

125125
[[DataAPI]]
126126
git-tree-sha1 = "674b67f344687a88310213ddfa8a2b3c76cc4252"

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["TianJun <tianjun.cpp@gmail.com>"]
44
version = "0.2.0"
55

66
[deps]
7+
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
78
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
89
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
910
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"

notebooks/Chapter03_Grid_World.ipynb

Lines changed: 158 additions & 139 deletions
Large diffs are not rendered by default.

notebooks/Manifest.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ deps = ["InteractiveUtils", "Markdown", "Sockets"]
480480
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
481481

482482
[[RLIntro]]
483-
deps = ["Distributions", "Plots", "Random", "ReinforcementLearningBase", "ReinforcementLearningCore", "SparseArrays", "StatsBase", "StatsPlots"]
483+
deps = ["DataStructures", "Distributions", "Flux", "Plots", "Random", "ReinforcementLearningBase", "ReinforcementLearningCore", "SparseArrays", "StatsBase", "StatsPlots"]
484484
path = ".."
485485
uuid = "02c1da58-b9a1-11e8-0212-f9611b8fe936"
486486
version = "0.2.0"

src/RLIntro.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module RLIntro
22

33
using ReinforcementLearningCore
4+
using ReinforcementLearningBase
45

56
include("environments/environments.jl")
67
include("extensions/extensions.jl")
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
export DeterministicDistributionModel
2+
3+
"""
4+
DeterministicDistributionModel(table::Array{Vector{NamedTuple{(:nextstate, :reward, :prob),Tuple{Int,Float64,Float64}}}, 2})
5+
6+
Store all the transformations in the `table` field.
7+
"""
8+
struct DeterministicDistributionModel <: AbstractEnvironmentModel
9+
table::Array{
10+
Vector{NamedTuple{(:nextstate, :reward, :prob),Tuple{Int,Float64,Float64}}},
11+
2,
12+
}
13+
end
14+
15+
RLBase.get_observation_space(m::DeterministicDistributionModel) = DiscreteSpace(size(m.table, 1))
16+
RLBase.get_action_space(m::DeterministicDistributionModel) = DiscreteSpace(size(m.table, 2))
17+
18+
(m::DeterministicDistributionModel)(s::Int, a::Int) = m.table[s, a]
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
export DynamicDistributionModel
2+
3+
"""
4+
DynamicDistributionModel(f::Tf, ns::Int, na::Int) -> DynamicDistributionModel{Tf}
5+
6+
Use a general function `f` to store the transformations. `ns` and `na` are the number of states and actions.
7+
"""
8+
struct DynamicDistributionModel{Tf<:Function} <: AbstractEnvironmentModel
9+
f::Tf
10+
ns::Int
11+
na::Int
12+
end
13+
14+
RLBase.get_observation_space(m::DynamicDistributionModel) = DiscreteSpace(m.ns)
15+
RLBase.get_action_space(m::DynamicDistributionModel) = DiscreteSpace(m.na)
16+
17+
(m::DynamicDistributionModel)(s, a) = m.f(s, a)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
include("deterministic_distribution_model.jl")
2+
include("dynamic_distribution_model.jl")
3+
include("experience_based_sample_model.jl")
4+
include("time_based_sample_model.jl")
5+
include("prioritized_sweeping_sample_model.jl")
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
export ExperienceBasedSampleModel, sample
2+
3+
import StatsBase: sample
4+
5+
"""
6+
ExperienceBasedSampleModel() -> ExperienceBasedSampleModel
7+
8+
Generate a transition based on previous experiences.
9+
"""
10+
mutable struct ExperienceBasedSampleModel <: AbstractEnvironmentModel
11+
experiences::Dict{
12+
Any,
13+
Dict{Any,NamedTuple{(:reward, :terminal, :nextstate),Tuple{Float64,Bool,Any}}},
14+
}
15+
sample_count::Int
16+
ExperienceBasedSampleModel() =
17+
new(
18+
Dict{
19+
Any,
20+
Dict{
21+
Any,
22+
NamedTuple{(:reward, :terminal, :nextstate),Tuple{Float64,Bool,Any}},
23+
},
24+
}(),
25+
0,
26+
)
27+
end
28+
29+
function RLBase.extract_experience(t::AbstractTrajectory, m::ExperienceBasedSampleModel)
30+
if length(t) > 0
31+
get_trace(t, :state)[end],
32+
get_trace(t, :action)[end],
33+
get_trace(t, :reward)[end],
34+
get_trace(t, :terminal)[end],
35+
get_trace(t, :next_state)[end]
36+
else
37+
nothing
38+
end
39+
end
40+
41+
RLBase.update!(m::ExperienceBasedSampleModel, ::Nothing) = nothing
42+
43+
function RLBase.update!(m::ExperienceBasedSampleModel, transition::Tuple)
44+
s, a, r, d, s′ = transition
45+
if haskey(m.experiences, s)
46+
m.experiences[s][a] = (reward = r, terminal = d, nextstate = s′)
47+
else
48+
m.experiences[s] = Dict{
49+
Any,
50+
NamedTuple{(:reward, :terminal, :nextstate),Tuple{Float64,Bool,Any}},
51+
}(a => (reward = r, terminal = d, nextstate = s′))
52+
end
53+
end
54+
55+
function sample(model::ExperienceBasedSampleModel)
56+
s = rand(keys(model.experiences))
57+
a = rand(keys(model.experiences[s]))
58+
model.sample_count += 1
59+
s, a, model.experiences[s][a]...
60+
end
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
export PrioritizedSweepingSampleModel
2+
3+
using DataStructures: PriorityQueue, dequeue!
4+
5+
import StatsBase: sample
6+
7+
"""
8+
PrioritizedSweepingSampleModel(θ::Float64=1e-4)
9+
10+
See more details at Section (8.4) on Page 168 of the book *Sutton, Richard S., and Andrew G. Barto. Reinforcement learning: An introduction. MIT press, 2018.*
11+
"""
12+
mutable struct PrioritizedSweepingSampleModel <: AbstractEnvironmentModel
13+
experiences::Dict{Tuple{Any,Any},Tuple{Float64,Bool,Any}}
14+
PQueue::PriorityQueue{Tuple{Any,Any},Float64}
15+
predecessors::Dict{Any,Set{Tuple{Any,Any,Float64,Bool}}}
16+
θ::Float64
17+
sample_count::Int
18+
PrioritizedSweepingSampleModel::Float64 = 1e-4) =
19+
new(
20+
Dict{Tuple{Any,Any},Tuple{Float64,Bool,Any}}(),
21+
PriorityQueue{Tuple{Any,Any},Float64}(Base.Order.Reverse),
22+
Dict{Any,Set{Tuple{Any,Any,Float64,Bool}}}(),
23+
θ,
24+
0,
25+
)
26+
end
27+
28+
function RLBase.extract_experience(
29+
t::AbstractTrajectory,
30+
model::PrioritizedSweepingSampleModel,
31+
)
32+
if length(t) > 0
33+
get_trace(t, :state)[end],
34+
get_trace(t, :action)[end],
35+
get_trace(t, :reward)[end],
36+
get_trace(t, :terminal)[end],
37+
get_trace(t, :next_state)[end]
38+
else
39+
nothing
40+
end
41+
end
42+
43+
function RLBase.update!(m::PrioritizedSweepingSampleModel, transition, P)
44+
s, a, r, d, s′ = transition
45+
m.experiences[(s, a)] = (r, d, s′)
46+
if P >= m.θ
47+
m.PQueue[(s, a)] = P
48+
end
49+
if !haskey(m.predecessors, s′)
50+
m.predecessors[s′] = Set{Tuple{Any,Any,Float64,Bool}}()
51+
end
52+
push!(m.predecessors[s′], (s, a, r, d))
53+
end
54+
55+
function sample(m::PrioritizedSweepingSampleModel)
56+
if length(m.PQueue) > 0
57+
s, a = dequeue!(m.PQueue)
58+
r, d, s′ = m.experiences[(s, a)]
59+
m.sample_count += 1
60+
s, a, r, d, s′
61+
else
62+
nothing
63+
end
64+
end

0 commit comments

Comments
 (0)