Avoid imports from current package

Closes gh-1421
This commit is contained in:
Stephane Nicoll
2023-06-07 15:44:48 +02:00
parent 8acbad503a
commit 4c77196504
6 changed files with 132 additions and 104 deletions

View File

@@ -39,6 +39,7 @@ import io.spring.initializr.generator.io.IndentingWriterFactory;
import io.spring.initializr.generator.language.Annotatable; import io.spring.initializr.generator.language.Annotatable;
import io.spring.initializr.generator.language.Annotation; import io.spring.initializr.generator.language.Annotation;
import io.spring.initializr.generator.language.CodeBlock.FormattingOptions; import io.spring.initializr.generator.language.CodeBlock.FormattingOptions;
import io.spring.initializr.generator.language.CompilationUnit;
import io.spring.initializr.generator.language.Parameter; import io.spring.initializr.generator.language.Parameter;
import io.spring.initializr.generator.language.SourceCode; import io.spring.initializr.generator.language.SourceCode;
import io.spring.initializr.generator.language.SourceCodeWriter; import io.spring.initializr.generator.language.SourceCodeWriter;
@@ -259,34 +260,30 @@ public class GroovySourceCodeWriter implements SourceCodeWriter<GroovySourceCode
private Set<String> determineImports(GroovyCompilationUnit compilationUnit) { private Set<String> determineImports(GroovyCompilationUnit compilationUnit) {
List<String> imports = new ArrayList<>(); List<String> imports = new ArrayList<>();
for (GroovyTypeDeclaration typeDeclaration : compilationUnit.getTypeDeclarations()) { for (GroovyTypeDeclaration typeDeclaration : compilationUnit.getTypeDeclarations()) {
if (requiresImport(typeDeclaration.getExtends())) { imports.add(typeDeclaration.getExtends());
imports.add(typeDeclaration.getExtends()); imports.addAll(appendImports(typeDeclaration.getAnnotations(), this::determineImports));
}
imports.addAll(getRequiredImports(typeDeclaration.getAnnotations(), this::determineImports));
for (GroovyFieldDeclaration fieldDeclaration : typeDeclaration.getFieldDeclarations()) { for (GroovyFieldDeclaration fieldDeclaration : typeDeclaration.getFieldDeclarations()) {
if (requiresImport(fieldDeclaration.getReturnType())) { imports.add(fieldDeclaration.getReturnType());
imports.add(fieldDeclaration.getReturnType()); imports.addAll(appendImports(fieldDeclaration.getAnnotations(), this::determineImports));
}
imports.addAll(getRequiredImports(fieldDeclaration.getAnnotations(), this::determineImports));
} }
for (GroovyMethodDeclaration methodDeclaration : typeDeclaration.getMethodDeclarations()) { for (GroovyMethodDeclaration methodDeclaration : typeDeclaration.getMethodDeclarations()) {
if (requiresImport(methodDeclaration.getReturnType())) { imports.add(methodDeclaration.getReturnType());
imports.add(methodDeclaration.getReturnType()); imports.addAll(appendImports(methodDeclaration.getAnnotations(), this::determineImports));
} imports.addAll(appendImports(methodDeclaration.getParameters(),
imports.addAll(getRequiredImports(methodDeclaration.getAnnotations(), this::determineImports));
imports.addAll(getRequiredImports(methodDeclaration.getParameters(),
(parameter) -> Collections.singletonList(parameter.getType()))); (parameter) -> Collections.singletonList(parameter.getType())));
imports.addAll(methodDeclaration.getCode().getImports()); imports.addAll(methodDeclaration.getCode().getImports());
determineImportsFromStatements(imports, methodDeclaration); determineImportsFromStatements(imports, methodDeclaration);
} }
} }
Collections.sort(imports); return imports.stream()
return new LinkedHashSet<>(imports); .filter((candidate) -> isImportCandidate(compilationUnit, candidate))
.sorted()
.collect(Collectors.toCollection(LinkedHashSet::new));
} }
@SuppressWarnings("removal") @SuppressWarnings("removal")
private void determineImportsFromStatements(List<String> imports, GroovyMethodDeclaration methodDeclaration) { private void determineImportsFromStatements(List<String> imports, GroovyMethodDeclaration methodDeclaration) {
imports.addAll(getRequiredImports( imports.addAll(appendImports(
methodDeclaration.getStatements() methodDeclaration.getStatements()
.stream() .stream()
.filter(GroovyExpressionStatement.class::isInstance) .filter(GroovyExpressionStatement.class::isInstance)
@@ -314,15 +311,12 @@ public class GroovySourceCodeWriter implements SourceCodeWriter<GroovySourceCode
return imports; return imports;
} }
private <T> List<String> getRequiredImports(List<T> candidates, Function<T, Collection<String>> mapping) { private <T> List<String> appendImports(List<T> candidates, Function<T, Collection<String>> mapping) {
return getRequiredImports(candidates.stream(), mapping); return appendImports(candidates.stream(), mapping);
} }
private <T> List<String> getRequiredImports(Stream<T> candidates, Function<T, Collection<String>> mapping) { private <T> List<String> appendImports(Stream<T> candidates, Function<T, Collection<String>> mapping) {
return candidates.map(mapping) return candidates.map(mapping).flatMap(Collection::stream).collect(Collectors.toList());
.flatMap(Collection::stream)
.filter(this::requiresImport)
.collect(Collectors.toList());
} }
private String getUnqualifiedName(String name) { private String getUnqualifiedName(String name) {
@@ -332,12 +326,12 @@ public class GroovySourceCodeWriter implements SourceCodeWriter<GroovySourceCode
return name.substring(name.lastIndexOf(".") + 1); return name.substring(name.lastIndexOf(".") + 1);
} }
private boolean requiresImport(String name) { private boolean isImportCandidate(CompilationUnit<?> compilationUnit, String name) {
if (name == null || !name.contains(".")) { if (name == null || !name.contains(".")) {
return false; return false;
} }
String packageName = name.substring(0, name.lastIndexOf('.')); String packageName = name.substring(0, name.lastIndexOf('.'));
return !"java.lang".equals(packageName); return !"java.lang".equals(packageName) && !compilationUnit.getPackageName().equals(packageName);
} }
static class GroovyFormattingOptions implements FormattingOptions { static class GroovyFormattingOptions implements FormattingOptions {

View File

@@ -39,6 +39,7 @@ import io.spring.initializr.generator.io.IndentingWriterFactory;
import io.spring.initializr.generator.language.Annotatable; import io.spring.initializr.generator.language.Annotatable;
import io.spring.initializr.generator.language.Annotation; import io.spring.initializr.generator.language.Annotation;
import io.spring.initializr.generator.language.CodeBlock; import io.spring.initializr.generator.language.CodeBlock;
import io.spring.initializr.generator.language.CompilationUnit;
import io.spring.initializr.generator.language.Parameter; import io.spring.initializr.generator.language.Parameter;
import io.spring.initializr.generator.language.SourceCode; import io.spring.initializr.generator.language.SourceCode;
import io.spring.initializr.generator.language.SourceCodeWriter; import io.spring.initializr.generator.language.SourceCodeWriter;
@@ -260,34 +261,31 @@ public class JavaSourceCodeWriter implements SourceCodeWriter<JavaSourceCode> {
private Set<String> determineImports(JavaCompilationUnit compilationUnit) { private Set<String> determineImports(JavaCompilationUnit compilationUnit) {
List<String> imports = new ArrayList<>(); List<String> imports = new ArrayList<>();
for (JavaTypeDeclaration typeDeclaration : compilationUnit.getTypeDeclarations()) { for (JavaTypeDeclaration typeDeclaration : compilationUnit.getTypeDeclarations()) {
if (requiresImport(typeDeclaration.getExtends())) { imports.add(typeDeclaration.getExtends());
imports.add(typeDeclaration.getExtends());
} imports.addAll(appendImports(typeDeclaration.getAnnotations(), this::determineImports));
imports.addAll(getRequiredImports(typeDeclaration.getAnnotations(), this::determineImports));
for (JavaFieldDeclaration fieldDeclaration : typeDeclaration.getFieldDeclarations()) { for (JavaFieldDeclaration fieldDeclaration : typeDeclaration.getFieldDeclarations()) {
if (requiresImport(fieldDeclaration.getReturnType())) { imports.add(fieldDeclaration.getReturnType());
imports.add(fieldDeclaration.getReturnType()); imports.addAll(appendImports(fieldDeclaration.getAnnotations(), this::determineImports));
}
imports.addAll(getRequiredImports(fieldDeclaration.getAnnotations(), this::determineImports));
} }
for (JavaMethodDeclaration methodDeclaration : typeDeclaration.getMethodDeclarations()) { for (JavaMethodDeclaration methodDeclaration : typeDeclaration.getMethodDeclarations()) {
if (requiresImport(methodDeclaration.getReturnType())) { imports.add(methodDeclaration.getReturnType());
imports.add(methodDeclaration.getReturnType()); imports.addAll(appendImports(methodDeclaration.getAnnotations(), this::determineImports));
} imports.addAll(appendImports(methodDeclaration.getParameters(),
imports.addAll(getRequiredImports(methodDeclaration.getAnnotations(), this::determineImports));
imports.addAll(getRequiredImports(methodDeclaration.getParameters(),
(parameter) -> Collections.singletonList(parameter.getType()))); (parameter) -> Collections.singletonList(parameter.getType())));
determineImportsFromStatements(imports, methodDeclaration); determineImportsFromStatements(imports, methodDeclaration);
imports.addAll(methodDeclaration.getCode().getImports()); imports.addAll(methodDeclaration.getCode().getImports());
} }
} }
Collections.sort(imports); return imports.stream()
return new LinkedHashSet<>(imports); .filter((candidate) -> isImportCandidate(compilationUnit, candidate))
.sorted()
.collect(Collectors.toCollection(LinkedHashSet::new));
} }
@SuppressWarnings("removal") @SuppressWarnings("removal")
private void determineImportsFromStatements(List<String> imports, JavaMethodDeclaration methodDeclaration) { private void determineImportsFromStatements(List<String> imports, JavaMethodDeclaration methodDeclaration) {
imports.addAll(getRequiredImports( imports.addAll(appendImports(
methodDeclaration.getStatements() methodDeclaration.getStatements()
.stream() .stream()
.filter(JavaExpressionStatement.class::isInstance) .filter(JavaExpressionStatement.class::isInstance)
@@ -315,15 +313,12 @@ public class JavaSourceCodeWriter implements SourceCodeWriter<JavaSourceCode> {
return imports; return imports;
} }
private <T> List<String> getRequiredImports(List<T> candidates, Function<T, Collection<String>> mapping) { private <T> List<String> appendImports(List<T> candidates, Function<T, Collection<String>> mapping) {
return getRequiredImports(candidates.stream(), mapping); return appendImports(candidates.stream(), mapping);
} }
private <T> List<String> getRequiredImports(Stream<T> candidates, Function<T, Collection<String>> mapping) { private <T> List<String> appendImports(Stream<T> candidates, Function<T, Collection<String>> mapping) {
return candidates.map(mapping) return candidates.map(mapping).flatMap(Collection::stream).collect(Collectors.toList());
.flatMap(Collection::stream)
.filter(this::requiresImport)
.collect(Collectors.toList());
} }
private String getUnqualifiedName(String name) { private String getUnqualifiedName(String name) {
@@ -333,12 +328,12 @@ public class JavaSourceCodeWriter implements SourceCodeWriter<JavaSourceCode> {
return name.substring(name.lastIndexOf(".") + 1); return name.substring(name.lastIndexOf(".") + 1);
} }
private boolean requiresImport(String name) { private boolean isImportCandidate(CompilationUnit<?> compilationUnit, String name) {
if (name == null || !name.contains(".")) { if (name == null || !name.contains(".")) {
return false; return false;
} }
String packageName = name.substring(0, name.lastIndexOf('.')); String packageName = name.substring(0, name.lastIndexOf('.'));
return !"java.lang".equals(packageName); return !"java.lang".equals(packageName) && !compilationUnit.getPackageName().equals(packageName);
} }
} }

View File

@@ -35,6 +35,7 @@ import io.spring.initializr.generator.io.IndentingWriterFactory;
import io.spring.initializr.generator.language.Annotatable; import io.spring.initializr.generator.language.Annotatable;
import io.spring.initializr.generator.language.Annotation; import io.spring.initializr.generator.language.Annotation;
import io.spring.initializr.generator.language.CodeBlock.FormattingOptions; import io.spring.initializr.generator.language.CodeBlock.FormattingOptions;
import io.spring.initializr.generator.language.CompilationUnit;
import io.spring.initializr.generator.language.Parameter; import io.spring.initializr.generator.language.Parameter;
import io.spring.initializr.generator.language.SourceCode; import io.spring.initializr.generator.language.SourceCode;
import io.spring.initializr.generator.language.SourceCodeWriter; import io.spring.initializr.generator.language.SourceCodeWriter;
@@ -299,10 +300,8 @@ public class KotlinSourceCodeWriter implements SourceCodeWriter<KotlinSourceCode
private Set<String> determineImports(KotlinCompilationUnit compilationUnit) { private Set<String> determineImports(KotlinCompilationUnit compilationUnit) {
List<String> imports = new ArrayList<>(); List<String> imports = new ArrayList<>();
for (KotlinTypeDeclaration typeDeclaration : compilationUnit.getTypeDeclarations()) { for (KotlinTypeDeclaration typeDeclaration : compilationUnit.getTypeDeclarations()) {
if (requiresImport(typeDeclaration.getExtends())) { imports.add(typeDeclaration.getExtends());
imports.add(typeDeclaration.getExtends()); imports.addAll(appendImports(typeDeclaration.getAnnotations(), this::determineImports));
}
imports.addAll(getRequiredImports(typeDeclaration.getAnnotations(), this::determineImports));
typeDeclaration.getPropertyDeclarations() typeDeclaration.getPropertyDeclarations()
.forEach(((propertyDeclaration) -> imports.addAll(determinePropertyImports(propertyDeclaration)))); .forEach(((propertyDeclaration) -> imports.addAll(determinePropertyImports(propertyDeclaration))));
typeDeclaration.getFunctionDeclarations() typeDeclaration.getFunctionDeclarations()
@@ -310,32 +309,29 @@ public class KotlinSourceCodeWriter implements SourceCodeWriter<KotlinSourceCode
} }
compilationUnit.getTopLevelFunctions() compilationUnit.getTopLevelFunctions()
.forEach((functionDeclaration) -> imports.addAll(determineFunctionImports(functionDeclaration))); .forEach((functionDeclaration) -> imports.addAll(determineFunctionImports(functionDeclaration)));
Collections.sort(imports); return imports.stream()
return new LinkedHashSet<>(imports); .filter((candidate) -> isImportCandidate(compilationUnit, candidate))
.sorted()
.collect(Collectors.toCollection(LinkedHashSet::new));
} }
private Set<String> determinePropertyImports(KotlinPropertyDeclaration propertyDeclaration) { private Set<String> determinePropertyImports(KotlinPropertyDeclaration propertyDeclaration) {
Set<String> imports = new LinkedHashSet<>(); return (propertyDeclaration.getReturnType() != null) ? Set.of(propertyDeclaration.getReturnType())
if (requiresImport(propertyDeclaration.getReturnType())) { : Collections.emptySet();
imports.add(propertyDeclaration.getReturnType());
}
return imports;
} }
private Set<String> determineFunctionImports(KotlinFunctionDeclaration functionDeclaration) { private Set<String> determineFunctionImports(KotlinFunctionDeclaration functionDeclaration) {
Set<String> imports = new LinkedHashSet<>(); Set<String> imports = new LinkedHashSet<>();
if (requiresImport(functionDeclaration.getReturnType())) { imports.add(functionDeclaration.getReturnType());
imports.add(functionDeclaration.getReturnType()); imports.addAll(appendImports(functionDeclaration.getAnnotations(), this::determineImports));
} imports.addAll(appendImports(functionDeclaration.getParameters(),
imports.addAll(getRequiredImports(functionDeclaration.getAnnotations(), this::determineImports));
imports.addAll(getRequiredImports(functionDeclaration.getParameters(),
(parameter) -> Collections.singleton(parameter.getType()))); (parameter) -> Collections.singleton(parameter.getType())));
imports.addAll(functionDeclaration.getCode().getImports()); imports.addAll(functionDeclaration.getCode().getImports());
imports.addAll(getRequiredImports( imports.addAll(appendImports(
getKotlinExpressions(functionDeclaration).filter(KotlinFunctionInvocation.class::isInstance) getKotlinExpressions(functionDeclaration).filter(KotlinFunctionInvocation.class::isInstance)
.map(KotlinFunctionInvocation.class::cast), .map(KotlinFunctionInvocation.class::cast),
(invocation) -> Collections.singleton(invocation.getTarget()))); (invocation) -> Collections.singleton(invocation.getTarget())));
imports.addAll(getRequiredImports( imports.addAll(appendImports(
getKotlinExpressions(functionDeclaration).filter(KotlinReifiedFunctionInvocation.class::isInstance) getKotlinExpressions(functionDeclaration).filter(KotlinReifiedFunctionInvocation.class::isInstance)
.map(KotlinReifiedFunctionInvocation.class::cast), .map(KotlinReifiedFunctionInvocation.class::cast),
(invocation) -> Collections.singleton(invocation.getName()))); (invocation) -> Collections.singleton(invocation.getName())));
@@ -368,15 +364,12 @@ public class KotlinSourceCodeWriter implements SourceCodeWriter<KotlinSourceCode
.map(KotlinExpressionStatement::getExpression); .map(KotlinExpressionStatement::getExpression);
} }
private <T> List<String> getRequiredImports(List<T> candidates, Function<T, Collection<String>> mapping) { private <T> List<String> appendImports(List<T> candidates, Function<T, Collection<String>> mapping) {
return getRequiredImports(candidates.stream(), mapping); return appendImports(candidates.stream(), mapping);
} }
private <T> List<String> getRequiredImports(Stream<T> candidates, Function<T, Collection<String>> mapping) { private <T> List<String> appendImports(Stream<T> candidates, Function<T, Collection<String>> mapping) {
return candidates.map(mapping) return candidates.map(mapping).flatMap(Collection::stream).collect(Collectors.toList());
.flatMap(Collection::stream)
.filter(this::requiresImport)
.collect(Collectors.toList());
} }
private String getUnqualifiedName(String name) { private String getUnqualifiedName(String name) {
@@ -386,12 +379,12 @@ public class KotlinSourceCodeWriter implements SourceCodeWriter<KotlinSourceCode
return name.substring(name.lastIndexOf(".") + 1); return name.substring(name.lastIndexOf(".") + 1);
} }
private boolean requiresImport(String name) { private boolean isImportCandidate(CompilationUnit<?> compilationUnit, String name) {
if (name == null || !name.contains(".")) { if (name == null || !name.contains(".")) {
return false; return false;
} }
String packageName = name.substring(0, name.lastIndexOf('.')); String packageName = name.substring(0, name.lastIndexOf('.'));
return !"java.lang".equals(packageName); return !"java.lang".equals(packageName) && !compilationUnit.getPackageName().equals(packageName);
} }
static class KotlinFormattingOptions implements FormattingOptions { static class KotlinFormattingOptions implements FormattingOptions {

View File

@@ -137,6 +137,20 @@ class GroovySourceCodeWriterTests {
" String trim(String value) {", " value.trim()", " }", "", "}"); " String trim(String value) {", " value.trim()", " }", "", "}");
} }
@Test
void importsFromSamePackageAreDiscarded() throws IOException {
GroovySourceCode sourceCode = new GroovySourceCode();
GroovyCompilationUnit compilationUnit = sourceCode.createCompilationUnit("com.example", "Test");
GroovyTypeDeclaration test = compilationUnit.createTypeDeclaration("Test");
test.addFieldDeclaration(GroovyFieldDeclaration.field("another").returning("com.example.Another"));
test.addFieldDeclaration(GroovyFieldDeclaration.field("sibling").returning("com.example.Sibling"));
test.addFieldDeclaration(GroovyFieldDeclaration.field("external").returning("com.example.another.External"));
List<String> lines = writeSingleType(sourceCode, "com/example/Test.groovy");
assertThat(lines).doesNotContain("import com.example.Another")
.doesNotContain("import com.example.Sibling")
.contains("import com.example.another.External");
}
@Test @Test
void springBootApplication() throws IOException { void springBootApplication() throws IOException {
GroovySourceCode sourceCode = new GroovySourceCode(); GroovySourceCode sourceCode = new GroovySourceCode();
@@ -205,11 +219,12 @@ class GroovySourceCodeWriterTests {
GroovySourceCode sourceCode = new GroovySourceCode(); GroovySourceCode sourceCode = new GroovySourceCode();
GroovyCompilationUnit compilationUnit = sourceCode.createCompilationUnit("com.example", "Test"); GroovyCompilationUnit compilationUnit = sourceCode.createCompilationUnit("com.example", "Test");
GroovyTypeDeclaration test = compilationUnit.createTypeDeclaration("Test"); GroovyTypeDeclaration test = compilationUnit.createTypeDeclaration("Test");
test.addFieldDeclaration( test.addFieldDeclaration(GroovyFieldDeclaration.field("testString")
GroovyFieldDeclaration.field("testString").modifiers(Modifier.PUBLIC).returning("com.example.One")); .modifiers(Modifier.PUBLIC)
.returning("com.example.another.One"));
List<String> lines = writeSingleType(sourceCode, "com/example/Test.groovy"); List<String> lines = writeSingleType(sourceCode, "com/example/Test.groovy");
assertThat(lines).containsExactly("package com.example", "", "import com.example.One", "", "class Test {", "", assertThat(lines).containsExactly("package com.example", "", "import com.example.another.One", "",
" public One testString", "", "}"); "class Test {", "", " public One testString", "", "}");
} }
@Test @Test
@@ -261,11 +276,12 @@ class GroovySourceCodeWriterTests {
@Test @Test
void annotationWithClassArrayAttribute() throws IOException { void annotationWithClassArrayAttribute() throws IOException {
List<String> lines = writeClassAnnotation(Annotation.name("org.springframework.test.TestApplication", List<String> lines = writeClassAnnotation(
(builder) -> builder.attribute("target", Class.class, "com.example.One", "com.example.Two"))); Annotation.name("org.springframework.test.TestApplication", (builder) -> builder.attribute("target",
assertThat(lines).containsExactly("package com.example", "", "import com.example.One", "import com.example.Two", Class.class, "com.example.another.One", "com.example.another.Two")));
"import org.springframework.test.TestApplication", "", "@TestApplication(target = [ One, Two ])", assertThat(lines).containsExactly("package com.example", "", "import com.example.another.One",
"class Test {", "", "}"); "import com.example.another.Two", "import org.springframework.test.TestApplication", "",
"@TestApplication(target = [ One, Two ])", "class Test {", "", "}");
} }
@Test @Test
@@ -273,8 +289,8 @@ class GroovySourceCodeWriterTests {
List<String> lines = writeClassAnnotation(Annotation.name("org.springframework.test.TestApplication", List<String> lines = writeClassAnnotation(Annotation.name("org.springframework.test.TestApplication",
(builder) -> builder.attribute("target", Class.class, "com.example.One") (builder) -> builder.attribute("target", Class.class, "com.example.One")
.attribute("unit", ChronoUnit.class, "java.time.temporal.ChronoUnit.NANOS"))); .attribute("unit", ChronoUnit.class, "java.time.temporal.ChronoUnit.NANOS")));
assertThat(lines).containsExactly("package com.example", "", "import com.example.One", assertThat(lines).containsExactly("package com.example", "", "import java.time.temporal.ChronoUnit",
"import java.time.temporal.ChronoUnit", "import org.springframework.test.TestApplication", "", "import org.springframework.test.TestApplication", "",
"@TestApplication(target = One, unit = ChronoUnit.NANOS)", "class Test {", "", "}"); "@TestApplication(target = One, unit = ChronoUnit.NANOS)", "class Test {", "", "}");
} }

View File

@@ -157,9 +157,9 @@ class JavaSourceCodeWriterTests {
JavaCompilationUnit compilationUnit = sourceCode.createCompilationUnit("com.example", "Test"); JavaCompilationUnit compilationUnit = sourceCode.createCompilationUnit("com.example", "Test");
JavaTypeDeclaration test = compilationUnit.createTypeDeclaration("Test"); JavaTypeDeclaration test = compilationUnit.createTypeDeclaration("Test");
test.addFieldDeclaration( test.addFieldDeclaration(
JavaFieldDeclaration.field("testString").modifiers(Modifier.PUBLIC).returning("com.example.One")); JavaFieldDeclaration.field("testString").modifiers(Modifier.PUBLIC).returning("com.another.One"));
List<String> lines = writeSingleType(sourceCode, "com/example/Test.java"); List<String> lines = writeSingleType(sourceCode, "com/example/Test.java");
assertThat(lines).containsExactly("package com.example;", "", "import com.example.One;", "", "class Test {", "", assertThat(lines).containsExactly("package com.example;", "", "import com.another.One;", "", "class Test {", "",
" public One testString;", "", "}"); " public One testString;", "", "}");
} }
@@ -213,6 +213,20 @@ class JavaSourceCodeWriterTests {
" public float testFloat = 99.999f;", "", " boolean testBool = true;", "", "}"); " public float testFloat = 99.999f;", "", " boolean testBool = true;", "", "}");
} }
@Test
void importsFromSamePackageAreDiscarded() throws IOException {
JavaSourceCode sourceCode = new JavaSourceCode();
JavaCompilationUnit compilationUnit = sourceCode.createCompilationUnit("com.example", "Test");
JavaTypeDeclaration test = compilationUnit.createTypeDeclaration("Test");
test.addFieldDeclaration(JavaFieldDeclaration.field("another").returning("com.example.Another"));
test.addFieldDeclaration(JavaFieldDeclaration.field("sibling").returning("com.example.Sibling"));
test.addFieldDeclaration(JavaFieldDeclaration.field("external").returning("com.example.another.External"));
List<String> lines = writeSingleType(sourceCode, "com/example/Test.java");
assertThat(lines).doesNotContain("import com.example.Another;")
.doesNotContain("import com.example.Sibling;")
.contains("import com.example.another.External;");
}
@Test @Test
void springBootApplication() throws IOException { void springBootApplication() throws IOException {
JavaSourceCode sourceCode = new JavaSourceCode(); JavaSourceCode sourceCode = new JavaSourceCode();
@@ -272,9 +286,9 @@ class JavaSourceCodeWriterTests {
@Test @Test
void annotationWithClassArrayAttribute() throws IOException { void annotationWithClassArrayAttribute() throws IOException {
List<String> lines = writeClassAnnotation(Annotation.name("org.springframework.test.TestApplication", List<String> lines = writeClassAnnotation(Annotation.name("org.springframework.test.TestApplication",
(builder) -> builder.attribute("target", Class.class, "com.example.One", "com.example.Two"))); (builder) -> builder.attribute("target", Class.class, "com.another.One", "com.another.Two")));
assertThat(lines).containsExactly("package com.example;", "", "import com.example.One;", assertThat(lines).containsExactly("package com.example;", "", "import com.another.One;",
"import com.example.Two;", "import org.springframework.test.TestApplication;", "", "import com.another.Two;", "import org.springframework.test.TestApplication;", "",
"@TestApplication(target = { One.class, Two.class })", "class Test {", "", "}"); "@TestApplication(target = { One.class, Two.class })", "class Test {", "", "}");
} }
@@ -283,8 +297,8 @@ class JavaSourceCodeWriterTests {
List<String> lines = writeClassAnnotation(Annotation.name("org.springframework.test.TestApplication", List<String> lines = writeClassAnnotation(Annotation.name("org.springframework.test.TestApplication",
(builder) -> builder.attribute("target", Class.class, "com.example.One") (builder) -> builder.attribute("target", Class.class, "com.example.One")
.attribute("unit", ChronoUnit.class, "java.time.temporal.ChronoUnit.NANOS"))); .attribute("unit", ChronoUnit.class, "java.time.temporal.ChronoUnit.NANOS")));
assertThat(lines).containsExactly("package com.example;", "", "import com.example.One;", assertThat(lines).containsExactly("package com.example;", "", "import java.time.temporal.ChronoUnit;",
"import java.time.temporal.ChronoUnit;", "import org.springframework.test.TestApplication;", "", "import org.springframework.test.TestApplication;", "",
"@TestApplication(target = One.class, unit = ChronoUnit.NANOS)", "class Test {", "", "}"); "@TestApplication(target = One.class, unit = ChronoUnit.NANOS)", "class Test {", "", "}");
} }

View File

@@ -168,10 +168,10 @@ class KotlinSourceCodeWriterTests {
KotlinCompilationUnit compilationUnit = sourceCode.createCompilationUnit("com.example", "Test"); KotlinCompilationUnit compilationUnit = sourceCode.createCompilationUnit("com.example", "Test");
KotlinTypeDeclaration test = compilationUnit.createTypeDeclaration("Test"); KotlinTypeDeclaration test = compilationUnit.createTypeDeclaration("Test");
test.addPropertyDeclaration( test.addPropertyDeclaration(
KotlinPropertyDeclaration.val("testProp").returning("com.example.One").emptyValue()); KotlinPropertyDeclaration.val("testProp").returning("com.example.another.One").emptyValue());
List<String> lines = writeSingleType(sourceCode, "com/example/Test.kt"); List<String> lines = writeSingleType(sourceCode, "com/example/Test.kt");
assertThat(lines).containsExactly("package com.example", "", "import com.example.One", "", "class Test {", "", assertThat(lines).containsExactly("package com.example", "", "import com.example.another.One", "",
" val testProp: One", "", "}"); "class Test {", "", " val testProp: One", "", "}");
} }
@Test @Test
@@ -295,6 +295,21 @@ class KotlinSourceCodeWriterTests {
" lateinit var testProp: Int", "", "}"); " lateinit var testProp: Int", "", "}");
} }
@Test
void importsFromSamePackageAreDiscarded() throws IOException {
KotlinSourceCode sourceCode = new KotlinSourceCode();
KotlinCompilationUnit compilationUnit = sourceCode.createCompilationUnit("com.example", "Test");
KotlinTypeDeclaration test = compilationUnit.createTypeDeclaration("Test");
test.addPropertyDeclaration(KotlinPropertyDeclaration.var("another").returning("com.example.Another").empty());
test.addPropertyDeclaration(KotlinPropertyDeclaration.var("sibling").returning("com.example.Sibling").empty());
test.addPropertyDeclaration(
KotlinPropertyDeclaration.var("external").returning("com.example.another.External").empty());
List<String> lines = writeSingleType(sourceCode, "com/example/Test.kt");
assertThat(lines).doesNotContain("import com.example.Another")
.doesNotContain("import com.example.Sibling")
.contains("import com.example.another.External");
}
@Test @Test
void springBootApplication() throws IOException { void springBootApplication() throws IOException {
KotlinSourceCode sourceCode = new KotlinSourceCode(); KotlinSourceCode sourceCode = new KotlinSourceCode();
@@ -346,10 +361,11 @@ class KotlinSourceCodeWriterTests {
@Test @Test
void annotationWithClassArrayAttribute() throws IOException { void annotationWithClassArrayAttribute() throws IOException {
List<String> lines = writeClassAnnotation(Annotation.name("org.springframework.test.TestApplication", List<String> lines = writeClassAnnotation(
(builder) -> builder.attribute("target", Class.class, "com.example.One", "com.example.Two"))); Annotation.name("org.springframework.test.TestApplication", (builder) -> builder.attribute("target",
assertThat(lines).containsExactly("package com.example", "", "import com.example.One", "import com.example.Two", Class.class, "com.example.another.One", "com.example.another.Two")));
"import org.springframework.test.TestApplication", "", assertThat(lines).containsExactly("package com.example", "", "import com.example.another.One",
"import com.example.another.Two", "import org.springframework.test.TestApplication", "",
"@TestApplication(target = [One::class, Two::class])", "class Test"); "@TestApplication(target = [One::class, Two::class])", "class Test");
} }
@@ -358,8 +374,8 @@ class KotlinSourceCodeWriterTests {
List<String> lines = writeClassAnnotation(Annotation.name("org.springframework.test.TestApplication", List<String> lines = writeClassAnnotation(Annotation.name("org.springframework.test.TestApplication",
(builder) -> builder.attribute("target", Class.class, "com.example.One") (builder) -> builder.attribute("target", Class.class, "com.example.One")
.attribute("unit", ChronoUnit.class, "java.time.temporal.ChronoUnit.NANOS"))); .attribute("unit", ChronoUnit.class, "java.time.temporal.ChronoUnit.NANOS")));
assertThat(lines).containsExactly("package com.example", "", "import com.example.One", assertThat(lines).containsExactly("package com.example", "", "import java.time.temporal.ChronoUnit",
"import java.time.temporal.ChronoUnit", "import org.springframework.test.TestApplication", "", "import org.springframework.test.TestApplication", "",
"@TestApplication(target = One::class, unit = ChronoUnit.NANOS)", "class Test"); "@TestApplication(target = One::class, unit = ChronoUnit.NANOS)", "class Test");
} }