Skip to content

Commit 66a6d55

Browse files
author
Nikita Kulikov
committed
Squashing against "dev/dataframe-interchange-api"
1 parent f92fb4c commit 66a6d55

File tree

11 files changed

+972
-7
lines changed

11 files changed

+972
-7
lines changed

onedal/common/dtype_dispatcher.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*******************************************************************************/
16-
16+
#include <iostream>
1717
#include <pybind11/pybind11.h>
1818

1919
#include "oneapi/dal/common.hpp"
20+
#include "oneapi/dal/table/common.hpp"
2021

2122
#include "onedal/common.hpp"
2223
#include "onedal/common/dtype_dispatcher.hpp"
@@ -25,8 +26,8 @@ namespace py = pybind11;
2526

2627
namespace oneapi::dal::python {
2728

28-
ONEDAL_PY_INIT_MODULE(dtype_dispatcher) {
29-
py::enum_<dal::data_type> py_dtype(m, "dtype");
29+
inline void instantiate_data_type(py::module& pm) {
30+
py::enum_<dal::data_type> py_dtype(pm, "dtype");
3031
py_dtype.value("int8", dal::data_type::int8);
3132
py_dtype.value("int16", dal::data_type::int16);
3233
py_dtype.value("int32", dal::data_type::int32);
@@ -38,6 +39,20 @@ ONEDAL_PY_INIT_MODULE(dtype_dispatcher) {
3839
py_dtype.value("float32", dal::data_type::float32);
3940
py_dtype.value("float64", dal::data_type::float64);
4041
py_dtype.export_values();
42+
}
43+
44+
inline void instantiate_feature_type(py::module& pm) {
45+
py::enum_<dal::feature_type> py_ftype(pm, "ftype");
46+
py_ftype.value("ratio", dal::feature_type::ratio);
47+
py_ftype.value("nominal", dal::feature_type::nominal);
48+
py_ftype.value("ordinal", dal::feature_type::ordinal);
49+
py_ftype.value("interval", dal::feature_type::interval);
50+
py_ftype.export_values();
51+
}
52+
53+
ONEDAL_PY_INIT_MODULE(dtype_dispatcher) {
54+
(void)instantiate_feature_type(m);
55+
(void)instantiate_data_type(m);
4156
} // ONEDAL_PY_INIT_MODULE(dtype_dispatcher)
4257

4358
} // namespace oneapi::dal::python

onedal/data_management/chunked_array/chunked_array.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ void instantiate_chunked_array_by_type(py::module& m) {
4343

4444
py::class_<chunked_array_t> py_array(m, c_name);
4545
py_array.def(py::init<>());
46+
py_array.def(py::init<chunked_array_t>());
4647
py_array.def(py::pickle(
4748
[](const chunked_array_t& m) -> py::bytes {
4849
return serialize(m);
@@ -86,19 +87,31 @@ void instantiate_chunked_array_by_type(py::module& m) {
8687
});
8788
}
8889

90+
inline void instantiate_make_chunked_array(py::module& pm) {
91+
constexpr const char name[] = "make_chunked_array";
92+
pm.def(name, [](data_type dtype, std::int64_t chunk_count) -> py::object {
93+
return detail::dispatch_by_data_type(dtype, [&](auto type_tag) -> py::object {
94+
using type_t = std::decay_t<decltype(type_tag)>;
95+
auto result = chunked_array<type_t>(chunk_count);
96+
return py::cast(std::move(result));
97+
});
98+
});
99+
}
100+
89101
template <typename... Types>
90102
inline void instantiate_chunked_array_impl(py::module& pm,
91103
const std::tuple<Types...>* const = nullptr) {
92104
auto instantiate = [&](auto type_tag) -> void {
93105
using type_t = std::decay_t<decltype(type_tag)>;
94-
return instantiate_chunked_array_by_type<type_t>(pm);
106+
instantiate_chunked_array_by_type<type_t>(pm);
95107
};
96108
return detail::apply(instantiate, Types{}...);
97109
}
98110

99111
void instantiate_chunked_array(py::module& pm) {
100112
constexpr const supported_types_t* types = nullptr;
101-
return instantiate_chunked_array_impl(pm, types);
113+
(void)instantiate_chunked_array_impl(pm, types);
114+
(void)instantiate_make_chunked_array(pm);
102115
}
103116

104117
} // namespace oneapi::dal::python::data_management

onedal/data_management/heterogen_table/heterogen_table.cpp

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030
#include "oneapi/dal/table/heterogen.hpp"
3131
#include "oneapi/dal/table/common.hpp"
3232

33+
#include "oneapi/dal/table/detail/table_iface.hpp"
34+
#include "oneapi/dal/table/detail/table_utils.hpp"
35+
3336
#include "oneapi/dal/detail/common.hpp"
3437

3538
namespace py = pybind11;
@@ -53,6 +56,23 @@ inline void make_chunked_array_setter(py::class_<dal::heterogen_table>& table) {
5356
});
5457
}
5558

59+
inline void instantiate_getter(py::class_<dal::heterogen_table>& table) {
60+
constexpr const char name[] = "get_column";
61+
table.def(name, [](const dal::heterogen_table& table, std::int64_t col) -> py::object {
62+
const auto dtype = table.get_metadata().get_data_type(col);
63+
return detail::dispatch_by_data_type(dtype, [&](auto type_tag) -> py::object {
64+
using dtype_t = std::decay_t<decltype(type_tag)>;
65+
using column_t = dal::chunked_array<dtype_t>;
66+
const detail::pimpl_accessor acc{};
67+
68+
auto iface = detail::get_heterogen_table_iface(table);
69+
auto raw_column = iface->get_column(col);
70+
auto column = acc.make<column_t>(raw_column);
71+
return py::cast(std::move(column));
72+
});
73+
});
74+
}
75+
5676
template <typename Table, typename... Types>
5777
void instantiate_setters(py::class_<Table>& table, const std::tuple<Types...>* const = nullptr) {
5878
return detail::apply(
@@ -75,8 +95,9 @@ void instantiate_heterogen_table(py::module& pm) {
7595

7696
instantiate_table_iface(py_heterogen_table);
7797

98+
(void)instantiate_getter(py_heterogen_table);
7899
constexpr const supported_types_t* types = nullptr;
79-
instantiate_setters(py_heterogen_table, types);
100+
(void)instantiate_setters(py_heterogen_table, types);
80101
}
81102

82103
} // namespace oneapi::dal::python::data_management

onedal/data_management/table_metadata/table_metadata.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ namespace py = pybind11;
2525

2626
namespace oneapi::dal::python::data_management {
2727

28-
void instantiate_table_metadata(py::module& pm) {
28+
inline void instantiate_table_metadata_impl(py::module& pm) {
2929
constexpr const char name[] = "table_metadata";
3030
using dtype_arr_t = dal::array<dal::data_type>;
3131
using ftype_arr_t = dal::array<dal::feature_type>;
@@ -34,6 +34,21 @@ void instantiate_table_metadata(py::module& pm) {
3434
py_metadata.def(py::init());
3535
py_metadata.def(py::init<dal::table_metadata>());
3636
py_metadata.def(py::init<dtype_arr_t, ftype_arr_t>());
37+
// TODO: Remove this tight conversions
38+
using dtype_t = typename py::enum_<dal::data_type>::Scalar;
39+
using ftype_t = typename py::enum_<dal::feature_type>::Scalar;
40+
py_metadata.def(py::init([](const dal::array<dtype_t>& dtypes,
41+
const dal::array<ftype_t>& ftypes) -> dal::table_metadata {
42+
dal::array<dal::data_type> casted_dtypes =
43+
dal::array<dal::data_type>(dtypes, //
44+
reinterpret_cast<const dal::data_type*>(dtypes.get_data()),
45+
dtypes.get_count());
46+
dal::array<dal::feature_type> casted_ftypes = dal::array<dal::feature_type>(
47+
ftypes, //
48+
reinterpret_cast<const dal::feature_type*>(ftypes.get_data()),
49+
ftypes.get_count());
50+
return dal::table_metadata(casted_dtypes, casted_ftypes);
51+
}));
3752
py_metadata.def(py::pickle(
3853
[](const table_metadata& m) -> py::bytes {
3954
return serialize(m);
@@ -48,4 +63,8 @@ void instantiate_table_metadata(py::module& pm) {
4863
py_metadata.def("get_feature_count", &table_metadata::get_feature_count);
4964
}
5065

66+
void instantiate_table_metadata(py::module& pm) {
67+
(void)instantiate_table_metadata_impl(pm);
68+
}
69+
5170
} // namespace oneapi::dal::python::data_management
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# ==============================================================================
2+
# Copyright 2023 Intel Corporation
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
17+
from .table_interop import is_dataframe_entity, to_table
18+
19+
__all__ = ["is_dataframe_entity", "to_table"]
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# ==============================================================================
2+
# Copyright 2023 Intel Corporation
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
17+
import onedal
18+
19+
from .dataframe_protocol import Column
20+
from .dtype_conversion import get_data_type
21+
from ..array import to_array, is_array_entity
22+
23+
data_type = onedal._backend.dtype
24+
make_chunked_array = onedal._backend.data_management.make_chunked_array
25+
26+
class ChunkedArrayBuilder:
27+
def __init__(self, dtype: data_type):
28+
self.__dtype = dtype
29+
self.chunks = list()
30+
31+
def append(self, chunk):
32+
assert is_array_entity(chunk)
33+
self.chunks.append(chunk)
34+
return self
35+
36+
@property
37+
def chunk_count(self) -> int:
38+
return len(self.chunks)
39+
40+
@property
41+
def dtype(self):
42+
return self.__dtype
43+
44+
def __validate(self, result):
45+
count = result.get_chunk_count()
46+
assert count == self.chunk_count
47+
assert result.validate()
48+
49+
def build(self):
50+
dtype, count = self.dtype, self.chunk_count
51+
result = make_chunked_array(dtype, count)
52+
53+
for index in range(count):
54+
chunk = self.chunks[index]
55+
array = to_array(chunk)
56+
result.set_chunk(index, array)
57+
58+
self.__validate(result)
59+
return result
60+
61+
def build_from_column(column: Column):
62+
dtype = get_data_type(column.dtype)
63+
builder = ChunkedArrayBuilder(dtype)
64+
65+
for chunk in column.get_chunks():
66+
assert chunk.num_chunks() == 1
67+
buffers = chunk.get_buffers()
68+
raw_chunk, _ = buffers["data"]
69+
builder.append(raw_chunk)
70+
71+
return builder.build()
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# ==============================================================================
2+
# Copyright 2023 Intel Corporation
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
17+
import numpy as np
18+
19+
import onedal
20+
from ..array import to_array
21+
22+
from .dtype_conversion import get_data_type
23+
from .column_builder import build_from_column
24+
from .dataframe_protocol import Column, DataFrame
25+
26+
feature_type = onedal._backend.ftype
27+
type_array = onedal._backend.data_management.array_q
28+
metadata = onedal._backend.data_management.table_metadata
29+
heterogen_table = onedal._backend.data_management.heterogen_table
30+
31+
class DataFrameBuilder:
32+
def __init__(self):
33+
self.dtypes = []
34+
self.columns = []
35+
36+
@property
37+
def column_count(self) -> int:
38+
result = len(self.columns)
39+
control = len(self.dtypes)
40+
assert result == control
41+
return result
42+
43+
def append(self, column: Column):
44+
self.dtypes.append(column.dtype)
45+
self.columns.append(column)
46+
return self
47+
48+
def build_dtype_array(self) -> type_array:
49+
result = list()
50+
for index in range(self.column_count):
51+
dtype = self.dtypes[index]
52+
dal_dtype = get_data_type(dtype)
53+
result.append(dal_dtype)
54+
result = np.asarray(result)
55+
result = result.astype(np.int32)
56+
return to_array(result)
57+
58+
# TODO: implement logic supporting
59+
def build_ftype_array(self) -> type_array:
60+
ratio = feature_type.ratio
61+
col_count = self.column_count
62+
result = np.full(col_count, ratio)
63+
result = result.astype(np.int32)
64+
return to_array(result)
65+
66+
def build_metadata(self) -> metadata:
67+
dtypes = self.build_dtype_array()
68+
ftypes = self.build_ftype_array()
69+
return metadata(dtypes, ftypes)
70+
71+
def __validate(self, result: heterogen_table):
72+
column_count = result.get_column_count()
73+
assert column_count == self.column_count
74+
75+
def build(self) -> heterogen_table:
76+
meta = self.build_metadata()
77+
result = heterogen_table(meta)
78+
79+
80+
for index in range(self.column_count):
81+
column = self.columns[index]
82+
array = build_from_column(column)
83+
result.set_column(index, array)
84+
85+
self.__validate(result)
86+
return result
87+
88+
def build_from_dataframe(df: DataFrame) -> heterogen_table:
89+
column_count = df.num_columns()
90+
builder = DataFrameBuilder()
91+
92+
for index in range(column_count):
93+
column = df.get_column(index)
94+
builder.append(column)
95+
96+
result = builder.build()
97+
row_count = result.get_row_count()
98+
assert row_count == df.num_rows()
99+
return result

0 commit comments

Comments
 (0)