diff --git a/src/main/java/org/openrewrite/java/migrate/lang/IfElseIfConstructToSwitch.java b/src/main/java/org/openrewrite/java/migrate/lang/IfElseIfConstructToSwitch.java index f2f4fc69d3..f75cbbf9d3 100644 --- a/src/main/java/org/openrewrite/java/migrate/lang/IfElseIfConstructToSwitch.java +++ b/src/main/java/org/openrewrite/java/migrate/lang/IfElseIfConstructToSwitch.java @@ -19,6 +19,7 @@ import lombok.Value; import org.jspecify.annotations.Nullable; import org.openrewrite.*; +import org.openrewrite.internal.ListUtils; import org.openrewrite.java.JavaIsoVisitor; import org.openrewrite.java.JavaTemplate; import org.openrewrite.java.JavaVisitor; @@ -31,12 +32,10 @@ import org.openrewrite.staticanalysis.groovy.GroovyFileChecker; import org.openrewrite.staticanalysis.kotlin.KotlinFileChecker; -import java.util.LinkedHashMap; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; +import java.util.*; import java.util.concurrent.atomic.AtomicBoolean; +import static java.util.Collections.singletonList; import static org.openrewrite.java.migrate.lang.NullCheck.Matcher.nullCheck; import static org.openrewrite.java.tree.J.Block.createEmptyBlock; @@ -218,7 +217,42 @@ public J.Return visitReturn(J.Return return_, AtomicBoolean atomicBoolean) { } switchBody.append("}\n"); - return JavaTemplate.apply(switchBody.toString(), cursor, if_.getCoordinates().replace(), arguments).withPrefix(if_.getPrefix()); + J.Switch result = JavaTemplate.apply(switchBody.toString(), cursor, if_.getCoordinates().replace(), arguments).withPrefix(if_.getPrefix()); + return fixTypeAttribution(result); + } + + /** + * JavaTemplate uses raw string substitution (#{}), which loses type information + * for non-JDK types. This method restores the original type information from the + * instanceof checks onto the generated switch case labels. + */ + private J.Switch fixTypeAttribution(J.Switch switch_) { + Iterator instanceOfs = patternMatchers.keySet().iterator(); + return switch_.withCases(switch_.getCases().withStatements( + ListUtils.map(switch_.getCases().getStatements(), stmt -> { + if (stmt instanceof J.Case && instanceOfs.hasNext()) { + J.Case case_ = (J.Case) stmt; + if (!case_.getCaseLabels().isEmpty() && case_.getCaseLabels().get(0) instanceof J.VariableDeclarations) { + J.InstanceOf instanceOf = instanceOfs.next(); + J.VariableDeclarations varDecl = (J.VariableDeclarations) case_.getCaseLabels().get(0); + // Replace typeExpression with the original clazz (which has proper type info) + varDecl = varDecl.withTypeExpression( + varDecl.getTypeExpression() != null ? + instanceOf.getClazz().withPrefix(varDecl.getTypeExpression().getPrefix()) : + instanceOf.getClazz().withPrefix(Space.EMPTY)); + // Fix variable type from original pattern + if (instanceOf.getPattern() instanceof J.Identifier && !varDecl.getVariables().isEmpty()) { + J.Identifier originalPattern = (J.Identifier) instanceOf.getPattern(); + J.VariableDeclarations.NamedVariable var0 = varDecl.getVariables().get(0); + varDecl = varDecl.withVariables(singletonList( + var0.withType(originalPattern.getType()) + .withName(var0.getName().withType(originalPattern.getType())))); + } + return case_.withCaseLabels(singletonList(varDecl.withPrefix(case_.getCaseLabels().get(0).getPrefix()))); + } + } + return stmt; + }))); } private Optional switchOn() { diff --git a/src/main/resources/META-INF/rewrite/java-version-21.yml b/src/main/resources/META-INF/rewrite/java-version-21.yml index fed5fadc7a..f2f47de43f 100644 --- a/src/main/resources/META-INF/rewrite/java-version-21.yml +++ b/src/main/resources/META-INF/rewrite/java-version-21.yml @@ -44,7 +44,7 @@ recipeList: - org.openrewrite.java.migrate.lang.SwitchCaseAssignmentsToSwitchExpression - org.openrewrite.java.migrate.lang.SwitchCaseReturnsToSwitchExpression - org.openrewrite.java.migrate.lang.SwitchExpressionYieldToArrow - #- org.openrewrite.java.migrate.lang.IfElseIfConstructToSwitch # FIXME `casecase` seen near non JDK types + - org.openrewrite.java.migrate.lang.IfElseIfConstructToSwitch - org.openrewrite.java.migrate.SwitchPatternMatching - org.openrewrite.java.migrate.lang.NullCheckAsSwitchCase diff --git a/src/test/java/org/openrewrite/java/migrate/lang/IfElseIfConstructToSwitchTest.java b/src/test/java/org/openrewrite/java/migrate/lang/IfElseIfConstructToSwitchTest.java index 7e83ab0dae..e495e41d96 100644 --- a/src/test/java/org/openrewrite/java/migrate/lang/IfElseIfConstructToSwitchTest.java +++ b/src/test/java/org/openrewrite/java/migrate/lang/IfElseIfConstructToSwitchTest.java @@ -341,6 +341,111 @@ static String formatter(Object obj) { ); } + @Test + void nullCheckWithNonJdkTypes() { + rewriteRun( + java( + """ + public class Dog {} + """ + ), + java( + """ + public class Cat {} + """ + ), + //language=java + java( + """ + class Test { + static String describe(Object obj) { + String result; + if (obj == null) { + result = "nothing"; + } else if (obj instanceof Dog d) { + result = "dog"; + } else if (obj instanceof Cat c) { + result = "cat"; + } else { + result = "unknown"; + } + return result; + } + } + """, + """ + class Test { + static String describe(Object obj) { + String result; + switch (obj) { + case null -> result = "nothing"; + case Dog d -> result = "dog"; + case Cat c -> result = "cat"; + default -> result = "unknown"; + } + return result; + } + } + """ + ) + ); + } + + @Test + void threeNonJdkTypes() { + rewriteRun( + java( + """ + public class Dog {} + """ + ), + java( + """ + public class Cat {} + """ + ), + java( + """ + public class Bird {} + """ + ), + //language=java + java( + """ + class Test { + static String describe(Object obj) { + String result; + if (obj instanceof Dog d) { + result = "dog"; + } else if (obj instanceof Cat c) { + result = "cat"; + } else if (obj instanceof Bird b) { + result = "bird"; + } else { + result = "unknown"; + } + return result; + } + } + """, + """ + class Test { + static String describe(Object obj) { + String result; + switch (obj) { + case Dog d -> result = "dog"; + case Cat c -> result = "cat"; + case Bird b -> result = "bird"; + default -> result = "unknown"; + } + return result; + } + } + """ + ) + ); + } + @Test void switchBlockForNestedClasses() { rewriteRun(