diff --git a/hutool-core/src/main/java/cn/hutool/core/io/NioUtil.java b/hutool-core/src/main/java/cn/hutool/core/io/NioUtil.java
index 41b0766d6..f2d13b4b8 100644
--- a/hutool-core/src/main/java/cn/hutool/core/io/NioUtil.java
+++ b/hutool-core/src/main/java/cn/hutool/core/io/NioUtil.java
@@ -1,5 +1,6 @@
package cn.hutool.core.io;
+import cn.hutool.core.io.copy.ChannelCopier;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.CharsetUtil;
import cn.hutool.core.util.StrUtil;
@@ -7,7 +8,6 @@ import cn.hutool.core.util.StrUtil;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
-import java.nio.ByteBuffer;
import java.nio.MappedByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.FileChannel;
@@ -53,7 +53,23 @@ public class NioUtil {
* @throws IORuntimeException IO异常
*/
public static long copyByNIO(InputStream in, OutputStream out, int bufferSize, StreamProgress streamProgress) throws IORuntimeException {
- return copy(Channels.newChannel(in), Channels.newChannel(out), bufferSize, streamProgress);
+ return copyByNIO(in, out, bufferSize, -1, streamProgress);
+ }
+
+ /**
+ * 拷贝流
+ * 本方法不会关闭流
+ *
+ * @param in 输入流
+ * @param out 输出流
+ * @param bufferSize 缓存大小
+ * @param streamProgress 进度条
+ * @return 传输的byte数
+ * @throws IORuntimeException IO异常
+ * @since 5.7.8
+ */
+ public static long copyByNIO(InputStream in, OutputStream out, int bufferSize, long count, StreamProgress streamProgress) throws IORuntimeException {
+ return copy(Channels.newChannel(in), Channels.newChannel(out), bufferSize, count, streamProgress);
}
/**
@@ -114,31 +130,23 @@ public class NioUtil {
* @throws IORuntimeException IO异常
*/
public static long copy(ReadableByteChannel in, WritableByteChannel out, int bufferSize, StreamProgress streamProgress) throws IORuntimeException {
- Assert.notNull(in, "InputStream is null !");
- Assert.notNull(out, "OutputStream is null !");
+ return copy(in, out, bufferSize, -1, streamProgress);
+ }
- ByteBuffer byteBuffer = ByteBuffer.allocate(bufferSize <= 0 ? DEFAULT_BUFFER_SIZE : bufferSize);
- long size = 0;
- if (null != streamProgress) {
- streamProgress.start();
- }
- try {
- while (in.read(byteBuffer) != EOF) {
- byteBuffer.flip();// 写转读
- size += out.write(byteBuffer);
- byteBuffer.clear();
- if (null != streamProgress) {
- streamProgress.progress(size);
- }
- }
- } catch (IOException e) {
- throw new IORuntimeException(e);
- }
- if (null != streamProgress) {
- streamProgress.finish();
- }
-
- return size;
+ /**
+ * 拷贝流,使用NIO,不会关闭channel
+ *
+ * @param in {@link ReadableByteChannel}
+ * @param out {@link WritableByteChannel}
+ * @param bufferSize 缓冲大小,如果小于等于0,使用默认
+ * @param count 读取总长度
+ * @param streamProgress {@link StreamProgress}进度处理器
+ * @return 拷贝的字节数
+ * @throws IORuntimeException IO异常
+ * @since 5.7.8
+ */
+ public static long copy(ReadableByteChannel in, WritableByteChannel out, int bufferSize, long count, StreamProgress streamProgress) throws IORuntimeException {
+ return new ChannelCopier(bufferSize, count, streamProgress).copy(in, out);
}
/**
diff --git a/hutool-core/src/main/java/cn/hutool/core/io/copy/ChannelCopier.java b/hutool-core/src/main/java/cn/hutool/core/io/copy/ChannelCopier.java
new file mode 100755
index 000000000..7585cdd58
--- /dev/null
+++ b/hutool-core/src/main/java/cn/hutool/core/io/copy/ChannelCopier.java
@@ -0,0 +1,116 @@
+package cn.hutool.core.io.copy;
+
+import cn.hutool.core.io.IORuntimeException;
+import cn.hutool.core.io.IoUtil;
+import cn.hutool.core.io.StreamProgress;
+import cn.hutool.core.lang.Assert;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.nio.ByteBuffer;
+import java.nio.channels.ReadableByteChannel;
+import java.nio.channels.WritableByteChannel;
+
+/**
+ * {@link ReadableByteChannel} 向 {@link WritableByteChannel} 拷贝
+ *
+ * @author looly
+ * @since 5.7.8
+ */
+public class ChannelCopier extends IoCopier {
+
+ /**
+ * 构造
+ */
+ public ChannelCopier() {
+ this(IoUtil.DEFAULT_BUFFER_SIZE);
+ }
+
+ /**
+ * 构造
+ *
+ * @param bufferSize 缓存大小
+ */
+ public ChannelCopier(int bufferSize) {
+ this(bufferSize, -1);
+ }
+
+ /**
+ * 构造
+ *
+ * @param bufferSize 缓存大小
+ * @param count 拷贝总数
+ */
+ public ChannelCopier(int bufferSize, long count) {
+ this(bufferSize, count, null);
+ }
+
+ /**
+ * 构造
+ *
+ * @param bufferSize 缓存大小
+ * @param count 拷贝总数
+ * @param progress 进度条
+ */
+ public ChannelCopier(int bufferSize, long count, StreamProgress progress) {
+ super(bufferSize, count, progress);
+ }
+
+ @Override
+ public long copy(ReadableByteChannel source, WritableByteChannel target) {
+ Assert.notNull(source, "InputStream is null !");
+ Assert.notNull(target, "OutputStream is null !");
+
+ final StreamProgress progress = this.progress;
+ if (null != progress) {
+ progress.start();
+ }
+ final long size;
+ try {
+ size = doCopy(source, target, ByteBuffer.allocate(bufferSize(this.count)), progress);
+ } catch (IOException e) {
+ throw new IORuntimeException(e);
+ }
+
+ if (null != progress) {
+ progress.finish();
+ }
+ return size;
+ }
+
+ /**
+ * 执行拷贝,如果限制最大长度,则按照最大长度读取,否则一直读取直到遇到-1
+ *
+ * @param source {@link InputStream}
+ * @param target {@link OutputStream}
+ * @param buffer 缓存
+ * @param progress 进度条
+ * @return 拷贝总长度
+ * @throws IOException IO异常
+ */
+ private long doCopy(ReadableByteChannel source, WritableByteChannel target, ByteBuffer buffer, StreamProgress progress) throws IOException {
+ long numToRead = this.count > 0 ? this.count : Long.MAX_VALUE;
+ long total = 0;
+
+ int read;
+ while (numToRead > 0) {
+ read = source.read(buffer);
+ if (read < 0) {
+ // 提前读取到末尾
+ break;
+ }
+ buffer.flip();// 写转读
+ target.write(buffer);
+ buffer.clear();
+
+ numToRead -= read;
+ total += read;
+ if (null != progress) {
+ progress.progress(total);
+ }
+ }
+
+ return total;
+ }
+}
diff --git a/hutool-core/src/main/java/cn/hutool/core/io/copy/IoCopier.java b/hutool-core/src/main/java/cn/hutool/core/io/copy/IoCopier.java
index a80c5079f..e9299d349 100755
--- a/hutool-core/src/main/java/cn/hutool/core/io/copy/IoCopier.java
+++ b/hutool-core/src/main/java/cn/hutool/core/io/copy/IoCopier.java
@@ -29,12 +29,12 @@ public abstract class IoCopier {
* 构造
*
* @param bufferSize 缓存大小,< 0 表示默认{@link IoUtil#DEFAULT_BUFFER_SIZE}
- * @param count 拷贝总数
+ * @param count 拷贝总数,-1表示无限制
* @param progress 进度条
*/
public IoCopier(int bufferSize, long count, StreamProgress progress) {
this.bufferSize = bufferSize > 0 ? bufferSize : IoUtil.DEFAULT_BUFFER_SIZE;
- this.count = count;
+ this.count = count <= 0 ? Long.MAX_VALUE : count;
this.progress = progress;
}
@@ -52,9 +52,6 @@ public abstract class IoCopier {
* @return 缓存大小
*/
protected int bufferSize(long count) {
- if(count < 0){
- count = Long.MAX_VALUE;
- }
return Math.min(this.bufferSize, (int)count);
}
}
diff --git a/hutool-http/src/main/java/cn/hutool/http/HttpResponse.java b/hutool-http/src/main/java/cn/hutool/http/HttpResponse.java
index 060cca9a8..80f2f62fd 100644
--- a/hutool-http/src/main/java/cn/hutool/http/HttpResponse.java
+++ b/hutool-http/src/main/java/cn/hutool/http/HttpResponse.java
@@ -254,8 +254,9 @@ public class HttpResponse extends HttpBase implements Closeable {
if (null == out) {
throw new NullPointerException("[out] is null!");
}
+ final int contentLength = Convert.toInt(header(Header.CONTENT_LENGTH), -1);
try {
- return IoUtil.copyByNIO(bodyStream(), out, IoUtil.DEFAULT_BUFFER_SIZE, streamProgress);
+ return IoUtil.copyByNIO(bodyStream(), out, IoUtil.DEFAULT_BUFFER_SIZE, contentLength, streamProgress);
} finally {
IoUtil.close(this);
if (isCloseOut) {
@@ -462,10 +463,11 @@ public class HttpResponse extends HttpBase implements Closeable {
return;
}
- int contentLength = Convert.toInt(header(Header.CONTENT_LENGTH), 0);
- final FastByteArrayOutputStream out = contentLength > 0 ? new FastByteArrayOutputStream(contentLength) : new FastByteArrayOutputStream();
+ final int contentLength = Convert.toInt(header(Header.CONTENT_LENGTH), -1);
+ final FastByteArrayOutputStream out = contentLength > 0 ?
+ new FastByteArrayOutputStream(contentLength) : new FastByteArrayOutputStream();
try {
- IoUtil.copy(in, out);
+ IoUtil.copy(in, out, -1, -1, null);
} catch (IORuntimeException e) {
//noinspection StatementWithEmptyBody
if (e.getCause() instanceof EOFException || StrUtil.containsIgnoreCase(e.getMessage(), "Premature EOF")) {