Commit 0306580c authored by ligaowei's avatar ligaowei

refactor(agent): 移除未使用的回调方法和冗余代码

重构ReAct相关组件,移除未使用的onFinalAnswer回调方法和冗余的getParameters/getContent方法。优化SseTokenEmitter为无状态设计,通过构造函数一次性传入所有状态。简化JwtAuthenticationFilter和DefaultPermissionEvaluator的权限检查逻辑。改进ErrorHandlerService的错误处理代码复用性。调整代码结构以提高可维护性。
parent 40bd44a9
...@@ -31,8 +31,7 @@ public class ReActAgentProcessor extends BaseAgentProcessor { ...@@ -31,8 +31,7 @@ public class ReActAgentProcessor extends BaseAgentProcessor {
@Autowired @Autowired
private RagService ragService; private RagService ragService;
@Autowired
private ReactCallback defaultReactCallback;
@Autowired @Autowired
private ReactExecutor defaultReactExecutor; private ReactExecutor defaultReactExecutor;
...@@ -40,6 +39,9 @@ public class ReActAgentProcessor extends BaseAgentProcessor { ...@@ -40,6 +39,9 @@ public class ReActAgentProcessor extends BaseAgentProcessor {
@Autowired @Autowired
private AgentToolManager agentToolManager; private AgentToolManager agentToolManager;
@Autowired
private ReactCallback defaultReactCallback;
@Override @Override
public String processRequest(Agent agent, AgentRequest request, String userId) { public String processRequest(Agent agent, AgentRequest request, String userId) {
log.info("使用ReAct Agent处理请求"); log.info("使用ReAct Agent处理请求");
...@@ -72,10 +74,6 @@ public class ReActAgentProcessor extends BaseAgentProcessor { ...@@ -72,10 +74,6 @@ public class ReActAgentProcessor extends BaseAgentProcessor {
// 处理请求的通用前置逻辑 // 处理请求的通用前置逻辑
String ragResponse = handlePreProcessing(agent, userMessage, userId, ragService, null); String ragResponse = handlePreProcessing(agent, userMessage, userId, ragService, null);
if (ragResponse != null) { if (ragResponse != null) {
// 触发最终答案回调
if (defaultReactCallback != null) {
defaultReactCallback.onFinalAnswer(ragResponse);
}
return ragResponse; return ragResponse;
} }
...@@ -83,10 +81,7 @@ public class ReActAgentProcessor extends BaseAgentProcessor { ...@@ -83,10 +81,7 @@ public class ReActAgentProcessor extends BaseAgentProcessor {
ChatClient client = ChatClient.builder(agentService.getChatModelForAgent(agent)).build(); ChatClient client = ChatClient.builder(agentService.getChatModelForAgent(agent)).build();
List<Object> tools = agentToolManager.getAvailableToolInstances(agent); List<Object> tools = agentToolManager.getAvailableToolInstances(agent);
// 添加自定义回调到ReAct执行器
if (defaultReactExecutor != null && defaultReactCallback != null) {
defaultReactExecutor.addReactCallback(defaultReactCallback);
}
// 使用ReAct执行器执行流程,传递Agent对象和用户ID以支持记忆功能 // 使用ReAct执行器执行流程,传递Agent对象和用户ID以支持记忆功能
String finalAnswer = defaultReactExecutor.execute(client, userMessage, tools, agent, userId); String finalAnswer = defaultReactExecutor.execute(client, userMessage, tools, agent, userId);
...@@ -114,10 +109,6 @@ public class ReActAgentProcessor extends BaseAgentProcessor { ...@@ -114,10 +109,6 @@ public class ReActAgentProcessor extends BaseAgentProcessor {
// 处理请求的通用前置逻辑 // 处理请求的通用前置逻辑
String ragResponse = handlePreProcessing(agent, userMessage, userId, ragService, tokenConsumer); String ragResponse = handlePreProcessing(agent, userMessage, userId, ragService, tokenConsumer);
if (ragResponse != null) { if (ragResponse != null) {
// 触发最终答案回调
if (defaultReactCallback != null) {
defaultReactCallback.onFinalAnswer(ragResponse);
}
return; return;
} }
......
...@@ -24,12 +24,6 @@ public class DefaultReactCallback implements ReactCallback { ...@@ -24,12 +24,6 @@ public class DefaultReactCallback implements ReactCallback {
recordReactStepToWorkPanel(reactStep); recordReactStepToWorkPanel(reactStep);
} }
@Override
public void onFinalAnswer(String finalAnswer) {
ReactStep finalStep = new ReactStep(0, ReactStepType.FINAL_ANSWER, finalAnswer);
recordReactStepToWorkPanel(finalStep);
}
private void recordReactStepToWorkPanel(ReactStep reactStep) { private void recordReactStepToWorkPanel(ReactStep reactStep) {
if (workPanelCollector == null) { if (workPanelCollector == null) {
return; return;
...@@ -46,7 +40,7 @@ public class DefaultReactCallback implements ReactCallback { ...@@ -46,7 +40,7 @@ public class DefaultReactCallback implements ReactCallback {
if (reactStep.getAction() != null) { if (reactStep.getAction() != null) {
// 记录工具调用动作 // 记录工具调用动作
String toolName = reactStep.getAction().getToolName(); String toolName = reactStep.getAction().getToolName();
Object parameters = reactStep.getAction().getParameters(); Object parameters = reactStep.getAction().getToolArgs();
// 记录工具调用,初始状态为pending // 记录工具调用,初始状态为pending
workPanelCollector.recordToolCallAction( workPanelCollector.recordToolCallAction(
...@@ -73,15 +67,15 @@ public class DefaultReactCallback implements ReactCallback { ...@@ -73,15 +67,15 @@ public class DefaultReactCallback implements ReactCallback {
// 使用动作信息更新工具调用结果 // 使用动作信息更新工具调用结果
workPanelCollector.recordToolCallAction( workPanelCollector.recordToolCallAction(
reactStep.getAction().getToolName(), reactStep.getAction().getToolName(),
reactStep.getAction().getParameters(), reactStep.getAction().getToolArgs(),
reactStep.getObservation().getContent(), reactStep.getObservation().getResult(),
"success", // 状态为success "success", // 状态为success
null // 无错误信息 null // 无错误信息
); );
log.info("[WorkPanel] 更新工具调用结果: 工具={} 结果摘要={}", log.info("[WorkPanel] 更新工具调用结果: 工具={} 结果摘要={}",
reactStep.getAction().getToolName(), reactStep.getAction().getToolName(),
reactStep.getObservation().getContent().substring(0, Math.min(50, reactStep.getObservation().getContent().length()))); reactStep.getObservation().getResult().substring(0, Math.min(50, reactStep.getObservation().getResult().length())));
} else { } else {
// 如果没有动作信息,记录为观察结果 // 如果没有动作信息,记录为观察结果
workPanelCollector.recordThinking(reactStep.getContent(), "observation"); workPanelCollector.recordThinking(reactStep.getContent(), "observation");
...@@ -108,7 +102,7 @@ public class DefaultReactCallback implements ReactCallback { ...@@ -108,7 +102,7 @@ public class DefaultReactCallback implements ReactCallback {
if (reactStep != null && reactStep.getAction() != null) { if (reactStep != null && reactStep.getAction() != null) {
workPanelCollector.recordToolCallAction( workPanelCollector.recordToolCallAction(
reactStep.getAction().getToolName(), reactStep.getAction().getToolName(),
reactStep.getAction().getParameters(), reactStep.getAction().getToolArgs(),
"记录失败: " + e.getMessage(), "记录失败: " + e.getMessage(),
"error", "error",
System.currentTimeMillis() // 使用当前时间戳作为执行时间 System.currentTimeMillis() // 使用当前时间戳作为执行时间
......
...@@ -10,6 +10,4 @@ public interface ReactCallback { ...@@ -10,6 +10,4 @@ public interface ReactCallback {
* @param reactStep ReAct步骤对象,包含步骤的所有核心信息 * @param reactStep ReAct步骤对象,包含步骤的所有核心信息
*/ */
void onStep(ReactStep reactStep); void onStep(ReactStep reactStep);
void onFinalAnswer(String ragResponse);
} }
\ No newline at end of file
...@@ -49,9 +49,6 @@ public class ReactStep { ...@@ -49,9 +49,6 @@ public class ReactStep {
public Object getToolArgs() { return toolArgs; } public Object getToolArgs() { return toolArgs; }
public void setToolArgs(Object toolArgs) { this.toolArgs = toolArgs; } public void setToolArgs(Object toolArgs) { this.toolArgs = toolArgs; }
// 根据DefaultReactCallback.java中的使用情况添加getParameters方法
public Object getParameters() { return toolArgs; }
} }
/** /**
...@@ -66,8 +63,5 @@ public class ReactStep { ...@@ -66,8 +63,5 @@ public class ReactStep {
public String getResult() { return result; } public String getResult() { return result; }
public void setResult(String result) { this.result = result; } public void setResult(String result) { this.result = result; }
// 根据DefaultReactCallback.java中的使用情况添加getContent方法
public String getContent() { return result; }
} }
} }
\ No newline at end of file
...@@ -27,7 +27,7 @@ public class AgentChatService { ...@@ -27,7 +27,7 @@ public class AgentChatService {
private final ErrorHandlerService errorHandlerService; private final ErrorHandlerService errorHandlerService;
private final AgentProcessorFactory agentProcessorFactory; private final AgentProcessorFactory agentProcessorFactory;
private final AgentToolManager agentToolManager; private final AgentToolManager agentToolManager;
private final UserSseService userSseSerivce; private final UserSseService userSseService;
private final pangea.hiagent.web.service.AgentService agentService; private final pangea.hiagent.web.service.AgentService agentService;
private final SseTokenEmitter sseTokenEmitter; private final SseTokenEmitter sseTokenEmitter;
...@@ -36,13 +36,13 @@ public class AgentChatService { ...@@ -36,13 +36,13 @@ public class AgentChatService {
ErrorHandlerService errorHandlerService, ErrorHandlerService errorHandlerService,
AgentProcessorFactory agentProcessorFactory, AgentProcessorFactory agentProcessorFactory,
AgentToolManager agentToolManager, AgentToolManager agentToolManager,
UserSseService workPanelSseService, UserSseService userSseService,
pangea.hiagent.web.service.AgentService agentService, pangea.hiagent.web.service.AgentService agentService,
SseTokenEmitter sseTokenEmitter) { SseTokenEmitter sseTokenEmitter) {
this.errorHandlerService = errorHandlerService; this.errorHandlerService = errorHandlerService;
this.agentProcessorFactory = agentProcessorFactory; this.agentProcessorFactory = agentProcessorFactory;
this.agentToolManager = agentToolManager; this.agentToolManager = agentToolManager;
this.userSseSerivce = workPanelSseService; this.userSseService = userSseService;
this.agentService = agentService; this.agentService = agentService;
this.sseTokenEmitter = sseTokenEmitter; this.sseTokenEmitter = sseTokenEmitter;
} }
...@@ -54,8 +54,10 @@ public class AgentChatService { ...@@ -54,8 +54,10 @@ public class AgentChatService {
// * @param userId 用户ID // * @param userId 用户ID
// * @return 处理结果 // * @return 处理结果
// */ // */
// public String handleChatSync(Agent agent, AgentRequest request, String userId) { // public String handleChatSync(Agent agent, AgentRequest request, String
// log.info("开始处理同步对话请求,AgentId: {}, 用户消息: {}", agent.getId(), request.getUserMessage()); // userId) {
// log.info("开始处理同步对话请求,AgentId: {}, 用户消息: {}", agent.getId(),
// request.getUserMessage());
// //
// try { // try {
// // 获取处理器 // // 获取处理器
...@@ -94,14 +96,14 @@ public class AgentChatService { ...@@ -94,14 +96,14 @@ public class AgentChatService {
if (userId == null) { if (userId == null) {
log.error("用户未认证"); log.error("用户未认证");
SseEmitter emitter = userSseSerivce.createEmitter(); SseEmitter emitter = userSseService.createEmitter();
// 检查响应是否已经提交 // 检查响应是否已经提交
if (!response.isCommitted()) { if (!response.isCommitted()) {
errorHandlerService.handleChatError(emitter, "用户未认证,请重新登录"); errorHandlerService.handleChatError(emitter, "用户未认证,请重新登录");
} else { } else {
log.warn("响应已提交,无法发送用户未认证错误信息"); log.warn("响应已提交,无法发送用户未认证错误信息");
// 检查emitter是否已经完成,避免重复关闭 // 检查emitter是否已经完成,避免重复关闭
if (!userSseSerivce.isEmitterCompleted(emitter)) { if (!userSseService.isEmitterCompleted(emitter)) {
emitter.complete(); emitter.complete();
} }
} }
...@@ -112,14 +114,14 @@ public class AgentChatService { ...@@ -112,14 +114,14 @@ public class AgentChatService {
Agent agent = agentService.getAgent(agentId); Agent agent = agentService.getAgent(agentId);
if (agent == null) { if (agent == null) {
log.warn("Agent不存在: {}", agentId); log.warn("Agent不存在: {}", agentId);
SseEmitter emitter = userSseSerivce.createEmitter(); SseEmitter emitter = userSseService.createEmitter();
// 检查响应是否已经提交 // 检查响应是否已经提交
if (!response.isCommitted()) { if (!response.isCommitted()) {
errorHandlerService.handleChatError(emitter, "Agent不存在"); errorHandlerService.handleChatError(emitter, "Agent不存在");
} else { } else {
log.warn("响应已提交,无法发送Agent不存在错误信息"); log.warn("响应已提交,无法发送Agent不存在错误信息");
// 检查emitter是否已经完成,避免重复关闭 // 检查emitter是否已经完成,避免重复关闭
if (!userSseSerivce.isEmitterCompleted(emitter)) { if (!userSseService.isEmitterCompleted(emitter)) {
emitter.complete(); emitter.complete();
} }
} }
...@@ -127,7 +129,7 @@ public class AgentChatService { ...@@ -127,7 +129,7 @@ public class AgentChatService {
} }
// 创建 SSE emitter // 创建 SSE emitter
SseEmitter emitter = userSseSerivce.createEmitter(); SseEmitter emitter = userSseService.createEmitter();
// 异步处理对话,避免阻塞HTTP连接 // 异步处理对话,避免阻塞HTTP连接
processChatStreamAsync(emitter, agent, chatRequest, userId); processChatStreamAsync(emitter, agent, chatRequest, userId);
...@@ -144,15 +146,10 @@ public class AgentChatService { ...@@ -144,15 +146,10 @@ public class AgentChatService {
processChatRequest(emitter, agent, chatRequest, userId); processChatRequest(emitter, agent, chatRequest, userId);
} catch (Exception e) { } catch (Exception e) {
log.error("处理聊天请求时发生异常", e); log.error("处理聊天请求时发生异常", e);
try {
// 检查响应是否已经提交 // 检查响应是否已经提交
if (emitter != null && !userSseSerivce.isEmitterCompleted(emitter)) { if (emitter != null && !userSseService.isEmitterCompleted(emitter)) {
errorHandlerService.handleChatError(emitter, "处理请求时发生错误", e, null); errorHandlerService.handleChatError(emitter, "处理请求时发生错误", e, null);
} else {
log.warn("响应已提交或emitter已完成,无法发送处理请求错误信息");
}
} catch (Exception handlerException) {
log.error("处理错误信息时发生异常", handlerException);
} }
} }
} }
...@@ -184,17 +181,11 @@ public class AgentChatService { ...@@ -184,17 +181,11 @@ public class AgentChatService {
// 转换请求对象 // 转换请求对象
AgentRequest request = chatRequest.toAgentRequest(agent.getId(), agent, agentToolManager); AgentRequest request = chatRequest.toAgentRequest(agent.getId(), agent, agentToolManager);
// 设置SSE发射器到token发射器 // 创建新的SseTokenEmitter实例
sseTokenEmitter.setEmitter(emitter); SseTokenEmitter tokenEmitter = sseTokenEmitter.createNewInstance(emitter, agent, request, userId, this::handleCompletion);
// 设置上下文信息
sseTokenEmitter.setContext(agent, request, userId);
// 设置完成回调
sseTokenEmitter.setCompletionCallback(this::handleCompletion);
// 处理流式请求 // 处理流式请求
processor.processStreamRequest(request, agent, userId, sseTokenEmitter); processor.processStreamRequest(request, agent, userId, tokenEmitter);
} catch (Exception e) { } catch (Exception e) {
log.error("处理聊天请求时发生异常", e); log.error("处理聊天请求时发生异常", e);
errorHandlerService.handleChatError(emitter, "处理请求时发生错误", e, null); errorHandlerService.handleChatError(emitter, "处理请求时发生错误", e, null);
...@@ -210,7 +201,8 @@ public class AgentChatService { ...@@ -210,7 +201,8 @@ public class AgentChatService {
* @param userId 用户ID * @param userId 用户ID
* @param fullContent 完整内容 * @param fullContent 完整内容
*/ */
private void handleCompletion(SseEmitter emitter, Agent agent, AgentRequest request, String userId, String fullContent) { private void handleCompletion(SseEmitter emitter, Agent agent, AgentRequest request, String userId,
String fullContent) {
log.info("Agent处理完成,总字符数: {}", fullContent != null ? fullContent.length() : 0); log.info("Agent处理完成,总字符数: {}", fullContent != null ? fullContent.length() : 0);
// 保存对话记录 // 保存对话记录
......
...@@ -102,7 +102,7 @@ public class ErrorHandlerService { ...@@ -102,7 +102,7 @@ public class ErrorHandlerService {
* *
* @param emitter SSE发射器 * @param emitter SSE发射器
* @param errorMessage 错误信息 * @param errorMessage 错误信息
* @param exception 异常对象 * @param exception 异常对象(可选)
* @param processorType 处理器类型(可选) * @param processorType 处理器类型(可选)
*/ */
public void handleChatError(SseEmitter emitter, String errorMessage, Exception exception, String processorType) { public void handleChatError(SseEmitter emitter, String errorMessage, Exception exception, String processorType) {
...@@ -142,44 +142,25 @@ public class ErrorHandlerService { ...@@ -142,44 +142,25 @@ public class ErrorHandlerService {
} }
/** /**
* 处理聊天过程中的异常() * 处理聊天过程中的异常(简化版
* *
* @param emitter SSE发射器 * @param emitter SSE发射器
* @param errorMessage 错误信息 * @param errorMessage 错误信息
*/ */
public void handleChatError(SseEmitter emitter, String errorMessage) { public void handleChatError(SseEmitter emitter, String errorMessage) {
// 参数验证 handleChatError(emitter, errorMessage, null, null);
if (errorMessage == null || errorMessage.isEmpty()) {
errorMessage = "未知错误";
}
// 生成错误跟踪ID
String errorId = generateErrorId();
log.error("[{}] 处理聊天请求时发生错误: {}", errorId, errorMessage);
try {
// 检查emitter是否已经完成,避免向已完成的连接发送错误信息
if (userSseService != null && !userSseService.isEmitterCompleted(emitter)) {
String fullErrorMessage = buildFullErrorMessage(errorMessage, null, errorId, null);
userSseService.sendErrorEvent(emitter, fullErrorMessage);
} else {
log.debug("[{}] SSE emitter已完成,跳过发送错误信息", errorId);
}
} catch (Exception sendErrorEx) {
log.error("[{}] 发送错误信息失败", errorId, sendErrorEx);
}
} }
/** /**
* 处理Token处理过程中的异常 * 处理带完成状态标记的异常
* *
* @param emitter SSE发射器 * @param emitter SSE发射器
* @param errorMessage 错误信息
* @param processorType 处理器类型 * @param processorType 处理器类型
* @param exception 异常对象 * @param exception 异常对象
* @param isCompleted 完成状态标记 * @param isCompleted 完成状态标记
*/ */
public void handleTokenError(SseEmitter emitter, String processorType, Exception exception, AtomicBoolean isCompleted) { private void handleErrorWithCompletion(SseEmitter emitter, String errorMessage, String processorType, Exception exception, AtomicBoolean isCompleted) {
// 参数验证 // 参数验证
if (processorType == null || processorType.isEmpty()) { if (processorType == null || processorType.isEmpty()) {
processorType = "未知处理器"; processorType = "未知处理器";
...@@ -192,17 +173,17 @@ public class ErrorHandlerService { ...@@ -192,17 +173,17 @@ public class ErrorHandlerService {
if (exception != null) { if (exception != null) {
exceptionMonitoringService.recordException( exceptionMonitoringService.recordException(
exception.getClass().getSimpleName(), exception.getClass().getSimpleName(),
"处理token时发生错误", errorMessage,
java.util.Arrays.toString(exception.getStackTrace()) java.util.Arrays.toString(exception.getStackTrace())
); );
} }
log.error("[{}] {}处理token时发生错误", errorId, processorType, exception); log.error("[{}] {}: {}", errorId, processorType, errorMessage, exception);
if (!isCompleted.getAndSet(true)) { if (!isCompleted.getAndSet(true)) {
try { try {
// 检查emitter是否已经完成,避免向已完成的连接发送错误信息 // 检查emitter是否已经完成,避免向已完成的连接发送错误信息
if (userSseService != null && !userSseService.isEmitterCompleted(emitter)) { if (userSseService != null && !userSseService.isEmitterCompleted(emitter)) {
String errorMessage = "处理响应时发生错误";
String fullErrorMessage = buildFullErrorMessage(errorMessage, exception, errorId, processorType); String fullErrorMessage = buildFullErrorMessage(errorMessage, exception, errorId, processorType);
userSseService.sendErrorEvent(emitter, fullErrorMessage); userSseService.sendErrorEvent(emitter, fullErrorMessage);
} else { } else {
...@@ -216,6 +197,18 @@ public class ErrorHandlerService { ...@@ -216,6 +197,18 @@ public class ErrorHandlerService {
} }
} }
/**
* 处理Token处理过程中的异常
*
* @param emitter SSE发射器
* @param processorType 处理器类型
* @param exception 异常对象
* @param isCompleted 完成状态标记
*/
public void handleTokenError(SseEmitter emitter, String processorType, Exception exception, AtomicBoolean isCompleted) {
handleErrorWithCompletion(emitter, "处理token时发生错误", processorType, exception, isCompleted);
}
/** /**
* 处理完成回调过程中的异常 * 处理完成回调过程中的异常
* *
...@@ -252,15 +245,13 @@ public class ErrorHandlerService { ...@@ -252,15 +245,13 @@ public class ErrorHandlerService {
} }
/** /**
* 处理流式处理中的错误 * 处理基于Consumer的流式错误
* *
* @param e 异常对象 * @param e 异常对象
* @param tokenConsumer token处理回调函数 * @param tokenConsumer token处理回调函数
* @param errorMessagePrefix 错误消息前缀 * @param errorMessage 完整错误消息
*/ */
public void handleStreamError(Throwable e, Consumer<String> tokenConsumer, String errorMessagePrefix) { private void handleConsumerError(Throwable e, Consumer<String> tokenConsumer, String errorMessage) {
String errorMessage = errorMessagePrefix + ": " + e.getMessage();
// 记录异常到监控服务 // 记录异常到监控服务
exceptionMonitoringService.recordException( exceptionMonitoringService.recordException(
e.getClass().getSimpleName(), e.getClass().getSimpleName(),
...@@ -268,12 +259,24 @@ public class ErrorHandlerService { ...@@ -268,12 +259,24 @@ public class ErrorHandlerService {
java.util.Arrays.toString(e.getStackTrace()) java.util.Arrays.toString(e.getStackTrace())
); );
log.error("流式处理错误: {}", errorMessage, e); log.error(errorMessage, e);
if (tokenConsumer != null) { if (tokenConsumer != null) {
tokenConsumer.accept("[ERROR] " + errorMessage); tokenConsumer.accept("[ERROR] " + errorMessage);
} }
} }
/**
* 处理流式处理中的错误
*
* @param e 异常对象
* @param tokenConsumer token处理回调函数
* @param errorMessagePrefix 错误消息前缀
*/
public void handleStreamError(Throwable e, Consumer<String> tokenConsumer, String errorMessagePrefix) {
String errorMessage = errorMessagePrefix + ": " + e.getMessage();
handleConsumerError(e, tokenConsumer, errorMessage);
}
/** /**
* 发送错误信息给客户端 * 发送错误信息给客户端
* *
...@@ -294,18 +297,7 @@ public class ErrorHandlerService { ...@@ -294,18 +297,7 @@ public class ErrorHandlerService {
*/ */
public void handleReactFlowError(Exception e, Consumer<String> tokenConsumer) { public void handleReactFlowError(Exception e, Consumer<String> tokenConsumer) {
String errorMessage = "处理ReAct流程时发生错误: " + e.getMessage(); String errorMessage = "处理ReAct流程时发生错误: " + e.getMessage();
handleConsumerError(e, tokenConsumer, errorMessage);
// 记录异常到监控服务
exceptionMonitoringService.recordException(
e.getClass().getSimpleName(),
errorMessage,
java.util.Arrays.toString(e.getStackTrace())
);
log.error("ReAct流程错误: {}", errorMessage, e);
if (tokenConsumer != null) {
tokenConsumer.accept("[ERROR] " + errorMessage);
}
} }
/** /**
...@@ -337,33 +329,6 @@ public class ErrorHandlerService { ...@@ -337,33 +329,6 @@ public class ErrorHandlerService {
* @param isCompleted 完成状态标记 * @param isCompleted 完成状态标记
*/ */
public void handleSaveDialogueError(SseEmitter emitter, Exception exception, AtomicBoolean isCompleted) { public void handleSaveDialogueError(SseEmitter emitter, Exception exception, AtomicBoolean isCompleted) {
// 生成错误跟踪ID handleErrorWithCompletion(emitter, "保存对话记录失败", "对话记录", exception, isCompleted);
String errorId = generateErrorId();
// 记录异常到监控服务
if (exception != null) {
exceptionMonitoringService.recordException(
exception.getClass().getSimpleName(),
"保存对话记录失败",
java.util.Arrays.toString(exception.getStackTrace())
);
}
log.error("[{}] 保存对话记录失败", errorId, exception);
if (!isCompleted.getAndSet(true)) {
try {
// 检查emitter是否已经完成,避免向已完成的连接发送错误信息
if (userSseService != null && !userSseService.isEmitterCompleted(emitter)) {
String errorMessage = "保存对话记录失败,请联系技术支持";
String fullErrorMessage = buildFullErrorMessage(errorMessage, exception, errorId, "对话记录");
userSseService.sendErrorEvent(emitter, fullErrorMessage);
} else {
log.debug("[{}] SSE emitter已完成,跳过发送错误信息", errorId);
}
} catch (Exception sendErrorEx) {
log.error("[{}] 发送错误信息失败", errorId, sendErrorEx);
}
}
} }
} }
\ No newline at end of file
...@@ -2,9 +2,10 @@ package pangea.hiagent.agent.service; ...@@ -2,9 +2,10 @@ package pangea.hiagent.agent.service;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLong;
import java.util.Map; import java.util.concurrent.locks.ReentrantReadWriteLock;
/** /**
* 异常监控服务 * 异常监控服务
...@@ -17,12 +18,18 @@ public class ExceptionMonitoringService { ...@@ -17,12 +18,18 @@ public class ExceptionMonitoringService {
// 异常统计信息 // 异常统计信息
private final Map<String, AtomicLong> exceptionCounters = new ConcurrentHashMap<>(); private final Map<String, AtomicLong> exceptionCounters = new ConcurrentHashMap<>();
// 异常详细信息缓存 // 异常详细信息缓存,使用时间戳作为键,便于按时间排序
private final Map<String, String> exceptionDetails = new ConcurrentHashMap<>(); private final Map<Long, String> exceptionDetails = new ConcurrentHashMap<>();
// 锁,用于保护缓存清理操作
private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
// 最大缓存条目数 // 最大缓存条目数
private static final int MAX_CACHE_SIZE = 1000; private static final int MAX_CACHE_SIZE = 1000;
// 清理阈值,当缓存超过最大值时,清理到这个值
private static final int CLEANUP_THRESHOLD = MAX_CACHE_SIZE - 200;
/** /**
* 记录异常信息 * 记录异常信息
* *
...@@ -37,14 +44,31 @@ public class ExceptionMonitoringService { ...@@ -37,14 +44,31 @@ public class ExceptionMonitoringService {
counter.incrementAndGet(); counter.incrementAndGet();
// 记录异常详细信息(保留最新的) // 记录异常详细信息(保留最新的)
String detailKey = exceptionType + "_" + System.currentTimeMillis(); long timestamp = System.currentTimeMillis();
exceptionDetails.put(detailKey, formatExceptionDetail(exceptionType, errorMessage, stackTrace)); exceptionDetails.put(timestamp, formatExceptionDetail(exceptionType, errorMessage, stackTrace));
// 控制缓存大小 // 控制缓存大小,使用写锁保护清理操作
if (exceptionDetails.size() > MAX_CACHE_SIZE) {
lock.writeLock().lock();
try {
// 再次检查,避免竞态条件
if (exceptionDetails.size() > MAX_CACHE_SIZE) { if (exceptionDetails.size() > MAX_CACHE_SIZE) {
// 移除最老的条目 // 找出最老的条目并移除,直到达到清理阈值
String oldestKey = exceptionDetails.keySet().iterator().next(); while (exceptionDetails.size() > CLEANUP_THRESHOLD) {
exceptionDetails.remove(oldestKey); // 找出最小的时间戳(最老的条目)
Long oldestTimestamp = exceptionDetails.keySet().stream()
.min(Long::compare)
.orElse(null);
if (oldestTimestamp != null) {
exceptionDetails.remove(oldestTimestamp);
} else {
break;
}
}
}
} finally {
lock.writeLock().unlock();
}
} }
// 记录日志 // 记录日志
...@@ -102,7 +126,11 @@ public class ExceptionMonitoringService { ...@@ -102,7 +126,11 @@ public class ExceptionMonitoringService {
* @return 异常详细信息 * @return 异常详细信息
*/ */
public Map<String, String> getExceptionDetails() { public Map<String, String> getExceptionDetails() {
return new ConcurrentHashMap<>(exceptionDetails); Map<String, String> result = new ConcurrentHashMap<>();
for (Map.Entry<Long, String> entry : exceptionDetails.entrySet()) {
result.put(entry.getKey().toString(), entry.getValue());
}
return result;
} }
/** /**
......
...@@ -10,6 +10,7 @@ import pangea.hiagent.web.dto.AgentRequest; ...@@ -10,6 +10,7 @@ import pangea.hiagent.web.dto.AgentRequest;
/** /**
* SSE Token发射器 * SSE Token发射器
* 专注于将token转换为SSE事件并发送 * 专注于将token转换为SSE事件并发送
* 无状态设计,每次使用时创建新实例
*/ */
@Slf4j @Slf4j
@Component @Component
...@@ -17,42 +18,51 @@ public class SseTokenEmitter implements TokenConsumerWithCompletion { ...@@ -17,42 +18,51 @@ public class SseTokenEmitter implements TokenConsumerWithCompletion {
private final UserSseService userSseService; private final UserSseService userSseService;
// 当前处理的emitter // 所有状态通过构造函数一次性传入
private SseEmitter emitter; private final SseEmitter emitter;
private final Agent agent;
// 上下文信息 private final AgentRequest request;
private Agent agent; private final String userId;
private AgentRequest request; private final CompletionCallback completionCallback;
private String userId;
// 完成回调
private CompletionCallback completionCallback;
public SseTokenEmitter(UserSseService userSseService) {
this.userSseService = userSseService;
}
/** /**
* 设置当前使用的SSE发射器 * 构造函数
* @param userSseService SSE服务
* @param emitter SSE发射器
* @param agent Agent对象
* @param request 请求对象
* @param userId 用户ID
* @param completionCallback 完成回调
*/ */
public void setEmitter(SseEmitter emitter) { public SseTokenEmitter(UserSseService userSseService, SseEmitter emitter, Agent agent,
AgentRequest request, String userId, CompletionCallback completionCallback) {
this.userSseService = userSseService;
this.emitter = emitter; this.emitter = emitter;
this.agent = agent;
this.request = request;
this.userId = userId;
this.completionCallback = completionCallback;
} }
/** /**
* 设置上下文信息 * 无参构造函数,用于Spring容器初始化
*/ */
public void setContext(Agent agent, AgentRequest request, String userId) { public SseTokenEmitter(UserSseService userSseService) {
this.agent = agent; this(userSseService, null, null, null, null, null);
this.request = request;
this.userId = userId;
} }
/** /**
* 设置完成回调 * 创建新的SseTokenEmitter实例
* @param emitter SSE发射器
* @param agent Agent对象
* @param request 请求对象
* @param userId 用户ID
* @param completionCallback 完成回调
* @return 新的SseTokenEmitter实例
*/ */
public void setCompletionCallback(CompletionCallback completionCallback) { public SseTokenEmitter createNewInstance(SseEmitter emitter, Agent agent, AgentRequest request,
this.completionCallback = completionCallback; String userId, CompletionCallback completionCallback) {
return new SseTokenEmitter(userSseService, emitter, agent, request, userId, completionCallback);
} }
@Override @Override
......
package pangea.hiagent.agent.service; package pangea.hiagent.agent.service;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.concurrent.atomic.AtomicBoolean;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import pangea.hiagent.workpanel.event.EventService;
/** /**
* Token消费者接口,支持完成回调 * Token消费者接口,支持完成回调
...@@ -17,17 +14,4 @@ public interface TokenConsumerWithCompletion extends Consumer<String> { ...@@ -17,17 +14,4 @@ public interface TokenConsumerWithCompletion extends Consumer<String> {
default void onComplete(String fullContent) { default void onComplete(String fullContent) {
// 默认实现为空 // 默认实现为空
} }
/**
* 当流式处理完成时调用,发送完成事件到前端
* @param fullContent 完整的内容
* @param emitter SSE发射器
* @param sseEventSender SSE事件发送器
* @param isCompleted 完成状态标记
*/
default void onComplete(String fullContent, SseEmitter emitter,
EventService eventService,
AtomicBoolean isCompleted) {
// 默认实现将在子类中覆盖
}
} }
\ No newline at end of file
...@@ -4,11 +4,10 @@ import lombok.extern.slf4j.Slf4j; ...@@ -4,11 +4,10 @@ import lombok.extern.slf4j.Slf4j;
import org.springframework.security.access.PermissionEvaluator; import org.springframework.security.access.PermissionEvaluator;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import pangea.hiagent.web.service.AgentService;
import pangea.hiagent.web.service.TimerService;
import pangea.hiagent.model.Agent; import pangea.hiagent.model.Agent;
import pangea.hiagent.model.TimerConfig; import pangea.hiagent.model.TimerConfig;
import pangea.hiagent.web.service.AgentService;
import pangea.hiagent.web.service.TimerService;
import java.io.Serializable; import java.io.Serializable;
...@@ -20,6 +19,9 @@ import java.io.Serializable; ...@@ -20,6 +19,9 @@ import java.io.Serializable;
@Component("permissionEvaluator") @Component("permissionEvaluator")
public class DefaultPermissionEvaluator implements PermissionEvaluator { public class DefaultPermissionEvaluator implements PermissionEvaluator {
private static final String AGENT_TYPE = "Agent";
private static final String TIMER_CONFIG_TYPE = "TimerConfig";
private final AgentService agentService; private final AgentService agentService;
private final TimerService timerService; private final TimerService timerService;
...@@ -37,33 +39,21 @@ public class DefaultPermissionEvaluator implements PermissionEvaluator { ...@@ -37,33 +39,21 @@ public class DefaultPermissionEvaluator implements PermissionEvaluator {
return false; return false;
} }
Object principal = authentication.getPrincipal(); String userId = authentication.getPrincipal().toString();
if (principal == null) {
return false;
}
String userId = principal.toString();
String perm = (String) permission; String perm = (String) permission;
try { try {
// 处理Agent访问权限 // 处理Agent访问权限
if (targetDomainObject instanceof Agent) { if (targetDomainObject instanceof Agent) {
Agent agent = (Agent) targetDomainObject; return checkAgentAccess(userId, (Agent) targetDomainObject, perm);
return checkAgentAccess(userId, agent, perm);
} }
// 处理TimerConfig访问权限 // 处理TimerConfig访问权限
else if (targetDomainObject instanceof TimerConfig) { else if (targetDomainObject instanceof TimerConfig) {
TimerConfig timer = (TimerConfig) targetDomainObject; return checkTimerAccess(userId, (TimerConfig) targetDomainObject, perm);
return checkTimerAccess(userId, timer, perm);
}
// 处理基于ID的资源访问
else if (targetDomainObject instanceof String) {
// 这种情况在hasPermission(Authentication, Serializable, String, Object)方法中处理
return false;
} }
} catch (Exception e) { } catch (Exception e) {
log.error("权限检查过程中发生异常: userId={}, targetDomainObject={}, permission={}", userId, targetDomainObject, permission, e); log.error("权限检查异常: userId={}, target={}, permission={}, error={}",
return false; userId, targetDomainObject.getClass().getSimpleName(), perm, e.getMessage());
} }
return false; return false;
...@@ -75,36 +65,23 @@ public class DefaultPermissionEvaluator implements PermissionEvaluator { ...@@ -75,36 +65,23 @@ public class DefaultPermissionEvaluator implements PermissionEvaluator {
return false; return false;
} }
Object principal = authentication.getPrincipal(); String userId = authentication.getPrincipal().toString();
if (principal == null) {
return false;
}
String userId = principal.toString();
String perm = (String) permission; String perm = (String) permission;
try { try {
// 处理基于ID的权限检查 // 处理基于ID的权限检查
if ("Agent".equals(targetType)) { if (AGENT_TYPE.equals(targetType)) {
Agent agent = agentService.getAgent(targetId.toString()); Agent agent = agentService.getAgent(targetId.toString());
if (agent == null) { return agent != null && checkAgentAccess(userId, agent, perm);
log.warn("未找到ID为 {} 的Agent", targetId);
return false;
}
return checkAgentAccess(userId, agent, perm);
} }
// 处理TimerConfig资源的权限检查 // 处理TimerConfig资源的权限检查
else if ("TimerConfig".equals(targetType)) { else if (TIMER_CONFIG_TYPE.equals(targetType)) {
TimerConfig timer = timerService.getTimerById(targetId.toString()); TimerConfig timer = timerService.getTimerById(targetId.toString());
if (timer == null) { return timer != null && checkTimerAccess(userId, timer, perm);
log.warn("未找到ID为 {} 的TimerConfig", targetId);
return false;
}
return checkTimerAccess(userId, timer, perm);
} }
} catch (Exception e) { } catch (Exception e) {
log.error("基于ID的权限检查过程中发生异常: userId={}, targetId={}, targetType={}, permission={}", userId, targetId, targetType, permission, e); log.error("基于ID的权限检查异常: userId={}, targetId={}, targetType={}, permission={}, error={}",
return false; userId, targetId, targetType, perm, e.getMessage());
} }
return false; return false;
...@@ -119,24 +96,17 @@ public class DefaultPermissionEvaluator implements PermissionEvaluator { ...@@ -119,24 +96,17 @@ public class DefaultPermissionEvaluator implements PermissionEvaluator {
return true; return true;
} }
// 检查Agent所有者 // 所有者可以访问
if (agent.getOwner().equals(userId)) { if (agent.getOwner().equals(userId)) {
return true; return true;
} }
// 根据权限类型进行检查 // 根据权限类型进行检查(目前只支持所有者访问)
switch (permission.toLowerCase()) { String permissionLower = permission.toLowerCase();
case "read": return switch (permissionLower) {
// 所有用户都可以读取公开的Agent(如果有此概念) case "read", "write", "delete", "execute" -> agent.getOwner().equals(userId);
return false; // 暂时不支持公开Agent default -> false;
case "write": };
case "delete":
case "execute":
// 只有所有者可以写入、删除或执行Agent
return agent.getOwner().equals(userId);
default:
return false;
}
} }
/** /**
...@@ -148,32 +118,24 @@ public class DefaultPermissionEvaluator implements PermissionEvaluator { ...@@ -148,32 +118,24 @@ public class DefaultPermissionEvaluator implements PermissionEvaluator {
return true; return true;
} }
// 检查定时器创建者 // 创建者可以访问
if (timer.getCreatedBy() != null && timer.getCreatedBy().equals(userId)) { if (timer.getCreatedBy() != null && timer.getCreatedBy().equals(userId)) {
return true; return true;
} }
// 根据权限类型进行检查 // 根据权限类型进行检查(目前只支持创建者访问)
switch (permission.toLowerCase()) { String permissionLower = permission.toLowerCase();
case "read": return switch (permissionLower) {
// 所有用户都可以读取公开的定时器(如果有此概念) case "read", "write", "delete" -> timer.getCreatedBy() != null && timer.getCreatedBy().equals(userId);
return false; // 暂时不支持公开定时器 default -> false;
case "write": };
case "delete":
// 只有创建者可以修改或删除定时器
return timer.getCreatedBy() != null && timer.getCreatedBy().equals(userId);
default:
return false;
}
} }
/** /**
* 检查是否为管理员用户 * 检查是否为管理员用户
*/ */
private boolean isAdminUser(String userId) { private boolean isAdminUser(String userId) {
// 这里可以根据实际需求实现管理员检查逻辑 // 管理员用户检查,可扩展为从配置或数据库读取
// 例如查询数据库或检查特殊用户ID
// 当前实现保留原有逻辑,但可以通过配置或数据库来管理管理员用户
return "admin".equals(userId) || "user-001".equals(userId); return "admin".equals(userId) || "user-001".equals(userId);
} }
} }
\ No newline at end of file
...@@ -5,18 +5,16 @@ import jakarta.servlet.ServletException; ...@@ -5,18 +5,16 @@ import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse; import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import pangea.hiagent.common.utils.JwtUtil;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import org.springframework.web.filter.OncePerRequestFilter; import org.springframework.web.filter.OncePerRequestFilter;
import pangea.hiagent.common.utils.JwtUtil;
import java.io.IOException; import java.io.IOException;
import java.util.Collections; import java.util.Collections;
import java.util.List;
/** /**
* JWT认证过滤器 * JWT认证过滤器
...@@ -26,6 +24,8 @@ import java.util.List; ...@@ -26,6 +24,8 @@ import java.util.List;
@Component @Component
public class JwtAuthenticationFilter extends OncePerRequestFilter { public class JwtAuthenticationFilter extends OncePerRequestFilter {
private static final String BEARER_PREFIX = "Bearer ";
private final JwtUtil jwtUtil; private final JwtUtil jwtUtil;
public JwtAuthenticationFilter(JwtUtil jwtUtil) { public JwtAuthenticationFilter(JwtUtil jwtUtil) {
...@@ -35,19 +35,8 @@ public class JwtAuthenticationFilter extends OncePerRequestFilter { ...@@ -35,19 +35,8 @@ public class JwtAuthenticationFilter extends OncePerRequestFilter {
@Override @Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException { throws ServletException, IOException {
boolean isStreamEndpoint = request.getRequestURI().contains("/api/v1/agent/chat-stream");
boolean isTimelineEndpoint = request.getRequestURI().contains("/api/v1/agent/timeline-events");
if (isStreamEndpoint) {
log.info("处理Agent流式对话请求: {} {}", request.getMethod(), request.getRequestURI());
}
if (isTimelineEndpoint) {
log.info("处理时间轴事件订阅请求: {} {}", request.getMethod(), request.getRequestURI());
}
// 对于OPTIONS请求,直接放行 // 对于OPTIONS请求,直接放行
if ("OPTIONS".equalsIgnoreCase(request.getMethod())) { if ("OPTIONS".equalsIgnoreCase(request.getMethod())) {
log.debug("OPTIONS请求,直接放行");
filterChain.doFilter(request, response); filterChain.doFilter(request, response);
return; return;
} }
...@@ -55,66 +44,25 @@ public class JwtAuthenticationFilter extends OncePerRequestFilter { ...@@ -55,66 +44,25 @@ public class JwtAuthenticationFilter extends OncePerRequestFilter {
try { try {
String token = extractTokenFromRequest(request); String token = extractTokenFromRequest(request);
log.debug("JWT过滤器处理请求: {} {},提取到token: {}", request.getMethod(), request.getRequestURI(), token);
if (StringUtils.hasText(token)) { if (StringUtils.hasText(token)) {
log.debug("开始JWT验证,token长度: {}", token.length());
// 验证token是否有效 // 验证token是否有效
boolean isValid = jwtUtil.validateToken(token); if (jwtUtil.validateToken(token)) {
log.debug("JWT验证结果: {}", isValid);
if (isValid) {
String userId = jwtUtil.getUserIdFromToken(token); String userId = jwtUtil.getUserIdFromToken(token);
log.debug("JWT验证通过,用户ID: {}", userId);
if (userId != null) { if (userId != null) {
// 创建认证对象,添加基本权限 // 创建认证对象,添加基本权限
List<SimpleGrantedAuthority> authorities = Collections.singletonList(new SimpleGrantedAuthority("ROLE_USER")); var authorities = Collections.singletonList(new SimpleGrantedAuthority("ROLE_USER"));
UsernamePasswordAuthenticationToken authentication = var authentication = new UsernamePasswordAuthenticationToken(userId, null, authorities);
new UsernamePasswordAuthenticationToken(userId, null, authorities);
SecurityContextHolder.getContext().setAuthentication(authentication); SecurityContextHolder.getContext().setAuthentication(authentication);
log.debug("已设置SecurityContext中的认证信息,用户ID: {}, 权限: {}", userId, authentication.getAuthorities());
} else {
log.warn("从token中提取的用户ID为空");
} }
} else {
log.warn("JWT验证失败,token可能已过期或无效");
// 检查token是否过期
boolean isExpired = jwtUtil.isTokenExpired(token);
log.warn("Token过期状态: {}", isExpired);
} }
} else {
log.debug("未找到有效的token");
// 记录请求信息以便调试
log.debug("请求URL: {}", request.getRequestURL());
log.debug("请求方法: {}", request.getMethod());
log.debug("Authorization头: {}", request.getHeader("Authorization"));
log.debug("token参数: {}", request.getParameter("token"));
} }
} catch (Exception e) { } catch (Exception e) {
log.error("JWT认证处理异常", e); log.error("JWT认证处理异常: {}", e.getMessage());
// 不在此处发送错误响应,让Spring Security的ExceptionTranslationFilter处理 // 不在此处发送错误响应,让Spring Security的ExceptionTranslationFilter处理
// 这样可以避免响应被提前提交
}
// 特别处理流式端点的权限问题
if (isStreamEndpoint || isTimelineEndpoint) {
// 检查是否已认证
if (SecurityContextHolder.getContext().getAuthentication() == null) {
log.warn("流式端点未认证访问: {} {}", request.getMethod(), request.getRequestURI());
// 对于SSE端点,如果未认证,我们不立即返回错误,而是让后续处理决定
// 因为客户端可能会在重新连接时带上token
} }
// 对于SSE端点,直接执行过滤器链,不进行额外的响应检查 // 继续执行过滤器链
filterChain.doFilter(request, response); filterChain.doFilter(request, response);
log.debug("JwtAuthenticationFilter处理完成(SSE端点): {} {}", request.getMethod(), request.getRequestURI());
return;
}
// 继续执行过滤器链,让Spring Security的其他过滤器处理认证和授权
// 这样可以让ExceptionTranslationFilter和AuthorizationFilter正确处理认证失败和权限拒绝
filterChain.doFilter(request, response);
log.debug("JwtAuthenticationFilter处理完成: {} {}", request.getMethod(), request.getRequestURI());
} }
/** /**
...@@ -124,23 +72,11 @@ public class JwtAuthenticationFilter extends OncePerRequestFilter { ...@@ -124,23 +72,11 @@ public class JwtAuthenticationFilter extends OncePerRequestFilter {
private String extractTokenFromRequest(HttpServletRequest request) { private String extractTokenFromRequest(HttpServletRequest request) {
// 首先尝试从请求头中提取Token // 首先尝试从请求头中提取Token
String authHeader = request.getHeader("Authorization"); String authHeader = request.getHeader("Authorization");
log.debug("从请求头中提取Authorization: {}", authHeader); if (StringUtils.hasText(authHeader) && authHeader.startsWith(BEARER_PREFIX)) {
if (StringUtils.hasText(authHeader) && authHeader.startsWith("Bearer ")) { return authHeader.substring(BEARER_PREFIX.length());
String token = authHeader.substring(7);
log.debug("从Authorization头中提取到token");
return token;
} }
// 如果请求头中没有Token,则尝试从URL参数中提取 // 如果请求头中没有Token,则尝试从URL参数中提取
// 这对于SSE连接特别有用,因为浏览器在自动重连时可能不会发送Authorization头 return request.getParameter("token");
String tokenParam = request.getParameter("token");
log.debug("从URL参数中提取token参数: {}", tokenParam);
if (StringUtils.hasText(tokenParam)) {
log.debug("从URL参数中提取到token");
return tokenParam;
}
log.debug("未找到有效的token");
return null;
} }
} }
\ No newline at end of file
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment