Skip to content

Commit 555f7d7

Browse files
Merge pull request #3132 from ChrisRackauckas-Claude/unify-sde-integrator-methods
Unify SDE integrator methods via field-gated conditionals
2 parents 32aede9 + 5ff8d2d commit 555f7d7

4 files changed

Lines changed: 52 additions & 11 deletions

File tree

lib/OrdinaryDiffEqCore/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "OrdinaryDiffEqCore"
22
uuid = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
33
authors = ["ParamThakkar123 <paramthakkar864@gmail.com>"]
4-
version = "3.19.0"
4+
version = "3.20.0"
55

66
[deps]
77
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

lib/OrdinaryDiffEqCore/src/dense/generic_dense.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1321,12 +1321,17 @@ function interpolation_differential_vars(differential_vars, y₀, idxs)
13211321
end
13221322
end
13231323

1324-
# If no dispatch found, assume Hermite
1324+
# If no dispatch found, assume Hermite (or linear when k is empty, e.g. SDE)
13251325
function _ode_interpolant(
13261326
Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, differential_vars
13271327
) where {TI}
13281328
TI > 3 && throw(DerivativeOrderNotPossibleError())
13291329

1330+
# Linear fallback when no dense output vectors (e.g. SDE)
1331+
if isempty(k)
1332+
return linear_interpolant(Θ, dt, y₀, y₁, idxs, T)
1333+
end
1334+
13301335
differential_vars = interpolation_differential_vars(differential_vars, y₀, idxs)
13311336
return hermite_interpolant(
13321337
Θ, dt, y₀, y₁, k, Val{cache isa OrdinaryDiffEqMutableCache},
@@ -1339,6 +1344,10 @@ function _ode_interpolant!(
13391344
) where {TI}
13401345
TI > 3 && throw(DerivativeOrderNotPossibleError())
13411346

1347+
if isempty(k)
1348+
return linear_interpolant!(out, Θ, dt, y₀, y₁, idxs, T)
1349+
end
1350+
13421351
differential_vars = interpolation_differential_vars(differential_vars, y₀, idxs)
13431352
return hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs, T, differential_vars)
13441353
end

lib/OrdinaryDiffEqCore/src/integrators/integrator_interface.jl

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,21 @@ function _change_t_via_interpolation!(
1717
else
1818
integrator(integrator.u, t)
1919
end
20+
# SDE path: reject noise to rewind W/P, update sqdt
21+
W = _get_W(integrator)
22+
if !isnothing(W)
23+
reject_noise!(W, t - integrator.tprev, integrator.u, integrator.p)
24+
reject_noise!(_get_P(integrator), t - integrator.tprev, integrator.u, integrator.p)
25+
end
2026
integrator.t = t
2127
integrator.dt = integrator.t - integrator.tprev
22-
SciMLBase.reeval_internals_due_to_modification!(
23-
integrator; callback_initializealg = reinitialize_alg
24-
)
28+
if isnothing(W)
29+
SciMLBase.reeval_internals_due_to_modification!(
30+
integrator; callback_initializealg = reinitialize_alg
31+
)
32+
else
33+
integrator.sqdt = sqrt(abs(integrator.dt))
34+
end
2535
if T
2636
solution_endpoint_match_cur_integrator!(integrator)
2737
end
@@ -374,6 +384,15 @@ function SciMLBase.set_rng!(integrator::ODEIntegrator, rng)
374384
)
375385
end
376386
integrator.rng = rng
387+
# Sync framework-constructed noise processes (SDE only, no-op when W/P are nothing)
388+
W = _get_W(integrator)
389+
if !isnothing(W) && integrator.noise === nothing
390+
W.rng = rng
391+
end
392+
P = _get_P(integrator)
393+
if !isnothing(P)
394+
P.rng = rng
395+
end
377396
return nothing
378397
end
379398

@@ -435,8 +454,10 @@ function SciMLBase.reinit!(
435454

436455
tType = typeof(integrator.t)
437456
tspan = (tType(t0), tType(tf))
438-
reinit_tstops!(tType, integrator.opts.tstops, tstops, d_discontinuities, tspan;
439-
p = parameter_values(integrator))
457+
reinit_tstops!(
458+
tType, integrator.opts.tstops, tstops, d_discontinuities, tspan;
459+
p = parameter_values(integrator)
460+
)
440461
reinit_saveat!(tType, integrator.opts.saveat, saveat, tspan)
441462
reinit_d_discontinuities!(tType, integrator.opts.d_discontinuities, d_discontinuities, tspan)
442463
if erase_sol
@@ -447,7 +468,9 @@ function SciMLBase.reinit!(
447468
end
448469
resize!(integrator.sol.u, resize_start)
449470
resize!(integrator.sol.t, resize_start)
450-
resize!(integrator.sol.k, resize_start)
471+
if _has_ks(integrator)
472+
resize!(integrator.sol.k, resize_start)
473+
end
451474

452475
if integrator.opts.save_start || (!isempty(saveat) && saveat[1] == tType(t0))
453476
copyat_or_push!(integrator.sol.t, 1, t0)
@@ -461,7 +484,7 @@ function SciMLBase.reinit!(
461484
if integrator.sol.u_analytic !== nothing
462485
resize!(integrator.sol.u_analytic, 0)
463486
end
464-
if integrator.alg isa OrdinaryDiffEqCompositeAlgorithm
487+
if is_composite_algorithm(integrator.alg)
465488
resize!(integrator.sol.alg_choice, resize_start)
466489
end
467490
integrator.saveiter = resize_start
@@ -505,17 +528,25 @@ function SciMLBase.reinit!(
505528
if reinit_retcode
506529
integrator.sol = SciMLBase.solution_new_retcode(integrator.sol, ReturnCode.Default)
507530
end
531+
532+
reinit_noise!(_get_W(integrator), integrator.dt)
508533
return nothing
509534
end
510535

511-
function SciMLBase.auto_dt_reset!(integrator::ODEIntegrator)
512-
integrator.dt = ode_determine_initdt(
536+
# Extensible initdt hook: ODE defaults to ode_determine_initdt.
537+
# SDE extends for SDE algorithm types to call sde_determine_initdt.
538+
function _determine_initdt(integrator)
539+
return ode_determine_initdt(
513540
integrator.u, integrator.t,
514541
integrator.tdir, integrator.opts.dtmax,
515542
integrator.opts.abstol, integrator.opts.reltol,
516543
integrator.opts.internalnorm, integrator.sol.prob,
517544
integrator
518545
)
546+
end
547+
548+
function SciMLBase.auto_dt_reset!(integrator::ODEIntegrator)
549+
integrator.dt = _determine_initdt(integrator)
519550
integrator.dtpropose = integrator.dt
520551
return increment_nf!(integrator.stats, 2)
521552
end

lib/OrdinaryDiffEqCore/src/integrators/integrator_utils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ reject_noise!(::Nothing, args...) = nothing
55
save_noise!(::Nothing) = nothing
66
noise_curt(::Nothing) = nothing
77
is_noise_saveable(::Nothing) = false
8+
reinit_noise!(::Nothing, dt) = nothing
89

910
# Noise field accessors — safe for any integrator type.
1011
# ODEIntegrator has W/P/sqdt; other integrators (DDEIntegrator) don't.

0 commit comments

Comments
 (0)