1+ export PrioritizedSweepingSampleModel
2+
3+ using DataStructures: PriorityQueue, dequeue!
4+
5+ import StatsBase: sample
6+
7+ """
8+ PrioritizedSweepingSampleModel(θ::Float64=1e-4)
9+
10+ See more details at Section (8.4) on Page 168 of the book *Sutton, Richard S., and Andrew G. Barto. Reinforcement learning: An introduction. MIT press, 2018.*
11+ """
12+ mutable struct PrioritizedSweepingSampleModel <: AbstractEnvironmentModel
13+ experiences:: Dict{Tuple{Any,Any},Tuple{Float64,Bool,Any}}
14+ PQueue:: PriorityQueue{Tuple{Any,Any},Float64}
15+ predecessors:: Dict{Any,Set{Tuple{Any,Any,Float64,Bool}}}
16+ θ:: Float64
17+ sample_count:: Int
18+ PrioritizedSweepingSampleModel (θ:: Float64 = 1e-4 ) =
19+ new (
20+ Dict {Tuple{Any,Any},Tuple{Float64,Bool,Any}} (),
21+ PriorityQueue {Tuple{Any,Any},Float64} (Base. Order. Reverse),
22+ Dict {Any,Set{Tuple{Any,Any,Float64,Bool}}} (),
23+ θ,
24+ 0 ,
25+ )
26+ end
27+
28+ function RLBase. extract_experience (
29+ t:: AbstractTrajectory ,
30+ model:: PrioritizedSweepingSampleModel ,
31+ )
32+ if length (t) > 0
33+ get_trace (t, :state )[end ],
34+ get_trace (t, :action )[end ],
35+ get_trace (t, :reward )[end ],
36+ get_trace (t, :terminal )[end ],
37+ get_trace (t, :next_state )[end ]
38+ else
39+ nothing
40+ end
41+ end
42+
43+ function RLBase. update! (m:: PrioritizedSweepingSampleModel , transition, P)
44+ s, a, r, d, s′ = transition
45+ m. experiences[(s, a)] = (r, d, s′)
46+ if P >= m. θ
47+ m. PQueue[(s, a)] = P
48+ end
49+ if ! haskey (m. predecessors, s′)
50+ m. predecessors[s′] = Set {Tuple{Any,Any,Float64,Bool}} ()
51+ end
52+ push! (m. predecessors[s′], (s, a, r, d))
53+ end
54+
55+ function sample (m:: PrioritizedSweepingSampleModel )
56+ if length (m. PQueue) > 0
57+ s, a = dequeue! (m. PQueue)
58+ r, d, s′ = m. experiences[(s, a)]
59+ m. sample_count += 1
60+ s, a, r, d, s′
61+ else
62+ nothing
63+ end
64+ end
0 commit comments