diff --git a/admin/pom.xml b/admin/pom.xml
index 24d9e59cd6d05f7bbdaf2025663b5c5b2f0a1092..3e1df20e9227a829a4352d83b5b9d0e3df6a0893 100644
--- a/admin/pom.xml
+++ b/admin/pom.xml
@@ -43,7 +43,7 @@
org.springframework.ai
spring-ai-bom
- 1.0.0
+ 1.0.1
pom
import
@@ -139,10 +139,6 @@
-
- org.springframework.ai
- spring-ai-starter-mcp-client
-
org.springframework.ai
spring-ai-starter-mcp-client-webflux
@@ -155,17 +151,26 @@
org.springframework.ai
- spring-ai-starter-model-ollama
+ spring-ai-ollama
+
+
+
+ org.springframework.ai
+ spring-ai-openai
org.springframework.ai
spring-ai-starter-vector-store-elasticsearch
+
+
+
+
- org.elasticsearch.client
- elasticsearch-rest-client
- 8.13.3
+ co.elastic.clients
+ elasticsearch-java
+ 8.19.3
diff --git a/admin/src/main/java/com/wt/admin/config/AIConfig.java b/admin/src/main/java/com/wt/admin/config/AIConfig.java
new file mode 100644
index 0000000000000000000000000000000000000000..f8d0ffb4aa22b49f56e460d514e55387a4151263
--- /dev/null
+++ b/admin/src/main/java/com/wt/admin/config/AIConfig.java
@@ -0,0 +1,253 @@
+package com.wt.admin.config;
+
+import com.wt.admin.config.cache.Cache;
+import com.wt.admin.config.cache.CacheManager;
+import com.wt.admin.config.cache.impl.ChatContentCache;
+import com.wt.admin.config.prop.AIProp;
+import com.wt.admin.domain.vo.ai.ChatModelContentVO;
+import com.wt.admin.service.vector.Vector;
+import com.wt.admin.service.vector.impl.ESVectorImpl;
+import com.wt.admin.service.vector.impl.MemoryVectorImpl;
+import io.micrometer.observation.ObservationRegistry;
+import jakarta.annotation.Resource;
+import lombok.Data;
+import org.dromara.easyai.config.SentenceConfig;
+import org.dromara.easyai.config.TfConfig;
+import org.dromara.easyai.entity.KeyWordForSentence;
+import org.dromara.easyai.naturalLanguage.TalkToTalk;
+import org.dromara.easyai.naturalLanguage.languageCreator.CatchKeyWord;
+import org.dromara.easyai.naturalLanguage.word.MyKeyWord;
+import org.dromara.easyai.naturalLanguage.word.WordEmbedding;
+import org.dromara.easyai.rnnJumpNerveCenter.RRNerveManager;
+import org.dromara.easyai.yolo.FastYolo;
+import org.springframework.ai.chat.memory.ChatMemory;
+import org.springframework.ai.chat.memory.InMemoryChatMemoryRepository;
+import org.springframework.ai.chat.memory.MessageWindowChatMemory;
+import org.springframework.ai.document.MetadataMode;
+import org.springframework.ai.embedding.EmbeddingModel;
+import org.springframework.ai.embedding.TokenCountBatchingStrategy;
+import org.springframework.ai.ollama.OllamaEmbeddingModel;
+import org.springframework.ai.ollama.api.OllamaApi;
+import org.springframework.ai.ollama.api.OllamaOptions;
+import org.springframework.ai.ollama.management.ModelManagementOptions;
+import org.springframework.ai.openai.OpenAiEmbeddingModel;
+import org.springframework.ai.openai.OpenAiEmbeddingOptions;
+import org.springframework.ai.openai.api.OpenAiApi;
+import org.springframework.ai.retry.RetryUtils;
+import org.springframework.ai.vectorstore.SimpleVectorStore;
+import org.springframework.beans.factory.annotation.Qualifier;
+import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
+import org.springframework.boot.web.servlet.FilterRegistrationBean;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.web.filter.RequestContextFilter;
+
+import java.util.ArrayList;
+import java.util.List;
+
+@Configuration
+public class AIConfig {
+
+ @Resource
+ private AIProp aiProp;
+
+ /**
+ * 词向量嵌入器,语义神经
+ * @return
+ */
+ @Bean("wordEmbedding")
+ public Cache wordEmbedding(){
+ return CacheManager.getCache("wordEmbedding");
+ }
+
+ @Bean("myKeyWord")
+ public Cache myKeyWordMap(){
+ return CacheManager.getCache("myKeyWord");
+ }
+
+ @Bean("catchKeyWord")
+ public Cache catchKeyWordMap(){
+ return CacheManager.getCache("catchKeyWord");
+ }
+
+ @Bean("sensorKeyWordMapper")
+ public Cache>> sensorKeyWordMapper(){
+ return CacheManager.getCache("sensorKeyWordMapper");
+ }
+
+ @Bean("imageYoloManager>")
+ public Cache imageYoloManager(){
+ return CacheManager.getCache("imageYoloManager");
+ }
+
+ @Bean("keyWordValueList")
+ public List keyWordValueList(){
+ return new ArrayList<>();
+ }
+
+ /**
+ * 缓存聊天内容
+ * @return
+ */
+ @Bean("chatContents")
+ public ChatContentCache chatContents(){
+ return new ChatContentCache<>();
+ }
+
+ @Bean("embeddingModel")
+ @ConditionalOnProperty(name = "spring.embedding.method", havingValue = "ollama")
+ public EmbeddingModel ollamaEmbeddingModel(){
+ OllamaApi build = OllamaApi.builder().baseUrl(aiProp.getEmbedding().getUrl()).build();
+ return new OllamaEmbeddingModel(
+ build,
+ OllamaOptions.builder()
+ .model(aiProp.getEmbedding().getModel())
+ .build(),
+ ObservationRegistry.NOOP,
+ ModelManagementOptions.defaults()
+ );
+ }
+
+ @Bean("embeddingModel")
+ @ConditionalOnProperty(name = "spring.embedding.method", havingValue = "openai")
+ public EmbeddingModel openAiEmbeddingModel(){
+ var openAiApi = OpenAiApi.builder()
+ .baseUrl(aiProp.getEmbedding().getUrl())
+ .apiKey(aiProp.getEmbedding().getApiKey())
+ .build();
+ return new OpenAiEmbeddingModel(
+ openAiApi,
+ MetadataMode.EMBED,
+ OpenAiEmbeddingOptions.builder()
+ .model(aiProp.getEmbedding().getModel())
+ .build(),
+ RetryUtils.DEFAULT_RETRY_TEMPLATE);
+ }
+
+ @Bean
+ public SimpleVectorStore simpleVectorStore(@Qualifier("embeddingModel") EmbeddingModel embeddingModel) {
+ return SimpleVectorStore.builder(embeddingModel)
+ .batchingStrategy(new TokenCountBatchingStrategy())
+ .build();
+ }
+
+// @Bean
+// @ConditionalOnProperty(name = "spring.vector.method", havingValue = "es")
+// public ElasticsearchVectorStore elasticsearchVectorStore(EmbeddingModel embeddingModel, ElasticsearchProperties elasticsearchProperties) {
+// HttpHost[] hosts = elasticsearchProperties.getUris().stream()
+// .map(HttpHost::create)
+// .toArray(HttpHost[]::new);
+// RestClientBuilder builder = RestClient.builder(hosts);
+//
+// builder.setHttpClientConfigCallback(httpClientBuilder -> {
+// try {
+// // 创建信任所有证书的SSL上下文 禁用SSL验证
+// SSLContext sslContext = SSLContextBuilder
+// .create()
+// .loadTrustMaterial(null, (chain, authType) -> true)
+// .build();
+//
+// return httpClientBuilder
+// .setSSLContext(sslContext)
+// .setSSLHostnameVerifier((hostname, session) -> true)
+// .setKeepAliveStrategy((response, context) -> {
+// // 设置keep-alive时间为30秒
+// return 30000;
+// })
+// .setMaxConnTotal(20)// 配置连接池和超时
+// .setMaxConnPerRoute(20)
+// .setConnectionTimeToLive(60, java.util.concurrent.TimeUnit.SECONDS)
+// .setDefaultRequestConfig(
+// org.apache.http.client.config.RequestConfig.custom()
+// .setConnectTimeout(10000) // 10秒连接超时
+// .setSocketTimeout(30000) // 30秒Socket超时
+// .build()
+// );
+// } catch (Exception e) {
+// throw new RuntimeException("Failed to create SSL context", e);
+// }
+// });
+// // 使用Basic认证方式
+// String auth = elasticsearchProperties.getUsername() + ":" + elasticsearchProperties.getPassword();
+// String encodedAuth = java.util.Base64.getEncoder().encodeToString(auth.getBytes());
+// builder.setDefaultHeaders(new Header[]{
+// new BasicHeader("Authorization", "Basic " + encodedAuth),
+// new BasicHeader("Connection", "keep-alive")
+// });
+//
+// ElasticsearchVectorStoreOptions options = new ElasticsearchVectorStoreOptions();
+// options.setIndexName("custom-index"); // Optional: defaults to "spring-ai-document-index"
+// options.setSimilarity(SimilarityFunction.cosine); // Optional: defaults to COSINE
+// options.setDimensions(1536); // Optional: defaults to model dimensions or 1536
+//
+// return ElasticsearchVectorStore.builder(builder.build(), embeddingModel)
+// .options(options) // Optional: use custom options
+// .initializeSchema(true) // Optional: defaults to false
+// .batchingStrategy(new TokenCountBatchingStrategy()) // Optional: defaults to TokenCountBatchingStrategy
+// .build();
+// }
+
+ @Bean
+ @ConditionalOnProperty(name = "spring.vector.method", havingValue = "es")
+ public Vector es() {
+ return new ESVectorImpl();
+ }
+
+ @Bean
+ @ConditionalOnProperty(name = "spring.vector.method", havingValue = "memory")
+ public Vector memory() {
+ return new MemoryVectorImpl();
+ }
+
+ @Bean
+ public ChatMemory chatMemory() {
+ return MessageWindowChatMemory.builder()
+ .maxMessages(10)
+ .chatMemoryRepository(new InMemoryChatMemoryRepository())
+ .build();
+ }
+
+
+ @Bean
+ public FilterRegistrationBean requestContextFilter() {
+ FilterRegistrationBean registrationBean = new FilterRegistrationBean<>();
+ registrationBean.setFilter(new RequestContextFilter());
+ registrationBean.setOrder(1);
+ registrationBean.addUrlPatterns("/*");
+ return registrationBean;
+ }
+
+
+ @Data
+ public static class KeywordValue{
+ private String keywordValue;
+ private Integer typeId;
+ private Long index;//该关键词索引id
+ public KeywordValue(String keywordValue, Integer typeId, Long index) {
+ this.keywordValue = keywordValue;
+ this.typeId = typeId;
+ this.index = index;
+ }
+ public KeywordValue(String keywordValue, Integer typeId) {
+ this.keywordValue = keywordValue;
+ this.typeId = typeId;
+ }
+ }
+
+ @Data
+ public static class WordAndRRManager {
+
+ private WordEmbedding wordEmbedding;
+ private RRNerveManager rrNerveManager;
+ private TalkToTalk talkToTalk;
+ private SentenceConfig sentenceConfig;
+ private TfConfig tfConfig;
+
+
+ public WordAndRRManager() {
+ this.wordEmbedding = new WordEmbedding();
+ this.rrNerveManager = new RRNerveManager(wordEmbedding);
+ }
+
+ }
+}
diff --git a/admin/src/main/java/com/wt/admin/config/ConstVar.java b/admin/src/main/java/com/wt/admin/config/ConstVar.java
index 51220033f2a89d7794dfb39f89e25d9a5d42e2b9..23695b88352af1d9fab9c2d4a2eca1582fe129f3 100644
--- a/admin/src/main/java/com/wt/admin/config/ConstVar.java
+++ b/admin/src/main/java/com/wt/admin/config/ConstVar.java
@@ -13,7 +13,6 @@ public class ConstVar {
Integer LANGUAGE = 1;
Integer QA = 2;
Integer IMAGE = 3;
- Integer VIDEO = 4;
}
public interface Socket{
diff --git a/admin/src/main/java/com/wt/admin/config/GlobalBeanConfig.java b/admin/src/main/java/com/wt/admin/config/GlobalBeanConfig.java
index defdeb0918ccbe646932df80e37e8da691114f14..86acd521fdb6ed3438c7190419cd8f46a778d917 100644
--- a/admin/src/main/java/com/wt/admin/config/GlobalBeanConfig.java
+++ b/admin/src/main/java/com/wt/admin/config/GlobalBeanConfig.java
@@ -2,43 +2,23 @@ package com.wt.admin.config;
import com.wt.admin.config.cache.Cache;
import com.wt.admin.config.cache.CacheManager;
-import com.wt.admin.config.cache.impl.ChatContentCache;
import com.wt.admin.domain.entity.sys.SysSettingEntity;
-import com.wt.admin.domain.vo.ai.ChatModelContentVO;
import com.wt.admin.domain.vo.sys.UserVO;
-import com.wt.admin.service.vector.Vector;
-import com.wt.admin.service.vector.impl.ESVectorImpl;
-import com.wt.admin.service.vector.impl.MemoryVectorImpl;
-import lombok.Data;
-import org.dromara.easyai.config.SentenceConfig;
-import org.dromara.easyai.config.TfConfig;
-import org.dromara.easyai.entity.KeyWordForSentence;
-import org.dromara.easyai.naturalLanguage.TalkToTalk;
-import org.dromara.easyai.naturalLanguage.languageCreator.CatchKeyWord;
-import org.dromara.easyai.naturalLanguage.word.MyKeyWord;
-import org.dromara.easyai.naturalLanguage.word.WordEmbedding;
-import org.dromara.easyai.rnnJumpNerveCenter.RRNerveManager;
-import org.dromara.easyai.yolo.FastYolo;
-import org.springframework.ai.embedding.EmbeddingModel;
-import org.springframework.ai.embedding.TokenCountBatchingStrategy;
-import org.springframework.ai.vectorstore.SimpleVectorStore;
-import org.springframework.ai.vectorstore.VectorStore;
-import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
-import org.springframework.boot.web.servlet.FilterRegistrationBean;
import org.springframework.context.MessageSource;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.support.ReloadableResourceBundleMessageSource;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
-import org.springframework.web.filter.RequestContextFilter;
-import java.util.ArrayList;
-import java.util.List;
import java.util.concurrent.Executor;
@Configuration
public class GlobalBeanConfig {
+ /**
+ * 国际化
+ * @return
+ */
@Bean("messageSource")
public MessageSource messageSource() {
ReloadableResourceBundleMessageSource messageSource = new ReloadableResourceBundleMessageSource();
@@ -47,6 +27,10 @@ public class GlobalBeanConfig {
return messageSource;
}
+ /**
+ * 用户缓存
+ * @return
+ */
@Bean
public Cache userCache(){
return CacheManager.getCache("user");
@@ -61,6 +45,10 @@ public class GlobalBeanConfig {
return CacheManager.getCache("setting");
}
+ /**
+ * 异步线程
+ * @return
+ */
@Bean("newAsyncExecutor")
public Executor newAsyncExecutor() {
ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();
@@ -73,6 +61,10 @@ public class GlobalBeanConfig {
return taskExecutor;
}
+ /**
+ * 通用线程池
+ * @return
+ */
@Bean("publicThread")
public Executor publicThread() {
ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();
@@ -85,108 +77,7 @@ public class GlobalBeanConfig {
return taskExecutor;
}
- /**
- * 词向量嵌入器,语义神经
- * @return
- */
- @Bean("wordEmbedding")
- public Cache wordEmbedding(){
- return CacheManager.getCache("wordEmbedding");
- }
-
- @Bean("myKeyWord")
- public Cache myKeyWordMap(){
- return CacheManager.getCache("myKeyWord");
- }
-
- @Bean("catchKeyWord")
- public Cache catchKeyWordMap(){
- return CacheManager.getCache("catchKeyWord");
- }
-
- @Bean("sensorKeyWordMapper")
- public Cache>> sensorKeyWordMapper(){
- return CacheManager.getCache("sensorKeyWordMapper");
- }
-
- @Bean("chatContents")
- public ChatContentCache chatContents(){
- return new ChatContentCache<>();
- }
-
- @Bean
- public VectorStore memoryVectorStore(EmbeddingModel embeddingModel) {
- return SimpleVectorStore.builder(embeddingModel)
- .batchingStrategy(new TokenCountBatchingStrategy())
- .build();
- }
-
- @Bean
- @ConditionalOnProperty(name = "spring.vector.es", havingValue = "true")
- public Vector es() {
- return new ESVectorImpl();
- }
-
- @Bean
- @ConditionalOnProperty(name = "spring.vector.es", havingValue = "false")
- public Vector memory() {
- return new MemoryVectorImpl();
- }
-
- @Bean("imageYoloManager>")
- public Cache imageYoloManager(){
- return CacheManager.getCache("imageYoloManager");
- }
-
- @Bean("keyWordValueList")
- public List keyWordValueList(){
- return new ArrayList<>();
- }
-
-
- @Bean
- public FilterRegistrationBean requestContextFilter() {
- FilterRegistrationBean registrationBean = new FilterRegistrationBean<>();
- registrationBean.setFilter(new RequestContextFilter());
- registrationBean.setOrder(1);
- registrationBean.addUrlPatterns("/*");
- return registrationBean;
- }
-
-
- @Data
- public static class KeywordValue{
- private String keywordValue;
- private Integer typeId;
- private Long index;//该关键词索引id
- public KeywordValue(String keywordValue, Integer typeId, Long index) {
- this.keywordValue = keywordValue;
- this.typeId = typeId;
- this.index = index;
- }
- public KeywordValue(String keywordValue, Integer typeId) {
- this.keywordValue = keywordValue;
- this.typeId = typeId;
- }
- }
-
- @Data
- public static class WordAndRRManager {
-
- private WordEmbedding wordEmbedding;
- private RRNerveManager rrNerveManager;
- private TalkToTalk talkToTalk;
- private SentenceConfig sentenceConfig;
- private TfConfig tfConfig;
-
-
- public WordAndRRManager() {
- this.wordEmbedding = new WordEmbedding();
- this.rrNerveManager = new RRNerveManager(wordEmbedding);
- }
-
- }
}
diff --git a/admin/src/main/java/com/wt/admin/config/prop/AIProp.java b/admin/src/main/java/com/wt/admin/config/prop/AIProp.java
new file mode 100644
index 0000000000000000000000000000000000000000..efd0fc75a896e9f70eb2b834a13da7dfd30d2920
--- /dev/null
+++ b/admin/src/main/java/com/wt/admin/config/prop/AIProp.java
@@ -0,0 +1,23 @@
+package com.wt.admin.config.prop;
+
+import lombok.Data;
+import org.springframework.boot.context.properties.ConfigurationProperties;
+import org.springframework.context.annotation.Configuration;
+
+@Data
+@Configuration
+@ConfigurationProperties(prefix = "spring")
+public class AIProp {
+
+ private Embedding embedding;
+
+ @Data
+ public static class Embedding{
+ private String method;
+ private String url;
+ private String model;
+ private String apiKey;
+
+ }
+
+}
diff --git a/admin/src/main/java/com/wt/admin/config/prop/ModelPathProp.java b/admin/src/main/java/com/wt/admin/config/prop/ModelPathProp.java
index 6399d4fa55a867b3aa0f551c18e44d4e5cb01c0c..9d0e29263438023eb0cd2d61753f5af414680102 100644
--- a/admin/src/main/java/com/wt/admin/config/prop/ModelPathProp.java
+++ b/admin/src/main/java/com/wt/admin/config/prop/ModelPathProp.java
@@ -4,6 +4,7 @@ import lombok.Data;
@Data
public class ModelPathProp {
+ private String basePath;
private String langeModel;
private String qaModel;
private String yoloModel;
diff --git a/admin/src/main/java/com/wt/admin/controller/ai/ChatController.java b/admin/src/main/java/com/wt/admin/controller/ai/ChatController.java
index fa50b324c7d3edc84f3827c73173f59da325c245..11fe7919037e1a7ac91bfb7e7368592c65e51f93 100644
--- a/admin/src/main/java/com/wt/admin/controller/ai/ChatController.java
+++ b/admin/src/main/java/com/wt/admin/controller/ai/ChatController.java
@@ -6,11 +6,15 @@ import com.aizuda.easy.security.util.LocalUtil;
import com.wt.admin.domain.dto.ai.AgentsInfoDTO;
import com.wt.admin.domain.dto.ai.ChatDTO;
import com.wt.admin.domain.dto.ai.ChatModelContentDTO;
-import com.wt.admin.domain.vo.ai.*;
+import com.wt.admin.domain.vo.ai.AgentsInfoVO;
+import com.wt.admin.domain.vo.ai.ChatModelContentVO;
import com.wt.admin.domain.vo.sys.UserVO;
import com.wt.admin.service.ai.ChatProxyService;
import jakarta.annotation.Resource;
-import org.springframework.web.bind.annotation.*;
+import org.springframework.web.bind.annotation.PostMapping;
+import org.springframework.web.bind.annotation.RequestBody;
+import org.springframework.web.bind.annotation.RequestMapping;
+import org.springframework.web.bind.annotation.RestController;
import reactor.core.publisher.Flux;
import java.util.List;
@@ -52,13 +56,6 @@ public class ChatController {
@PostMapping("question")
public Flux question(@RequestBody ChatDTO data) {
UserVO user = LocalUtil.getUser();
- if(data.getChatConfig().getEnableMCP()){
- String content = chatProxyService.question(data, user)
- .call()
- .content();
- chatProxyService.reply(content,data,user);
- return Flux.just(content);
- }
final StringBuilder msg = new StringBuilder();
return chatProxyService.question(data,user)
.stream()
diff --git a/admin/src/main/java/com/wt/admin/controller/image/ImageClassificationController.java b/admin/src/main/java/com/wt/admin/controller/image/ImageClassificationController.java
index 82066915648501a77eea9bb2821da1169a21acfa..3853dbdf7443eb68155edb1640560ef2ef5976af 100644
--- a/admin/src/main/java/com/wt/admin/controller/image/ImageClassificationController.java
+++ b/admin/src/main/java/com/wt/admin/controller/image/ImageClassificationController.java
@@ -10,7 +10,6 @@ import com.wt.admin.domain.vo.image.ImageClassificationListVO;
import com.wt.admin.domain.vo.image.ImageClassificationVO;
import com.wt.admin.domain.vo.image.ImageFeaturesVO;
import com.wt.admin.domain.vo.image.ImageItemVO;
-import com.wt.admin.domain.vo.sys.UserVO;
import com.wt.admin.service.image.ImageProxyService;
import com.wt.admin.util.PageUtil;
import jakarta.annotation.Resource;
@@ -83,8 +82,8 @@ public class ImageClassificationController {
}
@PostMapping("/testTraining")
- public Rep testTraining(@RequestParam("file") MultipartFile file,@RequestParam("id") Integer id) throws IOException {
- return Rep.ok(imageProxyService.testTraining(file,id));
+ public Rep testTraining(@RequestParam("file") MultipartFile file,@RequestParam("id") Integer id,@RequestParam("model") String model) throws IOException {
+ return Rep.ok(imageProxyService.testTraining(file,id,model));
}
}
diff --git a/admin/src/main/java/com/wt/admin/controller/image/ImageVideoController.java b/admin/src/main/java/com/wt/admin/controller/image/ImageVideoController.java
index 9297a6251e6fc4d0d57a10c19bf1665ac5c4b57f..b01c308211a711d9cfced0df736c7603957b1664 100644
--- a/admin/src/main/java/com/wt/admin/controller/image/ImageVideoController.java
+++ b/admin/src/main/java/com/wt/admin/controller/image/ImageVideoController.java
@@ -42,8 +42,8 @@ public class ImageVideoController {
}
@PostMapping("/training")
- public Rep training(@RequestParam("file") MultipartFile file, @RequestParam("id") Integer id) throws IOException {
- return Rep.ok(imageProxyService.videoTraining(file,id));
+ public Rep training(@RequestParam("file") MultipartFile file, @RequestParam("id") Integer id,@RequestParam("model") String model) throws IOException {
+ return Rep.ok(imageProxyService.videoTraining(file,id,model));
}
@Deprecated
diff --git a/admin/src/main/java/com/wt/admin/controller/language/ClassificationController.java b/admin/src/main/java/com/wt/admin/controller/language/ClassificationController.java
index 539c7767b4ca77e359ca7a7d300d8c23f55204d0..38c001089a8fb4b545eea36fb43114b4aa26156f 100644
--- a/admin/src/main/java/com/wt/admin/controller/language/ClassificationController.java
+++ b/admin/src/main/java/com/wt/admin/controller/language/ClassificationController.java
@@ -7,18 +7,19 @@ import com.wt.admin.code.language.Tagging2100;
import com.wt.admin.config.aspect.annotation.LogAno;
import com.wt.admin.domain.dto.language.ClassTrainingDTO;
import com.wt.admin.domain.dto.language.ClassificationDTO;
+import com.wt.admin.domain.dto.language.ClassificationTestTrainingDTO;
import com.wt.admin.domain.dto.language.SentenceConfigDTO;
import com.wt.admin.domain.vo.language.ClassificationVO;
import com.wt.admin.domain.vo.language.ParseSentenceVO;
import com.wt.admin.service.language.LanguageProxyService;
import com.wt.admin.util.AssertUtil;
+import jakarta.annotation.Resource;
import org.dromara.easyai.config.SentenceConfig;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
-import jakarta.annotation.Resource;
import java.util.List;
import static com.wt.admin.service.language.impl.LanguageTrainingService.CLASS;
@@ -65,7 +66,7 @@ public class ClassificationController {
@LogAno(name = "测试样本")
@PostMapping("testTraining")
- public Rep testTraining(@RequestBody String data){
+ public Rep testTraining(@RequestBody ClassificationTestTrainingDTO data){
return Rep.ok(languageProxyService.testTraining(data,LocalUtil.getUser()));
}
diff --git a/admin/src/main/java/com/wt/admin/controller/language/QAController.java b/admin/src/main/java/com/wt/admin/controller/language/QAController.java
index ba24f5f3bdbcdf5bbb2b229c436800c94ffe5189..bc43364f2d3f23e335568b3f144be2a7e94d84aa 100644
--- a/admin/src/main/java/com/wt/admin/controller/language/QAController.java
+++ b/admin/src/main/java/com/wt/admin/controller/language/QAController.java
@@ -1,28 +1,31 @@
package com.wt.admin.controller.language;
-import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.util.StrUtil;
import com.aizuda.easy.security.domain.Rep;
import com.aizuda.easy.security.util.LocalUtil;
import com.wt.admin.code.language.QA2200;
import com.wt.admin.config.aspect.annotation.LogAno;
import com.wt.admin.domain.dto.language.QADTO;
+import com.wt.admin.domain.dto.language.QATestTrainingDTO;
import com.wt.admin.domain.dto.language.QATrainingDTO;
import com.wt.admin.domain.dto.language.SentenceConfigDTO;
-import com.wt.admin.domain.vo.language.*;
+import com.wt.admin.domain.vo.language.ClassificationVO;
+import com.wt.admin.domain.vo.language.QAListVO;
+import com.wt.admin.domain.vo.language.QAParseSentenceVO;
+import com.wt.admin.domain.vo.language.QAVO;
import com.wt.admin.service.ai.impl.agents.easyai.api.EasyAIApi;
import com.wt.admin.service.language.LanguageProxyService;
import com.wt.admin.util.AssertUtil;
import com.wt.admin.util.PageUtil;
+import jakarta.annotation.Resource;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
-
-import jakarta.annotation.Resource;
import reactor.core.publisher.Flux;
+import java.time.Instant;
import java.util.List;
import static com.wt.admin.service.language.impl.LanguageTrainingService.QA;
@@ -68,7 +71,7 @@ public class QAController {
@LogAno(name = "测试样本")
@PostMapping("qaTestTraining")
- public Rep qaTestTraining(@RequestBody String data){
+ public Rep qaTestTraining(@RequestBody QATestTrainingDTO data){
return Rep.ok(languageProxyService.qaTestTraining(data,LocalUtil.getUser()));
}
@@ -81,26 +84,28 @@ public class QAController {
@PostMapping("/api/chat")
EasyAIApi.ChatResponse chat(@RequestBody EasyAIApi.ChatRequest request) {
+ long start = System.currentTimeMillis();
String meg = request.messages().get(request.messages().size()-1).content();
String answer = languageProxyService.qaTestTraining(meg).getAnswer();
if(StrUtil.isBlank(answer)){
answer = "没有相关内容。";
}
return new EasyAIApi.ChatResponse(request.model(),
- null, new EasyAIApi.Message(EasyAIApi.Message.Role.ASSISTANT,answer, List.of(),List.of()), "stop", true, null, null,
- null, null, null, null);
+ Instant.now(), new EasyAIApi.Message(EasyAIApi.Message.Role.ASSISTANT,answer, List.of(),List.of()), "stop", true, System.currentTimeMillis() - start, System.currentTimeMillis() - start,
+ 0, 0L, 0, 0L);
}
@PostMapping(value = "/api/streamChat",produces = "application/json;charset=UTF-8")
Flux streamChat(@RequestBody EasyAIApi.ChatRequest request) {
+ long start = System.currentTimeMillis();
String meg = request.messages().get(request.messages().size()-1).content();
String answer = languageProxyService.qaTestTraining(meg).getAnswer();
if(StrUtil.isBlank(answer)){
answer = "没有相关内容。";
}
return Flux.just(new EasyAIApi.ChatResponse(request.model(),
- null, new EasyAIApi.Message(EasyAIApi.Message.Role.ASSISTANT,answer,List.of(),List.of()), "stop", true, null, null,
- null, null, null, null));
+ Instant.now(), new EasyAIApi.Message(EasyAIApi.Message.Role.ASSISTANT,answer,List.of(),List.of()), "stop", true, System.currentTimeMillis() - start, System.currentTimeMillis() - start,
+ 0, 0L, 0, 0L));
}
}
diff --git a/admin/src/main/java/com/wt/admin/controller/model/ModelController.java b/admin/src/main/java/com/wt/admin/controller/model/ModelController.java
index f674adbfa7f38b400a3a7b76758ec50165b998db..8c74cdc75863731dac011ce0ed3f14f248a8f0b7 100644
--- a/admin/src/main/java/com/wt/admin/controller/model/ModelController.java
+++ b/admin/src/main/java/com/wt/admin/controller/model/ModelController.java
@@ -4,7 +4,6 @@ import com.aizuda.easy.security.domain.Rep;
import com.wt.admin.config.aspect.annotation.LogAno;
import com.wt.admin.domain.dto.model.ModelListDTO;
import com.wt.admin.domain.vo.model.ModelListVO;
-import com.wt.admin.domain.vo.sys.UserVO;
import com.wt.admin.service.model.ModelProxyService;
import com.wt.admin.util.PageUtil;
import jakarta.annotation.Resource;
@@ -13,6 +12,8 @@ import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
+import java.util.List;
+
@RestController
@RequestMapping("model")
public class ModelController {
@@ -44,4 +45,10 @@ public class ModelController {
return Rep.ok(modelProxyService.modelDel(data));
}
+ @LogAno(name = "显示不同类型的模型列表")
+ @PostMapping("modelAll")
+ public Rep> modelAll(@RequestBody ModelListDTO data){
+ return Rep.ok(modelProxyService.modelAll(data));
+ }
+
}
diff --git a/admin/src/main/java/com/wt/admin/domain/dto/ai/ChatDTO.java b/admin/src/main/java/com/wt/admin/domain/dto/ai/ChatDTO.java
index e020cbdb13b9c7a2278703a850c872ad1448bd20..77c425a01619a8a85212ef15804195de40f952a5 100644
--- a/admin/src/main/java/com/wt/admin/domain/dto/ai/ChatDTO.java
+++ b/admin/src/main/java/com/wt/admin/domain/dto/ai/ChatDTO.java
@@ -29,7 +29,5 @@ public class ChatDTO {
private String prompt = "";
// 是否启用协同
private Boolean enableSynergism = false;
- // 是否启用MCP
- private Boolean enableMCP = false;
}
}
diff --git a/admin/src/main/java/com/wt/admin/domain/dto/language/ClassificationTestTrainingDTO.java b/admin/src/main/java/com/wt/admin/domain/dto/language/ClassificationTestTrainingDTO.java
new file mode 100644
index 0000000000000000000000000000000000000000..83aeedaa58609cd0d6c42c9f92814500c097cebf
--- /dev/null
+++ b/admin/src/main/java/com/wt/admin/domain/dto/language/ClassificationTestTrainingDTO.java
@@ -0,0 +1,14 @@
+package com.wt.admin.domain.dto.language;
+
+import lombok.Data;
+
+import java.util.List;
+
+@Data
+public class ClassificationTestTrainingDTO {
+
+ private String data;
+ private List ids;
+ private String model;
+
+}
diff --git a/admin/src/main/java/com/wt/admin/domain/dto/language/QATestTrainingDTO.java b/admin/src/main/java/com/wt/admin/domain/dto/language/QATestTrainingDTO.java
new file mode 100644
index 0000000000000000000000000000000000000000..b8f4912c420a4ed2b9f1b949c1d6201b52d854cc
--- /dev/null
+++ b/admin/src/main/java/com/wt/admin/domain/dto/language/QATestTrainingDTO.java
@@ -0,0 +1,15 @@
+package com.wt.admin.domain.dto.language;
+
+
+import lombok.Data;
+
+import java.util.List;
+
+@Data
+public class QATestTrainingDTO {
+
+ private Integer classificationId;
+ private List id;
+ private String data;
+ private String model;
+}
diff --git a/admin/src/main/java/com/wt/admin/domain/entity/model/ModelListEntity.java b/admin/src/main/java/com/wt/admin/domain/entity/model/ModelListEntity.java
index 80c2aca1fc497fbc20c73e8264f9a07db615243b..9ed6478885cd310be05b6148ea6f061133b085fd 100644
--- a/admin/src/main/java/com/wt/admin/domain/entity/model/ModelListEntity.java
+++ b/admin/src/main/java/com/wt/admin/domain/entity/model/ModelListEntity.java
@@ -1,7 +1,6 @@
package com.wt.admin.domain.entity.model;
-import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
@@ -13,15 +12,12 @@ import lombok.Data;
@TableName(value = "model_list")
public class ModelListEntity extends PublicEntity {
- @TableId(value = "id",type = IdType.AUTO)
- private Integer id;
+ @TableId(value = "remark")
+ private String remark;
@TableField(value = "user_id")
private Integer userId;
- @TableField(value = "remark")
- private String remark;
-
@TableField(value = "model_config",typeHandler = JacksonTypeHandler.class)
private Object modelConfig;
diff --git a/admin/src/main/java/com/wt/admin/domain/model/ImageYoloModel.java b/admin/src/main/java/com/wt/admin/domain/model/ImageYoloModel.java
new file mode 100644
index 0000000000000000000000000000000000000000..fe5b1744ef5e90d833dcc5c7d035b871f849dbd9
--- /dev/null
+++ b/admin/src/main/java/com/wt/admin/domain/model/ImageYoloModel.java
@@ -0,0 +1,12 @@
+package com.wt.admin.domain.model;
+
+import lombok.Data;
+import org.dromara.easyai.yolo.YoloConfig;
+import org.dromara.easyai.yolo.YoloModel;
+
+@Data
+public class ImageYoloModel {
+
+ private YoloModel yoloModel;
+ private YoloConfig yoloConfig;
+}
diff --git a/admin/src/main/java/com/wt/admin/domain/model/LanguageModel.java b/admin/src/main/java/com/wt/admin/domain/model/LanguageModel.java
index 796db302cf52ccc817bc792ee1d633a9e35a7866..bec351179d757ce8bbef446c645be179d99e9b38 100644
--- a/admin/src/main/java/com/wt/admin/domain/model/LanguageModel.java
+++ b/admin/src/main/java/com/wt/admin/domain/model/LanguageModel.java
@@ -1,6 +1,6 @@
package com.wt.admin.domain.model;
-import com.wt.admin.config.GlobalBeanConfig;
+import com.wt.admin.config.AIConfig;
import com.wt.admin.domain.vo.language.KeyParameterModelMapperVO;
import com.wt.admin.domain.vo.language.KeyWordModelMapperVO;
import lombok.Data;
@@ -23,6 +23,6 @@ public class LanguageModel {
// MySentenceModel
private RandomModel randomModel;
// keyword
- private List keyWordValueList;
+ private List keyWordValueList;
private SentenceConfig config;
}
diff --git a/admin/src/main/java/com/wt/admin/domain/vo/image/ImageClassificationListVO.java b/admin/src/main/java/com/wt/admin/domain/vo/image/ImageClassificationListVO.java
index fae789c2e2514d8a34fd37f4babd75df85b963bd..1d30e245a8df115811c304c416357850f35fad57 100644
--- a/admin/src/main/java/com/wt/admin/domain/vo/image/ImageClassificationListVO.java
+++ b/admin/src/main/java/com/wt/admin/domain/vo/image/ImageClassificationListVO.java
@@ -2,7 +2,7 @@ package com.wt.admin.domain.vo.image;
import com.wt.admin.domain.vo.language.ClassificationVO;
-import com.wt.admin.util.PageUtil;
+import com.wt.admin.domain.vo.model.ModelListVO;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
@@ -20,4 +20,5 @@ public class ImageClassificationListVO {
private List imageClasss;
+ private List models;
}
diff --git a/admin/src/main/java/com/wt/admin/mapper/language/xml/SentenceMapper.xml b/admin/src/main/java/com/wt/admin/mapper/language/xml/SentenceMapper.xml
index 90f85730f3eca42c86998f291c2d319cf72f2bbc..31d30118e0276a21a3bbeb87b42644ec86e31498 100644
--- a/admin/src/main/java/com/wt/admin/mapper/language/xml/SentenceMapper.xml
+++ b/admin/src/main/java/com/wt/admin/mapper/language/xml/SentenceMapper.xml
@@ -26,7 +26,7 @@
+
+
diff --git a/admin/src/main/java/com/wt/admin/service/InitService.java b/admin/src/main/java/com/wt/admin/service/InitService.java
index ffd40614b933a4a9b6390ef87deabbeabde7e0d7..4e9ce0cb3c13674097b717320815ddf9bf62d49d 100644
--- a/admin/src/main/java/com/wt/admin/service/InitService.java
+++ b/admin/src/main/java/com/wt/admin/service/InitService.java
@@ -40,7 +40,6 @@ public class InitService implements ApplicationListener {
logAspects.init();
fileService.init();
sysSettingService.init();
- publicThread.execute(() -> languageProxyService.init());
publicThread.execute(() ->imageProxyService.init());
publicThread.execute(() ->agentsProxyService.init());
publicThread.execute(() ->knowledgeProxyService.init());
diff --git a/admin/src/main/java/com/wt/admin/service/ai/impl/ChatProxyServiceImpl.java b/admin/src/main/java/com/wt/admin/service/ai/impl/ChatProxyServiceImpl.java
index 4dd8bd9b19ef1f2c94ceaff6efe90613ff1e7303..62f4f37c9a6dd81e907345132f560bf0d555111d 100644
--- a/admin/src/main/java/com/wt/admin/service/ai/impl/ChatProxyServiceImpl.java
+++ b/admin/src/main/java/com/wt/admin/service/ai/impl/ChatProxyServiceImpl.java
@@ -23,7 +23,6 @@ import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.client.advisor.vectorstore.QuestionAnswerAdvisor;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.Message;
-import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.beans.factory.annotation.Autowired;
@@ -40,7 +39,7 @@ import java.util.stream.Collectors;
public class ChatProxyServiceImpl implements ChatProxyService {
@Autowired
- private Vector vectorStore;
+ private Vector vector;
@Resource
private ChatModelContentService chatModelContentService;
@Resource
@@ -77,9 +76,9 @@ public class ChatProxyServiceImpl implements ChatProxyService {
.topK(dto.getChatConfig().getTopK())
.similarityThreshold(dto.getChatConfig().getSimilarityThreshold())
.build();
- List documents = vectorStore.getVectorStore().similaritySearch(build);
+ List documents = vector.getElasticsearchVectorStore().similaritySearch(build);
log.debug("找到相似文档 {}", documents.size());
- chatClient.advisors(QuestionAnswerAdvisor.builder(vectorStore.getVectorStore()).searchRequest(build).build());
+ chatClient.advisors(QuestionAnswerAdvisor.builder(vector.getElasticsearchVectorStore()).searchRequest(build).build());
}
chatClient.advisors(MessageChatMemoryAdvisor.builder(chatMemory).conversationId(dto.getContentId().toString()).build());
if(dto.getChatConfig().getEnableSynergism()){
diff --git a/admin/src/main/java/com/wt/admin/service/ai/impl/KnowledgeInfoServiceImpl.java b/admin/src/main/java/com/wt/admin/service/ai/impl/KnowledgeInfoServiceImpl.java
index a0fa6196cddb6dfdbbd5163b3bbdb85f8b07af1f..4c8988ab8cf909d718e86be5aa84db40c03cf50d 100644
--- a/admin/src/main/java/com/wt/admin/service/ai/impl/KnowledgeInfoServiceImpl.java
+++ b/admin/src/main/java/com/wt/admin/service/ai/impl/KnowledgeInfoServiceImpl.java
@@ -33,7 +33,7 @@ public class KnowledgeInfoServiceImpl extends ServiceImpl(){
+ vector.add(FileUtil.file(FileUtils.getImageAbsURL(path)),new HashMap(){
{
put("knowledgeTitleId",knowledgeTitleId);
put("knowledgeInfoId",knowledgeInfoEntity.getUrl());
@@ -97,7 +97,7 @@ public class KnowledgeInfoServiceImpl extends ServiceImpl vectorStore.add(FileUtil.file(FileUtils.getImageAbsURL(i.getUrl())),new HashMap(){
+ list.forEach(i -> vector.add(FileUtil.file(FileUtils.getImageAbsURL(i.getUrl())),new HashMap(){
{
put("knowledgeTitleId",knowledgeTitleId);
put("knowledgeInfoId",i.getUrl());
@@ -108,13 +108,13 @@ public class KnowledgeInfoServiceImpl extends ServiceImpl list){
- if(!(vectorStore instanceof MemoryVectorImpl)){
+ if(!(vector instanceof MemoryVectorImpl)){
return;
}
log.debug("添加知识库到内存");
list.forEach(i -> {
List infos = knowledgeInfoMapper.findList(i.getId());
- infos.forEach(j -> vectorStore.add(FileUtil.file(FileUtils.getImageAbsURL(j.getUrl())),new HashMap(){
+ infos.forEach(j -> vector.add(FileUtil.file(FileUtils.getImageAbsURL(j.getUrl())),new HashMap(){
{
put("knowledgeTitleId",i.getId());
put("knowledgeInfoId",j.getUrl());
@@ -125,7 +125,7 @@ public class KnowledgeInfoServiceImpl extends ServiceImpl 0) {
builder.defaultToolCallbacks(allTools);
}
-
-
// 不建议使用工具,1. 很多大模型不支持 2. 不方便与外部建立连接
// ToolCallback[] dateTimeTools = ToolCallbacks.from(new DateTimeTools(),new WeatherTools());
return new AgentsManager.ChatClientMapper(builder.build(),options);
diff --git a/admin/src/main/java/com/wt/admin/service/ai/impl/agents/EasyAIBuilder.java b/admin/src/main/java/com/wt/admin/service/ai/impl/agents/EasyAIBuilder.java
index 23bc682d8cef9c9e6ccfc128a261019df43f2aa0..f4cab6235478229cc0f8dde8caeb18abb95b96bb 100644
--- a/admin/src/main/java/com/wt/admin/service/ai/impl/agents/EasyAIBuilder.java
+++ b/admin/src/main/java/com/wt/admin/service/ai/impl/agents/EasyAIBuilder.java
@@ -6,9 +6,7 @@ import com.wt.admin.domain.entity.ai.ModelConfigEntity;
import com.wt.admin.service.ai.impl.agents.easyai.EasyAIChatModel;
import com.wt.admin.service.ai.impl.agents.easyai.api.EasyAIApi;
import com.wt.admin.service.ai.impl.agents.easyai.api.EasyAIOptions;
-import com.wt.admin.service.ai.impl.mcp.MCPStart;
import io.micrometer.observation.ObservationRegistry;
-import jakarta.annotation.Resource;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.ollama.management.ModelManagementOptions;
import org.springframework.beans.factory.annotation.Value;
@@ -28,7 +26,7 @@ public class EasyAIBuilder extends AbstractAgentsBuilderService{
// 查询模型信息
EasyAIOptions options = new EasyAIOptions();
options.setModel(model.getModel());
- EasyAIApi easyAIApi = new EasyAIApi("http://localhost:"+port+"/qa");
+ EasyAIApi easyAIApi = EasyAIApi.builder().baseUrl("http://localhost:"+port+"/qa").build();
EasyAIChatModel easyAIModel = new EasyAIChatModel(easyAIApi,options,
ToolCallingManager.builder().build(),
ObservationRegistry.NOOP,
diff --git a/admin/src/main/java/com/wt/admin/service/ai/impl/agents/OllamaBuilder.java b/admin/src/main/java/com/wt/admin/service/ai/impl/agents/OllamaBuilder.java
index b6584944e35129bf81d88446f5aff9b60c975034..ccc1fa8a73d188439507a2ff6dfdb033ce1c7ebe 100644
--- a/admin/src/main/java/com/wt/admin/service/ai/impl/agents/OllamaBuilder.java
+++ b/admin/src/main/java/com/wt/admin/service/ai/impl/agents/OllamaBuilder.java
@@ -3,16 +3,13 @@ package com.wt.admin.service.ai.impl.agents;
import com.wt.admin.domain.dto.ai.AgentsInfoDTO;
import com.wt.admin.domain.entity.ai.MCPEntity;
import com.wt.admin.domain.entity.ai.ModelConfigEntity;
-import com.wt.admin.service.ai.impl.mcp.MCPStart;
import io.micrometer.observation.ObservationRegistry;
-import jakarta.annotation.Resource;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.ollama.management.ModelManagementOptions;
import org.springframework.stereotype.Component;
-import org.springframework.web.reactive.function.client.WebClient;
import java.util.List;
diff --git a/admin/src/main/java/com/wt/admin/service/ai/impl/agents/OpenAIBuilder.java b/admin/src/main/java/com/wt/admin/service/ai/impl/agents/OpenAIBuilder.java
new file mode 100644
index 0000000000000000000000000000000000000000..bacc97917c418725786e41960b5beb4260447363
--- /dev/null
+++ b/admin/src/main/java/com/wt/admin/service/ai/impl/agents/OpenAIBuilder.java
@@ -0,0 +1,38 @@
+package com.wt.admin.service.ai.impl.agents;
+
+import com.wt.admin.domain.dto.ai.AgentsInfoDTO;
+import com.wt.admin.domain.entity.ai.MCPEntity;
+import com.wt.admin.domain.entity.ai.ModelConfigEntity;
+import io.micrometer.observation.ObservationRegistry;
+import org.springframework.ai.model.tool.ToolCallingManager;
+import org.springframework.ai.openai.OpenAiChatModel;
+import org.springframework.ai.openai.OpenAiChatOptions;
+import org.springframework.ai.openai.api.OpenAiApi;
+import org.springframework.retry.support.RetryTemplate;
+import org.springframework.stereotype.Component;
+
+import java.util.List;
+
+@Component("OpenAI")
+public class OpenAIBuilder extends AbstractAgentsBuilderService {
+
+ @Override
+ public AgentsManager.ChatClientMapper builder(AgentsInfoDTO agents, ModelConfigEntity model, List mcpEntities) {
+ // 查询模型信息
+ var openAiApi = OpenAiApi.builder()
+ .baseUrl(model.getBaseUrl())
+ .apiKey(model.getAppKey())
+ .build();
+ OpenAiChatOptions options = OpenAiChatOptions.builder().model(model.getModel()).build();
+ OpenAiChatModel openAi = new OpenAiChatModel(
+ openAiApi,
+ options,
+ ToolCallingManager.builder().build(),
+ RetryTemplate.defaultInstance(),
+ ObservationRegistry.NOOP
+ );
+ // 在构建完模型后,对模型做其他初始化
+ return build(openAi, agents, mcpEntities, options);
+ }
+
+}
diff --git a/admin/src/main/java/com/wt/admin/service/ai/impl/agents/easyai/EasyAIChatModel.java b/admin/src/main/java/com/wt/admin/service/ai/impl/agents/easyai/EasyAIChatModel.java
index c1749a18ab210e210e80126674243b982a8ec8f8..c8ccbdf33b976a21501b4b76805ed0677b4e70f5 100644
--- a/admin/src/main/java/com/wt/admin/service/ai/impl/agents/easyai/EasyAIChatModel.java
+++ b/admin/src/main/java/com/wt/admin/service/ai/impl/agents/easyai/EasyAIChatModel.java
@@ -5,7 +5,6 @@ import com.wt.admin.service.ai.impl.agents.easyai.api.EasyAIApi;
import com.wt.admin.service.ai.impl.agents.easyai.api.EasyAIOptions;
import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationRegistry;
-import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.AssistantMessage;
@@ -27,15 +26,17 @@ import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.tool.*;
-import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.ai.ollama.api.common.OllamaApiConstants;
import org.springframework.ai.ollama.management.ModelManagementOptions;
+import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.util.json.JsonParser;
+import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import reactor.core.publisher.Flux;
+import reactor.core.scheduler.Schedulers;
import java.time.Duration;
import java.util.*;
@@ -51,8 +52,9 @@ public class EasyAIChatModel implements ChatModel {
private final ToolCallingManager toolCallingManager;
private final ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate;
private ChatModelObservationConvention observationConvention;
+ private final RetryTemplate retryTemplate;
- public EasyAIChatModel(EasyAIApi easyAIApi, EasyAIOptions defaultOptions, ToolCallingManager toolCallingManager, ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions, ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) {
+ public EasyAIChatModel(EasyAIApi easyAIApi, EasyAIOptions defaultOptions, ToolCallingManager toolCallingManager, ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions, ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate, RetryTemplate retryTemplate) {
this.observationConvention = DEFAULT_OBSERVATION_CONVENTION;
Assert.notNull(easyAIApi, "easyAIApi must not be null");
Assert.notNull(defaultOptions, "easyAIOptions must not be null");
@@ -60,16 +62,18 @@ public class EasyAIChatModel implements ChatModel {
Assert.notNull(observationRegistry, "observationRegistry must not be null");
Assert.notNull(modelManagementOptions, "modelManagementOptions must not be null");
Assert.notNull(toolExecutionEligibilityPredicate, "toolExecutionEligibilityPredicate must not be null");
+ Assert.notNull(retryTemplate, "retryTemplate must not be null");
this.easyAIApi = easyAIApi;
this.defaultOptions = defaultOptions;
this.toolCallingManager = toolCallingManager;
this.observationRegistry = observationRegistry;
this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate;
+ this.retryTemplate = retryTemplate;
}
@Deprecated
public EasyAIChatModel(EasyAIApi easyAIApi, EasyAIOptions defaultOptions,ToolCallingManager toolCallingManager, ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
- this(easyAIApi, defaultOptions, toolCallingManager, observationRegistry, modelManagementOptions, new DefaultToolExecutionEligibilityPredicate());
+ this(easyAIApi, defaultOptions, toolCallingManager, observationRegistry, modelManagementOptions, new DefaultToolExecutionEligibilityPredicate(), RetryUtils.DEFAULT_RETRY_TEMPLATE);
logger.warn("This constructor is deprecated and will be removed in the next milestone. Please use the easyAIChatModel.Builder or the new constructor accepting ToolCallingManager instead.");
}
@@ -85,29 +89,24 @@ public class EasyAIChatModel implements ChatModel {
return this.internalStream(requestPrompt, null);
}
-
private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) {
- EasyAIApi.ChatRequest request = this.easyAIChatRequest(prompt, false);
+ EasyAIApi.ChatRequest request = this.chatRequest(prompt, false);
ChatModelObservationContext observationContext = ChatModelObservationContext.builder().prompt(prompt).provider(OllamaApiConstants.PROVIDER_NAME).build();
- ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION
- .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry)
- .observe(() -> {
- EasyAIApi.ChatResponse easyAIResponse = this.easyAIApi.chat(request);
- List toolCalls = easyAIResponse.message().toolCalls() == null ? List.of() : easyAIResponse.message().toolCalls().stream()
- .map((toolCall) -> new AssistantMessage.ToolCall("", "function", toolCall.function().name(), ModelOptionsUtils.toJsonString(toolCall.function().arguments())))
- .toList();
- AssistantMessage assistantMessage = new AssistantMessage(easyAIResponse.message().content(), Map.of(), toolCalls);
- ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.NULL;
- if (easyAIResponse.promptEvalCount() != null && easyAIResponse.evalCount() != null) {
- generationMetadata = ChatGenerationMetadata.builder().finishReason(easyAIResponse.doneReason()).build();
- }
-
- Generation generator = new Generation(assistantMessage, generationMetadata);
- ChatResponse chatResponse = new ChatResponse(List.of(generator), from(easyAIResponse, previousChatResponse));
- observationContext.setResponse(chatResponse);
- return chatResponse;
- });
- if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && response != null && response.hasToolCalls()) {
+ ChatResponse response = (ChatResponse)ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry).observe(() -> {
+ EasyAIApi.ChatResponse ollamaResponse = (EasyAIApi.ChatResponse)this.retryTemplate.execute((ctx) -> this.easyAIApi.chat(request));
+ List toolCalls = ollamaResponse.message().toolCalls() == null ? List.of() : ollamaResponse.message().toolCalls().stream().map((toolCall) -> new AssistantMessage.ToolCall("", "function", toolCall.function().name(), ModelOptionsUtils.toJsonString(toolCall.function().arguments()))).toList();
+ AssistantMessage assistantMessage = new AssistantMessage(ollamaResponse.message().content(), Map.of(), toolCalls);
+ ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.NULL;
+ if (ollamaResponse.promptEvalCount() != null && ollamaResponse.evalCount() != null) {
+ generationMetadata = ChatGenerationMetadata.builder().finishReason(ollamaResponse.doneReason()).build();
+ }
+
+ Generation generator = new Generation(assistantMessage, generationMetadata);
+ ChatResponse chatResponse = new ChatResponse(List.of(generator), from(ollamaResponse, previousChatResponse));
+ observationContext.setResponse(chatResponse);
+ return chatResponse;
+ });
+ if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
ToolExecutionResult toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
return toolExecutionResult.returnDirect() ? ChatResponse.builder().from(response).generations(ToolExecutionResult.buildGenerations(toolExecutionResult)).build() : this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), response);
} else {
@@ -116,72 +115,41 @@ public class EasyAIChatModel implements ChatModel {
}
private Flux internalStream(Prompt prompt, ChatResponse previousChatResponse) {
- return Flux.deferContextual(contextView -> {
- EasyAIApi.ChatRequest request = ollamaChatRequest(prompt, true);
+ return Flux.deferContextual((contextView) -> {
+ EasyAIApi.ChatRequest request = this.chatRequest(prompt, true);
ChatModelObservationContext observationContext = ChatModelObservationContext.builder().prompt(prompt).provider(OllamaApiConstants.PROVIDER_NAME).build();
- Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
- this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
- this.observationRegistry);
-
- observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();
-
- Flux ollamaResponse = this.easyAIApi.streamingChat(request);
-
- Flux chatResponse = ollamaResponse.map(chunk -> {
- String content = (chunk.message() != null) ? chunk.message().content() : "";
-
+ Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry);
+ observation.parentObservation((Observation)contextView.getOrDefault("micrometer.observation", (Object)null)).start();
+ Flux easyaiResponse = this.easyAIApi.streamingChat(request);
+ Flux chatResponse = easyaiResponse.map((chunk) -> {
+ String content = chunk.message() != null ? chunk.message().content() : "";
List toolCalls = List.of();
-
- // Added null checks to prevent NPE when accessing tool calls
if (chunk.message() != null && chunk.message().toolCalls() != null) {
- toolCalls = chunk.message()
- .toolCalls()
- .stream()
- .map(toolCall -> new AssistantMessage.ToolCall("", "function", toolCall.function().name(),
- ModelOptionsUtils.toJsonString(toolCall.function().arguments())))
- .toList();
+ toolCalls = chunk.message().toolCalls().stream().map((toolCall) -> new AssistantMessage.ToolCall("", "function", toolCall.function().name(), ModelOptionsUtils.toJsonString(toolCall.function().arguments()))).toList();
}
- var assistantMessage = new AssistantMessage(content, Map.of(), toolCalls);
-
+ AssistantMessage assistantMessage = new AssistantMessage(content, Map.of(), toolCalls);
ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.NULL;
if (chunk.promptEvalCount() != null && chunk.evalCount() != null) {
generationMetadata = ChatGenerationMetadata.builder().finishReason(chunk.doneReason()).build();
}
- var generator = new Generation(assistantMessage, generationMetadata);
+ Generation generator = new Generation(assistantMessage, generationMetadata);
return new ChatResponse(List.of(generator), from(chunk, previousChatResponse));
});
-
- // @formatter:off
- Flux chatResponseFlux = chatResponse.flatMap(response -> {
- if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && response.hasToolCalls()) {
- var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
- if (toolExecutionResult.returnDirect()) {
- // Return tool execution result directly to the client.
- return Flux.just(ChatResponse.builder().from(response)
- .generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
- .build());
- } else {
- // Send the tool execution result back to the model.
- return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
- response);
- }
- }
- else {
- return Flux.just(response);
- }
- })
- .doOnError(observation::error)
- .doFinally(s ->
- observation.stop()
- )
- .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
- return new MessageAggregator().aggregate(chatResponseFlux, observationContext::setResponse);
+ Flux var10000 = chatResponse.flatMap((response) -> this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response) ? Flux.defer(() -> {
+ ToolExecutionResult toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
+ return toolExecutionResult.returnDirect() ? Flux.just(ChatResponse.builder().from(response).generations(ToolExecutionResult.buildGenerations(toolExecutionResult)).build()) : this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), response);
+ }).subscribeOn(Schedulers.boundedElastic()) : Flux.just(response));
+ Objects.requireNonNull(observation);
+ Flux chatResponseFlux = var10000.doOnError(observation::error).doFinally((s) -> observation.stop()).contextWrite((ctx) -> ctx.put("micrometer.observation", observation));
+ MessageAggregator var10 = new MessageAggregator();
+ Objects.requireNonNull(observationContext);
+ return var10.aggregate(chatResponseFlux, observationContext::setResponse);
});
}
- EasyAIApi.ChatRequest ollamaChatRequest(Prompt prompt, boolean stream) {
+ EasyAIApi.ChatRequest chatRequest(Prompt prompt, boolean stream) {
List ollamaMessages = prompt.getInstructions().stream().map(message -> {
if (message instanceof UserMessage userMessage) {
@@ -395,10 +363,13 @@ public class EasyAIChatModel implements ChatModel {
private ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate;
private ObservationRegistry observationRegistry;
private ModelManagementOptions modelManagementOptions;
+ private RetryTemplate retryTemplate;
+
private Builder() {
this.observationRegistry = ObservationRegistry.NOOP;
this.modelManagementOptions = ModelManagementOptions.defaults();
+ this.retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE;
}
public Builder easyAIAPI(EasyAIApi easyAIAPI) {
@@ -433,8 +404,13 @@ public class EasyAIChatModel implements ChatModel {
return this;
}
+ public Builder retryTemplate(RetryTemplate retryTemplate) {
+ this.retryTemplate = retryTemplate;
+ return this;
+ }
+
public EasyAIChatModel build() {
- return this.toolCallingManager != null ? new EasyAIChatModel(this.easyAIAPI, this.defaultOptions, this.toolCallingManager, this.observationRegistry, this.modelManagementOptions, this.toolExecutionEligibilityPredicate) : new EasyAIChatModel(this.easyAIAPI, this.defaultOptions, EasyAIChatModel.DEFAULT_TOOL_CALLING_MANAGER, this.observationRegistry, this.modelManagementOptions, this.toolExecutionEligibilityPredicate);
+ return this.toolCallingManager != null ? new EasyAIChatModel(this.easyAIAPI, this.defaultOptions, this.toolCallingManager, this.observationRegistry, this.modelManagementOptions, this.toolExecutionEligibilityPredicate,this.retryTemplate) : new EasyAIChatModel(this.easyAIAPI, this.defaultOptions, EasyAIChatModel.DEFAULT_TOOL_CALLING_MANAGER, this.observationRegistry, this.modelManagementOptions, this.toolExecutionEligibilityPredicate,this.retryTemplate);
}
}
}
diff --git a/admin/src/main/java/com/wt/admin/service/ai/impl/agents/easyai/api/EasyAIApi.java b/admin/src/main/java/com/wt/admin/service/ai/impl/agents/easyai/api/EasyAIApi.java
index 9340c9506de69155abc68a005239b71028e26ba3..d516da282e91265cb5184092e1ef055f6703a667 100644
--- a/admin/src/main/java/com/wt/admin/service/ai/impl/agents/easyai/api/EasyAIApi.java
+++ b/admin/src/main/java/com/wt/admin/service/ai/impl/agents/easyai/api/EasyAIApi.java
@@ -1,24 +1,23 @@
package com.wt.admin.service.ai.impl.agents.easyai.api;
+import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.ai.model.ModelOptionsUtils;
-import org.springframework.ai.observation.conventions.AiProvider;
+import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.HttpHeaders;
+import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
-import org.springframework.http.client.ClientHttpResponse;
+import org.springframework.http.ResponseEntity;
import org.springframework.util.Assert;
-import org.springframework.util.StreamUtils;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClient;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
-import java.io.IOException;
-import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.time.Instant;
import java.util.List;
@@ -29,108 +28,201 @@ import java.util.function.Consumer;
public class EasyAIApi {
- public static final String PROVIDER_NAME;
public static final String REQUEST_BODY_NULL_ERROR = "The request body can not be null.";
- private static final Log logger;
- private final ResponseErrorHandler responseErrorHandler;
+ private static final Log logger = LogFactory.getLog(EasyAIApi.class);
+ private static final String DEFAULT_BASE_URL = "http://localhost:11434";
private final RestClient restClient;
private final WebClient webClient;
- /**
- * http://localhost:9000
- * @param baseUrl
- */
- public EasyAIApi(String baseUrl) {
- this(baseUrl, RestClient.builder(), WebClient.builder());
+ public static EasyAIApi.Builder builder() {
+ return new EasyAIApi.Builder();
}
- public EasyAIApi(String baseUrl, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder) {
- this.responseErrorHandler = new EasyAIResponseErrorHandler();
+ private EasyAIApi(String baseUrl, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) {
Consumer defaultHeaders = (headers) -> {
headers.setContentType(MediaType.APPLICATION_JSON);
headers.setAccept(List.of(MediaType.APPLICATION_JSON));
};
- this.restClient = restClientBuilder.baseUrl(baseUrl).defaultHeaders(defaultHeaders).build();
- this.webClient = webClientBuilder.baseUrl(baseUrl).defaultHeaders(defaultHeaders).build();
+ this.restClient = restClientBuilder.clone().baseUrl(baseUrl).defaultHeaders(defaultHeaders).defaultStatusHandler(responseErrorHandler).build();
+ this.webClient = webClientBuilder.clone().baseUrl(baseUrl).defaultHeaders(defaultHeaders).build();
}
- public ChatResponse chat(ChatRequest chatRequest) {
- Assert.notNull(chatRequest, REQUEST_BODY_NULL_ERROR);
+ public EasyAIApi.ChatResponse chat(EasyAIApi.ChatRequest chatRequest) {
+ Assert.notNull(chatRequest, "The request body can not be null.");
Assert.isTrue(!chatRequest.stream(), "Stream mode must be disabled.");
- return this.restClient.post().uri("/api/chat", new Object[0])
- .body(chatRequest)
- .retrieve()
- .onStatus(this.responseErrorHandler)
- .body(ChatResponse.class);
+ return (EasyAIApi.ChatResponse)((RestClient.RequestBodySpec)this.restClient.post().uri("/api/chat", new Object[0])).body(chatRequest).retrieve().body(EasyAIApi.ChatResponse.class);
}
- public Flux streamingChat(ChatRequest chatRequest) {
- Assert.notNull(chatRequest, REQUEST_BODY_NULL_ERROR);
+ public Flux streamingChat(EasyAIApi.ChatRequest chatRequest) {
+ Assert.notNull(chatRequest, "The request body can not be null.");
Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true.");
AtomicBoolean isInsideTool = new AtomicBoolean(false);
- return this.webClient.post().uri("/api/streamChat", new Object[0])
- .body(Mono.just(chatRequest), ChatRequest.class)
- .retrieve()
- .bodyToFlux(ChatResponse.class)
- .map((chunk) -> {
- if (EasyAIApiHelper.isStreamingToolCall(chunk)) {
- isInsideTool.set(true);
- }
- return chunk;
- })
- .windowUntil((chunk) -> {
- if (isInsideTool.get() && EasyAIApiHelper.isStreamingDone(chunk)) {
- isInsideTool.set(false);
- return true;
- } else {
- return !isInsideTool.get();
- }
- })
- .concatMapIterable((window) -> {
- Mono monoChunk = window.reduce(new ChatResponse(), (previous, current) -> EasyAIApiHelper.merge(previous, current));
- return List.of(monoChunk);
- })
- .flatMap((mono) -> mono)
- .handle((data, sink) -> {
- if (logger.isTraceEnabled()) {
- logger.trace(data);
- }
-
- sink.next(data);
- });
+ return ((WebClient.RequestBodySpec)this.webClient.post().uri("/api/chat", new Object[0])).body(Mono.just(chatRequest), EasyAIApi.ChatRequest.class).retrieve().bodyToFlux(EasyAIApi.ChatResponse.class).map((chunk) -> {
+ if (EasyAIApiHelper.isStreamingToolCall(chunk)) {
+ isInsideTool.set(true);
+ }
+
+ return chunk;
+ }).windowUntil((chunk) -> {
+ if (isInsideTool.get() && EasyAIApiHelper.isStreamingDone(chunk)) {
+ isInsideTool.set(false);
+ return true;
+ } else {
+ return !isInsideTool.get();
+ }
+ }).concatMapIterable((window) -> {
+ Mono monoChunk = window.reduce(new EasyAIApi.ChatResponse(), (previous, current) -> EasyAIApiHelper.merge(previous, current));
+ return List.of(monoChunk);
+ }).flatMap((mono) -> mono).handle((data, sink) -> {
+ if (logger.isTraceEnabled()) {
+ logger.trace(data);
+ }
+
+ sink.next(data);
+ });
+ }
+
+ public EasyAIApi.EmbeddingsResponse embed(EasyAIApi.EmbeddingsRequest embeddingsRequest) {
+ Assert.notNull(embeddingsRequest, "The request body can not be null.");
+ return (EasyAIApi.EmbeddingsResponse)((RestClient.RequestBodySpec)this.restClient.post().uri("/api/embed", new Object[0])).body(embeddingsRequest).retrieve().body(EasyAIApi.EmbeddingsResponse.class);
+ }
+
+ public EasyAIApi.ListModelResponse listModels() {
+ return (EasyAIApi.ListModelResponse)this.restClient.get().uri("/api/tags", new Object[0]).retrieve().body(EasyAIApi.ListModelResponse.class);
+ }
+
+ public EasyAIApi.ShowModelResponse showModel(EasyAIApi.ShowModelRequest showModelRequest) {
+ Assert.notNull(showModelRequest, "showModelRequest must not be null");
+ return (EasyAIApi.ShowModelResponse)((RestClient.RequestBodySpec)this.restClient.post().uri("/api/show", new Object[0])).body(showModelRequest).retrieve().body(EasyAIApi.ShowModelResponse.class);
+ }
+
+ public ResponseEntity copyModel(EasyAIApi.CopyModelRequest copyModelRequest) {
+ Assert.notNull(copyModelRequest, "copyModelRequest must not be null");
+ return ((RestClient.RequestBodySpec)this.restClient.post().uri("/api/copy", new Object[0])).body(copyModelRequest).retrieve().toBodilessEntity();
}
+ public ResponseEntity deleteModel(EasyAIApi.DeleteModelRequest deleteModelRequest) {
+ Assert.notNull(deleteModelRequest, "deleteModelRequest must not be null");
+ return ((RestClient.RequestBodySpec)this.restClient.method(HttpMethod.DELETE).uri("/api/delete", new Object[0])).body(deleteModelRequest).retrieve().toBodilessEntity();
+ }
- static {
- PROVIDER_NAME = "easyAI";
- logger = LogFactory.getLog(EasyAIApi.class);
+ public Flux pullModel(EasyAIApi.PullModelRequest pullModelRequest) {
+ Assert.notNull(pullModelRequest, "pullModelRequest must not be null");
+ Assert.isTrue(pullModelRequest.stream(), "Request must set the stream property to true.");
+ return ((WebClient.RequestBodySpec)this.webClient.post().uri("/api/pull", new Object[0])).bodyValue(pullModelRequest).retrieve().bodyToFlux(EasyAIApi.ProgressResponse.class);
}
- private static class EasyAIResponseErrorHandler implements ResponseErrorHandler {
- private EasyAIResponseErrorHandler() {
+ @JsonInclude(JsonInclude.Include.NON_NULL)
+ @JsonIgnoreProperties(
+ ignoreUnknown = true
+ )
+ public static record Message(EasyAIApi.Message.Role role, String content, List images, List toolCalls) {
+ public Message(@JsonProperty("role") EasyAIApi.Message.Role role, @JsonProperty("content") String content, @JsonProperty("images") List images, @JsonProperty("tool_calls") List toolCalls) {
+ this.role = role;
+ this.content = content;
+ this.images = images;
+ this.toolCalls = toolCalls;
+ }
+
+ public static EasyAIApi.Message.Builder builder(EasyAIApi.Message.Role role) {
+ return new EasyAIApi.Message.Builder(role);
+ }
+
+ @JsonProperty("role")
+ public EasyAIApi.Message.Role role() {
+ return this.role;
+ }
+
+ @JsonProperty("content")
+ public String content() {
+ return this.content;
+ }
+
+ @JsonProperty("images")
+ public List images() {
+ return this.images;
+ }
+
+ @JsonProperty("tool_calls")
+ public List toolCalls() {
+ return this.toolCalls;
+ }
+
+ public static enum Role {
+ @JsonProperty("system")
+ SYSTEM,
+ @JsonProperty("user")
+ USER,
+ @JsonProperty("assistant")
+ ASSISTANT,
+ @JsonProperty("tool")
+ TOOL;
+ }
+
+ @JsonInclude(JsonInclude.Include.NON_NULL)
+ public static record ToolCall(EasyAIApi.Message.ToolCallFunction function) {
+ public ToolCall(@JsonProperty("function") EasyAIApi.Message.ToolCallFunction function) {
+ this.function = function;
+ }
+
+ @JsonProperty("function")
+ public EasyAIApi.Message.ToolCallFunction function() {
+ return this.function;
+ }
}
- @Override
- public boolean hasError(ClientHttpResponse response) throws IOException {
- return response.getStatusCode().isError();
+ @JsonInclude(JsonInclude.Include.NON_NULL)
+ public static record ToolCallFunction(String name, Map arguments) {
+ public ToolCallFunction(@JsonProperty("name") String name, @JsonProperty("arguments") Map arguments) {
+ this.name = name;
+ this.arguments = arguments;
+ }
+
+ @JsonProperty("name")
+ public String name() {
+ return this.name;
+ }
+
+ @JsonProperty("arguments")
+ public Map arguments() {
+ return this.arguments;
+ }
}
- @Override
- public void handleError(ClientHttpResponse response) throws IOException {
- if (response.getStatusCode().isError()) {
- int statusCode = response.getStatusCode().value();
- String statusText = response.getStatusText();
- String message = StreamUtils.copyToString(response.getBody(), StandardCharsets.UTF_8);
- logger.warn(String.format("[%s] %s - %s", statusCode, statusText, message));
- throw new RuntimeException(String.format("[%s] %s - %s", statusCode, statusText, message));
+ public static class Builder {
+ private final EasyAIApi.Message.Role role;
+ private String content;
+ private List images;
+ private List toolCalls;
+
+ public Builder(EasyAIApi.Message.Role role) {
+ this.role = role;
+ }
+
+ public EasyAIApi.Message.Builder content(String content) {
+ this.content = content;
+ return this;
+ }
+
+ public EasyAIApi.Message.Builder images(List images) {
+ this.images = images;
+ return this;
+ }
+
+ public EasyAIApi.Message.Builder toolCalls(List toolCalls) {
+ this.toolCalls = toolCalls;
+ return this;
+ }
+
+ public EasyAIApi.Message build() {
+ return new EasyAIApi.Message(this.role, this.content, this.images, this.toolCalls);
}
}
}
@JsonInclude(JsonInclude.Include.NON_NULL)
- public static record ChatRequest(String model, List messages, Boolean stream, Object format,
- String keepAlive, List tools, Map options) {
- public ChatRequest(@JsonProperty("model") String model, @JsonProperty("messages") List messages, @JsonProperty("stream") Boolean stream, @JsonProperty("format") Object format, @JsonProperty("keep_alive") String keepAlive, @JsonProperty("tools") List tools, @JsonProperty("options") Map options) {
+ public static record ChatRequest(String model, List messages, Boolean stream, Object format, String keepAlive, List tools, Map options) {
+ public ChatRequest(@JsonProperty("model") String model, @JsonProperty("messages") List messages, @JsonProperty("stream") Boolean stream, @JsonProperty("format") Object format, @JsonProperty("keep_alive") String keepAlive, @JsonProperty("tools") List tools, @JsonProperty("options") Map options) {
this.model = model;
this.messages = messages;
this.stream = stream;
@@ -140,8 +232,8 @@ public class EasyAIApi {
this.options = options;
}
- public static Builder builder(String model) {
- return new Builder(model);
+ public static EasyAIApi.ChatRequest.Builder builder(String model) {
+ return new EasyAIApi.ChatRequest.Builder(model);
}
@JsonProperty("model")
@@ -150,7 +242,7 @@ public class EasyAIApi {
}
@JsonProperty("messages")
- public List messages() {
+ public List messages() {
return this.messages;
}
@@ -170,7 +262,7 @@ public class EasyAIApi {
}
@JsonProperty("tools")
- public List tools() {
+ public List tools() {
return this.tools;
}
@@ -179,89 +271,30 @@ public class EasyAIApi {
return this.options;
}
- public static class Builder {
- private final String model;
- private List messages = List.of();
- private boolean stream = false;
- private Object format;
- private String keepAlive;
- private List tools = List.of();
- private Map options = Map.of();
-
- public Builder(String model) {
- Assert.notNull(model, "The model can not be null.");
- this.model = model;
- }
-
- public Builder messages(List messages) {
- this.messages = messages;
- return this;
- }
-
- public Builder stream(boolean stream) {
- this.stream = stream;
- return this;
- }
-
- public Builder format(Object format) {
- this.format = format;
- return this;
- }
-
- public Builder keepAlive(String keepAlive) {
- this.keepAlive = keepAlive;
- return this;
- }
-
- public Builder tools(List tools) {
- this.tools = tools;
- return this;
- }
-
- public Builder options(Map options) {
- Objects.requireNonNull(options, "The options can not be null.");
- this.options = EasyAIOptions.filterNonSupportedFields(options);
- return this;
- }
-
- public Builder options(EasyAIOptions options) {
- Objects.requireNonNull(options, "The options can not be null.");
- this.options = EasyAIOptions.filterNonSupportedFields(options.toMap());
- return this;
- }
-
- public ChatRequest build() {
- return new ChatRequest(this.model, this.messages, this.stream, this.format, this.keepAlive, this.tools, this.options);
- }
- }
-
@JsonInclude(JsonInclude.Include.NON_NULL)
- public static record Tool(Type type, Function function) {
- public Tool(Function function) {
- this(Type.FUNCTION, function);
+ public static record Tool(EasyAIApi.ChatRequest.Tool.Type type, EasyAIApi.ChatRequest.Tool.Function function) {
+ public Tool(EasyAIApi.ChatRequest.Tool.Function function) {
+ this(EasyAIApi.ChatRequest.Tool.Type.FUNCTION, function);
}
- public Tool(@JsonProperty("type") Type type, @JsonProperty("function") Function function) {
+ public Tool(@JsonProperty("type") EasyAIApi.ChatRequest.Tool.Type type, @JsonProperty("function") EasyAIApi.ChatRequest.Tool.Function function) {
this.type = type;
this.function = function;
}
@JsonProperty("type")
- public Type type() {
+ public EasyAIApi.ChatRequest.Tool.Type type() {
return this.type;
}
@JsonProperty("function")
- public Function function() {
+ public EasyAIApi.ChatRequest.Tool.Function function() {
return this.function;
}
public static enum Type {
@JsonProperty("function")
FUNCTION;
-
- private Type() {
- }
}
public static record Function(String name, String description, Map parameters) {
@@ -291,29 +324,85 @@ public class EasyAIApi {
}
}
}
+
+ public static class Builder {
+ private final String model;
+ private List messages = List.of();
+ private boolean stream = false;
+ private Object format;
+ private String keepAlive;
+ private List tools = List.of();
+ private Map options = Map.of();
+
+ public Builder(String model) {
+ Assert.notNull(model, "The model can not be null.");
+ this.model = model;
+ }
+
+ public EasyAIApi.ChatRequest.Builder messages(List messages) {
+ this.messages = messages;
+ return this;
+ }
+
+ public EasyAIApi.ChatRequest.Builder stream(boolean stream) {
+ this.stream = stream;
+ return this;
+ }
+
+ public EasyAIApi.ChatRequest.Builder format(Object format) {
+ this.format = format;
+ return this;
+ }
+
+ public EasyAIApi.ChatRequest.Builder keepAlive(String keepAlive) {
+ this.keepAlive = keepAlive;
+ return this;
+ }
+
+ public EasyAIApi.ChatRequest.Builder tools(List tools) {
+ this.tools = tools;
+ return this;
+ }
+
+ public EasyAIApi.ChatRequest.Builder options(Map options) {
+ Objects.requireNonNull(options, "The options can not be null.");
+ this.options = EasyAIOptions.filterNonSupportedFields(options);
+ return this;
+ }
+
+ public EasyAIApi.ChatRequest.Builder options(EasyAIOptions options) {
+ Objects.requireNonNull(options, "The options can not be null.");
+ this.options = EasyAIOptions.filterNonSupportedFields(options.toMap());
+ return this;
+ }
+
+ public EasyAIApi.ChatRequest build() {
+ return new EasyAIApi.ChatRequest(this.model, this.messages, this.stream, this.format, this.keepAlive, this.tools, this.options);
+ }
+ }
}
@JsonInclude(JsonInclude.Include.NON_NULL)
- public static record ChatResponse(String model, Instant createdAt, Message message, String doneReason, Boolean done,
- Long totalDuration, Long loadDuration, Integer promptEvalCount,
- Long promptEvalDuration, Integer evalCount, Long evalDuration) {
+ @JsonIgnoreProperties(
+ ignoreUnknown = true
+ )
+ public static record ChatResponse(String model, Instant createdAt, EasyAIApi.Message message, String doneReason, Boolean done, Long totalDuration, Long loadDuration, Integer promptEvalCount, Long promptEvalDuration, Integer evalCount, Long evalDuration) {
ChatResponse() {
- this(null, null, null, null, null, null,
- null,null, null, null, null);
+ this((String)null, (Instant)null, (EasyAIApi.Message)null, (String)null, (Boolean)null, (Long)null, (Long)null, (Integer)null, (Long)null, (Integer)null, (Long)null);
}
- public ChatResponse(@JsonProperty("model") String model, @JsonProperty("created_at") Instant createdAt, @JsonProperty("message") Message message, @JsonProperty("done_reason") String doneReason, @JsonProperty("done") Boolean done, @JsonProperty("total_duration") Long totalDuration, @JsonProperty("load_duration") Long loadDuration, @JsonProperty("prompt_eval_count") Integer promptEvalCount, @JsonProperty("prompt_eval_duration") Long promptEvalDuration, @JsonProperty("eval_count") Integer evalCount, @JsonProperty("eval_duration") Long evalDuration) {
- this.model = model; // 模型名称(如使用的AI模型类型)
- this.createdAt = createdAt; // 响应生成时间戳
- this.message = message; // 包含对话内容的消息对象
- this.doneReason = doneReason; // 任务结束原因(如完成/错误原因)
- this.done = done; // 任务是否完成的标记
- this.totalDuration = totalDuration; // 总处理时长(纳秒)
- this.loadDuration = loadDuration; // 模型加载耗时
- this.promptEvalCount = promptEvalCount; // 提示词评估次数
- this.promptEvalDuration = promptEvalDuration; // 提示词评估耗时
- this.evalCount = evalCount; // 推理次数
- this.evalDuration = evalDuration; // 推理总耗时
+ public ChatResponse(@JsonProperty("model") String model, @JsonProperty("created_at") Instant createdAt, @JsonProperty("message") EasyAIApi.Message message, @JsonProperty("done_reason") String doneReason, @JsonProperty("done") Boolean done, @JsonProperty("total_duration") Long totalDuration, @JsonProperty("load_duration") Long loadDuration, @JsonProperty("prompt_eval_count") Integer promptEvalCount, @JsonProperty("prompt_eval_duration") Long promptEvalDuration, @JsonProperty("eval_count") Integer evalCount, @JsonProperty("eval_duration") Long evalDuration) {
+ this.model = model;
+ this.createdAt = createdAt;
+ this.message = message;
+ this.doneReason = doneReason;
+ this.done = done;
+ this.totalDuration = totalDuration;
+ this.loadDuration = loadDuration;
+ this.promptEvalCount = promptEvalCount;
+ this.promptEvalDuration = promptEvalDuration;
+ this.evalCount = evalCount;
+ this.evalDuration = evalDuration;
}
public Duration getTotalDuration() {
@@ -343,7 +432,7 @@ public class EasyAIApi {
}
@JsonProperty("message")
- public Message message() {
+ public EasyAIApi.Message message() {
return this.message;
}
@@ -389,8 +478,50 @@ public class EasyAIApi {
}
@JsonInclude(JsonInclude.Include.NON_NULL)
- public static record EmbeddingsResponse(String model, List embeddings, Long totalDuration,
- Long loadDuration, Integer promptEvalCount) {
+ public static record EmbeddingsRequest(String model, List input, Duration keepAlive, Map options, Boolean truncate) {
+ public EmbeddingsRequest(String model, String input) {
+ this(model, List.of(input), (Duration)null, (Map)null, (Boolean)null);
+ }
+
+ public EmbeddingsRequest(@JsonProperty("model") String model, @JsonProperty("input") List input, @JsonProperty("keep_alive") Duration keepAlive, @JsonProperty("options") Map options, @JsonProperty("truncate") Boolean truncate) {
+ this.model = model;
+ this.input = input;
+ this.keepAlive = keepAlive;
+ this.options = options;
+ this.truncate = truncate;
+ }
+
+ @JsonProperty("model")
+ public String model() {
+ return this.model;
+ }
+
+ @JsonProperty("input")
+ public List input() {
+ return this.input;
+ }
+
+ @JsonProperty("keep_alive")
+ public Duration keepAlive() {
+ return this.keepAlive;
+ }
+
+ @JsonProperty("options")
+ public Map options() {
+ return this.options;
+ }
+
+ @JsonProperty("truncate")
+ public Boolean truncate() {
+ return this.truncate;
+ }
+ }
+
+ @JsonInclude(JsonInclude.Include.NON_NULL)
+ @JsonIgnoreProperties(
+ ignoreUnknown = true
+ )
+ public static record EmbeddingsResponse(String model, List embeddings, Long totalDuration, Long loadDuration, Integer promptEvalCount) {
public EmbeddingsResponse(@JsonProperty("model") String model, @JsonProperty("embeddings") List embeddings, @JsonProperty("total_duration") Long totalDuration, @JsonProperty("load_duration") Long loadDuration, @JsonProperty("prompt_eval_count") Integer promptEvalCount) {
this.model = model;
this.embeddings = embeddings;
@@ -426,36 +557,164 @@ public class EasyAIApi {
}
@JsonInclude(JsonInclude.Include.NON_NULL)
- public static record ListModelResponse(List models) {
- public ListModelResponse(@JsonProperty("models") List models) {
- this.models = models;
+ @JsonIgnoreProperties(
+ ignoreUnknown = true
+ )
+ public static record Model(String name, String model, Instant modifiedAt, Long size, String digest, EasyAIApi.Model.Details details) {
+ public Model(@JsonProperty("name") String name, @JsonProperty("model") String model, @JsonProperty("modified_at") Instant modifiedAt, @JsonProperty("size") Long size, @JsonProperty("digest") String digest, @JsonProperty("details") EasyAIApi.Model.Details details) {
+ this.name = name;
+ this.model = model;
+ this.modifiedAt = modifiedAt;
+ this.size = size;
+ this.digest = digest;
+ this.details = details;
}
- @JsonProperty("models")
- public List models() {
- return this.models;
+ @JsonProperty("name")
+ public String name() {
+ return this.name;
}
- }
- @JsonInclude(JsonInclude.Include.NON_NULL)
- public static record ShowModelResponse(String license, String modelfile, String parameters, String template,
- String system, Model.Details details, List messages,
- Map modelInfo, Map projectorInfo,
- Instant modifiedAt) {
- public ShowModelResponse(@JsonProperty("license") String license, @JsonProperty("modelfile") String modelfile, @JsonProperty("parameters") String parameters, @JsonProperty("template") String template, @JsonProperty("system") String system, @JsonProperty("details") Model.Details details, @JsonProperty("messages") List messages, @JsonProperty("model_info") Map modelInfo, @JsonProperty("projector_info") Map projectorInfo, @JsonProperty("modified_at") Instant modifiedAt) {
- this.license = license;
- this.modelfile = modelfile;
- this.parameters = parameters;
- this.template = template;
- this.system = system;
- this.details = details;
- this.messages = messages;
- this.modelInfo = modelInfo;
- this.projectorInfo = projectorInfo;
- this.modifiedAt = modifiedAt;
+ @JsonProperty("model")
+ public String model() {
+ return this.model;
}
- @JsonProperty("license")
+ @JsonProperty("modified_at")
+ public Instant modifiedAt() {
+ return this.modifiedAt;
+ }
+
+ @JsonProperty("size")
+ public Long size() {
+ return this.size;
+ }
+
+ @JsonProperty("digest")
+ public String digest() {
+ return this.digest;
+ }
+
+ @JsonProperty("details")
+ public EasyAIApi.Model.Details details() {
+ return this.details;
+ }
+
+ @JsonInclude(JsonInclude.Include.NON_NULL)
+ @JsonIgnoreProperties(
+ ignoreUnknown = true
+ )
+ public static record Details(String parentModel, String format, String family, List families, String parameterSize, String quantizationLevel) {
+ public Details(@JsonProperty("parent_model") String parentModel, @JsonProperty("format") String format, @JsonProperty("family") String family, @JsonProperty("families") List families, @JsonProperty("parameter_size") String parameterSize, @JsonProperty("quantization_level") String quantizationLevel) {
+ this.parentModel = parentModel;
+ this.format = format;
+ this.family = family;
+ this.families = families;
+ this.parameterSize = parameterSize;
+ this.quantizationLevel = quantizationLevel;
+ }
+
+ @JsonProperty("parent_model")
+ public String parentModel() {
+ return this.parentModel;
+ }
+
+ @JsonProperty("format")
+ public String format() {
+ return this.format;
+ }
+
+ @JsonProperty("family")
+ public String family() {
+ return this.family;
+ }
+
+ @JsonProperty("families")
+ public List families() {
+ return this.families;
+ }
+
+ @JsonProperty("parameter_size")
+ public String parameterSize() {
+ return this.parameterSize;
+ }
+
+ @JsonProperty("quantization_level")
+ public String quantizationLevel() {
+ return this.quantizationLevel;
+ }
+ }
+ }
+
+ @JsonInclude(JsonInclude.Include.NON_NULL)
+ @JsonIgnoreProperties(
+ ignoreUnknown = true
+ )
+ public static record ListModelResponse(List models) {
+ public ListModelResponse(@JsonProperty("models") List models) {
+ this.models = models;
+ }
+
+ @JsonProperty("models")
+ public List models() {
+ return this.models;
+ }
+ }
+
+ @JsonInclude(JsonInclude.Include.NON_NULL)
+ public static record ShowModelRequest(String model, String system, Boolean verbose, Map options) {
+ public ShowModelRequest(String model) {
+ this(model, (String)null, (Boolean)null, (Map)null);
+ }
+
+ public ShowModelRequest(@JsonProperty("model") String model, @JsonProperty("system") String system, @JsonProperty("verbose") Boolean verbose, @JsonProperty("options") Map options) {
+ this.model = model;
+ this.system = system;
+ this.verbose = verbose;
+ this.options = options;
+ }
+
+ @JsonProperty("model")
+ public String model() {
+ return this.model;
+ }
+
+ @JsonProperty("system")
+ public String system() {
+ return this.system;
+ }
+
+ @JsonProperty("verbose")
+ public Boolean verbose() {
+ return this.verbose;
+ }
+
+ @JsonProperty("options")
+ public Map options() {
+ return this.options;
+ }
+ }
+
+ @JsonInclude(JsonInclude.Include.NON_NULL)
+ @JsonIgnoreProperties(
+ ignoreUnknown = true
+ )
+ public static record ShowModelResponse(String license, String modelfile, String parameters, String template, String system, EasyAIApi.Model.Details details, List messages, Map modelInfo, Map projectorInfo, List capabilities, Instant modifiedAt) {
+ public ShowModelResponse(@JsonProperty("license") String license, @JsonProperty("modelfile") String modelfile, @JsonProperty("parameters") String parameters, @JsonProperty("template") String template, @JsonProperty("system") String system, @JsonProperty("details") EasyAIApi.Model.Details details, @JsonProperty("messages") List messages, @JsonProperty("model_info") Map modelInfo, @JsonProperty("projector_info") Map projectorInfo, @JsonProperty("capabilities") List capabilities, @JsonProperty("modified_at") Instant modifiedAt) {
+ this.license = license;
+ this.modelfile = modelfile;
+ this.parameters = parameters;
+ this.template = template;
+ this.system = system;
+ this.details = details;
+ this.messages = messages;
+ this.modelInfo = modelInfo;
+ this.projectorInfo = projectorInfo;
+ this.capabilities = capabilities;
+ this.modifiedAt = modifiedAt;
+ }
+
+ @JsonProperty("license")
public String license() {
return this.license;
}
@@ -481,12 +740,12 @@ public class EasyAIApi {
}
@JsonProperty("details")
- public Model.Details details() {
+ public EasyAIApi.Model.Details details() {
return this.details;
}
@JsonProperty("messages")
- public List messages() {
+ public List messages() {
return this.messages;
}
@@ -500,6 +759,11 @@ public class EasyAIApi {
return this.projectorInfo;
}
+ @JsonProperty("capabilities")
+ public List capabilities() {
+ return this.capabilities;
+ }
+
@JsonProperty("modified_at")
public Instant modifiedAt() {
return this.modifiedAt;
@@ -507,11 +771,40 @@ public class EasyAIApi {
}
@JsonInclude(JsonInclude.Include.NON_NULL)
- public static record PullModelRequest(String model, boolean insecure, String username, String password,
- boolean stream) {
+ public static record CopyModelRequest(String source, String destination) {
+ public CopyModelRequest(@JsonProperty("source") String source, @JsonProperty("destination") String destination) {
+ this.source = source;
+ this.destination = destination;
+ }
+
+ @JsonProperty("source")
+ public String source() {
+ return this.source;
+ }
+
+ @JsonProperty("destination")
+ public String destination() {
+ return this.destination;
+ }
+ }
+
+ @JsonInclude(JsonInclude.Include.NON_NULL)
+ public static record DeleteModelRequest(String model) {
+ public DeleteModelRequest(@JsonProperty("model") String model) {
+ this.model = model;
+ }
+
+ @JsonProperty("model")
+ public String model() {
+ return this.model;
+ }
+ }
+
+ @JsonInclude(JsonInclude.Include.NON_NULL)
+ public static record PullModelRequest(String model, boolean insecure, String username, String password, boolean stream) {
public PullModelRequest(@JsonProperty("model") String model, @JsonProperty("insecure") boolean insecure, @JsonProperty("username") String username, @JsonProperty("password") String password, @JsonProperty("stream") boolean stream) {
if (!stream) {
- logger.warn("Enforcing streaming of the model pull request");
+ EasyAIApi.logger.warn("Enforcing streaming of the model pull request");
}
stream = true;
@@ -523,7 +816,7 @@ public class EasyAIApi {
}
public PullModelRequest(String model) {
- this(model, false, (String) null, (String) null, true);
+ this(model, false, (String)null, (String)null, true);
}
@JsonProperty("model")
@@ -553,6 +846,9 @@ public class EasyAIApi {
}
@JsonInclude(JsonInclude.Include.NON_NULL)
+ @JsonIgnoreProperties(
+ ignoreUnknown = true
+ )
public static record ProgressResponse(String status, String digest, Long total, Long completed) {
public ProgressResponse(@JsonProperty("status") String status, @JsonProperty("digest") String digest, @JsonProperty("total") Long total, @JsonProperty("completed") Long completed) {
this.status = status;
@@ -582,304 +878,42 @@ public class EasyAIApi {
}
}
- @JsonInclude(JsonInclude.Include.NON_NULL)
- public static record DeleteModelRequest(String model) {
- public DeleteModelRequest(@JsonProperty("model") String model) {
- this.model = model;
- }
-
- @JsonProperty("model")
- public String model() {
- return this.model;
- }
- }
-
- @JsonInclude(JsonInclude.Include.NON_NULL)
- public static record CopyModelRequest(String source, String destination) {
- public CopyModelRequest(@JsonProperty("source") String source, @JsonProperty("destination") String destination) {
- this.source = source;
- this.destination = destination;
- }
+ public static class Builder {
+ private String baseUrl = "http://localhost:8083/qa";
+ private RestClient.Builder restClientBuilder = RestClient.builder();
+ private WebClient.Builder webClientBuilder = WebClient.builder();
+ private ResponseErrorHandler responseErrorHandler;
- @JsonProperty("source")
- public String source() {
- return this.source;
+ public Builder() {
+ this.responseErrorHandler = RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER;
}
- @JsonProperty("destination")
- public String destination() {
- return this.destination;
+ public EasyAIApi.Builder baseUrl(String baseUrl) {
+ Assert.hasText(baseUrl, "baseUrl cannot be null or empty");
+ this.baseUrl = baseUrl;
+ return this;
}
- }
- @JsonInclude(JsonInclude.Include.NON_NULL)
- public static record ShowModelRequest(String model, String system, Boolean verbose, Map options) {
- public ShowModelRequest(String model) {
- this(model, (String) null, (Boolean) null, (Map) null);
+ public EasyAIApi.Builder restClientBuilder(RestClient.Builder restClientBuilder) {
+ Assert.notNull(restClientBuilder, "restClientBuilder cannot be null");
+ this.restClientBuilder = restClientBuilder;
+ return this;
}
- public ShowModelRequest(@JsonProperty("model") String model, @JsonProperty("system") String system, @JsonProperty("verbose") Boolean verbose, @JsonProperty("options") Map options) {
- this.model = model;
- this.system = system;
- this.verbose = verbose;
- this.options = options;
+ public EasyAIApi.Builder webClientBuilder(WebClient.Builder webClientBuilder) {
+ Assert.notNull(webClientBuilder, "webClientBuilder cannot be null");
+ this.webClientBuilder = webClientBuilder;
+ return this;
}
- @JsonProperty("model")
- public String model() {
- return this.model;
+ public EasyAIApi.Builder responseErrorHandler(ResponseErrorHandler responseErrorHandler) {
+ Assert.notNull(responseErrorHandler, "responseErrorHandler cannot be null");
+ this.responseErrorHandler = responseErrorHandler;
+ return this;
}
- @JsonProperty("system")
- public String system() {
- return this.system;
- }
-
- @JsonProperty("verbose")
- public Boolean verbose() {
- return this.verbose;
- }
-
- @JsonProperty("options")
- public Map options() {
- return this.options;
+ public EasyAIApi build() {
+ return new EasyAIApi(this.baseUrl, this.restClientBuilder, this.webClientBuilder, this.responseErrorHandler);
}
}
-
- @JsonInclude(JsonInclude.Include.NON_NULL)
- public static record Model(String name, String model, Instant modifiedAt, Long size, String digest,
- Details details) {
- public Model(@JsonProperty("name") String name, @JsonProperty("model") String model, @JsonProperty("modified_at") Instant modifiedAt, @JsonProperty("size") Long size, @JsonProperty("digest") String digest, @JsonProperty("details") Details details) {
- this.name = name;
- this.model = model;
- this.modifiedAt = modifiedAt;
- this.size = size;
- this.digest = digest;
- this.details = details;
- }
-
- @JsonProperty("name")
- public String name() {
- return this.name;
- }
-
- @JsonProperty("model")
- public String model() {
- return this.model;
- }
-
- @JsonProperty("modified_at")
- public Instant modifiedAt() {
- return this.modifiedAt;
- }
-
- @JsonProperty("size")
- public Long size() {
- return this.size;
- }
-
- @JsonProperty("digest")
- public String digest() {
- return this.digest;
- }
-
- @JsonProperty("details")
- public Details details() {
- return this.details;
- }
-
- @JsonInclude(JsonInclude.Include.NON_NULL)
- public static record Details(String parentModel, String format, String family, List families,
- String parameterSize, String quantizationLevel) {
- public Details(@JsonProperty("parent_model") String parentModel, @JsonProperty("format") String format, @JsonProperty("family") String family, @JsonProperty("families") List families, @JsonProperty("parameter_size") String parameterSize, @JsonProperty("quantization_level") String quantizationLevel) {
- this.parentModel = parentModel;
- this.format = format;
- this.family = family;
- this.families = families;
- this.parameterSize = parameterSize;
- this.quantizationLevel = quantizationLevel;
- }
-
- @JsonProperty("parent_model")
- public String parentModel() {
- return this.parentModel;
- }
-
- @JsonProperty("format")
- public String format() {
- return this.format;
- }
-
- @JsonProperty("family")
- public String family() {
- return this.family;
- }
-
- @JsonProperty("families")
- public List families() {
- return this.families;
- }
-
- @JsonProperty("parameter_size")
- public String parameterSize() {
- return this.parameterSize;
- }
-
- @JsonProperty("quantization_level")
- public String quantizationLevel() {
- return this.quantizationLevel;
- }
- }
- }
-
- @JsonInclude(JsonInclude.Include.NON_NULL)
- public static record EmbeddingsRequest(String model, List input, Duration keepAlive,
- Map options, Boolean truncate) {
- public EmbeddingsRequest(String model, String input) {
- this(model, List.of(input), (Duration) null, (Map) null, (Boolean) null);
- }
-
- public EmbeddingsRequest(@JsonProperty("model") String model, @JsonProperty("input") List input, @JsonProperty("keep_alive") Duration keepAlive, @JsonProperty("options") Map options, @JsonProperty("truncate") Boolean truncate) {
- this.model = model;
- this.input = input;
- this.keepAlive = keepAlive;
- this.options = options;
- this.truncate = truncate;
- }
-
- @JsonProperty("model")
- public String model() {
- return this.model;
- }
-
- @JsonProperty("input")
- public List input() {
- return this.input;
- }
-
- @JsonProperty("keep_alive")
- public Duration keepAlive() {
- return this.keepAlive;
- }
-
- @JsonProperty("options")
- public Map options() {
- return this.options;
- }
-
- @JsonProperty("truncate")
- public Boolean truncate() {
- return this.truncate;
- }
- }
-
- @JsonInclude(JsonInclude.Include.NON_NULL)
- public static record Message(Role role, String content, List images,
- List toolCalls) {
- public Message(@JsonProperty("role") Role role, @JsonProperty("content") String content, @JsonProperty("images") List images, @JsonProperty("tool_calls") List toolCalls) {
- this.role = role;
- this.content = content;
- this.images = images;
- this.toolCalls = toolCalls;
- }
-
- public static Builder builder(Role role) {
- return new Builder(role);
- }
-
- @JsonProperty("role")
- public Role role() {
- return this.role;
- }
-
- @JsonProperty("content")
- public String content() {
- return this.content;
- }
-
- @JsonProperty("images")
- public List images() {
- return this.images;
- }
-
- @JsonProperty("tool_calls")
- public List toolCalls() {
- return this.toolCalls;
- }
-
- public static enum Role {
- @JsonProperty("system")
- SYSTEM,
- @JsonProperty("user")
- USER,
- @JsonProperty("assistant")
- ASSISTANT,
- @JsonProperty("tool")
- TOOL;
-
- private Role() {
- }
- }
-
- public static class Builder {
- private final Role role;
- private String content;
- private List images;
- private List toolCalls;
-
- public Builder(Role role) {
- this.role = role;
- }
-
- public Builder content(String content) {
- this.content = content;
- return this;
- }
-
- public Builder images(List images) {
- this.images = images;
- return this;
- }
-
- public Builder toolCalls(List toolCalls) {
- this.toolCalls = toolCalls;
- return this;
- }
-
- public Message build() {
- return new Message(this.role, this.content, this.images, this.toolCalls);
- }
- }
-
- @JsonInclude(JsonInclude.Include.NON_NULL)
- public static record ToolCallFunction(String name, Map arguments) {
- public ToolCallFunction(@JsonProperty("name") String name, @JsonProperty("arguments") Map arguments) {
- this.name = name;
- this.arguments = arguments;
- }
-
- @JsonProperty("name")
- public String name() {
- return this.name;
- }
-
- @JsonProperty("arguments")
- public Map arguments() {
- return this.arguments;
- }
- }
-
- @JsonInclude(JsonInclude.Include.NON_NULL)
- public static record ToolCall(ToolCallFunction function) {
- public ToolCall(@JsonProperty("function") ToolCallFunction function) {
- this.function = function;
- }
-
- @JsonProperty("function")
- public ToolCallFunction function() {
- return this.function;
- }
- }
- }
-
}
diff --git a/admin/src/main/java/com/wt/admin/service/ai/impl/mcp/MCPTransport.java b/admin/src/main/java/com/wt/admin/service/ai/impl/mcp/MCPTransport.java
index 13d8b118f9920928d53d7d224ca5f3d7cc146def..36c5a1ade88c03cac7b2c4434567ea4daa896eb4 100644
--- a/admin/src/main/java/com/wt/admin/service/ai/impl/mcp/MCPTransport.java
+++ b/admin/src/main/java/com/wt/admin/service/ai/impl/mcp/MCPTransport.java
@@ -3,7 +3,6 @@ package com.wt.admin.service.ai.impl.mcp;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
-import io.modelcontextprotocol.client.transport.ServerParameters;
import io.modelcontextprotocol.client.transport.StdioClientTransport;
import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport;
import org.springframework.ai.mcp.client.autoconfigure.NamedClientMcpTransport;
diff --git a/admin/src/main/java/com/wt/admin/service/image/ImageProxyService.java b/admin/src/main/java/com/wt/admin/service/image/ImageProxyService.java
index b71cf8bcb721142c7e0191d3729a1397bb7283e9..ddb0b47fb862b473f398d6e4055c6fc743351d1c 100644
--- a/admin/src/main/java/com/wt/admin/service/image/ImageProxyService.java
+++ b/admin/src/main/java/com/wt/admin/service/image/ImageProxyService.java
@@ -32,7 +32,7 @@ public interface ImageProxyService {
void imageTraining(ImageClassificationDTO data, UserVO user);
- String testTraining(MultipartFile file,Integer id);
+ String testTraining(MultipartFile file,Integer id,String model);
void init();
@@ -40,7 +40,7 @@ public interface ImageProxyService {
ImageVideoVO delVideo(List data);
- String videoTraining(MultipartFile file, Integer id);
+ String videoTraining(MultipartFile file, Integer id,String model);
String videoTraining(String url, Integer id);
diff --git a/admin/src/main/java/com/wt/admin/service/image/impl/ImageProxyServiceImpl.java b/admin/src/main/java/com/wt/admin/service/image/impl/ImageProxyServiceImpl.java
index a7f776ba52126b747cffde3eca90baf223f22dc1..f25d8982c3024f05db01ed0efc5a68d15755877e 100644
--- a/admin/src/main/java/com/wt/admin/service/image/impl/ImageProxyServiceImpl.java
+++ b/admin/src/main/java/com/wt/admin/service/image/impl/ImageProxyServiceImpl.java
@@ -1,6 +1,5 @@
package com.wt.admin.service.image.impl;
-import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.io.FileUtil;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.RandomUtil;
@@ -31,8 +30,8 @@ import jakarta.annotation.Resource;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import net.coobird.thumbnailator.Thumbnails;
+import org.dromara.easyai.yolo.FastYolo;
import org.dromara.easyai.yolo.OutBox;
-import org.dromara.easyai.yolo.YoloConfig;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
@@ -145,9 +144,12 @@ public class ImageProxyServiceImpl implements ImageProxyService, UploadFileServi
@Override
public ImageClassificationListVO list(ImageClassificationDTO data) {
+ ModelListDTO modelListDTO = new ModelListDTO();
+ modelListDTO.setModelType(ConstVar.ModelType.IMAGE);
return ImageClassificationListVO.builder()
.imageClasss(imageClassificationService.list(data))
.classs(classificationService.list(null))
+ .models(modelListService.modelAll(modelListDTO))
.build();
}
@@ -166,7 +168,7 @@ public class ImageProxyServiceImpl implements ImageProxyService, UploadFileServi
data.setTraining(true);
send(data.getTag() + "训练", 30);
imageClassificationService.edit(data);
- imageTrainingService.training(data, imageItemVOList,(b) ->
+ imageTrainingService.training(data, imageItemVOList, data.getRemark(),(b) ->
modelListService.addModel(user.getId(),data.getRemark(),data.getYoloConfig(), ConstVar.ModelType.IMAGE));
send(data.getTag() + "训练", 90);
imageItemService.updateBatch(imageItemVOList);
@@ -179,7 +181,7 @@ public class ImageProxyServiceImpl implements ImageProxyService, UploadFileServi
@SneakyThrows
@Override
- public String testTraining(MultipartFile file, Integer id) {
+ public String testTraining(MultipartFile file, Integer id,String model) {
String name = file.getOriginalFilename();
String path = "test/";
name = RandomUtil.randomNumbers(12) + (name.contains(".") ? name.substring(name.indexOf(".")) : "");
@@ -188,13 +190,13 @@ public class ImageProxyServiceImpl implements ImageProxyService, UploadFileServi
file.transferTo(tempFile);
ImageClassificationEntity byId = imageClassificationService.byId(id);
checkImagePixel(Map.of(absPath, file), byId);
- imageTrainingService.testTraining(id, tempFile, indexProp.getFiles().getSavePath() + path + name);
+ imageTrainingService.testTraining(model, tempFile, indexProp.getFiles().getSavePath() + path + name);
return indexProp.getFiles().getSource().replace("*", "") + path + "/" + name;
}
@SneakyThrows
@Override
- public String videoTraining(MultipartFile file, Integer id) {
+ public String videoTraining(MultipartFile file, Integer id,String model) {
String name = file.getOriginalFilename();
String path = "test/";
String random = RandomUtil.randomNumbers(12);
@@ -204,11 +206,15 @@ public class ImageProxyServiceImpl implements ImageProxyService, UploadFileServi
String newPath = indexProp.getFiles().getSavePath() + path + m3u8;
ImageClassificationEntity byId = imageClassificationService.byId(id);
VideoUtil.resizeVideoWH(file.getInputStream(), oldPath, byId.getImageWidth().intValue(), byId.getImageHeight().intValue());
+ FastYolo fastYolo = imageTrainingService.toFastYolo(model);
+ if(ObjectUtil.isEmpty(fastYolo)){
+ return null;
+ }
publicThread.execute(() -> {
try {
VideoUtil.videoToStream(oldPath, newPath, byId.getImageWidth().intValue(), byId.getImageHeight().intValue(), (img) -> {
InputStream inputStream = VideoUtil.bufferedImageToInputStream(img);
- List outBoxes = imageTrainingService.testTraining(id, inputStream);
+ List outBoxes = imageTrainingService.testTraining(fastYolo, inputStream);
Graphics2D g2d = img.createGraphics();
outBoxes.forEach(i -> g2d.drawRect(i.getX(), i.getY(), i.getWidth(), i.getHeight()));
return g2d;
diff --git a/admin/src/main/java/com/wt/admin/service/image/impl/ImageTrainingService.java b/admin/src/main/java/com/wt/admin/service/image/impl/ImageTrainingService.java
index 1d221980d30ee59188d75e0857030212a6612fc6..74d6801a9d2602a054b809cad750fd4d0db140d0 100644
--- a/admin/src/main/java/com/wt/admin/service/image/impl/ImageTrainingService.java
+++ b/admin/src/main/java/com/wt/admin/service/image/impl/ImageTrainingService.java
@@ -6,16 +6,17 @@ import cn.hutool.core.io.FileUtil;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.json.JSONUtil;
import com.wt.admin.config.ConstVar;
-import com.wt.admin.config.cache.Cache;
import com.wt.admin.config.prop.IndexProp;
import com.wt.admin.config.socket.WebSocketSessionManager;
import com.wt.admin.domain.dto.image.ImageClassificationDTO;
import com.wt.admin.domain.entity.image.ImageItemEntity;
import com.wt.admin.domain.entity.model.ModelListEntity;
+import com.wt.admin.domain.model.ImageYoloModel;
import com.wt.admin.domain.vo.image.ImageClassificationVO;
import com.wt.admin.domain.vo.socket.ProgressVO;
import com.wt.admin.domain.vo.socket.SocketVO;
import com.wt.admin.util.FileUtils;
+import jakarta.annotation.Resource;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.dromara.easyai.entity.ThreeChannelMatrix;
@@ -24,7 +25,6 @@ import org.dromara.easyai.tools.Picture;
import org.dromara.easyai.yolo.*;
import org.springframework.stereotype.Service;
-import jakarta.annotation.Resource;
import java.io.File;
import java.io.InputStream;
import java.util.ArrayList;
@@ -35,15 +35,12 @@ import java.util.function.Function;
@Slf4j
public class ImageTrainingService {
- @Resource
- private Cache imageYoloManager;
@Resource
private IndexProp indexProp;
@SneakyThrows
- public void training(ImageClassificationDTO data, List imageItemVOList, Function fun) {
+ public void training(ImageClassificationDTO data, List imageItemVOList,String modelName, Function fun) {
FastYolo fastYolo = new FastYolo(data.getYoloConfig());
- imageYoloManager.put(data.getId(),fastYolo);
List yoloSamples = new ArrayList<>();
imageItemVOList.forEach(image -> {
String imageAbsURL = FileUtils.getImageAbsURL(image.getImageUrl());
@@ -64,8 +61,11 @@ public class ImageTrainingService {
});
fastYolo.toStudy(yoloSamples);
YoloModel model = fastYolo.getModel();
+ ImageYoloModel imageYoloModel = new ImageYoloModel();
+ imageYoloModel.setYoloConfig(data.getYoloConfig());
+ imageYoloModel.setYoloModel(model);
ModelListEntity entity = fun.apply(true);
- FileUtil.writeUtf8String(JSONUtil.toJsonStr(model), String.format(indexProp.getModelPath().getYoloModel(), data.getId()));
+ FileUtil.writeUtf8String(JSONUtil.toJsonStr(imageYoloModel), String.format(indexProp.getModelPath().getYoloModel(), modelName));
}
@SneakyThrows
@@ -74,7 +74,6 @@ public class ImageTrainingService {
return;
}
FastYolo fastYolo = new FastYolo(data.getYoloConfig());
- imageYoloManager.put(data.getId(), fastYolo);
File file = FileUtil.file(String.format(indexProp.getModelPath().getYoloModel(),data.getId()));
if(!file.exists()){
return;
@@ -84,8 +83,8 @@ public class ImageTrainingService {
}
@SneakyThrows
- public List testTraining(Integer id,String fromUrl,String toUrl) {
- FastYolo fastYolo = imageYoloManager.get(id);
+ public List testTraining(String model,String fromUrl,String toUrl) {
+ FastYolo fastYolo = toFastYolo(model);
ThreeChannelMatrix matrix = Picture.getThreeMatrix(fromUrl, false);
List look = fastYolo.look(matrix, System.currentTimeMillis());
if(CollUtil.isEmpty(look)){
@@ -97,8 +96,8 @@ public class ImageTrainingService {
}
@SneakyThrows
- public List testTraining(Integer id,File file,String toUrl) {
- FastYolo fastYolo = imageYoloManager.get(id);
+ public List testTraining(String model,File file,String toUrl) {
+ FastYolo fastYolo = toFastYolo(model);
ThreeChannelMatrix matrix = Picture.getThreeMatrix(file, false);
List look = fastYolo.look(matrix, System.currentTimeMillis());
if(CollUtil.isEmpty(look)){
@@ -110,8 +109,8 @@ public class ImageTrainingService {
}
@SneakyThrows
- public List testTraining(Integer id,File file) {
- FastYolo fastYolo = imageYoloManager.get(id);
+ public List testTraining(String model,File file) {
+ FastYolo fastYolo = toFastYolo(model);
ThreeChannelMatrix matrix = Picture.getThreeMatrix(file, false);
List look = fastYolo.look(matrix, System.currentTimeMillis());
if(CollUtil.isEmpty(look)){
@@ -121,8 +120,8 @@ public class ImageTrainingService {
}
@SneakyThrows
- public List testTraining(Integer id, InputStream file) {
- FastYolo fastYolo = imageYoloManager.get(id);
+ public List testTraining(String model, InputStream file) {
+ FastYolo fastYolo = toFastYolo(model);
ThreeChannelMatrix matrix = Picture.getThreeMatrix(file, false);
List look = fastYolo.look(matrix, System.currentTimeMillis());
if(CollUtil.isEmpty(look)){
@@ -132,8 +131,18 @@ public class ImageTrainingService {
}
@SneakyThrows
- public List testTraining(Integer id,String path) {
- FastYolo fastYolo = imageYoloManager.get(id);
+ public List testTraining(FastYolo fastYolo, InputStream file) {
+ ThreeChannelMatrix matrix = Picture.getThreeMatrix(file, false);
+ List look = fastYolo.look(matrix, System.currentTimeMillis());
+ if(CollUtil.isEmpty(look)){
+ look = new ArrayList<>();
+ }
+ return look;
+ }
+
+ @SneakyThrows
+ public List testTraining(String model,String path) {
+ FastYolo fastYolo = toFastYolo(model);
ThreeChannelMatrix matrix = Picture.getThreeMatrix(path, false);
List look = fastYolo.look(matrix, System.currentTimeMillis());
if(CollUtil.isEmpty(look)){
@@ -142,6 +151,18 @@ public class ImageTrainingService {
return look;
}
+ public FastYolo toFastYolo(String model) throws Exception {
+ File file = FileUtil.file(String.format(indexProp.getModelPath().getYoloModel(),model));
+ if(!file.exists()){
+ return null;
+ }
+ String json = FileUtil.readUtf8String(file);
+ ImageYoloModel bean = JSONUtil.toBean(json, ImageYoloModel.class);
+ FastYolo fastYolo = new FastYolo(bean.getYoloConfig());
+ fastYolo.insertModel(bean.getYoloModel());
+ return fastYolo;
+ }
+
private void send(int current){
WebSocketSessionManager.sendToAll(new SocketVO
diff --git a/admin/src/main/java/com/wt/admin/service/language/LanguageProxyService.java b/admin/src/main/java/com/wt/admin/service/language/LanguageProxyService.java
index e8e364a1121721b45d743976a0db44fdc912eb40..bf2cef8a08595f1858709898ba3685e14547250e 100644
--- a/admin/src/main/java/com/wt/admin/service/language/LanguageProxyService.java
+++ b/admin/src/main/java/com/wt/admin/service/language/LanguageProxyService.java
@@ -23,19 +23,18 @@ public interface LanguageProxyService {
void classTraining(ClassTrainingDTO data, UserVO userVO);
- ParseSentenceVO testTraining(String data, UserVO user);
+ ParseSentenceVO testTraining(ClassificationTestTrainingDTO data, UserVO user);
void qaTraining(QATrainingDTO data, UserVO user);
- QAParseSentenceVO qaTestTraining(String data, UserVO user);
+ QAParseSentenceVO qaTestTraining(QATestTrainingDTO data, UserVO user);
QAParseSentenceVO qaTestTraining(String data);
- void init();
-
ModelListVO classTest(ModelListDTO data);
ModelListVO qaTest(ModelListDTO data);
SentenceConfigDTO findConfig(String tag);
+
}
diff --git a/admin/src/main/java/com/wt/admin/service/language/impl/LanguageProxyServiceImpl.java b/admin/src/main/java/com/wt/admin/service/language/impl/LanguageProxyServiceImpl.java
index a99755aa939679f1787b449503aa2178424461bb..2b3985172286a54dbe992a3af14f056d32106cfb 100644
--- a/admin/src/main/java/com/wt/admin/service/language/impl/LanguageProxyServiceImpl.java
+++ b/admin/src/main/java/com/wt/admin/service/language/impl/LanguageProxyServiceImpl.java
@@ -1,12 +1,9 @@
package com.wt.admin.service.language.impl;
-import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.util.ObjectUtil;
import com.wt.admin.code.language.QA2200;
import com.wt.admin.code.language.Tagging2100;
import com.wt.admin.config.ConstVar;
-import com.wt.admin.config.GlobalBeanConfig;
-import com.wt.admin.config.cache.Cache;
import com.wt.admin.config.socket.WebSocketSessionManager;
import com.wt.admin.domain.dto.language.*;
import com.wt.admin.domain.dto.model.ModelListDTO;
@@ -17,13 +14,15 @@ import com.wt.admin.domain.vo.model.ModelListVO;
import com.wt.admin.domain.vo.socket.ProgressVO;
import com.wt.admin.domain.vo.socket.SocketVO;
import com.wt.admin.domain.vo.sys.UserVO;
-import com.wt.admin.service.language.*;
+import com.wt.admin.service.language.ClassificationService;
+import com.wt.admin.service.language.LanguageProxyService;
+import com.wt.admin.service.language.QAService;
+import com.wt.admin.service.language.TaggingService;
import com.wt.admin.service.model.ModelListService;
import com.wt.admin.util.AssertUtil;
import com.wt.admin.util.PageUtil;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
-import org.dromara.easyai.config.SentenceConfig;
import org.dromara.easyai.config.TfConfig;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
@@ -31,8 +30,6 @@ import org.springframework.transaction.annotation.Transactional;
import java.util.List;
-import static com.wt.admin.service.language.impl.LanguageTrainingService.QA;
-
@Slf4j
@Service
public class LanguageProxyServiceImpl implements LanguageProxyService {
@@ -47,21 +44,6 @@ public class LanguageProxyServiceImpl implements LanguageProxyService {
private LanguageTrainingService languageTrainingService;
@Resource
private ModelListService modelListService;
- @Resource
- private Cache wordEmbeddings;
-
-
- @Override
- public void init(){
- log.debug("初始化关键词和QA数据");
- List list = taggingService.findByIds(null);
- WebSocketSessionManager.sendToAll(new SocketVO
- (ConstVar.Socket.PROGRESS,ProgressVO.set("初始化训练",0)));
- languageTrainingService.sentenceAndKeywordInit(list);
- log.debug("关键词初始化结束");
- languageTrainingService.QAInit();
- log.debug("QA初始化结束");
- }
@Override
@@ -80,18 +62,9 @@ public class LanguageProxyServiceImpl implements LanguageProxyService {
@Override
public SentenceConfigDTO findConfig(String tag) {
- GlobalBeanConfig.WordAndRRManager wordAndRRManager = wordEmbeddings.get(tag);
- SentenceConfig sentenceConfig = wordAndRRManager.getSentenceConfig();
SentenceConfigDTO sentenceConfigDTO = new SentenceConfigDTO();
- if(ObjectUtil.isNotEmpty(sentenceConfig)){
- BeanUtil.copyProperties(sentenceConfig,sentenceConfigDTO);
- }
- TfConfig tf = wordAndRRManager.getTfConfig();
- if(ObjectUtil.isEmpty(tf)){
- TfConfig tfConfig = new TfConfig();
- tf = BeanUtil.copyProperties(tfConfig, TfConfig.class);
- }
- sentenceConfigDTO.setTf(tf);
+ // TODO 查询最近一次训练的配置
+ sentenceConfigDTO.setTf(new TfConfig());
return sentenceConfigDTO;
}
@@ -143,7 +116,7 @@ public class LanguageProxyServiceImpl implements LanguageProxyService {
// 根据ids 查询所有样本
List list = taggingService.findByIds(data.getId());
// 对样本标记状态更新
- languageTrainingService.classTraining(list,data.getConfig(),(flag) -> {
+ languageTrainingService.classTraining(list,data.getConfig(),data.getRemark(),(flag) -> {
if(!flag){
// 训练失败
log.debug("准确率 < 0.1 没有价值,因该修改配置参数去调整");
@@ -163,11 +136,12 @@ public class LanguageProxyServiceImpl implements LanguageProxyService {
}
@Override
- public ParseSentenceVO testTraining(String data, UserVO user) {
- AssertUtil.Str.isEmpty(data, Tagging2100.CODE_2105);
+ public ParseSentenceVO testTraining(ClassificationTestTrainingDTO data, UserVO user) {
+ AssertUtil.Str.isEmpty(data.getData(), Tagging2100.CODE_2105);
ParseSentenceVO parseSentenceVO = new ParseSentenceVO();
try {
- parseSentenceVO = languageTrainingService.parseSentence(data, LanguageTrainingService.CLASS);
+ List list = taggingService.findByIds(data.getIds());
+ parseSentenceVO = languageTrainingService.parseSentence(list,data, LanguageTrainingService.CLASS);
ClassificationEntity classE = classificationService.findById(parseSentenceVO.getId());
if (ObjectUtil.isEmpty(classE)) {
return parseSentenceVO;
@@ -188,7 +162,7 @@ public class LanguageProxyServiceImpl implements LanguageProxyService {
send("QA-训练样本", 0);
// 根据ids 查询所有样本
List list = qaService.findByIds(data.getId());
- languageTrainingService.QATraining(list,data.getConfig(),(b) -> modelListService.addModel(user.getId(),data.getRemark() ,data.getConfig(),ConstVar.ModelType.QA));
+ languageTrainingService.QATraining(list,data.getConfig(),data.getRemark(),(b) -> modelListService.addModel(user.getId(),data.getRemark() ,data.getConfig(),ConstVar.ModelType.QA));
// 对样本标记状态更新
// taggingService.updateBatch(list);
send( null, 100);
@@ -201,8 +175,8 @@ public class LanguageProxyServiceImpl implements LanguageProxyService {
}
@Override
- public QAParseSentenceVO qaTestTraining(String data, UserVO user) {
- AssertUtil.Str.isEmpty(data, QA2200.CODE_2206);
+ public QAParseSentenceVO qaTestTraining(QATestTrainingDTO data, UserVO user) {
+ AssertUtil.Str.isEmpty(data.getData(), QA2200.CODE_2206);
QAParseSentenceVO qa = new QAParseSentenceVO();
qa.setAnswer(languageTrainingService.qaParseSentence(data));
return qa;
@@ -210,7 +184,9 @@ public class LanguageProxyServiceImpl implements LanguageProxyService {
@Override
public QAParseSentenceVO qaTestTraining(String data) {
- return qaTestTraining(data,null);
+ QATestTrainingDTO qaTestTrainingDTO = new QATestTrainingDTO();
+ qaTestTrainingDTO.setData(data);
+ return qaTestTraining(qaTestTrainingDTO,null);
}
private void send(String title, int progress){
diff --git a/admin/src/main/java/com/wt/admin/service/language/impl/LanguageTrainingService.java b/admin/src/main/java/com/wt/admin/service/language/impl/LanguageTrainingService.java
index 193757142c452b1d5e48840d354a808a5a7117b7..181cb12d7ad309c18f2cfd867d7b086a2f33413d 100644
--- a/admin/src/main/java/com/wt/admin/service/language/impl/LanguageTrainingService.java
+++ b/admin/src/main/java/com/wt/admin/service/language/impl/LanguageTrainingService.java
@@ -9,18 +9,21 @@ import cn.hutool.json.JSONUtil;
import com.aizuda.easy.security.code.BasicCode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.wt.admin.config.ConstVar;
-import com.wt.admin.config.GlobalBeanConfig;
-import com.wt.admin.config.cache.Cache;
import com.wt.admin.config.prop.IndexProp;
import com.wt.admin.config.socket.WebSocketSessionManager;
+import com.wt.admin.domain.dto.language.ClassificationTestTrainingDTO;
import com.wt.admin.domain.dto.language.KeyWordForSentenceDTO;
+import com.wt.admin.domain.dto.language.QATestTrainingDTO;
import com.wt.admin.domain.dto.language.SentenceConfigDTO;
import com.wt.admin.domain.entity.language.KeywordsEntity;
import com.wt.admin.domain.entity.language.QAEntity;
import com.wt.admin.domain.entity.model.ModelListEntity;
import com.wt.admin.domain.model.LanguageModel;
import com.wt.admin.domain.model.QAModel;
-import com.wt.admin.domain.vo.language.*;
+import com.wt.admin.domain.vo.language.KeyParameterModelMapperVO;
+import com.wt.admin.domain.vo.language.KeyWordModelMapperVO;
+import com.wt.admin.domain.vo.language.ParseSentenceVO;
+import com.wt.admin.domain.vo.language.SentenceVO;
import com.wt.admin.domain.vo.socket.ProgressVO;
import com.wt.admin.domain.vo.socket.SocketVO;
import com.wt.admin.util.AssertUtil;
@@ -37,11 +40,11 @@ import org.dromara.easyai.naturalLanguage.languageCreator.KeyWordModel;
import org.dromara.easyai.naturalLanguage.word.MyKeyWord;
import org.dromara.easyai.naturalLanguage.word.WordEmbedding;
import org.dromara.easyai.rnnJumpNerveCenter.RRNerveManager;
-import org.dromara.easyai.transFormer.model.TransFormerModel;
import org.springframework.stereotype.Service;
import java.io.File;
import java.util.*;
+import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@@ -51,16 +54,8 @@ import java.util.regex.Pattern;
@Slf4j
public class LanguageTrainingService {
- @Resource
- private Cache wordEmbeddings;
@Resource
private IndexProp indexProp;
- @Resource
- private Cache myKeyWordMap;
- @Resource
- private Cache catchKeyWordMap;
- @Resource
- private Cache>> sensorKeyWordMapper;
private final ObjectMapper mapper = new ObjectMapper();
public static final String CLASS = "classification";
public static final String QA = "qa";
@@ -69,19 +64,17 @@ public class LanguageTrainingService {
* 语句和关键词的学习
*/
@SneakyThrows
- public void classTraining(List list, SentenceConfigDTO config, Function fun){
- GlobalBeanConfig.WordAndRRManager wordAndRRManager = wordEmbeddings.get(CLASS);
- WordEmbedding wordEmbedding = wordAndRRManager.getWordEmbedding();
+ public void classTraining(List list, SentenceConfigDTO config,String remark, Function fun) {
AssertUtil.objIsNull(config, BasicCode.BASIC_CODE_99999);
AssertUtil.List.isEmpty(list, BasicCode.BASIC_CODE_99999);
- AssertUtil.objIsNull(wordEmbedding, BasicCode.BASIC_CODE_99999);
Collections.shuffle(list);
List sentence = new ArrayList<>(list.size());
Map> typeIdBySentences = new HashMap<>();
+ Map>> sensorKeyWordMapper = new ConcurrentHashMap<>();
for (int i = 0; i < list.size(); i++) {
SentenceVO sentenceVO = list.get(i);
- Cache> cache = sensorKeyWordMapper
- .computeIfAbsent(sentenceVO.getClassificationId(), k -> new Cache<>());
+ Map> cache = sensorKeyWordMapper
+ .computeIfAbsent(sentenceVO.getClassificationId(), k -> new HashMap<>());
// 获取语句
sentence.add(sentenceVO.getSentence());
// 对语句进行类型分组
@@ -99,29 +92,29 @@ public class LanguageTrainingService {
key.add(keyWordForSentence);
});
}
+ WordEmbedding wordEmbedding = new WordEmbedding();
+ RRNerveManager rrNerveManager = new RRNerveManager(wordEmbedding);
LanguageModel models = new LanguageModel();
config.setTypeNub(typeIdBySentences.keySet().size());
models.setConfig(config);
send(30);
wordEmbedding.setConfig(config);
- log.debug("训练配置信息:{}",JSONUtil.toJsonStr(config));
- RRNerveManager rrNerveManager = wordAndRRManager.getRrNerveManager();
+ log.debug("训练配置信息:{}", JSONUtil.toJsonStr(config));
rrNerveManager.init(config);
- wordEmbedding(sentence,config,wordEmbedding,models);
- wordAndRRManager.setSentenceConfig(config);
+ wordEmbedding(sentence, config, wordEmbedding, models);
log.debug("随机神经网络学习 每个分类样本不够300条,则重复数据到300条,20 * 300 = 6000");
models.setRandomModel(rrNerveManager.studyType(typeIdBySentences));
- keyWordMapperMap(config,wordEmbedding,models);
+ keyWordMapperMap(config, wordEmbedding, models,sensorKeyWordMapper);
boolean b = selfChecking(list, rrNerveManager);
- if(b){
+ if (b) {
ModelListEntity entity = fun.apply(b);
- if(ObjectUtil.isNotEmpty(entity)){
- FileUtil.writeUtf8String(JSONUtil.toJsonPrettyStr(models),indexProp.getModelPath().getLangeModel());
+ if (ObjectUtil.isNotEmpty(entity)) {
+ FileUtil.writeUtf8String(JSONUtil.toJsonPrettyStr(models), String.format(indexProp.getModelPath().getLangeModel(),remark));
}
}
}
- private boolean selfChecking(List list,RRNerveManager rrNerveManager) throws Exception {
+ private boolean selfChecking(List list, RRNerveManager rrNerveManager) throws Exception {
int num = 0;
for (int i = 0; i < list.size(); i++) {
SentenceVO sentence = list.get(i);
@@ -134,21 +127,19 @@ public class LanguageTrainingService {
return point > 0.1;
}
- /**
- * 语句及关键词初始化
- */
- @SneakyThrows
- public void sentenceAndKeywordInit(List list){
- GlobalBeanConfig.WordAndRRManager wordAndRRManager = wordEmbeddings.computeIfAbsent(CLASS
- , k -> new GlobalBeanConfig.WordAndRRManager());
- WordEmbedding wordEmbedding = wordAndRRManager.getWordEmbedding();
- if(CollUtil.isEmpty(list)){
- return;
+ public ParseSentenceVO parseSentence(List list,ClassificationTestTrainingDTO data, String tag) throws Exception {
+
+ File file = new File(String.format(indexProp.getModelPath().getLangeModel(),data.getModel()));
+ if(!file.exists()){
+ return null;
}
- for (SentenceVO sentenceVO : list) {
- Cache> cache = sensorKeyWordMapper
- .computeIfAbsent(sentenceVO.getClassificationId(), k -> new Cache<>());
- for (Object k : sentenceVO.getTaggings()) {
+
+ Map>> sensorKeyWordMapper = new ConcurrentHashMap<>();
+ for (int i = 0; i < list.size(); i++) {
+ SentenceVO sentenceVO = list.get(i);
+ Map> cache = sensorKeyWordMapper
+ .computeIfAbsent(sentenceVO.getClassificationId(), k -> new HashMap<>());
+ sentenceVO.getTaggings().forEach(k -> {
KeywordsEntity keywordsEntity = BeanUtil.toBean(k, KeywordsEntity.class);
KeyWordForSentenceDTO keyWordForSentence = new KeyWordForSentenceDTO();
keyWordForSentence.setSentence(sentenceVO.getSentence());
@@ -156,36 +147,29 @@ public class LanguageTrainingService {
keyWordForSentence.setId(keywordsEntity.getId());
keyWordForSentence.setReply(keywordsEntity.getReplies());
List key = cache.computeIfAbsent(keywordsEntity.getId(), j -> new ArrayList<>());
- if (!key.isEmpty()) {
- continue;
- }
key.add(keyWordForSentence);
- }
- }
- File file = FileUtil.file(indexProp.getModelPath().getLangeModel());
- if(!file.exists()){
- return;
+ });
}
+
+ WordEmbedding wordEmbedding = new WordEmbedding();
+ RRNerveManager rrNerveManager = new RRNerveManager(wordEmbedding);
LanguageModel model = JSONUtil.toBean(FileUtil.readUtf8String(file), LanguageModel.class);
SentenceConfig config = model.getConfig();
- wordAndRRManager.setSentenceConfig(config);
wordEmbedding.setConfig(config);
- RRNerveManager rrNerveManager = wordAndRRManager.getRrNerveManager();
rrNerveManager.init(config);
wordEmbedding.insertModel(model.getWordTwoVectorModel(), config.getWordVectorDimension());
rrNerveManager.insertModel(model.getRandomModel());
- keyWordMapperMapDeserialize(config,wordEmbedding,model);
- keyWordDeserialize(model);
- send(100);
- }
- public ParseSentenceVO parseSentence(String data,String tag) throws Exception {
+ Map myKeyWordMap = new HashMap<>();
+ keyWordMapperMapDeserialize(config,wordEmbedding,model,myKeyWordMap);
+
+ Map catchKeyWordMap = new HashMap<>();
+ keyWordDeserialize(model,catchKeyWordMap);
+
// 获得语句对应的id
- GlobalBeanConfig.WordAndRRManager wordAndRRManager = wordEmbeddings.get(tag);
- RRNerveManager rrNerveManager = wordAndRRManager.getRrNerveManager();
- int type = rrNerveManager.getType(data, System.currentTimeMillis());
+ int type = rrNerveManager.getType(data.getData(), System.currentTimeMillis());
MyKeyWord myKeyWord = myKeyWordMap.get(type);
- if(ObjectUtil.isEmpty(myKeyWord)){
+ if (ObjectUtil.isEmpty(myKeyWord)) {
return new ParseSentenceVO(type);
}
// 语句是否有关键词
@@ -194,31 +178,27 @@ public class LanguageTrainingService {
// return new ParseSentenceVO(type);
// }
List keyWordList = new ArrayList<>();
- Cache> integerListCache = sensorKeyWordMapper.get(type);
- integerListCache.forEach((key, value) -> {
- CatchKeyWord catchKeyWord = catchKeyWordMap.get(key);
- if(ObjectUtil.isEmpty(catchKeyWord)){
+ Map> integerListMap = sensorKeyWordMapper.get(type);
+ integerListMap.forEach((K,V) -> {
+ CatchKeyWord catchKeyWord = catchKeyWordMap.get(K);
+ if (ObjectUtil.isEmpty(catchKeyWord)) {
return;
}
- KeyWordForSentenceDTO keyWordForSentenceDTO = value.get(0);
- Set keyWordSet = catchKeyWord.getKeyWord(data);
+ KeyWordForSentenceDTO keyWordForSentenceDTO = V.get(0);
+ Set keyWordSet = catchKeyWord.getKeyWord(data.getData());
boolean b = keyWordSet.stream().anyMatch(StrUtil::isBlank);
String str = null;
- if(keyWordSet.isEmpty() || b){
+ if (keyWordSet.isEmpty() || b) {
str = keyWordForSentenceDTO.getReply();
}
- keyWordList.add(new ParseSentenceVO.Keywords(key,keyWordSet,str));
+ keyWordList.add(new ParseSentenceVO.Keywords(K, keyWordSet, str));
});
- return new ParseSentenceVO(type,keyWordList);
+ return new ParseSentenceVO(type, keyWordList);
}
@SneakyThrows
- public void QATraining(List list, SentenceConfigDTO config, Function fun){
- GlobalBeanConfig.WordAndRRManager wordAndRRManager = wordEmbeddings.get(QA);
- WordEmbedding wordEmbedding = new WordEmbedding();
- wordAndRRManager.setWordEmbedding(wordEmbedding);
- if(ObjectUtil.isEmpty(config) || CollUtil.isEmpty(list) ||
- ObjectUtil.isEmpty(wordEmbedding)){
+ public void QATraining(List list, SentenceConfigDTO config,String remark, Function fun) {
+ if (ObjectUtil.isEmpty(config) || CollUtil.isEmpty(list)) {
return;
}
Pattern pattern = Pattern.compile("@我:\\s*(.*?)(?=@AI:|$)@AI:\\s*(.*?)(?=@我:|$)", Pattern.DOTALL);
@@ -237,67 +217,42 @@ public class LanguageTrainingService {
sentences.add(talkBody);
}
});
+ send(50);
Collections.shuffle(sentences);
- send(40);
config.setTypeNub(0);
- wordEmbedding.setConfig(config);
-// wordEmbedding.init(sentenceModel, config.getWordVectorDimension());
- log.debug("qa 训练配置 {}",JSONUtil.toJsonStr(config));
-// WordTwoVectorModel wordTwoVectorModel = wordEmbedding.start();//词向量开始学习
- send(60);
+ log.debug("qa 训练配置 {}", JSONUtil.toJsonStr(config));
+
TalkToTalk talkToTalk = new TalkToTalk(config.getTf());
- TransFormerModel transFormerModel = talkToTalk.study(sentences);
- wordAndRRManager.setTalkToTalk(talkToTalk);
- wordAndRRManager.setTfConfig(config.getTf());
// 写文件
- QAModel model = new QAModel(null,transFormerModel);
+ QAModel model = new QAModel(null, talkToTalk.study(sentences));
model.setConfig(config.getTf());
String jsonModel = mapper.writeValueAsString(model);
ModelListEntity entity = fun.apply(true);
- FileUtil.writeUtf8String(jsonModel,indexProp.getModelPath().getQaModel());
- send(90);
+ FileUtil.writeUtf8String(jsonModel, String.format(indexProp.getModelPath().getQaModel(), remark));
+ send(95);
}
@SneakyThrows
- public void QAInit(){
- GlobalBeanConfig.WordAndRRManager wordAndRRManager = wordEmbeddings.computeIfAbsent(QA
- , k -> new GlobalBeanConfig.WordAndRRManager());
- WordEmbedding wordEmbedding = wordAndRRManager.getWordEmbedding();
- if(ObjectUtil.isEmpty(wordEmbedding)){
- return;
- }
- // 反序列化
- File file = FileUtil.file(indexProp.getModelPath().getQaModel());
+ public String qaParseSentence(QATestTrainingDTO data) {
+ File file = new File(String.format(indexProp.getModelPath().getQaModel(),data.getModel()));
if(!file.exists()){
- return;
+ return null;
}
QAModel model = mapper.readValue(FileUtil.readUtf8String(file), QAModel.class);
-// sentenceConfig.setTypeNub(config.getTypeNub());
-// wordEmbedding.setConfig(sentenceConfig);
-// wordEmbedding.insertModel(model.getWordTwoVectorModel(),sentenceConfig.getWordVectorDimension());
TalkToTalk talkToTalk = new TalkToTalk(model.getConfig());
talkToTalk.insertModel(model.getTransFormerModel());
- wordAndRRManager.setTfConfig(model.getConfig());
- wordAndRRManager.setTalkToTalk(talkToTalk);
+ String answer = talkToTalk.getAnswer(data.getData(), System.currentTimeMillis());
+ log.debug("question={} answer={}",data.getData(),answer);
+ return answer;
}
- @SneakyThrows
- public String qaParseSentence(String data) {
- GlobalBeanConfig.WordAndRRManager wordAndRRManager = wordEmbeddings.get(QA);
- TalkToTalk talkToTalk = wordAndRRManager.getTalkToTalk();
- if(ObjectUtil.isEmpty(talkToTalk)){
- return "您还未进行任何的数据训练";
- }
- return talkToTalk.getAnswer(data,System.currentTimeMillis());
- }
-
-
/**
* 词向量学习
+ *
* @param sentence 语句
*/
- private void wordEmbedding(List sentence,SentenceConfig sentenceConfig,WordEmbedding wordEmbedding,LanguageModel models) throws Exception {
- log.debug("词向量学习 {}",sentence.size());
+ private void wordEmbedding(List sentence, SentenceConfig sentenceConfig, WordEmbedding wordEmbedding, LanguageModel models) throws Exception {
+ log.debug("词向量学习 {}", sentence.size());
SentenceModel sentenceModel = new SentenceModel();
sentence.forEach(sentenceModel::setSentence);
wordEmbedding.init(sentenceModel, sentenceConfig.getWordVectorDimension());// 放入语句 和 词向量维度
@@ -309,26 +264,24 @@ public class LanguageTrainingService {
/**
* 关键词敏感性嗅探模型
*/
- private void keyWordMapperMap(SentenceConfig sentenceConfig,WordEmbedding wordEmbedding,LanguageModel models) {
+ private void keyWordMapperMap(SentenceConfig sentenceConfig, WordEmbedding wordEmbedding, LanguageModel models,Map>> sensorKeyWordMapper) {
// 键词敏感性嗅探模型
List keyParameterModelMapperVOS = new ArrayList<>();
// 关键词抓取模型
List keyWordModelMapperVOS = new ArrayList<>();
- sensorKeyWordMapper.forEach((key,value)->{
- value.forEach((key1,value1) -> {
+ sensorKeyWordMapper.forEach((key, value) -> {
+ value.forEach((key1, value1) -> {
// 一个CatchKeyWord只能对一个种类的关键词类别进行捕获,它的关键词类别ID与嗅探类MyKeyWord的ID是一一对应的。
// 主键是设定好的关键词类别ID,值是该类别ID下的句子与它的关键词集合。
try {
List list = value1.stream().map(i -> (KeyWordForSentence) i).toList();
//键词敏感性嗅探模型 sentenceConfig.setShowLog(false); 有效
MyKeyWord mk = new MyKeyWord(sentenceConfig, wordEmbedding);
- keyParameterModelMapperVOS.add(new KeyParameterModelMapperVO(key1,mk.study(list)));
- myKeyWordMap.put(key1,mk);
+ keyParameterModelMapperVOS.add(new KeyParameterModelMapperVO(key1, mk.study(list)));
// 关键词抓取模型
CatchKeyWord catchKeyWord = new CatchKeyWord();
catchKeyWord.study(list); //耗时的过程
KeyWordModel keyWordModel = catchKeyWord.getModel();
- catchKeyWordMap.put(key1, catchKeyWord); //吃内存
keyWordModelMapperVOS.add(new KeyWordModelMapperVO(key, keyWordModel));
} catch (Exception e) {
throw new RuntimeException(e);
@@ -340,29 +293,27 @@ public class LanguageTrainingService {
send(80);
}
- private void keyWordMapperMapDeserialize(SentenceConfig sentenceConfig,WordEmbedding wordEmbedding,LanguageModel model) throws Exception {
+ private void keyWordMapperMapDeserialize(SentenceConfig sentenceConfig, WordEmbedding wordEmbedding, LanguageModel model,Map myKeyWordMap) throws Exception {
for (KeyParameterModelMapperVO haveKey : model.getKeyParameter()) {
MyKeyWord myKeyWord = new MyKeyWord(sentenceConfig, wordEmbedding);
myKeyWord.insertModel(haveKey.getParameter());
myKeyWordMap.put(haveKey.getKey(), myKeyWord);
}
- send(60);
}
- private void keyWordDeserialize(LanguageModel model){
+ private void keyWordDeserialize(LanguageModel model,Map catchKeyWordMap) {
for (KeyWordModelMapperVO keyWordModelMapping : model.getKeyWord()) {
int key = keyWordModelMapping.getKey();
CatchKeyWord catchKeyWord = new CatchKeyWord();
catchKeyWordMap.put(key, catchKeyWord);
catchKeyWord.insertModel(keyWordModelMapping.getModel());
}
- send(80);
}
- private void send(int current){
+ private void send(int current) {
WebSocketSessionManager.sendToAll(new SocketVO<>
- (ConstVar.Socket.PROGRESS,ProgressVO.set(current)));
+ (ConstVar.Socket.PROGRESS, ProgressVO.set(current)));
}
diff --git a/admin/src/main/java/com/wt/admin/service/model/ModelListService.java b/admin/src/main/java/com/wt/admin/service/model/ModelListService.java
index b40cd0035821befaf1f8a88521bbcfdec4165f23..8a7ddc0ce1da743be564b97f05fe4e200ec69d56 100644
--- a/admin/src/main/java/com/wt/admin/service/model/ModelListService.java
+++ b/admin/src/main/java/com/wt/admin/service/model/ModelListService.java
@@ -5,6 +5,8 @@ import com.wt.admin.domain.entity.model.ModelListEntity;
import com.wt.admin.domain.vo.model.ModelListVO;
import com.wt.admin.util.PageUtil;
+import java.util.List;
+
public interface ModelListService {
ModelListEntity addModel(Integer userId,String remark,Object config, Integer type);
@@ -12,4 +14,6 @@ public interface ModelListService {
PageUtil.PageVO modelList(PageUtil.PageDTO data);
ModelListVO modelDel(ModelListDTO data);
+
+ List modelAll(ModelListDTO data);
}
diff --git a/admin/src/main/java/com/wt/admin/service/model/ModelProxyService.java b/admin/src/main/java/com/wt/admin/service/model/ModelProxyService.java
index 2499924f537751ff79bd0582517d373f36469359..c2d4c6e727c512329a16ad662f2b21e5b4cb3023 100644
--- a/admin/src/main/java/com/wt/admin/service/model/ModelProxyService.java
+++ b/admin/src/main/java/com/wt/admin/service/model/ModelProxyService.java
@@ -4,6 +4,8 @@ import com.wt.admin.domain.dto.model.ModelListDTO;
import com.wt.admin.domain.vo.model.ModelListVO;
import com.wt.admin.util.PageUtil;
+import java.util.List;
+
public interface ModelProxyService {
PageUtil.PageVO modelList(PageUtil.PageDTO