Commit 72ec5880 authored by youxiaoji's avatar youxiaoji

Merge branch 'refs/heads/main' into develop_tmp

# Conflicts:
#	backend/src/main/java/pangea/hiagent/agent/react/DefaultReactExecutor.java
#	backend/src/main/java/pangea/hiagent/agent/service/AgentChatService.java
#	backend/src/main/java/pangea/hiagent/agent/service/StreamRequestService.java
#	backend/src/main/java/pangea/hiagent/agent/service/UserSseService.java
#	backend/src/main/java/pangea/hiagent/tool/impl/HisenseTripTool.java
#	backend/src/main/java/pangea/hiagent/tool/impl/VisitorAppointmentTool.java
#	backend/src/main/java/pangea/hiagent/web/service/InfoCollectorService.java
#	frontend/src/components/ChatArea.vue
parents e80bffc7 40bd44a9
......@@ -193,6 +193,9 @@ backend/logs/
backend/storage/
backend/uploads/
backend/hiagentdb.mv.db
# H2 database files
backend/src/main/resources/hiagent_dev_db.*
./data/hiagent_dev_db.*
# Frontend files
frontend/node_modules/
......
现有架构分析
后端架构
核心组件
WorkPanelDataCollector - 实现IWorkPanelDataCollector接口,负责收集和存储事件
DefaultEventManager - 负责创建各种类型的事件对象
SseEventSender - 负责通过SSE发送事件到前端
EventDataManager - 负责构建事件数据,用于发送到前端
EventDeduplicationService - 负责事件去重
WorkPanelEventSubscriber - 负责订阅工作面板事件
WorkPanelService - 提供工作面板状态查询服务
数据流向
Agent执行过程中产生事件
事件被WorkPanelDataCollector收集和存储
通过订阅机制,事件被推送到SseEventSender
SseEventSender使用EventDataManager构建事件数据
通过SSE连接将事件数据发送到前端
主要问题
职责不清:
WorkPanelDataCollector既负责收集又负责订阅管理
DefaultEventManager和WorkPanelDataCollector之间职责重叠
SseEventSender和EventDataManager功能分离但紧密耦合
冗余代码:
多个组件都有相似的事件处理逻辑
事件去重逻辑分散在多个地方
重复的对象转换和序列化逻辑
过度工程化:
太多的小型服务类,增加了复杂性
过度使用依赖注入和组件拆分
前端架构
核心组件
TimelineService - 协调Timeline相关的各种服务
TimelineSseManager - 管理SSE连接和事件处理
SseConnectionManager - 管理SSE连接生命周期
SseEventListenerService - 解析SSE消息
SseEventProcessorService - 处理SSE事件
EventProcessingOptimizer - 优化事件处理
主要问题
服务层次过多:
简单的SSE事件处理被拆分成太多小的服务类
服务之间的依赖关系复杂
缺乏清晰的数据流:
事件处理流程在多个服务间跳跃,难以追踪
已更新待办
架构设计方案
基于对现有代码的分析,我提出以下方案:
后端架构
设计目标
减少不必要的抽象层和服务类
明确各组件职责边界
事件处理流程
降低组件间的耦合度
后的架构
EventService - 统一的事件服务
合并DefaultEventManager和EventDataManager的功能
负责创建、构建和发送事件
集成事件去重逻辑
WorkPanelService - 工作面板核心服务
合并WorkPanelDataCollector的部分功能
负责事件的收集、存储和订阅管理
直接依赖EventService进行事件发送
移除冗余组件
移除WorkPanelEventSubscriber(功能合并到WorkPanelService)
移除SseEventSender(功能合并到EventService)
移除EventDataManager(功能合并到EventService)
数据流向
Agent执行过程中产生事件
WorkPanelService收集和存储事件
WorkPanelService通过EventService发送事件到前端
EventService负责构建事件数据并通过SSE发送
前端架构
设计目标
服务层次结构
明确数据流路径
减少不必要的服务类
后的架构
TimelineService - 时间轴核心服务
合并TimelineSseManager和SseConnectionManager的功能
负责SSE连接管理和事件处理
直接处理事件解析和优化
移除冗余服务
移除TimelineSseManager(功能合并到TimelineService)
移除SseEventListenerService(功能合并到TimelineService)
移除SseEventProcessorService(功能合并到TimelineService)
数据流向
建立SSE连接
TimelineService接收和解析事件
TimelineService优化和处理事件
事件传递给UI组件展示
技术实现要点
事件去重:
在EventService中集中实现事件去重逻辑
使用更高效的缓存策略
对象池:
保留MapPoolService用于对象复用
优化对象池管理策略
错误处理:
统一异常处理机制
错误日志记录
性能优化:
批量处理事件发送
优化事件数据构建过程
\ No newline at end of file
# HiAgent 工具管理方案
## 1. 概述
本文档旨在详细说明 HiAgent 平台的工具管理机制,确保工具方法能够被 Spring AOP 正确代理,并支持手动扫描注册及 UI 配置功能。
通过对现有代码的分析,我们发现当前系统在工具管理方面还存在一些问题,主要包括:
1. 缺少手动触发工具扫描的前端界面
2. 工具无法被正确找到和调用的问题
3. 工具扫描API端点未暴露给前端使用
本文档将在分析这些问题的基础上,提出相应的改进建议.
## 2. Spring AOP 代理兼容性方案
### 2.1 工具类设计规范
为了确保工具方法能够被 Spring AOP 正确代理,所有工具类需要遵循以下规范:
1. **注解使用**
- 工具类必须使用 `@Component` 或其派生注解(如 `@Service`)进行标记
- 工具方法必须使用 `@org.springframework.ai.tool.annotation.Tool` 注解进行标记
2. **访问修饰符**
- 工具方法必须是 `public` 方法
- 避免在同一个类中直接调用其他带有 `@Tool` 注解的方法
3. **类设计**
- 工具类应该是无状态的,或者状态应该是线程安全的
- 避免使用 `final` 方法,因为这会影响 CGLIB 代理的创建
### 2.2 AOP 代理穿透机制
系统已经实现了 AOP 代理穿透机制,确保即使在使用 Spring AOP 代理的情况下也能正确获取工具信息:
1.[AgentToolManager.java](file:///c:/Users/Gavin/Documents/PangeaFinal/HiAgent/backend/src/main/java/pangea/hiagent/tool/AgentToolManager.java) 中提供了 `getTargetClass()` 方法来获取代理对象的原始类:
```java
private Class<?> getTargetClass(Object bean) {
if (bean == null) {
return null;
}
return AopUtils.getTargetClass(bean);
}
```
2. 在工具匹配过程中,系统会穿透代理获取真实的类信息进行比较,确保匹配准确性。
### 2.3 工具执行日志切面
系统通过 [ToolExecutionLoggerAspect.java](file:///c:/Users/Gavin/Documents/PangeaFinal/HiAgent/backend/src/main/java/pangea/hiagent/tool/aspect/ToolExecutionLoggerAspect.java) 实现了工具执行的日志记录和监控:
1. 使用 `@Around("@annotation(tool)")` 环绕通知拦截所有带有 `@Tool` 注解的方法
2. 自动记录工具执行的输入参数、输出结果、执行时间等信息
3. 将工具执行信息同步到 WorkPanel 进行可视化展示
## 3. 工具扫描与注册机制
### 3.1 自动扫描机制
系统通过 [ToolBeanNameInitializer.java](file:///c:/Users/Gavin/Documents/PangeaFinal/HiAgent/backend/src/main/java/pangea/hiagent/tool/ToolBeanNameInitializer.java) 实现工具的自动扫描和注册:
1. **扫描范围**
- 扫描所有 Spring 容器中的 Bean
- 识别带有 `@Tool` 注解方法的类作为工具类
- 过滤掉 Spring 框架自带的 Bean
2. **工具识别规则**
- 类名包含 "Tool" 关键字
-`@Component``@Service` 标注
- 类中包含带有 `@Tool` 注解的方法
3. **工具名称推导**
- 从类名推导工具名称,去除 "Tool" 后缀
- 转换为小驼峰命名格式
### 3.2 手动触发扫描
系统支持通过管理界面手动触发工具扫描和注册:
1. 提供 `initializeToolBeanNamesManually()` 方法用于手动触发扫描
2. 扫描过程会与数据库中的工具记录进行同步:
- 如果数据库中已存在对应工具,则更新 beanName
- 如果数据库中不存在对应工具,则创建新的工具记录
- 如果数据库中有记录但 Spring 容器中不存在对应 Bean,则记录警告信息
目前系统已经实现了手动扫描功能的后端API端点,位于 `/api/v1/admin/system/initialize-tool-beans`,通过 POST 请求触发。但在前端界面上还未提供相应的人机交互界面。
### 3.3 数据库同步策略
工具信息会被持久化存储在数据库中,确保系统重启后配置不会丢失:
1. **工具实体**
- 工具名称(唯一标识)
- Spring Bean 名称(用于查找对应的实例)
- 工具显示名称
- 工具描述
- 工具状态(active/inactive)
- 工具所有者等信息
2. **同步机制**
- 系统启动时不自动执行扫描(避免影响启动速度)
- 通过管理界面手动触发扫描和同步
- 支持增量更新,只处理发生变化的工具
## 4. 当前存在的问题与改进建议
### 4.1 当前存在的主要问题
通过分析现有代码和功能实现,我们发现工具管理系统存在以下主要问题:
1. **缺少手动扫描的前端界面**
- 后端已经实现了手动扫描工具的API端点(`/api/v1/admin/system/initialize-tool-beans`
- 但前端尚未提供相应的用户界面来触发这一功能
2. **工具无法正确找到和调用**
-[AgentToolManager.java](file:///c:/Users/Gavin/Documents/PangeaFinal/HiAgent/backend/src/main/java/pangea/hiagent/tool/AgentToolManager.java)`getAvailableToolInstances` 方法中,当工具的 beanName 为空或查找失败时,仅记录日志而没有提供有效的错误反馈机制
- 工具调用失败时缺乏详细的错误信息和调试手段
3. **工具管理页面功能不完善**
- 当前的 [ToolManagement.vue](file:///c:/Users/Gavin/Documents/PangeaFinal/HiAgent/frontend/src/pages/ToolManagement.vue) 页面仅支持基础的增删改查功能
- 缺少与后端扫描功能的集成
### 4.2 改进建议
针对上述问题,我们提出以下改进建议:
#### 4.2.1 完善前端工具管理界面
1. **增加手动扫描按钮**
- 在工具管理页面添加"扫描工具"按钮
- 点击后调用后端API `/api/v1/admin/system/initialize-tool-beans` 触发扫描
- 显示扫描进度和结果
- 提供扫描历史记录查看功能
2. **增强工具详情展示**
- 在工具列表中增加显示工具的Bean名称、状态等详细信息
- 提供工具测试功能,允许用户直接测试工具调用
- 显示工具的最后更新时间和创建者信息
3. **优化错误提示**
- 当工具无法找到或调用失败时,提供更明确的错误信息
- 增加工具诊断功能,帮助用户排查问题
- 提供常见问题解决方案链接和帮助文档
4. **增加工具诊断界面**
- 提供单个工具的详细诊断信息查看
- 支持批量工具状态检查
- 显示工具依赖关系图谱#### 4.2.2 后端功能优化
1. **完善工具调用错误处理**
-[AgentToolManager.java](file:///c:/Users/Gavin/Documents/PangeaFinal/HiAgent/backend/src/main/java/pangea/hiagent/tool/AgentToolManager.java) 中增强错误处理机制
- 提供更详细的错误信息,便于前端展示和用户排查问题
- 增加结构化的错误信息返回,包含具体的原因和解决方案建议
2. **增加工具诊断API**
- 提供工具诊断端点,检查工具是否正确定义和注册
- 返回工具的详细信息和可能存在的问题
- 支持单个工具诊断和批量工具诊断功能
3. **优化日志记录**
- 增强工具调用过程中的日志记录
- 提供更详细的调试信息,便于问题追踪
- 结构化日志信息,方便后续分析和问题定位## 5. 实施步骤
### 5.1 后端实施
1. 完善 [ToolBeanNameInitializer.java](file:///c:/Users/Gavin/Documents/PangeaFinal/HiAgent/backend/src/main/java/pangea/hiagent/tool/ToolBeanNameInitializer.java) 的手动扫描接口
2. 优化 [AgentToolManager.java](file:///c:/Users/Gavin/Documents/PangeaFinal/HiAgent/backend/src/main/java/pangea/hiagent/tool/AgentToolManager.java) 的工具获取逻辑,增强错误处理和日志记录
3. 增强 [ToolExecutionLoggerAspect.java](file:///c:/Users/Gavin/Documents/PangeaFinal/HiAgent/backend/src/main/java/pangea/hiagent/tool/aspect/ToolExecutionLoggerAspect.java) 的日志记录功能
4. 增加工具诊断API端点,提供工具状态检查功能
5. 实现工具依赖关系分析功能
6. 增加工具使用统计和性能监控
### 5.2 前端实施
1. 在工具管理页面增加手动扫描按钮,调用 `/api/v1/admin/system/initialize-tool-beans` 端点
2. 增强工具列表展示,显示更多工具详细信息如Bean名称、状态等
3. 增加工具测试功能,允许用户直接测试工具调用
4. 优化错误提示,提供更明确的错误信息帮助用户排查问题
5. 实现工具诊断界面,支持单个和批量工具诊断
6. 增加工具依赖关系可视化展示
### 5.3 测试验证
1. 验证工具方法的 AOP 代理兼容性
2. 测试手动扫描和自动注册功能
3. 验证 UI 配置功能的完整性和易用性
4. 测试工具执行的日志记录和监控功能
5. 验证错误处理机制的有效性
6. 测试工具诊断功能的准确性和完整性
7. 验证工具依赖关系分析的正确性
## 6. 总结
本方案通过规范工具类设计、实现 AOP 代理穿透、建立完善的扫描注册机制以及提供友好的 UI 配置界面,全面解决了工具管理的相关需求。该方案既保证了系统的稳定性和扩展性,又提升了用户的使用体验。
通过对现有代码的分析,我们确认系统已经具备了良好的基础架构,包括:
1. 完善的 AOP 代理支持
2. 工具扫描和注册的核心功能实现
3. 工具调用的日志记录机制
接下来的工作重点应该放在完善前端界面和增强错误处理上,使系统更加易于使用和维护。特别需要关注的是工具诊断功能的实现,这将大大提高系统运维和问题排查的效率。
\ No newline at end of file
......@@ -14,7 +14,7 @@
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>3.5.8</version>
<version>3.5.9</version>
<relativePath/>
</parent>
......@@ -108,7 +108,6 @@
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-milvus-store</artifactId>
<version>${spring-ai.version}</version>
</dependency>
......@@ -155,14 +154,12 @@
<dependency>
<groupId>com.mysql</groupId>
<artifactId>mysql-connector-j</artifactId>
<version>8.0.33</version>
</dependency>
<!-- H2 Database -->
<dependency>
<groupId>com.h2database</groupId>
<artifactId>h2</artifactId>
<version>2.2.224</version>
</dependency>
<!-- Redis -->
......@@ -194,14 +191,12 @@
<dependency>
<groupId>com.github.ben-manes.caffeine</groupId>
<artifactId>caffeine</artifactId>
<version>${caffeine.version}</version>
</dependency>
<!-- Lombok -->
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<version>1.18.30</version>
<scope>provided</scope>
</dependency>
<!-- Jackson -->
......@@ -234,7 +229,6 @@
<dependency>
<groupId>org.hibernate.validator</groupId>
<artifactId>hibernate-validator</artifactId>
<version>8.0.1.Final</version>
</dependency>
<!-- SpringDoc OpenAPI for Swagger -->
......@@ -351,7 +345,6 @@
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.11.0</version>
<configuration>
<source>17</source>
<target>17</target>
......@@ -370,7 +363,6 @@
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<version>3.0.0</version>
<configuration>
<argLine>-Dfile.encoding=UTF-8</argLine>
</configuration>
......
package pangea.hiagent.web.repository;
package pangea.hiagent;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
......
package pangea.hiagent.agent.processor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import pangea.hiagent.model.Agent;
import pangea.hiagent.rag.RagService;
import pangea.hiagent.agent.service.AgentErrorHandler;
import pangea.hiagent.agent.service.TokenConsumerWithCompletion;
import java.util.function.Consumer;
/**
* Agent处理器抽象基类
* 封装所有Agent处理器的公共逻辑
* 职责:提供所有Agent处理器共享的基础功能
*/
@Slf4j
public abstract class AbstractAgentProcessor extends BaseAgentProcessor {
@Autowired
protected AgentErrorHandler agentErrorHandler;
/**
* 处理RAG响应的通用逻辑
*
* @param ragResponse RAG响应
* @param tokenConsumer token消费者(流式处理时使用)
* @return RAG响应
*/
protected String handleRagResponse(String ragResponse, Consumer<String> tokenConsumer) {
if (tokenConsumer != null) {
// 对于流式处理,我们需要将RAG响应作为token发送
tokenConsumer.accept(ragResponse);
// 发送完成信号
if (tokenConsumer instanceof TokenConsumerWithCompletion) {
((TokenConsumerWithCompletion) tokenConsumer).onComplete(ragResponse);
}
}
return ragResponse;
}
/**
* 处理请求的通用前置逻辑
*
* @param agent Agent对象
* @param userMessage 用户消息
* @param userId 用户ID
* @param ragService RAG服务
* @param tokenConsumer token消费者(流式处理时使用)
* @return RAG响应,如果有的话;否则返回null继续正常处理流程
*/
protected String handlePreProcessing(Agent agent, String userMessage, String userId, RagService ragService, Consumer<String> tokenConsumer) {
// 为每个用户-Agent组合创建唯一的会话ID
String sessionId = generateSessionId(agent, userId);
// 添加用户消息到ChatMemory
addUserMessageToMemory(sessionId, userMessage);
// 检查是否启用RAG并尝试RAG增强
String ragResponse = tryRagEnhancement(agent, userMessage, ragService);
if (ragResponse != null) {
log.info("RAG增强返回结果,直接返回");
return handleRagResponse(ragResponse, tokenConsumer);
}
return null;
}
}
\ No newline at end of file
......@@ -9,7 +9,7 @@ import pangea.hiagent.memory.MemoryService;
import pangea.hiagent.memory.SmartHistorySummarizer;
import pangea.hiagent.model.Agent;
import pangea.hiagent.rag.RagService;
import pangea.hiagent.agent.service.AgentErrorHandler;
import pangea.hiagent.web.service.AgentService;
import pangea.hiagent.agent.service.ErrorHandlerService;
import pangea.hiagent.agent.service.TokenConsumerWithCompletion;
......@@ -38,8 +38,7 @@ public abstract class BaseAgentProcessor implements AgentProcessor {
@Autowired
protected ErrorHandlerService errorHandlerService;
@Autowired
protected AgentErrorHandler agentErrorHandler;
// 默认系统提示词
protected static final String DEFAULT_SYSTEM_PROMPT = "你是一个智能助手";
......@@ -135,7 +134,7 @@ public abstract class BaseAgentProcessor implements AgentProcessor {
* @return 是否为401错误
*/
protected boolean isUnauthorizedError(Throwable e) {
return agentErrorHandler.isUnauthorizedError(e);
return errorHandlerService.isUnauthorizedError(new Exception(e));
}
/**
......@@ -146,7 +145,17 @@ public abstract class BaseAgentProcessor implements AgentProcessor {
* @return 错误消息
*/
protected String handleSyncError(Throwable e, String errorMessagePrefix) {
return agentErrorHandler.handleSyncError(e, errorMessagePrefix);
// 检查是否是401 Unauthorized错误
if (isUnauthorizedError(e)) {
log.error("LLM返回401未授权错误: {}", e.getMessage());
return "请配置API密钥";
} else {
String errorMessage = e.getMessage();
if (errorMessage == null || errorMessage.isEmpty()) {
errorMessage = "未知错误";
}
return errorMessagePrefix + ": " + errorMessage;
}
}
/**
......@@ -325,7 +334,7 @@ public abstract class BaseAgentProcessor implements AgentProcessor {
hasError.set(true);
// 不再重新抛出异常,避免中断流式处理
// 但我们应该记录这个错误并向客户端发送错误信息
agentErrorHandler.sendErrorMessage(tokenConsumer, "[错误] 处理token时发生错误: " + e.getMessage());
errorHandlerService.sendErrorMessage(tokenConsumer, "[错误] 处理token时发生错误: " + e.getMessage());
}
}
} catch (Exception e) {
......@@ -344,7 +353,7 @@ public abstract class BaseAgentProcessor implements AgentProcessor {
*/
private void handleStreamError(Throwable throwable, Consumer<String> tokenConsumer, AtomicBoolean hasError) {
hasError.set(true);
agentErrorHandler.handleStreamError(throwable, tokenConsumer, "流式调用出错");
errorHandlerService.handleStreamError(throwable, tokenConsumer, "流式调用出错");
}
/**
......@@ -385,7 +394,11 @@ public abstract class BaseAgentProcessor implements AgentProcessor {
// 发送完成事件,包含完整内容
try {
if (tokenConsumer instanceof TokenConsumerWithCompletion) {
try {
((TokenConsumerWithCompletion) tokenConsumer).onComplete(fullText.toString());
} catch (NoClassDefFoundError e) {
log.error("TokenConsumerWithCompletion依赖类未找到,跳过完成回调: {}", e.getMessage());
}
if (log.isTraceEnabled()) {
log.trace("完成事件已发送");
}
......@@ -411,7 +424,7 @@ public abstract class BaseAgentProcessor implements AgentProcessor {
* @param isCompleted 是否已完成
*/
private void handleStreamModelError(Consumer<String> tokenConsumer, AtomicBoolean isCompleted) {
agentErrorHandler.sendErrorMessage(tokenConsumer, "[错误] 流式模型或提示词为空,无法启动流式处理");
errorHandlerService.sendErrorMessage(tokenConsumer, "[错误] 流式模型或提示词为空,无法启动流式处理");
// 标记完成
isCompleted.set(true);
}
......@@ -426,8 +439,58 @@ public abstract class BaseAgentProcessor implements AgentProcessor {
*/
private void handleUnexpectedError(Exception e, Consumer<String> tokenConsumer, AtomicBoolean isCompleted) {
String errorMessage = handleSyncError(e, "处理流式响应时发生错误");
agentErrorHandler.sendErrorMessage(tokenConsumer, "[错误] " + errorMessage);
errorHandlerService.sendErrorMessage(tokenConsumer, "[错误] " + errorMessage);
// 确保标记为已完成
isCompleted.set(true);
}
/**
* 处理RAG响应的通用逻辑
*
* @param ragResponse RAG响应
* @param tokenConsumer token消费者(流式处理时使用)
* @return RAG响应
*/
protected String handleRagResponse(String ragResponse, Consumer<String> tokenConsumer) {
if (tokenConsumer != null) {
// 对于流式处理,我们需要将RAG响应作为token发送
tokenConsumer.accept(ragResponse);
// 发送完成信号
if (tokenConsumer instanceof TokenConsumerWithCompletion) {
try {
((TokenConsumerWithCompletion) tokenConsumer).onComplete(ragResponse);
} catch (NoClassDefFoundError e) {
log.error("TokenConsumerWithCompletion依赖类未找到,跳过完成回调: {}", e.getMessage());
}
}
}
return ragResponse;
}
/**
* 处理请求的通用前置逻辑
*
* @param agent Agent对象
* @param userMessage 用户消息
* @param userId 用户ID
* @param ragService RAG服务
* @param tokenConsumer token消费者(流式处理时使用)
* @return RAG响应,如果有的话;否则返回null继续正常处理流程
*/
protected String handlePreProcessing(Agent agent, String userMessage, String userId, RagService ragService, Consumer<String> tokenConsumer) {
// 为每个用户-Agent组合创建唯一的会话ID
String sessionId = generateSessionId(agent, userId);
// 添加用户消息到ChatMemory
addUserMessageToMemory(sessionId, userMessage);
// 检查是否启用RAG并尝试RAG增强
String ragResponse = tryRagEnhancement(agent, userMessage, ragService);
if (ragResponse != null) {
log.info("RAG增强返回结果,直接返回");
return handleRagResponse(ragResponse, tokenConsumer);
}
return null;
}
}
\ No newline at end of file
......@@ -11,6 +11,7 @@ import pangea.hiagent.rag.RagService;
import pangea.hiagent.web.dto.AgentRequest;
import java.util.function.Consumer;
import pangea.hiagent.agent.service.TokenConsumerWithCompletion;
/**
* 普通Agent处理器实现类
......@@ -18,7 +19,7 @@ import java.util.function.Consumer;
*/
@Slf4j
@Service
public class NormalAgentProcessor extends AbstractAgentProcessor {
public class NormalAgentProcessor extends BaseAgentProcessor {
@Autowired(required = false)
private RagService ragService;
......@@ -67,7 +68,7 @@ public class NormalAgentProcessor extends AbstractAgentProcessor {
return responseContent;
} catch (Exception e) {
return agentErrorHandler.handleSyncError(e, "模型调用失败");
return handleSyncError(e, "模型调用失败");
}
}
......@@ -101,8 +102,15 @@ public class NormalAgentProcessor extends AbstractAgentProcessor {
// 流式处理
handleStreamingResponse(tokenConsumer, prompt, streamingChatModel, sessionId);
} catch (Exception e) {
agentErrorHandler.handleStreamError(e, tokenConsumer, "普通Agent流式处理失败");
agentErrorHandler.ensureCompletionCallback(tokenConsumer, "处理请求时发生错误: " + e.getMessage());
errorHandlerService.handleStreamError(e, tokenConsumer, "普通Agent流式处理失败");
// 直接调用完成回调,不依赖AgentErrorHandler
if (tokenConsumer instanceof TokenConsumerWithCompletion) {
try {
((TokenConsumerWithCompletion) tokenConsumer).onComplete("处理请求时发生错误: " + e.getMessage());
} catch (Exception ex) {
log.error("调用onComplete时发生错误: {}", ex.getMessage(), ex);
}
}
}
}
......@@ -114,9 +122,15 @@ public class NormalAgentProcessor extends AbstractAgentProcessor {
private void handleModelNotSupportStream(Consumer<String> tokenConsumer) {
String errorMessage = "[错误] 当前模型不支持流式输出";
// 发送错误信息
agentErrorHandler.sendErrorMessage(tokenConsumer, errorMessage);
errorHandlerService.sendErrorMessage(tokenConsumer, errorMessage);
// 确保在异常情况下也调用完成回调
agentErrorHandler.ensureCompletionCallback(tokenConsumer, errorMessage);
if (tokenConsumer instanceof TokenConsumerWithCompletion) {
try {
((TokenConsumerWithCompletion) tokenConsumer).onComplete(errorMessage);
} catch (Exception ex) {
log.error("调用onComplete时发生错误: {}", ex.getMessage(), ex);
}
}
}
@Override
......
......@@ -15,6 +15,7 @@ import pangea.hiagent.web.service.AgentService;
import java.util.List;
import java.util.function.Consumer;
import pangea.hiagent.agent.service.TokenConsumerWithCompletion;
/**
* ReAct Agent处理器实现类
......@@ -22,7 +23,7 @@ import java.util.function.Consumer;
*/
@Slf4j
@Service
public class ReActAgentProcessor extends AbstractAgentProcessor {
public class ReActAgentProcessor extends BaseAgentProcessor {
@Autowired
private AgentService agentService;
......@@ -87,16 +88,14 @@ public class ReActAgentProcessor extends AbstractAgentProcessor {
defaultReactExecutor.addReactCallback(defaultReactCallback);
}
// 使用ReAct执行器执行流程,传递Agent对象以支持记忆功能
String finalAnswer = defaultReactExecutor.executeWithAgent(client, userMessage, tools, agent);
// 使用ReAct执行器执行流程,传递Agent对象和用户ID以支持记忆功能
String finalAnswer = defaultReactExecutor.execute(client, userMessage, tools, agent, userId);
// 将助理回复添加到ChatMemory
String sessionId = generateSessionId(agent, userId);
addAssistantMessageToMemory(sessionId, finalAnswer);
// 助手回复已经由执行器保存到内存中,不需要重复保存
return finalAnswer;
} catch (Exception e) {
return agentErrorHandler.handleSyncError(e, "处理ReAct请求时发生错误");
return handleSyncError(e, "处理ReAct请求时发生错误");
}
}
......@@ -138,11 +137,18 @@ public class ReActAgentProcessor extends AbstractAgentProcessor {
return;
}
// 使用ReAct执行器流式执行流程,传递Agent对象以支持记忆功能
defaultReactExecutor.executeStreamWithAgent(client, userMessage, tools, tokenConsumer, agent);
// 使用ReAct执行器流式执行流程,传递Agent对象以支持记忆功能和用户ID以确保上下文传播
defaultReactExecutor.executeStream(client, userMessage, tools, tokenConsumer, agent, userId);
} catch (Exception e) {
agentErrorHandler.handleStreamError(e, tokenConsumer, "流式处理ReAct请求时发生错误");
agentErrorHandler.ensureCompletionCallback(tokenConsumer, "处理请求时发生错误: " + e.getMessage());
errorHandlerService.handleStreamError(e, tokenConsumer, "流式处理ReAct请求时发生错误");
// 直接调用完成回调,不依赖AgentErrorHandler
if (tokenConsumer instanceof TokenConsumerWithCompletion) {
try {
((TokenConsumerWithCompletion) tokenConsumer).onComplete("处理请求时发生错误: " + e.getMessage());
} catch (Exception ex) {
log.error("调用onComplete时发生错误: {}", ex.getMessage(), ex);
}
}
}
}
......@@ -154,8 +160,14 @@ public class ReActAgentProcessor extends AbstractAgentProcessor {
private void handleModelNotAvailable(Consumer<String> tokenConsumer) {
String errorMessage = "[错误] 无法获取Agent的聊天模型";
// 发送错误信息
agentErrorHandler.sendErrorMessage(tokenConsumer, errorMessage);
errorHandlerService.sendErrorMessage(tokenConsumer, errorMessage);
// 确保在异常情况下也调用完成回调
agentErrorHandler.ensureCompletionCallback(tokenConsumer, errorMessage);
if (tokenConsumer instanceof TokenConsumerWithCompletion) {
try {
((TokenConsumerWithCompletion) tokenConsumer).onComplete(errorMessage);
} catch (Exception ex) {
log.error("调用onComplete时发生错误: {}", ex.getMessage(), ex);
}
}
}
}
\ No newline at end of file
......@@ -6,78 +6,96 @@ import lombok.extern.slf4j.Slf4j;
import pangea.hiagent.workpanel.IWorkPanelDataCollector;
/**
* 自定义 ReAct 回调类,用于捕获并处理 ReAct 的每一步思维过程
* 适配项目现有的 ReAct 实现方式
* 简化的ReAct回调类
*/
@Slf4j
@Component // 注册为 Spring 组件,方便注入
@Component
public class DefaultReactCallback implements ReactCallback {
@Autowired
private IWorkPanelDataCollector workPanelCollector;
/**
* ReAct 每执行一个步骤,该方法会被触发
* @param reactStep ReAct 步骤对象,包含步骤的所有核心信息
*/
@Override
public void onStep(ReactStep reactStep) {
// 将信息记录到工作面板
log.info("ReAct步骤触发: 类型={}, 内容摘要={}",
reactStep.getStepType(),
reactStep.getContent() != null ?
reactStep.getContent().substring(0, Math.min(50, reactStep.getContent().length())) : "null");
recordReactStepToWorkPanel(reactStep);
}
/**
* 处理 ReAct 最终答案步骤
* @param finalAnswer 最终答案
*/
@Override
public void onFinalAnswer(String finalAnswer) {
// 创建一个FINAL_ANSWER类型的ReactStep并处理
ReactStep finalStep = new ReactStep(0, ReactStepType.FINAL_ANSWER, finalAnswer);
recordReactStepToWorkPanel(finalStep);
}
/**
* 将ReAct步骤记录到工作面板
* @param reactStep ReAct步骤
*/
private void recordReactStepToWorkPanel(ReactStep reactStep) {
if (workPanelCollector == null) {
log.debug("无法记录到工作面板:collector为null");
return;
}
try {
switch (reactStep.getStepType()) {
case THOUGHT:
workPanelCollector.recordThinking(reactStep.getContent(), "reasoning");
workPanelCollector.recordThinking(reactStep.getContent(), "thought");
log.info("[WorkPanel] 记录思考步骤: {}",
reactStep.getContent().substring(0, Math.min(100, reactStep.getContent().length())));
break;
case ACTION:
if (reactStep.getAction() != null) {
// 使用recordToolCallAction记录工具调用开始,状态为pending
// 记录工具调用动作
String toolName = reactStep.getAction().getToolName();
Object parameters = reactStep.getAction().getParameters();
// 记录工具调用,初始状态为pending
workPanelCollector.recordToolCallAction(
reactStep.getAction().getToolName(),
reactStep.getAction().getParameters(),
null,
"pending",
null
toolName,
parameters,
null, // 结果为空
"pending", // 状态为pending
null // 错误信息为空
);
// 同时记录工具调用信息到日志
log.info("[WorkPanel] 记录工具调用: 工具={} 参数={}", toolName, parameters);
} else {
// 如果没有具体的工具信息,记录为一般动作
workPanelCollector.recordThinking(reactStep.getContent(), "action");
log.info("[WorkPanel] 记录动作步骤: {}",
reactStep.getContent().substring(0, Math.min(100, reactStep.getContent().length())));
}
break;
case OBSERVATION:
if (reactStep.getObservation() != null && reactStep.getAction() != null) {
// 使用recordToolCallAction记录工具调用完成,状态为success
if (reactStep.getObservation() != null) {
// 检查是否有对应的动作信息
if (reactStep.getAction() != null) {
// 使用动作信息更新工具调用结果
workPanelCollector.recordToolCallAction(
reactStep.getAction().getToolName(),
reactStep.getAction().getParameters(),
reactStep.getObservation().getContent(),
"success",
null
"success", // 状态为success
null // 无错误信息
);
log.info("[WorkPanel] 更新工具调用结果: 工具={} 结果摘要={}",
reactStep.getAction().getToolName(),
reactStep.getObservation().getContent().substring(0, Math.min(50, reactStep.getObservation().getContent().length())));
} else {
// 如果没有动作信息,记录为观察结果
workPanelCollector.recordThinking(reactStep.getContent(), "observation");
log.info("[WorkPanel] 记录观察步骤: {}",
reactStep.getContent().substring(0, Math.min(100, reactStep.getContent().length())));
}
}
break;
case FINAL_ANSWER:
workPanelCollector.recordFinalAnswer(reactStep.getContent());
// 记录最终答案到日志
log.info("[WorkPanel] 记录最终答案: {}",
reactStep.getContent().substring(0, Math.min(100, reactStep.getContent().length())));
break;
default:
log.warn("未知的ReAct步骤类型: {}", reactStep.getStepType());
......@@ -85,6 +103,20 @@ public class DefaultReactCallback implements ReactCallback {
}
} catch (Exception e) {
log.error("记录ReAct步骤到工作面板失败", e);
// 即使发生异常,也尝试记录错误信息到工作面板
try {
if (reactStep != null && reactStep.getAction() != null) {
workPanelCollector.recordToolCallAction(
reactStep.getAction().getToolName(),
reactStep.getAction().getParameters(),
"记录失败: " + e.getMessage(),
"error",
System.currentTimeMillis() // 使用当前时间戳作为执行时间
);
}
} catch (Exception ex) {
log.error("记录错误信息到工作面板也失败", ex);
}
}
}
}
\ No newline at end of file
......@@ -15,21 +15,21 @@ public interface ReactExecutor {
* @param chatClient ChatClient实例
* @param userInput 用户输入
* @param tools 工具列表
* @param agent Agent对象
* @return 最终答案
*/
String execute(ChatClient chatClient, String userInput, List<Object> tools);
String execute(ChatClient chatClient, String userInput, List<Object> tools, Agent agent);
/**
* 执行ReAct流程(同步方式)- 支持Agent配置
* 执行ReAct流程(同步方式)
* @param chatClient ChatClient实例
* @param userInput 用户输入
* @param tools 工具列表
* @param agent Agent对象(可选)
* @param agent Agent对象
* @param userId 用户ID
* @return 最终答案
*/
default String executeWithAgent(ChatClient chatClient, String userInput, List<Object> tools, Agent agent) {
return execute(chatClient, userInput, tools);
}
String execute(ChatClient chatClient, String userInput, List<Object> tools, Agent agent, String userId);
/**
* 流式执行ReAct流程
......@@ -37,20 +37,20 @@ public interface ReactExecutor {
* @param userInput 用户输入
* @param tools 工具列表
* @param tokenConsumer token处理回调函数
* @param agent Agent对象
* @param userId 用户ID
*/
void executeStream(ChatClient chatClient, String userInput, List<Object> tools, Consumer<String> tokenConsumer);
void executeStream(ChatClient chatClient, String userInput, List<Object> tools, Consumer<String> tokenConsumer, Agent agent, String userId);
/**
* 流式执行ReAct流程 - 支持Agent配置
* 流式执行ReAct流程(旧方法,保持向后兼容)
* @param chatClient ChatClient实例
* @param userInput 用户输入
* @param tools 工具列表
* @param tokenConsumer token处理回调函数
* @param agent Agent对象(可选)
* @param agent Agent对象
*/
default void executeStreamWithAgent(ChatClient chatClient, String userInput, List<Object> tools, Consumer<String> tokenConsumer, Agent agent) {
executeStream(chatClient, userInput, tools, tokenConsumer);
}
void executeStream(ChatClient chatClient, String userInput, List<Object> tools, Consumer<String> tokenConsumer, Agent agent);
/**
* 添加ReAct回调
......
package pangea.hiagent.agent.react;
import lombok.Data;
/**
* ReAct步骤对象,包含步骤的所有核心信息
* ReAct步骤类,用于表示ReAct执行过程中的单个步骤
*/
@Data
public class ReactStep {
/**
* 步骤编号
*/
private int stepNumber;
/**
* 步骤类型
*/
private ReactStepType stepType;
/**
* 步骤核心内容(思维描述、动作指令、观察结果等)
*/
private String content;
/**
* 工具调用信息(仅在ACTION步骤时有值)
*/
private ToolCallAction action;
/**
* 工具观察结果(仅在OBSERVATION步骤时有值)
*/
private ToolObservation observation;
/**
* 构造函数
*/
public ReactStep() {}
/**
* 构造函数
* @param stepNumber 步骤编号
* @param stepType 步骤类型
* @param content 步骤内容
*/
public ReactStep(int stepNumber, ReactStepType stepType, String content) {
this.stepNumber = stepNumber;
this.stepType = stepType;
this.content = content;
}
// Getters and Setters
public int getStepNumber() { return stepNumber; }
public void setStepNumber(int stepNumber) { this.stepNumber = stepNumber; }
public ReactStepType getStepType() { return stepType; }
public void setStepType(ReactStepType stepType) { this.stepType = stepType; }
public String getContent() { return content; }
public void setContent(String content) { this.content = content; }
public ToolCallAction getAction() { return action; }
public void setAction(ToolCallAction action) { this.action = action; }
public ToolObservation getObservation() { return observation; }
public void setObservation(ToolObservation observation) { this.observation = observation; }
/**
* 工具调用动作类
* 工具调用动作内部
*/
@Data
public static class ToolCallAction {
/**
* 工具名称
*/
private String toolName;
private Object toolArgs;
/**
* 工具调用参数
*/
private Object parameters;
public ToolCallAction() {}
public ToolCallAction(String toolName, Object parameters) {
public ToolCallAction(String toolName, Object toolArgs) {
this.toolName = toolName;
this.parameters = parameters;
this.toolArgs = toolArgs;
}
public String getToolName() { return toolName; }
public void setToolName(String toolName) { this.toolName = toolName; }
public Object getToolArgs() { return toolArgs; }
public void setToolArgs(Object toolArgs) { this.toolArgs = toolArgs; }
// 根据DefaultReactCallback.java中的使用情况添加getParameters方法
public Object getParameters() { return toolArgs; }
}
/**
* 工具观察结果类
* 工具观察结果内部
*/
@Data
public static class ToolObservation {
/**
* 观察内容
*/
private String content;
public ToolObservation() {}
private String result;
public ToolObservation(String content) {
this.content = content;
public ToolObservation(String result) {
this.result = result;
}
public String getResult() { return result; }
public void setResult(String result) { this.result = result; }
// 根据DefaultReactCallback.java中的使用情况添加getContent方法
public String getContent() { return result; }
}
}
\ No newline at end of file
package pangea.hiagent.agent.service;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import java.util.function.Consumer;
/**
* Agent错误处理工具类
* 统一处理Agent处理器中的错误逻辑
*/
@Slf4j
@Component
public class AgentErrorHandler {
@Autowired
private ErrorHandlerService errorHandlerService;
/**
* 处理401未授权错误
*
* @param e 异常对象
* @return 是否为401错误
*/
public boolean isUnauthorizedError(Throwable e) {
return errorHandlerService.isUnauthorizedError(new Exception(e));
}
/**
* 处理流式处理中的错误
*
* @param e 异常对象
* @param tokenConsumer token处理回调函数
* @param errorMessagePrefix 错误消息前缀
*/
public void handleStreamError(Throwable e, Consumer<String> tokenConsumer, String errorMessagePrefix) {
errorHandlerService.handleStreamError(e, tokenConsumer, errorMessagePrefix);
}
/**
* 处理同步处理中的错误
*
* @param e 异常对象
* @param errorMessagePrefix 错误消息前缀
* @return 错误消息
*/
public String handleSyncError(Throwable e, String errorMessagePrefix) {
// 检查是否是401 Unauthorized错误
if (isUnauthorizedError(e)) {
log.error("LLM返回401未授权错误: {}", e.getMessage());
return "请配置API密钥";
} else {
String errorMessage = e.getMessage();
if (errorMessage == null || errorMessage.isEmpty()) {
errorMessage = "未知错误";
}
return errorMessagePrefix + ": " + errorMessage;
}
}
/**
* 发送错误信息给客户端
*
* @param tokenConsumer token处理回调函数
* @param errorMessage 错误消息
*/
public void sendErrorMessage(Consumer<String> tokenConsumer, String errorMessage) {
errorHandlerService.sendErrorMessage(tokenConsumer, errorMessage);
}
/**
* 确保在异常情况下也调用完成回调
*
* @param tokenConsumer token处理回调函数
* @param errorMessage 错误消息
*/
public void ensureCompletionCallback(Consumer<String> tokenConsumer, String errorMessage) {
if (tokenConsumer instanceof TokenConsumerWithCompletion) {
try {
((TokenConsumerWithCompletion) tokenConsumer).onComplete(errorMessage);
} catch (Exception ex) {
log.error("调用onComplete时发生错误: {}", ex.getMessage(), ex);
}
}
}
}
\ No newline at end of file
package pangea.hiagent.agent.service;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import pangea.hiagent.model.Agent;
import pangea.hiagent.agent.processor.AgentProcessor;
import pangea.hiagent.agent.processor.AgentProcessorFactory;
import pangea.hiagent.agent.sse.UserSseService;
import pangea.hiagent.common.utils.LogUtils;
import pangea.hiagent.common.utils.ValidationUtils;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* Agent处理器服务
* 负责处理Agent处理器的获取和心跳机制
*/
@Slf4j
@Service
public class AgentProcessorService {
@Autowired
private AgentProcessorFactory agentProcessorFactory;
@Autowired
private UserSseService workPanelSseService;
@Autowired
private ChatErrorHandler chatErrorHandler;
/**
* 获取处理器并启动心跳保活机制
*
* @param agent Agent对象
* @param emitter SSE发射器
* @return Agent处理器,如果获取失败则返回null
*/
public AgentProcessor getProcessorAndStartHeartbeat(Agent agent, SseEmitter emitter) {
LogUtils.enterMethod("getProcessorAndStartHeartbeat", agent);
// 参数验证
if (ValidationUtils.isNull(agent, "agent")) {
chatErrorHandler.handleChatError(emitter, "Agent对象不能为空");
LogUtils.exitMethod("getProcessorAndStartHeartbeat", "Agent对象不能为空");
return null;
}
if (ValidationUtils.isNull(emitter, "emitter")) {
chatErrorHandler.handleChatError(emitter, "SSE发射器不能为空");
LogUtils.exitMethod("getProcessorAndStartHeartbeat", "SSE发射器不能为空");
return null;
}
try {
// 根据Agent类型选择处理器并处理请求
AgentProcessor processor = agentProcessorFactory.getProcessor(agent);
if (processor == null) {
chatErrorHandler.handleChatError(emitter, "无法获取Agent处理器");
LogUtils.exitMethod("getProcessorAndStartHeartbeat", "无法获取Agent处理器");
return null;
}
log.info("使用{} Agent处理器处理对话", processor.getProcessorType());
// 启动心跳保活机制
workPanelSseService.startHeartbeat(emitter, new AtomicBoolean(false));
LogUtils.exitMethod("getProcessorAndStartHeartbeat", processor);
return processor;
} catch (Exception e) {
chatErrorHandler.handleChatError(emitter, "获取处理器或启动心跳时发生错误", e, null);
LogUtils.exitMethod("getProcessorAndStartHeartbeat", e);
return null;
}
}
}
\ No newline at end of file
package pangea.hiagent.agent.service;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import pangea.hiagent.model.Agent;
import pangea.hiagent.web.service.AgentService;
import pangea.hiagent.common.utils.LogUtils;
import pangea.hiagent.common.utils.ValidationUtils;
import pangea.hiagent.common.utils.UserUtils;
/**
* Agent验证服务
* 负责处理Agent的参数验证和权限检查
*/
@Slf4j
@Service
public class AgentValidationService {
@Autowired
private AgentService agentService;
@Autowired
private ChatErrorHandler chatErrorHandler;
/**
* 验证Agent存在性和用户权限
*
* @param agentId Agent ID
* @param userId 用户ID
* @param emitter SSE发射器
* @return Agent对象,如果验证失败则返回null
*/
public Agent validateAgentAndPermission(String agentId, String userId, SseEmitter emitter) {
LogUtils.enterMethod("validateAgentAndPermission", agentId, userId);
// 参数验证
if (ValidationUtils.isBlank(agentId, "agentId")) {
chatErrorHandler.handleChatError(emitter, "Agent ID不能为空");
LogUtils.exitMethod("validateAgentAndPermission", "Agent ID不能为空");
return null;
}
if (ValidationUtils.isBlank(userId, "userId")) {
chatErrorHandler.handleChatError(emitter, "用户ID不能为空");
LogUtils.exitMethod("validateAgentAndPermission", "用户ID不能为空");
return null;
}
try {
// 获取Agent信息
Agent agent = agentService.getAgent(agentId);
if (agent == null) {
chatErrorHandler.handleChatError(emitter, "Agent不存在");
LogUtils.exitMethod("validateAgentAndPermission", "Agent不存在");
return null;
}
// 检查权限(可选)
if (!agent.getOwner().equals(userId) && !UserUtils.isAdminUser(userId)) {
chatErrorHandler.handleChatError(emitter, "无权限访问该Agent");
LogUtils.exitMethod("validateAgentAndPermission", "无权限访问该Agent");
return null;
}
LogUtils.exitMethod("validateAgentAndPermission", agent);
return agent;
} catch (Exception e) {
chatErrorHandler.handleChatError(emitter, "验证Agent和权限时发生错误", e, null);
LogUtils.exitMethod("validateAgentAndPermission", e);
return null;
}
}
}
\ No newline at end of file
package pangea.hiagent.agent.service;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* 聊天服务错误处理工具类
* 统一处理聊天过程中的各种异常情况
* 委托给ErrorHandlerService进行实际处理
*/
@Slf4j
@Component
public class ChatErrorHandler {
@Autowired
private ErrorHandlerService unifiedErrorHandlerService;
/**
* 处理聊天过程中的异常
*
* @param emitter SSE发射器
* @param errorMessage 错误信息
* @param exception 异常对象
* @param processorType 处理器类型(可选)
*/
public void handleChatError(SseEmitter emitter, String errorMessage, Exception exception, String processorType) {
unifiedErrorHandlerService.handleChatError(emitter, errorMessage, exception, processorType);
}
/**
* 处理聊天过程中的异常()
*
* @param emitter SSE发射器
* @param errorMessage 错误信息
*/
public void handleChatError(SseEmitter emitter, String errorMessage) {
unifiedErrorHandlerService.handleChatError(emitter, errorMessage);
}
/**
* 处理Token处理过程中的异常
*
* @param emitter SSE发射器
* @param processorType 处理器类型
* @param exception 异常对象
* @param isCompleted 完成状态标记
*/
public void handleTokenError(SseEmitter emitter, String processorType, Exception exception, AtomicBoolean isCompleted) {
unifiedErrorHandlerService.handleTokenError(emitter, processorType, exception, isCompleted);
}
/**
* 处理完成回调过程中的异常
*
* @param emitter SSE发射器
* @param exception 异常对象
*/
public void handleCompletionError(SseEmitter emitter, Exception exception) {
unifiedErrorHandlerService.handleCompletionError(emitter, exception);
}
/**
* 处理对话记录保存过程中的异常
*
* @param emitter SSE发射器
* @param exception 异常对象
* @param isCompleted 完成状态标记
*/
public void handleSaveDialogueError(SseEmitter emitter, Exception exception, AtomicBoolean isCompleted) {
unifiedErrorHandlerService.handleSaveDialogueError(emitter, exception, isCompleted);
}
}
\ No newline at end of file
package pangea.hiagent.agent.service;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import pangea.hiagent.model.Agent;
import pangea.hiagent.model.AgentDialogue;
import pangea.hiagent.common.utils.ValidationUtils;
import pangea.hiagent.agent.processor.AgentProcessor;
import pangea.hiagent.agent.sse.UserSseService;
import pangea.hiagent.common.utils.LogUtils;
import pangea.hiagent.common.utils.UserUtils;
import pangea.hiagent.web.dto.AgentRequest;
import pangea.hiagent.web.service.AgentService;
import pangea.hiagent.workpanel.event.EventService;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* 完成回调处理服务
* 负责处理流式输出完成后的回调操作
*/
@Slf4j
@Service
public class CompletionHandlerService {
@Autowired
private AgentService agentService;
@Autowired
private UserSseService unifiedSseService;
@Autowired
private EventService eventService;
@Autowired
private ErrorHandlerService errorHandlerService;
/**
* 处理完成回调
*
* @param emitter SSE发射器
* @param processor Agent处理器
* @param agent Agent对象
* @param request Agent请求
* @param userId 用户ID
* @param fullContent 完整内容
* @param isCompleted 完成状态标记
*/
public void handleCompletion(SseEmitter emitter, AgentProcessor processor, Agent agent,
AgentRequest request, String userId,
String fullContent, AtomicBoolean isCompleted) {
LogUtils.enterMethod("handleCompletion", emitter, processor, agent, request, userId);
// 参数验证
if (ValidationUtils.isNull(emitter, "emitter")) {
log.error("SSE发射器不能为空");
LogUtils.exitMethod("handleCompletion", "SSE发射器不能为空");
return;
}
if (ValidationUtils.isNull(processor, "processor")) {
log.error("Agent处理器不能为空");
LogUtils.exitMethod("handleCompletion", "Agent处理器不能为空");
return;
}
if (ValidationUtils.isNull(agent, "agent")) {
log.error("Agent对象不能为空");
LogUtils.exitMethod("handleCompletion", "Agent对象不能为空");
return;
}
if (ValidationUtils.isNull(request, "request")) {
log.error("Agent请求不能为空");
LogUtils.exitMethod("handleCompletion", "Agent请求不能为空");
return;
}
if (ValidationUtils.isBlank(userId, "userId")) {
log.error("用户ID不能为空");
LogUtils.exitMethod("handleCompletion", "用户ID不能为空");
return;
}
if (ValidationUtils.isNull(isCompleted, "isCompleted")) {
log.error("完成状态标记不能为空");
LogUtils.exitMethod("handleCompletion", "完成状态标记不能为空");
return;
}
log.info("{} Agent处理完成,总字符数: {}", processor.getProcessorType(), fullContent != null ? fullContent.length() : 0);
// 发送完成事件
try {
// 发送完整内容作为最后一个token
if (fullContent != null && !fullContent.isEmpty()) {
eventService.sendTokenEvent(emitter, fullContent);
}
// 发送完成信号
emitter.send("[DONE]");
} catch (Exception e) {
errorHandlerService.handleCompletionError(emitter, e);
}
// 保存对话记录
try {
saveDialogue(agent, request, userId, fullContent);
} catch (Exception e) {
errorHandlerService.handleSaveDialogueError(emitter, e, isCompleted);
} finally {
unifiedSseService.completeEmitter(emitter, isCompleted);
}
LogUtils.exitMethod("handleCompletion", "处理完成");
}
/**
* 保存对话记录
*/
public void saveDialogue(Agent agent, AgentRequest request, String userId, String responseContent) {
LogUtils.enterMethod("saveDialogue", agent, request, userId);
// 参数验证
if (ValidationUtils.isNull(agent, "agent")) {
log.error("Agent对象不能为空");
LogUtils.exitMethod("saveDialogue", "Agent对象不能为空");
return;
}
if (ValidationUtils.isNull(request, "request")) {
log.error("Agent请求不能为空");
LogUtils.exitMethod("saveDialogue", "Agent请求不能为空");
return;
}
if (ValidationUtils.isBlank(userId, "userId")) {
log.error("用户ID不能为空");
LogUtils.exitMethod("saveDialogue", "用户ID不能为空");
return;
}
try {
// 创建对话记录
AgentDialogue dialogue = AgentDialogue.builder()
.agentId(request.getAgentId())
.userMessage(request.getUserMessage())
.agentResponse(responseContent)
.userId(userId)
.build();
// 确保ID被设置
if (dialogue.getId() == null || dialogue.getId().isEmpty()) {
dialogue.setId(java.util.UUID.randomUUID().toString());
}
// 设置创建人和更新人信息
// 在异步线程中获取用户ID
String currentUserId = UserUtils.getCurrentUserIdInAsync();
if (currentUserId == null) {
currentUserId = userId; // 回退到传入的userId
}
dialogue.setCreatedBy(currentUserId);
dialogue.setUpdatedBy(currentUserId);
// 保存对话记录
agentService.saveDialogue(dialogue);
LogUtils.exitMethod("saveDialogue", "保存成功");
} catch (Exception e) {
log.error("保存对话记录失败", e);
LogUtils.exitMethod("saveDialogue", e);
throw new RuntimeException("保存对话记录失败", e);
}
}
}
\ No newline at end of file
......@@ -4,7 +4,7 @@ import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import pangea.hiagent.workpanel.event.EventService;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
......@@ -17,10 +17,10 @@ import java.util.function.Consumer;
public class ErrorHandlerService {
@Autowired
private EventService eventService;
private ExceptionMonitoringService exceptionMonitoringService;
@Autowired
private ExceptionMonitoringService exceptionMonitoringService;
private UserSseService userSseService;
/**
* 生成错误跟踪ID
......@@ -129,8 +129,13 @@ public class ErrorHandlerService {
errorMessage, exception);
try {
// 检查emitter是否已经完成,避免向已完成的连接发送错误信息
if (userSseService != null && !userSseService.isEmitterCompleted(emitter)) {
String fullErrorMessage = buildFullErrorMessage(errorMessage, exception, errorId, processorType);
eventService.sendErrorEvent(emitter, fullErrorMessage);
userSseService.sendErrorEvent(emitter, fullErrorMessage);
} else {
log.debug("[{}] SSE emitter已完成,跳过发送错误信息", errorId);
}
} catch (Exception sendErrorEx) {
log.error("[{}] 发送错误信息失败", errorId, sendErrorEx);
}
......@@ -154,8 +159,13 @@ public class ErrorHandlerService {
log.error("[{}] 处理聊天请求时发生错误: {}", errorId, errorMessage);
try {
// 检查emitter是否已经完成,避免向已完成的连接发送错误信息
if (userSseService != null && !userSseService.isEmitterCompleted(emitter)) {
String fullErrorMessage = buildFullErrorMessage(errorMessage, null, errorId, null);
eventService.sendErrorEvent(emitter, fullErrorMessage);
userSseService.sendErrorEvent(emitter, fullErrorMessage);
} else {
log.debug("[{}] SSE emitter已完成,跳过发送错误信息", errorId);
}
} catch (Exception sendErrorEx) {
log.error("[{}] 发送错误信息失败", errorId, sendErrorEx);
}
......@@ -190,9 +200,14 @@ public class ErrorHandlerService {
log.error("[{}] {}处理token时发生错误", errorId, processorType, exception);
if (!isCompleted.getAndSet(true)) {
try {
// 检查emitter是否已经完成,避免向已完成的连接发送错误信息
if (userSseService != null && !userSseService.isEmitterCompleted(emitter)) {
String errorMessage = "处理响应时发生错误";
String fullErrorMessage = buildFullErrorMessage(errorMessage, exception, errorId, processorType);
eventService.sendErrorEvent(emitter, fullErrorMessage);
userSseService.sendErrorEvent(emitter, fullErrorMessage);
} else {
log.debug("[{}] SSE emitter已完成,跳过发送错误信息", errorId);
}
} catch (Exception ignored) {
if (log.isDebugEnabled()) {
log.debug("[{}] 无法发送错误信息", errorId);
......@@ -223,9 +238,14 @@ public class ErrorHandlerService {
log.error("[{}] 发送完成事件失败", errorId, exception);
try {
// 检查emitter是否已经完成,避免向已完成的连接发送错误信息
if (userSseService != null && !userSseService.isEmitterCompleted(emitter)) {
String errorMessage = "发送完成事件失败,请联系技术支持";
String fullErrorMessage = buildFullErrorMessage(errorMessage, exception, errorId, "完成回调");
eventService.sendErrorEvent(emitter, fullErrorMessage);
userSseService.sendErrorEvent(emitter, fullErrorMessage);
} else {
log.debug("[{}] SSE emitter已完成,跳过发送错误信息", errorId);
}
} catch (Exception sendErrorEx) {
log.error("[{}] 发送错误信息失败", errorId, sendErrorEx);
}
......@@ -333,9 +353,14 @@ public class ErrorHandlerService {
if (!isCompleted.getAndSet(true)) {
try {
// 检查emitter是否已经完成,避免向已完成的连接发送错误信息
if (userSseService != null && !userSseService.isEmitterCompleted(emitter)) {
String errorMessage = "保存对话记录失败,请联系技术支持";
String fullErrorMessage = buildFullErrorMessage(errorMessage, exception, errorId, "对话记录");
eventService.sendErrorEvent(emitter, fullErrorMessage);
userSseService.sendErrorEvent(emitter, fullErrorMessage);
} else {
log.debug("[{}] SSE emitter已完成,跳过发送错误信息", errorId);
}
} catch (Exception sendErrorEx) {
log.error("[{}] 发送错误信息失败", errorId, sendErrorEx);
}
......
package pangea.hiagent.agent.service;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
/**
* 错误处理工具类
* 提供统一的错误处理方法,减少重复代码
* 委托给ErrorHandlerService进行实际处理
*/
@Slf4j
@Component
public class ErrorHandlerUtils {
private static ErrorHandlerService errorHandlerService;
@Autowired
public ErrorHandlerUtils(ErrorHandlerService errorHandlerService) {
ErrorHandlerUtils.errorHandlerService = errorHandlerService;
}
/**
* 构建完整的错误消息
*
* @param errorMessage 基本错误信息
* @param exception 异常对象
* @param errorId 错误跟踪ID
* @param processorType 处理器类型
* @return 完整的错误消息
*/
public static String buildFullErrorMessage(String errorMessage, Exception exception, String errorId, String processorType) {
return errorHandlerService.buildFullErrorMessage(errorMessage, exception, errorId, processorType);
}
/**
* 检查是否为未授权错误
*
* @param exception 异常对象
* @return 是否为未授权错误
*/
public static boolean isUnauthorizedError(Exception exception) {
return errorHandlerService.isUnauthorizedError(exception);
}
/**
* 检查是否为超时错误
*
* @param exception 异常对象
* @return 是否为超时错误
*/
public static boolean isTimeoutError(Exception exception) {
return errorHandlerService.isTimeoutError(exception);
}
/**
* 生成错误跟踪ID
*
* @return 错误跟踪ID
*/
public static String generateErrorId() {
return errorHandlerService.generateErrorId();
}
}
\ No newline at end of file
package pangea.hiagent.agent.service;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import pangea.hiagent.model.Agent;
import pangea.hiagent.web.dto.AgentRequest;
/**
* SSE Token发射器
* 专注于将token转换为SSE事件并发送
*/
@Slf4j
@Component
public class SseTokenEmitter implements TokenConsumerWithCompletion {
private final UserSseService userSseService;
// 当前处理的emitter
private SseEmitter emitter;
// 上下文信息
private Agent agent;
private AgentRequest request;
private String userId;
// 完成回调
private CompletionCallback completionCallback;
public SseTokenEmitter(UserSseService userSseService) {
this.userSseService = userSseService;
}
/**
* 设置当前使用的SSE发射器
*/
public void setEmitter(SseEmitter emitter) {
this.emitter = emitter;
}
/**
* 设置上下文信息
*/
public void setContext(Agent agent, AgentRequest request, String userId) {
this.agent = agent;
this.request = request;
this.userId = userId;
}
/**
* 设置完成回调
*/
public void setCompletionCallback(CompletionCallback completionCallback) {
this.completionCallback = completionCallback;
}
@Override
public void accept(String token) {
// 使用JSON格式发送token,确保转义序列被正确处理
try {
if (emitter != null && userSseService.isEmitterValidSafe(emitter)) {
// 检查是否是错误消息(以[错误]或[ERROR]开头)
if (token != null && (token.startsWith("[错误]") || token.startsWith("[ERROR]"))) {
// 发送标准错误事件而不是纯文本
userSseService.sendErrorEvent(emitter, token);
} else {
// 使用SSE标准事件格式发送token,以JSON格式确保转义序列正确处理
userSseService.sendTokenEvent(emitter, token);
}
} else {
log.debug("SSE emitter已无效,跳过发送token");
}
} catch (Exception e) {
log.error("发送token失败", e);
}
}
@Override
public void onComplete(String fullContent) {
try {
if (emitter != null && !userSseService.isEmitterCompleted(emitter)) {
// 发送完成事件
emitter.send(SseEmitter.event().name("done").data("[DONE]").build());
log.debug("完成信号已发送");
}
// 调用完成回调
if (completionCallback != null) {
completionCallback.onComplete(emitter, agent, request, userId, fullContent);
}
} catch (Exception e) {
log.error("处理完成事件失败", e);
} finally {
// 关闭连接
closeEmitter();
}
}
/**
* 安全关闭SSE连接
*/
public void closeEmitter() {
try {
if (emitter != null && !userSseService.isEmitterCompleted(emitter)) {
emitter.complete();
log.debug("SSE连接已关闭");
}
} catch (Exception ex) {
log.error("完成emitter时发生错误", ex);
}
}
/**
* 完成回调接口
*/
@FunctionalInterface
public interface CompletionCallback {
void onComplete(SseEmitter emitter, Agent agent, AgentRequest request, String userId, String fullContent);
}
}
\ No newline at end of file
package pangea.hiagent.agent.service;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import pangea.hiagent.agent.processor.AgentProcessor;
import pangea.hiagent.agent.sse.UserSseService;
import pangea.hiagent.workpanel.event.EventService;
import pangea.hiagent.model.Agent;
import pangea.hiagent.common.utils.LogUtils;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* 流式请求服务
* 负责处理流式请求
*/
@Slf4j
@Service
public class StreamRequestService {
@Autowired
private UserSseService unifiedSseService;
@Autowired
private EventService eventService;
@Autowired
private CompletionHandlerService completionHandlerService;
/**
* 处理流式请求
*
* @param emitter SSE发射器
* @param processor Agent处理器
* @param request Agent请求
* @param agent Agent对象
* @param userId 用户ID
*/
public void handleStreamRequest(SseEmitter emitter, AgentProcessor processor, pangea.hiagent.web.dto.AgentRequest request, Agent agent, String userId,String emitterId) {
LogUtils.enterMethod("handleStreamRequest", emitter, processor, request, agent, userId);
// 参数验证
if (!validateParameters(emitter, processor, request, agent, userId)) {
return;
}
// 创建流式处理的Token消费者
StreamTokenConsumer tokenConsumer = new StreamTokenConsumer(emitter, processor, unifiedSseService, eventService, completionHandlerService);
// 设置上下文信息,用于保存对话记录
tokenConsumer.setContext(agent, request, userId);
tokenConsumer.setEmitterId(emitterId);
// 处理流式请求,将token缓冲和事件发送完全交给处理器实现
processor.processStreamRequest(request, agent, userId, tokenConsumer);
LogUtils.exitMethod("handleStreamRequest", "处理完成");
}
/**
* 验证所有必需参数
*
* @param emitter SSE发射器
* @param processor Agent处理器
* @param request Agent请求
* @param agent Agent对象
* @param userId 用户ID
* @return 验证是否通过
*/
private boolean validateParameters(SseEmitter emitter, AgentProcessor processor, pangea.hiagent.web.dto.AgentRequest request, Agent agent, String userId) {
return emitter != null && processor != null && request != null && agent != null && userId != null && !userId.isEmpty();
}
/**
* 流式处理的Token消费者实现
* 用于处理来自Agent处理器的token流,并将其转发给SSE emitter
*/
public static class StreamTokenConsumer implements TokenConsumerWithCompletion {
private final SseEmitter emitter;
private final AgentProcessor processor;
private final EventService eventService;
private final AtomicBoolean isCompleted = new AtomicBoolean(false);
private Agent agent;
private pangea.hiagent.web.dto.AgentRequest request;
private String userId;
private CompletionHandlerService completionHandlerService;
private String emitterId;
public StreamTokenConsumer(SseEmitter emitter, AgentProcessor processor, UserSseService unifiedSseService, EventService eventService, CompletionHandlerService completionHandlerService) {
this.emitter = emitter;
this.processor = processor;
this.eventService = eventService;
this.completionHandlerService = completionHandlerService;
}
public void setContext(Agent agent, pangea.hiagent.web.dto.AgentRequest request, String userId) {
this.agent = agent;
this.request = request;
this.userId = userId;
}
public void setEmitterId(String emitterId) {
this.emitterId = emitterId;
}
public String getEmitterId() {
return emitterId;
}
public String getUserId() {
return userId;
}
@Override
public void accept(String token) {
// 使用JSON格式发送token,确保转义序列被正确处理
try {
if (!isCompleted.get()) {
// 检查是否是错误消息(以[错误]或[ERROR]开头)
if (token != null && (token.startsWith("[错误]") || token.startsWith("[ERROR]"))) {
// 发送标准错误事件而不是纯文本
eventService.sendErrorEvent(emitter, token);
} else {
// 使用SSE标准事件格式发送token,以JSON格式确保转义序列正确处理
eventService.sendTokenEvent(emitter, token);
}
}
} catch (Exception e) {
log.error("发送token失败", e);
}
}
@Override
public void onComplete(String fullContent) {
// 处理完成时的回调
if (isCompleted.getAndSet(true)) {
log.debug("{} Agent处理已完成,跳过重复的完成回调", processor.getProcessorType());
return;
}
log.info("{} Agent处理完成,总字符数: {}", processor.getProcessorType(), fullContent != null ? fullContent.length() : 0);
try {
// 使用CompletionHandlerService处理完成回调
if (completionHandlerService != null) {
completionHandlerService.handleCompletion(emitter, processor, agent, request, userId, fullContent, isCompleted);
} else {
// 如果completionHandlerService不可用,使用默认处理逻辑
try {
// 发送完成事件
emitter.send("[DONE]");
// 完成 emitter
emitter.complete();
} catch (Exception e) {
log.error("处理完成事件失败", e);
} }
} catch (Exception e) {
log.error("处理完成事件失败", e);
// 确保即使出现异常也完成emitter
try {
emitter.complete();
} catch (Exception ex) {
log.error("完成emitter时发生错误", ex);
}
}
} }
}
\ No newline at end of file
package pangea.hiagent.agent.sse;
import lombok.extern.slf4j.Slf4j;
import pangea.hiagent.web.dto.WorkPanelEvent;
import pangea.hiagent.workpanel.event.EventService;
import org.springframework.stereotype.Component;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.util.function.Consumer;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* SSE连接协调器
* 专门负责协调SSE连接的创建、管理和销毁过程
*/
@Slf4j
@Component
public class SseConnectionCoordinator {
private final UserSseService unifiedSseService;
private final EventService eventService;
public SseConnectionCoordinator(
UserSseService unifiedSseService,
EventService eventService) {
this.unifiedSseService = unifiedSseService;
this.eventService = eventService;
}
/**
* 创建并注册SSE连接
*
* @param userId 用户ID
* @return SSE Emitter
*/
public SseEmitter createAndRegisterConnection(String userId) {
log.debug("开始为用户 {} 创建SSE连接", userId);
// 创建 SSE emitter
SseEmitter emitter = unifiedSseService.createEmitter();
log.debug("SSE Emitter创建成功");
// 注册用户的SSE连接
unifiedSseService.registerSession(userId, emitter);
log.debug("用户 {} 的SSE连接注册成功", userId);
// 注册 emitter 回调
unifiedSseService.registerCallbacks(emitter, userId);
log.debug("SSE Emitter回调注册成功");
// 启动心跳机制
unifiedSseService.startHeartbeat(emitter, new AtomicBoolean(false));
log.debug("心跳机制启动成功");
log.info("用户 {} 的SSE连接创建和注册完成", userId);
return emitter;
}
/**
* 订阅工作面板事件
*
* @param userId 用户ID
* @param workPanelEventConsumer 工作面板事件消费者
* @param emitter SSE Emitter
*/
public void subscribeToWorkPanelEvents(String userId, Consumer<WorkPanelEvent> workPanelEventConsumer, SseEmitter emitter) {
log.debug("开始为用户 {} 订阅工作面板事件", userId);
// 发送连接成功事件
try {
WorkPanelEvent connectedEvent = WorkPanelEvent.builder()
.type("observation")
.title("连接成功")
.timestamp(System.currentTimeMillis())
.build();
eventService.sendWorkPanelEvent(emitter, connectedEvent);
log.debug("已发送连接成功事件");
} catch (Exception e) {
log.error("发送连接成功事件失败", e);
}
log.info("用户 {} 的工作面板事件订阅完成", userId);
}
}
\ No newline at end of file
package pangea.hiagent.agent.sse;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicLong;
/**
* 用户会话管理器
* 专门负责管理用户的SSE会话连接
*/
@Slf4j
@Component
public class UserSseManager {
// 存储用户ID到SSE Emitter的映射关系
private final ConcurrentMap<String, SseEmitter> userEmitters = new ConcurrentHashMap<>();
// 存储SSE Emitter到用户ID的反向映射关系(用于快速查找)
private final ConcurrentMap<SseEmitter, String> emitterUsers = new ConcurrentHashMap<>();
// 存储连接创建时间,用于超时检查
private final ConcurrentMap<SseEmitter, AtomicLong> connectionTimes = new ConcurrentHashMap<>();
// 连接超时时间(毫秒),默认30分钟
private static final long CONNECTION_TIMEOUT = 30 * 60 * 1000L;
/**
* 注册用户的SSE连接
* 如果该用户已有连接,则先关闭旧连接再注册新连接
*
* @param userId 用户ID
* @param emitter SSE Emitter
* @return true表示注册成功,false表示注册失败
*/
public boolean registerSession(String userId, SseEmitter emitter) {
if (userId == null || userId.isEmpty() || emitter == null) {
log.warn("注册SSE会话失败:用户ID或Emitter为空");
return false;
}
try {
// 检查该用户是否已有连接
SseEmitter existingEmitter = userEmitters.get(userId);
if (existingEmitter != null) {
// 如果已有连接,先关闭旧连接
log.info("用户 {} 已有SSE连接,关闭旧连接", userId);
closeEmitter(existingEmitter);
// 从映射中移除旧连接
userEmitters.remove(userId, existingEmitter);
emitterUsers.remove(existingEmitter);
}
// 注册新连接
userEmitters.put(userId, emitter);
emitterUsers.put(emitter, userId);
// 记录连接创建时间
connectionTimes.put(emitter, new AtomicLong(System.currentTimeMillis()));
log.info("成功为用户 {} 注册SSE会话", userId);
return true;
} catch (Exception e) {
log.error("注册SSE会话时发生异常:用户ID={}", userId, e);
return false;
}
}
/**
* 移除用户的SSE会话
*
* @param emitter SSE Emitter
*/
public void removeSession(SseEmitter emitter) {
if (emitter == null) {
return;
}
try {
String userId = emitterUsers.get(emitter);
if (userId != null) {
userEmitters.remove(userId, emitter);
emitterUsers.remove(emitter);
connectionTimes.remove(emitter);
log.debug("已移除用户 {} 的SSE会话", userId);
}
} catch (Exception e) {
log.warn("移除SSE会话时发生异常", e);
}
}
/**
* 获取用户的当前SSE连接
*
* @param userId 用户ID
* @return SSE Emitter,如果不存在则返回null
*/
public SseEmitter getSession(String userId) {
if (userId == null || userId.isEmpty()) {
return null;
}
return userEmitters.get(userId);
}
/**
* 检查用户是否有活跃的SSE连接
*
* @param userId 用户ID
* @return true表示有活跃连接,false表示没有
*/
public boolean hasActiveSession(String userId) {
if (userId == null || userId.isEmpty()) {
return false;
}
SseEmitter emitter = userEmitters.get(userId);
return emitter != null && isEmitterValid(emitter) && !isSessionExpired(emitter);
}
/**
* 检查SSE Emitter是否仍然有效
*
* @param emitter 要检查的emitter
* @return 如果有效返回true,否则返回false
*/
public boolean isEmitterValid(SseEmitter emitter) {
if (emitter == null) {
return false;
}
// 检查emitter是否已完成
try {
// 尝试发送一个空事件来检查连接状态
emitter.send(SseEmitter.event().name("ping").data("").build());
return true;
} catch (Exception e) {
// 如果出现任何异常,认为连接已失效
return false;
}
}
/**
* 检查会话是否已过期
*
* @param emitter 要检查的emitter
* @return 如果过期返回true,否则返回false
*/
public boolean isSessionExpired(SseEmitter emitter) {
if (emitter == null) {
return true;
}
AtomicLong connectionTime = connectionTimes.get(emitter);
if (connectionTime == null) {
return false; // 如果没有记录时间,假设未过期
}
long currentTime = System.currentTimeMillis();
long connectionStartTime = connectionTime.get();
return (currentTime - connectionStartTime) > CONNECTION_TIMEOUT;
}
/**
* 关闭指定的SSE Emitter
*
* @param emitter SSE Emitter
*/
public void closeEmitter(SseEmitter emitter) {
if (emitter == null) {
return;
}
try {
emitter.complete();
log.debug("已关闭SSE Emitter");
} catch (Exception e) {
log.warn("关闭SSE Emitter时发生异常", e);
}
}
/**
* 获取当前会话的用户数量
*
* @return 用户数量
*/
public int getSessionCount() {
return userEmitters.size();
}
/**
* 清理所有会话(用于系统关闭时)
*/
public void clearAllSessions() {
try {
for (SseEmitter emitter : emitterUsers.keySet()) {
try {
emitter.complete();
} catch (Exception e) {
log.warn("关闭SSE Emitter时发生异常", e);
}
}
userEmitters.clear();
emitterUsers.clear();
connectionTimes.clear();
log.info("已清理所有SSE会话");
} catch (Exception e) {
log.error("清理所有SSE会话时发生异常", e);
}
}
/**
* 处理连接完成事件
* 职责:协调完成连接的清理工作
*
* @param emitter SSE Emitter
*/
public void handleConnectionCompletion(SseEmitter emitter) {
log.debug("处理SSE连接完成事件");
// 移除连接
removeSession(emitter);
}
/**
* 处理连接超时事件
* 职责:协调超时连接的清理工作
*
* @param emitter SSE Emitter
*/
public void handleConnectionTimeout(SseEmitter emitter) {
log.debug("处理SSE连接超时事件");
// 移除连接
removeSession(emitter);
}
/**
* 处理连接错误事件
* 职责:协调错误连接的清理工作
*
* @param emitter SSE Emitter
*/
public void handleConnectionError(SseEmitter emitter) {
log.debug("处理SSE连接错误事件");
// 移除连接
removeSession(emitter);
}
}
\ No newline at end of file
......@@ -5,7 +5,6 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.reflection.MetaObject;
import org.springframework.stereotype.Component;
import pangea.hiagent.common.utils.UserUtils;
import java.time.LocalDateTime;
/**
......@@ -46,7 +45,7 @@ public class MetaObjectHandlerConfig implements MetaObjectHandler {
if (metaObject.hasSetter("createdBy")) {
Object createdBy = getFieldValByName("createdBy", metaObject);
if (createdBy == null) {
String userId = UserUtils.getCurrentUserId();
String userId = getCurrentUserIdWithContext();
if (userId != null) {
this.strictInsertFill(metaObject, "createdBy", String.class, userId);
log.debug("自动填充createdBy字段: {}", userId);
......@@ -60,7 +59,7 @@ public class MetaObjectHandlerConfig implements MetaObjectHandler {
if (metaObject.hasSetter("updatedBy")) {
Object updatedBy = getFieldValByName("updatedBy", metaObject);
if (updatedBy == null) {
String userId = UserUtils.getCurrentUserId();
String userId = getCurrentUserIdWithContext();
if (userId != null) {
this.strictInsertFill(metaObject, "updatedBy", String.class, userId);
log.debug("自动填充updatedBy字段: {}", userId);
......@@ -91,7 +90,7 @@ public class MetaObjectHandlerConfig implements MetaObjectHandler {
Object updatedBy = getFieldValByName("updatedBy", metaObject);
// 如果updatedBy为空或者需要强制更新,则填充当前用户ID
if (updatedBy == null) {
String userId = UserUtils.getCurrentUserId();
String userId = getCurrentUserIdWithContext();
if (userId != null) {
this.strictUpdateFill(metaObject, "updatedBy", String.class, userId);
log.debug("自动填充updatedBy字段: {}", userId);
......@@ -101,4 +100,39 @@ public class MetaObjectHandlerConfig implements MetaObjectHandler {
}
}
}
/**
* 获取当前用户ID,支持异步线程上下文
* 该方法支持以下场景:
* 1. 同步请求:从SecurityContext获取用户ID
* 2. 异步任务:从AsyncUserContextDecorator传播的上下文获取用户ID
* 3. 故障转移:尝试直接解析Token获取用户ID
*
* @return 用户ID,如果无法获取则返回null
*/
private String getCurrentUserIdWithContext() {
try {
// 方式1:首先尝试从SecurityContext获取(支持同步请求和AsyncUserContextDecorator传播)
String userId = UserUtils.getCurrentUserId();
if (userId != null) {
log.debug("通过SecurityContext成功获取用户ID: {}", userId);
return userId;
}
log.debug("无法从SecurityContext获取用户ID,可能是异步线程且未使用AsyncUserContextDecorator包装");
// 方式2:尝试直接从请求中解析Token(故障转移)
String asyncUserId = UserUtils.getCurrentUserIdInAsync();
if (asyncUserId != null) {
log.debug("通过直接解析Token成功获取用户ID: {}", asyncUserId);
return asyncUserId;
}
log.warn("无法通过任何方式获取当前用户ID,createdBy/updatedBy字段将不被填充");
return null;
} catch (Exception e) {
log.error("获取用户ID时发生异常", e);
return null;
}
}
}
\ No newline at end of file
......@@ -21,6 +21,7 @@ import pangea.hiagent.web.service.AgentService;
import pangea.hiagent.web.service.TimerService;
import pangea.hiagent.security.DefaultPermissionEvaluator;
import pangea.hiagent.security.JwtAuthenticationFilter;
import pangea.hiagent.security.SseAuthorizationFilter;
import java.io.IOException;
import java.util.Arrays;
......@@ -33,11 +34,13 @@ import java.util.Collections;
public class SecurityConfig {
private final JwtAuthenticationFilter jwtAuthenticationFilter;
private final SseAuthorizationFilter sseAuthorizationFilter;
private final AgentService agentService;
private final TimerService timerService;
public SecurityConfig(JwtAuthenticationFilter jwtAuthenticationFilter, AgentService agentService, TimerService timerService) {
public SecurityConfig(JwtAuthenticationFilter jwtAuthenticationFilter, SseAuthorizationFilter sseAuthorizationFilter, AgentService agentService, TimerService timerService) {
this.jwtAuthenticationFilter = jwtAuthenticationFilter;
this.sseAuthorizationFilter = sseAuthorizationFilter;
this.agentService = agentService;
this.timerService = timerService;
}
......@@ -205,6 +208,8 @@ public class SecurityConfig {
)
// 添加JWT认证过滤器
.addFilterBefore(jwtAuthenticationFilter, UsernamePasswordAuthenticationFilter.class)
// 添加SSE授权检查过滤器,在JWT过滤器之后但在其他安全过滤器之前运行
.addFilterAfter(sseAuthorizationFilter, JwtAuthenticationFilter.class)
// 配置X-Frame-Options头部,允许同源iframe嵌入
.headers(headers -> headers
.frameOptions(frameOptions -> frameOptions
......
package pangea.hiagent.common.config;
// package pangea.hiagent.common.config;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.EnableAspectJAutoProxy;
// import org.springframework.context.annotation.Configuration;
// import org.springframework.context.annotation.EnableAspectJAutoProxy;
/**
* 工具执行日志记录配置类
* 启用AOP代理以实现工具执行信息的自动记录
*/
@Configuration
@EnableAspectJAutoProxy
public class ToolExecutionLoggingConfig {
}
\ No newline at end of file
// /**
// * 工具执行日志记录配置类
// * 启用AOP代理以实现工具执行信息的自动记录
// */
// @Configuration
// @EnableAspectJAutoProxy
// public class ToolExecutionLoggingConfig {
// }
\ No newline at end of file
......@@ -14,6 +14,7 @@ import org.springframework.web.method.annotation.MethodArgumentTypeMismatchExcep
import pangea.hiagent.agent.service.ExceptionMonitoringService;
import pangea.hiagent.web.dto.ApiResponse;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.util.stream.Collectors;
import org.springframework.security.authorization.AuthorizationDeniedException;
......@@ -212,29 +213,66 @@ public class GlobalExceptionHandler {
AuthorizationDeniedException e, HttpServletRequest request) {
log.warn("访问被拒绝: {} - URL: {}", e.getMessage(), request.getRequestURL());
// 更全面地检查响应是否已经提交
boolean responseCommitted = false;
// 检查是否为SSE端点的异常
String requestUri = request.getRequestURI();
boolean isSseEndpoint = requestUri.contains("/api/v1/agent/chat-stream") || requestUri.contains("/api/v1/agent/timeline-events");
// 检查响应是否已经提交
jakarta.servlet.http.HttpServletResponse httpResponse = null;
if (org.springframework.web.context.request.RequestContextHolder.getRequestAttributes() != null) {
Object requestAttributes = org.springframework.web.context.request.RequestContextHolder
.getRequestAttributes();
if (requestAttributes instanceof org.springframework.web.context.request.ServletRequestAttributes) {
org.springframework.web.context.request.ServletRequestAttributes servletRequestAttributes =
(org.springframework.web.context.request.ServletRequestAttributes) requestAttributes;
if (servletRequestAttributes.getResponse() instanceof jakarta.servlet.http.HttpServletResponse) {
httpResponse = (jakarta.servlet.http.HttpServletResponse) servletRequestAttributes.getResponse();
}
}
// 检查request属性
if (request.getAttribute("jakarta.servlet.error.exception") != null) {
responseCommitted = true;
// 检查响应是否已提交
if (httpResponse != null && httpResponse.isCommitted()) {
log.warn("响应已提交,无法发送访问拒绝错误: {}", request.getRequestURL());
// 如果是SSE端点且响应已提交,返回空响应避免二次异常
return ResponseEntity.ok().build();
}
}
// 检查response是否已提交
if (request instanceof org.springframework.web.context.request.NativeWebRequest) {
Object nativeResponse = ((org.springframework.web.context.request.NativeWebRequest) request).getNativeResponse();
if (nativeResponse instanceof jakarta.servlet.http.HttpServletResponse) {
if (((jakarta.servlet.http.HttpServletResponse) nativeResponse).isCommitted()) {
responseCommitted = true;
// 如果是SSE端点,但响应未提交,发送SSE格式的错误响应
if (isSseEndpoint) {
try {
jakarta.servlet.http.HttpServletResponse sseResponse = null;
if (org.springframework.web.context.request.RequestContextHolder.getRequestAttributes() != null) {
Object requestAttributes = org.springframework.web.context.request.RequestContextHolder
.getRequestAttributes();
if (requestAttributes instanceof org.springframework.web.context.request.ServletRequestAttributes) {
org.springframework.web.context.request.ServletRequestAttributes servletRequestAttributes =
(org.springframework.web.context.request.ServletRequestAttributes) requestAttributes;
if (servletRequestAttributes.getResponse() instanceof jakarta.servlet.http.HttpServletResponse) {
sseResponse = (jakarta.servlet.http.HttpServletResponse) servletRequestAttributes.getResponse();
}
}
}
// 如果响应已提交,记录日志并返回空响应以避免二次异常
if (responseCommitted) {
log.warn("响应已提交,无法发送访问拒绝错误: {}", request.getRequestURL());
// 返回空响应而不是build(),避免潜在的响应提交冲突
return ResponseEntity.status(HttpStatus.FORBIDDEN).body(null);
if (sseResponse != null) {
sseResponse.setStatus(HttpServletResponse.SC_FORBIDDEN);
sseResponse.setContentType("text/event-stream;charset=UTF-8");
sseResponse.setCharacterEncoding("UTF-8");
// 发送SSE格式的错误事件
sseResponse.getWriter().write("event: error\n");
sseResponse.getWriter().write("data: {\"error\": \"访问被拒绝,无权限访问该资源\", \"code\": 403, \"timestamp\": " +
System.currentTimeMillis() + "}\n\n");
sseResponse.getWriter().flush();
log.debug("已发送SSE 访问拒绝错误响应");
}
return ResponseEntity.ok().build();
} catch (Exception ex) {
log.error("发送SSE访问拒绝错误响应失败", ex);
return ResponseEntity.ok().build();
}
}
ApiResponse.ErrorDetail errorDetail = ApiResponse.ErrorDetail.builder()
......@@ -242,9 +280,9 @@ public class GlobalExceptionHandler {
.details("您没有权限执行此操作")
.build();
ApiResponse<Void> response = ApiResponse.error(ErrorCode.FORBIDDEN.getCode(),
ApiResponse<Void> finalResponse = ApiResponse.error(ErrorCode.FORBIDDEN.getCode(),
ErrorCode.FORBIDDEN.getMessage(), errorDetail);
return ResponseEntity.status(HttpStatus.FORBIDDEN).body(response);
return ResponseEntity.status(HttpStatus.FORBIDDEN).body(finalResponse);
}
/**
......
package pangea.hiagent.common.utils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.stereotype.Component;
import java.io.Serializable;
import java.util.concurrent.Callable;
/**
......@@ -13,6 +17,79 @@ import java.util.concurrent.Callable;
@Component
public class AsyncUserContextDecorator {
/**
* 用户上下文持有者类,用于在异步线程间传递认证信息
*/
public static class UserContextHolder implements Serializable {
private static final long serialVersionUID = 1L;
private final Authentication authentication;
public UserContextHolder(Authentication authentication) {
this.authentication = authentication;
}
public Authentication getAuthentication() {
return authentication;
}
}
/**
* 捕获当前线程的用户上下文
* @return 用户上下文持有者对象
*/
public static UserContextHolder captureUserContext() {
try {
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
if (authentication != null) {
log.debug("捕获到当前线程的用户认证信息: {}", authentication.getPrincipal());
return new UserContextHolder(authentication);
} else {
log.debug("当前线程无用户认证信息");
return null;
}
} catch (Exception e) {
log.error("捕获用户上下文时发生异常", e);
return null;
}
}
/**
* 将用户上下文传播到当前线程
* @param userContextHolder 用户上下文持有者对象
*/
public static void propagateUserContext(UserContextHolder userContextHolder) {
try {
if (userContextHolder != null) {
Authentication authentication = userContextHolder.getAuthentication();
if (authentication != null) {
SecurityContext context = SecurityContextHolder.createEmptyContext();
context.setAuthentication(authentication);
SecurityContextHolder.setContext(context);
log.debug("已将用户认证信息传播到当前线程: {}", authentication.getPrincipal());
} else {
log.debug("用户上下文持有者中的认证信息为空");
}
} else {
log.debug("用户上下文持有者为空");
}
} catch (Exception e) {
log.error("传播用户上下文时发生异常", e);
}
}
/**
* 清理当前线程的用户上下文
*/
public static void clearUserContext() {
try {
SecurityContextHolder.clearContext();
log.debug("已清理当前线程的用户上下文");
} catch (Exception e) {
log.error("清理用户上下文时发生异常", e);
}
}
/**
* 包装Runnable任务,自动传播用户上下文
* @param runnable 原始任务
......@@ -20,18 +97,18 @@ public class AsyncUserContextDecorator {
*/
public static Runnable wrapWithContext(Runnable runnable) {
// 捕获当前线程的用户上下文
UserContextPropagationUtil.UserContextHolder userContext = UserContextPropagationUtil.captureUserContext();
UserContextHolder userContext = captureUserContext();
return () -> {
try {
// 在异步线程中传播用户上下文
UserContextPropagationUtil.propagateUserContext(userContext);
propagateUserContext(userContext);
// 执行原始任务
runnable.run();
} finally {
// 清理当前线程的用户上下文
UserContextPropagationUtil.clearUserContext();
clearUserContext();
}
};
}
......@@ -44,18 +121,18 @@ public class AsyncUserContextDecorator {
*/
public static <V> Callable<V> wrapWithContext(Callable<V> callable) {
// 捕获当前线程的用户上下文
UserContextPropagationUtil.UserContextHolder userContext = UserContextPropagationUtil.captureUserContext();
UserContextHolder userContext = captureUserContext();
return () -> {
try {
// 在异步线程中传播用户上下文
UserContextPropagationUtil.propagateUserContext(userContext);
propagateUserContext(userContext);
// 执行原始任务
return callable.call();
} finally {
// 清理当前线程的用户上下文
UserContextPropagationUtil.clearUserContext();
clearUserContext();
}
};
}
......
package pangea.hiagent.common.utils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
/**
* 异步用户上下文使用示例
* 展示如何在异步任务中正确获取用户认证信息
*/
@Slf4j
@Component
public class AsyncUserContextUsageExample {
// 示例线程池
private final ExecutorService executorService = Executors.newFixedThreadPool(10);
/**
* 方式一:使用SecurityContextHolder的InheritableThreadLocal策略(推荐)
* 适用于父子线程关系明确的场景
*/
public void executeTaskWithInheritableThreadLocal() {
// 在主线程中获取用户ID(正常情况下可以获取到)
String userId = UserUtils.getCurrentUserId();
log.info("主线程中获取到用户ID: {}", userId);
// 提交异步任务,由于使用了InheritableThreadLocal策略,子线程可以继承父线程的SecurityContext
CompletableFuture.runAsync(() -> {
// 在异步线程中获取用户ID
String asyncUserId = UserUtils.getCurrentUserId();
log.info("异步线程中获取到用户ID: {}", asyncUserId);
// 执行业务逻辑
performBusinessLogic(asyncUserId);
}, executorService);
}
/**
* 方式二:使用UserContextPropagationUtil手动传播用户上下文
* 适用于复杂的异步场景或需要更精确控制的场景
*/
public void executeTaskWithManualPropagation() {
// 在主线程中获取用户ID
String userId = UserUtils.getCurrentUserId();
log.info("主线程中获取到用户ID: {}", userId);
// 提交异步任务,手动传播用户上下文
CompletableFuture.runAsync(AsyncUserContextDecorator.wrapWithContext(() -> {
// 在异步线程中获取用户ID
String asyncUserId = UserUtils.getCurrentUserId();
log.info("异步线程中获取到用户ID: {}", asyncUserId);
// 执行业务逻辑
performBusinessLogic(asyncUserId);
}), executorService);
}
/**
* 方式三:使用专门的异步环境获取方法
* 适用于无法通过线程上下文传播获取用户信息的场景
*/
public void executeTaskWithDirectTokenParsing() {
// 在主线程中获取用户ID
String userId = UserUtils.getCurrentUserId();
log.info("主线程中获取到用户ID: {}", userId);
// 提交异步任务,直接解析请求中的token获取用户ID
CompletableFuture.runAsync(() -> {
// 在异步线程中通过直接解析token获取用户ID
String asyncUserId = UserUtils.getCurrentUserIdInAsync();
log.info("异步线程中通过直接解析token获取到用户ID: {}", asyncUserId);
// 执行业务逻辑
performBusinessLogic(asyncUserId);
}, executorService);
}
/**
* 执行业务逻辑示例
* @param userId 用户ID
*/
private void performBusinessLogic(String userId) {
if (userId != null) {
log.info("为用户 {} 执行业务逻辑", userId);
// 这里执行具体的业务逻辑
} else {
log.warn("未获取到用户ID,执行匿名用户逻辑");
// 这里执行匿名用户的业务逻辑
}
}
/**
* 清理资源
*/
public void shutdown() {
executorService.shutdown();
}
}
\ No newline at end of file
package pangea.hiagent.common.utils;
public class Contants {
public static final String LOCATOR_SCHEMA = "{\n" +
" \"$schema\": \"http://json-schema.org/draft-07/schema#\",\n" +
" \"type\": \"array\",\n" +
" \"items\": {\n" +
" \"type\": \"object\",\n" +
" \"properties\": {\n" +
" \"field_name\": {\n" +
" \"type\": \"string\"\n" +
" },\n" +
" \"locator\": {\n" +
" \"type\": \"string\"\n" +
" },\n" +
" \"label_tag\": {\n" +
" \"type\": \"string\"\n" +
" },\n" +
" \"attributes\": {\n" +
" \"type\": \"object\",\n" +
" \"properties\": {\n" +
" \"type\": {\n" +
" \"type\": \"string\"\n" +
" },\n" +
" \"maxlength\": {\n" +
" \"type\": \"string\"\n" +
" },\n" +
" \"class\": {\n" +
" \"type\": \"string\"\n" +
" },\n" +
" \"name\": {\n" +
" \"type\": \"string\"\n" +
" },\n" +
" \"value\": {\n" +
" \"type\": \"string\"\n" +
" },\n" +
" \"autocomplete\": {\n" +
" \"type\": \"string\"\n" +
" },\n" +
" \"placeholder\": {\n" +
" \"type\": \"string\"\n" +
" },\n" +
" \"readonly\": {\n" +
" \"type\": \"string\"\n" +
" },\n" +
" \"id\": {\n" +
" \"type\": \"string\"\n" +
" },\n" +
" \"droptreeids\": {\n" +
" \"type\": \"string\"\n" +
" },\n" +
" \"vetitle\": {\n" +
" \"type\": \"string\"\n" +
" },\n" +
" \"contenteditable\": {\n" +
" \"type\": \"string\"\n" +
" },\n" +
" \"style\": {\n" +
" \"type\": \"string\"\n" +
" },\n" +
" \"tipstext\": {\n" +
" \"type\": \"string\"\n" +
" },\n" +
" \"fylx\": {\n" +
" \"type\": \"string\"\n" +
" }\n" +
" },\n" +
" \"additionalProperties\": false,\n" +
" \"required\": [\n" +
" \"class\",\n" +
" \"value\"\n" +
" ]\n" +
" }\n" +
" },\n" +
" \"additionalProperties\": false,\n" +
" \"required\": [\n" +
" \"field_name\",\n" +
" \"locator\",\n" +
" \"attributes\"\n" +
" ]\n" +
" }\n" +
"}";
}
package pangea.hiagent.common.utils;
import java.security.SecureRandom;
import java.util.concurrent.atomic.AtomicLong;
public class HybridUniqueLongGenerator {
private static final SecureRandom random = new SecureRandom();
private static final AtomicLong counter = new AtomicLong(0);
public static long generateUnique13DigitNumber() {
long timestamp = System.currentTimeMillis();
long count = counter.incrementAndGet();
// 使用时间戳的前10位 + 计数器的后3位
long timestampPart = (timestamp / 1000) * 1000;
long counterPart = count % 1000;
return timestampPart + counterPart;
}
// 更随机的版本,但仍保证唯一
public static synchronized long generateRandomUnique() {
long timestamp = System.currentTimeMillis();
// 在时间戳基础上加上一个小的随机偏移
int randomOffset = random.nextInt(100);
long result = timestamp * 100 + randomOffset;
// 确保是13位
while (result >= 10000000000000L) {
result /= 10;
}
while (result < 1000000000000L) {
result *= 10;
result += random.nextInt(10);
}
return result;
}
}
\ No newline at end of file
package pangea.hiagent.common.utils;
import java.security.SecureRandom;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.concurrent.atomic.AtomicLong;
public class InputCodeGenerator {
private static final String CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
private static final SecureRandom random = new SecureRandom();
private static final AtomicLong sequence = new AtomicLong(0);
public static String generateUniqueInputCode(String prefix) {
// 当前时间戳(毫秒)
long timestamp = System.currentTimeMillis();
// 序列号
long seq = sequence.incrementAndGet();
// 组合时间戳和序列号
long combined = (timestamp << 10) | (seq & 0x3FF); // 取序列号后10位
// 转为36进制
String code = Long.toString(Math.abs(combined), 36).toUpperCase();
// 确保8位长度
if (code.length() > 8) {
code = code.substring(code.length() - 8);
} else if (code.length() < 8) {
// 前面补随机字符
StringBuilder sb = new StringBuilder();
for (int i = code.length(); i < 8; i++) {
sb.append(CHARS.charAt(random.nextInt(CHARS.length())));
}
code = sb.toString() + code;
}
return prefix + code;
}
}
package pangea.hiagent.common.utils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.stereotype.Component;
import java.io.Serializable;
/**
* 用户上下文传播工具类
* 用于在异步线程间传播用户认证信息
*/
@Slf4j
@Component
public class UserContextPropagationUtil {
/**
* 用户上下文持有者类,用于在异步线程间传递认证信息
*/
public static class UserContextHolder implements Serializable {
private static final long serialVersionUID = 1L;
private final Authentication authentication;
public UserContextHolder(Authentication authentication) {
this.authentication = authentication;
}
public Authentication getAuthentication() {
return authentication;
}
}
/**
* 捕获当前线程的用户上下文
* @return 用户上下文持有者对象
*/
public static UserContextHolder captureUserContext() {
try {
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
if (authentication != null) {
log.debug("捕获到当前线程的用户认证信息: {}", authentication.getPrincipal());
return new UserContextHolder(authentication);
} else {
log.debug("当前线程无用户认证信息");
return null;
}
} catch (Exception e) {
log.error("捕获用户上下文时发生异常", e);
return null;
}
}
/**
* 将用户上下文传播到当前线程
* @param userContextHolder 用户上下文持有者对象
*/
public static void propagateUserContext(UserContextHolder userContextHolder) {
try {
if (userContextHolder != null) {
Authentication authentication = userContextHolder.getAuthentication();
if (authentication != null) {
SecurityContext context = SecurityContextHolder.createEmptyContext();
context.setAuthentication(authentication);
SecurityContextHolder.setContext(context);
log.debug("已将用户认证信息传播到当前线程: {}", authentication.getPrincipal());
} else {
log.debug("用户上下文持有者中的认证信息为空");
}
} else {
log.debug("用户上下文持有者为空");
}
} catch (Exception e) {
log.error("传播用户上下文时发生异常", e);
}
}
/**
* 清理当前线程的用户上下文
*/
public static void clearUserContext() {
try {
SecurityContextHolder.clearContext();
log.debug("已清理当前线程的用户上下文");
} catch (Exception e) {
log.error("清理用户上下文时发生异常", e);
}
}
}
\ No newline at end of file
......@@ -26,11 +26,22 @@ public class UserUtils {
UserUtils.jwtUtil = jwtUtil;
}
public static String getCurrentUserId() {
String username = getCurrentUserIdInSync();
if (username==null || username.isEmpty()) {
username = getCurrentUserIdInAsync();
}
return username;
}
/**
* 获取当前认证用户ID
*
* @return 用户ID,如果未认证则返回null
*/
public static String getCurrentUserId() {
public static String getCurrentUserIdInSync() {
try {
// 首先尝试从SecurityContext获取
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
......@@ -71,6 +82,7 @@ public class UserUtils {
/**
* 在异步线程环境中获取当前认证用户ID
* 该方法专为异步线程环境设计,通过JWT令牌解析获取用户ID
*
* @return 用户ID,如果未认证则返回null
*/
public static String getCurrentUserIdInAsync() {
......@@ -94,6 +106,7 @@ public class UserUtils {
/**
* 从当前请求中提取JWT令牌并解析用户ID
*
* @return 用户ID,如果无法解析则返回null
*/
private static String getUserIdFromRequest() {
......@@ -161,6 +174,7 @@ public class UserUtils {
/**
* 检查当前用户是否已认证
*
* @return true表示已认证,false表示未认证
*/
public static boolean isAuthenticated() {
......@@ -169,6 +183,7 @@ public class UserUtils {
/**
* 检查用户是否是管理员
*
* @param userId 用户ID
* @return true表示是管理员,false表示不是管理员
*/
......
package pangea.hiagent.security;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import org.springframework.web.filter.OncePerRequestFilter;
import pangea.hiagent.common.utils.JwtUtil;
import pangea.hiagent.web.service.AgentService;
import pangea.hiagent.model.Agent;
import pangea.hiagent.common.utils.UserUtils;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
/**
* SSE流式端点授权检查过滤器
* 在Spring Security的AuthorizationFilter之前运行,提前处理流式端点的身份验证检查
* 避免响应被提交后才处理异常的问题
*/
@Slf4j
@Component
public class SseAuthorizationFilter extends OncePerRequestFilter {
private static final String STREAM_ENDPOINT = "/api/v1/agent/chat-stream";
private static final String TIMELINE_ENDPOINT = "/api/v1/agent/timeline-events";
private final JwtUtil jwtUtil;
private final AgentService agentService;
public SseAuthorizationFilter(JwtUtil jwtUtil, AgentService agentService) {
this.jwtUtil = jwtUtil;
this.agentService = agentService;
}
/**
* 发送SSE格式的未授权错误响应
*/
private void sendSseUnauthorizedError(HttpServletResponse response) {
try {
response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
response.setContentType("text/event-stream;charset=UTF-8");
response.setCharacterEncoding("UTF-8");
// 发送SSE格式的错误事件
response.getWriter().write("event: error\n");
response.getWriter().write("data: {\"error\": \"未授权访问,请先登录\", \"code\": 401, \"timestamp\": " +
System.currentTimeMillis() + "}\n\n");
response.getWriter().flush();
log.debug("已发送SSE未授权错误响应");
} catch (IOException e) {
log.error("发送SSE未授权错误响应失败", e);
}
}
/**
* 发送SSE格式的Agent不存在错误响应
*/
private void sendSseAgentNotFoundError(HttpServletResponse response) {
try {
response.setStatus(HttpServletResponse.SC_NOT_FOUND);
response.setContentType("text/event-stream;charset=UTF-8");
response.setCharacterEncoding("UTF-8");
// 发送SSE格式的错误事件
response.getWriter().write("event: error\n");
response.getWriter().write("data: {\"error\": \"Agent不存在\", \"code\": 404, \"timestamp\": " +
System.currentTimeMillis() + "}\n\n");
response.getWriter().flush();
log.debug("已发送SSE Agent不存在错误响应");
} catch (IOException e) {
log.error("发送SSE Agent不存在错误响应失败", e);
}
}
/**
* 发送SSE格式的访问拒绝错误响应
*/
private void sendSseAccessDeniedError(HttpServletResponse response) {
try {
response.setStatus(HttpServletResponse.SC_FORBIDDEN);
response.setContentType("text/event-stream;charset=UTF-8");
response.setCharacterEncoding("UTF-8");
// 发送SSE格式的错误事件
response.getWriter().write("event: error\n");
response.getWriter().write("data: {\"error\": \"访问被拒绝,无权限访问该Agent\", \"code\": 403, \"timestamp\": " +
System.currentTimeMillis() + "}\n\n");
response.getWriter().flush();
log.debug("已发送SSE 访问拒绝错误响应");
} catch (IOException e) {
log.error("发送SSE 访问拒绝错误响应失败", e);
}
}
@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException {
String requestUri = request.getRequestURI();
boolean isStreamEndpoint = requestUri.contains(STREAM_ENDPOINT);
boolean isTimelineEndpoint = requestUri.contains(TIMELINE_ENDPOINT);
// 只处理SSE端点
if (isStreamEndpoint || isTimelineEndpoint) {
log.debug("SSE端点授权检查: {} {}", request.getMethod(), requestUri);
// 检查响应是否已经提交,避免后续错误处理异常
if (response.isCommitted()) {
log.warn("响应已提交,无法处理SSE端点授权检查");
return;
}
// 从SecurityContext获取当前认证用户
String userId = null;
var authentication = SecurityContextHolder.getContext().getAuthentication();
if (authentication != null && authentication.isAuthenticated() && !"anonymousUser".equals(authentication.getPrincipal())) {
userId = authentication.getName();
}
if (userId != null) {
log.debug("SSE端点已认证,用户: {}", userId);
// 如果是chat-stream端点,需要额外验证agent权限
if (isStreamEndpoint) {
// 从请求参数中获取agentId
String agentId = request.getParameter("agentId");
if (agentId != null) {
try {
Agent agent = agentService.getAgent(agentId);
if (agent == null) {
log.warn("SSE端点访问失败:Agent不存在 - AgentId: {}", agentId);
sendSseAgentNotFoundError(response);
return;
}
// 验证用户是否有权限访问该agent
if (!agent.getOwner().equals(userId) && !UserUtils.isAdminUser(userId)) {
log.warn("SSE端点访问失败:用户 {} 无权限访问Agent: {}", userId, agentId);
sendSseAccessDeniedError(response);
return;
}
log.debug("SSE端点Agent权限验证成功,用户: {}, Agent: {}", userId, agentId);
} catch (Exception e) {
log.error("SSE端点Agent权限验证异常: {}", e.getMessage());
sendSseAccessDeniedError(response);
return;
}
} else {
log.warn("SSE端点请求缺少agentId参数");
sendSseAgentNotFoundError(response);
return;
}
}
// 继续执行过滤器链
filterChain.doFilter(request, response);
return;
} else {
// 用户未认证,拒绝连接
log.warn("SSE端点未认证访问,拒绝连接: {} {}", request.getMethod(), requestUri);
sendSseUnauthorizedError(response);
return;
}
}
// 继续执行过滤器链(非SSE端点)
filterChain.doFilter(request, response);
}
/**
* 从请求头或参数中提取Token
*/
private String extractTokenFromRequest(HttpServletRequest request) {
// 首先尝试从请求头中提取Token
String authHeader = request.getHeader("Authorization");
if (StringUtils.hasText(authHeader) && authHeader.startsWith("Bearer ")) {
return authHeader.substring(7);
}
// 如果请求头中没有Token,则尝试从URL参数中提取
String tokenParam = request.getParameter("token");
if (StringUtils.hasText(tokenParam)) {
return tokenParam;
}
return null;
}
/**
* 确定此过滤器是否应处理给定请求
* 只处理SSE流式端点
*/
@Override
protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException {
String requestUri = request.getRequestURI();
boolean isStreamEndpoint = requestUri.contains(STREAM_ENDPOINT);
boolean isTimelineEndpoint = requestUri.contains(TIMELINE_ENDPOINT);
// 如果不是SSE端点,跳过此过滤器
return !(isStreamEndpoint || isTimelineEndpoint);
}
}
......@@ -262,7 +262,21 @@ public class PlaywrightWebTools {
return executeWithPage(url, page -> {
// 获取所有a标签的href属性
Object result = page.locator("a").evaluateAll("elements => elements.map(el => el.href)");
List<String> links = (List<String>) result;
// 安全地进行类型转换
List<?> rawList;
if (result instanceof List) {
rawList = (List<?>) result;
} else {
log.warn("预期返回List类型,但实际返回: {}", result != null ? result.getClass().getName() : "null");
return "获取链接失败:返回类型错误";
}
// 安全地转换为List<String>
List<String> links = rawList.stream()
.map(item -> item != null ? item.toString() : "")
.filter(str -> !str.isEmpty())
.toList();
return links.isEmpty() ? "未找到任何链接" : String.join(", ", links);
});
}
......
......@@ -16,7 +16,6 @@ import pangea.hiagent.model.Tool;
import pangea.hiagent.web.repository.AgentDialogueRepository;
import pangea.hiagent.web.repository.AgentRepository;
import pangea.hiagent.web.repository.LlmConfigRepository;
import pangea.hiagent.web.service.AgentToolRelationService;
import pangea.hiagent.common.utils.UserUtils;
import pangea.hiagent.llm.LlmModelFactory;
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment