add method and test

This commit is contained in:
Looly 2020-05-13 06:24:53 +08:00
parent c047fd5a23
commit e2428714a0
4 changed files with 100 additions and 20 deletions

View File

@ -10,7 +10,10 @@
* 【extra 】 增加Sftp.lsEntries方法Ftp和Sftp增加recursiveDownloadFolderpr#121@Gitee
* 【system 】 OshiUtil增加getNetworkIFs方法
* 【core 】 CollUtil增加unionDistinct、unionAll方法pr#122@Gitee
* 【core 】 增加IoUtil.readObj重载通过ValidateObjectInputStream由用户自定义安全检查。
### Bug修复
* 【core 】 修复IoUtil.readObj中反序列化安全检查导致的一些问题去掉安全检查。
-------------------------------------------------------------------------------------------------------------

View File

@ -293,7 +293,7 @@ public class CollUtil {
return coll1;
}
final ArrayList<T> result = new ArrayList<>();
final List<T> result = new ArrayList<>();
final Map<T, Integer> map1 = countMap(coll1);
final Map<T, Integer> map2 = countMap(coll2);
final Set<T> elts = newHashSet(coll2);

View File

@ -20,7 +20,6 @@ import java.io.Flushable;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
@ -648,14 +647,16 @@ public class IoUtil {
/**
* 从流中读取对象即对象的反序列化
*
* <p>
* 注意 此方法不会检查反序列化安全可能存在反序列化漏洞风险
* </p>
*
* @param <T> 读取对象的类型
* @param in 输入流
* @return 输出流
* @throws IORuntimeException IO异常
* @throws UtilException ClassNotFoundException包装
* @deprecated 由于存在对象反序列化漏洞风险请使用{@link #readObj(InputStream, Class)}
*/
@Deprecated
public static <T> T readObj(InputStream in) throws IORuntimeException, UtilException {
return readObj(in, null);
}
@ -663,6 +664,10 @@ public class IoUtil {
/**
* 从流中读取对象即对象的反序列化读取后不关闭流
*
* <p>
* 注意 此方法不会检查反序列化安全可能存在反序列化漏洞风险
* </p>
*
* @param <T> 读取对象的类型
* @param in 输入流
* @param clazz 读取对象类型
@ -671,14 +676,38 @@ public class IoUtil {
* @throws UtilException ClassNotFoundException包装
*/
public static <T> T readObj(InputStream in, Class<T> clazz) throws IORuntimeException, UtilException {
try {
return readObj((in instanceof ValidateObjectInputStream) ?
(ValidateObjectInputStream) in : new ValidateObjectInputStream(in),
clazz);
} catch (IOException e) {
throw new IORuntimeException(e);
}
}
/**
* 从流中读取对象即对象的反序列化读取后不关闭流
*
* <p>
* 此方法使用了{@link ValidateObjectInputStream}中的黑白名单方式过滤类用于避免反序列化漏洞<br>
* 通过构造{@link ValidateObjectInputStream}调用{@link ValidateObjectInputStream#accept(Class[])}
* 或者{@link ValidateObjectInputStream#refuse(Class[])}方法添加可以被序列化的类或者禁止序列化的类
* </p>
*
* @param <T> 读取对象的类型
* @param in 输入流使用{@link ValidateObjectInputStream}中的黑白名单方式过滤类用于避免反序列化漏洞
* @param clazz 读取对象类型
* @return 输出流
* @throws IORuntimeException IO异常
* @throws UtilException ClassNotFoundException包装
*/
public static <T> T readObj(ValidateObjectInputStream in, Class<T> clazz) throws IORuntimeException, UtilException {
if (in == null) {
throw new IllegalArgumentException("The InputStream must not be null");
}
ObjectInputStream ois;
try {
ois = new ValidateObjectInputStream(in, clazz);
//noinspection unchecked
return (T) ois.readObject();
return (T) in.readObject();
} catch (IOException e) {
throw new IORuntimeException(e);
} catch (ClassNotFoundException e) {
@ -989,7 +1018,7 @@ public class IoUtil {
*
* @param out 输出流
* @param isCloseOut 写入完毕是否关闭输出流
* @param obj 写入的对象内容
* @param obj 写入的对象内容
* @throws IORuntimeException IO异常
* @since 5.3.3
*/

View File

@ -1,10 +1,14 @@
package cn.hutool.core.io;
import cn.hutool.core.collection.CollUtil;
import java.io.IOException;
import java.io.InputStream;
import java.io.InvalidClassException;
import java.io.ObjectInputStream;
import java.io.ObjectStreamClass;
import java.util.HashSet;
import java.util.Set;
/**
* 带有类验证的对象流用于避免反序列化漏洞<br>
@ -15,27 +19,48 @@ import java.io.ObjectStreamClass;
*/
public class ValidateObjectInputStream extends ObjectInputStream {
private Class<?> acceptClass;
private Set<String> whiteClassSet;
private Set<String> blackClassSet;
/**
* 构造
*
* @param inputStream
* @param acceptClass 接受的类
* @param acceptClasses 白名单的类
* @throws IOException IO异常
*/
public ValidateObjectInputStream(InputStream inputStream, Class<?> acceptClass) throws IOException {
public ValidateObjectInputStream(InputStream inputStream, Class<?>... acceptClasses) throws IOException {
super(inputStream);
this.acceptClass = acceptClass;
accept(acceptClasses);
}
/**
* 禁止反序列化的类用于反序列化验证
*
* @param refuseClasses 禁止反序列化的类
* @since 5.3.5
*/
public void refuse(Class<?>... refuseClasses) {
if(null == this.blackClassSet){
this.blackClassSet = new HashSet<>();
}
for (Class<?> acceptClass : refuseClasses) {
this.blackClassSet.add(acceptClass.getName());
}
}
/**
* 接受反序列化的类用于反序列化验证
*
* @param acceptClass 接受反序列化的类
* @param acceptClasses 接受反序列化的类
*/
public void accept(Class<?> acceptClass) {
this.acceptClass = acceptClass;
public void accept(Class<?>... acceptClasses) {
if(null == this.whiteClassSet){
this.whiteClassSet = new HashSet<>();
}
for (Class<?> acceptClass : acceptClasses) {
this.whiteClassSet.add(acceptClass.getName());
}
}
/**
@ -43,11 +68,34 @@ public class ValidateObjectInputStream extends ObjectInputStream {
*/
@Override
protected Class<?> resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException {
if (null != this.acceptClass && false == desc.getName().equals(acceptClass.getName())) {
throw new InvalidClassException(
"Unauthorized deserialization attempt",
desc.getName());
}
validateClassName(desc.getName());
return super.resolveClass(desc);
}
/**
* 验证反序列化的类是否合法
* @param className 类名
* @throws InvalidClassException 非法类
*/
private void validateClassName(String className) throws InvalidClassException {
// 黑名单
if(CollUtil.isNotEmpty(this.blackClassSet)){
if(this.blackClassSet.contains(className)){
throw new InvalidClassException("Unauthorized deserialization attempt by black list", className);
}
}
if(CollUtil.isEmpty(this.whiteClassSet)){
return;
}
if(className.startsWith("java.")){
// java中的类默认在白名单中
return;
}
if(this.whiteClassSet.contains(className)){
return;
}
throw new InvalidClassException("Unauthorized deserialization attempt", className);
}
}