diff --git a/crates/ide-diagnostics/src/handlers/type_mismatch.rs b/crates/ide-diagnostics/src/handlers/type_mismatch.rs index 250c692d16f8..627171bdcb2a 100644 --- a/crates/ide-diagnostics/src/handlers/type_mismatch.rs +++ b/crates/ide-diagnostics/src/handlers/type_mismatch.rs @@ -68,7 +68,16 @@ fn fixes(ctx: &DiagnosticsContext<'_, '_>, d: &hir::TypeMismatch<'_>) -> Option< let mut fixes = Vec::new(); if let Some(expr_ptr) = d.expr_or_pat.value.cast::() { - let expr_ptr = &InFile { file_id: d.expr_or_pat.file_id, value: expr_ptr }; + let file_id = d.expr_or_pat.file_id; + let root = ctx.sema.db.parse_or_expand(file_id); + let mut expr = expr_ptr.to_node(&root); + while let ast::Expr::BlockExpr(block) = expr.clone() { + match block.tail_expr() { + Some(tail) => expr = tail, + None => break, + } + } + let expr_ptr = &InFile { file_id, value: AstPtr::new(&expr) }; add_reference(ctx, d, expr_ptr, &mut fixes); add_missing_ok_or_some(ctx, d, expr_ptr, &mut fixes); remove_unnecessary_wrapper(ctx, d, expr_ptr, &mut fixes); @@ -342,6 +351,44 @@ mod tests { check_has_fix, check_no_fix, }; + #[test] + fn add_reference_when_tail_expr_of_block() { + check_fix( + r#" +fn main() { + let a = 0; + let b = 1; + if false { &a } else { b$0 }; +} + "#, + r#" +fn main() { + let a = 0; + let b = 1; + if false { &a } else { &b }; +} + "#, + ); + } + + #[test] + fn str_ref_to_owned_when_tail_expr_of_block() { + check_has_fix( + r#" +struct String; +fn main() { + if true { String } else { ""$0 }; +} + "#, + r#" +struct String; +fn main() { + if true { String } else { "".to_owned() }; +} + "#, + ); + } + #[test] fn missing_reference() { check_diagnostics(