Skip to content

Commit 0c3bae5

Browse files
committed
Improve tests
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 03a44c0 commit 0c3bae5

File tree

5 files changed

+251
-5
lines changed

5 files changed

+251
-5
lines changed

onnxscript/optimizer/_constant_folding_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,22 @@ def test_gather_symdim(self):
614614
optimized = self._fold(model)
615615
self.assertEqual(optimized.graph.node(-1).op_type, "Identity")
616616

617+
def test_reshape_identity_int32_shape(self):
618+
"""Reshape with a constant INT32 shape input should be recognized as identity."""
619+
model_ir = ir.from_onnx_text(
620+
"""
621+
<ir_version: 7, opset_import: [ "" : 17]>
622+
agraph (float[3, 4] x) => (float[3, 4] z)
623+
{
624+
shape_i64 = Constant <value_ints=[3, 4]> ()
625+
shape = Cast <to=6> (shape_i64)
626+
z = Reshape (x, shape)
627+
}
628+
"""
629+
)
630+
optimized = self._fold(model_ir)
631+
self.assertEqual(optimized.graph.node(-1).op_type, "Identity")
632+
617633
def test_input_size_limit(self):
618634
model_text = """
619635
<ir_version: 7, opset_import: [ "" : 17]>

onnxscript/rewriter/rules/common/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,15 +102,15 @@
102102
successive_relu_rule,
103103
)
104104
from onnxscript.rewriter.rules.common._gemm_to_matmul_add import gemm_to_matmul_add_rule
105+
from onnxscript.rewriter.rules.common._materialize_reshape_shape import (
106+
materialize_reshape_shape_rule,
107+
)
105108
from onnxscript.rewriter.rules.common._matmul_add_to_gemm import (
106109
matmul_add_to_gemm_rule,
107110
transpose_a_matmul_add_to_gemm_rule,
108111
transpose_ab_matmul_add_to_gemm_rule,
109112
transpose_b_matmul_add_to_gemm_rule,
110113
)
111-
from onnxscript.rewriter.rules.common._materialize_reshape_shape import (
112-
materialize_reshape_shape_rule,
113-
)
114114
from onnxscript.rewriter.rules.common._min_max_to_clip import (
115115
max_max_rule,
116116
max_min_rule,

onnxscript/rewriter/rules/common/_collapse_slices_test.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,62 @@ def test_slice_equal_dynamic_shape_but_step_reverse(self):
119119
count = _collapse_slices.rules.apply_to_model(model)
120120
# Should not change the output shape if we did not use the default step of 1
121121
self.assertEqual(count, 0)
122+
123+
def test_multi_element_steps_all_ones_collapses(self):
124+
"""Slice with multi-axis steps=[1,1] and matching shapes should collapse."""
125+
model = ir.from_onnx_text(
126+
"""
127+
<ir_version: 7, opset_import: [ "" : 17]>
128+
agraph (float[L, M] data) => (float[L, M] output)
129+
{
130+
starts = Constant<value: tensor = int64[2] {0, 0}>()
131+
ends = Constant<value: tensor = int64[2] {9999, 9999}>()
132+
axes = Constant<value: tensor = int64[2] {0, 1}>()
133+
steps = Constant<value: tensor = int64[2] {1, 1}>()
134+
output = Slice (data, starts, ends, axes, steps)
135+
}
136+
"""
137+
)
138+
count = _collapse_slices.rules.apply_to_model(model)
139+
self.assertEqual(count, 1)
140+
self.assertIn("Identity", [node.op_type for node in model.graph])
141+
142+
def test_multi_element_steps_with_non_one_does_not_collapse(self):
143+
"""Slice with steps containing a non-1 element should not collapse."""
144+
model = ir.from_onnx_text(
145+
"""
146+
<ir_version: 7, opset_import: [ "" : 17]>
147+
agraph (float[10, 20] data) => (float[10, 10] output)
148+
{
149+
starts = Constant<value: tensor = int64[2] {0, 0}>()
150+
ends = Constant<value: tensor = int64[2] {10, 20}>()
151+
axes = Constant<value: tensor = int64[2] {0, 1}>()
152+
steps = Constant<value: tensor = int64[2] {1, 2}>()
153+
output = Slice (data, starts, ends, axes, steps)
154+
}
155+
"""
156+
)
157+
count = _collapse_slices.rules.apply_to_model(model)
158+
self.assertEqual(count, 0)
159+
160+
def test_multi_element_steps_numerical_correctness(self):
161+
"""Verify numerical correctness of multi-axis collapse."""
162+
model_text = """
163+
<ir_version: 7, opset_import: [ "" : 17]>
164+
agraph (float[4, 5] data) => (float[4, 5] output)
165+
{
166+
starts = Constant<value: tensor = int64[2] {0, 0}>()
167+
ends = Constant<value: tensor = int64[2] {100, 100}>()
168+
axes = Constant<value: tensor = int64[2] {0, 1}>()
169+
steps = Constant<value: tensor = int64[2] {1, 1}>()
170+
output = Slice (data, starts, ends, axes, steps)
171+
}
172+
"""
173+
original = ir.from_onnx_text(model_text)
174+
model = ir.from_onnx_text(model_text)
175+
_collapse_slices.rules.apply_to_model(model)
176+
testing.assert_numerically_equal(
177+
original,
178+
model,
179+
(np.random.rand(4, 5).astype(np.float32),),
180+
)

onnxscript/rewriter/rules/common/_materialize_reshape_shape.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,18 @@ def check(self, context, data: ir.Value, shape: ir.Value) -> MatchResult:
4848
return check_result.fail(
4949
f"Output shape has {sym_count} symbolic dims, cannot materialize."
5050
)
51+
52+
# Preserve allowzero attribute from original node
53+
self._allowzero = context.nodes[0].attributes.get_int("allowzero", 0)
5154
return check_result
5255

5356
def rewrite(self, op, data: ir.Value, shape: ir.Value):
5457
new_shape = op.Constant(
5558
value=ir.tensor(self._new_dims, dtype=ir.DataType.INT64),
5659
)
57-
return op.Reshape(data, new_shape)
60+
return op.Reshape(data, new_shape, allowzero=self._allowzero or None)
5861

5962

6063
materialize_reshape_shape_rule = MaterializeReshapeShape.rule()
6164

6265
rules = RewriteRuleSet([materialize_reshape_shape_rule])
63-
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from __future__ import annotations
4+
5+
import unittest
6+
7+
import numpy as np
8+
9+
from onnxscript import ir
10+
from onnxscript.rewriter import testing
11+
from onnxscript.rewriter.rules.common import _materialize_reshape_shape
12+
13+
14+
class MaterializeReshapeShapeTest(unittest.TestCase):
15+
def test_fully_static_output_shape_materializes(self):
16+
"""When output shape is fully static, replace dynamic shape input with constant."""
17+
model = ir.from_onnx_text(
18+
"""
19+
<ir_version: 7, opset_import: [ "" : 17]>
20+
agraph (float[6] data) => (float[2, 3] output)
21+
{
22+
shape = Shape(data)
23+
output = Reshape(data, shape)
24+
}
25+
"""
26+
)
27+
for node in model.graph:
28+
if node.op_type == "Reshape":
29+
node.outputs[0].shape = ir.Shape([2, 3])
30+
break
31+
count = _materialize_reshape_shape.rules.apply_to_model(model)
32+
self.assertEqual(count, 1)
33+
reshape_nodes = [n for n in model.graph if n.op_type == "Reshape"]
34+
self.assertEqual(len(reshape_nodes), 1)
35+
shape_input = reshape_nodes[0].inputs[1]
36+
self.assertIsNotNone(shape_input.const_value)
37+
self.assertEqual(shape_input.const_value.numpy().tolist(), [2, 3])
38+
39+
def test_one_symbolic_dim_uses_minus_one(self):
40+
"""When output has one symbolic dim, replace it with -1."""
41+
model = ir.from_onnx_text(
42+
"""
43+
<ir_version: 7, opset_import: [ "" : 17]>
44+
agraph (float[6] data) => (float[B, 3] output)
45+
{
46+
shape = Shape(data)
47+
output = Reshape(data, shape)
48+
}
49+
"""
50+
)
51+
for node in model.graph:
52+
if node.op_type == "Reshape":
53+
node.outputs[0].shape = ir.Shape(["B", 3])
54+
break
55+
count = _materialize_reshape_shape.rules.apply_to_model(model)
56+
self.assertEqual(count, 1)
57+
reshape_nodes = [n for n in model.graph if n.op_type == "Reshape"]
58+
self.assertEqual(len(reshape_nodes), 1)
59+
shape_input = reshape_nodes[0].inputs[1]
60+
self.assertIsNotNone(shape_input.const_value)
61+
self.assertEqual(shape_input.const_value.numpy().tolist(), [-1, 3])
62+
63+
def test_two_symbolic_dims_not_materialized(self):
64+
"""When output has two symbolic dims, the rule should not fire."""
65+
model = ir.from_onnx_text(
66+
"""
67+
<ir_version: 7, opset_import: [ "" : 17]>
68+
agraph (float[6] data) => (float[B, C] output)
69+
{
70+
shape = Shape(data)
71+
output = Reshape(data, shape)
72+
}
73+
"""
74+
)
75+
for node in model.graph:
76+
if node.op_type == "Reshape":
77+
node.outputs[0].shape = ir.Shape(["B", "C"])
78+
break
79+
count = _materialize_reshape_shape.rules.apply_to_model(model)
80+
self.assertEqual(count, 0)
81+
82+
def test_constant_shape_input_not_replaced(self):
83+
"""When the shape input is already a constant, the rule should not fire."""
84+
model = ir.from_onnx_text(
85+
"""
86+
<ir_version: 7, opset_import: [ "" : 17]>
87+
agraph (float[6] data) => (float[2, 3] output)
88+
{
89+
shape = Constant<value: tensor = int64[2] {2, 3}>()
90+
output = Reshape(data, shape)
91+
}
92+
"""
93+
)
94+
count = _materialize_reshape_shape.rules.apply_to_model(model)
95+
self.assertEqual(count, 0)
96+
97+
def test_unknown_output_shape_not_materialized(self):
98+
"""When the output shape is unknown, the rule should not fire."""
99+
model = ir.from_onnx_text(
100+
"""
101+
<ir_version: 7, opset_import: [ "" : 17]>
102+
agraph (float[6] data) => (float output)
103+
{
104+
shape = Shape(data)
105+
output = Reshape(data, shape)
106+
}
107+
"""
108+
)
109+
for node in model.graph:
110+
if node.op_type == "Reshape":
111+
node.outputs[0].shape = None
112+
break
113+
count = _materialize_reshape_shape.rules.apply_to_model(model)
114+
self.assertEqual(count, 0)
115+
116+
def test_allowzero_attribute_preserved(self):
117+
"""The allowzero attribute should be preserved on the new Reshape."""
118+
model = ir.from_onnx_text(
119+
"""
120+
<ir_version: 7, opset_import: [ "" : 17]>
121+
agraph (float[6] data) => (float[2, 3] output)
122+
{
123+
shape = Shape(data)
124+
output = Reshape<allowzero=1>(data, shape)
125+
}
126+
"""
127+
)
128+
for node in model.graph:
129+
if node.op_type == "Reshape":
130+
node.outputs[0].shape = ir.Shape([2, 3])
131+
break
132+
count = _materialize_reshape_shape.rules.apply_to_model(model)
133+
self.assertEqual(count, 1)
134+
reshape_nodes = [n for n in model.graph if n.op_type == "Reshape"]
135+
self.assertEqual(len(reshape_nodes), 1)
136+
allowzero = reshape_nodes[0].attributes.get_int("allowzero", 0)
137+
self.assertEqual(allowzero, 1)
138+
139+
def test_numerical_correctness_static(self):
140+
"""Verify numerical equivalence for fully static materialization."""
141+
# Build a model where a dynamic Concat produces the shape for Reshape.
142+
# After materialization, the Reshape uses a constant shape.
143+
model_text = """
144+
<ir_version: 7, opset_import: [ "" : 17]>
145+
agraph (float[12] data, float[3, 4] ref) => (float[3, 4] output)
146+
{
147+
shape = Shape(ref)
148+
output = Reshape(data, shape)
149+
}
150+
"""
151+
original = ir.from_onnx_text(model_text)
152+
model = ir.from_onnx_text(model_text)
153+
for node in model.graph:
154+
if node.op_type == "Reshape":
155+
node.outputs[0].shape = ir.Shape([3, 4])
156+
break
157+
_materialize_reshape_shape.rules.apply_to_model(model)
158+
testing.assert_numerically_equal(
159+
original,
160+
model,
161+
(
162+
np.arange(12).astype(np.float32),
163+
np.zeros((3, 4), dtype=np.float32),
164+
),
165+
)
166+
167+
168+
if __name__ == "__main__":
169+
unittest.main()

0 commit comments

Comments
 (0)