Commit 0306580c authored by ligaowei's avatar ligaowei

refactor(agent): 移除未使用的回调方法和冗余代码

重构ReAct相关组件,移除未使用的onFinalAnswer回调方法和冗余的getParameters/getContent方法。优化SseTokenEmitter为无状态设计,通过构造函数一次性传入所有状态。简化JwtAuthenticationFilter和DefaultPermissionEvaluator的权限检查逻辑。改进ErrorHandlerService的错误处理代码复用性。调整代码结构以提高可维护性。
parent 40bd44a9
......@@ -31,8 +31,7 @@ public class ReActAgentProcessor extends BaseAgentProcessor {
@Autowired
private RagService ragService;
@Autowired
private ReactCallback defaultReactCallback;
@Autowired
private ReactExecutor defaultReactExecutor;
......@@ -40,6 +39,9 @@ public class ReActAgentProcessor extends BaseAgentProcessor {
@Autowired
private AgentToolManager agentToolManager;
@Autowired
private ReactCallback defaultReactCallback;
@Override
public String processRequest(Agent agent, AgentRequest request, String userId) {
log.info("使用ReAct Agent处理请求");
......@@ -72,10 +74,6 @@ public class ReActAgentProcessor extends BaseAgentProcessor {
// 处理请求的通用前置逻辑
String ragResponse = handlePreProcessing(agent, userMessage, userId, ragService, null);
if (ragResponse != null) {
// 触发最终答案回调
if (defaultReactCallback != null) {
defaultReactCallback.onFinalAnswer(ragResponse);
}
return ragResponse;
}
......@@ -83,10 +81,7 @@ public class ReActAgentProcessor extends BaseAgentProcessor {
ChatClient client = ChatClient.builder(agentService.getChatModelForAgent(agent)).build();
List<Object> tools = agentToolManager.getAvailableToolInstances(agent);
// 添加自定义回调到ReAct执行器
if (defaultReactExecutor != null && defaultReactCallback != null) {
defaultReactExecutor.addReactCallback(defaultReactCallback);
}
// 使用ReAct执行器执行流程,传递Agent对象和用户ID以支持记忆功能
String finalAnswer = defaultReactExecutor.execute(client, userMessage, tools, agent, userId);
......@@ -114,10 +109,6 @@ public class ReActAgentProcessor extends BaseAgentProcessor {
// 处理请求的通用前置逻辑
String ragResponse = handlePreProcessing(agent, userMessage, userId, ragService, tokenConsumer);
if (ragResponse != null) {
// 触发最终答案回调
if (defaultReactCallback != null) {
defaultReactCallback.onFinalAnswer(ragResponse);
}
return;
}
......
......@@ -24,12 +24,6 @@ public class DefaultReactCallback implements ReactCallback {
recordReactStepToWorkPanel(reactStep);
}
@Override
public void onFinalAnswer(String finalAnswer) {
ReactStep finalStep = new ReactStep(0, ReactStepType.FINAL_ANSWER, finalAnswer);
recordReactStepToWorkPanel(finalStep);
}
private void recordReactStepToWorkPanel(ReactStep reactStep) {
if (workPanelCollector == null) {
return;
......@@ -46,7 +40,7 @@ public class DefaultReactCallback implements ReactCallback {
if (reactStep.getAction() != null) {
// 记录工具调用动作
String toolName = reactStep.getAction().getToolName();
Object parameters = reactStep.getAction().getParameters();
Object parameters = reactStep.getAction().getToolArgs();
// 记录工具调用,初始状态为pending
workPanelCollector.recordToolCallAction(
......@@ -73,15 +67,15 @@ public class DefaultReactCallback implements ReactCallback {
// 使用动作信息更新工具调用结果
workPanelCollector.recordToolCallAction(
reactStep.getAction().getToolName(),
reactStep.getAction().getParameters(),
reactStep.getObservation().getContent(),
reactStep.getAction().getToolArgs(),
reactStep.getObservation().getResult(),
"success", // 状态为success
null // 无错误信息
);
log.info("[WorkPanel] 更新工具调用结果: 工具={} 结果摘要={}",
reactStep.getAction().getToolName(),
reactStep.getObservation().getContent().substring(0, Math.min(50, reactStep.getObservation().getContent().length())));
reactStep.getObservation().getResult().substring(0, Math.min(50, reactStep.getObservation().getResult().length())));
} else {
// 如果没有动作信息,记录为观察结果
workPanelCollector.recordThinking(reactStep.getContent(), "observation");
......@@ -108,7 +102,7 @@ public class DefaultReactCallback implements ReactCallback {
if (reactStep != null && reactStep.getAction() != null) {
workPanelCollector.recordToolCallAction(
reactStep.getAction().getToolName(),
reactStep.getAction().getParameters(),
reactStep.getAction().getToolArgs(),
"记录失败: " + e.getMessage(),
"error",
System.currentTimeMillis() // 使用当前时间戳作为执行时间
......
......@@ -10,6 +10,4 @@ public interface ReactCallback {
* @param reactStep ReAct步骤对象,包含步骤的所有核心信息
*/
void onStep(ReactStep reactStep);
void onFinalAnswer(String ragResponse);
}
\ No newline at end of file
......@@ -49,9 +49,6 @@ public class ReactStep {
public Object getToolArgs() { return toolArgs; }
public void setToolArgs(Object toolArgs) { this.toolArgs = toolArgs; }
// 根据DefaultReactCallback.java中的使用情况添加getParameters方法
public Object getParameters() { return toolArgs; }
}
/**
......@@ -66,8 +63,5 @@ public class ReactStep {
public String getResult() { return result; }
public void setResult(String result) { this.result = result; }
// 根据DefaultReactCallback.java中的使用情况添加getContent方法
public String getContent() { return result; }
}
}
\ No newline at end of file
......@@ -27,7 +27,7 @@ public class AgentChatService {
private final ErrorHandlerService errorHandlerService;
private final AgentProcessorFactory agentProcessorFactory;
private final AgentToolManager agentToolManager;
private final UserSseService userSseSerivce;
private final UserSseService userSseService;
private final pangea.hiagent.web.service.AgentService agentService;
private final SseTokenEmitter sseTokenEmitter;
......@@ -36,13 +36,13 @@ public class AgentChatService {
ErrorHandlerService errorHandlerService,
AgentProcessorFactory agentProcessorFactory,
AgentToolManager agentToolManager,
UserSseService workPanelSseService,
UserSseService userSseService,
pangea.hiagent.web.service.AgentService agentService,
SseTokenEmitter sseTokenEmitter) {
this.errorHandlerService = errorHandlerService;
this.agentProcessorFactory = agentProcessorFactory;
this.agentToolManager = agentToolManager;
this.userSseSerivce = workPanelSseService;
this.userSseService = userSseService;
this.agentService = agentService;
this.sseTokenEmitter = sseTokenEmitter;
}
......@@ -54,8 +54,10 @@ public class AgentChatService {
// * @param userId 用户ID
// * @return 处理结果
// */
// public String handleChatSync(Agent agent, AgentRequest request, String userId) {
// log.info("开始处理同步对话请求,AgentId: {}, 用户消息: {}", agent.getId(), request.getUserMessage());
// public String handleChatSync(Agent agent, AgentRequest request, String
// userId) {
// log.info("开始处理同步对话请求,AgentId: {}, 用户消息: {}", agent.getId(),
// request.getUserMessage());
//
// try {
// // 获取处理器
......@@ -94,14 +96,14 @@ public class AgentChatService {
if (userId == null) {
log.error("用户未认证");
SseEmitter emitter = userSseSerivce.createEmitter();
SseEmitter emitter = userSseService.createEmitter();
// 检查响应是否已经提交
if (!response.isCommitted()) {
errorHandlerService.handleChatError(emitter, "用户未认证,请重新登录");
} else {
log.warn("响应已提交,无法发送用户未认证错误信息");
// 检查emitter是否已经完成,避免重复关闭
if (!userSseSerivce.isEmitterCompleted(emitter)) {
if (!userSseService.isEmitterCompleted(emitter)) {
emitter.complete();
}
}
......@@ -112,14 +114,14 @@ public class AgentChatService {
Agent agent = agentService.getAgent(agentId);
if (agent == null) {
log.warn("Agent不存在: {}", agentId);
SseEmitter emitter = userSseSerivce.createEmitter();
SseEmitter emitter = userSseService.createEmitter();
// 检查响应是否已经提交
if (!response.isCommitted()) {
errorHandlerService.handleChatError(emitter, "Agent不存在");
} else {
log.warn("响应已提交,无法发送Agent不存在错误信息");
// 检查emitter是否已经完成,避免重复关闭
if (!userSseSerivce.isEmitterCompleted(emitter)) {
if (!userSseService.isEmitterCompleted(emitter)) {
emitter.complete();
}
}
......@@ -127,7 +129,7 @@ public class AgentChatService {
}
// 创建 SSE emitter
SseEmitter emitter = userSseSerivce.createEmitter();
SseEmitter emitter = userSseService.createEmitter();
// 异步处理对话,避免阻塞HTTP连接
processChatStreamAsync(emitter, agent, chatRequest, userId);
......@@ -144,15 +146,10 @@ public class AgentChatService {
processChatRequest(emitter, agent, chatRequest, userId);
} catch (Exception e) {
log.error("处理聊天请求时发生异常", e);
try {
// 检查响应是否已经提交
if (emitter != null && !userSseSerivce.isEmitterCompleted(emitter)) {
if (emitter != null && !userSseService.isEmitterCompleted(emitter)) {
errorHandlerService.handleChatError(emitter, "处理请求时发生错误", e, null);
} else {
log.warn("响应已提交或emitter已完成,无法发送处理请求错误信息");
}
} catch (Exception handlerException) {
log.error("处理错误信息时发生异常", handlerException);
}
}
}
......@@ -184,17 +181,11 @@ public class AgentChatService {
// 转换请求对象
AgentRequest request = chatRequest.toAgentRequest(agent.getId(), agent, agentToolManager);
// 设置SSE发射器到token发射器
sseTokenEmitter.setEmitter(emitter);
// 设置上下文信息
sseTokenEmitter.setContext(agent, request, userId);
// 设置完成回调
sseTokenEmitter.setCompletionCallback(this::handleCompletion);
// 创建新的SseTokenEmitter实例
SseTokenEmitter tokenEmitter = sseTokenEmitter.createNewInstance(emitter, agent, request, userId, this::handleCompletion);
// 处理流式请求
processor.processStreamRequest(request, agent, userId, sseTokenEmitter);
processor.processStreamRequest(request, agent, userId, tokenEmitter);
} catch (Exception e) {
log.error("处理聊天请求时发生异常", e);
errorHandlerService.handleChatError(emitter, "处理请求时发生错误", e, null);
......@@ -210,7 +201,8 @@ public class AgentChatService {
* @param userId 用户ID
* @param fullContent 完整内容
*/
private void handleCompletion(SseEmitter emitter, Agent agent, AgentRequest request, String userId, String fullContent) {
private void handleCompletion(SseEmitter emitter, Agent agent, AgentRequest request, String userId,
String fullContent) {
log.info("Agent处理完成,总字符数: {}", fullContent != null ? fullContent.length() : 0);
// 保存对话记录
......
......@@ -102,7 +102,7 @@ public class ErrorHandlerService {
*
* @param emitter SSE发射器
* @param errorMessage 错误信息
* @param exception 异常对象
* @param exception 异常对象(可选)
* @param processorType 处理器类型(可选)
*/
public void handleChatError(SseEmitter emitter, String errorMessage, Exception exception, String processorType) {
......@@ -142,44 +142,25 @@ public class ErrorHandlerService {
}
/**
* 处理聊天过程中的异常()
* 处理聊天过程中的异常(简化版
*
* @param emitter SSE发射器
* @param errorMessage 错误信息
*/
public void handleChatError(SseEmitter emitter, String errorMessage) {
// 参数验证
if (errorMessage == null || errorMessage.isEmpty()) {
errorMessage = "未知错误";
}
// 生成错误跟踪ID
String errorId = generateErrorId();
log.error("[{}] 处理聊天请求时发生错误: {}", errorId, errorMessage);
try {
// 检查emitter是否已经完成,避免向已完成的连接发送错误信息
if (userSseService != null && !userSseService.isEmitterCompleted(emitter)) {
String fullErrorMessage = buildFullErrorMessage(errorMessage, null, errorId, null);
userSseService.sendErrorEvent(emitter, fullErrorMessage);
} else {
log.debug("[{}] SSE emitter已完成,跳过发送错误信息", errorId);
}
} catch (Exception sendErrorEx) {
log.error("[{}] 发送错误信息失败", errorId, sendErrorEx);
}
handleChatError(emitter, errorMessage, null, null);
}
/**
* 处理Token处理过程中的异常
* 处理带完成状态标记的异常
*
* @param emitter SSE发射器
* @param errorMessage 错误信息
* @param processorType 处理器类型
* @param exception 异常对象
* @param isCompleted 完成状态标记
*/
public void handleTokenError(SseEmitter emitter, String processorType, Exception exception, AtomicBoolean isCompleted) {
private void handleErrorWithCompletion(SseEmitter emitter, String errorMessage, String processorType, Exception exception, AtomicBoolean isCompleted) {
// 参数验证
if (processorType == null || processorType.isEmpty()) {
processorType = "未知处理器";
......@@ -192,17 +173,17 @@ public class ErrorHandlerService {
if (exception != null) {
exceptionMonitoringService.recordException(
exception.getClass().getSimpleName(),
"处理token时发生错误",
errorMessage,
java.util.Arrays.toString(exception.getStackTrace())
);
}
log.error("[{}] {}处理token时发生错误", errorId, processorType, exception);
log.error("[{}] {}: {}", errorId, processorType, errorMessage, exception);
if (!isCompleted.getAndSet(true)) {
try {
// 检查emitter是否已经完成,避免向已完成的连接发送错误信息
if (userSseService != null && !userSseService.isEmitterCompleted(emitter)) {
String errorMessage = "处理响应时发生错误";
String fullErrorMessage = buildFullErrorMessage(errorMessage, exception, errorId, processorType);
userSseService.sendErrorEvent(emitter, fullErrorMessage);
} else {
......@@ -216,6 +197,18 @@ public class ErrorHandlerService {
}
}
/**
* 处理Token处理过程中的异常
*
* @param emitter SSE发射器
* @param processorType 处理器类型
* @param exception 异常对象
* @param isCompleted 完成状态标记
*/
public void handleTokenError(SseEmitter emitter, String processorType, Exception exception, AtomicBoolean isCompleted) {
handleErrorWithCompletion(emitter, "处理token时发生错误", processorType, exception, isCompleted);
}
/**
* 处理完成回调过程中的异常
*
......@@ -252,15 +245,13 @@ public class ErrorHandlerService {
}
/**
* 处理流式处理中的错误
* 处理基于Consumer的流式错误
*
* @param e 异常对象
* @param tokenConsumer token处理回调函数
* @param errorMessagePrefix 错误消息前缀
* @param errorMessage 完整错误消息
*/
public void handleStreamError(Throwable e, Consumer<String> tokenConsumer, String errorMessagePrefix) {
String errorMessage = errorMessagePrefix + ": " + e.getMessage();
private void handleConsumerError(Throwable e, Consumer<String> tokenConsumer, String errorMessage) {
// 记录异常到监控服务
exceptionMonitoringService.recordException(
e.getClass().getSimpleName(),
......@@ -268,12 +259,24 @@ public class ErrorHandlerService {
java.util.Arrays.toString(e.getStackTrace())
);
log.error("流式处理错误: {}", errorMessage, e);
log.error(errorMessage, e);
if (tokenConsumer != null) {
tokenConsumer.accept("[ERROR] " + errorMessage);
}
}
/**
* 处理流式处理中的错误
*
* @param e 异常对象
* @param tokenConsumer token处理回调函数
* @param errorMessagePrefix 错误消息前缀
*/
public void handleStreamError(Throwable e, Consumer<String> tokenConsumer, String errorMessagePrefix) {
String errorMessage = errorMessagePrefix + ": " + e.getMessage();
handleConsumerError(e, tokenConsumer, errorMessage);
}
/**
* 发送错误信息给客户端
*
......@@ -294,18 +297,7 @@ public class ErrorHandlerService {
*/
public void handleReactFlowError(Exception e, Consumer<String> tokenConsumer) {
String errorMessage = "处理ReAct流程时发生错误: " + e.getMessage();
// 记录异常到监控服务
exceptionMonitoringService.recordException(
e.getClass().getSimpleName(),
errorMessage,
java.util.Arrays.toString(e.getStackTrace())
);
log.error("ReAct流程错误: {}", errorMessage, e);
if (tokenConsumer != null) {
tokenConsumer.accept("[ERROR] " + errorMessage);
}
handleConsumerError(e, tokenConsumer, errorMessage);
}
/**
......@@ -337,33 +329,6 @@ public class ErrorHandlerService {
* @param isCompleted 完成状态标记
*/
public void handleSaveDialogueError(SseEmitter emitter, Exception exception, AtomicBoolean isCompleted) {
// 生成错误跟踪ID
String errorId = generateErrorId();
// 记录异常到监控服务
if (exception != null) {
exceptionMonitoringService.recordException(
exception.getClass().getSimpleName(),
"保存对话记录失败",
java.util.Arrays.toString(exception.getStackTrace())
);
}
log.error("[{}] 保存对话记录失败", errorId, exception);
if (!isCompleted.getAndSet(true)) {
try {
// 检查emitter是否已经完成,避免向已完成的连接发送错误信息
if (userSseService != null && !userSseService.isEmitterCompleted(emitter)) {
String errorMessage = "保存对话记录失败,请联系技术支持";
String fullErrorMessage = buildFullErrorMessage(errorMessage, exception, errorId, "对话记录");
userSseService.sendErrorEvent(emitter, fullErrorMessage);
} else {
log.debug("[{}] SSE emitter已完成,跳过发送错误信息", errorId);
}
} catch (Exception sendErrorEx) {
log.error("[{}] 发送错误信息失败", errorId, sendErrorEx);
}
}
handleErrorWithCompletion(emitter, "保存对话记录失败", "对话记录", exception, isCompleted);
}
}
\ No newline at end of file
......@@ -2,9 +2,10 @@ package pangea.hiagent.agent.service;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import java.util.Map;
import java.util.concurrent.locks.ReentrantReadWriteLock;
/**
* 异常监控服务
......@@ -17,12 +18,18 @@ public class ExceptionMonitoringService {
// 异常统计信息
private final Map<String, AtomicLong> exceptionCounters = new ConcurrentHashMap<>();
// 异常详细信息缓存
private final Map<String, String> exceptionDetails = new ConcurrentHashMap<>();
// 异常详细信息缓存,使用时间戳作为键,便于按时间排序
private final Map<Long, String> exceptionDetails = new ConcurrentHashMap<>();
// 锁,用于保护缓存清理操作
private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
// 最大缓存条目数
private static final int MAX_CACHE_SIZE = 1000;
// 清理阈值,当缓存超过最大值时,清理到这个值
private static final int CLEANUP_THRESHOLD = MAX_CACHE_SIZE - 200;
/**
* 记录异常信息
*
......@@ -37,14 +44,31 @@ public class ExceptionMonitoringService {
counter.incrementAndGet();
// 记录异常详细信息(保留最新的)
String detailKey = exceptionType + "_" + System.currentTimeMillis();
exceptionDetails.put(detailKey, formatExceptionDetail(exceptionType, errorMessage, stackTrace));
long timestamp = System.currentTimeMillis();
exceptionDetails.put(timestamp, formatExceptionDetail(exceptionType, errorMessage, stackTrace));
// 控制缓存大小
// 控制缓存大小,使用写锁保护清理操作
if (exceptionDetails.size() > MAX_CACHE_SIZE) {
lock.writeLock().lock();
try {
// 再次检查,避免竞态条件
if (exceptionDetails.size() > MAX_CACHE_SIZE) {
// 移除最老的条目
String oldestKey = exceptionDetails.keySet().iterator().next();
exceptionDetails.remove(oldestKey);
// 找出最老的条目并移除,直到达到清理阈值
while (exceptionDetails.size() > CLEANUP_THRESHOLD) {
// 找出最小的时间戳(最老的条目)
Long oldestTimestamp = exceptionDetails.keySet().stream()
.min(Long::compare)
.orElse(null);
if (oldestTimestamp != null) {
exceptionDetails.remove(oldestTimestamp);
} else {
break;
}
}
}
} finally {
lock.writeLock().unlock();
}
}
// 记录日志
......@@ -102,7 +126,11 @@ public class ExceptionMonitoringService {
* @return 异常详细信息
*/
public Map<String, String> getExceptionDetails() {
return new ConcurrentHashMap<>(exceptionDetails);
Map<String, String> result = new ConcurrentHashMap<>();
for (Map.Entry<Long, String> entry : exceptionDetails.entrySet()) {
result.put(entry.getKey().toString(), entry.getValue());
}
return result;
}
/**
......
......@@ -10,6 +10,7 @@ import pangea.hiagent.web.dto.AgentRequest;
/**
* SSE Token发射器
* 专注于将token转换为SSE事件并发送
* 无状态设计,每次使用时创建新实例
*/
@Slf4j
@Component
......@@ -17,42 +18,51 @@ public class SseTokenEmitter implements TokenConsumerWithCompletion {
private final UserSseService userSseService;
// 当前处理的emitter
private SseEmitter emitter;
// 上下文信息
private Agent agent;
private AgentRequest request;
private String userId;
// 完成回调
private CompletionCallback completionCallback;
public SseTokenEmitter(UserSseService userSseService) {
this.userSseService = userSseService;
}
// 所有状态通过构造函数一次性传入
private final SseEmitter emitter;
private final Agent agent;
private final AgentRequest request;
private final String userId;
private final CompletionCallback completionCallback;
/**
* 设置当前使用的SSE发射器
* 构造函数
* @param userSseService SSE服务
* @param emitter SSE发射器
* @param agent Agent对象
* @param request 请求对象
* @param userId 用户ID
* @param completionCallback 完成回调
*/
public void setEmitter(SseEmitter emitter) {
public SseTokenEmitter(UserSseService userSseService, SseEmitter emitter, Agent agent,
AgentRequest request, String userId, CompletionCallback completionCallback) {
this.userSseService = userSseService;
this.emitter = emitter;
this.agent = agent;
this.request = request;
this.userId = userId;
this.completionCallback = completionCallback;
}
/**
* 设置上下文信息
* 无参构造函数,用于Spring容器初始化
*/
public void setContext(Agent agent, AgentRequest request, String userId) {
this.agent = agent;
this.request = request;
this.userId = userId;
public SseTokenEmitter(UserSseService userSseService) {
this(userSseService, null, null, null, null, null);
}
/**
* 设置完成回调
* 创建新的SseTokenEmitter实例
* @param emitter SSE发射器
* @param agent Agent对象
* @param request 请求对象
* @param userId 用户ID
* @param completionCallback 完成回调
* @return 新的SseTokenEmitter实例
*/
public void setCompletionCallback(CompletionCallback completionCallback) {
this.completionCallback = completionCallback;
public SseTokenEmitter createNewInstance(SseEmitter emitter, Agent agent, AgentRequest request,
String userId, CompletionCallback completionCallback) {
return new SseTokenEmitter(userSseService, emitter, agent, request, userId, completionCallback);
}
@Override
......
package pangea.hiagent.agent.service;
import java.util.function.Consumer;
import java.util.concurrent.atomic.AtomicBoolean;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import pangea.hiagent.workpanel.event.EventService;
/**
* Token消费者接口,支持完成回调
......@@ -17,17 +14,4 @@ public interface TokenConsumerWithCompletion extends Consumer<String> {
default void onComplete(String fullContent) {
// 默认实现为空
}
/**
* 当流式处理完成时调用,发送完成事件到前端
* @param fullContent 完整的内容
* @param emitter SSE发射器
* @param sseEventSender SSE事件发送器
* @param isCompleted 完成状态标记
*/
default void onComplete(String fullContent, SseEmitter emitter,
EventService eventService,
AtomicBoolean isCompleted) {
// 默认实现将在子类中覆盖
}
}
\ No newline at end of file
......@@ -4,11 +4,10 @@ import lombok.extern.slf4j.Slf4j;
import org.springframework.security.access.PermissionEvaluator;
import org.springframework.security.core.Authentication;
import org.springframework.stereotype.Component;
import pangea.hiagent.web.service.AgentService;
import pangea.hiagent.web.service.TimerService;
import pangea.hiagent.model.Agent;
import pangea.hiagent.model.TimerConfig;
import pangea.hiagent.web.service.AgentService;
import pangea.hiagent.web.service.TimerService;
import java.io.Serializable;
......@@ -20,6 +19,9 @@ import java.io.Serializable;
@Component("permissionEvaluator")
public class DefaultPermissionEvaluator implements PermissionEvaluator {
private static final String AGENT_TYPE = "Agent";
private static final String TIMER_CONFIG_TYPE = "TimerConfig";
private final AgentService agentService;
private final TimerService timerService;
......@@ -37,33 +39,21 @@ public class DefaultPermissionEvaluator implements PermissionEvaluator {
return false;
}
Object principal = authentication.getPrincipal();
if (principal == null) {
return false;
}
String userId = principal.toString();
String userId = authentication.getPrincipal().toString();
String perm = (String) permission;
try {
// 处理Agent访问权限
if (targetDomainObject instanceof Agent) {
Agent agent = (Agent) targetDomainObject;
return checkAgentAccess(userId, agent, perm);
return checkAgentAccess(userId, (Agent) targetDomainObject, perm);
}
// 处理TimerConfig访问权限
else if (targetDomainObject instanceof TimerConfig) {
TimerConfig timer = (TimerConfig) targetDomainObject;
return checkTimerAccess(userId, timer, perm);
}
// 处理基于ID的资源访问
else if (targetDomainObject instanceof String) {
// 这种情况在hasPermission(Authentication, Serializable, String, Object)方法中处理
return false;
return checkTimerAccess(userId, (TimerConfig) targetDomainObject, perm);
}
} catch (Exception e) {
log.error("权限检查过程中发生异常: userId={}, targetDomainObject={}, permission={}", userId, targetDomainObject, permission, e);
return false;
log.error("权限检查异常: userId={}, target={}, permission={}, error={}",
userId, targetDomainObject.getClass().getSimpleName(), perm, e.getMessage());
}
return false;
......@@ -75,36 +65,23 @@ public class DefaultPermissionEvaluator implements PermissionEvaluator {
return false;
}
Object principal = authentication.getPrincipal();
if (principal == null) {
return false;
}
String userId = principal.toString();
String userId = authentication.getPrincipal().toString();
String perm = (String) permission;
try {
// 处理基于ID的权限检查
if ("Agent".equals(targetType)) {
if (AGENT_TYPE.equals(targetType)) {
Agent agent = agentService.getAgent(targetId.toString());
if (agent == null) {
log.warn("未找到ID为 {} 的Agent", targetId);
return false;
}
return checkAgentAccess(userId, agent, perm);
return agent != null && checkAgentAccess(userId, agent, perm);
}
// 处理TimerConfig资源的权限检查
else if ("TimerConfig".equals(targetType)) {
else if (TIMER_CONFIG_TYPE.equals(targetType)) {
TimerConfig timer = timerService.getTimerById(targetId.toString());
if (timer == null) {
log.warn("未找到ID为 {} 的TimerConfig", targetId);
return false;
}
return checkTimerAccess(userId, timer, perm);
return timer != null && checkTimerAccess(userId, timer, perm);
}
} catch (Exception e) {
log.error("基于ID的权限检查过程中发生异常: userId={}, targetId={}, targetType={}, permission={}", userId, targetId, targetType, permission, e);
return false;
log.error("基于ID的权限检查异常: userId={}, targetId={}, targetType={}, permission={}, error={}",
userId, targetId, targetType, perm, e.getMessage());
}
return false;
......@@ -119,24 +96,17 @@ public class DefaultPermissionEvaluator implements PermissionEvaluator {
return true;
}
// 检查Agent所有者
// 所有者可以访问
if (agent.getOwner().equals(userId)) {
return true;
}
// 根据权限类型进行检查
switch (permission.toLowerCase()) {
case "read":
// 所有用户都可以读取公开的Agent(如果有此概念)
return false; // 暂时不支持公开Agent
case "write":
case "delete":
case "execute":
// 只有所有者可以写入、删除或执行Agent
return agent.getOwner().equals(userId);
default:
return false;
}
// 根据权限类型进行检查(目前只支持所有者访问)
String permissionLower = permission.toLowerCase();
return switch (permissionLower) {
case "read", "write", "delete", "execute" -> agent.getOwner().equals(userId);
default -> false;
};
}
/**
......@@ -148,32 +118,24 @@ public class DefaultPermissionEvaluator implements PermissionEvaluator {
return true;
}
// 检查定时器创建者
// 创建者可以访问
if (timer.getCreatedBy() != null && timer.getCreatedBy().equals(userId)) {
return true;
}
// 根据权限类型进行检查
switch (permission.toLowerCase()) {
case "read":
// 所有用户都可以读取公开的定时器(如果有此概念)
return false; // 暂时不支持公开定时器
case "write":
case "delete":
// 只有创建者可以修改或删除定时器
return timer.getCreatedBy() != null && timer.getCreatedBy().equals(userId);
default:
return false;
}
// 根据权限类型进行检查(目前只支持创建者访问)
String permissionLower = permission.toLowerCase();
return switch (permissionLower) {
case "read", "write", "delete" -> timer.getCreatedBy() != null && timer.getCreatedBy().equals(userId);
default -> false;
};
}
/**
* 检查是否为管理员用户
*/
private boolean isAdminUser(String userId) {
// 这里可以根据实际需求实现管理员检查逻辑
// 例如查询数据库或检查特殊用户ID
// 当前实现保留原有逻辑,但可以通过配置或数据库来管理管理员用户
// 管理员用户检查,可扩展为从配置或数据库读取
return "admin".equals(userId) || "user-001".equals(userId);
}
}
\ No newline at end of file
......@@ -5,18 +5,16 @@ import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
import pangea.hiagent.common.utils.JwtUtil;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import org.springframework.web.filter.OncePerRequestFilter;
import pangea.hiagent.common.utils.JwtUtil;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
/**
* JWT认证过滤器
......@@ -26,6 +24,8 @@ import java.util.List;
@Component
public class JwtAuthenticationFilter extends OncePerRequestFilter {
private static final String BEARER_PREFIX = "Bearer ";
private final JwtUtil jwtUtil;
public JwtAuthenticationFilter(JwtUtil jwtUtil) {
......@@ -35,19 +35,8 @@ public class JwtAuthenticationFilter extends OncePerRequestFilter {
@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException {
boolean isStreamEndpoint = request.getRequestURI().contains("/api/v1/agent/chat-stream");
boolean isTimelineEndpoint = request.getRequestURI().contains("/api/v1/agent/timeline-events");
if (isStreamEndpoint) {
log.info("处理Agent流式对话请求: {} {}", request.getMethod(), request.getRequestURI());
}
if (isTimelineEndpoint) {
log.info("处理时间轴事件订阅请求: {} {}", request.getMethod(), request.getRequestURI());
}
// 对于OPTIONS请求,直接放行
if ("OPTIONS".equalsIgnoreCase(request.getMethod())) {
log.debug("OPTIONS请求,直接放行");
filterChain.doFilter(request, response);
return;
}
......@@ -55,66 +44,25 @@ public class JwtAuthenticationFilter extends OncePerRequestFilter {
try {
String token = extractTokenFromRequest(request);
log.debug("JWT过滤器处理请求: {} {},提取到token: {}", request.getMethod(), request.getRequestURI(), token);
if (StringUtils.hasText(token)) {
log.debug("开始JWT验证,token长度: {}", token.length());
// 验证token是否有效
boolean isValid = jwtUtil.validateToken(token);
log.debug("JWT验证结果: {}", isValid);
if (isValid) {
if (jwtUtil.validateToken(token)) {
String userId = jwtUtil.getUserIdFromToken(token);
log.debug("JWT验证通过,用户ID: {}", userId);
if (userId != null) {
// 创建认证对象,添加基本权限
List<SimpleGrantedAuthority> authorities = Collections.singletonList(new SimpleGrantedAuthority("ROLE_USER"));
UsernamePasswordAuthenticationToken authentication =
new UsernamePasswordAuthenticationToken(userId, null, authorities);
var authorities = Collections.singletonList(new SimpleGrantedAuthority("ROLE_USER"));
var authentication = new UsernamePasswordAuthenticationToken(userId, null, authorities);
SecurityContextHolder.getContext().setAuthentication(authentication);
log.debug("已设置SecurityContext中的认证信息,用户ID: {}, 权限: {}", userId, authentication.getAuthorities());
} else {
log.warn("从token中提取的用户ID为空");
}
} else {
log.warn("JWT验证失败,token可能已过期或无效");
// 检查token是否过期
boolean isExpired = jwtUtil.isTokenExpired(token);
log.warn("Token过期状态: {}", isExpired);
}
} else {
log.debug("未找到有效的token");
// 记录请求信息以便调试
log.debug("请求URL: {}", request.getRequestURL());
log.debug("请求方法: {}", request.getMethod());
log.debug("Authorization头: {}", request.getHeader("Authorization"));
log.debug("token参数: {}", request.getParameter("token"));
}
} catch (Exception e) {
log.error("JWT认证处理异常", e);
log.error("JWT认证处理异常: {}", e.getMessage());
// 不在此处发送错误响应,让Spring Security的ExceptionTranslationFilter处理
// 这样可以避免响应被提前提交
}
// 特别处理流式端点的权限问题
if (isStreamEndpoint || isTimelineEndpoint) {
// 检查是否已认证
if (SecurityContextHolder.getContext().getAuthentication() == null) {
log.warn("流式端点未认证访问: {} {}", request.getMethod(), request.getRequestURI());
// 对于SSE端点,如果未认证,我们不立即返回错误,而是让后续处理决定
// 因为客户端可能会在重新连接时带上token
}
// 对于SSE端点,直接执行过滤器链,不进行额外的响应检查
// 继续执行过滤器链
filterChain.doFilter(request, response);
log.debug("JwtAuthenticationFilter处理完成(SSE端点): {} {}", request.getMethod(), request.getRequestURI());
return;
}
// 继续执行过滤器链,让Spring Security的其他过滤器处理认证和授权
// 这样可以让ExceptionTranslationFilter和AuthorizationFilter正确处理认证失败和权限拒绝
filterChain.doFilter(request, response);
log.debug("JwtAuthenticationFilter处理完成: {} {}", request.getMethod(), request.getRequestURI());
}
/**
......@@ -124,23 +72,11 @@ public class JwtAuthenticationFilter extends OncePerRequestFilter {
private String extractTokenFromRequest(HttpServletRequest request) {
// 首先尝试从请求头中提取Token
String authHeader = request.getHeader("Authorization");
log.debug("从请求头中提取Authorization: {}", authHeader);
if (StringUtils.hasText(authHeader) && authHeader.startsWith("Bearer ")) {
String token = authHeader.substring(7);
log.debug("从Authorization头中提取到token");
return token;
if (StringUtils.hasText(authHeader) && authHeader.startsWith(BEARER_PREFIX)) {
return authHeader.substring(BEARER_PREFIX.length());
}
// 如果请求头中没有Token,则尝试从URL参数中提取
// 这对于SSE连接特别有用,因为浏览器在自动重连时可能不会发送Authorization头
String tokenParam = request.getParameter("token");
log.debug("从URL参数中提取token参数: {}", tokenParam);
if (StringUtils.hasText(tokenParam)) {
log.debug("从URL参数中提取到token");
return tokenParam;
}
log.debug("未找到有效的token");
return null;
return request.getParameter("token");
}
}
\ No newline at end of file
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment