Skip to content

Commit af725c0

Browse files
committed
ORC-2161: [C++] UnionColumnReader should reject out-of-range union tags
### What changes were proposed in this pull request? This PR adds validation for C++ union tag values before they are used as child indexes. The change covers: - UnionColumnReader::skip - UnionColumnReader::nextInternal - UnionColumnPrinter::printRow If a malformed ORC file contains a union tag that is greater than or equal to the number of union children, the C++ reader/printer now throws ParseError instead of indexing out of bounds. ### Why are the changes needed? Union tags are decoded from the ORC data stream as byte values, but the valid range depends on the number of union children. Malformed input can contain a tag outside that range. The C++ reader previously trusted the tag value directly when indexing per-child state. This patch makes malformed union tags fail cleanly. ### How was this patch tested? Added C++ unit tests for a two-child union with invalid tag value 200, covering: - next - next with nulls - skip - ColumnPrinter ### Was this patch authored or co-authored using generative AI tooling? Generated-by: OpenAI Codex GPT-5. Closes #2618 from wgtmac/fix_union. Authored-by: Gang Wu <ustcwg@gmail.com> Signed-off-by: Gang Wu <ustcwg@gmail.com> (cherry picked from commit 3563ee5) Signed-off-by: Gang Wu <ustcwg@gmail.com>
1 parent df7c5ea commit af725c0

4 files changed

Lines changed: 114 additions & 7 deletions

File tree

c++/src/ColumnPrinter.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -540,11 +540,17 @@ namespace orc {
540540
if (hasNulls && !notNull[rowId]) {
541541
writeString(buffer, "null");
542542
} else {
543+
const size_t tag = static_cast<size_t>(tags_[rowId]);
544+
if (tag >= fieldPrinter_.size()) {
545+
throw ParseError("Invalid union tag " + to_string(static_cast<int64_t>(tag)) +
546+
" for union with " +
547+
to_string(static_cast<int64_t>(fieldPrinter_.size())) + " children");
548+
}
543549
writeString(buffer, "{\"tag\": ");
544-
const auto numBuffer = std::to_string(static_cast<int64_t>(tags_[rowId]));
550+
const auto numBuffer = std::to_string(static_cast<int64_t>(tag));
545551
writeString(buffer, numBuffer.c_str());
546552
writeString(buffer, ", \"value\": ");
547-
fieldPrinter_[tags_[rowId]]->printRow(offsets_[rowId]);
553+
fieldPrinter_[tag]->printRow(offsets_[rowId]);
548554
writeChar(buffer, '}');
549555
}
550556
}

c++/src/ColumnReader.cc

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,6 +1130,16 @@ namespace orc {
11301130
}
11311131
}
11321132

1133+
size_t getCheckedUnionTag(unsigned char tag, uint64_t numChildren) {
1134+
size_t child = static_cast<size_t>(tag);
1135+
if (child >= numChildren) {
1136+
throw ParseError("Invalid union tag " + to_string(static_cast<int64_t>(child)) +
1137+
" for union with " + to_string(static_cast<int64_t>(numChildren)) +
1138+
" children");
1139+
}
1140+
return child;
1141+
}
1142+
11331143
class UnionColumnReader : public ColumnReader {
11341144
private:
11351145
std::unique_ptr<ByteRleDecoder> rle_;
@@ -1180,15 +1190,15 @@ namespace orc {
11801190
uint64_t UnionColumnReader::skip(uint64_t numValues) {
11811191
numValues = ColumnReader::skip(numValues);
11821192
const uint64_t BUFFER_SIZE = 1024;
1183-
char buffer[BUFFER_SIZE];
1193+
unsigned char buffer[BUFFER_SIZE];
11841194
uint64_t lengthsRead = 0;
11851195
int64_t* counts = childrenCounts_.data();
11861196
memset(counts, 0, sizeof(int64_t) * numChildren_);
11871197
while (lengthsRead < numValues) {
11881198
uint64_t chunk = std::min(numValues - lengthsRead, BUFFER_SIZE);
1189-
rle_->next(buffer, chunk, nullptr);
1199+
rle_->next(reinterpret_cast<char*>(buffer), chunk, nullptr);
11901200
for (size_t i = 0; i < chunk; ++i) {
1191-
counts[static_cast<size_t>(buffer[i])] += 1;
1201+
counts[getCheckedUnionTag(buffer[i], numChildren_)] += 1;
11921202
}
11931203
lengthsRead += chunk;
11941204
}
@@ -1224,12 +1234,14 @@ namespace orc {
12241234
if (notNull) {
12251235
for (size_t i = 0; i < numValues; ++i) {
12261236
if (notNull[i]) {
1227-
offsets[i] = static_cast<uint64_t>(counts[static_cast<size_t>(tags[i])]++);
1237+
size_t tag = getCheckedUnionTag(tags[i], numChildren_);
1238+
offsets[i] = static_cast<uint64_t>(counts[tag]++);
12281239
}
12291240
}
12301241
} else {
12311242
for (size_t i = 0; i < numValues; ++i) {
1232-
offsets[i] = static_cast<uint64_t>(counts[static_cast<size_t>(tags[i])]++);
1243+
size_t tag = getCheckedUnionTag(tags[i], numChildren_);
1244+
offsets[i] = static_cast<uint64_t>(counts[tag]++);
12331245
}
12341246
}
12351247
// read the right number of each child column

c++/test/TestColumnPrinter.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,25 @@ namespace orc {
506506
}
507507
}
508508

509+
TEST(TestColumnPrinter, UnionColumnPrinterRejectsInvalidTag) {
510+
std::string line;
511+
std::unique_ptr<Type> type = createUnionType();
512+
type->addUnionChild(createPrimitiveType(LONG));
513+
type->addUnionChild(createPrimitiveType(INT));
514+
std::unique_ptr<ColumnPrinter> printer = createColumnPrinter(line, type.get());
515+
516+
UnionVectorBatch batch(1, *getDefaultPool());
517+
batch.children.push_back(new LongVectorBatch(1, *getDefaultPool()));
518+
batch.children.push_back(new LongVectorBatch(1, *getDefaultPool()));
519+
batch.numElements = 1;
520+
batch.hasNulls = false;
521+
batch.tags[0] = 200;
522+
batch.offsets[0] = 0;
523+
524+
printer->reset(batch);
525+
EXPECT_THROW(printer->printRow(0), ParseError);
526+
}
527+
509528
TEST(TestColumnPrinter, StructColumnPrinter) {
510529
std::string line;
511530
std::unique_ptr<Type> type = createStructType();

c++/test/TestColumnReader.cc

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3717,6 +3717,76 @@ namespace orc {
37173717
batch.toString());
37183718
}
37193719

3720+
const unsigned char INVALID_UNION_TAG[] = {0xff, 0xc8};
3721+
const unsigned char ONE_PRESENT_VALUE[] = {0x00, 0x80};
3722+
3723+
std::unique_ptr<Type> createTwoChildUnionRowType() {
3724+
std::unique_ptr<Type> unionType = createUnionType();
3725+
unionType->addUnionChild(createPrimitiveType(LONG));
3726+
unionType->addUnionChild(createPrimitiveType(INT));
3727+
std::unique_ptr<Type> rowType = createStructType();
3728+
rowType->addStructField("col0", std::move(unionType));
3729+
return rowType;
3730+
}
3731+
3732+
std::unique_ptr<ColumnReader> buildInvalidUnionTagReader(MockStripeStreams& streams,
3733+
bool hasNulls = false) {
3734+
std::vector<bool> selectedColumns(4, false);
3735+
selectedColumns[0] = true;
3736+
selectedColumns[1] = true;
3737+
EXPECT_CALL(streams, getSelectedColumns()).WillRepeatedly(testing::Return(selectedColumns));
3738+
EXPECT_CALL(streams, getSchemaEvolution()).WillRepeatedly(testing::Return(nullptr));
3739+
3740+
proto::ColumnEncoding directEncoding;
3741+
directEncoding.set_kind(proto::ColumnEncoding_Kind_DIRECT);
3742+
EXPECT_CALL(streams, getEncoding(testing::_)).WillRepeatedly(testing::Return(directEncoding));
3743+
3744+
EXPECT_CALL(streams, getStreamProxy(testing::_, proto::Stream_Kind_PRESENT, true))
3745+
.WillRepeatedly(testing::Return(nullptr));
3746+
3747+
if (hasNulls) {
3748+
EXPECT_CALL(streams, getStreamProxy(1, proto::Stream_Kind_PRESENT, true))
3749+
.WillRepeatedly(testing::Return(
3750+
new SeekableArrayInputStream(ONE_PRESENT_VALUE, ARRAY_SIZE(ONE_PRESENT_VALUE))));
3751+
}
3752+
3753+
EXPECT_CALL(streams, getStreamProxy(1, proto::Stream_Kind_DATA, true))
3754+
.WillRepeatedly(testing::Return(
3755+
new SeekableArrayInputStream(INVALID_UNION_TAG, ARRAY_SIZE(INVALID_UNION_TAG))));
3756+
3757+
std::unique_ptr<Type> rowType = createTwoChildUnionRowType();
3758+
return buildReader(*rowType, streams);
3759+
}
3760+
3761+
void addSingleUnionBatch(StructVectorBatch& batch) {
3762+
batch.fields.push_back(new UnionVectorBatch(1, *getDefaultPool()));
3763+
}
3764+
3765+
TEST(TestColumnReader, testUnionRejectsInvalidTag) {
3766+
MockStripeStreams streams;
3767+
std::unique_ptr<ColumnReader> reader = buildInvalidUnionTagReader(streams);
3768+
3769+
StructVectorBatch batch(1, *getDefaultPool());
3770+
addSingleUnionBatch(batch);
3771+
EXPECT_THROW(reader->next(batch, 1, 0), ParseError);
3772+
}
3773+
3774+
TEST(TestColumnReader, testUnionRejectsInvalidTagWithNulls) {
3775+
MockStripeStreams streams;
3776+
std::unique_ptr<ColumnReader> reader = buildInvalidUnionTagReader(streams, true);
3777+
3778+
StructVectorBatch batch(1, *getDefaultPool());
3779+
addSingleUnionBatch(batch);
3780+
EXPECT_THROW(reader->next(batch, 1, 0), ParseError);
3781+
}
3782+
3783+
TEST(TestColumnReader, testUnionSkipRejectsInvalidTag) {
3784+
MockStripeStreams streams;
3785+
std::unique_ptr<ColumnReader> reader = buildInvalidUnionTagReader(streams);
3786+
3787+
EXPECT_THROW(reader->skip(1), ParseError);
3788+
}
3789+
37203790
TEST(TestColumnReader, testUnionWithNulls) {
37213791
MockStripeStreams streams;
37223792

0 commit comments

Comments
 (0)