Skip to content

Commit c5e6bd8

Browse files
Copilottianleiwu
andauthored
Fix string tensor deserialization in ORT format models (#28133)
### Description `ConvertInitializersIntoOrtValues()` replaces initializer TensorProtos with ones pointing to in-memory raw buffers via `TensorToTensorProto(..., use_tensor_buffer=true)`. For string tensors exceeding 127 bytes, this stores a pointer to `std::string` C++ objects as "external data"—but those objects contain heap pointers, not serializable content. The `string_data` field ends up empty, so ORT format save loses all string data. On reload: shape says N elements, `string_data_size()` is 0 → deserialization fails. Changes: - **`tensorprotoutils.cc`**: Add `!tensor.IsDataTypeString()` guard in `TensorToTensorProto` so string tensors always populate `string_data` rather than taking the external-data-in-memory path - **`graph.cc`**: Skip string tensors in `ConvertInitializersIntoOrtValues()` since the raw-buffer optimization is fundamentally incompatible with string data - **`graph_test.cc`**: Add regression test creating a 20-element string initializer, calling `ConvertInitializersIntoOrtValues()`, and verifying string data survives ### Motivation and Context Since onnxruntime 1.23.0, loading ORT format models with string tensor initializers fails with: ``` INVALID_ARGUMENT: Deserialize tensor failed. UnpackTensor: the pre-allocate size does not match the size in proto ``` Reproduction: any model with a string initializer (e.g. Gather over a string array) saved via `optimized_model_filepath` with `.ort` extension, then reloaded. --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: tianleiwu <30328909+tianleiwu@users.noreply.github.com> Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
1 parent 3729b61 commit c5e6bd8

3 files changed

Lines changed: 164 additions & 2 deletions

File tree

onnxruntime/core/framework/tensorprotoutils.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1934,7 +1934,10 @@ ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor,
19341934
}
19351935

19361936
tensor_proto.set_data_type(tensor.GetElementType());
1937-
if (use_tensor_buffer && tensor.SizeInBytes() > kSmallTensorExternalDataThreshold) {
1937+
// String tensors cannot use the external data in-memory optimization because their raw buffer
1938+
// contains std::string objects (with internal pointers), not serializable string content.
1939+
if (use_tensor_buffer && !tensor.IsDataTypeString() &&
1940+
tensor.SizeInBytes() > kSmallTensorExternalDataThreshold) {
19381941
// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/graph_flatbuffers_utils.cc#L302
19391942
const auto* raw_data = tensor.DataRaw();
19401943
ORT_ENFORCE(raw_data, "Missing raw data for tensor proto. Invalid tensor.");

onnxruntime/core/graph/graph.cc

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3775,6 +3775,13 @@ Status Graph::ConvertInitializersIntoOrtValues() {
37753775
continue;
37763776
}
37773777

3778+
// String tensors cannot use the raw buffer in-memory optimization because their raw data
3779+
// contains std::string objects (with internal pointers), not serializable content.
3780+
// They are kept as regular TensorProtos and deserialized normally during inference.
3781+
if (utils::HasString(tensor_proto)) {
3782+
continue;
3783+
}
3784+
37783785
size_t size_in_bytes = 0;
37793786
ORT_RETURN_IF_ERROR(utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &size_in_bytes));
37803787
if (size_in_bytes > utils::kSmallTensorExternalDataThreshold) {
@@ -4092,7 +4099,10 @@ Status Graph::InjectExternalInitializedTensors(const InlinedHashMap<std::string,
40924099
OrtValue ort_value;
40934100
TensorProto tensor_proto;
40944101
constexpr const bool use_tensor_buffer_true = true;
4095-
if (user_tensor.SizeInBytes() > utils::kSmallTensorExternalDataThreshold) {
4102+
// String tensors cannot use the raw buffer in-memory optimization because their raw data
4103+
// contains std::string objects (with internal pointers), not serializable content.
4104+
if (!user_tensor.IsDataTypeString() &&
4105+
user_tensor.SizeInBytes() > utils::kSmallTensorExternalDataThreshold) {
40964106
if (user_tensor.OwnsBuffer()) {
40974107
// If the user tensor has its own memory, we avoid copying
40984108
tensor_proto = utils::TensorToTensorProto(user_tensor, name, use_tensor_buffer_true);

onnxruntime/test/ir/graph_test.cc

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,14 @@
55
#include <fstream>
66
#include "core/common/inlined_containers.h"
77
#include "core/common/span_utils.h"
8+
#include "core/flatbuffers/ort_format_version.h"
9+
#include "core/flatbuffers/schema/ort.fbs.h"
810
#include "core/framework/tensorprotoutils.h"
11+
#include "core/graph/graph_flatbuffers_utils.h"
912
#include "core/graph/graph_viewer.h"
1013
#include "core/graph/model.h"
1114
#include "core/graph/op.h"
15+
#include "core/graph/ort_format_load_options.h"
1216
#include "core/session/inference_session.h"
1317
#include "core/session/environment.h"
1418
#include "test/providers/provider_test_utils.h"
@@ -2828,5 +2832,150 @@ TEST_F(GraphTest, ShapeInferenceAfterInitializerExternalization) {
28282832
ASSERT_TRUE(second_resolve.IsOK()) << "Second resolve failed: " << second_resolve.ErrorMessage();
28292833
}
28302834

2835+
// Targeted test for the TensorToTensorProto defense-in-depth: calling with a string tensor
2836+
// and use_tensor_buffer=true must produce a TensorProto with string_data (not external data).
2837+
TEST_F(GraphTest, TensorToTensorProtoStringTensorDefenseInDepth) {
2838+
const int num_strings = 20;
2839+
TensorShape shape({num_strings});
2840+
Tensor string_tensor(DataTypeImpl::GetType<std::string>(), shape, CPUAllocator::DefaultInstance());
2841+
auto* data = string_tensor.MutableData<std::string>();
2842+
for (int i = 0; i < num_strings; ++i) {
2843+
data[i] = "test_value_" + std::to_string(i);
2844+
}
2845+
2846+
// Verify the tensor is large enough to normally trigger the external data path.
2847+
ASSERT_GT(string_tensor.SizeInBytes(), utils::kSmallTensorExternalDataThreshold);
2848+
2849+
// Call with use_tensor_buffer=true — defense-in-depth should still produce string_data.
2850+
auto tensor_proto = utils::TensorToTensorProto(string_tensor, "string_test", /*use_tensor_buffer=*/true);
2851+
2852+
ASSERT_EQ(tensor_proto.string_data_size(), num_strings)
2853+
<< "TensorToTensorProto should populate string_data for string tensors even with use_tensor_buffer=true";
2854+
ASSERT_FALSE(utils::HasExternalDataInMemory(tensor_proto))
2855+
<< "String tensor should not use external data in memory";
2856+
2857+
for (int i = 0; i < num_strings; ++i) {
2858+
EXPECT_EQ(tensor_proto.string_data(i), "test_value_" + std::to_string(i));
2859+
}
2860+
}
2861+
2862+
// Regression test: ConvertInitializersIntoOrtValues must skip string tensors because their
2863+
// raw buffer contains std::string objects (with internal pointers), not serializable data.
2864+
// Without the fix, string initializer data was lost when the TensorProto was replaced with
2865+
// one using external data in memory, breaking ORT format model serialization.
2866+
TEST_F(GraphTest, ConvertInitializersIntoOrtValuesSkipsStringTensors) {
2867+
ModelProto model_proto;
2868+
model_proto.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
2869+
auto* opset = model_proto.add_opset_import();
2870+
opset->set_version(17);
2871+
auto* graph_proto = model_proto.mutable_graph();
2872+
graph_proto->set_name("test_string_initializer_conversion");
2873+
2874+
// Create a string initializer with enough elements to exceed kSmallTensorExternalDataThreshold (127 bytes).
2875+
// sizeof(std::string) is typically 32 bytes (MSVC/libstdc++) or 24 bytes (libc++), so 20 elements
2876+
// will exceed 127 bytes on all major platforms (20 * 24 = 480 > 127).
2877+
const int num_strings = 20;
2878+
auto* string_initializer = graph_proto->add_initializer();
2879+
string_initializer->set_name("string_data");
2880+
string_initializer->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_STRING);
2881+
string_initializer->add_dims(num_strings);
2882+
for (int i = 0; i < num_strings; ++i) {
2883+
string_initializer->add_string_data("value_" + std::to_string(i));
2884+
}
2885+
2886+
// Create a Gather node: Gather(string_data, indices) -> output
2887+
auto* gather_node = graph_proto->add_node();
2888+
gather_node->set_op_type("Gather");
2889+
gather_node->add_input("string_data");
2890+
gather_node->add_input("indices");
2891+
gather_node->add_output("output");
2892+
auto* axis_attr = gather_node->add_attribute();
2893+
axis_attr->set_name("axis");
2894+
axis_attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT);
2895+
axis_attr->set_i(0);
2896+
2897+
// Add graph input for indices
2898+
auto* input = graph_proto->add_input();
2899+
input->set_name("indices");
2900+
auto* input_type = input->mutable_type()->mutable_tensor_type();
2901+
input_type->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
2902+
input_type->mutable_shape()->add_dim()->set_dim_value(1);
2903+
2904+
// Add graph output
2905+
auto* output = graph_proto->add_output();
2906+
output->set_name("output");
2907+
auto* output_type = output->mutable_type()->mutable_tensor_type();
2908+
output_type->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_STRING);
2909+
output_type->mutable_shape()->add_dim()->set_dim_value(1);
2910+
2911+
// Load and resolve
2912+
std::shared_ptr<Model> model;
2913+
ASSERT_STATUS_OK(Model::Load(std::move(model_proto), model, nullptr, *logger_));
2914+
Graph& graph = model->MainGraph();
2915+
ASSERT_STATUS_OK(graph.Resolve());
2916+
2917+
// Verify string initializer has string_data before conversion
2918+
const ONNX_NAMESPACE::TensorProto* init_before = nullptr;
2919+
ASSERT_TRUE(graph.GetInitializedTensor("string_data", init_before));
2920+
ASSERT_EQ(init_before->string_data_size(), num_strings);
2921+
ASSERT_FALSE(utils::HasExternalDataInMemory(*init_before));
2922+
2923+
// Convert initializers into OrtValues
2924+
ASSERT_STATUS_OK(graph.ConvertInitializersIntoOrtValues());
2925+
2926+
// After conversion, string initializer should still have string_data intact
2927+
// (i.e., it should NOT have been replaced with external data in memory).
2928+
const ONNX_NAMESPACE::TensorProto* init_after = nullptr;
2929+
ASSERT_TRUE(graph.GetInitializedTensor("string_data", init_after));
2930+
ASSERT_EQ(init_after->string_data_size(), num_strings)
2931+
<< "String initializer data was lost during ConvertInitializersIntoOrtValues";
2932+
ASSERT_FALSE(utils::HasExternalDataInMemory(*init_after))
2933+
<< "String tensor should not use external data in memory";
2934+
2935+
// Verify the string content is preserved
2936+
for (int i = 0; i < num_strings; ++i) {
2937+
EXPECT_EQ(init_after->string_data(i), "value_" + std::to_string(i));
2938+
}
2939+
2940+
// End-to-end: save to ORT format and reload, verifying string data survives the round-trip.
2941+
flatbuffers::FlatBufferBuilder builder;
2942+
{
2943+
flatbuffers::Offset<fbs::Model> fbs_model_offset;
2944+
ASSERT_STATUS_OK(model->SaveToOrtFormat(builder, fbs_model_offset));
2945+
flatbuffers::Offset<fbs::InferenceSession> fbs_session_offset =
2946+
fbs::CreateInferenceSessionDirect(builder,
2947+
std::to_string(kOrtModelVersion).c_str(),
2948+
fbs_model_offset,
2949+
0);
2950+
builder.Finish(fbs_session_offset);
2951+
}
2952+
2953+
// Load back from ORT format buffer
2954+
{
2955+
const auto* fbs_buffer = builder.GetBufferPointer();
2956+
const auto* fbs_session = fbs::GetInferenceSession(fbs_buffer);
2957+
ASSERT_NE(fbs_session, nullptr);
2958+
ASSERT_NE(fbs_session->model(), nullptr);
2959+
2960+
OrtFormatLoadOptions load_options;
2961+
std::unique_ptr<Model> loaded_model;
2962+
ASSERT_STATUS_OK(Model::LoadFromOrtFormat(*fbs_session->model(),
2963+
nullptr, // local_registries
2964+
load_options,
2965+
*logger_,
2966+
loaded_model));
2967+
2968+
// Verify the string initializer survived the ORT format round-trip
2969+
const auto& loaded_graph = loaded_model->MainGraph();
2970+
const ONNX_NAMESPACE::TensorProto* loaded_init = nullptr;
2971+
ASSERT_TRUE(loaded_graph.GetInitializedTensor("string_data", loaded_init));
2972+
ASSERT_EQ(loaded_init->string_data_size(), num_strings)
2973+
<< "String initializer data was lost during ORT format save/load round-trip";
2974+
for (int i = 0; i < num_strings; ++i) {
2975+
EXPECT_EQ(loaded_init->string_data(i), "value_" + std::to_string(i));
2976+
}
2977+
}
2978+
}
2979+
28312980
} // namespace test
28322981
} // namespace onnxruntime

0 commit comments

Comments
 (0)