Commit fa515e2b authored by youxiaoji's avatar youxiaoji

* [调整emitterId为随机生成,后续需要将原生emitter进行封装,增加ID等参数进行标识]

parent 99f9a094
...@@ -8,6 +8,7 @@ import org.springframework.ai.chat.prompt.Prompt; ...@@ -8,6 +8,7 @@ import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.context.annotation.Lazy; import org.springframework.context.annotation.Lazy;
import pangea.hiagent.agent.service.StreamRequestService;
import pangea.hiagent.agent.sse.UserSseService; import pangea.hiagent.agent.sse.UserSseService;
import pangea.hiagent.model.UserToken; import pangea.hiagent.model.UserToken;
import pangea.hiagent.tool.impl.HisenseTripTool; import pangea.hiagent.tool.impl.HisenseTripTool;
...@@ -24,6 +25,7 @@ import pangea.hiagent.tool.AgentToolManager; ...@@ -24,6 +25,7 @@ import pangea.hiagent.tool.AgentToolManager;
import pangea.hiagent.tool.impl.DateTimeTools; import pangea.hiagent.tool.impl.DateTimeTools;
import java.util.List; import java.util.List;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer; import java.util.function.Consumer;
...@@ -237,21 +239,23 @@ public class DefaultReactExecutor implements ReactExecutor { ...@@ -237,21 +239,23 @@ public class DefaultReactExecutor implements ReactExecutor {
try { try {
// 触发思考步骤 // 触发思考步骤
triggerThinkStep("开始处理用户请求: " + userInput); triggerThinkStep("开始处理用户请求: " + userInput);
StreamRequestService.StreamTokenConsumer consumer = (StreamRequestService.StreamTokenConsumer)tokenConsumer;
String emitterId = consumer.getEmitterId();
// 构建Prompt,包含历史对话记录 // 构建Prompt,包含历史对话记录
Prompt prompt = buildPromptWithHistory(agent.getSystemPrompt(), userInput, agent); Prompt prompt = buildPromptWithHistory(agent.getSystemPrompt(), userInput, agent);
UserToken userToken = userTokenService.getUserToken("admin","pangea"); UserToken userToken = userTokenService.getUserToken(consumer.getUserId(),"pangea");
VisitorAppointmentTool hisenseTripTool = new VisitorAppointmentTool(userToken,agentService,infoCollectorService,userSseService); VisitorAppointmentTool hisenseTripTool = new VisitorAppointmentTool(userToken,agentService,infoCollectorService,userSseService);
hisenseTripTool.initialize(); hisenseTripTool.initialize();
// 订阅流式响应 // 订阅流式响应
chatClient.prompt(prompt) chatClient.prompt(prompt)
.tools(hisenseTripTool) .tools(hisenseTripTool)
.toolContext(Map.of("emitterId",emitterId))
.stream() .stream()
.chatResponse() .chatResponse()
.subscribe( .subscribe(
chatResponse -> handleTokenResponse(chatResponse, tokenConsumer, fullResponse), chatResponse -> handleTokenResponse(chatResponse, tokenConsumer, fullResponse),
throwable -> handleStreamError(throwable, tokenConsumer), throwable -> handleStreamError(throwable, tokenConsumer,emitterId),
() -> handleStreamCompletion(tokenConsumer, fullResponse, agent) () -> handleStreamCompletion(tokenConsumer, fullResponse, agent,emitterId)
); );
} catch (Exception e) { } catch (Exception e) {
...@@ -300,7 +304,7 @@ public class DefaultReactExecutor implements ReactExecutor { ...@@ -300,7 +304,7 @@ public class DefaultReactExecutor implements ReactExecutor {
* @param fullResponse 完整响应构建器 * @param fullResponse 完整响应构建器
* @param agent Agent对象 * @param agent Agent对象
*/ */
private void handleStreamCompletion(Consumer<String> tokenConsumer, StringBuilder fullResponse, Agent agent) { private void handleStreamCompletion(Consumer<String> tokenConsumer, StringBuilder fullResponse, Agent agent,String emitterId) {
try { try {
log.info("流式处理完成"); log.info("流式处理完成");
// 触发最终答案步骤 // 触发最终答案步骤
...@@ -308,7 +312,7 @@ public class DefaultReactExecutor implements ReactExecutor { ...@@ -308,7 +312,7 @@ public class DefaultReactExecutor implements ReactExecutor {
// 将助理回复添加到ChatMemory // 将助理回复添加到ChatMemory
saveAssistantResponseToMemory(agent, fullResponse.toString()); saveAssistantResponseToMemory(agent, fullResponse.toString());
log.info("complete, remove emitterId {}",emitterId);
// 发送完成事件,包含完整内容 // 发送完成事件,包含完整内容
sendCompletionEvent(tokenConsumer, fullResponse.toString()); sendCompletionEvent(tokenConsumer, fullResponse.toString());
} catch (Exception e) { } catch (Exception e) {
...@@ -391,7 +395,8 @@ public class DefaultReactExecutor implements ReactExecutor { ...@@ -391,7 +395,8 @@ public class DefaultReactExecutor implements ReactExecutor {
* @param throwable 异常对象 * @param throwable 异常对象
* @param tokenConsumer token消费者 * @param tokenConsumer token消费者
*/ */
private void handleStreamError(Throwable throwable, Consumer<String> tokenConsumer) { private void handleStreamError(Throwable throwable, Consumer<String> tokenConsumer,String emitterId) {
log.info("error,remove emitterId:{}", emitterId);
errorHandlerService.handleStreamError(throwable, tokenConsumer, "ReAct流式处理"); errorHandlerService.handleStreamError(throwable, tokenConsumer, "ReAct流式处理");
} }
......
...@@ -15,6 +15,8 @@ import pangea.hiagent.tool.AgentToolManager; ...@@ -15,6 +15,8 @@ import pangea.hiagent.tool.AgentToolManager;
import pangea.hiagent.web.dto.AgentRequest; import pangea.hiagent.web.dto.AgentRequest;
import pangea.hiagent.workpanel.event.EventService; import pangea.hiagent.workpanel.event.EventService;
import jakarta.servlet.http.HttpServletResponse; import jakarta.servlet.http.HttpServletResponse;
import java.util.UUID;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.ThreadPoolExecutor;
...@@ -121,7 +123,10 @@ public class AgentChatService { ...@@ -121,7 +123,10 @@ public class AgentChatService {
// 创建 SSE emitter // 创建 SSE emitter
SseEmitter emitter = workPanelSseService.createEmitter(); SseEmitter emitter = workPanelSseService.createEmitter();
workPanelSseService.registerEmitter("worker1", emitter);
String emitterId = UUID.randomUUID().toString();
log.info("emitterId: {}", emitterId);
workPanelSseService.registerEmitter(emitterId, emitter);
// 将userId设为final以在Lambda表达式中使用 // 将userId设为final以在Lambda表达式中使用
final String finalUserId = userId; final String finalUserId = userId;
...@@ -129,7 +134,7 @@ public class AgentChatService { ...@@ -129,7 +134,7 @@ public class AgentChatService {
// 异步处理对话,避免阻塞HTTP连接 // 异步处理对话,避免阻塞HTTP连接
executorService.execute(() -> { executorService.execute(() -> {
try { try {
processChatRequest(emitter, agentId, chatRequest, finalUserId); processChatRequest(emitter, agentId, chatRequest, finalUserId,emitterId);
} catch (Exception e) { } catch (Exception e) {
log.error("处理聊天请求时发生异常", e); log.error("处理聊天请求时发生异常", e);
// 检查响应是否已经提交 // 检查响应是否已经提交
...@@ -152,7 +157,7 @@ public class AgentChatService { ...@@ -152,7 +157,7 @@ public class AgentChatService {
* @param chatRequest 聊天请求 * @param chatRequest 聊天请求
* @param userId 用户ID * @param userId 用户ID
*/ */
private void processChatRequest(SseEmitter emitter, String agentId, ChatRequest chatRequest, String userId) { private void processChatRequest(SseEmitter emitter, String agentId, ChatRequest chatRequest, String userId,String emitterId) {
try { try {
// 获取Agent信息并进行权限检查 // 获取Agent信息并进行权限检查
Agent agent = agentValidationService.validateAgentAndPermission(agentId, userId, emitter); Agent agent = agentValidationService.validateAgentAndPermission(agentId, userId, emitter);
...@@ -173,7 +178,7 @@ public class AgentChatService { ...@@ -173,7 +178,7 @@ public class AgentChatService {
AgentRequest request = chatRequest.toAgentRequest(agentId, agent, agentToolManager); AgentRequest request = chatRequest.toAgentRequest(agentId, agent, agentToolManager);
// 处理流式请求 // 处理流式请求
streamRequestService.handleStreamRequest(emitter, processor, request, agent, userId); streamRequestService.handleStreamRequest(emitter, processor, request, agent, userId,emitterId);
} catch (Exception e) { } catch (Exception e) {
chatErrorHandler.handleChatError(emitter, "处理请求时发生错误", e, null); chatErrorHandler.handleChatError(emitter, "处理请求时发生错误", e, null);
} }
......
...@@ -38,7 +38,7 @@ public class StreamRequestService { ...@@ -38,7 +38,7 @@ public class StreamRequestService {
* @param agent Agent对象 * @param agent Agent对象
* @param userId 用户ID * @param userId 用户ID
*/ */
public void handleStreamRequest(SseEmitter emitter, AgentProcessor processor, pangea.hiagent.web.dto.AgentRequest request, Agent agent, String userId) { public void handleStreamRequest(SseEmitter emitter, AgentProcessor processor, pangea.hiagent.web.dto.AgentRequest request, Agent agent, String userId,String emitterId) {
LogUtils.enterMethod("handleStreamRequest", emitter, processor, request, agent, userId); LogUtils.enterMethod("handleStreamRequest", emitter, processor, request, agent, userId);
// 参数验证 // 参数验证
...@@ -50,6 +50,7 @@ public class StreamRequestService { ...@@ -50,6 +50,7 @@ public class StreamRequestService {
StreamTokenConsumer tokenConsumer = new StreamTokenConsumer(emitter, processor, unifiedSseService, eventService, completionHandlerService); StreamTokenConsumer tokenConsumer = new StreamTokenConsumer(emitter, processor, unifiedSseService, eventService, completionHandlerService);
// 设置上下文信息,用于保存对话记录 // 设置上下文信息,用于保存对话记录
tokenConsumer.setContext(agent, request, userId); tokenConsumer.setContext(agent, request, userId);
tokenConsumer.setEmitterId(emitterId);
// 处理流式请求,将token缓冲和事件发送完全交给处理器实现 // 处理流式请求,将token缓冲和事件发送完全交给处理器实现
processor.processStreamRequest(request, agent, userId, tokenConsumer); processor.processStreamRequest(request, agent, userId, tokenConsumer);
...@@ -84,6 +85,7 @@ public class StreamRequestService { ...@@ -84,6 +85,7 @@ public class StreamRequestService {
private pangea.hiagent.web.dto.AgentRequest request; private pangea.hiagent.web.dto.AgentRequest request;
private String userId; private String userId;
private CompletionHandlerService completionHandlerService; private CompletionHandlerService completionHandlerService;
private String emitterId;
public StreamTokenConsumer(SseEmitter emitter, AgentProcessor processor, UserSseService unifiedSseService, EventService eventService, CompletionHandlerService completionHandlerService) { public StreamTokenConsumer(SseEmitter emitter, AgentProcessor processor, UserSseService unifiedSseService, EventService eventService, CompletionHandlerService completionHandlerService) {
this.emitter = emitter; this.emitter = emitter;
...@@ -97,6 +99,15 @@ public class StreamRequestService { ...@@ -97,6 +99,15 @@ public class StreamRequestService {
this.request = request; this.request = request;
this.userId = userId; this.userId = userId;
} }
public void setEmitterId(String emitterId) {
this.emitterId = emitterId;
}
public String getEmitterId() {
return emitterId;
}
public String getUserId() {
return userId;
}
@Override @Override
public void accept(String token) { public void accept(String token) {
......
...@@ -557,7 +557,12 @@ public class UserSseService { ...@@ -557,7 +557,12 @@ public class UserSseService {
this.userEmitters.put(id, emitter); this.userEmitters.put(id, emitter);
} }
public SseEmitter getEmitter(String id) { public SseEmitter getEmitter(String id) {
return userEmitters.get(id); return userEmitters.get(id);
} }
public boolean removeEmitter(String id) {
userEmitters.remove(id);
return true;
}
} }
\ No newline at end of file
...@@ -13,6 +13,7 @@ import jakarta.annotation.PreDestroy; ...@@ -13,6 +13,7 @@ import jakarta.annotation.PreDestroy;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.tool.annotation.Tool; import org.springframework.ai.tool.annotation.Tool;
import org.springframework.ai.tool.annotation.ToolParam; import org.springframework.ai.tool.annotation.ToolParam;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
...@@ -265,7 +266,7 @@ public class VisitorAppointmentTool { ...@@ -265,7 +266,7 @@ public class VisitorAppointmentTool {
} }
@Tool(description = "如果在用户信息中有任何与访客预约相关的信息,调用这个工具来保存访客预约信息") @Tool(description = "如果在用户信息中有任何与访客预约相关的信息,调用这个工具来保存访客预约信息")
public String applyInfoSave(@ToolParam(required = true) JSONObject infos) throws IOException { public String applyInfoSave(@ToolParam(required = true) JSONObject infos,ToolContext toolContext) throws IOException {
log.info("applyInfoSave(infos={})", infos); log.info("applyInfoSave(infos={})", infos);
infos.keySet().forEach(key -> { infos.keySet().forEach(key -> {
infoCollectorService.saveValue(key, infos.get(key)); infoCollectorService.saveValue(key, infos.get(key));
...@@ -281,7 +282,7 @@ public class VisitorAppointmentTool { ...@@ -281,7 +282,7 @@ public class VisitorAppointmentTool {
} }
JSONObject formMessage = new JSONObject(); JSONObject formMessage = new JSONObject();
formMessage.put("coms", lackJson); formMessage.put("coms", lackJson);
sendFormMessage(formMessage); sendFormMessage(formMessage,toolContext);
StringBuilder sb = new StringBuilder(); StringBuilder sb = new StringBuilder();
if (keys.isEmpty()) { if (keys.isEmpty()) {
sb.append("用户已提交全部数据,提示用户提交申请"); sb.append("用户已提交全部数据,提示用户提交申请");
...@@ -297,8 +298,9 @@ public class VisitorAppointmentTool { ...@@ -297,8 +298,9 @@ public class VisitorAppointmentTool {
} }
return sb.toString(); return sb.toString();
} }
private void sendFormMessage(JSONObject formMessage) throws IOException { private void sendFormMessage(JSONObject formMessage,ToolContext toolContext) throws IOException {
SseEmitter sseEmitter = userSseService.getEmitter("worker1"); String emitterId = toolContext.getContext().get("emitterId").toString();
SseEmitter sseEmitter = userSseService.getEmitter(emitterId);
log.info("Send Form Message {}", formMessage); log.info("Send Form Message {}", formMessage);
sseEmitter.send(SseEmitter.event().name("form").data(formMessage)); sseEmitter.send(SseEmitter.event().name("form").data(formMessage));
} }
...@@ -309,7 +311,9 @@ public class VisitorAppointmentTool { ...@@ -309,7 +311,9 @@ public class VisitorAppointmentTool {
* @return 页面内容(HTML文本) * @return 页面内容(HTML文本)
*/ */
@Tool(description = "获取访客预约申请必要信息") @Tool(description = "获取访客预约申请必要信息")
public String getAppointmentApplyNecessaryInfo() throws IOException { public String getAppointmentApplyNecessaryInfo(ToolContext toolContext) {
StringBuilder stringBuilder = new StringBuilder();
if(infoCollectorService.exists(pageId)){ if(infoCollectorService.exists(pageId)){
JSONArray jsonArray = infoCollectorService.getInfo(pageId); JSONArray jsonArray = infoCollectorService.getInfo(pageId);
JSONArray lackJson = new JSONArray(); JSONArray lackJson = new JSONArray();
...@@ -319,8 +323,15 @@ public class VisitorAppointmentTool { ...@@ -319,8 +323,15 @@ public class VisitorAppointmentTool {
} }
JSONObject formMessage = new JSONObject(); JSONObject formMessage = new JSONObject();
formMessage.put("coms", lackJson); formMessage.put("coms", lackJson);
sendFormMessage(formMessage); try {
return "已获取必要信息,保存用户提交的信息"; sendFormMessage(formMessage,toolContext);
}catch (Exception e){
e.printStackTrace();
}
stringBuilder.append(formMessage.toJSONString());
stringBuilder.append("提示用户以json格式提交信息;如果用户已提供部分信息,需要将这些信息与`props.name`属性的值进行匹配,并将匹配之后的信息以json格式提交到`applyInfoSave`以保存信息");
return stringBuilder.toString();
} }
long startTime = System.currentTimeMillis(); long startTime = System.currentTimeMillis();
...@@ -354,11 +365,10 @@ public class VisitorAppointmentTool { ...@@ -354,11 +365,10 @@ public class VisitorAppointmentTool {
} else { } else {
jsonArray = infoCollectorService.getInfo(pageId); jsonArray = infoCollectorService.getInfo(pageId);
} }
StringBuilder stringBuilder = new StringBuilder();
// 提取页面内容 // 提取页面内容
String content = stringBuilder.toString(); String content = stringBuilder.toString();
JSONObject formMessage = generateJson(jsonArray); JSONObject formMessage = generateJson(jsonArray);
sendFormMessage(formMessage); sendFormMessage(formMessage,toolContext);
long endTime = System.currentTimeMillis(); long endTime = System.currentTimeMillis();
log.info("成功获取海信出差申请页面内容,耗时: {} ms", endTime - startTime); log.info("成功获取海信出差申请页面内容,耗时: {} ms", endTime - startTime);
log.info("用户需要提交的信息包括:{}", formMessage); log.info("用户需要提交的信息包括:{}", formMessage);
......
...@@ -14,6 +14,6 @@ public interface UserTokenRepository extends BaseMapper<User> { ...@@ -14,6 +14,6 @@ public interface UserTokenRepository extends BaseMapper<User> {
@Select("SELECT * FROM sys_user_token WHERE user_name = #{userName} AND token_type=#{tokenType} ORDER BY created_at DESC") @Select("SELECT * FROM sys_user_token WHERE user_id = #{userName} AND token_type=#{tokenType} ORDER BY created_at DESC")
UserToken getTokenByUserNameAndTokenType(String userName, String tokenType); UserToken getTokenByUserIdAndTokenType(String userName, String tokenType);
} }
...@@ -11,7 +11,7 @@ public class UserTokenService { ...@@ -11,7 +11,7 @@ public class UserTokenService {
private UserTokenRepository userTokenRepository; private UserTokenRepository userTokenRepository;
public UserToken getUserToken(String userName,String tokenType) { public UserToken getUserToken(String userId,String tokenType) {
return userTokenRepository.getTokenByUserNameAndTokenType(userName,tokenType); return userTokenRepository.getTokenByUserIdAndTokenType(userId,tokenType);
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment