Skip to content

Commit e2fa30b

Browse files
committed
sync
1 parent 9cc7acb commit e2fa30b

File tree

6 files changed

+2912
-2972
lines changed

6 files changed

+2912
-2972
lines changed

notebooks/Chapter08_Maze.ipynb

Lines changed: 2099 additions & 2259 deletions
Large diffs are not rendered by default.

notebooks/Chapter08_Trajectory_Sampling.ipynb

Lines changed: 722 additions & 709 deletions
Large diffs are not rendered by default.

src/extensions/environment_models/prioritized_sweeping_sample_model.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,13 @@ function RLBase.extract_experience(
4040
end
4141
end
4242

43-
function RLBase.update!(m::PrioritizedSweepingSampleModel, transition, P)
44-
s, a, r, d, s′ = transition
43+
function RLBase.update!(m::PrioritizedSweepingSampleModel, t::AbstractTrajectory, p::AbstractPolicy)
44+
experience = extract_experience(t, m)
45+
isnothing(experience) || update!(m, (experience..., get_priority(p, experience)))
46+
end
47+
48+
function RLBase.update!(m::PrioritizedSweepingSampleModel, transition::Tuple)
49+
s, a, r, d, s′, P = transition
4550
m.experiences[(s, a)] = (r, d, s′)
4651
if P >= m.θ
4752
m.PQueue[(s, a)] = P

src/extensions/environment_models/time_based_sample_model.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ mutable struct TimeBasedSampleModel <: AbstractEnvironmentModel
3030
)
3131
end
3232

33-
function extract_transitions(t::AbstractTrajectory, m::TimeBasedSampleModel)
33+
function RLBase.extract_experience(t::AbstractTrajectory, m::TimeBasedSampleModel)
3434
if length(t) > 0
3535
get_trace(t, :state)[end],
3636
get_trace(t, :action)[end],

src/extensions/extensions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
include("utils/utils.jl")
2+
include("environment_models/environment_models.jl")
23
include("learners/learners.jl")
34
include("policies/policies.jl")
4-
include("environment_models/environment_models.jl")
55
include("iteration_methods.jl")

src/extensions/learners/temporal_difference_learner.jl

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,88 @@ function RLBase.extract_experience(
208208
end
209209
end
210210

211+
function RLBase.update!(
212+
learner::TDLearner{<:AbstractApproximator,:SARS},
213+
model::Union{TimeBasedSampleModel,ExperienceBasedSampleModel},
214+
t::AbstractTrajectory,
215+
plan_step::Int,
216+
)
217+
@assert learner.n == 0 "n must be 0 here"
218+
for _ = 1:plan_step
219+
transitions = extract_experience(model, learner)
220+
if !isnothing(transitions)
221+
update!(learner, transitions)
222+
end
223+
end
224+
end
225+
226+
function RLBase.extract_experience(
227+
model::Union{ExperienceBasedSampleModel,TimeBasedSampleModel},
228+
learner::TDLearner{<:AbstractApproximator,:SARS},
229+
)
230+
if length(model.experiences) > 0
231+
s = sample(model)
232+
(
233+
states = [s[1]],
234+
actions = [s[2]],
235+
rewards = [s[3]],
236+
terminals = [s[4]],
237+
next_states = [s[5]],
238+
)
239+
else
240+
nothing
241+
end
242+
end
243+
244+
function RLBase.get_priority(learner::TDLearner{<:AbstractApproximator,:SARS}, transition::Tuple)
245+
s, a, r, d, s′ = transition
246+
γ, Q, opt = learner.γ, learner.approximator, learner.optimizer
247+
error = d ? apply!(opt, (s, a), r - Q(s, a)) :
248+
apply!(opt, (s, a), r + γ^(learner.n + 1) * maximum(Q(s′)) - Q(s, a))
249+
abs(error)
250+
end
251+
252+
function RLBase.update!(
253+
learner::TDLearner{<:AbstractApproximator,:SARS},
254+
model::PrioritizedSweepingSampleModel,
255+
t::AbstractTrajectory,
256+
plan_step::Int,
257+
)
258+
for _ = 1:plan_step
259+
# @assert learner.n == 0 "n must be 0 here"
260+
transitions = extract_experience(model, learner)
261+
if !isnothing(transitions)
262+
update!(learner, transitions)
263+
s, _, _, _, _ = transitions
264+
s = s[] # length(s) is assumed to be 1
265+
for (s̄, ā, r̄, d̄) in model.predecessors[s]
266+
P = get_priority(learner, (s̄, ā, r̄, d̄, s))
267+
if P model.θ
268+
model.PQueue[(s̄, ā)] = P
269+
end
270+
end
271+
end
272+
end
273+
end
274+
275+
function RLBase.extract_experience(
276+
model::PrioritizedSweepingSampleModel,
277+
learner::TDLearner{<:AbstractApproximator,:SARS},
278+
)
279+
if length(model.PQueue) > 0
280+
s = sample(model)
281+
(
282+
states = [s[1]],
283+
actions = [s[2]],
284+
rewards = [s[3]],
285+
terminals = [s[4]],
286+
next_states = [s[5]],
287+
)
288+
else
289+
nothing
290+
end
291+
end
292+
211293
#####
212294
# SARS DoubleLearner
213295
#####

0 commit comments

Comments
 (0)