2020#include " oneapi/dal/array.hpp"
2121#include " oneapi/dal/common.hpp"
2222#include " oneapi/dal/detail/common.hpp"
23+ #include " oneapi/dal/detail/array_utils.hpp"
2324
2425#include " onedal/common.hpp"
2526
@@ -46,6 +47,47 @@ inline void check_access(const dal::array<Type>& arr, std::int64_t idx) {
4647 }
4748}
4849
50+ template <typename T>
51+ inline auto get_policy (const dal::array<T>& arr) {
52+ return dal::detail::get_impl (arr).get_policy ();
53+ }
54+
55+ template <typename Policy>
56+ constexpr inline bool is_host_policy_v = dal::detail::is_one_of_v<Policy, //
57+ detail::default_host_policy, detail::host_policy>;
58+
59+ template <typename InpPolicy, typename OutPolicy>
60+ inline bool need_copy (const InpPolicy& inp, const OutPolicy& out) {
61+ using out_policy_t = std::decay_t <decltype (inp)>;
62+ using inp_policy_t = std::decay_t <decltype (out)>;
63+ constexpr bool is_inp_host = is_host_policy_v<inp_policy_t >;
64+ constexpr bool is_out_host = is_host_policy_v<out_policy_t >;
65+ constexpr bool result = !(is_inp_host && is_out_host);
66+ return result;
67+ }
68+
69+ // TODO: Check for the same policy
70+ template <typename Policy, typename T>
71+ inline dal::array<T> to_policy (const Policy& out, const dal::array<T>& source) {
72+ return std::visit ([&](const auto & inp) -> dal::array<T> {
73+ if (need_copy (inp, out)) {
74+ return detail::copy (out, source);
75+ }
76+ else {
77+ return dal::array<T>{ source };
78+ }
79+ }, get_policy (source));
80+ }
81+
82+ template <typename Policy, typename Array>
83+ inline void instantiate_to_policy (py::class_<Array>& py_array) {
84+ constexpr const char name[] = " to_policy" ;
85+ py_array.def (name, [](const Array& source, const Policy& policy) {
86+ auto result = to_policy (policy, source);
87+ return py::cast ( std::move (result) );
88+ });
89+ }
90+
4991template <typename Type>
5092void instantiate_array_by_type (py::module & pm) {
5193 const auto name = name_array<Type>();
@@ -67,6 +109,9 @@ void instantiate_array_by_type(py::module& pm) {
67109 py_array.def (" __len__" , &array_t ::get_count);
68110 py_array.def (" get_count" , &array_t ::get_count);
69111 py_array.def (" has_mutable_data" , &array_t ::has_mutable_data);
112+ py_array.def (" has_data" , [](const array_t & array) -> bool {
113+ return array.get_count () > std::int64_t (0l );
114+ });
70115 py_array.def (" get_slice" , [](const array_t & array,
71116 std::int64_t first, std::int64_t last) -> array_t {
72117 constexpr std::int64_t zero = 0l ;
@@ -92,7 +137,7 @@ void instantiate_array_by_type(py::module& pm) {
92137 py_array.def (" get_policy" , [](const array_t & arr) -> py::object {
93138 return std::visit ([](const auto & policy) -> py::object {
94139 return py::cast ( policy );
95- }, dal::detail::get_impl (arr). get_policy ());
140+ }, get_policy (arr ));
96141 });
97142 py_array.def (" __getitem__" , [](const array_t & arr, std::int64_t idx) -> Type {
98143 check_access<Type, false >(arr, idx);
@@ -105,6 +150,11 @@ void instantiate_array_by_type(py::module& pm) {
105150 py_array.def_property_readonly (" __is_onedal_array__" , [](const array_t &) -> bool {
106151 return true ;
107152 });
153+ instantiate_to_policy<detail::host_policy>(py_array);
154+ instantiate_to_policy<detail::default_host_policy>(py_array);
155+ #ifdef ONEDAL_DATA_PARALLEL
156+ instantiate_to_policy<detail::data_parallel_policy>(py_array);
157+ #endif // ONEDAL_DATA_PARALLEL
108158}
109159
110160template <typename ... Types>
0 commit comments