Skip to content

Commit 0e73535

Browse files
vrasparCopilot
andauthored
Fix heap OOB write in MaxPoolGrad via indices bounds validation (#27903)
### Description `MaxPoolGrad` uses `Indices` tensor values as raw pointer offsets into the output buffer without bounds checking. A malicious model can supply arbitrary index values to write to arbitrary heap locations. **Fix:** Validate each index is in `[0, dX_size)` before use via `ORT_RETURN_IF`, returning an error for out-of-range values. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent f9c83ae commit 0e73535

File tree

2 files changed

+123
-3
lines changed

2 files changed

+123
-3
lines changed
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include <memory>
5+
#include <string>
6+
#include <vector>
7+
8+
#include "gtest/gtest.h"
9+
10+
#include "test/providers/provider_test_utils.h"
11+
#include "test/util/include/default_providers.h"
12+
13+
namespace onnxruntime {
14+
namespace contrib {
15+
namespace test {
16+
17+
using namespace onnxruntime::test;
18+
19+
namespace {
20+
constexpr auto kOpsetVersion = 9;
21+
22+
void RunMaxPoolGradTest(const std::vector<int64_t>& dY_shape,
23+
const std::vector<float>& dY_data,
24+
const std::vector<int64_t>& indices_shape,
25+
const std::vector<int64_t>& indices_data,
26+
const std::vector<int64_t>& dX_shape,
27+
const std::vector<float>& dX_expected,
28+
bool expect_failure = false,
29+
const std::string& expected_failure_msg = "") {
30+
OpTester t{"MaxPoolGrad", kOpsetVersion, kOnnxDomain};
31+
t.AddInput("dY", dY_shape, dY_data);
32+
t.AddInput("Indices", indices_shape, indices_data);
33+
t.AddOutput<float>("dX", dX_shape, dX_expected);
34+
35+
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
36+
execution_providers.push_back(DefaultCpuExecutionProvider());
37+
38+
if (expect_failure) {
39+
t.Run(OpTester::ExpectResult::kExpectFailure, expected_failure_msg,
40+
{}, nullptr, &execution_providers);
41+
} else {
42+
t.Run(OpTester::ExpectResult::kExpectSuccess, "",
43+
{}, nullptr, &execution_providers);
44+
}
45+
}
46+
} // namespace
47+
48+
TEST(MaxPoolGradTest, Basic) {
49+
// dY shape: [1, 1, 2, 2] = 4 elements
50+
// dX shape: [1, 1, 4, 4] = 16 elements
51+
// Indices point to valid positions within [0, 16)
52+
RunMaxPoolGradTest(
53+
/*dY_shape=*/{1, 1, 2, 2},
54+
/*dY_data=*/{1.0f, 2.0f, 3.0f, 4.0f},
55+
/*indices_shape=*/{1, 1, 2, 2},
56+
/*indices_data=*/{0, 2, 8, 10},
57+
/*dX_shape=*/{1, 1, 4, 4},
58+
/*dX_expected=*/{1.0f, 0.0f, 2.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 3.0f, 0.0f, 4.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f});
59+
}
60+
61+
TEST(MaxPoolGradTest, IndicesOutOfBoundsPositive) {
62+
// Index 1000 is far beyond dX size of 16
63+
RunMaxPoolGradTest(
64+
/*dY_shape=*/{1, 1, 2, 2},
65+
/*dY_data=*/{1.0f, 2.0f, 3.0f, 4.0f},
66+
/*indices_shape=*/{1, 1, 2, 2},
67+
/*indices_data=*/{0, 1000, 8, 10},
68+
/*dX_shape=*/{1, 1, 4, 4},
69+
/*dX_expected=*/std::vector<float>(16, 0.0f),
70+
/*expect_failure=*/true,
71+
/*expected_failure_msg=*/"out of range");
72+
}
73+
74+
TEST(MaxPoolGradTest, IndicesOutOfBoundsNegative) {
75+
// Negative index is invalid
76+
RunMaxPoolGradTest(
77+
/*dY_shape=*/{1, 1, 2, 2},
78+
/*dY_data=*/{1.0f, 2.0f, 3.0f, 4.0f},
79+
/*indices_shape=*/{1, 1, 2, 2},
80+
/*indices_data=*/{0, -1, 8, 10},
81+
/*dX_shape=*/{1, 1, 4, 4},
82+
/*dX_expected=*/std::vector<float>(16, 0.0f),
83+
/*expect_failure=*/true,
84+
/*expected_failure_msg=*/"out of range");
85+
}
86+
87+
TEST(MaxPoolGradTest, IndicesExactlyAtBoundary) {
88+
// Index 15 is the last valid position for dX size 16
89+
RunMaxPoolGradTest(
90+
/*dY_shape=*/{1, 1, 2, 2},
91+
/*dY_data=*/{1.0f, 2.0f, 3.0f, 4.0f},
92+
/*indices_shape=*/{1, 1, 2, 2},
93+
/*indices_data=*/{0, 5, 10, 15},
94+
/*dX_shape=*/{1, 1, 4, 4},
95+
/*dX_expected=*/{1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 2.0f, 0.0f, 0.0f, 0.0f, 0.0f, 3.0f, 0.0f, 0.0f, 0.0f, 0.0f, 4.0f});
96+
}
97+
98+
TEST(MaxPoolGradTest, IndicesOnePassedBoundary) {
99+
// Index 16 is one passed the last valid position (dX size = 16)
100+
RunMaxPoolGradTest(
101+
/*dY_shape=*/{1, 1, 2, 2},
102+
/*dY_data=*/{1.0f, 2.0f, 3.0f, 4.0f},
103+
/*indices_shape=*/{1, 1, 2, 2},
104+
/*indices_data=*/{0, 5, 10, 16},
105+
/*dX_shape=*/{1, 1, 4, 4},
106+
/*dX_expected=*/std::vector<float>(16, 0.0f),
107+
/*expect_failure=*/true,
108+
/*expected_failure_msg=*/"out of range");
109+
}
110+
111+
} // namespace test
112+
} // namespace contrib
113+
} // namespace onnxruntime

orttraining/orttraining/training_ops/cpu/nn/pool_gradient_op.cc

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,16 @@ Status MaxPoolGrad<T>::Compute(OpKernelContext* context) const {
5252
const int64_t* indices_data = indices->template Data<int64_t>();
5353
T* dX_data = dX->template MutableData<T>();
5454

55-
EigenVectorMap<T>(dX_data, narrow<Eigen::Index>(dX_shape.Size())).setZero();
56-
57-
for (int64_t i = 0; i < dY->Shape().Size(); ++i) {
55+
const int64_t dX_size = dX_shape.Size();
56+
EigenVectorMap<T>(dX_data, narrow<Eigen::Index>(dX_size)).setZero();
57+
58+
const int64_t dY_size = dY->Shape().Size();
59+
for (int64_t i = 0; i < dY_size; ++i) {
60+
if (indices_data[i] < 0 || indices_data[i] >= dX_size) {
61+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
62+
"MaxPoolGrad: index value ", indices_data[i],
63+
" at position ", i, " is out of range [0, ", dX_size, ")");
64+
}
5865
T* p_dX_data = dX_data + indices_data[i];
5966
*p_dX_data += dY_data[i];
6067
}

0 commit comments

Comments
 (0)