Skip to content

Commit df87e74

Browse files
committed
add files from uxlfoundation#1568
1 parent ae038cc commit df87e74

15 files changed

+1388
-0
lines changed
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
/*******************************************************************************
2+
* Copyright 2023 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
17+
#include <optional>
18+
19+
#ifdef ONEDAL_DATA_PARALLEL
20+
#include <sycl/sycl.hpp>
21+
#endif // ONEDAL_DATA_PARALLEL
22+
23+
#include <pybind11/pybind11.h>
24+
25+
#include "onedal/common/device_lookup.hpp"
26+
#include "onedal/interop/dlpack/api/dlpack.h"
27+
#include "onedal/interop/dlpack/device_conversion.hpp"
28+
29+
namespace py = pybind11;
30+
31+
namespace oneapi::dal::python::interop::dlpack {
32+
33+
constexpr inline auto cpu = DLDeviceType::kDLCPU;
34+
constexpr inline auto oneapi = DLDeviceType::kDLOneAPI;
35+
36+
DLDevice get_cpu_device() {
37+
return DLDevice{ cpu, 0 };
38+
}
39+
40+
bool is_cpu_device(DLDevice device) {
41+
const bool is_trivial = device.device_id == 0;
42+
const bool is_dlpack_cpu = device.device_type == cpu;
43+
return is_dlpack_cpu && is_trivial;
44+
}
45+
46+
bool is_oneapi_device(DLDevice device) {
47+
const bool trivial = device.device_id == 0;
48+
const bool is_dlpack_oneapi = device.device_type == oneapi;
49+
#ifdef ONEDAL_DATA_PARALLEL
50+
auto dev_opt = get_device_by_id(device.device_id);
51+
const bool is_known_by_id = dev_opt.has_value();
52+
#else // ONEDAL_DATA_PARALLEL
53+
constexpr bool is_known_by_id = false;
54+
#endif // ONEDAL_DATA_PARALLEL
55+
return is_dlpack_oneapi && is_known_by_id;
56+
}
57+
58+
bool is_unknown_device(DLDevice device) {
59+
const bool dlpack_cpu = is_cpu_device(device);
60+
const bool dlpack_oneapi = is_oneapi_device(device);
61+
return !dlpack_cpu && !dlpack_oneapi;
62+
}
63+
64+
#ifdef ONEDAL_DATA_PARALLEL
65+
66+
std::optional<sycl::device> convert_to_sycl(DLDevice device) {
67+
if (is_cpu_device(device)) {
68+
return sycl::ext::oneapi::detail::select_device( //
69+
&sycl::cpu_selector_v);
70+
}
71+
else if (is_oneapi_device(device)) {
72+
return get_device_by_id(device.device_id);
73+
}
74+
else {
75+
return {};
76+
}
77+
}
78+
79+
std::optional<DLDevice> convert_from_sycl(sycl::device device) {
80+
if (auto id = get_device_id(device)) {
81+
const std::uint32_t uid = id.value();
82+
auto raw = static_cast<std::int32_t>(uid);
83+
return { DLDevice{ oneapi, raw } };
84+
}
85+
else {
86+
return {};
87+
}
88+
}
89+
90+
#endif // ONEDAL_DATA_PARALLEL
91+
92+
py::object to_policy(DLDevice device) {
93+
if (is_cpu_device(device)) {
94+
detail::default_host_policy pol{};
95+
return py::cast(std::move(pol));
96+
}
97+
#ifdef ONEDAL_DATA_PARALLEL
98+
else if (is_oneapi_device(device)) {
99+
auto dev = convert_to_sycl(device);
100+
sycl::queue queue{ dev.value() };
101+
detail::data_parallel_policy pol{ queue };
102+
return py::cast(std::move(pol));
103+
}
104+
#endif // ONEDAL_DATA_PARALLEL
105+
else {
106+
throw std::runtime_error("Unknown device");
107+
}
108+
}
109+
110+
DLDevice get_device(std::shared_ptr<sycl::queue> ptr) {
111+
#ifdef ONEDAL_DATA_PARALLEL
112+
if (ptr.get() != nullptr) {
113+
const sycl::device dev = ptr->get_device();
114+
if (auto device = convert_from_sycl(dev)) {
115+
return device.value();
116+
}
117+
}
118+
#endif // ONEDAL_DATA_PARALLEL
119+
return get_cpu_device();
120+
}
121+
122+
std::shared_ptr<sycl::queue> get_queue(DLDevice device) {
123+
if (is_cpu_device(device)) {
124+
return { nullptr };
125+
}
126+
#ifdef ONEDAL_DATA_PARALLEL
127+
if (is_oneapi_device(device)) {
128+
const auto dev = convert_to_sycl(device);
129+
const sycl::device& unwrapped = dev.value();
130+
return std::make_shared<sycl::queue>(unwrapped);
131+
}
132+
#endif // ONEDAL_DATA_PARALLEL
133+
else {
134+
throw std::runtime_error("Unknown device");
135+
return { nullptr };
136+
}
137+
}
138+
139+
py::object to_policy(py::tuple tp) {
140+
constexpr const char err[] = "Ill-formed device tuple";
141+
142+
if (tp.size() != py::ssize_t{ 2ul }) {
143+
throw std::runtime_error(err);
144+
}
145+
146+
auto type = tp[0ul].cast<std::int32_t>();
147+
148+
DLDevice desc{ static_cast<DLDeviceType>(type), tp[1].cast<std::int32_t>() };
149+
150+
return to_policy(std::move(desc));
151+
}
152+
153+
void instantiate_convert_to_policy(py::module& pm) {
154+
pm.def("convert_to_policy", [](py::tuple tp) -> py::object {
155+
return to_policy(std::move(tp));
156+
});
157+
}
158+
159+
} // namespace oneapi::dal::python::interop::dlpack
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/*******************************************************************************
2+
* Copyright 2023 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
17+
#pragma once
18+
19+
#include <optional>
20+
21+
#ifdef ONEDAL_DATA_PARALLEL
22+
#include <sycl/sycl.hpp>
23+
#else // ONEDAL_DATA_PARALLEL
24+
namespace sycl {
25+
class queue;
26+
} //namespace sycl
27+
#endif // ONEDAL_DATA_PARALLEL
28+
29+
#include <pybind11/pybind11.h>
30+
31+
#include "oneapi/dal/array.hpp"
32+
33+
#include "onedal/interop/dlpack/api/dlpack.h"
34+
35+
namespace py = pybind11;
36+
37+
namespace oneapi::dal::python::interop::dlpack {
38+
39+
DLDevice get_cpu_device();
40+
41+
// Passing by value is done for a reason
42+
// These structures are extremely small
43+
// and will not create any perf overheads
44+
bool is_cpu_device(DLDevice device);
45+
bool is_oneapi_device(DLDevice device);
46+
bool is_unknown_device(DLDevice device);
47+
48+
#ifdef ONEDAL_DATA_PARALLEL
49+
50+
std::optional<DLDevice> convert_from_sycl(sycl::device device);
51+
std::optional<sycl::device> convert_to_sycl(DLDevice device);
52+
53+
#endif // ONEDAL_DATA_PARALLEL
54+
55+
template <typename Type>
56+
inline DLDevice make_device(const dal::array<Type>& arr) {
57+
#ifdef ONEDAL_DATA_PARALLEL
58+
if (auto queue = arr.get_queue()) {
59+
auto device = queue.value().get_device();
60+
return convert_from_sycl(device).value();
61+
}
62+
#endif // ONEDAL_DATA_PARALLEL
63+
return get_cpu_device();
64+
}
65+
66+
DLDevice get_device(std::shared_ptr<sycl::queue> ptr);
67+
std::shared_ptr<sycl::queue> get_queue(DLDevice device);
68+
69+
void instantiate_convert_to_policy(py::module& pm);
70+
71+
} // namespace oneapi::dal::python::interop::dlpack

onedal/datatypes/dlpack/dlpack.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*******************************************************************************
2+
* Copyright 2023 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
17+
#include "onedal/interop/dlpack/dlpack.hpp"
18+
#include "onedal/interop/dlpack/dlpack_helper.hpp"
19+
#include "onedal/interop/dlpack/dlpack_and_array.hpp"
20+
#include "onedal/interop/dlpack/dlpack_and_table.hpp"
21+
#include "onedal/interop/dlpack/device_conversion.hpp"
22+
23+
namespace oneapi::dal::python::interop {
24+
25+
void instantiate_dlpack_interop(py::module& pm) {
26+
auto sub_module = pm.def_submodule("dlpack");
27+
dlpack::instantiate_dlpack_helper(sub_module);
28+
dlpack::instantiate_dlpack_and_array(sub_module);
29+
dlpack::instantiate_dlpack_and_table(sub_module);
30+
dlpack::instantiate_convert_to_policy(sub_module);
31+
}
32+
33+
} // namespace oneapi::dal::python::interop

0 commit comments

Comments
 (0)