Skip to content

Commit e99e0de

Browse files
Harsh SinghHarsh Singh
authored andcommitted
refactor(RKN): unify DPRKN/ERKN velocity-independent methods via NystromVITableau
Introduces NystromVITableau{T,T2} struct and generic NystromVICache / NystromVIConstantCache, eliminating 9 separate per-method cache structs and their per-method perform_step! implementations. Methods unified: DPRKN4, DPRKN5, DPRKN6FM, DPRKN8, DPRKN12, ERKN4, ERKN5, ERKN7, Nystrom5VelocityIndependent. Net: -1623 lines. All nystrom_convergence_tests pass.
1 parent a2c092a commit e99e0de

5 files changed

Lines changed: 477 additions & 1623 deletions

File tree

lib/OrdinaryDiffEqRKN/src/OrdinaryDiffEqRKN.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ include("rkn_caches.jl")
2727
include("interp_func.jl")
2828
include("interpolants.jl")
2929
include("rkn_perform_step.jl")
30+
include("generic_rkn_vi_perform_step.jl")
3031

3132
export Nystrom4, FineRKN4, FineRKN5, Nystrom4VelocityIndependent,
3233
Nystrom5VelocityIndependent,
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
## Generic Nyström velocity-independent perform_step!
2+
## Solves: y'' = f(t, y) where f is velocity-independent
3+
## kᵢ = f1(duprev, yᵢ, p, t + cᵢ*dt) (duprev constant throughout)
4+
## yᵢ = y₀ + cᵢ*h*y'₀ + h²*Σⱼ<ᵢ aᵢⱼ*kⱼ
5+
## y₁ = y₀ + h*y'₀ + h²*Σᵢ bᵢ*kᵢ
6+
## y'₁ = y'₀ + h*Σᵢ bpᵢ*kᵢ
7+
8+
function initialize!(integrator, cache::NystromVIConstantCache)
9+
integrator.kshortsize = 2
10+
integrator.k = typeof(integrator.k)(undef, integrator.kshortsize)
11+
duprev, uprev = integrator.uprev.x
12+
kdu = integrator.f.f1(duprev, uprev, integrator.p, integrator.t)
13+
ku = integrator.f.f2(duprev, uprev, integrator.p, integrator.t)
14+
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
15+
integrator.stats.nf2 += 1
16+
integrator.fsalfirst = ArrayPartition((kdu, ku))
17+
integrator.fsallast = zero(integrator.fsalfirst)
18+
integrator.k[1] = integrator.fsalfirst
19+
return integrator.k[2] = integrator.fsallast
20+
end
21+
22+
function initialize!(integrator, cache::NystromVICache)
23+
integrator.kshortsize = 2
24+
resize!(integrator.k, integrator.kshortsize)
25+
integrator.k[1] = integrator.fsalfirst
26+
integrator.k[2] = integrator.fsallast
27+
duprev, uprev = integrator.uprev.x
28+
integrator.f.f1(integrator.fsalfirst.x[1], duprev, uprev, integrator.p, integrator.t)
29+
integrator.f.f2(integrator.fsalfirst.x[2], duprev, uprev, integrator.p, integrator.t)
30+
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
31+
return integrator.stats.nf2 += 1
32+
end
33+
34+
@muladd function perform_step!(
35+
integrator, cache::NystromVIConstantCache, repeat_step = false)
36+
(; t, dt, f, p) = integrator
37+
duprev, uprev = integrator.uprev.x
38+
(; tab) = cache
39+
(; a, b, bp, btilde, bptilde, c, pos_only_error) = tab
40+
k1 = integrator.fsalfirst.x[1]
41+
nstages = length(b)
42+
dtsq = dt^2
43+
44+
# Compute intermediate stages
45+
ks = Vector{typeof(k1)}(undef, nstages)
46+
ks[1] = k1
47+
for i in 2:nstages
48+
ku = uprev + dt * c[i - 1] * duprev
49+
for j in 1:(i - 1)
50+
if !iszero(a[i, j])
51+
ku = ku + dtsq * a[i, j] * ks[j]
52+
end
53+
end
54+
ks[i] = f.f1(duprev, ku, p, t + dt * c[i - 1])
55+
end
56+
57+
# Position and velocity updates
58+
u = uprev + dt * duprev
59+
for i in 1:nstages
60+
if !iszero(b[i])
61+
u = u + dtsq * b[i] * ks[i]
62+
end
63+
end
64+
du = duprev
65+
for i in 1:nstages
66+
if !iszero(bp[i])
67+
du = du + dt * bp[i] * ks[i]
68+
end
69+
end
70+
71+
integrator.u = ArrayPartition((du, u))
72+
integrator.fsallast = ArrayPartition((f.f1(du, u, p, t + dt), f.f2(du, u, p, t + dt)))
73+
OrdinaryDiffEqCore.increment_nf!(integrator.stats, nstages)
74+
integrator.stats.nf2 += 1
75+
integrator.k[1] = integrator.fsalfirst
76+
integrator.k[2] = integrator.fsallast
77+
78+
if integrator.opts.adaptive && !isempty(btilde)
79+
uhat = zero(uprev)
80+
for i in 1:nstages
81+
if !iszero(btilde[i])
82+
uhat = uhat + dtsq * btilde[i] * ks[i]
83+
end
84+
end
85+
if pos_only_error
86+
atmp = calculate_residuals(uhat, integrator.uprev.x[2], integrator.u.x[2],
87+
integrator.opts.abstol, integrator.opts.reltol,
88+
integrator.opts.internalnorm, t)
89+
integrator.EEst = integrator.opts.internalnorm(atmp, t)
90+
else
91+
duhat = zero(duprev)
92+
for i in 1:nstages
93+
if !isempty(bptilde) && !iszero(bptilde[i])
94+
duhat = duhat + dt * bptilde[i] * ks[i]
95+
end
96+
end
97+
utilde = ArrayPartition((duhat, uhat))
98+
atmp = calculate_residuals(utilde, integrator.uprev, integrator.u,
99+
integrator.opts.abstol, integrator.opts.reltol,
100+
integrator.opts.internalnorm, t)
101+
integrator.EEst = integrator.opts.internalnorm(atmp, t)
102+
end
103+
end
104+
end
105+
106+
@muladd function perform_step!(
107+
integrator, cache::NystromVICache, repeat_step = false)
108+
(; t, dt, f, p) = integrator
109+
du, u = integrator.u.x
110+
duprev, uprev = integrator.uprev.x
111+
(; ks, k, utilde, tmp, atmp, tab) = cache
112+
(; a, b, bp, btilde, bptilde, c, pos_only_error) = tab
113+
ku = tmp.x[2]
114+
k1 = integrator.fsalfirst.x[1]
115+
nstages = length(b)
116+
dtsq = dt^2
117+
118+
# Compute intermediate stages k2..knstages, stored in ks[1..nstages-1]
119+
for i in 2:nstages
120+
@.. broadcast=false ku = uprev + dt * c[i - 1] * duprev
121+
for j in 1:(i - 1)
122+
if !iszero(a[i, j])
123+
kj = (j == 1) ? k1 : ks[j - 1]
124+
@.. broadcast=false ku = ku + dtsq * a[i, j] * kj
125+
end
126+
end
127+
f.f1(ks[i - 1], duprev, ku, p, t + dt * c[i - 1])
128+
end
129+
130+
# Position update: u = uprev + dt*duprev + dt^2 * sum(b[i]*ki)
131+
@.. broadcast=false u = uprev + dt * duprev
132+
for i in 1:nstages
133+
if !iszero(b[i])
134+
ki = (i == 1) ? k1 : ks[i - 1]
135+
@.. broadcast=false u = u + dtsq * b[i] * ki
136+
end
137+
end
138+
139+
# Velocity update: du = duprev + dt * sum(bp[i]*ki)
140+
@.. broadcast=false du = duprev
141+
for i in 1:nstages
142+
if !iszero(bp[i])
143+
ki = (i == 1) ? k1 : ks[i - 1]
144+
@.. broadcast=false du = du + dt * bp[i] * ki
145+
end
146+
end
147+
148+
f.f1(k.x[1], du, u, p, t + dt)
149+
f.f2(k.x[2], du, u, p, t + dt)
150+
OrdinaryDiffEqCore.increment_nf!(integrator.stats, nstages)
151+
integrator.stats.nf2 += 1
152+
153+
if integrator.opts.adaptive && !isempty(btilde)
154+
if pos_only_error
155+
uhat = utilde.x[2]
156+
@.. broadcast=false uhat = zero(uhat)
157+
for i in 1:nstages
158+
if !iszero(btilde[i])
159+
ki = (i == 1) ? k1 : ks[i - 1]
160+
@.. broadcast=false uhat = uhat + dtsq * btilde[i] * ki
161+
end
162+
end
163+
calculate_residuals!(atmp.x[2], uhat, integrator.uprev.x[2], integrator.u.x[2],
164+
integrator.opts.abstol, integrator.opts.reltol,
165+
integrator.opts.internalnorm, t)
166+
integrator.EEst = integrator.opts.internalnorm(atmp.x[2], t)
167+
else
168+
duhat, uhat = utilde.x
169+
@.. broadcast=false uhat = zero(uhat)
170+
@.. broadcast=false duhat = zero(duhat)
171+
for i in 1:nstages
172+
ki = (i == 1) ? k1 : ks[i - 1]
173+
if !iszero(btilde[i])
174+
@.. broadcast=false uhat = uhat + dtsq * btilde[i] * ki
175+
end
176+
if !isempty(bptilde) && !iszero(bptilde[i])
177+
@.. broadcast=false duhat = duhat + dt * bptilde[i] * ki
178+
end
179+
end
180+
calculate_residuals!(atmp, utilde, integrator.uprev, integrator.u,
181+
integrator.opts.abstol, integrator.opts.reltol,
182+
integrator.opts.internalnorm, t)
183+
integrator.EEst = integrator.opts.internalnorm(atmp, t)
184+
end
185+
end
186+
end

0 commit comments

Comments
 (0)