Skip to content

Commit c220e9a

Browse files
committed
sync
1 parent e2fa30b commit c220e9a

File tree

9 files changed

+7734
-7462
lines changed

9 files changed

+7734
-7462
lines changed

notebooks/Chapter08_Maze.ipynb

Lines changed: 192 additions & 26 deletions
Large diffs are not rendered by default.

notebooks/Chapter09_Random_Walk.ipynb

Lines changed: 7403 additions & 7434 deletions
Large diffs are not rendered by default.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
include("linear_approximator.jl")
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
export LinearVApproximator, LinearQApproximator
2+
3+
using LinearAlgebra: dot
4+
5+
"""
6+
LinearVApproximator(weights::Array{Float64, N}) -> LinearVApproximator{N}
7+
Use the weighted sum to represent the estimation of a state.
8+
The state is expected to have the same length with `weights`.
9+
See also [`LinearQApproximator`](@ref)
10+
"""
11+
struct LinearVApproximator{N} <: AbstractApproximator
12+
weights::Array{Float64,N}
13+
end
14+
15+
RLBase.ApproximatorStyle(::LinearVApproximator) = VApproximator()
16+
17+
# TODO: support Vector
18+
(V::LinearVApproximator)(s) = dot(s, V.weights)
19+
20+
function RLBase.update!(V::LinearVApproximator, correction::Pair)
21+
s, e = correction
22+
V.weights .+= s .* e
23+
end
24+
25+
"""
26+
LinearQApproximator(weights::Vector{Float64}, feature_func::F, actions::Vector{Int}) -> LinearQApproximator{F}
27+
Use weighted sum to represent the estimation given a state and an action.
28+
# Fields
29+
- `weights::Vector{Float64}`: the weight of each feature.
30+
- `feature_func::Function`: decide how to generate a feature vector of `length(weights)` given a state and an action as parameters.
31+
- `actions::Vector{Int}`: all possible actions.
32+
See also [`LinearVApproximator`](@ref).
33+
"""
34+
Base.@kwdef struct LinearQApproximator{F} <: AbstractApproximator
35+
weights::Vector{Float64}
36+
feature_func::F
37+
actions::Vector{Int}
38+
end
39+
40+
RLBase.ApproximatorStyle(::LinearQApproximator) = QApproximator()
41+
42+
(Q::LinearQApproximator)(s, a::Int) = dot(Q.weights, Q.feature_func(s, a))
43+
44+
(Q::LinearQApproximator)(s) = [Q(s, a) for a in Q.actions]
45+
46+
function RLBase.update!(Q::LinearQApproximator, correction::Pair)
47+
(s, a), e = correction
48+
xs = Q.feature_func(s, a)
49+
Q.weights .+= xs .* e
50+
end

src/extensions/extensions.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
include("utils/utils.jl")
2+
include("preprocessors.jl")
3+
include("approximators/approximators.jl")
24
include("environment_models/environment_models.jl")
35
include("learners/learners.jl")
46
include("policies/policies.jl")

src/extensions/preprocessors.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
export FourierPreprocessor, PolynomialPreprocessor, TilingPreprocessor
2+
3+
"""
4+
FourierPreprocessor(order::Int)
5+
Transform a scalar to a vector of `order+1` Fourier bases.
6+
"""
7+
struct FourierPreprocessor <: AbstractPreprocessor
8+
order::Int
9+
end
10+
11+
(p::FourierPreprocessor)(s::Number) = [cos(i * π * s) for i = 0:p.order]
12+
13+
"""
14+
PolynomialPreprocessor(order::Int)
15+
Transform a scalar to vector of maximum `order` polynomial.
16+
"""
17+
struct PolynomialPreprocessor <: AbstractPreprocessor
18+
order::Int
19+
end
20+
21+
(p::PolynomialPreprocessor)(s::Number) = [s^i for i = 0:p.order]
22+
23+
"""
24+
TilingPreprocessor(tilings::Vector{<:Tiling})
25+
Use each `tilings` to encode the state and return a vector.
26+
"""
27+
struct TilingPreprocessor{Tt<:Tiling} <: AbstractPreprocessor
28+
tilings::Vector{Tt}
29+
end
30+
31+
(p::TilingPreprocessor)(s::Union{<:Number,<:Array}) = [encode(t, s) for t in p.tilings]

src/extensions/utils/base.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,5 @@ end
5959

6060
discount_rewards(rewards, γ) = discount_rewards!(similar(rewards), rewards, γ)
6161

62-
discount_rewards_reduced(rewards, γ) = foldr((r, g) -> r + γ * g, rewards)
62+
discount_rewards_reduced(rewards, γ) = foldr((r, g) -> r + γ * g, rewards)
63+

src/extensions/utils/tiling.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
export Tiling, encode
2+
3+
import Base: length, -
4+
5+
"""
6+
Tiling(ranges::NTuple{N, Tr}) where {N, Tr}
7+
Using a tuple of `ranges` to simulate a tiling.
8+
The length of `ranges` indicates the dimension of tilling.
9+
# Example
10+
```julia
11+
julia> t = Tiling((1:2:5, 10:5:20))
12+
Tiling{2,StepRange{Int64,Int64}}((1:2:5, 10:5:20), [1 3; 2 4])
13+
julia> encode(t, (2, 12)) # encode into an Int
14+
1
15+
julia> encode(t, (2, 18))
16+
3
17+
julia> t2 = t - (1, 3) # shift a little to get a new Tiling
18+
Tiling{2,StepRange{Int64,Int64}}((0:2:4, 7:5:17), [1 3; 2 4])
19+
```
20+
"""
21+
struct Tiling{N,Tr<:AbstractRange}
22+
ranges::NTuple{N,Tr}
23+
inds::LinearIndices{N,NTuple{N,Base.OneTo{Int}}}
24+
Tiling(ranges::NTuple{N,Tr}) where {N,Tr} =
25+
new{N,Tr}(ranges, LinearIndices(Tuple(length(r) - 1 for r in ranges)))
26+
end
27+
28+
"""
29+
(-)(t::Tiling, xs)
30+
Shift `t` along each dimension by each element in `xs`.
31+
"""
32+
function Base.:-(t::Tiling, xs)
33+
Tiling(Tuple(r .- x for (r, x) in zip(t.ranges, xs)))
34+
end
35+
36+
Base.length(t::Tiling) = reduce(*, (length(r) - 1 for r in t.ranges))
37+
38+
encode(range::AbstractRange, x) = floor(Int, div(x - range[1], step(range)) + 1)
39+
40+
41+
# TODO: use @generator here!
42+
encode(t::Tiling{1}, x::Number) = encode(t.ranges[1], x)
43+
encode(t::Tiling{1}, xs) = encode(t.ranges[1], xs[1])
44+
encode(t::Tiling{2}, xs) =
45+
t.inds[CartesianIndex(encode(t.ranges[1], xs[1]), encode(t.ranges[2], xs[2]))]
46+
encode(t::Tiling{3}, xs) =
47+
t.inds[CartesianIndex(
48+
encode(t.ranges[1], xs[1]),
49+
encode(t.ranges[2], xs[2]),
50+
encode(t.ranges[3], xs[3]),
51+
)]

src/extensions/utils/utils.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
include("base.jl")
2-
include("optimizers.jl")
2+
include("optimizers.jl")
3+
include("tiling.jl")

0 commit comments

Comments
 (0)