Skip to content

Commit a2c092a

Browse files
singhharsh1708Harsh SinghChrisRackauckas
authored
Add MREEF multirate extrapolated explicit Euler solver with tests (#3139)
* Add MREEF multirate extrapolated explicit Euler solver with tests * Update FBDF GPU test from @test_broken to @test (now passing) * Add MREEF multirate extrapolated Euler method with tests and proper integration * Fix MREEF LTS issue and formatting * Runic formatting fixes after rebase * Fix typo in inference test * Remove unrelated changes from inference tests * Align MREEF split with SciML convention (f1 fast, f2 slow) * Move MREEF to new OrdinaryDiffEqMultirate subpackage and address review feedback * Move MREEF to OrdinaryDiffEqMultirate subpackage and address review feedback * Apply suggestion from @ChrisRackauckas --------- Co-authored-by: Harsh Singh <harsh@Harshs-MacBook-Air.local> Co-authored-by: Christopher Rackauckas <accounts@chrisrackauckas.com>
1 parent 2293fa7 commit a2c092a

9 files changed

Lines changed: 582 additions & 0 deletions

File tree

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
name = "OrdinaryDiffEqMultirate"
2+
uuid = "d4b830b4-ac80-426b-8507-16693d424963"
3+
authors = ["singhharsh1708 <hs1663531@gmail.com>"]
4+
version = "1.0.0"
5+
6+
[deps]
7+
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
8+
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
9+
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
10+
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
11+
OrdinaryDiffEqLowOrderRK = "1344f307-1e59-4825-a18e-ace9aa3fa4c6"
12+
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
13+
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
14+
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
15+
16+
[compat]
17+
Aqua = "0.8.11"
18+
DiffEqBase = "6.194"
19+
DiffEqDevTools = "2.44.4"
20+
FastBroadcast = "0.3"
21+
JET = "0.9, 0.11"
22+
MuladdMacro = "0.2"
23+
OrdinaryDiffEqCore = "3"
24+
OrdinaryDiffEqLowOrderRK = "1"
25+
Pkg = "1"
26+
RecursiveArrayTools = "3.36"
27+
Reexport = "1.2"
28+
SafeTestsets = "0.1.0"
29+
SciMLBase = "2.116"
30+
Test = "<0.0.1, 1"
31+
julia = "1.10"
32+
33+
[extras]
34+
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
35+
DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
36+
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
37+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
38+
OrdinaryDiffEqLowOrderRK = "1344f307-1e59-4825-a18e-ace9aa3fa4c6"
39+
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
40+
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
41+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
42+
43+
[sources.OrdinaryDiffEqCore]
44+
path = "../OrdinaryDiffEqCore"
45+
46+
[sources.OrdinaryDiffEqLowOrderRK]
47+
path = "../OrdinaryDiffEqLowOrderRK"
48+
49+
[targets]
50+
test = ["DiffEqDevTools", "LinearAlgebra", "OrdinaryDiffEqLowOrderRK", "SafeTestsets", "Test", "Pkg"]
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
module OrdinaryDiffEqMultirate
2+
3+
import OrdinaryDiffEqCore: alg_order, isfsal,
4+
OrdinaryDiffEqAdaptiveAlgorithm,
5+
generic_solver_docstring,
6+
unwrap_alg, initialize!, perform_step!,
7+
calculate_residuals, calculate_residuals!,
8+
OrdinaryDiffEqMutableCache, OrdinaryDiffEqConstantCache,
9+
@cache, alg_cache, full_cache, get_fsalfirstlast
10+
import OrdinaryDiffEqCore
11+
import FastBroadcast: @..
12+
import MuladdMacro: @muladd
13+
import RecursiveArrayTools: recursivefill!
14+
import DiffEqBase: prepare_alg
15+
16+
using Reexport
17+
@reexport using SciMLBase
18+
19+
include("algorithms.jl")
20+
include("alg_utils.jl")
21+
include("mreef_caches.jl")
22+
include("mreef_perform_step.jl")
23+
24+
export MREEF
25+
26+
end
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
alg_order(alg::MREEF) = alg.order
2+
isfsal(::MREEF) = false
3+
4+
function prepare_alg(alg::MREEF, u0::AbstractArray, p, prob)
5+
return alg
6+
end
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
@doc generic_solver_docstring(
2+
"Multirate Richardson Extrapolation with Euler as the base method (MREEF).
3+
4+
Solves a split ODE of the form `du/dt = f1(u,t) + f2(u,t)` where `f1` is the
5+
fast component and `f2` is the slow component (SciML convention). The slow rate
6+
`f2` is frozen over each macro interval and the fast rate `f1` is integrated
7+
with `m` explicit Euler substeps. Aitken–Neville Richardson extrapolation over
8+
`order` base solutions is then applied to boost accuracy.",
9+
"MREEF",
10+
"Multirate explicit method.",
11+
"""@article{engstrom2009multirate,
12+
title={Multirate explicit Adams methods for time integration of conservation laws},
13+
author={Engstr{\\\"o}m, C and Ferm, L and L{\\\"o}tstedt, P and Sj{\\\"o}green, B},
14+
year={2009}}""",
15+
"""
16+
- `m`: number of fast substeps per macro interval. Default is `4`.
17+
- `order`: extrapolation order (number of base solutions). Default is `4`.
18+
- `seq`: subdivision sequence, `:harmonic` (default) or `:romberg`.
19+
""",
20+
"""
21+
m::Int = 4,
22+
order::Int = 4,
23+
seq::Symbol = :harmonic,
24+
"""
25+
)
26+
Base.@kwdef struct MREEF <: OrdinaryDiffEqAdaptiveAlgorithm
27+
m::Int = 4
28+
order::Int = 4
29+
seq::Symbol = :harmonic
30+
end
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
struct MREEFConstantCache{T} <: OrdinaryDiffEqConstantCache
2+
T::T # pre-allocated extrapolation table: Vector of length `order`
3+
end
4+
5+
@cache mutable struct MREEFCache{uType, rateType, uNoUnitsType} <: OrdinaryDiffEqMutableCache
6+
u::uType
7+
uprev::uType
8+
tmp::uType
9+
atmp::uNoUnitsType
10+
k_slow::rateType
11+
k_fast::rateType
12+
T::Array{uType, 1}
13+
fsalfirst::rateType
14+
k::rateType
15+
end
16+
17+
get_fsalfirstlast(cache::MREEFCache, u) = (cache.fsalfirst, cache.k)
18+
19+
function alg_cache(
20+
alg::MREEF, u, rate_prototype, ::Type{uEltypeNoUnits},
21+
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
22+
dt, reltol, p, calck,
23+
::Val{true}, verbose
24+
) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
25+
tmp = zero(u)
26+
atmp = similar(u, uEltypeNoUnits)
27+
recursivefill!(atmp, false)
28+
k_slow = zero(rate_prototype)
29+
k_fast = zero(rate_prototype)
30+
T = [zero(u) for _ in 1:(alg.order)]
31+
fsalfirst = zero(rate_prototype)
32+
k = zero(rate_prototype)
33+
return MREEFCache(u, uprev, tmp, atmp, k_slow, k_fast, T, fsalfirst, k)
34+
end
35+
36+
function alg_cache(
37+
alg::MREEF, u, rate_prototype, ::Type{uEltypeNoUnits},
38+
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
39+
dt, reltol, p, calck,
40+
::Val{false}, verbose
41+
) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
42+
T = Vector{typeof(u)}(undef, alg.order)
43+
return MREEFConstantCache(T)
44+
end
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# ── MREEF: step-count sequence ────────────────────────────────────────────────
2+
3+
@inline function _mreef_sequence(seq::Symbol, order::Int)
4+
if seq === :harmonic
5+
return ntuple(j -> j, order)
6+
elseif seq === :romberg
7+
return ntuple(j -> 1 << (j - 1), order)
8+
else
9+
throw(ArgumentError("MREEF: unknown sequence `$seq`, choose :harmonic or :romberg"))
10+
end
11+
end
12+
13+
# ── MREEF initialize! ─────────────────────────────────────────────────────────
14+
15+
function initialize!(integrator, cache::MREEFCache)
16+
integrator.kshortsize = 2
17+
(; fsalfirst, k) = cache
18+
integrator.fsalfirst = fsalfirst
19+
integrator.fsallast = k
20+
resize!(integrator.k, integrator.kshortsize)
21+
integrator.k[1] = integrator.fsalfirst
22+
integrator.k[2] = integrator.fsallast
23+
integrator.f.f1(integrator.fsalfirst, integrator.uprev, integrator.p, integrator.t)
24+
integrator.f.f2(cache.tmp, integrator.uprev, integrator.p, integrator.t)
25+
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) # f1
26+
integrator.stats.nf2 += 1 # f2
27+
return integrator.fsalfirst .+= cache.tmp
28+
end
29+
30+
function initialize!(integrator, cache::MREEFConstantCache)
31+
integrator.kshortsize = 2
32+
integrator.k = typeof(integrator.k)(undef, integrator.kshortsize)
33+
integrator.fsalfirst = integrator.f.f1(integrator.uprev, integrator.p, integrator.t) +
34+
integrator.f.f2(integrator.uprev, integrator.p, integrator.t)
35+
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
36+
integrator.stats.nf2 += 1
37+
integrator.fsallast = zero(integrator.fsalfirst)
38+
integrator.k[1] = integrator.fsalfirst
39+
return integrator.k[2] = integrator.fsallast
40+
end
41+
42+
# ── MREEF perform_step! (in-place, MutableCache) ──────────────────────────────
43+
#
44+
# Base multirate Euler with nj macro intervals, m fast substeps each:
45+
# 1. k_slow = f.f2(u, p, t_mac) — frozen slow rate for the macro interval
46+
# 2. m fast substeps: u += h_fast*(k_slow + f.f1(u, p, t_fast))
47+
# f1 = fast/stiff (large eigenvalues), f2 = slow/non-stiff (SciML convention).
48+
# Then apply Aitken–Neville Richardson extrapolation over T[1..order].
49+
50+
function perform_step!(integrator, cache::MREEFCache, repeat_step = false)
51+
(; t, dt, uprev, u, f, p) = integrator
52+
(; tmp, atmp, k_slow, k_fast, T) = cache
53+
alg = unwrap_alg(integrator, false)
54+
m = alg.m
55+
order = alg.order
56+
ns = _mreef_sequence(alg.seq, order)
57+
58+
# Fill first tableau column: T[j] = base method with ns[j] macro intervals
59+
for j in 1:order
60+
nj = ns[j]
61+
h_mac = dt / nj
62+
h_fast = h_mac / m
63+
64+
@.. broadcast = false T[j] = uprev
65+
66+
for i_mac in 1:nj
67+
t_mac = t + (i_mac - 1) * h_mac
68+
69+
# Slow evaluation (f2): frozen for all m fast substeps
70+
f.f2(k_slow, T[j], p, t_mac)
71+
integrator.stats.nf2 += 1
72+
73+
for i_fast in 1:m
74+
t_fast = t_mac + (i_fast - 1) * h_fast
75+
f.f1(k_fast, T[j], p, t_fast)
76+
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
77+
@.. broadcast = false T[j] = T[j] + h_fast * k_slow + h_fast * k_fast
78+
end
79+
end
80+
end
81+
82+
# Aitken–Neville Richardson extrapolation (in-place, reverse-row order)
83+
# Formula: T[j] <- T[j] + (T[j] - T[j-1]) / (ns[j]/ns[j-k] - 1)
84+
for k in 1:(order - 1)
85+
for j in order:-1:(k + 1)
86+
ratio = ns[j] / ns[j - k]
87+
@.. broadcast = false tmp = (T[j] - T[j - 1]) / (ratio - 1)
88+
@.. broadcast = false T[j] = T[j] + tmp
89+
end
90+
end
91+
92+
@.. broadcast = false u = T[order]
93+
94+
return if integrator.opts.adaptive
95+
@.. broadcast = false tmp = T[order] - T[order - 1]
96+
calculate_residuals!(
97+
atmp,
98+
tmp,
99+
uprev,
100+
u,
101+
integrator.opts.abstol,
102+
integrator.opts.reltol,
103+
integrator.opts.internalnorm,
104+
t,
105+
)
106+
integrator.EEst = integrator.opts.internalnorm(atmp, t)
107+
end
108+
end
109+
110+
# ── MREEF perform_step! (out-of-place, ConstantCache) ─────────────────────────
111+
112+
@muladd function perform_step!(integrator, cache::MREEFConstantCache, repeat_step = false)
113+
(; t, dt, uprev, f, p) = integrator
114+
alg = unwrap_alg(integrator, false)
115+
m = alg.m
116+
order = alg.order
117+
ns = _mreef_sequence(alg.seq, order)
118+
T = cache.T
119+
120+
for j in 1:order
121+
nj = ns[j]
122+
h_mac = dt / nj
123+
h_fast = h_mac / m
124+
125+
u_cur = uprev
126+
for i_mac in 1:nj
127+
t_mac = t + (i_mac - 1) * h_mac
128+
k_slow = f.f2(u_cur, p, t_mac)
129+
integrator.stats.nf2 += 1
130+
for i_fast in 1:m
131+
t_fast = t_mac + (i_fast - 1) * h_fast
132+
k_fast = f.f1(u_cur, p, t_fast)
133+
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
134+
u_cur = @.. broadcast = false u_cur + h_fast * k_slow + h_fast * k_fast
135+
end
136+
end
137+
T[j] = u_cur
138+
end
139+
140+
# Aitken–Neville Richardson extrapolation
141+
for k in 1:(order - 1)
142+
for j in order:-1:(k + 1)
143+
ratio = ns[j] / ns[j - k]
144+
T[j] = @.. broadcast = false T[j] + (T[j] - T[j - 1]) / (ratio - 1)
145+
end
146+
end
147+
148+
integrator.u = T[order]
149+
150+
if integrator.opts.adaptive
151+
utilde = @.. broadcast = false T[order] - T[order - 1]
152+
atmp = calculate_residuals(
153+
utilde,
154+
uprev,
155+
integrator.u,
156+
integrator.opts.abstol,
157+
integrator.opts.reltol,
158+
integrator.opts.internalnorm,
159+
t,
160+
)
161+
integrator.EEst = integrator.opts.internalnorm(atmp, t)
162+
end
163+
end

0 commit comments

Comments
 (0)