Skip to content

Commit 03a44c0

Browse files
committed
Create slice and reshape improvements
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent f6a7de1 commit 03a44c0

File tree

5 files changed

+79
-1
lines changed

5 files changed

+79
-1
lines changed

onnxscript/optimizer/_constant_folding.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,10 @@ def set_sym_value(self, value: ir.Value, sym_value: SymbolicValue) -> None:
196196

197197
def get_shape_value(self, value: ir.Value | None) -> ir.Shape | None:
198198
const_value = _get_numpy_value(value, ir.DataType.INT64, size_limit=10)
199+
if const_value is None:
200+
# Reshape accepts shape input of INT32 type as well, so we also check for INT32 here
201+
# This is common for tflite models
202+
const_value = _get_numpy_value(value, ir.DataType.INT32, size_limit=10)
199203
if const_value is not None:
200204
if const_value.ndim == 1:
201205
return ir.Shape(const_value.tolist())

onnxscript/rewriter/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
_fuse_batchnorm,
4343
_fuse_pad_into_conv,
4444
_fuse_relus_clips,
45+
_materialize_reshape_shape,
4546
_min_max_to_clip,
4647
_no_op,
4748
_redundant_scatter_nd,
@@ -54,6 +55,7 @@
5455
*_broadcast_to_matmul.rules,
5556
*_cast_constant_of_shape.rules,
5657
*_collapse_slices.rules,
58+
*_materialize_reshape_shape.rules,
5759
*_min_max_to_clip.rules,
5860
*_fuse_relus_clips.rules,
5961
*_basic_rules.basic_optimization_rules(),

onnxscript/rewriter/rules/common/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
"max_min_rule",
2626
"gemm_to_matmul_add_rule",
2727
"matmul_add_to_gemm_rule",
28+
"materialize_reshape_shape_rule",
2829
"mul_by_1_rule",
2930
"no_op_cast_rule",
3031
"no_op_dynamic_scatter_nd_rule",
@@ -107,6 +108,9 @@
107108
transpose_ab_matmul_add_to_gemm_rule,
108109
transpose_b_matmul_add_to_gemm_rule,
109110
)
111+
from onnxscript.rewriter.rules.common._materialize_reshape_shape import (
112+
materialize_reshape_shape_rule,
113+
)
110114
from onnxscript.rewriter.rules.common._min_max_to_clip import (
111115
max_max_rule,
112116
max_min_rule,

onnxscript/rewriter/rules/common/_collapse_slices.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,12 @@ def _same_shape(op, data: ir.Value, slice_output: ir.Value, steps: ir.Value, **_
8282
if data.shape is None or slice_output.shape is None:
8383
return False
8484

85-
if not _ir_utils.is_singleton_value(steps, 1):
85+
# All steps must be 1
86+
steps_np = _ir_utils.get_numpy_value(steps)
87+
if steps_np is not None:
88+
if not all(s == 1 for s in steps_np.flatten()):
89+
return False
90+
elif not _ir_utils.is_singleton_value(steps, 1):
8691
return False
8792

8893
return _ir_utils.same_shape(data.shape, slice_output.shape)
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""Materialize Reshape shape input from known output shape.
4+
5+
When symbolic shape inference has been run, a Reshape node may have a known
6+
output shape even though its shape input is computed dynamically (e.g., via a
7+
Shape → Cast → Split → Concat chain). This rule replaces the shape input
8+
with a concrete constant, allowing the dynamic chain to become dead code and
9+
be removed by unused-node elimination.
10+
11+
- Fully static output shape → constant with exact dims.
12+
- Exactly one symbolic dim → replace it with ``-1`` (Reshape infers it).
13+
"""
14+
15+
from __future__ import annotations
16+
17+
from onnxscript import ir
18+
from onnxscript.rewriter import _ir_utils as ir_utils
19+
from onnxscript.rewriter._basics import MatchResult
20+
from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet
21+
22+
23+
class MaterializeReshapeShape(RewriteRuleClassBase):
24+
"""Replace a dynamic Reshape shape input with a constant when output shape is known."""
25+
26+
def pattern(self, op, data, shape):
27+
return op.Reshape(data, shape)
28+
29+
def check(self, context, data: ir.Value, shape: ir.Value) -> MatchResult:
30+
check_result = MatchResult()
31+
32+
# Shape input must not already be a constant
33+
if ir_utils.get_numpy_value(shape) is not None:
34+
return check_result.fail("Shape input is already a constant.")
35+
36+
output = context.output_values[0]
37+
if output.shape is None:
38+
return check_result.fail("Output shape is not known.")
39+
40+
dims = list(output.shape)
41+
sym_count = sum(1 for d in dims if not isinstance(d, int))
42+
43+
if sym_count == 0:
44+
self._new_dims = [int(d) for d in dims]
45+
elif sym_count == 1:
46+
self._new_dims = [-1 if not isinstance(d, int) else int(d) for d in dims]
47+
else:
48+
return check_result.fail(
49+
f"Output shape has {sym_count} symbolic dims, cannot materialize."
50+
)
51+
return check_result
52+
53+
def rewrite(self, op, data: ir.Value, shape: ir.Value):
54+
new_shape = op.Constant(
55+
value=ir.tensor(self._new_dims, dtype=ir.DataType.INT64),
56+
)
57+
return op.Reshape(data, new_shape)
58+
59+
60+
materialize_reshape_shape_rule = MaterializeReshapeShape.rule()
61+
62+
rules = RewriteRuleSet([materialize_reshape_shape_rule])
63+

0 commit comments

Comments
 (0)