diff --git a/hutool-core/src/main/java/cn/hutool/core/io/NioUtil.java b/hutool-core/src/main/java/cn/hutool/core/io/NioUtil.java index 41b0766d6..f2d13b4b8 100644 --- a/hutool-core/src/main/java/cn/hutool/core/io/NioUtil.java +++ b/hutool-core/src/main/java/cn/hutool/core/io/NioUtil.java @@ -1,5 +1,6 @@ package cn.hutool.core.io; +import cn.hutool.core.io.copy.ChannelCopier; import cn.hutool.core.lang.Assert; import cn.hutool.core.util.CharsetUtil; import cn.hutool.core.util.StrUtil; @@ -7,7 +8,6 @@ import cn.hutool.core.util.StrUtil; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; -import java.nio.ByteBuffer; import java.nio.MappedByteBuffer; import java.nio.channels.Channels; import java.nio.channels.FileChannel; @@ -53,7 +53,23 @@ public class NioUtil { * @throws IORuntimeException IO异常 */ public static long copyByNIO(InputStream in, OutputStream out, int bufferSize, StreamProgress streamProgress) throws IORuntimeException { - return copy(Channels.newChannel(in), Channels.newChannel(out), bufferSize, streamProgress); + return copyByNIO(in, out, bufferSize, -1, streamProgress); + } + + /** + * 拷贝流
+ * 本方法不会关闭流 + * + * @param in 输入流 + * @param out 输出流 + * @param bufferSize 缓存大小 + * @param streamProgress 进度条 + * @return 传输的byte数 + * @throws IORuntimeException IO异常 + * @since 5.7.8 + */ + public static long copyByNIO(InputStream in, OutputStream out, int bufferSize, long count, StreamProgress streamProgress) throws IORuntimeException { + return copy(Channels.newChannel(in), Channels.newChannel(out), bufferSize, count, streamProgress); } /** @@ -114,31 +130,23 @@ public class NioUtil { * @throws IORuntimeException IO异常 */ public static long copy(ReadableByteChannel in, WritableByteChannel out, int bufferSize, StreamProgress streamProgress) throws IORuntimeException { - Assert.notNull(in, "InputStream is null !"); - Assert.notNull(out, "OutputStream is null !"); + return copy(in, out, bufferSize, -1, streamProgress); + } - ByteBuffer byteBuffer = ByteBuffer.allocate(bufferSize <= 0 ? DEFAULT_BUFFER_SIZE : bufferSize); - long size = 0; - if (null != streamProgress) { - streamProgress.start(); - } - try { - while (in.read(byteBuffer) != EOF) { - byteBuffer.flip();// 写转读 - size += out.write(byteBuffer); - byteBuffer.clear(); - if (null != streamProgress) { - streamProgress.progress(size); - } - } - } catch (IOException e) { - throw new IORuntimeException(e); - } - if (null != streamProgress) { - streamProgress.finish(); - } - - return size; + /** + * 拷贝流,使用NIO,不会关闭channel + * + * @param in {@link ReadableByteChannel} + * @param out {@link WritableByteChannel} + * @param bufferSize 缓冲大小,如果小于等于0,使用默认 + * @param count 读取总长度 + * @param streamProgress {@link StreamProgress}进度处理器 + * @return 拷贝的字节数 + * @throws IORuntimeException IO异常 + * @since 5.7.8 + */ + public static long copy(ReadableByteChannel in, WritableByteChannel out, int bufferSize, long count, StreamProgress streamProgress) throws IORuntimeException { + return new ChannelCopier(bufferSize, count, streamProgress).copy(in, out); } /** diff --git a/hutool-core/src/main/java/cn/hutool/core/io/copy/ChannelCopier.java b/hutool-core/src/main/java/cn/hutool/core/io/copy/ChannelCopier.java new file mode 100755 index 000000000..7585cdd58 --- /dev/null +++ b/hutool-core/src/main/java/cn/hutool/core/io/copy/ChannelCopier.java @@ -0,0 +1,116 @@ +package cn.hutool.core.io.copy; + +import cn.hutool.core.io.IORuntimeException; +import cn.hutool.core.io.IoUtil; +import cn.hutool.core.io.StreamProgress; +import cn.hutool.core.lang.Assert; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.channels.ReadableByteChannel; +import java.nio.channels.WritableByteChannel; + +/** + * {@link ReadableByteChannel} 向 {@link WritableByteChannel} 拷贝 + * + * @author looly + * @since 5.7.8 + */ +public class ChannelCopier extends IoCopier { + + /** + * 构造 + */ + public ChannelCopier() { + this(IoUtil.DEFAULT_BUFFER_SIZE); + } + + /** + * 构造 + * + * @param bufferSize 缓存大小 + */ + public ChannelCopier(int bufferSize) { + this(bufferSize, -1); + } + + /** + * 构造 + * + * @param bufferSize 缓存大小 + * @param count 拷贝总数 + */ + public ChannelCopier(int bufferSize, long count) { + this(bufferSize, count, null); + } + + /** + * 构造 + * + * @param bufferSize 缓存大小 + * @param count 拷贝总数 + * @param progress 进度条 + */ + public ChannelCopier(int bufferSize, long count, StreamProgress progress) { + super(bufferSize, count, progress); + } + + @Override + public long copy(ReadableByteChannel source, WritableByteChannel target) { + Assert.notNull(source, "InputStream is null !"); + Assert.notNull(target, "OutputStream is null !"); + + final StreamProgress progress = this.progress; + if (null != progress) { + progress.start(); + } + final long size; + try { + size = doCopy(source, target, ByteBuffer.allocate(bufferSize(this.count)), progress); + } catch (IOException e) { + throw new IORuntimeException(e); + } + + if (null != progress) { + progress.finish(); + } + return size; + } + + /** + * 执行拷贝,如果限制最大长度,则按照最大长度读取,否则一直读取直到遇到-1 + * + * @param source {@link InputStream} + * @param target {@link OutputStream} + * @param buffer 缓存 + * @param progress 进度条 + * @return 拷贝总长度 + * @throws IOException IO异常 + */ + private long doCopy(ReadableByteChannel source, WritableByteChannel target, ByteBuffer buffer, StreamProgress progress) throws IOException { + long numToRead = this.count > 0 ? this.count : Long.MAX_VALUE; + long total = 0; + + int read; + while (numToRead > 0) { + read = source.read(buffer); + if (read < 0) { + // 提前读取到末尾 + break; + } + buffer.flip();// 写转读 + target.write(buffer); + buffer.clear(); + + numToRead -= read; + total += read; + if (null != progress) { + progress.progress(total); + } + } + + return total; + } +} diff --git a/hutool-core/src/main/java/cn/hutool/core/io/copy/IoCopier.java b/hutool-core/src/main/java/cn/hutool/core/io/copy/IoCopier.java index a80c5079f..e9299d349 100755 --- a/hutool-core/src/main/java/cn/hutool/core/io/copy/IoCopier.java +++ b/hutool-core/src/main/java/cn/hutool/core/io/copy/IoCopier.java @@ -29,12 +29,12 @@ public abstract class IoCopier { * 构造 * * @param bufferSize 缓存大小,< 0 表示默认{@link IoUtil#DEFAULT_BUFFER_SIZE} - * @param count 拷贝总数 + * @param count 拷贝总数,-1表示无限制 * @param progress 进度条 */ public IoCopier(int bufferSize, long count, StreamProgress progress) { this.bufferSize = bufferSize > 0 ? bufferSize : IoUtil.DEFAULT_BUFFER_SIZE; - this.count = count; + this.count = count <= 0 ? Long.MAX_VALUE : count; this.progress = progress; } @@ -52,9 +52,6 @@ public abstract class IoCopier { * @return 缓存大小 */ protected int bufferSize(long count) { - if(count < 0){ - count = Long.MAX_VALUE; - } return Math.min(this.bufferSize, (int)count); } } diff --git a/hutool-http/src/main/java/cn/hutool/http/HttpResponse.java b/hutool-http/src/main/java/cn/hutool/http/HttpResponse.java index 060cca9a8..80f2f62fd 100644 --- a/hutool-http/src/main/java/cn/hutool/http/HttpResponse.java +++ b/hutool-http/src/main/java/cn/hutool/http/HttpResponse.java @@ -254,8 +254,9 @@ public class HttpResponse extends HttpBase implements Closeable { if (null == out) { throw new NullPointerException("[out] is null!"); } + final int contentLength = Convert.toInt(header(Header.CONTENT_LENGTH), -1); try { - return IoUtil.copyByNIO(bodyStream(), out, IoUtil.DEFAULT_BUFFER_SIZE, streamProgress); + return IoUtil.copyByNIO(bodyStream(), out, IoUtil.DEFAULT_BUFFER_SIZE, contentLength, streamProgress); } finally { IoUtil.close(this); if (isCloseOut) { @@ -462,10 +463,11 @@ public class HttpResponse extends HttpBase implements Closeable { return; } - int contentLength = Convert.toInt(header(Header.CONTENT_LENGTH), 0); - final FastByteArrayOutputStream out = contentLength > 0 ? new FastByteArrayOutputStream(contentLength) : new FastByteArrayOutputStream(); + final int contentLength = Convert.toInt(header(Header.CONTENT_LENGTH), -1); + final FastByteArrayOutputStream out = contentLength > 0 ? + new FastByteArrayOutputStream(contentLength) : new FastByteArrayOutputStream(); try { - IoUtil.copy(in, out); + IoUtil.copy(in, out, -1, -1, null); } catch (IORuntimeException e) { //noinspection StatementWithEmptyBody if (e.getCause() instanceof EOFException || StrUtil.containsIgnoreCase(e.getMessage(), "Premature EOF")) {