Commit 8bff979e authored by ligaowei's avatar ligaowei

feat: 重构用户ID获取逻辑并优化ReAct执行流程

重构UserUtils类,提供静态方法支持并优化线程安全
新增EventSplitter组件用于实时分割ReAct事件流
统一所有Controller和Service使用静态方法获取用户ID
移除冗余的SseEventBroadcaster组件,简化事件发送逻辑
更新.gitignore排除数据库文件
parent b230dbdc
...@@ -217,4 +217,6 @@ Thumbs.db ...@@ -217,4 +217,6 @@ Thumbs.db
.Trashes .Trashes
ehthumbs.db ehthumbs.db
Icon? Icon?
*.icon? *.icon?
\ No newline at end of file backend/data/hiagent_dev_db.trace.db
backend/data/hiagent_dev_db.mv.db
...@@ -3,6 +3,7 @@ package pangea.hiagent.agent.react; ...@@ -3,6 +3,7 @@ package pangea.hiagent.agent.react;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import pangea.hiagent.agent.service.UserSseService;
import pangea.hiagent.workpanel.IWorkPanelDataCollector; import pangea.hiagent.workpanel.IWorkPanelDataCollector;
/** /**
...@@ -15,6 +16,9 @@ public class DefaultReactCallback implements ReactCallback { ...@@ -15,6 +16,9 @@ public class DefaultReactCallback implements ReactCallback {
@Autowired @Autowired
private IWorkPanelDataCollector workPanelCollector; private IWorkPanelDataCollector workPanelCollector;
@Autowired
private UserSseService userSseService;
@Override @Override
public void onStep(ReactStep reactStep) { public void onStep(ReactStep reactStep) {
log.info("ReAct步骤触发: 类型={}, 内容摘要={}", log.info("ReAct步骤触发: 类型={}, 内容摘要={}",
...@@ -32,7 +36,9 @@ public class DefaultReactCallback implements ReactCallback { ...@@ -32,7 +36,9 @@ public class DefaultReactCallback implements ReactCallback {
try { try {
switch (reactStep.getStepType()) { switch (reactStep.getStepType()) {
case THOUGHT: case THOUGHT:
workPanelCollector.recordThinking(reactStep.getContent(), "thought");
// userSseService.sendWorkPanelEvent(reactStep.getContent(), "thought");
// workPanelCollector.recordThinking(reactStep.getContent(), "thought");
log.info("[WorkPanel] 记录思考步骤: {}", log.info("[WorkPanel] 记录思考步骤: {}",
reactStep.getContent().substring(0, Math.min(100, reactStep.getContent().length()))); reactStep.getContent().substring(0, Math.min(100, reactStep.getContent().length())));
break; break;
......
...@@ -5,7 +5,6 @@ import org.springframework.ai.chat.client.ChatClient; ...@@ -5,7 +5,6 @@ import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.messages.*; import org.springframework.ai.chat.messages.*;
import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import pangea.hiagent.agent.service.ErrorHandlerService; import pangea.hiagent.agent.service.ErrorHandlerService;
...@@ -13,7 +12,6 @@ import pangea.hiagent.agent.service.TokenConsumerWithCompletion; ...@@ -13,7 +12,6 @@ import pangea.hiagent.agent.service.TokenConsumerWithCompletion;
import pangea.hiagent.memory.MemoryService; import pangea.hiagent.memory.MemoryService;
import pangea.hiagent.model.Agent; import pangea.hiagent.model.Agent;
import pangea.hiagent.tool.AgentToolManager; import pangea.hiagent.tool.AgentToolManager;
import pangea.hiagent.tool.impl.DateTimeTools;
import pangea.hiagent.common.utils.UserUtils; import pangea.hiagent.common.utils.UserUtils;
import java.util.List; import java.util.List;
import java.util.ArrayList; import java.util.ArrayList;
...@@ -30,20 +28,22 @@ public class DefaultReactExecutor implements ReactExecutor { ...@@ -30,20 +28,22 @@ public class DefaultReactExecutor implements ReactExecutor {
private String defaultSystemPrompt; private String defaultSystemPrompt;
private final List<ReactCallback> reactCallbacks = new ArrayList<>(); private final List<ReactCallback> reactCallbacks = new ArrayList<>();
private final EventSplitter eventSplitter;
@Autowired
private DateTimeTools dateTimeTools;
@Autowired
private MemoryService memoryService; private MemoryService memoryService;
@Autowired
private ErrorHandlerService errorHandlerService; private ErrorHandlerService errorHandlerService;
private final AgentToolManager agentToolManager; private final AgentToolManager agentToolManager;
public DefaultReactExecutor(AgentToolManager agentToolManager) { public DefaultReactExecutor(EventSplitter eventSplitter, AgentToolManager agentToolManager ,
MemoryService memoryService, ErrorHandlerService errorHandlerService) {
this.eventSplitter = eventSplitter;
this.agentToolManager = agentToolManager; this.agentToolManager = agentToolManager;
this.memoryService = memoryService;
this.errorHandlerService = errorHandlerService;
} }
@Override @Override
...@@ -56,7 +56,7 @@ public class DefaultReactExecutor implements ReactExecutor { ...@@ -56,7 +56,7 @@ public class DefaultReactExecutor implements ReactExecutor {
@Override @Override
public String execute(ChatClient chatClient, String userInput, List<Object> tools, Agent agent) { public String execute(ChatClient chatClient, String userInput, List<Object> tools, Agent agent) {
// 调用带用户ID的方法,首先尝试获取当前用户ID // 调用带用户ID的方法,首先尝试获取当前用户ID
String userId = UserUtils.getCurrentUserId(); String userId = UserUtils.getCurrentUserIdStatic();
return execute(chatClient, userInput, tools, agent, userId); return execute(chatClient, userInput, tools, agent, userId);
} }
...@@ -117,7 +117,7 @@ public class DefaultReactExecutor implements ReactExecutor { ...@@ -117,7 +117,7 @@ public class DefaultReactExecutor implements ReactExecutor {
try { try {
// 如果没有提供用户ID,则尝试获取当前用户ID // 如果没有提供用户ID,则尝试获取当前用户ID
if (userId == null) { if (userId == null) {
userId = UserUtils.getCurrentUserId(); userId = UserUtils.getCurrentUserIdStatic();
} }
String sessionId = memoryService.generateSessionId(agent, userId); String sessionId = memoryService.generateSessionId(agent, userId);
...@@ -142,7 +142,7 @@ public class DefaultReactExecutor implements ReactExecutor { ...@@ -142,7 +142,7 @@ public class DefaultReactExecutor implements ReactExecutor {
@Override @Override
public void executeStream(ChatClient chatClient, String userInput, List<Object> tools, Consumer<String> tokenConsumer, Agent agent) { public void executeStream(ChatClient chatClient, String userInput, List<Object> tools, Consumer<String> tokenConsumer, Agent agent) {
// 调用带用户ID的方法,但首先尝试获取当前用户ID // 调用带用户ID的方法,但首先尝试获取当前用户ID
String userId = UserUtils.getCurrentUserId(); String userId = UserUtils.getCurrentUserIdStatic();
executeStream(chatClient, userInput, tools, tokenConsumer, agent, userId); executeStream(chatClient, userInput, tools, tokenConsumer, agent, userId);
} }
...@@ -190,9 +190,12 @@ public class DefaultReactExecutor implements ReactExecutor { ...@@ -190,9 +190,12 @@ public class DefaultReactExecutor implements ReactExecutor {
if (tokenConsumer != null) { if (tokenConsumer != null) {
tokenConsumer.accept(token); tokenConsumer.accept(token);
} }
eventSplitter.feedToken(token);
} }
} catch (Exception e) { } catch (Exception e) {
log.error("处理token时发生错误", e); log.error("处理token时发生错误", e);
errorHandlerService.handleReactFlowError(e, tokenConsumer);
} }
} }
...@@ -217,17 +220,6 @@ public class DefaultReactExecutor implements ReactExecutor { ...@@ -217,17 +220,6 @@ public class DefaultReactExecutor implements ReactExecutor {
} }
} }
/**
* 检查是否已经触发了Final Answer步骤
*
* @param fullResponse 完整响应内容
* @return 如果已经触发了Final Answer则返回true,否则返回false
*/
private boolean hasFinalAnswerBeenTriggered(String fullResponse) {
// 使用正则表达式进行高效的不区分大小写匹配
return fullResponse.matches("(?i).*(Final Answer:|Final_Answer:|最终答案:).*");
}
/** /**
* 将助手的回复保存到内存中 * 将助手的回复保存到内存中
* *
...@@ -336,11 +328,6 @@ public class DefaultReactExecutor implements ReactExecutor { ...@@ -336,11 +328,6 @@ public class DefaultReactExecutor implements ReactExecutor {
} }
} }
// 添加默认的日期时间工具(如果尚未添加)
if (dateTimeTools != null && !tools.contains(dateTimeTools)) {
tools.add(dateTimeTools);
}
return tools; return tools;
} }
} }
\ No newline at end of file
package pangea.hiagent.agent.react;
import java.util.Arrays;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.springframework.stereotype.Component;
@Component
public class EventSplitter {
private final List<String> keywords = Arrays.asList(
"Thought", "Action", "Observation", "Iteration_Decision", "Final_Answer"
);
private final Pattern keywordPattern = Pattern.compile(
String.format("(%s):", String.join("|", keywords))
);
private String currentType = null;
private StringBuilder currentContent = new StringBuilder();
private StringBuilder buffer = new StringBuilder();
private final ReactCallback callback;
private volatile int stepNumber = 0;
public EventSplitter(ReactCallback callback) {
this.callback = callback;
}
// 每收到一个token/字符,调用此方法
public void feedToken(String token) {
buffer.append(token);
Matcher matcher = keywordPattern.matcher(buffer);
if (matcher.find()) {
// 发现新事件
if (currentType != null && currentContent.length() > 0) {
// 实时输出已分割事件
callback.onStep(new ReactStep(stepNumber++, ReactStepType.fromString(currentType), currentContent.toString()));
}
// 更新事件类型
currentType = matcher.group(1);
currentContent.setLength(0);
// 移除关键词和冒号
buffer.delete(0, matcher.end());
}
// 累积内容
currentContent.append(buffer);
buffer.setLength(0);
}
// 流式结束时,调用此方法输出最后一个事件
public void endStream() {
if (currentType != null && currentContent.length() > 0) {
callback.onStep(new ReactStep(stepNumber++, ReactStepType.fromString(currentType), currentContent.toString()));
}
}
}
...@@ -22,5 +22,9 @@ public enum ReactStepType { ...@@ -22,5 +22,9 @@ public enum ReactStepType {
/** /**
* 最终答案步骤:结合工具结果生成最终回答 * 最终答案步骤:结合工具结果生成最终回答
*/ */
FINAL_ANSWER FINAL_ANSWER;
public static ReactStepType fromString(String currentType) {
return ReactStepType.valueOf(currentType.toUpperCase());
}
} }
\ No newline at end of file
...@@ -84,11 +84,11 @@ public class AgentChatService { ...@@ -84,11 +84,11 @@ public class AgentChatService {
log.info("开始处理流式对话请求,AgentId: {}, 用户消息: {}", agentId, chatRequest.getMessage()); log.info("开始处理流式对话请求,AgentId: {}, 用户消息: {}", agentId, chatRequest.getMessage());
// 尝试获取当前用户ID,优先从SecurityContext获取,其次从请求中解析JWT // 尝试获取当前用户ID,优先从SecurityContext获取,其次从请求中解析JWT
String userId = UserUtils.getCurrentUserId(); String userId = UserUtils.getCurrentUserIdStatic();
// 如果在主线程中未能获取到用户ID,尝试在异步环境中获取 // 如果在主线程中未能获取到用户ID,再次尝试获取(支持异步环境)
if (userId == null) { if (userId == null) {
userId = UserUtils.getCurrentUserIdInAsync(); userId = UserUtils.getCurrentUserIdStatic();
} }
if (userId == null) { if (userId == null) {
......
package pangea.hiagent.agent.service; package pangea.hiagent.agent.service;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import pangea.hiagent.common.utils.UserUtils;
import pangea.hiagent.web.dto.ToolEvent;
import pangea.hiagent.web.dto.WorkPanelEvent; import pangea.hiagent.web.dto.WorkPanelEvent;
import pangea.hiagent.workpanel.event.EventService; import pangea.hiagent.workpanel.event.EventService;
import pangea.hiagent.workpanel.data.TokenEventDataBuilder; import pangea.hiagent.workpanel.data.TokenEventDataBuilder;
...@@ -587,30 +589,6 @@ public class UserSseService { ...@@ -587,30 +589,6 @@ public class UserSseService {
} }
} }
/**
* 发送工作面板事件给指定用户
*
* @param userId 用户ID
* @param event 工作面板事件
*/
public void sendWorkPanelEventToUser(String userId, WorkPanelEvent event) {
log.debug("开始向用户 {} 发送工作面板事件: {}", userId, event.getType());
// 检查连接是否仍然有效
SseEmitter emitter = getSession(userId);
if (emitter != null) {
try {
// 直接向当前 emitter 发送事件
sendWorkPanelEvent(emitter, event);
log.debug("已发送工作面板事件到客户端: {}", event.getType());
} catch (IOException e) {
log.error("发送工作面板事件失败: {}", e.getMessage(), e);
}
} else {
log.debug("连接已失效,跳过发送事件: {}", event.getType());
}
}
/** /**
* 发送连接成功事件 * 发送连接成功事件
* *
......
...@@ -112,9 +112,9 @@ public class MetaObjectHandlerConfig implements MetaObjectHandler { ...@@ -112,9 +112,9 @@ public class MetaObjectHandlerConfig implements MetaObjectHandler {
*/ */
private String getCurrentUserIdWithContext() { private String getCurrentUserIdWithContext() {
try { try {
// 直接调用UserUtils.getCurrentUserId(),该方法已经包含了所有获取用户ID的方式 // 直接调用UserUtils.getCurrentUserIdStatic(),该方法已经包含了所有获取用户ID的方式
// 并且优先从ThreadLocal获取,支持异步线程 // 并且优先从ThreadLocal获取,支持异步线程
String userId = UserUtils.getCurrentUserId(); String userId = UserUtils.getCurrentUserIdStatic();
if (userId != null) { if (userId != null) {
log.debug("成功获取用户ID: {}", userId); log.debug("成功获取用户ID: {}", userId);
return userId; return userId;
......
...@@ -99,7 +99,7 @@ public class AsyncUserContextDecorator { ...@@ -99,7 +99,7 @@ public class AsyncUserContextDecorator {
// 捕获当前线程的用户上下文 // 捕获当前线程的用户上下文
UserContextHolder userContext = captureUserContext(); UserContextHolder userContext = captureUserContext();
// 同时捕获当前线程的用户ID(用于ThreadLocal传播) // 同时捕获当前线程的用户ID(用于ThreadLocal传播)
String currentUserId = UserUtils.getCurrentUserId(); String currentUserId = UserUtils.getCurrentUserIdStatic();
return () -> { return () -> {
try { try {
...@@ -107,7 +107,7 @@ public class AsyncUserContextDecorator { ...@@ -107,7 +107,7 @@ public class AsyncUserContextDecorator {
propagateUserContext(userContext); propagateUserContext(userContext);
// 将用户ID设置到ThreadLocal中,增强可靠性 // 将用户ID设置到ThreadLocal中,增强可靠性
if (currentUserId != null) { if (currentUserId != null) {
UserUtils.setCurrentUserId(currentUserId); UserUtils.setCurrentUserIdStatic(currentUserId);
} }
// 执行原始任务 // 执行原始任务
...@@ -116,7 +116,7 @@ public class AsyncUserContextDecorator { ...@@ -116,7 +116,7 @@ public class AsyncUserContextDecorator {
// 清理当前线程的用户上下文 // 清理当前线程的用户上下文
clearUserContext(); clearUserContext();
// 清理ThreadLocal中的用户ID // 清理ThreadLocal中的用户ID
UserUtils.clearCurrentUserId(); UserUtils.clearCurrentUserIdStatic();
} }
}; };
} }
...@@ -131,7 +131,7 @@ public class AsyncUserContextDecorator { ...@@ -131,7 +131,7 @@ public class AsyncUserContextDecorator {
// 捕获当前线程的用户上下文 // 捕获当前线程的用户上下文
UserContextHolder userContext = captureUserContext(); UserContextHolder userContext = captureUserContext();
// 同时捕获当前线程的用户ID(用于ThreadLocal传播) // 同时捕获当前线程的用户ID(用于ThreadLocal传播)
String currentUserId = UserUtils.getCurrentUserId(); String currentUserId = UserUtils.getCurrentUserIdStatic();
return () -> { return () -> {
try { try {
...@@ -139,7 +139,7 @@ public class AsyncUserContextDecorator { ...@@ -139,7 +139,7 @@ public class AsyncUserContextDecorator {
propagateUserContext(userContext); propagateUserContext(userContext);
// 将用户ID设置到ThreadLocal中,增强可靠性 // 将用户ID设置到ThreadLocal中,增强可靠性
if (currentUserId != null) { if (currentUserId != null) {
UserUtils.setCurrentUserId(currentUserId); UserUtils.setCurrentUserIdStatic(currentUserId);
} }
// 执行原始任务 // 执行原始任务
...@@ -148,7 +148,7 @@ public class AsyncUserContextDecorator { ...@@ -148,7 +148,7 @@ public class AsyncUserContextDecorator {
// 清理当前线程的用户上下文 // 清理当前线程的用户上下文
clearUserContext(); clearUserContext();
// 清理ThreadLocal中的用户ID // 清理ThreadLocal中的用户ID
UserUtils.clearCurrentUserId(); UserUtils.clearCurrentUserIdStatic();
} }
}; };
} }
......
...@@ -59,7 +59,7 @@ public class MemoryService { ...@@ -59,7 +59,7 @@ public class MemoryService {
* @return 用户ID * @return 用户ID
*/ */
private String getCurrentUserId() { private String getCurrentUserId() {
String userId = UserUtils.getCurrentUserId(); String userId = UserUtils.getCurrentUserIdStatic();
if (userId == null) { if (userId == null) {
log.warn("无法通过UserUtils获取当前用户ID"); log.warn("无法通过UserUtils获取当前用户ID");
} }
......
...@@ -12,6 +12,7 @@ import org.springframework.stereotype.Component; ...@@ -12,6 +12,7 @@ import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import org.springframework.web.filter.OncePerRequestFilter; import org.springframework.web.filter.OncePerRequestFilter;
import pangea.hiagent.common.utils.JwtUtil; import pangea.hiagent.common.utils.JwtUtil;
import pangea.hiagent.common.utils.UserUtils;
import java.io.IOException; import java.io.IOException;
import java.util.Collections; import java.util.Collections;
...@@ -27,9 +28,12 @@ public class JwtAuthenticationFilter extends OncePerRequestFilter { ...@@ -27,9 +28,12 @@ public class JwtAuthenticationFilter extends OncePerRequestFilter {
private static final String BEARER_PREFIX = "Bearer "; private static final String BEARER_PREFIX = "Bearer ";
private final JwtUtil jwtUtil; private final JwtUtil jwtUtil;
private final UserUtils userUtils;
public JwtAuthenticationFilter(JwtUtil jwtUtil) { public JwtAuthenticationFilter(JwtUtil jwtUtil, UserUtils userUtils) {
this.jwtUtil = jwtUtil; this.jwtUtil = jwtUtil;
this.userUtils = userUtils;
} }
@Override @Override
...@@ -47,15 +51,17 @@ public class JwtAuthenticationFilter extends OncePerRequestFilter { ...@@ -47,15 +51,17 @@ public class JwtAuthenticationFilter extends OncePerRequestFilter {
if (StringUtils.hasText(token)) { if (StringUtils.hasText(token)) {
// 验证token是否有效 // 验证token是否有效
if (jwtUtil.validateToken(token)) { if (jwtUtil.validateToken(token)) {
String userId = jwtUtil.getUserIdFromToken(token); String userId = jwtUtil.getUserIdFromToken(token);
if (userId != null) { if (userId != null) {
// 创建认证对象,添加基本权限 // 创建认证对象,添加基本权限
var authorities = Collections.singletonList(new SimpleGrantedAuthority("ROLE_USER")); var authorities = Collections.singletonList(new SimpleGrantedAuthority("ROLE_USER"));
var authentication = new UsernamePasswordAuthenticationToken(userId, null, authorities); var authentication = new UsernamePasswordAuthenticationToken(userId, null, authorities);
SecurityContextHolder.getContext().setAuthentication(authentication); SecurityContextHolder.getContext().setAuthentication(authentication);
}
userUtils.setCurrentUserId(userId);
} }
} }
}
} catch (Exception e) { } catch (Exception e) {
log.error("JWT认证处理异常: {}", e.getMessage()); log.error("JWT认证处理异常: {}", e.getMessage());
// 不在此处发送错误响应,让Spring Security的ExceptionTranslationFilter处理 // 不在此处发送错误响应,让Spring Security的ExceptionTranslationFilter处理
......
...@@ -2,7 +2,10 @@ package pangea.hiagent.tool; ...@@ -2,7 +2,10 @@ package pangea.hiagent.tool;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import pangea.hiagent.workpanel.event.SseEventBroadcaster; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import pangea.hiagent.agent.service.UserSseService;
import pangea.hiagent.common.utils.UserUtils;
import pangea.hiagent.web.dto.WorkPanelEvent; import pangea.hiagent.web.dto.WorkPanelEvent;
import java.util.HashMap; import java.util.HashMap;
...@@ -16,7 +19,10 @@ import java.util.Map; ...@@ -16,7 +19,10 @@ import java.util.Map;
public abstract class BaseTool { public abstract class BaseTool {
@Autowired @Autowired
private SseEventBroadcaster sseEventBroadcaster; private UserSseService userSseService;
@Autowired
private UserUtils userUtils;
/** /**
* 工具执行包装方法 * 工具执行包装方法
...@@ -31,8 +37,11 @@ public abstract class BaseTool { ...@@ -31,8 +37,11 @@ public abstract class BaseTool {
String toolName = this.getClass().getSimpleName(); String toolName = this.getClass().getSimpleName();
long startTime = System.currentTimeMillis(); long startTime = System.currentTimeMillis();
// 在方法开始时获取用户ID,此时线程通常是原始请求线程,能够正确获取
String userId = userUtils.getCurrentUserId();
// 1. 发送工具开始执行事件 // 1. 发送工具开始执行事件
sendToolEvent(toolName, methodName, params, null, "执行中", startTime, null); sendToolEvent(toolName, methodName, params, null, "执行中", startTime, null,null, userId);
T result = null; T result = null;
String status = "成功"; String status = "成功";
...@@ -51,7 +60,7 @@ public abstract class BaseTool { ...@@ -51,7 +60,7 @@ public abstract class BaseTool {
long duration = endTime - startTime; long duration = endTime - startTime;
// 3. 发送工具执行完成事件 // 3. 发送工具执行完成事件
sendToolEvent(toolName, methodName, params, result, status, startTime, duration, exception); sendToolEvent(toolName, methodName, params, result, status, startTime, duration, exception, userId);
} }
return result; return result;
...@@ -78,10 +87,11 @@ public abstract class BaseTool { ...@@ -78,10 +87,11 @@ public abstract class BaseTool {
* @param startTime 开始时间戳 * @param startTime 开始时间戳
* @param duration 执行耗时(毫秒) * @param duration 执行耗时(毫秒)
* @param exception 异常信息(可选) * @param exception 异常信息(可选)
* @param userId 用户ID,从方法开始时传递
*/ */
private void sendToolEvent(String toolName, String methodName, private void sendToolEvent(String toolName, String methodName,
Map<String, Object> params, Object result, String status, Map<String, Object> params, Object result, String status,
Long startTime, Long duration, Exception... exception) { Long startTime, Long duration, Exception exception, String userId) {
try { try {
Map<String, Object> eventData = new HashMap<>(); Map<String, Object> eventData = new HashMap<>();
eventData.put("toolName", toolName); eventData.put("toolName", toolName);
...@@ -92,9 +102,9 @@ public abstract class BaseTool { ...@@ -92,9 +102,9 @@ public abstract class BaseTool {
eventData.put("startTime", startTime); eventData.put("startTime", startTime);
eventData.put("duration", duration); eventData.put("duration", duration);
if (exception != null && exception.length > 0 && exception[0] != null) { if (exception != null) {
eventData.put("error", exception[0].getMessage()); eventData.put("error", exception.getMessage());
eventData.put("errorType", exception[0].getClass().getSimpleName()); eventData.put("errorType", exception.getClass().getSimpleName());
} }
WorkPanelEvent event = WorkPanelEvent.builder() WorkPanelEvent event = WorkPanelEvent.builder()
...@@ -102,9 +112,16 @@ public abstract class BaseTool { ...@@ -102,9 +112,16 @@ public abstract class BaseTool {
.title(toolName + "." + methodName) .title(toolName + "." + methodName)
.timestamp(System.currentTimeMillis()) .timestamp(System.currentTimeMillis())
.metadata(eventData) .metadata(eventData)
.userId(userId)
.build(); .build();
sseEventBroadcaster.broadcastWorkPanelEvent(event); // 获取用户的SSE发射器
SseEmitter emitter = userSseService.getSession(userId);
if (emitter != null) {
userSseService.sendWorkPanelEvent(emitter, event);
} else {
log.debug("未找到用户 {} 的SSE连接,跳过发送事件", userId);
}
log.debug("已发送工具事件: {}#{}, 状态: {}", toolName, methodName, status); log.debug("已发送工具事件: {}#{}, 状态: {}", toolName, methodName, status);
} catch (Exception e) { } catch (Exception e) {
......
...@@ -44,7 +44,7 @@ public class AgentController { ...@@ -44,7 +44,7 @@ public class AgentController {
@PostMapping @PostMapping
public ApiResponse<Agent> createAgent(@RequestBody Agent agent) { public ApiResponse<Agent> createAgent(@RequestBody Agent agent) {
try { try {
String userId = UserUtils.getCurrentUserId(); String userId = UserUtils.getCurrentUserIdStatic();
if (userId == null) { if (userId == null) {
return ApiResponse.error(4001, "用户未认证"); return ApiResponse.error(4001, "用户未认证");
} }
...@@ -67,7 +67,7 @@ public class AgentController { ...@@ -67,7 +67,7 @@ public class AgentController {
@PostMapping("/with-tools") @PostMapping("/with-tools")
public ApiResponse<Agent> createAgentWithTools(@RequestBody AgentWithToolsDTO agentWithToolsDTO) { public ApiResponse<Agent> createAgentWithTools(@RequestBody AgentWithToolsDTO agentWithToolsDTO) {
try { try {
String userId = UserUtils.getCurrentUserId(); String userId = UserUtils.getCurrentUserIdStatic();
if (userId == null) { if (userId == null) {
return ApiResponse.error(4001, "用户未认证"); return ApiResponse.error(4001, "用户未认证");
} }
...@@ -109,7 +109,7 @@ public class AgentController { ...@@ -109,7 +109,7 @@ public class AgentController {
@PreAuthorize("@permissionEvaluator.hasPermission(authentication, #id, 'Agent', 'write')") @PreAuthorize("@permissionEvaluator.hasPermission(authentication, #id, 'Agent', 'write')")
@PutMapping("/{id}") @PutMapping("/{id}")
public ApiResponse<Agent> updateAgent(@PathVariable(name = "id") String id, @RequestBody Agent agent) { public ApiResponse<Agent> updateAgent(@PathVariable(name = "id") String id, @RequestBody Agent agent) {
String userId = UserUtils.getCurrentUserId(); String userId = UserUtils.getCurrentUserIdStatic();
if (userId == null) { if (userId == null) {
log.warn("用户未认证,无法更新Agent: {}", id); log.warn("用户未认证,无法更新Agent: {}", id);
return ApiResponse.error(4001, "用户未认证"); return ApiResponse.error(4001, "用户未认证");
...@@ -163,7 +163,7 @@ public class AgentController { ...@@ -163,7 +163,7 @@ public class AgentController {
@PreAuthorize("@permissionEvaluator.hasPermission(authentication, #id, 'Agent', 'write')") @PreAuthorize("@permissionEvaluator.hasPermission(authentication, #id, 'Agent', 'write')")
@PutMapping("/{id}/with-tools") @PutMapping("/{id}/with-tools")
public ApiResponse<Agent> updateAgentWithTools(@PathVariable(name = "id") String id, @RequestBody AgentWithToolsDTO agentWithToolsDTO) { public ApiResponse<Agent> updateAgentWithTools(@PathVariable(name = "id") String id, @RequestBody AgentWithToolsDTO agentWithToolsDTO) {
String userId = UserUtils.getCurrentUserId(); String userId = UserUtils.getCurrentUserIdStatic();
if (userId == null) { if (userId == null) {
log.warn("用户未认证,无法更新Agent: {}", id); log.warn("用户未认证,无法更新Agent: {}", id);
return ApiResponse.error(4001, "用户未认证"); return ApiResponse.error(4001, "用户未认证");
...@@ -238,7 +238,7 @@ public class AgentController { ...@@ -238,7 +238,7 @@ public class AgentController {
@DeleteMapping("/{id}") @DeleteMapping("/{id}")
public ApiResponse<Void> deleteAgent(@PathVariable(name = "id") String id) { public ApiResponse<Void> deleteAgent(@PathVariable(name = "id") String id) {
try { try {
String userId = UserUtils.getCurrentUserId(); String userId = UserUtils.getCurrentUserIdStatic();
log.info("用户 {} 开始删除Agent: {}", userId, id); log.info("用户 {} 开始删除Agent: {}", userId, id);
agentService.deleteAgent(id); agentService.deleteAgent(id);
log.info("用户 {} 成功删除Agent: {}", userId, id); log.info("用户 {} 成功删除Agent: {}", userId, id);
...@@ -292,7 +292,7 @@ public class AgentController { ...@@ -292,7 +292,7 @@ public class AgentController {
@PreAuthorize("isAuthenticated()") @PreAuthorize("isAuthenticated()")
public ApiResponse<java.util.List<Agent>> getUserAgents() { public ApiResponse<java.util.List<Agent>> getUserAgents() {
try { try {
String userId = UserUtils.getCurrentUserId(); String userId = UserUtils.getCurrentUserIdStatic();
if (userId == null) { if (userId == null) {
return ApiResponse.error(4001, "用户未认证"); return ApiResponse.error(4001, "用户未认证");
} }
......
...@@ -40,7 +40,7 @@ public class MemoryController { ...@@ -40,7 +40,7 @@ public class MemoryController {
@GetMapping("/dialogue") @GetMapping("/dialogue")
public ApiResponse<List<Map<String, Object>>> getDialogueMemories() { public ApiResponse<List<Map<String, Object>>> getDialogueMemories() {
try { try {
String userId = UserUtils.getCurrentUserId(); String userId = UserUtils.getCurrentUserIdStatic();
if (userId == null) { if (userId == null) {
log.warn("用户未认证,无法获取对话记忆列表"); log.warn("用户未认证,无法获取对话记忆列表");
return ApiResponse.error(401, "用户未认证"); return ApiResponse.error(401, "用户未认证");
...@@ -82,7 +82,7 @@ public class MemoryController { ...@@ -82,7 +82,7 @@ public class MemoryController {
@GetMapping("/knowledge") @GetMapping("/knowledge")
public ApiResponse<List<Map<String, Object>>> getKnowledgeMemories() { public ApiResponse<List<Map<String, Object>>> getKnowledgeMemories() {
try { try {
String userId = UserUtils.getCurrentUserId(); String userId = UserUtils.getCurrentUserIdStatic();
if (userId == null) { if (userId == null) {
log.warn("用户未认证,无法获取知识记忆列表"); log.warn("用户未认证,无法获取知识记忆列表");
return ApiResponse.error(401, "用户未认证"); return ApiResponse.error(401, "用户未认证");
...@@ -110,7 +110,7 @@ public class MemoryController { ...@@ -110,7 +110,7 @@ public class MemoryController {
@GetMapping("/dialogue/agent/{agentId}") @GetMapping("/dialogue/agent/{agentId}")
public ApiResponse<Map<String, Object>> getDialogueMemoryDetail(@PathVariable String agentId) { public ApiResponse<Map<String, Object>> getDialogueMemoryDetail(@PathVariable String agentId) {
try { try {
String userId = UserUtils.getCurrentUserId(); String userId = UserUtils.getCurrentUserIdStatic();
if (userId == null) { if (userId == null) {
log.warn("用户未认证,无法获取对话记忆详情"); log.warn("用户未认证,无法获取对话记忆详情");
return ApiResponse.error(401, "用户未认证"); return ApiResponse.error(401, "用户未认证");
...@@ -190,7 +190,7 @@ public class MemoryController { ...@@ -190,7 +190,7 @@ public class MemoryController {
@DeleteMapping("/dialogue/{sessionId}") @DeleteMapping("/dialogue/{sessionId}")
public ApiResponse<Void> clearDialogueMemory(@PathVariable String sessionId) { public ApiResponse<Void> clearDialogueMemory(@PathVariable String sessionId) {
try { try {
String userId = UserUtils.getCurrentUserId(); String userId = UserUtils.getCurrentUserIdStatic();
if (userId == null) { if (userId == null) {
log.warn("用户未认证,无法清空对话记忆"); log.warn("用户未认证,无法清空对话记忆");
return ApiResponse.error(401, "用户未认证"); return ApiResponse.error(401, "用户未认证");
...@@ -223,7 +223,7 @@ public class MemoryController { ...@@ -223,7 +223,7 @@ public class MemoryController {
@DeleteMapping("/knowledge/{id}") @DeleteMapping("/knowledge/{id}")
public ApiResponse<Void> deleteKnowledgeMemory(@PathVariable String id) { public ApiResponse<Void> deleteKnowledgeMemory(@PathVariable String id) {
try { try {
String userId = UserUtils.getCurrentUserId(); String userId = UserUtils.getCurrentUserIdStatic();
if (userId == null) { if (userId == null) {
log.warn("用户未认证,无法删除知识记忆"); log.warn("用户未认证,无法删除知识记忆");
return ApiResponse.error(401, "用户未认证"); return ApiResponse.error(401, "用户未认证");
......
...@@ -258,7 +258,7 @@ public class TimerController { ...@@ -258,7 +258,7 @@ public class TimerController {
* 获取当前认证用户ID * 获取当前认证用户ID
*/ */
private String getCurrentUserId() { private String getCurrentUserId() {
return UserUtils.getCurrentUserId(); return UserUtils.getCurrentUserIdStatic();
} }
/** /**
......
...@@ -39,7 +39,7 @@ public class ToolController { ...@@ -39,7 +39,7 @@ public class ToolController {
* @return 用户ID * @return 用户ID
*/ */
private String getCurrentUserId() { private String getCurrentUserId() {
return UserUtils.getCurrentUserId(); return UserUtils.getCurrentUserIdStatic();
} }
/** /**
......
...@@ -38,4 +38,9 @@ public class WorkPanelEvent implements Serializable { ...@@ -38,4 +38,9 @@ public class WorkPanelEvent implements Serializable {
* 元数据 * 元数据
*/ */
private Map<String, Object> metadata; private Map<String, Object> metadata;
/**
* 触发事件的用户ID
*/
private String userId;
} }
\ No newline at end of file
...@@ -145,7 +145,7 @@ public class AgentService { ...@@ -145,7 +145,7 @@ public class AgentService {
} }
// 验证用户权限(确保用户是所有者) // 验证用户权限(确保用户是所有者)
String currentUserId = UserUtils.getCurrentUserId(); String currentUserId = UserUtils.getCurrentUserIdStatic();
if (currentUserId == null) { if (currentUserId == null) {
log.warn("用户未认证,无法更新Agent: {}", agent.getId()); log.warn("用户未认证,无法更新Agent: {}", agent.getId());
throw new BusinessException(ErrorCode.UNAUTHORIZED.getCode(), "用户未认证"); throw new BusinessException(ErrorCode.UNAUTHORIZED.getCode(), "用户未认证");
......
...@@ -89,7 +89,7 @@ public class WebSocketConnectionManager { ...@@ -89,7 +89,7 @@ public class WebSocketConnectionManager {
String userId = (String) session.getAttributes().get("userId"); String userId = (String) session.getAttributes().get("userId");
if (userId == null || userId.isEmpty()) { if (userId == null || userId.isEmpty()) {
// 如果没有有效的用户ID,尝试从SecurityContext获取 // 如果没有有效的用户ID,尝试从SecurityContext获取
userId = UserUtils.getCurrentUserId(); userId = UserUtils.getCurrentUserIdStatic();
if (userId == null || userId.isEmpty()) { if (userId == null || userId.isEmpty()) {
// 如果仍然无法获取用户ID,使用默认值 // 如果仍然无法获取用户ID,使用默认值
userId = "unknown-user"; userId = "unknown-user";
......
package pangea.hiagent.workpanel.event;
import lombok.extern.slf4j.Slf4j;
import pangea.hiagent.agent.service.UserSseService;
import pangea.hiagent.web.dto.ToolEvent;
import pangea.hiagent.web.dto.WorkPanelEvent;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CopyOnWriteArrayList;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
/**
* SSE事件广播器
* 专门负责广播事件给所有订阅者
*/
@Slf4j
@Component
public class SseEventBroadcaster {
@Autowired
private UserSseService unifiedSseService;
@Autowired
private EventService eventService;
/**
* 广播工作面板事件给所有订阅者
*
* @param event 工作面板事件
*/
public void broadcastWorkPanelEvent(WorkPanelEvent event) {
if (event == null) {
log.warn("广播事件时接收到null事件");
return;
}
try {
// 预构建事件数据,避免重复构建
Map<String, Object> eventData = eventService.buildWorkPanelEventData(event);
try {
// 获取所有emitter并广播
List<SseEmitter> emitters = unifiedSseService.getEmitters();
int successCount = 0;
int failureCount = 0;
// 使用CopyOnWriteArrayList避免并发修改异常
for (SseEmitter emitter : new CopyOnWriteArrayList<>(emitters)) {
try {
// 检查emitter是否仍然有效
if (unifiedSseService.isEmitterValid(emitter)) {
emitter.send(SseEmitter.event().name("message").data(eventData));
successCount++;
} else {
// 移除无效的emitter
log.debug("移除无效的SSE连接");
unifiedSseService.removeEmitter(emitter);
failureCount++;
}
} catch (IOException e) {
log.error("发送事件失败,移除失效连接: {}", e.getMessage());
unifiedSseService.removeEmitter(emitter);
failureCount++;
} catch (IllegalStateException e) {
log.debug("Emitter已关闭,移除连接: {}", e.getMessage());
unifiedSseService.removeEmitter(emitter);
failureCount++;
} catch (Exception e) {
log.error("发送事件时发生未知异常,移除连接: {}", e.getMessage(), e);
unifiedSseService.removeEmitter(emitter);
failureCount++;
}
}
if (failureCount > 0) {
log.warn("事件广播部分失败: 成功={}, 失败={}", successCount, failureCount);
}
// 记录对象池使用统计信息(每100次广播记录一次)
if ((successCount + failureCount) % 100 == 0) {
log.debug("对象池使用统计: {}", eventService.getMapPoolStatistics());
}
} finally {
// 确保eventData被归还到对象池
eventService.releaseMap(eventData);
}
} catch (Exception e) {
String toolName = null;
if (event instanceof ToolEvent) {
toolName = ((ToolEvent) event).getToolName();
}
log.error("广播事件失败: 事件类型={}, 工具={}, 错误信息={}",
event.getType(),
toolName,
e.getMessage(),
e);
}
}
}
\ No newline at end of file
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment