From e2428714a0ce7945b5c9fd485299c98659ac8d08 Mon Sep 17 00:00:00 2001 From: Looly Date: Wed, 13 May 2020 06:24:53 +0800 Subject: [PATCH] add method and test --- CHANGELOG.md | 3 + .../cn/hutool/core/collection/CollUtil.java | 2 +- .../main/java/cn/hutool/core/io/IoUtil.java | 43 +++++++++-- .../core/io/ValidateObjectInputStream.java | 72 +++++++++++++++---- 4 files changed, 100 insertions(+), 20 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c84e852e5..49d4b4c03 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,10 @@ * 【extra 】 增加Sftp.lsEntries方法,Ftp和Sftp增加recursiveDownloadFolder(pr#121@Gitee) * 【system 】 OshiUtil增加getNetworkIFs方法 * 【core 】 CollUtil增加unionDistinct、unionAll方法(pr#122@Gitee) +* 【core 】 增加IoUtil.readObj重载,通过ValidateObjectInputStream由用户自定义安全检查。 + ### Bug修复 +* 【core 】 修复IoUtil.readObj中反序列化安全检查导致的一些问题,去掉安全检查。 ------------------------------------------------------------------------------------------------------------- diff --git a/hutool-core/src/main/java/cn/hutool/core/collection/CollUtil.java b/hutool-core/src/main/java/cn/hutool/core/collection/CollUtil.java index ec3c5437a..c3ae49c56 100644 --- a/hutool-core/src/main/java/cn/hutool/core/collection/CollUtil.java +++ b/hutool-core/src/main/java/cn/hutool/core/collection/CollUtil.java @@ -293,7 +293,7 @@ public class CollUtil { return coll1; } - final ArrayList result = new ArrayList<>(); + final List result = new ArrayList<>(); final Map map1 = countMap(coll1); final Map map2 = countMap(coll2); final Set elts = newHashSet(coll2); diff --git a/hutool-core/src/main/java/cn/hutool/core/io/IoUtil.java b/hutool-core/src/main/java/cn/hutool/core/io/IoUtil.java index cda8b5efe..2896c0102 100644 --- a/hutool-core/src/main/java/cn/hutool/core/io/IoUtil.java +++ b/hutool-core/src/main/java/cn/hutool/core/io/IoUtil.java @@ -20,7 +20,6 @@ import java.io.Flushable; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; -import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.OutputStream; import java.io.OutputStreamWriter; @@ -648,14 +647,16 @@ public class IoUtil { /** * 从流中读取对象,即对象的反序列化 * + *

+ * 注意!!! 此方法不会检查反序列化安全,可能存在反序列化漏洞风险!!! + *

+ * * @param 读取对象的类型 * @param in 输入流 * @return 输出流 * @throws IORuntimeException IO异常 * @throws UtilException ClassNotFoundException包装 - * @deprecated 由于存在对象反序列化漏洞风险,请使用{@link #readObj(InputStream, Class)} */ - @Deprecated public static T readObj(InputStream in) throws IORuntimeException, UtilException { return readObj(in, null); } @@ -663,6 +664,10 @@ public class IoUtil { /** * 从流中读取对象,即对象的反序列化,读取后不关闭流 * + *

+ * 注意!!! 此方法不会检查反序列化安全,可能存在反序列化漏洞风险!!! + *

+ * * @param 读取对象的类型 * @param in 输入流 * @param clazz 读取对象类型 @@ -671,14 +676,38 @@ public class IoUtil { * @throws UtilException ClassNotFoundException包装 */ public static T readObj(InputStream in, Class clazz) throws IORuntimeException, UtilException { + try { + return readObj((in instanceof ValidateObjectInputStream) ? + (ValidateObjectInputStream) in : new ValidateObjectInputStream(in), + clazz); + } catch (IOException e) { + throw new IORuntimeException(e); + } + } + + /** + * 从流中读取对象,即对象的反序列化,读取后不关闭流 + * + *

+ * 此方法使用了{@link ValidateObjectInputStream}中的黑白名单方式过滤类,用于避免反序列化漏洞
+ * 通过构造{@link ValidateObjectInputStream},调用{@link ValidateObjectInputStream#accept(Class[])} + * 或者{@link ValidateObjectInputStream#refuse(Class[])}方法添加可以被序列化的类或者禁止序列化的类。 + *

+ * + * @param 读取对象的类型 + * @param in 输入流,使用{@link ValidateObjectInputStream}中的黑白名单方式过滤类,用于避免反序列化漏洞 + * @param clazz 读取对象类型 + * @return 输出流 + * @throws IORuntimeException IO异常 + * @throws UtilException ClassNotFoundException包装 + */ + public static T readObj(ValidateObjectInputStream in, Class clazz) throws IORuntimeException, UtilException { if (in == null) { throw new IllegalArgumentException("The InputStream must not be null"); } - ObjectInputStream ois; try { - ois = new ValidateObjectInputStream(in, clazz); //noinspection unchecked - return (T) ois.readObject(); + return (T) in.readObject(); } catch (IOException e) { throw new IORuntimeException(e); } catch (ClassNotFoundException e) { @@ -989,7 +1018,7 @@ public class IoUtil { * * @param out 输出流 * @param isCloseOut 写入完毕是否关闭输出流 - * @param obj 写入的对象内容 + * @param obj 写入的对象内容 * @throws IORuntimeException IO异常 * @since 5.3.3 */ diff --git a/hutool-core/src/main/java/cn/hutool/core/io/ValidateObjectInputStream.java b/hutool-core/src/main/java/cn/hutool/core/io/ValidateObjectInputStream.java index ae077f412..91c1a130e 100644 --- a/hutool-core/src/main/java/cn/hutool/core/io/ValidateObjectInputStream.java +++ b/hutool-core/src/main/java/cn/hutool/core/io/ValidateObjectInputStream.java @@ -1,10 +1,14 @@ package cn.hutool.core.io; +import cn.hutool.core.collection.CollUtil; + import java.io.IOException; import java.io.InputStream; import java.io.InvalidClassException; import java.io.ObjectInputStream; import java.io.ObjectStreamClass; +import java.util.HashSet; +import java.util.Set; /** * 带有类验证的对象流,用于避免反序列化漏洞
@@ -15,27 +19,48 @@ import java.io.ObjectStreamClass; */ public class ValidateObjectInputStream extends ObjectInputStream { - private Class acceptClass; + private Set whiteClassSet; + private Set blackClassSet; /** * 构造 * * @param inputStream 流 - * @param acceptClass 接受的类 + * @param acceptClasses 白名单的类 * @throws IOException IO异常 */ - public ValidateObjectInputStream(InputStream inputStream, Class acceptClass) throws IOException { + public ValidateObjectInputStream(InputStream inputStream, Class... acceptClasses) throws IOException { super(inputStream); - this.acceptClass = acceptClass; + accept(acceptClasses); + } + + /** + * 禁止反序列化的类,用于反序列化验证 + * + * @param refuseClasses 禁止反序列化的类 + * @since 5.3.5 + */ + public void refuse(Class... refuseClasses) { + if(null == this.blackClassSet){ + this.blackClassSet = new HashSet<>(); + } + for (Class acceptClass : refuseClasses) { + this.blackClassSet.add(acceptClass.getName()); + } } /** * 接受反序列化的类,用于反序列化验证 * - * @param acceptClass 接受反序列化的类 + * @param acceptClasses 接受反序列化的类 */ - public void accept(Class acceptClass) { - this.acceptClass = acceptClass; + public void accept(Class... acceptClasses) { + if(null == this.whiteClassSet){ + this.whiteClassSet = new HashSet<>(); + } + for (Class acceptClass : acceptClasses) { + this.whiteClassSet.add(acceptClass.getName()); + } } /** @@ -43,11 +68,34 @@ public class ValidateObjectInputStream extends ObjectInputStream { */ @Override protected Class resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException { - if (null != this.acceptClass && false == desc.getName().equals(acceptClass.getName())) { - throw new InvalidClassException( - "Unauthorized deserialization attempt", - desc.getName()); - } + validateClassName(desc.getName()); return super.resolveClass(desc); } + + /** + * 验证反序列化的类是否合法 + * @param className 类名 + * @throws InvalidClassException 非法类 + */ + private void validateClassName(String className) throws InvalidClassException { + // 黑名单 + if(CollUtil.isNotEmpty(this.blackClassSet)){ + if(this.blackClassSet.contains(className)){ + throw new InvalidClassException("Unauthorized deserialization attempt by black list", className); + } + } + + if(CollUtil.isEmpty(this.whiteClassSet)){ + return; + } + if(className.startsWith("java.")){ + // java中的类默认在白名单中 + return; + } + if(this.whiteClassSet.contains(className)){ + return; + } + + throw new InvalidClassException("Unauthorized deserialization attempt", className); + } }