@@ -196,6 +196,10 @@ def set_sym_value(self, value: ir.Value, sym_value: SymbolicValue) -> None:
196196
197197 def get_shape_value (self , value : ir .Value | None ) -> ir .Shape | None :
198198 const_value = _get_numpy_value (value , ir .DataType .INT64 , size_limit = 10 )
199+ if const_value is None :
200+ # Reshape accepts shape input of INT32 type as well, so we also check for INT32 here
201+ # This is common for tflite models
202+ const_value = _get_numpy_value (value , ir .DataType .INT32 , size_limit = 10 )
199203 if const_value is not None :
200204 if const_value .ndim == 1 :
201205 return ir .Shape (const_value .tolist ())
@@ -471,7 +475,11 @@ def _propagate_shape_value(node: ir.Node, op, state: OptimizerState) -> ReturnVa
471475def reshape (node : ir .Node , op , state : OptimizerState ) -> ReturnValue :
472476 """Replace a Reshape node by Identity when applicable.
473477
474- Also propagate symbolic shape values.
478+ Also propagate symbolic shape values. When the output shape is known
479+ (from shape inference) and the shape input is not already a constant,
480+ replace the shape input with a concrete constant:
481+ - Fully static output shape → constant with exact dims.
482+ - Exactly one symbolic dim → replace it with -1 (Reshape infers it).
475483 """
476484 input = _get_input (node , 0 )
477485 shape = _get_input (node , 1 )
@@ -482,6 +490,10 @@ def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
482490 shape_value = state .get_shape_value (shape )
483491
484492 if shape_value is None or input_shape is None :
493+ # Try to materialize the shape input from the output shape
494+ result = _try_materialize_reshape_shape (node , op , input , shape )
495+ if result is not None :
496+ return result
485497 return _propagate_shape_value (node , op , state )
486498
487499 # No need to check for special values like -1, 0, etc. here
@@ -490,6 +502,44 @@ def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
490502 return _propagate_shape_value (node , op , state )
491503
492504
505+ def _try_materialize_reshape_shape (
506+ node : ir .Node , op , input : ir .Value , shape : ir .Value
507+ ) -> ReturnValue :
508+ """Create a constant shape input for Reshape from the output shape.
509+
510+ When the Reshape output already has a known shape (e.g., from symbolic
511+ shape inference) but the shape input is not a constant, we can replace
512+ the shape input with a materialized constant tensor.
513+
514+ - If the output shape is fully static, use the exact dims.
515+ - If exactly one dim is symbolic, replace it with -1.
516+ """
517+ output = _get_output (node , 0 )
518+ if output is None or output .shape is None :
519+ return None
520+
521+ out_shape = output .shape
522+ dims = list (out_shape )
523+ sym_count = sum (1 for d in dims if not isinstance (d , int ))
524+
525+ if sym_count == 0 :
526+ # Fully static — create constant with exact dims
527+ new_shape = op .Constant (
528+ value = ir .tensor (dims , dtype = ir .DataType .INT64 ),
529+ )
530+ return op .Reshape (input , new_shape )
531+
532+ if sym_count == 1 :
533+ # Replace the single symbolic dim with -1
534+ concrete_dims = [- 1 if not isinstance (d , int ) else d for d in dims ]
535+ new_shape = op .Constant (
536+ value = ir .tensor (concrete_dims , dtype = ir .DataType .INT64 ),
537+ )
538+ return op .Reshape (input , new_shape )
539+
540+ return None
541+
542+
493543@register ("Squeeze" )
494544def squeeze (node : ir .Node , op , state : OptimizerState ) -> ReturnValue :
495545 """Propagate symbolic shape values."""
0 commit comments