Skip to content

Commit a67ab69

Browse files
committed
Update slice
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent a856c40 commit a67ab69

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

onnxscript/optimizer/_constant_folding.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,20 @@ def squeeze(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
546546
return _propagate_shape_value(node, op, state)
547547

548548

549+
@register("Slice")
550+
def slice_op(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
551+
"""Replace Slice with Identity when output shape matches input shape."""
552+
input = _get_input(node, 0)
553+
if input is None or input.shape is None:
554+
return None
555+
output = _get_output(node, 0)
556+
if output is None or output.shape is None:
557+
return None
558+
if _same_shape(input.shape, output.shape):
559+
return op.Identity(input)
560+
return None
561+
562+
549563
@register("Cast")
550564
def cast(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
551565
input = _get_input(node, 0)

0 commit comments

Comments
 (0)