mirror of
https://gitee.com/dromara/hutool.git
synced 2026-01-09 09:55:12 +08:00
Merge pull request #4207 from elichow/v7-dev
hutool-ai对gemini的实现 for v7
This commit is contained in:
@@ -21,6 +21,7 @@ import cn.hutool.v7.ai.core.AIService;
|
||||
import cn.hutool.v7.ai.core.Message;
|
||||
import cn.hutool.v7.ai.model.deepseek.DeepSeekService;
|
||||
import cn.hutool.v7.ai.model.doubao.DoubaoService;
|
||||
import cn.hutool.v7.ai.model.gemini.GeminiService;
|
||||
import cn.hutool.v7.ai.model.grok.GrokService;
|
||||
import cn.hutool.v7.ai.model.hutool.HutoolService;
|
||||
import cn.hutool.v7.ai.model.openai.OpenaiService;
|
||||
@@ -114,6 +115,17 @@ public class AIUtil {
|
||||
return getAIService(config, OpenaiService.class);
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取Gemini模型服务
|
||||
*
|
||||
* @param config 创建的AI服务模型的配置
|
||||
* @return GeminiService
|
||||
* @since 6.0.0
|
||||
*/
|
||||
public static GeminiService getGeminiService(final AIConfig config) {
|
||||
return getAIService(config, GeminiService.class);
|
||||
}
|
||||
|
||||
/**
|
||||
* AI大模型对话功能
|
||||
*
|
||||
|
||||
@@ -46,7 +46,11 @@ public enum ModelName {
|
||||
/**
|
||||
* ollama
|
||||
*/
|
||||
OLLAMA("ollama");
|
||||
OLLAMA("ollama"),
|
||||
/**
|
||||
* gemini
|
||||
*/
|
||||
GEMINI("gemini");
|
||||
|
||||
private final String value;
|
||||
|
||||
|
||||
@@ -141,8 +141,8 @@ public class Models {
|
||||
DOUBAO_EMBEDDING_TEXT_240715("doubao-embedding-text-240715"),
|
||||
DOUBAO_EMBEDDING_VISION("doubao-embedding-vision-241215"),
|
||||
DOUBAO_SEEDREAM_3_0_T2I("doubao-seedream-3-0-t2i-250415"),
|
||||
DOUBAO_SEEDDANCE_1_0_LITE_T2V("doubao-seedance-1-0-lite-t2v-250428"),
|
||||
DOUBAO_SEEDDANCE_1_0_lite_I2V("doubao-seedance-1-0-lite-i2v-250428"),
|
||||
DOUBAO_SEEDANCE_1_0_LITE_T2V("doubao-seedance-1-0-lite-t2v-250428"),
|
||||
DOUBAO_SEEDANCE_1_0_LITE_I2V("doubao-seedance-1-0-lite-i2v-250428"),
|
||||
WAN2_1_14B_T2V("wan2-1-14b-t2v-250225"),
|
||||
WAN2_1_14B_I2V("wan2-1-14b-i2v-250225");
|
||||
|
||||
@@ -217,4 +217,42 @@ public class Models {
|
||||
return model;
|
||||
}
|
||||
}
|
||||
|
||||
// Gemini的模型
|
||||
public enum Gemini {
|
||||
GEMINI_2_5_PRO_PREVIEW_TTS("gemini-2.5-pro-preview-tts"),
|
||||
GEMINI_2_5_FLASH_PREVIEW_TTS("gemini-2.5-flash-preview-tts"),
|
||||
VEO_2_0_GENERATE_001("veo-2.0-generate-001"),
|
||||
VEO_3_0_FAST_GENERATE_001("veo-3.0-fast-generate-001"),
|
||||
VEO_3_0_GENERATE_001("veo-3.0-generate-001"),
|
||||
VEO_3_1_FAST_GENERATE_PREVIEW("veo-3.1-fast-generate-preview"),
|
||||
VEO_3_1_GENERATE_PREVIEW("veo-3.1-generate-preview"),
|
||||
IMAGEN_4_0_GENERATE_001("imagen-4.0-generate-001"),
|
||||
IMAGEN_4_0_ULTRA_GENERATE_001("imagen-4.0-ultra-generate-001"),
|
||||
IMAGEN_4_0_FAST_GENERATE_001("imagen-4.0-fast-generate-001"),
|
||||
IMAGEN_3_0_GENERATE_002("imagen-3.0-generate-002"),
|
||||
GEMINI_3_PRO_PREVIEW("gemini-3-pro-preview"),
|
||||
GEMINI_3_FLASH("gemini-3-flash"),
|
||||
GEMINI_2_5_PRO("gemini-2.5-pro"),
|
||||
GEMINI_2_5_FLASH("gemini-2.5-flash"),
|
||||
GEMINI_2_5_FLASH_LITE("gemini-2.5-flash-lite"),
|
||||
GEMINI_2_5_FLASH_IMAGE("gemini-2.5-flash-image"),
|
||||
GEMINI_2_0_FLASH("gemini-2.0-flash"),
|
||||
GEMINI_2_0_FLASH_LITE("gemini-2.0-flash-lite"),
|
||||
GEMINI_2_0_PRO_EXP("gemini-2.0-pro-exp"),
|
||||
GEMINI_1_5_FLASH("gemini-1.5-flash"),
|
||||
GEMINI_1_5_PRO("gemini-1.5-pro"),
|
||||
GEMINI_1_5_FLASH_8B("gemini-1.5-flash-8b"),
|
||||
GEMINI_1_0_PRO("gemini-1.0-pro");
|
||||
|
||||
private final String model;
|
||||
|
||||
Gemini(final String model) {
|
||||
this.model = model;
|
||||
}
|
||||
|
||||
public String getModel() {
|
||||
return model;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Hutool Team and hutool.cn
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package cn.hutool.v7.ai.model.gemini;
|
||||
|
||||
/**
|
||||
* gemini公共类
|
||||
*
|
||||
* @author elichow
|
||||
* @since 6.0.0
|
||||
*/
|
||||
public class GeminiCommon {
|
||||
|
||||
//要生成的图片数量
|
||||
public enum GeminiImageCount {
|
||||
|
||||
ONE(1),
|
||||
TWO(2),
|
||||
THREE(3),
|
||||
FOUR(4);
|
||||
|
||||
private final int count;
|
||||
|
||||
GeminiImageCount(int count) {
|
||||
this.count = count;
|
||||
}
|
||||
|
||||
public int getCount() {
|
||||
return count;
|
||||
}
|
||||
}
|
||||
|
||||
//生成的图片大小 (imageSize) - 仅限 Standard 和 Ultra
|
||||
public enum GeminiImageSize {
|
||||
|
||||
SIZE_1K("1K"),
|
||||
SIZE_2K("2K");
|
||||
|
||||
private final String value;
|
||||
|
||||
GeminiImageSize(String value) {
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
public String getValue() {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
//宽高比
|
||||
public enum GeminiAspectRatio {
|
||||
|
||||
SQUARE("1:1"),
|
||||
PORTRAIT_3_4("3:4"),
|
||||
LANDSCAPE_4_3("4:3"),
|
||||
PORTRAIT_9_16("9:16"),
|
||||
LANDSCAPE_16_9("16:9");
|
||||
|
||||
private final String ratio;
|
||||
|
||||
GeminiAspectRatio(String ratio) {
|
||||
this.ratio = ratio;
|
||||
}
|
||||
|
||||
public String getRatio() {
|
||||
return ratio;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
//人物生成权限
|
||||
public enum GeminiPersonGeneration {
|
||||
|
||||
DONT_ALLOW("dont_allow"),
|
||||
ALLOW_ADULT("allow_adult"),
|
||||
ALLOW_ALL("allow_all");
|
||||
|
||||
private final String value;
|
||||
|
||||
GeminiPersonGeneration(String value) {
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
public String getValue() {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
//生成的视频的时长
|
||||
public enum GeminiDurationSeconds {
|
||||
|
||||
FOUR(4),
|
||||
SIX(6),
|
||||
EIGHT(8);
|
||||
|
||||
private final Integer value;
|
||||
|
||||
GeminiDurationSeconds(Integer value) {
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
public Integer getValue() {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
//语音音色
|
||||
public enum GeminiVoice {
|
||||
|
||||
AOEDE("Aoede"),
|
||||
CHARON("Charon"),
|
||||
KORE("Kore"),
|
||||
FENRIR("Fenrir"),
|
||||
PUCK("Puck");
|
||||
|
||||
private final String value;
|
||||
|
||||
GeminiVoice(String value) {
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
public String getValue() {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Hutool Team and hutool.cn
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package cn.hutool.v7.ai.model.gemini;
|
||||
|
||||
import cn.hutool.v7.ai.Models;
|
||||
import cn.hutool.v7.ai.core.BaseAIConfig;
|
||||
|
||||
|
||||
/**
|
||||
* Gemini配置类,初始化API接口地址,设置默认的模型
|
||||
*
|
||||
* @author elichow
|
||||
* @since 6.0.0
|
||||
*/
|
||||
public class GeminiConfig extends BaseAIConfig {
|
||||
|
||||
// Google Generative AI 的基础 URL
|
||||
private final String API_URL = "https://generativelanguage.googleapis.com/v1beta";
|
||||
|
||||
// 默认模型
|
||||
private final String DEFAULT_MODEL = Models.Gemini.GEMINI_2_5_FLASH.getModel();
|
||||
|
||||
public GeminiConfig() {
|
||||
setApiUrl(API_URL);
|
||||
setModel(DEFAULT_MODEL);
|
||||
}
|
||||
|
||||
public GeminiConfig(String apiKey) {
|
||||
this();
|
||||
setApiKey(apiKey);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getModelName() {
|
||||
return "gemini";
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Hutool Team and hutool.cn
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package cn.hutool.v7.ai.model.gemini;
|
||||
|
||||
|
||||
import cn.hutool.v7.ai.core.AIConfig;
|
||||
import cn.hutool.v7.ai.core.AIServiceProvider;
|
||||
|
||||
/**
|
||||
* 创建Gemini服务实现类
|
||||
*
|
||||
* @author elichow
|
||||
* @since 6.0.0
|
||||
*/
|
||||
public class GeminiProvider implements AIServiceProvider {
|
||||
|
||||
@Override
|
||||
public String getServiceName() {
|
||||
return "gemini";
|
||||
}
|
||||
|
||||
@Override
|
||||
public GeminiService create(final AIConfig config) {
|
||||
return new GeminiServiceImpl(config);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,127 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Hutool Team and hutool.cn
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package cn.hutool.v7.ai.model.gemini;
|
||||
|
||||
|
||||
|
||||
import cn.hutool.v7.ai.core.AIService;
|
||||
import cn.hutool.v7.ai.core.Message;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.List;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
/**
|
||||
* Gemini服务支持的扩展接口
|
||||
*
|
||||
* @author elichow
|
||||
* @since 6.0.0
|
||||
*/
|
||||
public interface GeminiService extends AIService {
|
||||
|
||||
/**
|
||||
* 全模态理解(图像/视频/音频/PDF):模型会依据传入的媒体资源给出回复。
|
||||
*
|
||||
* @param prompt 指令
|
||||
* @param mediaList 媒体资源列表 (支持 Base64, URL, 或 File API 的 URI)
|
||||
* @return AI回答
|
||||
*/
|
||||
String chatMultimodal(String prompt, final List<String> mediaList);
|
||||
|
||||
/**
|
||||
* 全模态理解-SSE流式输出
|
||||
*
|
||||
* @param prompt 指令
|
||||
* @param mediaList 媒体资源列表
|
||||
* @param callback 流式数据回调函数
|
||||
*/
|
||||
void chatMultimodal(String prompt, final List<String> mediaList, final Consumer<String> callback);
|
||||
|
||||
/**
|
||||
* 结构化输出:强制要求模型返回 JSON 格式
|
||||
*
|
||||
* @param messages 消息列表
|
||||
* @return AI回答
|
||||
*/
|
||||
String chatJson(final List<Message> messages);
|
||||
|
||||
/**
|
||||
* 生成图像 (Imagen 模型集成)
|
||||
*
|
||||
* @param prompt 图像描述词
|
||||
* @return 包含图片数据的响应 (通常为 Base64)
|
||||
*/
|
||||
String predictImage(String prompt);
|
||||
|
||||
/**
|
||||
* 生成视频:根据文本提示语生成视频
|
||||
*
|
||||
* @param prompt 视频描述词
|
||||
* @return 包含 operationName 的 JSON 字符串
|
||||
*/
|
||||
String predictVideo(String prompt);
|
||||
|
||||
/**
|
||||
* 获取视频生成状态:用于轮询视频生成进度
|
||||
*
|
||||
* @param operationName 生成视频接口返回的任务名称
|
||||
* @return 包含视频状态(done)及结果的 JSON 字符串
|
||||
*/
|
||||
String getVideoOperation(String operationName);
|
||||
|
||||
|
||||
/**
|
||||
* 下载生成的视频文件
|
||||
*
|
||||
* @param videoUri 视频文件的 URI
|
||||
* @param filePath 保存视频的文件路径
|
||||
*/
|
||||
void downLoadVideo(String videoUri, String filePath);
|
||||
|
||||
/**
|
||||
* 文本转语音 (TTS)
|
||||
*
|
||||
* @param prompt 文本或带有导演备注的内容
|
||||
* @return 语音文件的 Base64 编码字符串
|
||||
*/
|
||||
String textToSpeech(String prompt);
|
||||
|
||||
/**
|
||||
* 文本转语音 (TTS) - 指定音色
|
||||
*
|
||||
* @param prompt 文本或带有导演备注的内容
|
||||
* @param voice 预定义的音色常量
|
||||
* @return 语音文件的 Base64 编码字符串
|
||||
*/
|
||||
String textToSpeech(String prompt, String voice);
|
||||
|
||||
/**
|
||||
* 上传大文件到Gemini File API
|
||||
*
|
||||
* @param file 本地文件
|
||||
* @return 上传后的文件对象信息
|
||||
*/
|
||||
String uploadFile(final File file);
|
||||
|
||||
/**
|
||||
* 为原始 PCM 音频数据添加 WAV 头
|
||||
*
|
||||
* @param rawPcm 原始 PCM 音频字节数组
|
||||
* @return 带有 WAV 头的音频字节数组
|
||||
*/
|
||||
byte[] addWavHeader(final byte[] rawPcm);
|
||||
}
|
||||
@@ -0,0 +1,568 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Hutool Team and hutool.cn
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package cn.hutool.v7.ai.model.gemini;
|
||||
|
||||
|
||||
import cn.hutool.v7.ai.AIException;
|
||||
import cn.hutool.v7.ai.core.AIConfig;
|
||||
import cn.hutool.v7.ai.core.BaseAIService;
|
||||
import cn.hutool.v7.ai.core.Message;
|
||||
import cn.hutool.v7.core.codec.binary.Base64;
|
||||
import cn.hutool.v7.core.io.file.FileUtil;
|
||||
import cn.hutool.v7.core.io.resource.FileResource;
|
||||
import cn.hutool.v7.core.io.resource.HttpResource;
|
||||
import cn.hutool.v7.core.map.MapUtil;
|
||||
import cn.hutool.v7.core.text.StrUtil;
|
||||
import cn.hutool.v7.core.thread.ThreadUtil;
|
||||
import cn.hutool.v7.http.HttpGlobalConfig;
|
||||
import cn.hutool.v7.http.HttpUtil;
|
||||
import cn.hutool.v7.http.client.ClientConfig;
|
||||
import cn.hutool.v7.http.client.Request;
|
||||
import cn.hutool.v7.http.client.Response;
|
||||
import cn.hutool.v7.http.client.body.ResourceBody;
|
||||
import cn.hutool.v7.http.client.engine.ClientEngine;
|
||||
import cn.hutool.v7.http.client.engine.ClientEngineFactory;
|
||||
import cn.hutool.v7.http.meta.HeaderName;
|
||||
import cn.hutool.v7.http.meta.Method;
|
||||
import cn.hutool.v7.json.JSONObject;
|
||||
import cn.hutool.v7.json.JSONUtil;
|
||||
|
||||
import java.io.*;
|
||||
import java.net.HttpURLConnection;
|
||||
import java.net.URL;
|
||||
import java.util.*;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
/**
|
||||
* Gemini服务,AI具体功能的实现
|
||||
*
|
||||
* @author elichow
|
||||
* @since 6.0.0
|
||||
*/
|
||||
public class GeminiServiceImpl extends BaseAIService implements GeminiService {
|
||||
|
||||
private final String GENERATE_CONTENT = ":generateContent";
|
||||
private final String STREAM_GENERATE_CONTENT = ":streamGenerateContent";
|
||||
private final String PREDICT = ":predict";
|
||||
private final String PREDICT_LONG_RUNNING = ":predictLongRunning";
|
||||
private final String UPLOAD_BASE_URL = "https://generativelanguage.googleapis.com/upload/v1beta/files";
|
||||
|
||||
public GeminiServiceImpl(final AIConfig config) {
|
||||
super(config);
|
||||
}
|
||||
|
||||
private String getEndpoint(final boolean stream) {
|
||||
String action = stream ? STREAM_GENERATE_CONTENT : GENERATE_CONTENT;
|
||||
return "/models/" + config.getModel() + action;
|
||||
}
|
||||
|
||||
private String getPredictImageEndpoint() {
|
||||
return "/models/" + config.getModel() + PREDICT;
|
||||
}
|
||||
|
||||
private String getPredictVideoEndpoint() {
|
||||
return "/models/" + config.getModel() + PREDICT_LONG_RUNNING;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String chat(final List<Message> messages) {
|
||||
final Map<String, Object> paramMap = buildChatRequestMap(messages);
|
||||
final Response response = sendPost(getEndpoint(false), JSONUtil.toJsonStr(paramMap));
|
||||
return response.bodyStr();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void chat(final List<Message> messages, final Consumer<String> callback) {
|
||||
final Map<String, Object> paramMap = buildChatRequestMap(messages);
|
||||
final String endpoint = getEndpoint(true) + "?alt=sse";
|
||||
ThreadUtil.newThread(() -> sendPostStream(endpoint, paramMap, callback), "gemini-chat-sse").start();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String chatMultimodal(String prompt, final List<String> mediaList) {
|
||||
final Map<String, Object> paramMap = buildMultimodalRequestMap(prompt, mediaList);
|
||||
final Response response = sendPost(getEndpoint(false), JSONUtil.toJsonStr(paramMap));
|
||||
return response.bodyStr();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void chatMultimodal(String prompt, final List<String> mediaList, final Consumer<String> callback) {
|
||||
final Map<String, Object> paramMap = buildMultimodalRequestMap(prompt, mediaList);
|
||||
final String endpoint = getEndpoint(true) + "?alt=sse";
|
||||
ThreadUtil.newThread(() -> sendPostStream(endpoint, paramMap, callback), "gemini-m-sse").start();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String chatJson(final List<Message> messages) {
|
||||
final Map<String, Object> paramMap = buildChatRequestMap(messages);
|
||||
Map<String, Object> genConfig = MapUtil.get(paramMap, "generationConfig", Map.class);
|
||||
if (genConfig == null) {
|
||||
genConfig = new HashMap<>();
|
||||
}
|
||||
//指定响应MIME类型为JSON
|
||||
genConfig.put("response_mime_type", "application/json");
|
||||
|
||||
final Response response = sendPost(getEndpoint(false), JSONUtil.toJsonStr(paramMap));
|
||||
return response.bodyStr();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String predictImage(String prompt) {
|
||||
final Map<String, Object> paramMap = buildPredictImageRequestMap(prompt);
|
||||
final Response response = sendPost(getPredictImageEndpoint(), JSONUtil.toJsonStr(paramMap));
|
||||
return response.bodyStr();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String predictVideo(String prompt) {
|
||||
final Map<String, Object> paramMap = buildPredictVideoRequestMap(prompt);
|
||||
final Response response = sendPost(getPredictVideoEndpoint(), JSONUtil.toJsonStr(paramMap));
|
||||
return response.bodyStr();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getVideoOperation(String operationName) {
|
||||
String endPoint = "/" + operationName;
|
||||
final Response response = sendGet(endPoint);
|
||||
return response.bodyStr();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void downLoadVideo(String videoUri, String filePath) {
|
||||
if (StrUtil.isBlank(videoUri)) {
|
||||
throw new AIException("Video URI is empty");
|
||||
}
|
||||
//设置超时
|
||||
final Response response = HttpUtil.createGet(videoUri)
|
||||
.header("x-goog-api-key", config.getApiKey())
|
||||
.setMaxRedirects(1)
|
||||
.send().sync();
|
||||
if (response.isOk()) {
|
||||
FileUtil.writeFromStream(response.bodyStream(), FileUtil.file(filePath));
|
||||
} else {
|
||||
throw new AIException("Download failed with status: " + response.getStatus());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String textToSpeech(String prompt) {
|
||||
final Map<String, Object> paramMap = buildTextToSpeechRequestMap(prompt);
|
||||
final Response response = sendPost(getEndpoint(false), JSONUtil.toJsonStr(paramMap));
|
||||
return response.bodyStr();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String textToSpeech(String prompt, String voice) {
|
||||
final Map<String, Object> voiceConfig = MapUtil.of("prebuilt_voice_config", MapUtil.of("voice_name", voice));
|
||||
config.putAdditionalConfigByKey("speech_config", MapUtil.of("voice_config", voiceConfig));
|
||||
return this.textToSpeech(prompt);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String uploadFile(final File file) {
|
||||
if (null == file || !file.exists()) {
|
||||
throw new AIException("File not found!");
|
||||
}
|
||||
try {
|
||||
String mimeType = FileUtil.getMimeType(file.getName());
|
||||
if (StrUtil.isBlank(mimeType)) {
|
||||
mimeType = "application/octet-stream";
|
||||
}
|
||||
|
||||
final ClientConfig clientConfig = ClientConfig.of()
|
||||
.setConnectionTimeout(config.getTimeout())
|
||||
.setReadTimeout(config.getTimeout());
|
||||
final ClientEngine engine = ClientEngineFactory.createEngine().init(clientConfig);
|
||||
|
||||
String metadata = JSONUtil.toJsonStr(MapUtil.of("file", MapUtil.of("display_name", file.getName())));
|
||||
final Request initRequest = HttpUtil.createRequest(getUploadBaseUrl(), Method.POST)
|
||||
.header("x-goog-api-key", config.getApiKey())
|
||||
.header("X-Goog-Upload-Protocol", "resumable")
|
||||
.header("X-Goog-Upload-Command", "start")
|
||||
.header("X-Goog-Upload-Header-Content-Length", String.valueOf(file.length()))
|
||||
.header("X-Goog-Upload-Header-Content-Type", mimeType)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(metadata);
|
||||
|
||||
String sessionUrl;
|
||||
try (Response initRes = engine.send(initRequest)) {
|
||||
sessionUrl = initRes.header("X-Goog-Upload-URL");
|
||||
}
|
||||
|
||||
if (StrUtil.isBlank(sessionUrl)) {
|
||||
throw new AIException("Failed to get upload session URL");
|
||||
}
|
||||
|
||||
final Request uploadRequest = HttpUtil.createRequest(sessionUrl, Method.PUT)
|
||||
.header("X-Goog-Upload-Command", "upload, finalize")
|
||||
.header("X-Goog-Upload-Offset", "0")
|
||||
.header("Content-Length", String.valueOf(file.length()));
|
||||
|
||||
FileResource fileResource = new FileResource(file);
|
||||
|
||||
HttpResource httpResource = new HttpResource(fileResource, mimeType);
|
||||
|
||||
uploadRequest.body(new ResourceBody(httpResource));
|
||||
|
||||
try (Response uploadRes = engine.send(uploadRequest)) {
|
||||
if (uploadRes.isOk()) {
|
||||
return uploadRes.bodyStr();
|
||||
} else {
|
||||
throw new AIException("Upload failed with status: " + uploadRes.getStatus());
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
throw new AIException("Gemini upload failed: " + e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public byte[] addWavHeader(final byte[] pcmData) {
|
||||
final int totalDataLen = pcmData.length;
|
||||
final int totalAudioLen = totalDataLen + 36;
|
||||
// Gemini TTS 默认通常是 24k 或 16k
|
||||
final int sampleRate = 24000;
|
||||
// 单声道
|
||||
final int channels = 1;
|
||||
// 16bit
|
||||
final int byteRate = sampleRate * channels * 2;
|
||||
|
||||
final byte[] header = new byte[44];
|
||||
header[0] = 'R'; header[1] = 'I'; header[2] = 'F'; header[3] = 'F';
|
||||
header[4] = (byte) (totalAudioLen & 0xff);
|
||||
header[5] = (byte) ((totalAudioLen >> 8) & 0xff);
|
||||
header[6] = (byte) ((totalAudioLen >> 16) & 0xff);
|
||||
header[7] = (byte) ((totalAudioLen >> 24) & 0xff);
|
||||
header[8] = 'W'; header[9] = 'A'; header[10] = 'V'; header[11] = 'E';
|
||||
header[12] = 'f'; header[13] = 'm'; header[14] = 't'; header[15] = ' ';
|
||||
header[16] = 16; header[17] = 0; header[18] = 0; header[19] = 0;
|
||||
// PCM 格式
|
||||
header[20] = 1; header[21] = 0;
|
||||
header[22] = (byte) channels; header[23] = 0;
|
||||
header[24] = (byte) (sampleRate & 0xff);
|
||||
header[25] = (byte) ((sampleRate >> 8) & 0xff);
|
||||
header[26] = (byte) ((sampleRate >> 16) & 0xff);
|
||||
header[27] = (byte) ((sampleRate >> 24) & 0xff);
|
||||
header[28] = (byte) (byteRate & 0xff);
|
||||
header[29] = (byte) ((byteRate >> 8) & 0xff);
|
||||
header[30] = (byte) ((byteRate >> 16) & 0xff);
|
||||
header[31] = (byte) ((byteRate >> 24) & 0xff);
|
||||
header[32] = (byte) (channels * 2); header[33] = 0;
|
||||
// 16 bits per sample
|
||||
header[34] = 16; header[35] = 0;
|
||||
header[36] = 'd'; header[37] = 'a'; header[38] = 't'; header[39] = 'a';
|
||||
header[40] = (byte) (totalDataLen & 0xff);
|
||||
header[41] = (byte) ((totalDataLen >> 8) & 0xff);
|
||||
header[42] = (byte) ((totalDataLen >> 16) & 0xff);
|
||||
header[43] = (byte) ((totalDataLen >> 24) & 0xff);
|
||||
|
||||
final byte[] wavData = new byte[header.length + pcmData.length];
|
||||
System.arraycopy(header, 0, wavData, 0, header.length);
|
||||
System.arraycopy(pcmData, 0, wavData, header.length, pcmData.length);
|
||||
return wavData;
|
||||
}
|
||||
|
||||
/**
|
||||
* 动态根据 API 配置生成 Upload 地址
|
||||
*/
|
||||
private String getUploadBaseUrl() {
|
||||
String apiUrl = config.getApiUrl();
|
||||
//自动提取域名部分
|
||||
if (StrUtil.contains(apiUrl, "generativelanguage.googleapis.com")) {
|
||||
return "https://generativelanguage.googleapis.com/upload/v1beta/files";
|
||||
}
|
||||
//如果是反代或自定义节点,动态拼接
|
||||
try {
|
||||
final URL url = new URL(apiUrl);
|
||||
return new URL(url.getProtocol(), url.getHost(), url.getPort(), UPLOAD_BASE_URL).toString();
|
||||
} catch (Exception e) {
|
||||
return apiUrl.replace("/models/", "/upload/v1beta/files").split("/models")[0];
|
||||
}
|
||||
}
|
||||
|
||||
private Map<String, Object> buildChatRequestMap(final List<Message> messages) {
|
||||
final Map<String, Object> paramMap = new HashMap<>();
|
||||
final List<Map<String, Object>> contents = new ArrayList<>();
|
||||
Map<String, Object> systemInstruction = null;
|
||||
|
||||
for (Message msg : messages) {
|
||||
if ("system".equalsIgnoreCase(msg.getRole())) {
|
||||
systemInstruction = MapUtil.ofEntries(MapUtil.entry("parts",
|
||||
Collections.singletonList(MapUtil.ofEntries(MapUtil.entry("text", msg.getContent())))));
|
||||
} else {
|
||||
contents.add(MapUtil.ofEntries(
|
||||
MapUtil.entry("role", "user".equalsIgnoreCase(msg.getRole()) ? "user" : "model"),
|
||||
MapUtil.entry("parts", Collections.singletonList(MapUtil.ofEntries(MapUtil.entry("text", msg.getContent()))))
|
||||
));
|
||||
}
|
||||
}
|
||||
paramMap.put("contents", contents);
|
||||
if (systemInstruction != null) {
|
||||
paramMap.put("system_instruction", systemInstruction);
|
||||
}
|
||||
paramMap.putAll(config.getAdditionalConfigMap());
|
||||
return paramMap;
|
||||
}
|
||||
|
||||
private Map<String, Object> buildMultimodalRequestMap(String prompt, final List<String> mediaList) {
|
||||
final List<Map<String, Object>> parts = new ArrayList<>();
|
||||
parts.add(MapUtil.ofEntries(MapUtil.entry("text", prompt)));
|
||||
|
||||
if (mediaList != null && !mediaList.isEmpty()) {
|
||||
for (String media : mediaList) {
|
||||
if (StrUtil.isBlank(media)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
//Gemini File资源
|
||||
if (media.contains("files/")) {
|
||||
String fileUri = media;
|
||||
if (!media.startsWith("http")) {
|
||||
fileUri = "https://generativelanguage.googleapis.com/v1beta/" + media;
|
||||
}
|
||||
//直接从服务端获取该文件上传时真实记录的 mimeType
|
||||
String realMimeType = getRemoteFileMimeType(fileUri);
|
||||
parts.add(MapUtil.ofEntries(
|
||||
MapUtil.entry("file_data", MapUtil.ofEntries(
|
||||
MapUtil.entry("mime_type", realMimeType),
|
||||
MapUtil.entry("file_uri", fileUri)
|
||||
))
|
||||
));
|
||||
} else if (media.startsWith("http")) {
|
||||
//普通网络图片 (下载并转 Base64)
|
||||
try {
|
||||
final byte[] bytes = HttpUtil.createGet(media).send().bodyBytes();
|
||||
//尝试识别下载文件的 MIME,无法识别则不强加后缀逻辑,通过流内容自适应
|
||||
String mime = FileUtil.getMimeType(media);
|
||||
if (StrUtil.isBlank(mime)) {
|
||||
// 基础兜底
|
||||
mime = "image/jpeg";
|
||||
}
|
||||
parts.add(MapUtil.ofEntries(
|
||||
MapUtil.entry("inline_data", MapUtil.ofEntries(
|
||||
MapUtil.entry("mime_type", mime),
|
||||
MapUtil.entry("data", Base64.encode(bytes))
|
||||
))
|
||||
));
|
||||
} catch (Exception e) {
|
||||
throw new AIException("Failed to download media from URL: " + media, e.getMessage());
|
||||
}
|
||||
} else {
|
||||
//Base64 数据
|
||||
parts.add(MapUtil.ofEntries(
|
||||
MapUtil.entry("inline_data", MapUtil.ofEntries(
|
||||
MapUtil.entry("mime_type", "image/jpeg"),
|
||||
MapUtil.entry("data", media)
|
||||
))
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
final Map<String, Object> paramMap = new HashMap<>();
|
||||
paramMap.put("contents", Collections.singletonList(MapUtil.ofEntries(
|
||||
MapUtil.entry("role", "user"),
|
||||
MapUtil.entry("parts", parts)
|
||||
)));
|
||||
|
||||
//合并其他参数
|
||||
if (MapUtil.isNotEmpty(config.getAdditionalConfigMap())) {
|
||||
paramMap.putAll(config.getAdditionalConfigMap());
|
||||
}
|
||||
return paramMap;
|
||||
}
|
||||
|
||||
private Map<String, Object> buildPredictVideoRequestMap(String prompt) {
|
||||
final Map<String, Object> instance = new HashMap<>();
|
||||
instance.put("prompt", prompt);
|
||||
|
||||
final Map<String, Object> parameters = new HashMap<>();
|
||||
parameters.put("durationSeconds", GeminiCommon.GeminiDurationSeconds.EIGHT.getValue());
|
||||
|
||||
//合并其他参数
|
||||
final Map<String, Object> additional = config.getAdditionalConfigMap();
|
||||
if (MapUtil.isNotEmpty(additional)) {
|
||||
parameters.putAll(additional);
|
||||
}
|
||||
|
||||
final Map<String, Object> paramMap = new HashMap<>();
|
||||
paramMap.put("instances", Collections.singletonList(instance));
|
||||
paramMap.put("parameters", parameters);
|
||||
return paramMap;
|
||||
}
|
||||
|
||||
private Map<String, Object> buildPredictImageRequestMap(String prompt) {
|
||||
final Map<String, Object> instance = new HashMap<>();
|
||||
instance.put("prompt", prompt);
|
||||
|
||||
final Map<String, Object> parameters = new HashMap<>();
|
||||
// 官方默认4,通常我们会按需设为1
|
||||
parameters.put("sampleCount", GeminiCommon.GeminiImageCount.ONE.getCount());
|
||||
// 默认 1:1
|
||||
parameters.put("aspectRatio", GeminiCommon.GeminiAspectRatio.SQUARE.getRatio());
|
||||
// 默认
|
||||
parameters.put("personGeneration", GeminiCommon.GeminiPersonGeneration.ALLOW_ADULT.getValue());
|
||||
|
||||
//合并其他参数
|
||||
final Map<String, Object> additional = config.getAdditionalConfigMap();
|
||||
if (MapUtil.isNotEmpty(additional)) {
|
||||
parameters.putAll(additional);
|
||||
if (additional.containsKey("numberOfImages")) {
|
||||
parameters.put("sampleCount", additional.get("numberOfImages"));
|
||||
}
|
||||
}
|
||||
|
||||
final Map<String, Object> paramMap = new HashMap<>();
|
||||
paramMap.put("instances", Collections.singletonList(instance));
|
||||
paramMap.put("parameters", parameters);
|
||||
return paramMap;
|
||||
}
|
||||
|
||||
private Map<String, Object> buildTextToSpeechRequestMap(String prompt) {
|
||||
final Map<String, Object> paramMap = new HashMap<>();
|
||||
final Map<String, Object> part = new HashMap<>();
|
||||
part.put("text", prompt);
|
||||
|
||||
final Map<String, Object> content = new HashMap<>();
|
||||
content.put("role", "user");
|
||||
content.put("parts", Collections.singletonList(part));
|
||||
paramMap.put("contents", Collections.singletonList(content));
|
||||
|
||||
final Map<String, Object> generationConfig = new HashMap<>();
|
||||
//基础固定参数:必须指定返回音频格式
|
||||
generationConfig.put("response_modalities", Collections.singletonList("AUDIO"));
|
||||
|
||||
//合并其他参数
|
||||
final Map<String, Object> additionalMap = config.getAdditionalConfigMap();
|
||||
if (MapUtil.isNotEmpty(additionalMap)) {
|
||||
generationConfig.putAll(additionalMap);
|
||||
}
|
||||
paramMap.put("generation_config", generationConfig);
|
||||
return paramMap;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取远程文件的 MIME 类型
|
||||
*
|
||||
* @param fileUri 文件URI
|
||||
* @return MIME类型
|
||||
*/
|
||||
private String getRemoteFileMimeType(String fileUri) {
|
||||
try {
|
||||
HttpGlobalConfig.setTimeout(config.getTimeout());
|
||||
Request httpRequest = HttpUtil.createGet(fileUri)
|
||||
.header(HeaderName.ACCEPT, "application/json")
|
||||
.header("x-goog-api-key", config.getApiKey());
|
||||
String responseBody = httpRequest.send().bodyStr();
|
||||
final JSONObject json = JSONUtil.parseObj(responseBody);
|
||||
|
||||
//提取服务端的mimeType
|
||||
String mimeType = json.getStr("mimeType");
|
||||
if (StrUtil.isNotBlank(mimeType)) {
|
||||
return mimeType;
|
||||
}
|
||||
} catch (Exception e) {
|
||||
throw new AIException("Failed to get remote file MIME type", e.getMessage());
|
||||
}
|
||||
return "application/octet-stream";
|
||||
}
|
||||
|
||||
/**
|
||||
* 发送Get请求
|
||||
* @param endpoint 请求节点
|
||||
* @return 请求响应
|
||||
*/
|
||||
@Override
|
||||
protected Response sendGet(final String endpoint) {
|
||||
//链式构建请求
|
||||
try {
|
||||
//设置超时
|
||||
HttpGlobalConfig.setTimeout(config.getTimeout());
|
||||
return HttpUtil.createRequest(config.getApiUrl() + endpoint, Method.GET)
|
||||
.header(HeaderName.ACCEPT, "application/json")
|
||||
.header("x-goog-api-key", config.getApiKey())
|
||||
.send();
|
||||
} catch (final AIException e) {
|
||||
throw new AIException("Failed to send GET request: " + e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Response sendPost(String endpoint, String paramJson) {
|
||||
//链式构建请求
|
||||
try {
|
||||
//设置超时3分钟
|
||||
HttpGlobalConfig.setTimeout(config.getTimeout());
|
||||
return HttpUtil.createRequest(config.getApiUrl() + endpoint, Method.POST)
|
||||
.header(HeaderName.CONTENT_TYPE, "application/json")
|
||||
.header(HeaderName.ACCEPT, "application/json")
|
||||
.header("x-goog-api-key", config.getApiKey())
|
||||
.body(paramJson)
|
||||
.send();
|
||||
} catch (final AIException e) {
|
||||
throw new AIException("Failed to send POST request:" + e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 支持流式返回的 POST 请求
|
||||
*
|
||||
* @param endpoint 请求地址
|
||||
* @param paramMap 请求参数
|
||||
* @param callback 流式数据回调函数
|
||||
*/
|
||||
@Override
|
||||
protected void sendPostStream(String endpoint, final Map<String, Object> paramMap, final Consumer<String> callback) {
|
||||
HttpURLConnection connection = null;
|
||||
try {
|
||||
// 创建连接
|
||||
URL apiUrl = new URL(config.getApiUrl() + endpoint);
|
||||
connection = (HttpURLConnection) apiUrl.openConnection();
|
||||
connection.setRequestMethod(Method.POST.name());
|
||||
connection.setRequestProperty(HeaderName.CONTENT_TYPE.getValue(), "application/json");
|
||||
connection.setRequestProperty("x-goog-api-key", config.getApiKey());
|
||||
connection.setDoOutput(true);
|
||||
//设置读取超时
|
||||
connection.setReadTimeout(config.getReadTimeout());
|
||||
//设置连接超时
|
||||
connection.setConnectTimeout(config.getTimeout());
|
||||
// 发送请求体
|
||||
try (OutputStream os = connection.getOutputStream()) {
|
||||
String jsonInputString = JSONUtil.toJsonStr(paramMap);
|
||||
os.write(jsonInputString.getBytes());
|
||||
os.flush();
|
||||
}
|
||||
|
||||
// 读取流式响应
|
||||
try (BufferedReader reader = new BufferedReader(new InputStreamReader(connection.getInputStream()))) {
|
||||
String line;
|
||||
while ((line = reader.readLine()) != null) {
|
||||
// 调用回调函数处理每一行数据
|
||||
callback.accept(line);
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
callback.accept("{\"error\": \"" + e.getMessage() + "\"}");
|
||||
} finally {
|
||||
// 关闭连接
|
||||
if (connection != null) {
|
||||
connection.disconnect();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Hutool Team and hutool.cn
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/**
|
||||
* 对gemini的封装实现
|
||||
*
|
||||
* @author elichow
|
||||
* @since 6.0.0
|
||||
*/
|
||||
|
||||
package cn.hutool.v7.ai.model.gemini;
|
||||
@@ -20,3 +20,4 @@ cn.hutool.v7.ai.model.openai.OpenaiConfig
|
||||
cn.hutool.v7.ai.model.doubao.DoubaoConfig
|
||||
cn.hutool.v7.ai.model.grok.GrokConfig
|
||||
cn.hutool.v7.ai.model.ollama.OllamaConfig
|
||||
cn.hutool.v7.ai.model.gemini.GeminiConfig
|
||||
|
||||
@@ -20,3 +20,4 @@ cn.hutool.v7.ai.model.openai.OpenaiProvider
|
||||
cn.hutool.v7.ai.model.doubao.DoubaoProvider
|
||||
cn.hutool.v7.ai.model.grok.GrokProvider
|
||||
cn.hutool.v7.ai.model.ollama.OllamaProvider
|
||||
cn.hutool.v7.ai.model.gemini.GeminiProvider
|
||||
|
||||
@@ -21,6 +21,7 @@ import cn.hutool.v7.ai.core.AIService;
|
||||
import cn.hutool.v7.ai.core.Message;
|
||||
import cn.hutool.v7.ai.model.deepseek.DeepSeekService;
|
||||
import cn.hutool.v7.ai.model.doubao.DoubaoService;
|
||||
import cn.hutool.v7.ai.model.gemini.GeminiService;
|
||||
import cn.hutool.v7.ai.model.grok.GrokService;
|
||||
import cn.hutool.v7.ai.model.hutool.HutoolService;
|
||||
import cn.hutool.v7.ai.model.openai.OpenaiService;
|
||||
@@ -77,6 +78,12 @@ class AIUtilTest {
|
||||
assertNotNull(openAIService);
|
||||
}
|
||||
|
||||
@Test
|
||||
void getGeminiService() {
|
||||
final GeminiService geminiService = AIUtil.getGeminiService(new AIConfigBuilder(ModelName.GEMINI.getValue()).setApiKey(key).build());
|
||||
assertNotNull(geminiService);
|
||||
}
|
||||
|
||||
@Test
|
||||
void chat() {
|
||||
final String chat = AIUtil.chat(new AIConfigBuilder(ModelName.DEEPSEEK.getValue()).setApiKey(key).build(), "写一首赞美我的诗");
|
||||
|
||||
@@ -128,7 +128,7 @@ class DoubaoServiceTest {
|
||||
@Disabled
|
||||
void videoTasks() {
|
||||
final DoubaoService doubaoService = AIServiceFactory.getAIService(new AIConfigBuilder(ModelName.DOUBAO.getValue())
|
||||
.setApiKey(key).setModel(Models.Doubao.DOUBAO_SEEDDANCE_1_0_lite_I2V.getModel()).build(), DoubaoService.class);
|
||||
.setApiKey(key).setModel(Models.Doubao.DOUBAO_SEEDANCE_1_0_LITE_I2V.getModel()).build(), DoubaoService.class);
|
||||
final String videoTasks = doubaoService.videoTasks("生成一段动画视频,主角是大耳朵图图,一个活泼可爱的小男孩。视频中图图在公园里玩耍," +
|
||||
"画面采用明亮温暖的卡通风格,色彩鲜艳,动作流畅。背景音乐轻快活泼,带有冒险感,音效包括鸟叫声、欢笑声和山洞回声。", "https://img2.baidu.com/it/u=862000265,4064861820&fm=253&fmt=auto&app=138&f=JPEG?w=800&h=1544");
|
||||
assertNotNull(videoTasks);//cgt-20250306170051-6r9gk
|
||||
|
||||
@@ -0,0 +1,286 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Hutool Team and hutool.cn
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package cn.hutool.v7.ai.model.gemini;
|
||||
|
||||
import cn.hutool.v7.ai.AIServiceFactory;
|
||||
import cn.hutool.v7.ai.ModelName;
|
||||
import cn.hutool.v7.ai.Models;
|
||||
import cn.hutool.v7.ai.core.AIConfig;
|
||||
import cn.hutool.v7.ai.core.AIConfigBuilder;
|
||||
import cn.hutool.v7.ai.core.Message;
|
||||
import cn.hutool.v7.core.codec.binary.Base64;
|
||||
import cn.hutool.v7.core.io.file.FileUtil;
|
||||
import cn.hutool.v7.core.text.StrUtil;
|
||||
import cn.hutool.v7.core.thread.ThreadUtil;
|
||||
import cn.hutool.v7.json.JSONArray;
|
||||
import cn.hutool.v7.json.JSONObject;
|
||||
import cn.hutool.v7.json.JSONUtil;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||
|
||||
|
||||
class GeminiServiceTest {
|
||||
|
||||
String key = "your key";
|
||||
GeminiService geminiService = AIServiceFactory.getAIService(new AIConfigBuilder(ModelName.GEMINI.getValue()).setApiKey(key).build(), GeminiService.class);
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
void chat() {
|
||||
final String chat = geminiService.chat("我应该怎么度过2025年的最后一天?");
|
||||
assertNotNull(chat);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
void chatStream() {
|
||||
String prompt = "写一个疯狂星期四广告词";
|
||||
// 使用AtomicBoolean作为结束标志
|
||||
AtomicBoolean isDone = new AtomicBoolean(false);
|
||||
|
||||
geminiService.chat(prompt, data -> {
|
||||
assertNotNull(data);
|
||||
if (data.contains("finishReason")) {
|
||||
// 设置结束标志
|
||||
isDone.set(true);
|
||||
} else if (data.contains("\"error\"")) {
|
||||
isDone.set(true);
|
||||
}
|
||||
|
||||
});
|
||||
// 轮询检查结束标志
|
||||
while (!isDone.get()) {
|
||||
ThreadUtil.sleep(100);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
void testUpload() {
|
||||
String uploadFile = geminiService.uploadFile(new File("/Users/hdbuoge/Desktop/111.mov"));
|
||||
assertNotNull(uploadFile);
|
||||
//https://generativelanguage.googleapis.com/v1beta/files/a8kkc6263yth
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
void chatMultimodalImage() {
|
||||
final GeminiService geminiService = AIServiceFactory.getAIService(new AIConfigBuilder(ModelName.GEMINI.getValue())
|
||||
.setApiKey(key).setModel(Models.Gemini.GEMINI_2_0_FLASH.getModel()).build(), GeminiService.class);
|
||||
final String chatVision = geminiService.chatMultimodal("图片上有些什么内容?", Arrays.asList("https://generativelanguage.googleapis.com/v1beta/files/a8kkc6263yth"));
|
||||
assertNotNull(chatVision);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
void chatMultimodalImageSteam() {
|
||||
final GeminiService geminiService = AIServiceFactory.getAIService(new AIConfigBuilder(ModelName.GEMINI.getValue())
|
||||
.setApiKey(key).setModel(Models.Gemini.GEMINI_2_0_FLASH.getModel()).build(), GeminiService.class);
|
||||
String prompt = "图片上有些什么内容?中文回答";
|
||||
// 使用AtomicBoolean作为结束标志
|
||||
AtomicBoolean isDone = new AtomicBoolean(false);
|
||||
|
||||
geminiService.chatMultimodal(prompt, Arrays.asList("https://generativelanguage.googleapis.com/v1beta/files/a8kkc6263yth"), data -> {
|
||||
assertNotNull(data);
|
||||
if (data.contains("finishReason")) {
|
||||
// 设置结束标志
|
||||
isDone.set(true);
|
||||
} else if (data.contains("\"error\"")) {
|
||||
isDone.set(true);
|
||||
}
|
||||
|
||||
});
|
||||
// 轮询检查结束标志
|
||||
while (!isDone.get()) {
|
||||
ThreadUtil.sleep(100);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
void chatMultimodalVideo() {
|
||||
final GeminiService geminiService = AIServiceFactory.getAIService(new AIConfigBuilder(ModelName.GEMINI.getValue())
|
||||
.setApiKey(key).setModel(Models.Gemini.GEMINI_2_0_FLASH.getModel()).build(), GeminiService.class);
|
||||
final String chatVision = geminiService.chatMultimodal("视频中第3秒发生了什么?", Arrays.asList("https://generativelanguage.googleapis.com/v1beta/files/k1whwbqznecz"));
|
||||
assertNotNull(chatVision);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
void chatMultimodalVideoStream() {
|
||||
final GeminiService geminiService = AIServiceFactory.getAIService(new AIConfigBuilder(ModelName.GEMINI.getValue())
|
||||
.setApiKey(key).setModel(Models.Gemini.GEMINI_2_0_FLASH.getModel()).build(), GeminiService.class);
|
||||
String prompt = "视频中第3秒有什么物品?";
|
||||
// 使用AtomicBoolean作为结束标志
|
||||
AtomicBoolean isDone = new AtomicBoolean(false);
|
||||
|
||||
geminiService.chatMultimodal(prompt, Arrays.asList("https://generativelanguage.googleapis.com/v1beta/files/k1whwbqznecz"), data -> {
|
||||
assertNotNull(data);
|
||||
if (data.contains("finishReason")) {
|
||||
// 设置结束标志
|
||||
isDone.set(true);
|
||||
} else if (data.contains("\"error\"")) {
|
||||
isDone.set(true);
|
||||
}
|
||||
|
||||
});
|
||||
// 轮询检查结束标志
|
||||
while (!isDone.get()) {
|
||||
ThreadUtil.sleep(100);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
void chatJson() {
|
||||
// 测试结构化输出
|
||||
final List<Message> messages = new ArrayList<>();
|
||||
messages.add(new Message("user", "提取以下信息:张三,男,25岁。返回JSON格式。"));
|
||||
final String jsonResponse = geminiService.chatJson(messages);
|
||||
assertNotNull(jsonResponse);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
void chatImage() {
|
||||
final GeminiService geminiService = AIServiceFactory.getAIService(new AIConfigBuilder(ModelName.GEMINI.getValue())
|
||||
.setApiKey(key).setModel(Models.Gemini.GEMINI_2_5_FLASH_IMAGE.getModel()).build(), GeminiService.class);
|
||||
final String response = geminiService.chat("一只在太空中行走的赛博朋克风格的猫");
|
||||
// 注意:Gemini返回是包含base64数据的响应体
|
||||
assertNotNull(response);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
void predictImage() {
|
||||
// 测试 Imagen 4 原生图片生成
|
||||
final GeminiService geminiService = AIServiceFactory.getAIService(new AIConfigBuilder(ModelName.GEMINI.getValue())
|
||||
.setApiKey(key).setModel(Models.Gemini.IMAGEN_4_0_GENERATE_001.getModel()).build(), GeminiService.class);
|
||||
//暂时只支持英文提示词
|
||||
final String response = geminiService.predictImage("Oil painting of New Year's greetings");
|
||||
assertNotNull(response);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
void predictImageAndSave() {
|
||||
AIConfig config = new AIConfigBuilder(ModelName.GEMINI.getValue())
|
||||
.setApiKey(key).setModel(Models.Gemini.IMAGEN_4_0_GENERATE_001.getModel()).build();
|
||||
config.putAdditionalConfigByKey("numberOfImages", GeminiCommon.GeminiImageCount.TWO.getCount());
|
||||
final GeminiService geminiService = AIServiceFactory.getAIService(config, GeminiService.class);
|
||||
|
||||
//暂时只支持英文提示词
|
||||
final String response = geminiService.predictImage("A park in the spring next to a lake, the sun sets across the lake, golden hour, red wildflowers");
|
||||
|
||||
//解析JSON 结构
|
||||
JSONObject jsonObject = JSONUtil.parseObj(response);
|
||||
JSONArray predictions = jsonObject.getJSONArray("predictions");
|
||||
|
||||
if (predictions != null) {
|
||||
for (int i = 0; i < predictions.size(); i++) {
|
||||
JSONObject item = predictions.getJSONObject(i);
|
||||
|
||||
//提取Base64数据
|
||||
String base64Data = item.getStr("bytesBase64Encoded");
|
||||
|
||||
//划分并保存到本地文件
|
||||
String fileName = "generated_image_" + i + ".png";
|
||||
FileUtil.writeBytes(Base64.decode(base64Data), fileName);
|
||||
FileUtil.writeBytes(Base64.decode(base64Data), "your filePath" + fileName);
|
||||
|
||||
assertNotNull(base64Data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
void generateVideoTest() {
|
||||
AIConfig config = new AIConfigBuilder(ModelName.GEMINI.getValue())
|
||||
.setApiKey(key).setModel(Models.Gemini.VEO_3_1_GENERATE_PREVIEW.getModel()).build();
|
||||
config.putAdditionalConfigByKey("aspectRatio", GeminiCommon.GeminiAspectRatio.LANDSCAPE_16_9.getRatio());
|
||||
config.putAdditionalConfigByKey("durationSeconds", GeminiCommon.GeminiDurationSeconds.EIGHT.getValue());
|
||||
final GeminiService geminiService = AIServiceFactory.getAIService(config, GeminiService.class);
|
||||
|
||||
|
||||
// 4. 发起异步生成请求
|
||||
String initialRes = geminiService.predictVideo("在一艘即将沉没的轮船上,男主角从后面抱起展开双手的女主角,女主角说:”u jump, i jump“");
|
||||
JSONObject resObj = JSONUtil.parseObj(initialRes);
|
||||
|
||||
String operationName = resObj.getStr("name");
|
||||
|
||||
// 5. 轮询获取结果 (LRO 模式)
|
||||
String videoUri = null;
|
||||
int maxRetries = 30; // 约等待 5-10 分钟
|
||||
for (int i = 0; i < maxRetries; i++) {
|
||||
ThreadUtil.sleep(20000); // 每 20 秒查询一次
|
||||
|
||||
String statusJson = geminiService.getVideoOperation(operationName);
|
||||
JSONObject statusObj = JSONUtil.parseObj(statusJson);
|
||||
|
||||
// 判断是否完成
|
||||
if (statusObj.getBool("done", false)) {
|
||||
// 路径参考:response.generateVideoResponse.generatedSamples[0].video.uri
|
||||
videoUri = statusObj.getByPath("response.generateVideoResponse.generatedSamples[0].video.uri", String.class);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
assertNotNull(videoUri);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
void downLoadVideo() {
|
||||
String videoUri = "geminiService.getVideoOperation返回的videoUri";
|
||||
geminiService.downLoadVideo(videoUri, "your filePath");
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
void testTTSWithBuildMethod() {
|
||||
AIConfig config = new AIConfigBuilder(ModelName.GEMINI.getValue())
|
||||
.setApiKey(key).setModel(Models.Gemini.GEMINI_2_5_PRO_PREVIEW_TTS.getModel()).build();
|
||||
config.putAdditionalConfigByKey("temperature", 0.7);
|
||||
final GeminiService geminiService = AIServiceFactory.getAIService(config, GeminiService.class);
|
||||
|
||||
String prompt = "Hello, this is a test of the native text to speech system.";
|
||||
String result = geminiService.textToSpeech(prompt, GeminiCommon.GeminiVoice.AOEDE.getValue());
|
||||
|
||||
JSONObject json = JSONUtil.parseObj(result);
|
||||
String base64Data = json.getByPath("candidates[0].content.parts[0].inlineData.data", String.class);
|
||||
byte[] rawPcm = Base64.decode(base64Data);
|
||||
|
||||
//接口返回的是裸PCM流
|
||||
byte[] wavFile = geminiService.addWavHeader(rawPcm);
|
||||
|
||||
|
||||
if (StrUtil.isNotBlank(base64Data)) {
|
||||
FileUtil.writeBytes(wavFile, "your filePath");
|
||||
}
|
||||
|
||||
assertNotNull(wavFile);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user