提交Collectors.toMap的对null友好实现,避免NPE

This commit is contained in:
VampireAchao 2022-01-15 11:48:51 +08:00 committed by achao
parent 9da17cf6c4
commit 78ac9fcdef
3 changed files with 83 additions and 9 deletions

View File

@ -199,7 +199,7 @@ public class CollStreamUtil {
if (CollUtil.isEmpty(collection) || key1 == null || key2 == null) { if (CollUtil.isEmpty(collection) || key1 == null || key2 == null) {
return Collections.emptyMap(); return Collections.emptyMap();
} }
return groupBy(collection, key1, Collectors.toMap(key2, Function.identity(), (l, r) -> l), isParallel); return groupBy(collection, key1, CollectorUtil.toMap(key2, Function.identity(), (l, r) -> l), isParallel);
} }
/** /**

View File

@ -7,6 +7,7 @@ import java.util.Collections;
import java.util.EnumSet; import java.util.EnumSet;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.Set;
import java.util.StringJoiner; import java.util.StringJoiner;
import java.util.function.BiConsumer; import java.util.function.BiConsumer;
import java.util.function.BinaryOperator; import java.util.function.BinaryOperator;
@ -22,6 +23,16 @@ import java.util.stream.Collector;
*/ */
public class CollectorUtil { public class CollectorUtil {
/**
* 说明已包含IDENTITY_FINISH特征 Characteristics.IDENTITY_FINISH 的缩写
*/
public static final Set<Collector.Characteristics> CH_ID
= Collections.unmodifiableSet(EnumSet.of(Collector.Characteristics.IDENTITY_FINISH));
/**
* 说明不包含IDENTITY_FINISH特征
*/
public static final Set<Collector.Characteristics> CH_NOID = Collections.emptySet();
/** /**
* 提供任意对象的Join操作的{@link Collector}实现对象默认调用toString方法 * 提供任意对象的Join操作的{@link Collector}实现对象默认调用toString方法
* *
@ -93,17 +104,12 @@ public class CollectorUtil {
A container = m.computeIfAbsent(key, k -> downstreamSupplier.get()); A container = m.computeIfAbsent(key, k -> downstreamSupplier.get());
downstreamAccumulator.accept(container, t); downstreamAccumulator.accept(container, t);
}; };
BinaryOperator<Map<K, A>> merger = (m1, m2) -> { BinaryOperator<Map<K, A>> merger = mapMerger(downstream.combiner());
for (Map.Entry<K, A> e : m2.entrySet()) {
m1.merge(e.getKey(), e.getValue(), downstream.combiner());
}
return m1;
};
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
Supplier<Map<K, A>> mangledFactory = (Supplier<Map<K, A>>) mapFactory; Supplier<Map<K, A>> mangledFactory = (Supplier<Map<K, A>>) mapFactory;
if (downstream.characteristics().contains(Collector.Characteristics.IDENTITY_FINISH)) { if (downstream.characteristics().contains(Collector.Characteristics.IDENTITY_FINISH)) {
return new SimpleCollector<>(mangledFactory, accumulator, merger, Collections.unmodifiableSet(EnumSet.of(Collector.Characteristics.IDENTITY_FINISH))); return new SimpleCollector<>(mangledFactory, accumulator, merger, CH_ID);
} else { } else {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
Function<A, A> downstreamFinisher = (Function<A, A>) downstream.finisher(); Function<A, A> downstreamFinisher = (Function<A, A>) downstream.finisher();
@ -113,7 +119,7 @@ public class CollectorUtil {
M castResult = (M) intermediate; M castResult = (M) intermediate;
return castResult; return castResult;
}; };
return new SimpleCollector<>(mangledFactory, accumulator, merger, finisher, Collections.emptySet()); return new SimpleCollector<>(mangledFactory, accumulator, merger, finisher, CH_NOID);
} }
} }
@ -134,4 +140,65 @@ public class CollectorUtil {
return groupingBy(classifier, HashMap::new, downstream); return groupingBy(classifier, HashMap::new, downstream);
} }
/**
* 对null友好的 toMap 操作的 {@link Collector}实现默认使用HashMap
*
* @param keyMapper 指定map中的key
* @param valueMapper 指定map中的value
* @param mergeFunction 合并前对value进行的操作
* @param <T> 实体类型
* @param <K> map中key的类型
* @param <U> map中value的类型
* @return 对null友好的 toMap 操作的 {@link Collector}实现
*/
public static <T, K, U>
Collector<T, ?, Map<K, U>> toMap(Function<? super T, ? extends K> keyMapper,
Function<? super T, ? extends U> valueMapper,
BinaryOperator<U> mergeFunction) {
return toMap(keyMapper, valueMapper, mergeFunction, HashMap::new);
}
/**
* 对null友好的 toMap 操作的 {@link Collector}实现
*
* @param keyMapper 指定map中的key
* @param valueMapper 指定map中的value
* @param mergeFunction 合并前对value进行的操作
* @param mapSupplier 最终需要的map类型
* @param <T> 实体类型
* @param <K> map中key的类型
* @param <U> map中value的类型
* @param <M> map的类型
* @return 对null友好的 toMap 操作的 {@link Collector}实现
*/
public static <T, K, U, M extends Map<K, U>>
Collector<T, ?, M> toMap(Function<? super T, ? extends K> keyMapper,
Function<? super T, ? extends U> valueMapper,
BinaryOperator<U> mergeFunction,
Supplier<M> mapSupplier) {
BiConsumer<M, T> accumulator
= (map, element) -> map.put(Opt.ofNullable(element).map(keyMapper).get(), Opt.ofNullable(element).map(valueMapper).get());
return new SimpleCollector<>(mapSupplier, accumulator, mapMerger(mergeFunction), CH_ID);
}
/**
* 用户合并map的BinaryOperator传入合并前需要对value进行的操作
*
* @param mergeFunction 合并前需要对value进行的操作
* @param <K> key的类型
* @param <V> value的类型
* @param <M> map
* @return 用户合并map的BinaryOperator
*/
public static <K, V, M extends Map<K, V>> BinaryOperator<M> mapMerger(BinaryOperator<V> mergeFunction) {
return (m1, m2) -> {
for (Map.Entry<K, V> e : m2.entrySet()) {
m1.merge(e.getKey(), e.getValue(), mergeFunction);
}
return m1;
};
}
} }

View File

@ -150,6 +150,13 @@ public class CollStreamUtilTest {
compare.put(2L, map2); compare.put(2L, map2);
Assert.assertEquals(compare, map); Assert.assertEquals(compare, map);
// 对null友好
Map<Long, Map<Long, Student>> termIdClassIdStudentMap = CollStreamUtil.group2Map(Arrays.asList(null, new Student(2, 2, 1, "王五")), Student::getTermId, Student::getClassId);
Map<Long, Map<Long, Student>> termIdClassIdStudentCompareMap = new HashMap<Long, Map<Long, Student>>() {{
put(null, MapUtil.of(null, null));
put(2L, MapUtil.of(2L, new Student(2, 2, 1, "王五")));
}};
Assert.assertEquals(termIdClassIdStudentCompareMap, termIdClassIdStudentMap);
} }
@Test @Test