Skip to content

Commit 03b4d73

Browse files
author
Nikita Kulikov
committed
1 parent f6d23bf commit 03b4d73

21 files changed

+156
-118
lines changed

onedal/cluster/dbscan.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def _fit(self, X, y, sample_weight, module, queue):
8181
).ravel()
8282
else:
8383
self.core_sample_indices_ = np.array([], dtype=np.intc)
84+
self.core_sample_indices_ = self.core_sample_indices_.astype(np.intc)
8485
self.components_ = np.take(X, self.core_sample_indices_, axis=0)
8586
self.n_features_in_ = X.shape[1]
8687
return self

onedal/cluster/kmeans.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,10 @@ def _check_params_vs_input(
9191
self, X_table, policy, default_n_init=10, dtype=np.float32
9292
):
9393
# n_clusters
94-
if X_table.shape[0] < self.n_clusters:
94+
X_row_count = X_table.get_row_count()
95+
if X_row_count < self.n_clusters:
9596
raise ValueError(
96-
f"n_samples={X_table.shape[0]} should be >= n_clusters={self.n_clusters}."
97+
f"n_samples={X_row_count} should be >= n_clusters={self.n_clusters}."
9798
)
9899

99100
# tol
@@ -182,7 +183,7 @@ def _init_centroids_custom(
182183
elif _is_arraylike_not_scalar(init):
183184
centers = np.asarray(init)
184185
assert centers.shape[0] == n_clusters
185-
assert centers.shape[1] == X_table.column_count
186+
assert centers.shape[1] == X_table.get_column_count()
186187
centers = _convert_to_supported(policy, init)
187188
centers_table = to_table(centers)
188189
else:
@@ -221,10 +222,6 @@ def _init_centroids_generic(self, X, init, random_state, policy, dtype=np.float3
221222
def _fit_backend(self, X_table, centroids_table, module, policy, dtype=np.float32):
222223
params = self._get_onedal_params(dtype)
223224

224-
# TODO: check all features for having correct type
225-
meta = _backend.get_table_metadata(X_table)
226-
assert meta.get_npy_dtype(0) == dtype
227-
228225
result = module.train(policy, params, X_table, centroids_table)
229226

230227
return (
@@ -238,7 +235,7 @@ def _fit(self, X, module, queue=None):
238235
policy = self._get_policy(queue, X)
239236
_, X_table, dtype = self._get_params_and_input(X, policy)
240237

241-
self.n_features_in_ = X_table.column_count
238+
self.n_features_in_ = X_table.get_column_count()
242239

243240
best_model, best_n_iter = None, None
244241
best_inertia, best_labels = None, None

onedal/common/dispatch_utils.hpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
*******************************************************************************/
1616

1717
#pragma once
18-
#include <iostream>
18+
1919
#include <pybind11/pybind11.h>
2020

2121
#include "oneapi/dal/train.hpp"
@@ -79,13 +79,8 @@ struct train_ops {
7979

8080
template <typename Float, typename Method, typename... Args>
8181
auto operator()(const pybind11::dict& params) {
82-
std::cout << "C++: Function: " << __PRETTY_FUNCTION__ << std::endl;
83-
std::cout << "C++: Before training" << std::endl;
8482
auto desc = ops.template operator()<Float, Method, Task, Args...>(params);
85-
std::cout << "C++: Descriptor has been created" << std::endl;
86-
auto result = dal::train(policy, desc, input);
87-
std::cout << "C++: Algorithms has been called" << std::endl;
88-
return result;
83+
return dal::train(policy, desc, input);
8984
}
9085

9186
Policy policy;

onedal/data_management/array/array.cpp

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "oneapi/dal/array.hpp"
2121
#include "oneapi/dal/common.hpp"
2222
#include "oneapi/dal/detail/common.hpp"
23+
#include "oneapi/dal/detail/array_utils.hpp"
2324

2425
#include "onedal/common.hpp"
2526

@@ -46,6 +47,47 @@ inline void check_access(const dal::array<Type>& arr, std::int64_t idx) {
4647
}
4748
}
4849

50+
template <typename T>
51+
inline auto get_policy(const dal::array<T>& arr) {
52+
return dal::detail::get_impl(arr).get_policy();
53+
}
54+
55+
template <typename Policy>
56+
constexpr inline bool is_host_policy_v = dal::detail::is_one_of_v<Policy, //
57+
detail::default_host_policy, detail::host_policy>;
58+
59+
template <typename InpPolicy, typename OutPolicy>
60+
inline bool need_copy(const InpPolicy& inp, const OutPolicy& out) {
61+
using out_policy_t = std::decay_t<decltype(inp)>;
62+
using inp_policy_t = std::decay_t<decltype(out)>;
63+
constexpr bool is_inp_host = is_host_policy_v<inp_policy_t>;
64+
constexpr bool is_out_host = is_host_policy_v<out_policy_t>;
65+
constexpr bool result = !(is_inp_host && is_out_host);
66+
return result;
67+
}
68+
69+
// TODO: Check for the same policy
70+
template <typename Policy, typename T>
71+
inline dal::array<T> to_policy(const Policy& out, const dal::array<T>& source) {
72+
return std::visit([&](const auto& inp) -> dal::array<T> {
73+
if (need_copy(inp, out)) {
74+
return detail::copy(out, source);
75+
}
76+
else {
77+
return dal::array<T>{ source };
78+
}
79+
}, get_policy(source));
80+
}
81+
82+
template <typename Policy, typename Array>
83+
inline void instantiate_to_policy(py::class_<Array>& py_array) {
84+
constexpr const char name[] = "to_policy";
85+
py_array.def(name, [](const Array& source, const Policy& policy) {
86+
auto result = to_policy(policy, source);
87+
return py::cast( std::move(result) );
88+
});
89+
}
90+
4991
template <typename Type>
5092
void instantiate_array_by_type(py::module& pm) {
5193
const auto name = name_array<Type>();
@@ -67,6 +109,9 @@ void instantiate_array_by_type(py::module& pm) {
67109
py_array.def("__len__", &array_t::get_count);
68110
py_array.def("get_count", &array_t::get_count);
69111
py_array.def("has_mutable_data", &array_t::has_mutable_data);
112+
py_array.def("has_data", [](const array_t& array) -> bool {
113+
return array.get_count() > std::int64_t(0l);
114+
});
70115
py_array.def("get_slice", [](const array_t& array,
71116
std::int64_t first, std::int64_t last) -> array_t {
72117
constexpr std::int64_t zero = 0l;
@@ -92,7 +137,7 @@ void instantiate_array_by_type(py::module& pm) {
92137
py_array.def("get_policy", [](const array_t& arr) -> py::object {
93138
return std::visit([](const auto& policy) -> py::object {
94139
return py::cast( policy );
95-
}, dal::detail::get_impl(arr).get_policy());
140+
}, get_policy(arr));
96141
});
97142
py_array.def("__getitem__", [](const array_t& arr, std::int64_t idx) -> Type {
98143
check_access<Type, false>(arr, idx);
@@ -105,6 +150,11 @@ void instantiate_array_by_type(py::module& pm) {
105150
py_array.def_property_readonly("__is_onedal_array__", [](const array_t&) -> bool {
106151
return true;
107152
});
153+
instantiate_to_policy<detail::host_policy>(py_array);
154+
instantiate_to_policy<detail::default_host_policy>(py_array);
155+
#ifdef ONEDAL_DATA_PARALLEL
156+
instantiate_to_policy<detail::data_parallel_policy>(py_array);
157+
#endif // ONEDAL_DATA_PARALLEL
108158
}
109159

110160
template <typename... Types>

onedal/data_management/table/table.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,6 @@ void instantiate_table(py::module& pm) {
6868
py_table.def(py::init<dal::homogen_table>());
6969
py_table.def(py::init<dal::heterogen_table>());
7070

71-
py_table.def("has_data", &dal::table::has_data);
72-
73-
py_table.def("__repr__", [](const dal::table& t) {
74-
std::stringstream stream;
75-
stream << "oneDAL table of kind: " << t.get_kind();
76-
return stream.str();
77-
});
78-
7971
return instantiate_table_iface(py_table);
8072
}
8173

onedal/data_management/table/table_iface.hpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616

1717
#pragma once
1818

19+
#include <string>
1920
#include <cstdint>
21+
#include <sstream>
2022

2123
#include <pybind11/pybind11.h>
2224

@@ -48,9 +50,25 @@ inline void instantiate_table_iface(py::class_<Table>& py_table) {
4850
py_table.def("get_row_count", &Table::get_row_count);
4951
py_table.def("get_data_layout", &Table::get_data_layout);
5052
py_table.def("get_column_count", &Table::get_column_count);
53+
py_table.def("has_data", [](const Table& t) -> bool {
54+
return t.has_data();
55+
});
5156
py_table.def("get_kind", [](const Table& t) -> table_kind {
5257
return static_cast<table_kind>(t.get_kind());
5358
});
59+
py_table.def_property_readonly("__is_onedal_table__", [](const Table&) -> bool {
60+
return true;
61+
});
62+
63+
py_table.def("__repr__", [](const Table& t) {
64+
std::stringstream stream;
65+
66+
static const auto type = py::type::handle_of<Table>();
67+
static const auto name = std::string{ py::str(type) };
68+
stream << "<oneDAL table of kind: " << t.get_kind();
69+
stream << " packed into " << name << '>';
70+
return stream.str();
71+
});
5472
}
5573

5674
} // namespace oneapi::dal::python::data_management

onedal/datatypes/_data_conversion.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,6 @@
2424

2525
import onedal.interoperability as interop
2626

27-
to_table_one = interop.to_table
28-
from_table_one = interop.from_table
29-
3027
try:
3128
import dpctl
3229
import dpctl.tensor as dpt
@@ -42,10 +39,17 @@ def _apply_and_pass(func, *args):
4239
return tuple(map(func, args))
4340

4441

42+
def from_table_one(table):
43+
return interop.from_table(table)
44+
45+
4546
def from_table(*args):
4647
return _apply_and_pass(from_table_one, *args)
4748

4849

50+
def to_table_one(table):
51+
return interop.to_table(table)
52+
4953
def convert_one_to_table(arg):
5054
return to_table_one(make2d(arg))
5155

onedal/interoperability/buffer/buffer_and_table.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*******************************************************************************/
16-
#include <iostream>
16+
1717
#include <pybind11/numpy.h>
1818
#include <pybind11/pybind11.h>
1919

onedal/interoperability/csr_table.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import onedal
2+
import numpy as np
23

34
from .array import to_array, from_array
45

@@ -7,6 +8,7 @@
78
table = onedal._backend.data_management.table
89
csr_table = onedal._backend.data_management.csr_table
910
csr_kind = onedal._backend.data_management.table_kind.csr
11+
sparse_indexing = onedal._backend.data_management.sparse_indexing
1012

1113
def is_native_csr(entity) -> bool:
1214
if isinstance(entity, table):
@@ -31,14 +33,26 @@ def to_csr_table_native(entity) -> csr_table:
3133
assert is_native_csr(entity)
3234
return csr_table(entity)
3335

36+
def to_typed_array(x, dtypes = [np.int64]):
37+
result = x
38+
if x.dtype not in list(dtypes):
39+
result = x.astype(dtypes[0])
40+
assert result.dtype in dtypes
41+
return to_array(result)
42+
3443
# Converting python entity to table
44+
# TODO: Implement smarter logic
3545
def to_csr_table_python(entity) -> csr_table:
3646
assert isspmatrix_csr(entity)
3747
_, col_count = entity.shape
38-
ids = to_array(entity.indices)
39-
ofs = to_array(entity.indptr)
40-
nz = to_array(entity.nnz)
41-
result = csr_table(nz, ids, ofs, col_count)
48+
ids = to_typed_array(entity.indices)
49+
print(ids.get_policy())
50+
ofs = to_typed_array(entity.indptr)
51+
print(ofs.get_policy())
52+
nz = to_array(entity.data)
53+
print(nz.get_policy())
54+
result = csr_table(nz, ids, ofs, \
55+
col_count, sparse_indexing.zero_based)
4256
assert_table(result, entity)
4357
return result
4458

onedal/interoperability/dlpack/array_interop.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from .dlpack_utils import is_dlpack_entity, is_nd
44

55
wrap_to_array = onedal._backend.interop.dlpack.wrap_to_array
6-
wrap_from_array = onedal._backend.interop.dlpack.wrap_from_array
76

87
# TODO: implement more complex logic of
98
# checking shape & strides in entity
@@ -15,18 +14,3 @@ def is_dlpack_array(entity) -> bool:
1514
def to_array(entity):
1615
assert is_dlpack_array(entity)
1716
return wrap_to_array(entity)
18-
19-
class fake_dlpack_array:
20-
def __init__(self, array):
21-
self.array = array
22-
23-
@property
24-
def __dlpack__(self):
25-
if not hasattr(self, "dlpack") or self.dlpack is None:
26-
self.dlpack = wrap_from_array(self.array)
27-
return self.dlpack
28-
29-
def from_array(array):
30-
result = fake_dlpack_array(array)
31-
assert is_dlpack_array(result)
32-
return result

0 commit comments

Comments
 (0)