|
5 | 5 | #include <fstream> |
6 | 6 | #include "core/common/inlined_containers.h" |
7 | 7 | #include "core/common/span_utils.h" |
| 8 | +#include "core/flatbuffers/ort_format_version.h" |
| 9 | +#include "core/flatbuffers/schema/ort.fbs.h" |
8 | 10 | #include "core/framework/tensorprotoutils.h" |
| 11 | +#include "core/graph/graph_flatbuffers_utils.h" |
9 | 12 | #include "core/graph/graph_viewer.h" |
10 | 13 | #include "core/graph/model.h" |
11 | 14 | #include "core/graph/op.h" |
| 15 | +#include "core/graph/ort_format_load_options.h" |
12 | 16 | #include "core/session/inference_session.h" |
13 | 17 | #include "core/session/environment.h" |
14 | 18 | #include "test/providers/provider_test_utils.h" |
@@ -2828,5 +2832,150 @@ TEST_F(GraphTest, ShapeInferenceAfterInitializerExternalization) { |
2828 | 2832 | ASSERT_TRUE(second_resolve.IsOK()) << "Second resolve failed: " << second_resolve.ErrorMessage(); |
2829 | 2833 | } |
2830 | 2834 |
|
| 2835 | +// Targeted test for the TensorToTensorProto defense-in-depth: calling with a string tensor |
| 2836 | +// and use_tensor_buffer=true must produce a TensorProto with string_data (not external data). |
| 2837 | +TEST_F(GraphTest, TensorToTensorProtoStringTensorDefenseInDepth) { |
| 2838 | + const int num_strings = 20; |
| 2839 | + TensorShape shape({num_strings}); |
| 2840 | + Tensor string_tensor(DataTypeImpl::GetType<std::string>(), shape, CPUAllocator::DefaultInstance()); |
| 2841 | + auto* data = string_tensor.MutableData<std::string>(); |
| 2842 | + for (int i = 0; i < num_strings; ++i) { |
| 2843 | + data[i] = "test_value_" + std::to_string(i); |
| 2844 | + } |
| 2845 | + |
| 2846 | + // Verify the tensor is large enough to normally trigger the external data path. |
| 2847 | + ASSERT_GT(string_tensor.SizeInBytes(), utils::kSmallTensorExternalDataThreshold); |
| 2848 | + |
| 2849 | + // Call with use_tensor_buffer=true — defense-in-depth should still produce string_data. |
| 2850 | + auto tensor_proto = utils::TensorToTensorProto(string_tensor, "string_test", /*use_tensor_buffer=*/true); |
| 2851 | + |
| 2852 | + ASSERT_EQ(tensor_proto.string_data_size(), num_strings) |
| 2853 | + << "TensorToTensorProto should populate string_data for string tensors even with use_tensor_buffer=true"; |
| 2854 | + ASSERT_FALSE(utils::HasExternalDataInMemory(tensor_proto)) |
| 2855 | + << "String tensor should not use external data in memory"; |
| 2856 | + |
| 2857 | + for (int i = 0; i < num_strings; ++i) { |
| 2858 | + EXPECT_EQ(tensor_proto.string_data(i), "test_value_" + std::to_string(i)); |
| 2859 | + } |
| 2860 | +} |
| 2861 | + |
| 2862 | +// Regression test: ConvertInitializersIntoOrtValues must skip string tensors because their |
| 2863 | +// raw buffer contains std::string objects (with internal pointers), not serializable data. |
| 2864 | +// Without the fix, string initializer data was lost when the TensorProto was replaced with |
| 2865 | +// one using external data in memory, breaking ORT format model serialization. |
| 2866 | +TEST_F(GraphTest, ConvertInitializersIntoOrtValuesSkipsStringTensors) { |
| 2867 | + ModelProto model_proto; |
| 2868 | + model_proto.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); |
| 2869 | + auto* opset = model_proto.add_opset_import(); |
| 2870 | + opset->set_version(17); |
| 2871 | + auto* graph_proto = model_proto.mutable_graph(); |
| 2872 | + graph_proto->set_name("test_string_initializer_conversion"); |
| 2873 | + |
| 2874 | + // Create a string initializer with enough elements to exceed kSmallTensorExternalDataThreshold (127 bytes). |
| 2875 | + // sizeof(std::string) is typically 32 bytes (MSVC/libstdc++) or 24 bytes (libc++), so 20 elements |
| 2876 | + // will exceed 127 bytes on all major platforms (20 * 24 = 480 > 127). |
| 2877 | + const int num_strings = 20; |
| 2878 | + auto* string_initializer = graph_proto->add_initializer(); |
| 2879 | + string_initializer->set_name("string_data"); |
| 2880 | + string_initializer->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_STRING); |
| 2881 | + string_initializer->add_dims(num_strings); |
| 2882 | + for (int i = 0; i < num_strings; ++i) { |
| 2883 | + string_initializer->add_string_data("value_" + std::to_string(i)); |
| 2884 | + } |
| 2885 | + |
| 2886 | + // Create a Gather node: Gather(string_data, indices) -> output |
| 2887 | + auto* gather_node = graph_proto->add_node(); |
| 2888 | + gather_node->set_op_type("Gather"); |
| 2889 | + gather_node->add_input("string_data"); |
| 2890 | + gather_node->add_input("indices"); |
| 2891 | + gather_node->add_output("output"); |
| 2892 | + auto* axis_attr = gather_node->add_attribute(); |
| 2893 | + axis_attr->set_name("axis"); |
| 2894 | + axis_attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); |
| 2895 | + axis_attr->set_i(0); |
| 2896 | + |
| 2897 | + // Add graph input for indices |
| 2898 | + auto* input = graph_proto->add_input(); |
| 2899 | + input->set_name("indices"); |
| 2900 | + auto* input_type = input->mutable_type()->mutable_tensor_type(); |
| 2901 | + input_type->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); |
| 2902 | + input_type->mutable_shape()->add_dim()->set_dim_value(1); |
| 2903 | + |
| 2904 | + // Add graph output |
| 2905 | + auto* output = graph_proto->add_output(); |
| 2906 | + output->set_name("output"); |
| 2907 | + auto* output_type = output->mutable_type()->mutable_tensor_type(); |
| 2908 | + output_type->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_STRING); |
| 2909 | + output_type->mutable_shape()->add_dim()->set_dim_value(1); |
| 2910 | + |
| 2911 | + // Load and resolve |
| 2912 | + std::shared_ptr<Model> model; |
| 2913 | + ASSERT_STATUS_OK(Model::Load(std::move(model_proto), model, nullptr, *logger_)); |
| 2914 | + Graph& graph = model->MainGraph(); |
| 2915 | + ASSERT_STATUS_OK(graph.Resolve()); |
| 2916 | + |
| 2917 | + // Verify string initializer has string_data before conversion |
| 2918 | + const ONNX_NAMESPACE::TensorProto* init_before = nullptr; |
| 2919 | + ASSERT_TRUE(graph.GetInitializedTensor("string_data", init_before)); |
| 2920 | + ASSERT_EQ(init_before->string_data_size(), num_strings); |
| 2921 | + ASSERT_FALSE(utils::HasExternalDataInMemory(*init_before)); |
| 2922 | + |
| 2923 | + // Convert initializers into OrtValues |
| 2924 | + ASSERT_STATUS_OK(graph.ConvertInitializersIntoOrtValues()); |
| 2925 | + |
| 2926 | + // After conversion, string initializer should still have string_data intact |
| 2927 | + // (i.e., it should NOT have been replaced with external data in memory). |
| 2928 | + const ONNX_NAMESPACE::TensorProto* init_after = nullptr; |
| 2929 | + ASSERT_TRUE(graph.GetInitializedTensor("string_data", init_after)); |
| 2930 | + ASSERT_EQ(init_after->string_data_size(), num_strings) |
| 2931 | + << "String initializer data was lost during ConvertInitializersIntoOrtValues"; |
| 2932 | + ASSERT_FALSE(utils::HasExternalDataInMemory(*init_after)) |
| 2933 | + << "String tensor should not use external data in memory"; |
| 2934 | + |
| 2935 | + // Verify the string content is preserved |
| 2936 | + for (int i = 0; i < num_strings; ++i) { |
| 2937 | + EXPECT_EQ(init_after->string_data(i), "value_" + std::to_string(i)); |
| 2938 | + } |
| 2939 | + |
| 2940 | + // End-to-end: save to ORT format and reload, verifying string data survives the round-trip. |
| 2941 | + flatbuffers::FlatBufferBuilder builder; |
| 2942 | + { |
| 2943 | + flatbuffers::Offset<fbs::Model> fbs_model_offset; |
| 2944 | + ASSERT_STATUS_OK(model->SaveToOrtFormat(builder, fbs_model_offset)); |
| 2945 | + flatbuffers::Offset<fbs::InferenceSession> fbs_session_offset = |
| 2946 | + fbs::CreateInferenceSessionDirect(builder, |
| 2947 | + std::to_string(kOrtModelVersion).c_str(), |
| 2948 | + fbs_model_offset, |
| 2949 | + 0); |
| 2950 | + builder.Finish(fbs_session_offset); |
| 2951 | + } |
| 2952 | + |
| 2953 | + // Load back from ORT format buffer |
| 2954 | + { |
| 2955 | + const auto* fbs_buffer = builder.GetBufferPointer(); |
| 2956 | + const auto* fbs_session = fbs::GetInferenceSession(fbs_buffer); |
| 2957 | + ASSERT_NE(fbs_session, nullptr); |
| 2958 | + ASSERT_NE(fbs_session->model(), nullptr); |
| 2959 | + |
| 2960 | + OrtFormatLoadOptions load_options; |
| 2961 | + std::unique_ptr<Model> loaded_model; |
| 2962 | + ASSERT_STATUS_OK(Model::LoadFromOrtFormat(*fbs_session->model(), |
| 2963 | + nullptr, // local_registries |
| 2964 | + load_options, |
| 2965 | + *logger_, |
| 2966 | + loaded_model)); |
| 2967 | + |
| 2968 | + // Verify the string initializer survived the ORT format round-trip |
| 2969 | + const auto& loaded_graph = loaded_model->MainGraph(); |
| 2970 | + const ONNX_NAMESPACE::TensorProto* loaded_init = nullptr; |
| 2971 | + ASSERT_TRUE(loaded_graph.GetInitializedTensor("string_data", loaded_init)); |
| 2972 | + ASSERT_EQ(loaded_init->string_data_size(), num_strings) |
| 2973 | + << "String initializer data was lost during ORT format save/load round-trip"; |
| 2974 | + for (int i = 0; i < num_strings; ++i) { |
| 2975 | + EXPECT_EQ(loaded_init->string_data(i), "value_" + std::to_string(i)); |
| 2976 | + } |
| 2977 | + } |
| 2978 | +} |
| 2979 | + |
2831 | 2980 | } // namespace test |
2832 | 2981 | } // namespace onnxruntime |
0 commit comments