1818
1919#include " oneapi/dal/algo/covariance/backend/cpu/compute_kernel.hpp"
2020#include " oneapi/dal/algo/covariance/backend/cpu/compute_kernel_common.hpp"
21+ #include " oneapi/dal/algo/covariance/backend/cpu/partial_compute_kernel.hpp"
22+ #include " oneapi/dal/algo/covariance/backend/cpu/finalize_compute_kernel.hpp"
2123#include " oneapi/dal/backend/interop/common.hpp"
2224#include " oneapi/dal/backend/interop/error_converter.hpp"
2325#include " oneapi/dal/backend/interop/table_conversion.hpp"
2426
27+ #include " oneapi/dal/backend/primitives/utils.hpp"
28+
2529#include " oneapi/dal/table/row_accessor.hpp"
2630
2731namespace oneapi ::dal::covariance::backend {
@@ -30,13 +34,83 @@ using dal::backend::context_cpu;
3034using descriptor_t = detail::descriptor_base<task::compute>;
3135using parameters_t = detail::compute_parameters<task::compute>;
3236
37+ namespace be = dal::backend;
38+ namespace pr = be::primitives;
3339namespace daal_covariance = daal::algorithms::covariance;
3440namespace interop = dal::backend::interop;
3541
3642template <typename Float, daal::internal::CpuType Cpu>
3743using daal_covariance_kernel_t = daal_covariance::internal::
3844 CovarianceDenseBatchKernel<Float, daal_covariance::Method::defaultDense, Cpu>;
3945
46+ template <typename Float, typename Task>
47+ static compute_result<Task> call_daal_spmd_kernel (const context_cpu& ctx,
48+ const detail::descriptor_base<Task>& desc,
49+ const detail::compute_parameters<Task>& params,
50+ const table& data) {
51+ auto & comm = ctx.get_communicator ();
52+ const std::int64_t component_count = data.get_column_count ();
53+
54+ // Compute partial results locally on this rank's data
55+ partial_compute_input<Task> partial_input (data);
56+ auto partial_result =
57+ partial_compute_kernel_cpu<Float, method::by_default, Task>{}(ctx, desc, partial_input);
58+
59+ // Extract partial results as mutable arrays
60+ auto nobs_nd = pr::table2ndarray<Float>(partial_result.get_partial_n_rows ());
61+ auto sums_nd = pr::table2ndarray<Float>(partial_result.get_partial_sum ());
62+ auto crossproduct_nd = pr::table2ndarray<Float>(partial_result.get_partial_crossproduct ());
63+
64+ auto nobs_ary = dal::array<Float>::wrap (nobs_nd.get_mutable_data (), nobs_nd.get_count ());
65+ auto sums_ary = dal::array<Float>::wrap (sums_nd.get_mutable_data (), sums_nd.get_count ());
66+ auto crossproduct_ary =
67+ dal::array<Float>::wrap (crossproduct_nd.get_mutable_data (), crossproduct_nd.get_count ());
68+
69+ // The DAAL online kernel stores centered crossproducts:
70+ // cp = X^T*X - sums*sums^T/nobs
71+ // Simple allreduce of centered crossproducts is incorrect because each
72+ // rank uses its local mean. Un-center before allreduce, then re-center
73+ // with global statistics after.
74+ const Float local_nobs = *nobs_ary.get_data ();
75+ if (!desc.get_assume_centered () && local_nobs >= 1.0 ) {
76+ Float* cp_ptr = crossproduct_ary.get_mutable_data ();
77+ const Float* sums_ptr = sums_ary.get_data ();
78+ const Float inv_nobs = Float (1 ) / local_nobs;
79+ for (std::int64_t i = 0 ; i < component_count; ++i) {
80+ for (std::int64_t j = 0 ; j < component_count; ++j) {
81+ cp_ptr[i * component_count + j] += inv_nobs * sums_ptr[i] * sums_ptr[j];
82+ }
83+ }
84+ }
85+
86+ // Allreduce raw crossproduct, sums, and nobs across all ranks
87+ comm.allreduce (nobs_ary).wait ();
88+ comm.allreduce (sums_ary).wait ();
89+ comm.allreduce (crossproduct_ary).wait ();
90+
91+ // Re-center with global statistics
92+ const Float global_nobs = *nobs_ary.get_data ();
93+ if (!desc.get_assume_centered () && global_nobs >= 1.0 ) {
94+ Float* cp_ptr = crossproduct_ary.get_mutable_data ();
95+ const Float* sums_ptr = sums_ary.get_data ();
96+ const Float inv_nobs = Float (1 ) / global_nobs;
97+ for (std::int64_t i = 0 ; i < component_count; ++i) {
98+ for (std::int64_t j = 0 ; j < component_count; ++j) {
99+ cp_ptr[i * component_count + j] -= inv_nobs * sums_ptr[i] * sums_ptr[j];
100+ }
101+ }
102+ }
103+
104+ // Reconstruct aggregated partial result and finalize
105+ partial_compute_result<Task> aggregated;
106+ aggregated.set_partial_n_rows (homogen_table::wrap (nobs_ary, 1 , 1 ));
107+ aggregated.set_partial_sum (homogen_table::wrap (sums_ary, 1 , component_count));
108+ aggregated.set_partial_crossproduct (
109+ homogen_table::wrap (crossproduct_ary, component_count, component_count));
110+
111+ return finalize_compute_kernel_cpu<Float, method::by_default, Task>{}(ctx, desc, aggregated);
112+ }
113+
40114template <typename Float, typename Task>
41115static compute_result<Task> call_daal_kernel (const context_cpu& ctx,
42116 const detail::descriptor_base<Task>& desc,
@@ -121,6 +195,9 @@ static compute_result<Task> compute(const context_cpu& ctx,
121195 const detail::descriptor_base<Task>& desc,
122196 const detail::compute_parameters<Task>& params,
123197 const compute_input<Task>& input) {
198+ if (ctx.get_communicator ().get_rank_count () > 1 ) {
199+ return call_daal_spmd_kernel<Float, Task>(ctx, desc, params, input.get_data ());
200+ }
124201 return call_daal_kernel<Float, Task>(ctx, desc, params, input.get_data ());
125202}
126203
0 commit comments