Skip to content

Commit a856c40

Browse files
committed
Improve Reshape constant fold
Constant fold the input to reshape if from the output shape we know it can be folded Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 2b2618e commit a856c40

File tree

1 file changed

+51
-1
lines changed

1 file changed

+51
-1
lines changed

onnxscript/optimizer/_constant_folding.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,10 @@ def set_sym_value(self, value: ir.Value, sym_value: SymbolicValue) -> None:
196196

197197
def get_shape_value(self, value: ir.Value | None) -> ir.Shape | None:
198198
const_value = _get_numpy_value(value, ir.DataType.INT64, size_limit=10)
199+
if const_value is None:
200+
# Reshape accepts shape input of INT32 type as well, so we also check for INT32 here
201+
# This is common for tflite models
202+
const_value = _get_numpy_value(value, ir.DataType.INT32, size_limit=10)
199203
if const_value is not None:
200204
if const_value.ndim == 1:
201205
return ir.Shape(const_value.tolist())
@@ -471,7 +475,11 @@ def _propagate_shape_value(node: ir.Node, op, state: OptimizerState) -> ReturnVa
471475
def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
472476
"""Replace a Reshape node by Identity when applicable.
473477
474-
Also propagate symbolic shape values.
478+
Also propagate symbolic shape values. When the output shape is known
479+
(from shape inference) and the shape input is not already a constant,
480+
replace the shape input with a concrete constant:
481+
- Fully static output shape → constant with exact dims.
482+
- Exactly one symbolic dim → replace it with -1 (Reshape infers it).
475483
"""
476484
input = _get_input(node, 0)
477485
shape = _get_input(node, 1)
@@ -482,6 +490,10 @@ def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
482490
shape_value = state.get_shape_value(shape)
483491

484492
if shape_value is None or input_shape is None:
493+
# Try to materialize the shape input from the output shape
494+
result = _try_materialize_reshape_shape(node, op, input, shape)
495+
if result is not None:
496+
return result
485497
return _propagate_shape_value(node, op, state)
486498

487499
# No need to check for special values like -1, 0, etc. here
@@ -490,6 +502,44 @@ def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
490502
return _propagate_shape_value(node, op, state)
491503

492504

505+
def _try_materialize_reshape_shape(
506+
node: ir.Node, op, input: ir.Value, shape: ir.Value
507+
) -> ReturnValue:
508+
"""Create a constant shape input for Reshape from the output shape.
509+
510+
When the Reshape output already has a known shape (e.g., from symbolic
511+
shape inference) but the shape input is not a constant, we can replace
512+
the shape input with a materialized constant tensor.
513+
514+
- If the output shape is fully static, use the exact dims.
515+
- If exactly one dim is symbolic, replace it with -1.
516+
"""
517+
output = _get_output(node, 0)
518+
if output is None or output.shape is None:
519+
return None
520+
521+
out_shape = output.shape
522+
dims = list(out_shape)
523+
sym_count = sum(1 for d in dims if not isinstance(d, int))
524+
525+
if sym_count == 0:
526+
# Fully static — create constant with exact dims
527+
new_shape = op.Constant(
528+
value=ir.tensor(dims, dtype=ir.DataType.INT64),
529+
)
530+
return op.Reshape(input, new_shape)
531+
532+
if sym_count == 1:
533+
# Replace the single symbolic dim with -1
534+
concrete_dims = [-1 if not isinstance(d, int) else d for d in dims]
535+
new_shape = op.Constant(
536+
value=ir.tensor(concrete_dims, dtype=ir.DataType.INT64),
537+
)
538+
return op.Reshape(input, new_shape)
539+
540+
return None
541+
542+
493543
@register("Squeeze")
494544
def squeeze(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
495545
"""Propagate symbolic shape values."""

0 commit comments

Comments
 (0)