Commit c107c059 authored by 王舵's avatar 王舵

merge: 合并 main分支代码,解决冲突

parents 5daacfc6 23ae72a7
# HiAgent 问题修复指南
## 问题分析
从终端日志中发现两个主要问题:
1. **LLM配置验证失败**`java.lang.IllegalArgumentException: LLM配置验证失败: deepseek`
2. **Spring Security访问被拒绝**`AuthorizationDeniedException: Access Denied`
## 根本原因
### LLM配置问题
[data.sql](file:///c:/Users/Gavin/Documents/PangeaFinal/HiAgent/backend/src/main/resources/data.sql)中的deepseek配置API密钥为空,而[DeepSeekModelAdapter.java](file:///c:/Users/Gavin/Documents/PangeaFinal/HiAgent/backend/src/main/java/pangea/hiagent/llm/DeepSeekModelAdapter.java)的验证逻辑要求必须有非空的API密钥。
### 安全配置问题
缺少必要的环境变量,包括`DEEPSEEK_API_KEY``JWT_SECRET`
## 解决方案
### 方案一:使用环境变量(推荐)
1. 编辑[run-with-env.bat](file:///c:/Users/Gavin/Documents/PangeaFinal/HiAgent/run-with-env.bat)文件,将占位符替换为实际值:
```batch
set DEEPSEEK_API_KEY=sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx # 替换为你的DeepSeek API密钥
set JWT_SECRET=your-secure-jwt-secret-key # 替换为你自己的JWT密钥
```
2. 运行[run-with-env.bat](file:///c:/Users/Gavin/Documents/PangeaFinal/HiAgent/run-with-env.bat)启动应用:
```bash
run-with-env.bat
```
### 方案二:临时修复(仅用于测试)
如果你只是想快速测试应用而不关心安全性,可以:
1. 修改[LlmConfigService.java](file:///c:/Users/Gavin/Documents/PangeaFinal/HiAgent/backend/src/main/java/pangea/hiagent/service/LlmConfigService.java)中的验证逻辑,允许空API密钥:
```java
// 在DeepSeekModelAdapter.java中修改validateConfig方法
@Override
public boolean validateConfig(LlmConfig config) {
return config != null &&
config.getEnabled();
// 移除了对API密钥非空的检查
}
```
注意:这种方法仅适用于测试环境,生产环境中必须配置有效的API密钥。
## 登录凭证
默认登录账户:
- 用户名:`admin`
- 密码:`admin123` (如果使用的是开发环境默认密码)
## 验证修复
启动应用后,可以通过以下方式验证修复是否成功:
1. 访问 http://localhost:8080 并使用默认账户登录
2. 进入Agent管理页面,确认Agent可以正常加载
3. 尝试与Agent进行对话,确认不再出现"LLM配置验证失败"错误
## 故障排除
如果仍然遇到问题,请检查:
1. 确认环境变量已正确设置
2. 确认数据库已正确初始化
3. 查看应用启动日志中是否有其他错误信息
4. 确认网络连接正常,可以访问DeepSeek API
\ No newline at end of file
# HiAgent - 智能AI助手
HiAgent 是一个功能强大的个人AI助手,集成了多种工具和服务,能够帮助用户完成各种任务。
## 🌟 核心功能
### 网页访问和内容提取
- **网页访问工具**:能够根据网站名称或URL访问网页并在工作面板中预览
- **网页内容提取工具**:智能提取网页正文内容,自动识别并提取文章标题和主要内容,过滤掉广告、导航栏等无关内容
- **增强型网页嵌入预览**:支持多种加载策略(直接HTML、直接获取内容、iframe),自动处理X-Frame-Options等安全限制
### 计算和数据处理
- **计算器工具**:执行基本数学运算和复杂数学计算
- **日期时间工具**:获取当前时间、日期计算等
- **文件处理工具**:文件上传、下载和处理
- **字符串处理工具**:文本处理和转换功能
### 其他实用工具
- **天气查询工具**:获取指定城市的天气信息
- **OAuth2.0授权工具**:支持通过用户名和密码凭证获取网页资源访问授权,实现标准OAuth2.0认证流程
## 🛠 技术架构
### 后端技术栈
- **Spring Boot 3.3.4**:基于Java 17的现代化Web框架
- **Spring AI**:集成多种AI模型和服务
- **MySQL/H2**:数据存储
- **Redis**:缓存和会话管理
- **Milvus**:向量数据库支持
- **RabbitMQ**:消息队列服务
### 前端技术栈
- **Vue 3**:现代化的前端框架
- **TypeScript**:类型安全的JavaScript超集
- **Vite**:快速的构建工具
## 📦 主要工具介绍
### 网页内容提取工具 (WebContentExtractorTool)
这是一个专门用于从网页中提取有意义文本内容的工具。它能够自动识别并提取网页的标题和正文内容,同时过滤掉广告、导航栏等无关内容。
#### 功能特点
1. **智能内容提取**:自动识别网页的主要内容区域
2. **广告过滤**:自动过滤广告、导航栏等无关内容
3. **格式保留**:保留原文的标题层级和段落结构
4. **错误处理**:完善的错误处理机制和日志记录
#### 使用方法
在Agent对话中直接调用:
```
extractWebContent("https://example.com/article")
```
### 网页访问工具 (WebPageAccessTools)
提供根据网站名称或URL地址访问网页并在工作面板中预览的功能。
#### 功能特点
1. **多种访问方式**:支持按网站名称或直接URL访问
2. **内置网站映射**:支持常见网站的快捷访问
3. **工作面板集成**:直接在工作面板中预览网页内容
4. **多种加载策略**:支持HTML内容、直接获取内容和iframe三种加载方式
5. **智能错误处理**:自动处理X-Frame-Options等安全限制,提供友好的错误提示
#### 使用方法
```
accessWebSiteByName("百度")
accessWebSiteByUrl("https://www.example.com")
```
### 增强型网页嵌入预览 (EmbedPreview)
提供增强的网页嵌入预览功能,支持多种加载策略以应对不同的安全限制。
#### 功能特点
1. **多种加载策略**
- **HTML内容**:直接渲染后端提供的HTML内容
- **直接获取**:通过代理API获取网页内容并直接渲染(绕过X-Frame-Options限制)
- **iframe加载**:传统的iframe嵌入方式(备选方案)
2. **智能回退机制**:当一种策略失败时自动尝试其他策略
3. **安全处理**:使用DOMPurify清理内容,防止XSS攻击
4. **错误处理**:完善的错误处理和用户友好的错误提示
5. **响应式设计**:适配不同屏幕尺寸
#### 使用方法
```vue
<EmbedPreview
:html-content="htmlContent"
:embed-url="url"
embed-title="预览标题"
embed-type="网页"
/>
```
### OAuth2.0授权工具 (OAuth2AuthorizationTool)
这是一个支持OAuth2.0标准认证流程的工具,允许用户通过用户名和密码凭证获取访问受保护资源的令牌。
#### 功能特点
1. **标准OAuth2.0支持**:完全符合OAuth2.0 RFC标准
2. **密码凭证流**:支持Resource Owner Password Credentials Grant流程
3. **令牌管理**:自动管理和缓存访问令牌
4. **令牌刷新**:支持使用刷新令牌获取新的访问令牌
5. **安全存储**:令牌安全存储,自动处理过期令牌
6. **资源访问**:使用获取的令牌访问受保护的资源
#### 使用方法
```
// 1. 获取访问令牌
authorizeWithPasswordCredentials(
"https://example.com/oauth/token",
"your-client-id",
"your-client-secret",
"your-username",
"your-password",
"read write"
)
// 2. 刷新访问令牌
refreshToken(
"https://example.com/oauth/token",
"your-client-id",
"your-client-secret",
"your-refresh-token"
)
// 3. 访问受保护资源
accessProtectedResource(
"https://example.com/api/protected",
"https://example.com/oauth/token",
"your-client-id"
)
```
有关更详细的使用说明,请参阅 [OAuth2.0工具使用指南](OAUTH2_TOOL_USAGE_GUIDE.md)
## 🚀 快速开始
### 环境要求
- Java 17+
- Node.js 16+
- Maven 3.8+
- MySQL 8.0+ (可选,也可使用内置H2数据库)
### 后端启动
```bash
cd backend
mvn spring-boot:run
```
### 前端启动
```bash
cd frontend
npm install
npm run dev
```
### 一键启动脚本
- Windows: `run-all-debug.bat`
- 后端独立: `run-backend-debug.bat`
- 前端独立: `run-frontend-debug.bat`
## 📁 项目结构
```
HiAgent/
├── backend/ # 后端服务
│ ├── src/
│ │ ├── main/
│ │ │ ├── java/ # Java源代码
│ │ │ └── resources/ # 配置文件和静态资源
│ │ └── test/ # 测试代码
│ └── pom.xml # Maven配置文件
├── frontend/ # 前端应用
│ ├── src/ # Vue源代码
│ ├── public/ # 静态资源
│ └── package.json # NPM配置文件
├── docker-compose.yml # Docker编排文件
└── README.md # 项目说明文件
```
## 🔧 配置说明
### 数据库配置
项目支持MySQL和H2数据库,默认使用H2内存数据库,可通过修改`application.yml`配置切换。
### AI模型配置
支持多种AI模型:
- OpenAI/DeepSeek
- Ollama本地模型
- 其他兼容OpenAI API的模型
## 🧪 测试
### 后端测试
```bash
cd backend
mvn test
```
### 前端测试
```bash
cd frontend
npm run test
```
## 🐳 Docker部署
使用docker-compose一键部署:
```bash
docker-compose up -d
```
## 📚 文档
项目包含丰富的技术文档:
- 工具使用说明
- 架构设计文档
- 部署指南
- 故障排查手册
## 🤝 贡献
欢迎提交Issue和Pull Request来改进项目。
## 📄 许可证
本项目采用MIT许可证。
## 🙏 致谢
感谢所有开源项目的贡献者们,以及支持这个项目开发的用户。
\ No newline at end of file
{
"name": "backend",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"dependencies": {
"playwright": "^1.57.0"
}
},
"node_modules/fsevents": {
"version": "2.3.2",
"resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.2.tgz",
"integrity": "sha512-xiqMQR4xAeHTuB9uWm+fFRcIOgKBMiOBP+eXiyT7jsgVCq1bkVygt00oASowB7EdtpOHaaPgKt812P9ab+DDKA==",
"hasInstallScript": true,
"license": "MIT",
"optional": true,
"os": [
"darwin"
],
"engines": {
"node": "^8.16.0 || ^10.6.0 || >=11.0.0"
}
},
"node_modules/playwright": {
"version": "1.57.0",
"resolved": "https://registry.npmjs.org/playwright/-/playwright-1.57.0.tgz",
"integrity": "sha512-ilYQj1s8sr2ppEJ2YVadYBN0Mb3mdo9J0wQ+UuDhzYqURwSoW4n1Xs5vs7ORwgDGmyEh33tRMeS8KhdkMoLXQw==",
"license": "Apache-2.0",
"dependencies": {
"playwright-core": "1.57.0"
},
"bin": {
"playwright": "cli.js"
},
"engines": {
"node": ">=18"
},
"optionalDependencies": {
"fsevents": "2.3.2"
}
},
"node_modules/playwright-core": {
"version": "1.57.0",
"resolved": "https://registry.npmjs.org/playwright-core/-/playwright-core-1.57.0.tgz",
"integrity": "sha512-agTcKlMw/mjBWOnD6kFZttAAGHgi/Nw0CZ2o6JqWSbMlI219lAFLZZCyqByTsvVAJq5XA5H8cA6PrvBRpBWEuQ==",
"license": "Apache-2.0",
"bin": {
"playwright-core": "cli.js"
},
"engines": {
"node": ">=18"
}
}
}
}
{
"dependencies": {
"playwright": "^1.57.0"
}
}
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
<milvus-lite.version>2.3.0</milvus-lite.version> <milvus-lite.version>2.3.0</milvus-lite.version>
<jjwt.version>0.12.6</jjwt.version> <jjwt.version>0.12.6</jjwt.version>
<caffeine.version>3.1.8</caffeine.version> <caffeine.version>3.1.8</caffeine.version>
<maven.compiler.encoding>UTF-8</maven.compiler.encoding>
</properties> </properties>
<dependencyManagement> <dependencyManagement>
...@@ -307,6 +308,25 @@ ...@@ -307,6 +308,25 @@
<version>2.0.48</version> <version>2.0.48</version>
</dependency> </dependency>
<!-- Quartz for job scheduling -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-quartz</artifactId>
</dependency>
<!-- Spring Boot Mail Starter for POP3 email access -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-mail</artifactId>
</dependency>
<!-- Cron utils for cron expression parsing -->
<dependency>
<groupId>com.cronutils</groupId>
<artifactId>cron-utils</artifactId>
<version>9.2.0</version>
</dependency>
</dependencies> </dependencies>
...@@ -381,6 +401,7 @@ ...@@ -381,6 +401,7 @@
<configuration> <configuration>
<source>17</source> <source>17</source>
<target>17</target> <target>17</target>
<encoding>UTF-8</encoding>
<annotationProcessorPaths> <annotationProcessorPaths>
<path> <path>
<groupId>org.projectlombok</groupId> <groupId>org.projectlombok</groupId>
...@@ -396,6 +417,9 @@ ...@@ -396,6 +417,9 @@
<groupId>org.apache.maven.plugins</groupId> <groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId> <artifactId>maven-surefire-plugin</artifactId>
<version>3.0.0</version> <version>3.0.0</version>
<configuration>
<argLine>-Dfile.encoding=UTF-8</argLine>
</configuration>
</plugin> </plugin>
</plugins> </plugins>
</build> </build>
......
...@@ -17,24 +17,32 @@ import pangea.hiagent.utils.JwtUtil; ...@@ -17,24 +17,32 @@ import pangea.hiagent.utils.JwtUtil;
import pangea.hiagent.websocket.DomSyncHandler; import pangea.hiagent.websocket.DomSyncHandler;
import java.util.Map; import java.util.Map;
import lombok.extern.slf4j.Slf4j;
/** /**
* WebSocket配置类 * WebSocket配置类
*/ */
@Slf4j
@Configuration @Configuration
@EnableWebSocket @EnableWebSocket
public class DomSyncWebSocketConfig implements WebSocketConfigurer { public class DomSyncWebSocketConfig implements WebSocketConfigurer {
private final JwtHandshakeInterceptor jwtHandshakeInterceptor; private final JwtHandshakeInterceptor jwtHandshakeInterceptor;
private final pangea.hiagent.core.PlaywrightManager playwrightManager;
public DomSyncWebSocketConfig(JwtHandshakeInterceptor jwtHandshakeInterceptor) { public DomSyncWebSocketConfig(JwtHandshakeInterceptor jwtHandshakeInterceptor,
pangea.hiagent.core.PlaywrightManager playwrightManager) {
this.jwtHandshakeInterceptor = jwtHandshakeInterceptor; this.jwtHandshakeInterceptor = jwtHandshakeInterceptor;
this.playwrightManager = playwrightManager;
} }
// 注入DomSyncHandler,交由Spring管理生命周期 // 注入DomSyncHandler,交由Spring管理生命周期
@Bean @Bean
public DomSyncHandler domSyncHandler() { public DomSyncHandler domSyncHandler() {
return new DomSyncHandler(); DomSyncHandler handler = new DomSyncHandler();
// 通过设置器注入PlaywrightManager
handler.setPlaywrightManager(playwrightManager);
return handler;
} }
@Override @Override
...@@ -50,6 +58,7 @@ public class DomSyncWebSocketConfig implements WebSocketConfigurer { ...@@ -50,6 +58,7 @@ public class DomSyncWebSocketConfig implements WebSocketConfigurer {
/** /**
* JWT握手拦截器,用于WebSocket连接时的认证 * JWT握手拦截器,用于WebSocket连接时的认证
*/ */
@Slf4j
@Component @Component
class JwtHandshakeInterceptor implements HandshakeInterceptor { class JwtHandshakeInterceptor implements HandshakeInterceptor {
private final JwtUtil jwtUtil; private final JwtUtil jwtUtil;
...@@ -62,9 +71,13 @@ class JwtHandshakeInterceptor implements HandshakeInterceptor { ...@@ -62,9 +71,13 @@ class JwtHandshakeInterceptor implements HandshakeInterceptor {
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception { WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
String token = extractTokenFromRequest(request); String token = extractTokenFromRequest(request);
String clientInfo = "[" + (request.getRemoteAddress() != null ? request.getRemoteAddress().toString() : "unknown") + "] ";
log.info(clientInfo + "WebSocket握手请求 - URI: {}, Query: {}", request.getURI(), request.getURI().getQuery());
if (StringUtils.hasText(token)) { if (StringUtils.hasText(token)) {
try { try {
log.debug(clientInfo + "Token提取成功,长度: {}", token.length());
// 验证token是否有效 // 验证token是否有效
boolean isValid = jwtUtil.validateToken(token); boolean isValid = jwtUtil.validateToken(token);
if (isValid) { if (isValid) {
...@@ -73,30 +86,62 @@ class JwtHandshakeInterceptor implements HandshakeInterceptor { ...@@ -73,30 +86,62 @@ class JwtHandshakeInterceptor implements HandshakeInterceptor {
if (userId != null) { if (userId != null) {
attributes.put("token", token); attributes.put("token", token);
attributes.put("userId", userId); attributes.put("userId", userId);
System.out.println("WebSocket连接认证成功,用户ID: " + userId); log.info(clientInfo + "WebSocket连接认证成功,用户ID: {}", userId);
return true; return true;
} else { } else {
System.err.println("无法从token中提取用户ID"); log.error(clientInfo + "错误:无法从token中提取用户ID。Token长度: {}", token.length());
log.error(clientInfo + "token前50字符: {}", token.substring(0, Math.min(50, token.length())));
// 尝试从token的payload中直接解析userId
try {
String[] parts = token.split("\\.");
if (parts.length > 1) {
String payload = new String(java.util.Base64.getUrlDecoder().decode(parts[1]));
log.error(clientInfo + "token payload: {}", payload);
}
} catch (Exception payloadEx) {
log.error(clientInfo + "解析token payload时发生异常: {}", payloadEx.getMessage(), payloadEx);
}
} }
} else { } else {
System.err.println("JWT验证失败,token可能已过期或无效"); boolean isExpired = jwtUtil.isTokenExpired(token);
log.error(clientInfo + "JWT验证失败。Token已过期: {}", isExpired);
// 如果Token已过期,返回401状态码和明确的错误信息
response.setStatusCode(org.springframework.http.HttpStatus.UNAUTHORIZED);
response.getHeaders().add("WWW-Authenticate", "Bearer error=\"invalid_token\", error_description=\"Token expired\"");
return false;
} }
} catch (Exception e) { } catch (Exception e) {
System.err.println("JWT验证过程中发生错误: " + e.getMessage()); log.error(clientInfo + "JWT验证过程中发生异常: {}", e.getClass().getSimpleName(), e);
e.printStackTrace();
// 如果验证过程出现异常,返回401状态码
response.setStatusCode(org.springframework.http.HttpStatus.UNAUTHORIZED);
response.getHeaders().add("WWW-Authenticate", "Bearer error=\"invalid_token\", error_description=\"Token validation failed\"");
return false;
} }
} else {
log.warn(clientInfo + "WebSocket连接缺少认证token");
log.warn(clientInfo + "请求头Authorization: {}", request.getHeaders().getFirst("Authorization"));
String query = request.getURI().getQuery();
log.warn(clientInfo + "查询字符串: {}", query != null ? query : "(为空)");
} }
// 如果没有有效的token,拒绝连接 // 如果没有有效的token,拒绝连接
System.err.println("WebSocket连接缺少有效的认证token"); log.warn(clientInfo + "拒绝WebSocket连接,返回401 UNAUTHORIZED");
response.setStatusCode(org.springframework.http.HttpStatus.UNAUTHORIZED); response.setStatusCode(org.springframework.http.HttpStatus.UNAUTHORIZED);
response.getHeaders().add("WWW-Authenticate", "Bearer realm=\"WebSocket\"");
return false; return false;
} }
@Override @Override
public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler wsHandler, Exception exception) { WebSocketHandler wsHandler, Exception exception) {
// 握手后处理,这里不需要特殊处理 String clientInfo = "[" + (request.getRemoteAddress() != null ? request.getRemoteAddress().toString() : "unknown") + "] ";
if (exception != null) {
log.error(clientInfo + "WebSocket握手失败,异常: {}", exception.getClass().getSimpleName(), exception);
} else {
log.info(clientInfo + "WebSocket握手后处理完成");
}
} }
/** /**
...@@ -108,16 +153,24 @@ class JwtHandshakeInterceptor implements HandshakeInterceptor { ...@@ -108,16 +153,24 @@ class JwtHandshakeInterceptor implements HandshakeInterceptor {
String authHeader = request.getHeaders().getFirst("Authorization"); String authHeader = request.getHeaders().getFirst("Authorization");
if (StringUtils.hasText(authHeader) && authHeader.startsWith("Bearer ")) { if (StringUtils.hasText(authHeader) && authHeader.startsWith("Bearer ")) {
String token = authHeader.substring(7); String token = authHeader.substring(7);
log.debug("从Authorization头中提取Token,长度: {}", token.length());
return token; return token;
} }
// 如果请求头中没有Token,则尝试从URL参数中提取 // 如果请求头中没有Token,则尝试从URL参数中提取
String query = request.getURI().getQuery(); String query = request.getURI().getQuery();
if (query != null) { if (query != null) {
UriComponentsBuilder builder = UriComponentsBuilder.newInstance().query(query); try {
String token = builder.build().getQueryParams().getFirst("token"); UriComponentsBuilder builder = UriComponentsBuilder.newInstance().query(query);
if (StringUtils.hasText(token)) { String token = builder.build().getQueryParams().getFirst("token");
return token; if (StringUtils.hasText(token)) {
log.debug("从URL参数中提取Token,长度: {}", token.length());
return token;
} else {
log.debug("URL中没有token参数,Query: {}", query);
}
} catch (Exception e) {
log.warn("解析URL参数时出错: {}", e.getMessage());
} }
} }
......
package pangea.hiagent.config;
import org.quartz.Scheduler;
import org.quartz.spi.JobFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.scheduling.quartz.SchedulerFactoryBean;
import org.springframework.scheduling.quartz.SpringBeanJobFactory;
/**
* Quartz配置类
* 配置Quartz调度器和相关组件
*/
@Configuration
public class QuartzConfig {
/**
* 配置JobFactory,用于将Spring的Bean注入到Quartz的Job中
*/
@Bean
public JobFactory jobFactory() {
SpringBeanJobFactory jobFactory = new SpringBeanJobFactory();
return jobFactory;
}
/**
* 配置SchedulerFactoryBean,用于创建Scheduler实例
*/
@Bean
public SchedulerFactoryBean schedulerFactoryBean(@Autowired JobFactory jobFactory) {
SchedulerFactoryBean factory = new SchedulerFactoryBean();
factory.setJobFactory(jobFactory);
factory.setWaitForJobsToCompleteOnShutdown(true);
factory.setOverwriteExistingJobs(true);
return factory;
}
/**
* 配置Scheduler实例,用于管理和执行定时任务
*/
@Bean
public Scheduler scheduler(@Autowired SchedulerFactoryBean factory) {
return factory.getScheduler();
}
}
...@@ -19,6 +19,8 @@ import org.springframework.web.cors.CorsConfigurationSource; ...@@ -19,6 +19,8 @@ import org.springframework.web.cors.CorsConfigurationSource;
import org.springframework.web.cors.UrlBasedCorsConfigurationSource; import org.springframework.web.cors.UrlBasedCorsConfigurationSource;
import pangea.hiagent.security.DefaultPermissionEvaluator; import pangea.hiagent.security.DefaultPermissionEvaluator;
import pangea.hiagent.security.JwtAuthenticationFilter; import pangea.hiagent.security.JwtAuthenticationFilter;
import pangea.hiagent.service.AgentService;
import pangea.hiagent.service.TimerService;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
...@@ -30,11 +32,13 @@ import java.util.Collections; ...@@ -30,11 +32,13 @@ import java.util.Collections;
public class SecurityConfig { public class SecurityConfig {
private final JwtAuthenticationFilter jwtAuthenticationFilter; private final JwtAuthenticationFilter jwtAuthenticationFilter;
private final DefaultPermissionEvaluator customPermissionEvaluator; private final AgentService agentService;
private final TimerService timerService;
public SecurityConfig(JwtAuthenticationFilter jwtAuthenticationFilter, DefaultPermissionEvaluator customPermissionEvaluator) { public SecurityConfig(JwtAuthenticationFilter jwtAuthenticationFilter, AgentService agentService, TimerService timerService) {
this.jwtAuthenticationFilter = jwtAuthenticationFilter; this.jwtAuthenticationFilter = jwtAuthenticationFilter;
this.customPermissionEvaluator = customPermissionEvaluator; this.agentService = agentService;
this.timerService = timerService;
} }
/** /**
...@@ -51,7 +55,9 @@ public class SecurityConfig { ...@@ -51,7 +55,9 @@ public class SecurityConfig {
@Bean @Bean
public MethodSecurityExpressionHandler methodSecurityExpressionHandler() { public MethodSecurityExpressionHandler methodSecurityExpressionHandler() {
DefaultMethodSecurityExpressionHandler expressionHandler = new DefaultMethodSecurityExpressionHandler(); DefaultMethodSecurityExpressionHandler expressionHandler = new DefaultMethodSecurityExpressionHandler();
expressionHandler.setPermissionEvaluator(customPermissionEvaluator); // 创建带有AgentService和TimerService的权限评估器
DefaultPermissionEvaluator permissionEvaluator = new DefaultPermissionEvaluator(agentService, timerService);
expressionHandler.setPermissionEvaluator(permissionEvaluator);
return expressionHandler; return expressionHandler;
} }
...@@ -87,6 +93,8 @@ public class SecurityConfig { ...@@ -87,6 +93,8 @@ public class SecurityConfig {
.sessionManagement(session -> session.sessionCreationPolicy(SessionCreationPolicy.STATELESS)) .sessionManagement(session -> session.sessionCreationPolicy(SessionCreationPolicy.STATELESS))
// 配置请求授权 // 配置请求授权
.authorizeHttpRequests(authz -> authz .authorizeHttpRequests(authz -> authz
// WebSocket端点 - 由握手拦截器处理认证,不需要通过Spring Security过滤链
.requestMatchers("/ws/**").permitAll()
// OAuth2 相关端点公开访问 // OAuth2 相关端点公开访问
.requestMatchers("/api/v1/auth/oauth2/**").permitAll() .requestMatchers("/api/v1/auth/oauth2/**").permitAll()
// OAuth2提供商管理端点需要认证(仅管理员可访问) // OAuth2提供商管理端点需要认证(仅管理员可访问)
...@@ -123,6 +131,12 @@ public class SecurityConfig { ...@@ -123,6 +131,12 @@ public class SecurityConfig {
response.getWriter().write("{\"code\":401,\"message\":\"未授权访问\",\"timestamp\":" + System.currentTimeMillis() + "}"); response.getWriter().write("{\"code\":401,\"message\":\"未授权访问\",\"timestamp\":" + System.currentTimeMillis() + "}");
}) })
.accessDeniedHandler((request, response, accessDeniedException) -> { .accessDeniedHandler((request, response, accessDeniedException) -> {
// 检查响应是否已经提交
if (response.isCommitted()) {
System.err.println("响应已经提交,无法处理访问拒绝异常: " + request.getRequestURI());
return;
}
response.setStatus(403); response.setStatus(403);
response.setContentType("application/json;charset=UTF-8"); response.setContentType("application/json;charset=UTF-8");
response.getWriter().write("{\"code\":403,\"message\":\"访问被拒绝\",\"timestamp\":" + System.currentTimeMillis() + "}"); response.getWriter().write("{\"code\":403,\"message\":\"访问被拒绝\",\"timestamp\":" + System.currentTimeMillis() + "}");
......
package pangea.hiagent.config;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.core.context.SecurityContextHolder;
import jakarta.annotation.PostConstruct;
/**
* SecurityContext配置类
* 用于配置SecurityContextHolder策略,支持异步线程间传播认证信息
*/
@Configuration
public class SecurityContextConfig {
/**
* 在应用启动时设置SecurityContextHolder策略为MODE_INHERITABLETHREADLOCAL
* 这样可以在父子线程之间自动传播SecurityContext
*/
@PostConstruct
public void configureSecurityContextHolderStrategy() {
// 设置SecurityContextHolder策略为可继承的ThreadLocal模式
// 这样在异步线程中也可以获取到父线程的认证信息
SecurityContextHolder.setStrategyName(SecurityContextHolder.MODE_INHERITABLETHREADLOCAL);
}
}
\ No newline at end of file
...@@ -10,6 +10,7 @@ import pangea.hiagent.dto.PageData; ...@@ -10,6 +10,7 @@ import pangea.hiagent.dto.PageData;
import pangea.hiagent.model.Agent; import pangea.hiagent.model.Agent;
import pangea.hiagent.service.AgentService; import pangea.hiagent.service.AgentService;
import com.baomidou.mybatisplus.core.metadata.IPage; import com.baomidou.mybatisplus.core.metadata.IPage;
import pangea.hiagent.utils.UserUtils;
/** /**
* Agent API控制器 * Agent API控制器
...@@ -152,8 +153,6 @@ public class AgentController { ...@@ -152,8 +153,6 @@ public class AgentController {
* 获取当前认证用户ID * 获取当前认证用户ID
*/ */
private String getCurrentUserId() { private String getCurrentUserId() {
Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); return UserUtils.getCurrentUserId();
return (authentication != null && authentication.getPrincipal() != null) ?
(String) authentication.getPrincipal() : null;
} }
} }
\ No newline at end of file
...@@ -8,6 +8,7 @@ import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; ...@@ -8,6 +8,7 @@ import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import pangea.hiagent.workpanel.SseEventManager; import pangea.hiagent.workpanel.SseEventManager;
import pangea.hiagent.utils.UserUtils;
import pangea.hiagent.dto.WorkPanelEvent; import pangea.hiagent.dto.WorkPanelEvent;
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
...@@ -30,11 +31,7 @@ public class TimelineEventController { ...@@ -30,11 +31,7 @@ public class TimelineEventController {
* 获取当前认证用户ID * 获取当前认证用户ID
*/ */
private String getCurrentUserId() { private String getCurrentUserId() {
Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); return UserUtils.getCurrentUserId();
if (authentication != null && authentication.getPrincipal() != null) {
return (String) authentication.getPrincipal();
}
return null;
} }
/** /**
...@@ -48,23 +45,18 @@ public class TimelineEventController { ...@@ -48,23 +45,18 @@ public class TimelineEventController {
log.info("开始处理时间轴事件订阅请求"); log.info("开始处理时间轴事件订阅请求");
String userId = getCurrentUserId(); String userId = getCurrentUserId();
// 创建 SSE emitter
SseEmitter emitter = new SseEmitter(300000L); // 5分钟超时
if (userId == null) { if (userId == null) {
log.error("用户未认证"); log.error("用户未认证");
// 立即创建并完成emitter,不发送任何数据 // 使用sendError方法发送错误信息,而不是直接completeWithError
SseEmitter emitter = new SseEmitter(300000L); sseEventManager.sendError(emitter, "用户未认证");
try {
emitter.completeWithError(new IllegalArgumentException("用户未认证"));
} catch (Exception e) {
log.error("完成SSE连接失败", e);
}
return emitter; return emitter;
} }
log.debug("用户认证成功,用户ID: {}", userId); log.debug("用户认证成功,用户ID: {}", userId);
// 创建 SSE emitter
SseEmitter emitter = new SseEmitter(300000L); // 5分钟超时
// 注册 emitter 回调 // 注册 emitter 回调
emitter.onCompletion(() -> { emitter.onCompletion(() -> {
log.debug("SSE连接完成"); log.debug("SSE连接完成");
......
package pangea.hiagent.controller;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.bind.annotation.*;
import pangea.hiagent.dto.ApiResponse;
import pangea.hiagent.dto.TimerExecutionHistoryDto;
import pangea.hiagent.service.HistoryService;
/**
* 定时器执行历史API控制器
* 负责处理执行历史的查询和管理
*/
@Slf4j
@RestController
@RequestMapping("/api/v1/timer-history")
public class TimerHistoryController {
private final HistoryService historyService;
public TimerHistoryController(HistoryService historyService) {
this.historyService = historyService;
}
/**
* 获取执行历史列表
*/
@GetMapping
public ApiResponse<Page<TimerExecutionHistoryDto>> listExecutionHistory(
@RequestParam(required = false) String timerId,
@RequestParam(required = false) Integer success,
@RequestParam(required = false) String startTime,
@RequestParam(required = false) String endTime,
@RequestParam(defaultValue = "1") int page,
@RequestParam(defaultValue = "10") int size) {
try {
log.info("获取执行历史列表,timerId: {}, success: {}, startTime: {}, endTime: {}",
timerId, success, startTime, endTime);
Page<TimerExecutionHistoryDto> historyPage = historyService.getExecutionHistoryList(
timerId, success, startTime, endTime, page, size);
return ApiResponse.success(historyPage, "获取执行历史成功");
} catch (Exception e) {
log.error("获取执行历史失败", e);
return ApiResponse.error(4001, "获取执行历史失败: " + e.getMessage());
}
}
/**
* 获取指定定时器的执行历史
*/
@GetMapping("/{timerId}")
public ApiResponse<Page<TimerExecutionHistoryDto>> listTimerExecutionHistory(
@PathVariable String timerId,
@RequestParam(defaultValue = "1") int page,
@RequestParam(defaultValue = "10") int size) {
try {
log.info("获取定时器 {} 的执行历史", timerId);
Page<TimerExecutionHistoryDto> historyPage = historyService.getExecutionHistoryByTimerId(
timerId, page, size);
return ApiResponse.success(historyPage, "获取定时器执行历史成功");
} catch (Exception e) {
log.error("获取定时器执行历史失败", e);
return ApiResponse.error(4001, "获取定时器执行历史失败: " + e.getMessage());
}
}
/**
* 获取执行历史详情
*/
@GetMapping("/detail/{id}")
public ApiResponse<TimerExecutionHistoryDto> getExecutionHistoryDetail(@PathVariable Long id) {
try {
log.info("获取执行历史详情: {}", id);
TimerExecutionHistoryDto historyDetail = historyService.getExecutionHistoryDetail(id);
if (historyDetail == null) {
return ApiResponse.error(4004, "执行历史不存在");
}
return ApiResponse.success(historyDetail, "获取执行历史详情成功");
} catch (Exception e) {
log.error("获取执行历史详情失败", e);
return ApiResponse.error(4001, "获取执行历史详情失败: " + e.getMessage());
}
}
}
\ No newline at end of file
package pangea.hiagent.controller;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import pangea.hiagent.model.ToolConfig;
import pangea.hiagent.service.ToolConfigService;
import java.util.List;
import java.util.Map;
/**
* 工具配置控制器
* 提供参数配置的REST API
*/
@Slf4j
@RestController
@RequestMapping("/api/v1/tool-configs")
public class ToolConfigController {
@Autowired
private ToolConfigService toolConfigService;
/**
* 获取所有工具配置
* @return 工具配置列表
*/
@GetMapping
public ResponseEntity<List<ToolConfig>> getAllToolConfigs() {
log.debug("获取所有工具配置");
List<ToolConfig> toolConfigs = toolConfigService.getAllToolConfigs();
return ResponseEntity.ok(toolConfigs);
}
/**
* 根据工具名称获取参数配置
* @param toolName 工具名称
* @return 参数配置键值对
*/
@GetMapping("/{toolName}")
public ResponseEntity<Map<String, String>> getToolParams(@PathVariable String toolName) {
log.debug("根据工具名称获取参数配置,工具名称:{}", toolName);
Map<String, String> params = toolConfigService.getToolParams(toolName);
return ResponseEntity.ok(params);
}
/**
* 根据工具名称和参数名称获取参数值
* @param toolName 工具名称
* @param paramName 参数名称
* @return 参数值
*/
@GetMapping("/{toolName}/{paramName}")
public ResponseEntity<String> getParamValue(@PathVariable String toolName, @PathVariable String paramName) {
log.debug("根据工具名称和参数名称获取参数值,工具名称:{},参数名称:{}", toolName, paramName);
String paramValue = toolConfigService.getParamValue(toolName, paramName);
return ResponseEntity.ok(paramValue);
}
/**
* 保存工具配置
* @param toolConfig 工具配置对象
* @return 保存后的工具配置对象
*/
@PostMapping
public ResponseEntity<ToolConfig> saveToolConfig(@RequestBody ToolConfig toolConfig) {
log.debug("保存工具配置:{}", toolConfig);
ToolConfig savedConfig = toolConfigService.saveToolConfig(toolConfig);
if (savedConfig != null) {
return ResponseEntity.ok(savedConfig);
} else {
return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).build();
}
}
/**
* 保存参数值
* @param toolName 工具名称
* @param paramName 参数名称
* @param paramValue 参数值
* @return 保存结果
*/
@PutMapping("/{toolName}/{paramName}")
public ResponseEntity<Void> saveParamValue(@PathVariable String toolName, @PathVariable String paramName, @RequestBody String paramValue) {
log.debug("保存参数值,工具名称:{},参数名称:{},参数值:{}", toolName, paramName, paramValue);
toolConfigService.saveParamValue(toolName, paramName, paramValue);
return ResponseEntity.ok().build();
}
/**
* 删除工具配置
* @param id 配置ID
* @return 删除结果
*/
@DeleteMapping("/{id}")
public ResponseEntity<Void> deleteToolConfig(@PathVariable String id) {
log.debug("删除工具配置,ID:{}", id);
toolConfigService.deleteToolConfig(id);
return ResponseEntity.ok().build();
}
}
\ No newline at end of file
...@@ -9,6 +9,7 @@ import org.springframework.web.bind.annotation.*; ...@@ -9,6 +9,7 @@ import org.springframework.web.bind.annotation.*;
import pangea.hiagent.dto.ApiResponse; import pangea.hiagent.dto.ApiResponse;
import pangea.hiagent.model.Tool; import pangea.hiagent.model.Tool;
import pangea.hiagent.service.ToolService; import pangea.hiagent.service.ToolService;
import pangea.hiagent.utils.UserUtils;
import java.util.List; import java.util.List;
...@@ -33,11 +34,7 @@ public class ToolController { ...@@ -33,11 +34,7 @@ public class ToolController {
* @return 用户ID * @return 用户ID
*/ */
private String getCurrentUserId() { private String getCurrentUserId() {
Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); return UserUtils.getCurrentUserId();
if (authentication != null && authentication.getPrincipal() instanceof String) {
return (String) authentication.getPrincipal();
}
return null;
} }
/** /**
...@@ -129,7 +126,7 @@ public class ToolController { ...@@ -129,7 +126,7 @@ public class ToolController {
* 获取工具列表 * 获取工具列表
*/ */
@GetMapping @GetMapping
@Operation(summary = "获取工具列表", description = "获取所有可用工具") @Operation(summary = "获取工具列表", description = "获取当前用户可用工具")
public ApiResponse<List<Tool>> getTools() { public ApiResponse<List<Tool>> getTools() {
try { try {
String userId = getCurrentUserId(); String userId = getCurrentUserId();
...@@ -137,7 +134,7 @@ public class ToolController { ...@@ -137,7 +134,7 @@ public class ToolController {
return ApiResponse.error(4001, "用户未认证"); return ApiResponse.error(4001, "用户未认证");
} }
List<Tool> tools = toolService.getAllTools(); List<Tool> tools = toolService.getUserTools(userId);
return ApiResponse.success(tools, "获取工具列表成功"); return ApiResponse.success(tools, "获取工具列表成功");
} catch (Exception e) { } catch (Exception e) {
log.error("获取工具列表失败", e); log.error("获取工具列表失败", e);
......
package pangea.hiagent.core;
import com.microsoft.playwright.Browser;
import com.microsoft.playwright.BrowserContext;
import com.microsoft.playwright.Playwright;
/**
* Playwright管理器接口
* 提供统一的Playwright实例管理和用户隔离机制
*/
public interface PlaywrightManager {
/**
* 获取共享的Playwright实例
*
* @return Playwright实例
*/
Playwright getPlaywright();
/**
* 获取共享的浏览器实例
*
* @return Browser实例
*/
Browser getBrowser();
/**
* 为指定用户获取专用的浏览器上下文
* 实现用户级别的隔离
*
* @param userId 用户ID
* @return 该用户专用的BrowserContext
*/
BrowserContext getUserContext(String userId);
/**
* 为指定用户获取专用的浏览器上下文(带自定义配置)
*
* @param userId 用户ID
* @param options 浏览器上下文选项
* @return 该用户专用的BrowserContext
*/
BrowserContext getUserContext(String userId, Browser.NewContextOptions options);
/**
* 释放指定用户的浏览器上下文
*
* @param userId 用户ID
*/
void releaseUserContext(String userId);
/**
* 释放所有资源
*/
void destroy();
}
\ No newline at end of file
package pangea.hiagent.core;
import com.microsoft.playwright.*;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import jakarta.annotation.PostConstruct;
import jakarta.annotation.PreDestroy;
import java.util.concurrent.*;
/**
* Playwright管理器实现类
* 负责统一管理Playwright实例和用户隔离的BrowserContext
*/
@Slf4j
@Component
public class PlaywrightManagerImpl implements PlaywrightManager {
// 共享的Playwright实例
private Playwright playwright;
// 共享的浏览器实例
private Browser browser;
// 用户浏览器上下文映射表(用户ID -> BrowserContext)
private final ConcurrentMap<String, BrowserContext> userContexts = new ConcurrentHashMap<>();
// 用户上下文创建时间映射表(用于超时清理)
private final ConcurrentMap<String, Long> contextCreationTimes = new ConcurrentHashMap<>();
// 用户上下文超时时间(毫秒),默认30分钟
private static final long CONTEXT_TIMEOUT = 30 * 60 * 1000;
// 清理任务调度器
private ScheduledExecutorService cleanupScheduler;
/**
* 初始化Playwright和浏览器实例
*/
@PostConstruct
public void initialize() {
try {
log.info("正在初始化Playwright管理器...");
// 创建Playwright实例
this.playwright = Playwright.create();
// 启动Chrome浏览器,无头模式
this.browser = playwright.chromium().launch(new BrowserType.LaunchOptions()
.setHeadless(true)
.setArgs(java.util.Arrays.asList(
"--no-sandbox",
"--disable-dev-shm-usage",
"--disable-gpu",
"--remote-allow-origins=*")));
// 初始化清理任务调度器
this.cleanupScheduler = Executors.newSingleThreadScheduledExecutor();
// 每5分钟检查一次超时的用户上下文
this.cleanupScheduler.scheduleAtFixedRate(this::cleanupExpiredContexts,
5, 5, TimeUnit.MINUTES);
log.info("Playwright管理器初始化成功");
} catch (Exception e) {
log.error("Playwright管理器初始化失败: ", e);
throw new RuntimeException("Failed to initialize Playwright manager", e);
}
}
@Override
public Playwright getPlaywright() {
if (playwright == null) {
throw new IllegalStateException("Playwright instance is not initialized");
}
return playwright;
}
@Override
public Browser getBrowser() {
if (browser == null || !browser.isConnected()) {
throw new IllegalStateException("Browser instance is not available");
}
return browser;
}
@Override
public BrowserContext getUserContext(String userId) {
Browser.NewContextOptions options = new Browser.NewContextOptions()
.setViewportSize(1344, 2992) // 设置视口大小,与前端一致;手机型号:Google Pixel 9 Pro XL
.setUserAgent("Mozilla/5.0 (Linux; Android 15; Pixel 9 Pro XL Build/UP2A.250105.004) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Mobile Safari/537.36"); // 设置用户代理
return getUserContext(userId, options);
}
@Override
public BrowserContext getUserContext(String userId, Browser.NewContextOptions options) {
if (userId == null || userId.isEmpty()) {
throw new IllegalArgumentException("User ID cannot be null or empty");
}
if (options == null) {
options = new Browser.NewContextOptions();
}
// 尝试从缓存中获取已存在的上下文
BrowserContext context = userContexts.get(userId);
// 如果上下文不存在或已关闭,则创建新的
if (context == null || context.pages().isEmpty()) {
try {
log.debug("为用户 {} 创建新的浏览器上下文", userId);
context = browser.newContext(options);
userContexts.put(userId, context);
contextCreationTimes.put(userId, System.currentTimeMillis());
} catch (Exception e) {
log.error("为用户 {} 创建浏览器上下文失败", userId, e);
throw new RuntimeException("Failed to create browser context for user: " + userId, e);
}
}
return context;
}
@Override
public void releaseUserContext(String userId) {
if (userId == null || userId.isEmpty()) {
return;
}
BrowserContext context = userContexts.remove(userId);
contextCreationTimes.remove(userId);
if (context != null) {
try {
context.close();
log.debug("用户 {} 的浏览器上下文已释放", userId);
} catch (Exception e) {
log.warn("关闭用户 {} 的浏览器上下文时发生异常", userId, e);
}
}
}
/**
* 清理过期的用户上下文
*/
private void cleanupExpiredContexts() {
long currentTime = System.currentTimeMillis();
long expiredThreshold = currentTime - CONTEXT_TIMEOUT;
for (String userId : contextCreationTimes.keySet()) {
Long creationTime = contextCreationTimes.get(userId);
if (creationTime != null && creationTime < expiredThreshold) {
log.info("清理过期的用户上下文: {}", userId);
releaseUserContext(userId);
}
}
}
/**
* 销毁所有资源
*/
@PreDestroy
@Override
public void destroy() {
log.info("开始销毁Playwright管理器资源...");
try {
// 关闭清理任务调度器
if (cleanupScheduler != null) {
cleanupScheduler.shutdown();
if (!cleanupScheduler.awaitTermination(5, TimeUnit.SECONDS)) {
cleanupScheduler.shutdownNow();
}
}
} catch (Exception e) {
log.warn("关闭清理任务调度器时发生异常", e);
}
// 关闭所有用户上下文
for (String userId : userContexts.keySet()) {
releaseUserContext(userId);
}
// 关闭浏览器
try {
if (browser != null && browser.isConnected()) {
browser.close();
log.info("浏览器实例已关闭");
}
} catch (Exception e) {
log.warn("关闭浏览器实例时发生异常", e);
}
// 关闭Playwright
try {
if (playwright != null) {
playwright.close();
log.info("Playwright实例已关闭");
}
} catch (Exception e) {
log.warn("关闭Playwright实例时发生异常", e);
}
log.info("Playwright管理器资源已全部销毁");
}
}
\ No newline at end of file
# Playwright实例管理优化方案
## 1. 当前问题分析
通过对代码库的分析,我们发现当前Playwright的使用存在以下问题:
### 1.1 重复实例化问题
目前系统中有三个独立的Playwright实例:
1. **DomSyncHandler.java** - WebSocket处理器中的Playwright实例
2. **PlaywrightWebTools.java** - 网页自动化工具类中的Playwright实例
3. **HisenseSsoAuthTool.java** - 海信SSO认证工具类中的Playwright实例
每个实例都在各自的类中独立创建和管理,造成资源浪费和维护困难。
### 1.2 资源管理不统一
各个Playwright实例的生命周期管理分散在不同的类中,缺乏统一的资源回收机制,可能导致内存泄漏。
### 1.3 用户隔离缺失
当前实现中没有有效的用户隔离机制,所有操作都在共享的浏览器上下文中执行,存在安全隐患。
## 2. 优化目标
1. **统一实例管理**:创建单一的Playwright管理器,整个应用共享一个Playwright实例
2. **资源优化**:减少重复创建的开销,提高资源利用率
3. **用户隔离**:实现基于BrowserContext的用户隔离机制
4. **易于维护**:提供清晰的接口和生命周期管理
## 3. 设计方案
### 3.1 架构设计
我们将采用以下架构:
```
+---------------------+
| PlaywrightManager | <- 统一管理Playwright实例
+----------+----------+
|
| 1..1
|
+----------v----------+
| PlaywrightInstance | <- 封装Playwright核心实例
+----------+----------+
|
| 1..*
|
+----------v----------+
| BrowserContextPool | <- 管理用户隔离的BrowserContext
+----------+----------+
|
| 1..*
|
+----------v----------+
| BrowserContext | <- 每个用户独立的浏览上下文
+---------------------+
```
### 3.2 核心组件
#### 3.2.1 PlaywrightManager (接口)
定义Playwright管理器的核心接口:
- 获取共享Playwright实例
- 获取用户专属BrowserContext
- 资源释放
#### 3.2.2 PlaywrightManagerImpl (实现)
PlaywrightManager的具体实现:
- 单例模式确保只有一个Playwright实例
- 管理BrowserContext池
- 实现资源的初始化和销毁
#### 3.2.3 UserContextManager
负责用户上下文管理:
- 为每个用户创建独立的BrowserContext
- 管理上下文的生命周期
- 实现超时自动清理机制
## 4. 实施步骤
### 4.1 创建Playwright管理接口和实现类
1. 创建PlaywrightManager接口
2. 创建PlaywrightManagerImpl实现类
3. 配置Spring Bean管理
### 4.2 实现用户隔离机制
1. 创建UserContextManager类
2. 实现基于用户ID的BrowserContext分配
3. 添加超时清理机制
### 4.3 重构现有代码
1. 修改DomSyncHandler以使用新的Playwright管理器
2. 修改PlaywrightWebTools以使用新的Playwright管理器
3. 修改HisenseSsoAuthTool以使用新的Playwright管理器
## 5. 预期收益
### 5.1 性能提升
- 减少Playwright实例创建次数,降低系统开销
- 统一资源管理,避免内存泄漏
### 5.2 安全增强
- 实现用户级别的浏览上下文隔离,每个用户拥有独立的浏览环境
- 通过JWT认证机制获取真实用户ID,确保上下文隔离基于实际用户身份
- 防止用户间数据交叉污染
- 在用户会话结束时正确释放资源,防止资源泄露
- 拒绝未认证的WebSocket连接,提升了整体安全性
### 5.3 可维护性改善
- 集中管理Playwright相关资源
- 简化代码维护和升级
## 6. 风险与应对
### 6.1 兼容性风险
- **风险**:重构可能影响现有功能
- **应对**:充分测试,逐步替换
### 6.2 并发访问风险
- **风险**:多线程环境下可能出现资源竞争
- **应对**:使用线程安全的数据结构和同步机制
## 7. 后续优化建议
1. 添加监控指标,跟踪Playwright资源使用情况
2. 实现动态扩缩容的BrowserContext池
3. 添加更精细的权限控制机制
\ No newline at end of file
package pangea.hiagent.dto;
import com.fasterxml.jackson.annotation.JsonInclude;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.time.LocalDateTime;
/**
* 提示词模板DTO
*/
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@JsonInclude(JsonInclude.Include.NON_NULL)
public class PromptTemplateDto {
private String id;
private String name;
private String description;
private String templateContent;
private String paramSchema;
private String templateType;
private Integer isSystem;
private LocalDateTime createdAt;
private LocalDateTime updatedAt;
private String createdBy;
}
package pangea.hiagent.dto;
import com.fasterxml.jackson.annotation.JsonInclude;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotNull;
import java.time.LocalDateTime;
/**
* 定时器配置DTO
*/
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@JsonInclude(JsonInclude.Include.NON_NULL)
public class TimerConfigDto {
private String id;
@NotBlank(message = "定时器名称不能为空")
private String name;
private String description;
@NotBlank(message = "Cron表达式不能为空")
private String cronExpression;
@NotNull(message = "启用状态不能为空")
private Integer enabled;
@NotBlank(message = "关联Agent ID不能为空")
private String agentId;
private String agentName;
private String promptTemplate;
private String paramsJson;
private LocalDateTime lastExecutionTime;
private LocalDateTime nextExecutionTime;
private LocalDateTime createdAt;
private LocalDateTime updatedAt;
private String createdBy;
}
package pangea.hiagent.dto;
import com.fasterxml.jackson.annotation.JsonInclude;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.time.LocalDateTime;
/**
* 定时器执行历史DTO
*/
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
@JsonInclude(JsonInclude.Include.NON_NULL)
public class TimerExecutionHistoryDto {
private String id;
private String timerId;
private String timerName;
private LocalDateTime executionTime;
private Integer success;
private String result;
private String errorMessage;
private Long executionDuration;
private String actualPrompt;
private LocalDateTime createdAt;
}
...@@ -15,6 +15,8 @@ import pangea.hiagent.dto.ApiResponse; ...@@ -15,6 +15,8 @@ import pangea.hiagent.dto.ApiResponse;
import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequest;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import org.springframework.security.authorization.AuthorizationDeniedException;
/** /**
* 全局异常处理器 * 全局异常处理器
* 统一处理系统中的各种异常 * 统一处理系统中的各种异常
...@@ -154,6 +156,34 @@ public class GlobalExceptionHandler { ...@@ -154,6 +156,34 @@ public class GlobalExceptionHandler {
return ResponseEntity.status(HttpStatus.UNAUTHORIZED).body(response); return ResponseEntity.status(HttpStatus.UNAUTHORIZED).body(response);
} }
/**
* 处理授权拒绝异常
*/
@ExceptionHandler(AuthorizationDeniedException.class)
public ResponseEntity<ApiResponse<Void>> handleAuthorizationDeniedException(
AuthorizationDeniedException e, HttpServletRequest request) {
log.warn("访问被拒绝: {} - URL: {}", e.getMessage(), request.getRequestURL());
// 检查响应是否已经提交
if (request.getAttribute("jakarta.servlet.error.exception") != null ||
(request instanceof org.springframework.web.context.request.NativeWebRequest &&
((org.springframework.web.context.request.NativeWebRequest) request).getNativeResponse() instanceof jakarta.servlet.http.HttpServletResponse &&
((jakarta.servlet.http.HttpServletResponse) ((org.springframework.web.context.request.NativeWebRequest) request).getNativeResponse()).isCommitted())) {
log.warn("响应已提交,无法发送访问拒绝错误: {}", request.getRequestURL());
// 响应已提交,无法发送错误响应
return ResponseEntity.status(HttpStatus.FORBIDDEN).build();
}
ApiResponse.ErrorDetail errorDetail = ApiResponse.ErrorDetail.builder()
.type("ACCESS_DENIED")
.details("您没有权限执行此操作")
.build();
ApiResponse<Void> response = ApiResponse.error(ErrorCode.FORBIDDEN.getCode(),
ErrorCode.FORBIDDEN.getMessage(), errorDetail);
return ResponseEntity.status(HttpStatus.FORBIDDEN).body(response);
}
/** /**
* 处理系统异常 * 处理系统异常
* 增强版本:更好地处理SSE流式响应中的异常 * 增强版本:更好地处理SSE流式响应中的异常
...@@ -177,6 +207,11 @@ public class GlobalExceptionHandler { ...@@ -177,6 +207,11 @@ public class GlobalExceptionHandler {
if (log.isDebugEnabled()) { if (log.isDebugEnabled()) {
log.debug("SSE异步请求不可用,客户端已断开连接 - URL: {}", request.getRequestURL()); log.debug("SSE异步请求不可用,客户端已断开连接 - URL: {}", request.getRequestURL());
} }
} else if (e.getMessage() != null && e.getMessage().contains("response has already been committed")) {
// 响应已提交异常 - 客户端已断开
if (log.isDebugEnabled()) {
log.debug("响应已提交,客户端可能已断开连接 - URL: {}", request.getRequestURL());
}
} else { } else {
// 非IOException的SSE异常才记录为ERROR // 非IOException的SSE异常才记录为ERROR
log.error("SSE流式处理异常 - URL: {} - 异常类型: {} - 异常消息: {}", log.error("SSE流式处理异常 - URL: {} - 异常类型: {} - 异常消息: {}",
...@@ -228,6 +263,9 @@ public class GlobalExceptionHandler { ...@@ -228,6 +263,9 @@ public class GlobalExceptionHandler {
boolean isStreamPath = requestUri != null && (requestUri.contains("stream") || boolean isStreamPath = requestUri != null && (requestUri.contains("stream") ||
requestUri.contains("chat") && requestUri.contains("event")); requestUri.contains("chat") && requestUri.contains("event"));
// 特别检查chat-stream路径
boolean isChatStreamPath = requestUri != null && requestUri.contains("chat-stream");
// 检查异常链中是否包含SSE相关异常 // 检查异常链中是否包含SSE相关异常
boolean hasSseException = checkForSseException(e); boolean hasSseException = checkForSseException(e);
...@@ -237,9 +275,10 @@ public class GlobalExceptionHandler { ...@@ -237,9 +275,10 @@ public class GlobalExceptionHandler {
(e.getMessage().contains("Socket") || (e.getMessage().contains("Socket") ||
e.getMessage().contains("软件中止") || e.getMessage().contains("软件中止") ||
e.getMessage().contains("ServletOutputStream") || e.getMessage().contains("ServletOutputStream") ||
e.getMessage().contains("Pipe")); e.getMessage().contains("Pipe") ||
e.getMessage().contains("Software caused connection abort"));
return isAcceptingStream || isStreamContent || isStreamPath || hasSseException || isSseOperationException; return isAcceptingStream || isStreamContent || isStreamPath || isChatStreamPath || hasSseException || isSseOperationException;
} }
/** /**
...@@ -266,7 +305,8 @@ public class GlobalExceptionHandler { ...@@ -266,7 +305,8 @@ public class GlobalExceptionHandler {
message.contains("software") || message.contains("software") ||
message.contains("软件中止") || message.contains("软件中止") ||
message.contains("断开") || message.contains("断开") ||
message.contains("AsyncRequestNotUsable")) { message.contains("AsyncRequestNotUsable") ||
message.contains("Software caused connection abort")) {
return true; return true;
} }
} }
...@@ -278,7 +318,10 @@ public class GlobalExceptionHandler { ...@@ -278,7 +318,10 @@ public class GlobalExceptionHandler {
return true; return true;
} }
String causeMsg = cause.getMessage(); String causeMsg = cause.getMessage();
if (causeMsg != null && (causeMsg.contains("Socket") || causeMsg.contains("Pipe"))) { if (causeMsg != null && (causeMsg.contains("Socket") ||
causeMsg.contains("Pipe") ||
causeMsg.contains("Software caused connection abort") ||
causeMsg.contains("软件中止"))) {
return true; return true;
} }
return checkForSseException(cause); return checkForSseException(cause);
......
...@@ -47,6 +47,11 @@ public class CaffeineChatMemory implements ChatMemory { ...@@ -47,6 +47,11 @@ public class CaffeineChatMemory implements ChatMemory {
cache.put(conversationId, existingMessages); cache.put(conversationId, existingMessages);
log.debug("成功将{}条消息添加到会话{}", messages.size(), conversationId); log.debug("成功将{}条消息添加到会话{}", messages.size(), conversationId);
// 如果会话ID包含null,记录警告信息
if (conversationId != null && conversationId.contains("null")) {
log.warn("检测到包含'null'的会话ID: {},可能存在认证问题", conversationId);
}
} catch (Exception e) { } catch (Exception e) {
log.error("保存消息到Caffeine缓存时发生错误", e); log.error("保存消息到Caffeine缓存时发生错误", e);
throw new RuntimeException("Failed to save messages to Caffeine cache", e); throw new RuntimeException("Failed to save messages to Caffeine cache", e);
......
...@@ -8,6 +8,7 @@ import org.springframework.beans.factory.annotation.Autowired; ...@@ -8,6 +8,7 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import pangea.hiagent.utils.UserUtils;
import pangea.hiagent.model.Agent; import pangea.hiagent.model.Agent;
import java.util.Collections; import java.util.Collections;
...@@ -46,6 +47,12 @@ public class MemoryService { ...@@ -46,6 +47,12 @@ public class MemoryService {
if (userId == null) { if (userId == null) {
userId = getCurrentUserId(); userId = getCurrentUserId();
} }
// 如果userId仍然为null,使用默认值避免生成"null_xxx"格式的会话ID
if (userId == null) {
userId = "unknown-user";
}
return userId + "_" + agent.getId(); return userId + "_" + agent.getId();
} }
...@@ -54,9 +61,11 @@ public class MemoryService { ...@@ -54,9 +61,11 @@ public class MemoryService {
* @return 用户ID * @return 用户ID
*/ */
private String getCurrentUserId() { private String getCurrentUserId() {
Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); String userId = UserUtils.getCurrentUserId();
return (authentication != null && authentication.getPrincipal() != null) ? if (userId == null) {
(String) authentication.getPrincipal() : null; log.warn("无法通过UserUtils获取当前用户ID");
}
return userId;
} }
/** /**
......
package pangea.hiagent.model;
import com.baomidou.mybatisplus.annotation.TableName;
import lombok.Data;
import lombok.EqualsAndHashCode;
/**
* 提示词模板实体类
*/
@Data
@EqualsAndHashCode(callSuper = true)
@TableName("hiagent_prompt_template")
public class PromptTemplate extends BaseEntity {
/**
* 模板名称
*/
private String name;
/**
* 模板描述
*/
private String description;
/**
* 模板内容
*/
private String templateContent;
/**
* 参数Schema定义(JSON格式)
*/
private String paramSchema;
/**
* 模板类型
*/
private String templateType;
/**
* 是否为系统模板(0-自定义,1-系统)
*/
private Integer isSystem;
}
package pangea.hiagent.model;
import com.baomidou.mybatisplus.annotation.TableName;
import lombok.Data;
import lombok.EqualsAndHashCode;
/**
* 定时器配置实体类
*/
@Data
@EqualsAndHashCode(callSuper = true)
@TableName("hiagent_timer_config")
public class TimerConfig extends BaseEntity {
/**
* 定时器名称
*/
private String name;
/**
* 定时器描述
*/
private String description;
/**
* Cron表达式(支持秒级)
*/
private String cronExpression;
/**
* 启用状态(0-禁用,1-启用)
*/
private Integer enabled;
/**
* 关联的Agent ID
*/
private String agentId;
/**
* 关联的Agent名称
*/
private String agentName;
/**
* 提示词模板
*/
private String promptTemplate;
/**
* 动态参数配置(JSON格式)
*/
private String paramsJson;
/**
* 最后执行时间
*/
private java.time.LocalDateTime lastExecutionTime;
/**
* 下次执行时间
*/
private java.time.LocalDateTime nextExecutionTime;
}
package pangea.hiagent.model;
import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableLogic;
import com.baomidou.mybatisplus.annotation.TableName;
import lombok.Data;
import java.time.LocalDateTime;
/**
* 定时器执行历史实体类
*/
@Data
@TableName("hiagent_timer_execution_history")
public class TimerExecutionHistory {
/**
* 主键ID,使用数据库自增策略
*/
@TableId(value = "id", type = IdType.AUTO)
private Long id;
/**
* 关联的定时器ID
*/
@TableField("timer_id")
private String timerId;
/**
* 定时器名称
*/
@TableField("timer_name")
private String timerName;
/**
* 执行时间
*/
@TableField("execution_time")
private LocalDateTime executionTime;
/**
* 执行结果(0-失败,1-成功)
*/
@TableField("success")
private Integer success;
/**
* 执行结果详情
*/
@TableField("result")
private String result;
/**
* 错误信息
*/
@TableField("error_message")
private String errorMessage;
/**
* 执行时长(毫秒)
*/
@TableField("execution_duration")
private Long executionDuration;
/**
* 实际执行的提示词
*/
@TableField("actual_prompt")
private String actualPrompt;
/**
* 创建时间
*/
@TableField("created_at")
private LocalDateTime createdAt;
/**
* 更新时间
*/
@TableField("updated_at")
private LocalDateTime updatedAt;
/**
* 创建人
*/
@TableField("created_by")
private String createdBy;
/**
* 更新人
*/
@TableField("updated_by")
private String updatedBy;
/**
* 删除标记(0-未删除,1-已删除)
*/
@TableLogic
@TableField("deleted")
private Integer deleted;
/**
* 备注
*/
@TableField("remark")
private String remark;
}
package pangea.hiagent.model;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableName;
import lombok.Data;
import lombok.EqualsAndHashCode;
/**
* 工具配置实体类
* 用于存储工具参数配置
*/
@Data
@EqualsAndHashCode(callSuper = true)
@TableName("tool_configs")
public class ToolConfig extends BaseEntity {
/**
* 工具名称
*/
@TableField("tool_name")
private String toolName;
/**
* 参数名称
*/
@TableField("param_name")
private String paramName;
/**
* 参数值
*/
@TableField("param_value")
private String paramValue;
/**
* 参数描述
*/
private String description;
/**
* 默认值
*/
@TableField("default_value")
private String defaultValue;
/**
* 参数类型
*/
private String type;
/**
* 是否必填
*/
private Boolean required;
/**
* 参数分组
*/
@TableField("group_name")
private String groupName;
}
\ No newline at end of file
...@@ -246,33 +246,72 @@ public class DefaultReactExecutor implements ReactExecutor { ...@@ -246,33 +246,72 @@ public class DefaultReactExecutor implements ReactExecutor {
} }
}, },
throwable -> { throwable -> {
log.error("流式处理出错", throwable); log.error("流式处理出错: {}", throwable.getMessage(), throwable);
// 检查是否是401 Unauthorized错误 // 检查是否是401 Unauthorized错误
if (isUnauthorizedError(throwable)) { if (isUnauthorizedError(throwable)) {
log.error("LLM返回401未授权错误: {}", throwable.getMessage()); log.error("LLM返回401未授权错误,请检查API密钥配置");
sendErrorToConsumer(tokenConsumer, " 请配置API密钥"); recordStreamError("LLM返回401未授权错误");
try {
if (tokenConsumer != null) {
tokenConsumer.accept("[错误] 请配置API密钥");
}
} catch (Exception e) {
log.error("发送API密钥错误失败: {}", e.getMessage());
}
} else if (throwable.getMessage() != null && throwable.getMessage().contains("timeout")) {
log.error("流式处理超时: {}", throwable.getMessage());
recordStreamError("流式处理超时");
try {
if (tokenConsumer != null) {
tokenConsumer.accept("[错误] 流式处理超时,请稍后重试");
}
} catch (Exception e) {
log.error("发送超时错误失败: {}", e.getMessage());
}
} else { } else {
recordStreamError(throwable.getMessage()); // 一般错误
sendErrorToConsumer(tokenConsumer, throwable.getMessage()); recordStreamError("流式处理异常: " + throwable.getMessage());
}
},
() -> {
log.info("流式处理完成");
// 触发最终答案步骤
triggerFinalAnswerStep(fullResponse.toString());
// 将助理回复添加到ChatMemory
if (agent != null) {
try { try {
String sessionId = memoryService.generateSessionId(agent); if (tokenConsumer != null) {
memoryService.addAssistantMessageToMemory(sessionId, fullResponse.toString()); tokenConsumer.accept("[错误] 流式处理失败: " + throwable.getMessage());
}
} catch (Exception e) { } catch (Exception e) {
log.warn("保存助理回复到内存时发生错误: {}", e.getMessage()); log.error("发送错误信息失败: {}", e.getMessage());
} }
} }
// 发送完成事件,包含完整内容 // 确保即使出现错误也能标记完成
sendCompletionEvent(tokenConsumer, fullResponse.toString()); log.debug("标记流式处理完成(因错误而终止)");
},
() -> {
try {
log.info("流式处理完成");
// 触发最终答案步骤
triggerFinalAnswerStep(fullResponse.toString());
// 将助理回复添加到ChatMemory
if (agent != null) {
try {
String sessionId = memoryService.generateSessionId(agent);
memoryService.addAssistantMessageToMemory(sessionId, fullResponse.toString());
} catch (Exception e) {
log.warn("保存助理回复到内存时发生错误: {}", e.getMessage());
}
}
// 发送完成事件,包含完整内容
sendCompletionEvent(tokenConsumer, fullResponse.toString());
} catch (Exception e) {
log.error("处理流式完成回调时发生错误", e);
// 即使在完成回调中出现错误,也要确保标记完成
if (tokenConsumer instanceof AgentChatService.TokenConsumerWithCompletion) {
try {
((AgentChatService.TokenConsumerWithCompletion) tokenConsumer).onComplete("[处理完成时发生错误] " + e.getMessage());
} catch (Exception ex) {
log.error("调用onComplete时发生错误", ex);
}
}
}
} }
); );
......
package pangea.hiagent.repository;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import org.apache.ibatis.annotations.Mapper;
import pangea.hiagent.model.PromptTemplate;
/**
* 提示词模板Repository接口
*/
@Mapper
public interface PromptTemplateRepository extends BaseMapper<PromptTemplate> {
}
package pangea.hiagent.repository;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import org.apache.ibatis.annotations.Mapper;
import pangea.hiagent.model.TimerConfig;
/**
* 定时器配置Repository接口
*/
@Mapper
public interface TimerConfigRepository extends BaseMapper<TimerConfig> {
}
package pangea.hiagent.repository;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import org.apache.ibatis.annotations.Mapper;
import pangea.hiagent.model.TimerExecutionHistory;
/**
* 定时器执行历史Repository接口
*/
@Mapper
public interface TimerExecutionHistoryRepository extends BaseMapper<TimerExecutionHistory> {
}
package pangea.hiagent.repository;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import org.apache.ibatis.annotations.Mapper;
import org.apache.ibatis.annotations.Select;
import pangea.hiagent.model.ToolConfig;
import java.util.List;
import java.util.Map;
/**
* 工具配置仓库接口
* 提供工具配置数据访问功能
*/
@Mapper
public interface ToolConfigRepository extends BaseMapper<ToolConfig> {
/**
* 根据工具名称获取配置列表
* @param toolName 工具名称
* @return 配置列表
*/
@Select("SELECT * FROM tool_configs WHERE tool_name = #{toolName} AND deleted = 0")
List<ToolConfig> findByToolName(String toolName);
/**
* 根据工具名称和参数名称获取配置
* @param toolName 工具名称
* @param paramName 参数名称
* @return 配置对象
*/
@Select("SELECT * FROM tool_configs WHERE tool_name = #{toolName} AND param_name = #{paramName} AND deleted = 0 LIMIT 1")
ToolConfig findByToolNameAndParamName(String toolName, String paramName);
/**
* 获取所有工具配置列表
* @return 配置列表
*/
@Select("SELECT * FROM tool_configs WHERE deleted = 0 ORDER BY tool_name, group_name, param_name")
List<ToolConfig> findAllActive();
/**
* 根据工具名称获取参数键值对
* @param toolName 工具名称
* @return 参数键值对
*/
@Select("SELECT param_name, param_value FROM tool_configs WHERE tool_name = #{toolName} AND deleted = 0")
List<Map<String, Object>> findParamValuesByToolName(String toolName);
}
\ No newline at end of file
package pangea.hiagent.scheduler;
import lombok.extern.slf4j.Slf4j;
import org.quartz.JobExecutionContext;
import org.quartz.JobExecutionException;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.scheduling.quartz.QuartzJobBean;
import pangea.hiagent.service.TimerService;
/**
* 定时器任务执行类
* 由Quartz调度器触发,执行具体的定时器任务
*/
@Slf4j
public class TimerJob extends QuartzJobBean {
@Autowired
private TimerService timerService;
/**
* 执行定时器任务
* @param context 任务执行上下文,包含任务的参数信息
*/
@Override
protected void executeInternal(JobExecutionContext context) throws JobExecutionException {
try {
// 从上下文中获取定时器ID
String timerId = context.getJobDetail().getJobDataMap().getString("timerId");
log.info("开始执行定时器任务: {}", timerId);
if (timerId == null || timerId.isEmpty()) {
log.error("定时器任务缺少timerId参数");
return;
}
// 调用TimerService执行定时器任务
timerService.executeTimerTask(timerId);
log.info("定时器任务执行完成: {}", timerId);
} catch (Exception e) {
log.error("定时器任务执行失败", e);
throw new JobExecutionException(e);
}
}
}
package pangea.hiagent.scheduler;
import com.cronutils.model.CronType;
import com.cronutils.model.definition.CronDefinitionBuilder;
import com.cronutils.parser.CronParser;
import lombok.extern.slf4j.Slf4j;
import org.quartz.*;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import pangea.hiagent.model.TimerConfig;
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.util.Calendar;
import java.util.Date;
/**
* 定时器调度管理器
* 负责管理Quartz的Job和Trigger,实现动态添加、更新和删除定时任务
*/
@Slf4j
@Component
public class TimerScheduler {
@Autowired
private Scheduler scheduler;
// Cron解析器,用于验证Cron表达式
private final CronParser cronParser = new CronParser(CronDefinitionBuilder.instanceDefinitionFor(CronType.QUARTZ));
/**
* 添加或更新定时器任务
* @param timerConfig 定时器配置信息
*/
public void addOrUpdateTimer(TimerConfig timerConfig) {
try {
log.info("添加或更新定时器任务: {}", timerConfig.getId());
// 验证Cron表达式
cronParser.parse(timerConfig.getCronExpression());
// 构建JobDetail
JobDetail jobDetail = buildJobDetail(timerConfig);
// 构建Trigger
Trigger trigger = buildTrigger(jobDetail, timerConfig);
// 添加或更新Job和Trigger
if (scheduler.checkExists(jobDetail.getKey())) {
scheduler.rescheduleJob(trigger.getKey(), trigger);
log.info("更新定时器任务: {}", timerConfig.getId());
} else {
scheduler.scheduleJob(jobDetail, trigger);
log.info("添加定时器任务: {}", timerConfig.getId());
}
// 如果定时器被禁用,暂停任务
if (timerConfig.getEnabled() == 0) {
scheduler.pauseJob(jobDetail.getKey());
log.info("暂停定时器任务: {}", timerConfig.getId());
} else {
// 如果定时器被启用,恢复任务
scheduler.resumeJob(jobDetail.getKey());
log.info("恢复定时器任务: {}", timerConfig.getId());
}
} catch (Exception e) {
log.error("添加或更新定时器任务失败: {}", timerConfig.getId(), e);
throw new RuntimeException("添加或更新定时器任务失败: " + e.getMessage(), e);
}
}
/**
* 删除定时器任务
* @param timerId 定时器ID
*/
public void deleteTimer(String timerId) {
try {
log.info("删除定时器任务: {}", timerId);
JobKey jobKey = JobKey.jobKey("timerJob_" + timerId, "timerGroup");
scheduler.deleteJob(jobKey);
log.info("删除定时器任务成功: {}", timerId);
} catch (Exception e) {
log.error("删除定时器任务失败: {}", timerId, e);
throw new RuntimeException("删除定时器任务失败: " + e.getMessage(), e);
}
}
/**
* 启用定时器任务
* @param timerId 定时器ID
*/
public void enableTimer(String timerId) {
try {
log.info("启用定时器任务: {}", timerId);
JobKey jobKey = JobKey.jobKey("timerJob_" + timerId, "timerGroup");
scheduler.resumeJob(jobKey);
log.info("启用定时器任务成功: {}", timerId);
} catch (Exception e) {
log.error("启用定时器任务失败: {}", timerId, e);
throw new RuntimeException("启用定时器任务失败: " + e.getMessage(), e);
}
}
/**
* 禁用定时器任务
* @param timerId 定时器ID
*/
public void disableTimer(String timerId) {
try {
log.info("禁用定时器任务: {}", timerId);
JobKey jobKey = JobKey.jobKey("timerJob_" + timerId, "timerGroup");
scheduler.pauseJob(jobKey);
log.info("禁用定时器任务成功: {}", timerId);
} catch (Exception e) {
log.error("禁用定时器任务失败: {}", timerId, e);
throw new RuntimeException("禁用定时器任务失败: " + e.getMessage(), e);
}
}
/**
* 构建JobDetail
* @param timerConfig 定时器配置信息
* @return JobDetail对象
*/
private JobDetail buildJobDetail(TimerConfig timerConfig) {
JobDataMap jobDataMap = new JobDataMap();
jobDataMap.put("timerId", timerConfig.getId());
return JobBuilder.newJob(TimerJob.class)
.withIdentity("timerJob_" + timerConfig.getId(), "timerGroup")
.withDescription(timerConfig.getName())
.usingJobData(jobDataMap)
.storeDurably(false)
.build();
}
/**
* 构建Trigger
* @param jobDetail JobDetail对象
* @param timerConfig 定时器配置信息
* @return Trigger对象
*/
private Trigger buildTrigger(JobDetail jobDetail, TimerConfig timerConfig) {
return TriggerBuilder.newTrigger()
.forJob(jobDetail)
.withIdentity("timerTrigger_" + timerConfig.getId(), "timerGroup")
.withDescription(timerConfig.getName())
.withSchedule(CronScheduleBuilder.cronSchedule(timerConfig.getCronExpression()))
.startNow()
.build();
}
/**
* 计算下次执行时间
* @param cronExpression Cron表达式
* @return 下次执行时间
*/
public LocalDateTime calculateNextExecutionTime(String cronExpression) {
try {
// 简化实现,暂时不计算下次执行时间
// 后续可以使用其他方式实现,或者升级cron-utils库版本
return null;
} catch (Exception e) {
log.error("计算下次执行时间失败: {}", cronExpression, e);
return null;
}
}
}
...@@ -5,7 +5,9 @@ import org.springframework.security.access.PermissionEvaluator; ...@@ -5,7 +5,9 @@ import org.springframework.security.access.PermissionEvaluator;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import pangea.hiagent.model.Agent; import pangea.hiagent.model.Agent;
import pangea.hiagent.model.TimerConfig;
import pangea.hiagent.service.AgentService; import pangea.hiagent.service.AgentService;
import pangea.hiagent.service.TimerService;
import java.io.Serializable; import java.io.Serializable;
...@@ -18,13 +20,15 @@ import java.io.Serializable; ...@@ -18,13 +20,15 @@ import java.io.Serializable;
public class DefaultPermissionEvaluator implements PermissionEvaluator { public class DefaultPermissionEvaluator implements PermissionEvaluator {
private final AgentService agentService; private final AgentService agentService;
private final TimerService timerService;
public DefaultPermissionEvaluator(AgentService agentService) { public DefaultPermissionEvaluator(AgentService agentService, TimerService timerService) {
this.agentService = agentService; this.agentService = agentService;
this.timerService = timerService;
} }
/** /**
* 检查用户是否有权访问指定Agent * 检查用户是否有权访问指定资源
*/ */
@Override @Override
public boolean hasPermission(Authentication authentication, Object targetDomainObject, Object permission) { public boolean hasPermission(Authentication authentication, Object targetDomainObject, Object permission) {
...@@ -35,18 +39,20 @@ public class DefaultPermissionEvaluator implements PermissionEvaluator { ...@@ -35,18 +39,20 @@ public class DefaultPermissionEvaluator implements PermissionEvaluator {
String userId = (String) authentication.getPrincipal(); String userId = (String) authentication.getPrincipal();
String perm = (String) permission; String perm = (String) permission;
// 目前只处理Agent访问权限 // 处理Agent访问权限
if (targetDomainObject instanceof Agent) { if (targetDomainObject instanceof Agent) {
Agent agent = (Agent) targetDomainObject; Agent agent = (Agent) targetDomainObject;
return checkAgentAccess(userId, agent, perm); return checkAgentAccess(userId, agent, perm);
} else if (targetDomainObject instanceof String) { }
// 假设targetDomainObject是Agent ID // 处理TimerConfig访问权限
String agentId = (String) targetDomainObject; else if (targetDomainObject instanceof TimerConfig) {
Agent agent = agentService.getAgent(agentId); TimerConfig timer = (TimerConfig) targetDomainObject;
if (agent == null) { return checkTimerAccess(userId, timer, perm);
return false; }
} // 处理基于ID的资源访问
return checkAgentAccess(userId, agent, perm); else if (targetDomainObject instanceof String) {
// 这种情况在hasPermission(Authentication, Serializable, String, Object)方法中处理
return false;
} }
return false; return false;
...@@ -68,6 +74,14 @@ public class DefaultPermissionEvaluator implements PermissionEvaluator { ...@@ -68,6 +74,14 @@ public class DefaultPermissionEvaluator implements PermissionEvaluator {
return false; return false;
} }
return checkAgentAccess(userId, agent, perm); return checkAgentAccess(userId, agent, perm);
}
// 处理TimerConfig资源的权限检查
else if ("TimerConfig".equals(targetType)) {
TimerConfig timer = timerService.getTimerById(targetId.toString());
if (timer == null) {
return false;
}
return checkTimerAccess(userId, timer, perm);
} }
return false; return false;
...@@ -102,12 +116,41 @@ public class DefaultPermissionEvaluator implements PermissionEvaluator { ...@@ -102,12 +116,41 @@ public class DefaultPermissionEvaluator implements PermissionEvaluator {
} }
} }
/**
* 检查用户对TimerConfig的访问权限
*/
private boolean checkTimerAccess(String userId, TimerConfig timer, String permission) {
// 管理员可以访问所有定时器
if (isAdminUser(userId)) {
return true;
}
// 检查定时器创建者
if (timer.getCreatedBy() != null && timer.getCreatedBy().equals(userId)) {
return true;
}
// 根据权限类型进行检查
switch (permission.toLowerCase()) {
case "read":
// 所有用户都可以读取公开的定时器(如果有此概念)
return false; // 暂时不支持公开定时器
case "write":
case "delete":
// 只有创建者可以修改或删除定时器
return timer.getCreatedBy() != null && timer.getCreatedBy().equals(userId);
default:
return false;
}
}
/** /**
* 检查是否为管理员用户 * 检查是否为管理员用户
*/ */
private boolean isAdminUser(String userId) { private boolean isAdminUser(String userId) {
// 这里可以根据实际需求实现管理员检查逻辑 // 这里可以根据实际需求实现管理员检查逻辑
// 例如查询数据库或检查特殊用户ID // 例如查询数据库或检查特殊用户ID
// 当前实现保留原有逻辑,但可以通过配置或数据库来管理管理员用户
return "admin".equals(userId) || "user-001".equals(userId); return "admin".equals(userId) || "user-001".equals(userId);
} }
} }
\ No newline at end of file
...@@ -72,11 +72,6 @@ public class JwtAuthenticationFilter extends OncePerRequestFilter { ...@@ -72,11 +72,6 @@ public class JwtAuthenticationFilter extends OncePerRequestFilter {
new UsernamePasswordAuthenticationToken(userId, null, authorities); new UsernamePasswordAuthenticationToken(userId, null, authorities);
SecurityContextHolder.getContext().setAuthentication(authentication); SecurityContextHolder.getContext().setAuthentication(authentication);
log.debug("已设置SecurityContext中的认证信息,用户ID: {}, 权限: {}", userId, authentication.getAuthorities()); log.debug("已设置SecurityContext中的认证信息,用户ID: {}, 权限: {}", userId, authentication.getAuthorities());
// 认证成功后继续处理请求
filterChain.doFilter(request, response);
log.debug("JwtAuthenticationFilter处理完成: {} {}", request.getMethod(), request.getRequestURI());
return;
} else { } else {
log.warn("从token中提取的用户ID为空"); log.warn("从token中提取的用户ID为空");
} }
...@@ -88,6 +83,34 @@ public class JwtAuthenticationFilter extends OncePerRequestFilter { ...@@ -88,6 +83,34 @@ public class JwtAuthenticationFilter extends OncePerRequestFilter {
} }
} catch (Exception e) { } catch (Exception e) {
log.error("JWT认证处理异常", e); log.error("JWT认证处理异常", e);
// 检查响应是否已经提交
if (!response.isCommitted()) {
try {
response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
response.setContentType("application/json;charset=UTF-8");
response.getWriter().write("{\"code\":401,\"message\":\"认证失败\",\"timestamp\":" + System.currentTimeMillis() + "}");
} catch (IOException ioException) {
log.error("发送认证失败响应时发生IO异常", ioException);
}
} else {
log.warn("响应已经提交,无法发送认证失败响应");
}
}
// 检查是否是SSE端点并且响应已经提交
if ((isStreamEndpoint || isTimelineEndpoint) && response.isCommitted()) {
log.debug("SSE端点响应已提交,跳过过滤器链继续处理");
return;
}
// 特别处理流式端点的权限问题
if (isStreamEndpoint || isTimelineEndpoint) {
// 检查是否已认证
if (SecurityContextHolder.getContext().getAuthentication() == null) {
log.warn("流式端点未认证访问: {} {}", request.getMethod(), request.getRequestURI());
// 对于SSE端点,如果未认证,我们不立即返回错误,而是让后续处理决定
// 因为客户端可能会在重新连接时带上token
}
} }
// 继续执行过滤器链,让Spring Security的其他过滤器处理认证和授权 // 继续执行过滤器链,让Spring Security的其他过滤器处理认证和授权
......
...@@ -15,6 +15,8 @@ import pangea.hiagent.repository.LlmConfigRepository; ...@@ -15,6 +15,8 @@ import pangea.hiagent.repository.LlmConfigRepository;
import pangea.hiagent.llm.LlmModelFactory; import pangea.hiagent.llm.LlmModelFactory;
import java.util.List; import java.util.List;
import org.springframework.cache.annotation.CacheEvict;
import org.springframework.cache.annotation.Cacheable;
/** /**
* Agent服务类 * Agent服务类
...@@ -43,6 +45,7 @@ public class AgentService { ...@@ -43,6 +45,7 @@ public class AgentService {
* 创建Agent * 创建Agent
*/ */
@Transactional @Transactional
@CacheEvict(value = {"agents", "agent"}, allEntries = true)
public Agent createAgent(Agent agent) { public Agent createAgent(Agent agent) {
log.info("创建Agent: {}", agent.getName()); log.info("创建Agent: {}", agent.getName());
...@@ -80,6 +83,7 @@ public class AgentService { ...@@ -80,6 +83,7 @@ public class AgentService {
* 更新Agent * 更新Agent
*/ */
@Transactional @Transactional
@CacheEvict(value = {"agents", "agent"}, allEntries = true)
public Agent updateAgent(Agent agent) { public Agent updateAgent(Agent agent) {
log.info("更新Agent: {}", agent.getId()); log.info("更新Agent: {}", agent.getId());
...@@ -99,6 +103,7 @@ public class AgentService { ...@@ -99,6 +103,7 @@ public class AgentService {
* 删除Agent * 删除Agent
*/ */
@Transactional @Transactional
@CacheEvict(value = {"agents", "agent"}, allEntries = true)
public void deleteAgent(String id) { public void deleteAgent(String id) {
log.info("删除Agent: {}", id); log.info("删除Agent: {}", id);
agentRepository.deleteById(id); agentRepository.deleteById(id);
...@@ -110,6 +115,7 @@ public class AgentService { ...@@ -110,6 +115,7 @@ public class AgentService {
* @param id Agent ID * @param id Agent ID
* @return Agent对象,如果不存在则返回null * @return Agent对象,如果不存在则返回null
*/ */
@Cacheable(value = "agent", key = "#id")
public Agent getAgent(String id) { public Agent getAgent(String id) {
if (id == null || id.isEmpty()) { if (id == null || id.isEmpty()) {
log.warn("尝试使用无效ID获取Agent"); log.warn("尝试使用无效ID获取Agent");
...@@ -123,6 +129,7 @@ public class AgentService { ...@@ -123,6 +129,7 @@ public class AgentService {
* *
* @return Agent列表 * @return Agent列表
*/ */
@Cacheable(value = "agents")
public List<Agent> listAgents() { public List<Agent> listAgents() {
List<Agent> agents = agentRepository.selectList(null); List<Agent> agents = agentRepository.selectList(null);
log.info("获取到 {} 个Agent", agents != null ? agents.size() : 0); log.info("获取到 {} 个Agent", agents != null ? agents.size() : 0);
...@@ -150,6 +157,7 @@ public class AgentService { ...@@ -150,6 +157,7 @@ public class AgentService {
/** /**
* 获取用户的Agent列表 * 获取用户的Agent列表
*/ */
@Cacheable(value = "agents", key = "#userId")
public List<Agent> getUserAgents(String userId) { public List<Agent> getUserAgents(String userId) {
// 使用优化的查询方法 // 使用优化的查询方法
return agentRepository.findActiveAgentsByOwnerWithExplicitColumns(userId); return agentRepository.findActiveAgentsByOwnerWithExplicitColumns(userId);
......
package pangea.hiagent.service;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import pangea.hiagent.model.TimerExecutionHistory;
import pangea.hiagent.repository.TimerExecutionHistoryRepository;
import java.time.LocalDateTime;
import java.time.temporal.ChronoUnit;
/**
* 执行历史清理服务
* 定期清理过期的执行历史记录
*/
@Slf4j
@Service
public class HistoryCleanupService {
private final TimerExecutionHistoryRepository timerExecutionHistoryRepository;
public HistoryCleanupService(TimerExecutionHistoryRepository timerExecutionHistoryRepository) {
this.timerExecutionHistoryRepository = timerExecutionHistoryRepository;
}
/**
* 清理过期的执行历史记录
* 每天凌晨2点执行一次,清理30天前的记录
*/
@Scheduled(cron = "0 0 2 * * ?")
@Transactional
public void cleanupOldHistory() {
log.info("开始清理过期执行历史记录");
// 计算30天前的时间
LocalDateTime cutoffTime = LocalDateTime.now().minus(30, ChronoUnit.DAYS);
// 构建查询条件:执行时间小于30天前
LambdaQueryWrapper<TimerExecutionHistory> wrapper = new LambdaQueryWrapper<>();
wrapper.lt(TimerExecutionHistory::getExecutionTime, cutoffTime);
// 清理30天前的执行历史记录
int deletedCount = timerExecutionHistoryRepository.delete(wrapper);
log.info("清理完成,共删除 {} 条过期执行历史记录", deletedCount);
}
}
\ No newline at end of file
package pangea.hiagent.service;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import org.springframework.stereotype.Service;
import pangea.hiagent.dto.TimerExecutionHistoryDto;
import pangea.hiagent.model.TimerExecutionHistory;
import pangea.hiagent.repository.TimerExecutionHistoryRepository;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.List;
import java.util.stream.Collectors;
/**
* 执行历史服务类
* 负责执行历史的查询、统计和管理
*/
@Service
public class HistoryService {
private final TimerExecutionHistoryRepository timerExecutionHistoryRepository;
private final DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss");
public HistoryService(TimerExecutionHistoryRepository timerExecutionHistoryRepository) {
this.timerExecutionHistoryRepository = timerExecutionHistoryRepository;
}
/**
* 获取执行历史列表,支持多条件筛选和分页
*
* @param timerId 定时器ID
* @param success 执行结果(1-成功,0-失败)
* @param startTime 开始时间
* @param endTime 结束时间
* @param page 页码
* @param size 每页大小
* @return 执行历史列表
*/
public Page<TimerExecutionHistoryDto> getExecutionHistoryList(
String timerId, Integer success, String startTime, String endTime, int page, int size) {
// 构建查询条件
LambdaQueryWrapper<TimerExecutionHistory> wrapper = buildQueryWrapper(timerId, success, startTime, endTime);
// 按执行时间倒序排序
wrapper.orderByDesc(TimerExecutionHistory::getExecutionTime);
// 分页查询
Page<TimerExecutionHistory> pagination = new Page<>(page, size);
timerExecutionHistoryRepository.selectPage(pagination, wrapper);
// 转换为DTO
List<TimerExecutionHistoryDto> records = pagination.getRecords().stream()
.map(this::convertToDto)
.collect(Collectors.toList());
// 创建新的分页对象并设置数据
Page<TimerExecutionHistoryDto> resultPage = new Page<>(pagination.getCurrent(), pagination.getSize(), pagination.getTotal());
resultPage.setRecords(records);
return resultPage;
}
/**
* 获取指定定时器的执行历史
*
* @param timerId 定时器ID
* @param page 页码
* @param size 每页大小
* @return 执行历史列表
*/
public Page<TimerExecutionHistoryDto> getExecutionHistoryByTimerId(String timerId, int page, int size) {
return getExecutionHistoryList(timerId, null, null, null, page, size);
}
/**
* 获取执行历史详情
*
* @param id 执行历史ID
* @return 执行历史详情
*/
public TimerExecutionHistoryDto getExecutionHistoryDetail(Long id) {
TimerExecutionHistory history = timerExecutionHistoryRepository.selectById(id);
return history != null ? convertToDto(history) : null;
}
/**
* 构建查询条件
*/
private LambdaQueryWrapper<TimerExecutionHistory> buildQueryWrapper(
String timerId, Integer success, String startTime, String endTime) {
LambdaQueryWrapper<TimerExecutionHistory> wrapper = new LambdaQueryWrapper<>();
// 定时器ID条件
if (timerId != null && !timerId.isEmpty()) {
wrapper.eq(TimerExecutionHistory::getTimerId, timerId);
}
// 执行结果条件
if (success != null) {
wrapper.eq(TimerExecutionHistory::getSuccess, success);
}
// 开始时间条件
if (startTime != null && !startTime.isEmpty()) {
LocalDateTime start = LocalDateTime.parse(startTime, formatter);
wrapper.ge(TimerExecutionHistory::getExecutionTime, start);
}
// 结束时间条件
if (endTime != null && !endTime.isEmpty()) {
LocalDateTime end = LocalDateTime.parse(endTime, formatter);
wrapper.le(TimerExecutionHistory::getExecutionTime, end);
}
return wrapper;
}
/**
* 转换实体为DTO
*/
private TimerExecutionHistoryDto convertToDto(TimerExecutionHistory history) {
return TimerExecutionHistoryDto.builder()
.id(history.getId() != null ? history.getId().toString() : null)
.timerId(history.getTimerId())
.timerName(history.getTimerName())
.executionTime(history.getExecutionTime())
.success(history.getSuccess())
.result(history.getResult())
.errorMessage(history.getErrorMessage())
.actualPrompt(history.getActualPrompt())
.executionDuration(history.getExecutionDuration())
.build();
}
/**
* 统计定时器执行成功次数
*
* @param timerId 定时器ID
* @return 成功次数
*/
public long countSuccessExecution(String timerId) {
LambdaQueryWrapper<TimerExecutionHistory> wrapper = new LambdaQueryWrapper<>();
wrapper.eq(TimerExecutionHistory::getTimerId, timerId)
.eq(TimerExecutionHistory::getSuccess, 1);
return timerExecutionHistoryRepository.selectCount(wrapper);
}
/**
* 统计定时器执行失败次数
*
* @param timerId 定时器ID
* @return 失败次数
*/
public long countFailedExecution(String timerId) {
LambdaQueryWrapper<TimerExecutionHistory> wrapper = new LambdaQueryWrapper<>();
wrapper.eq(TimerExecutionHistory::getTimerId, timerId)
.eq(TimerExecutionHistory::getSuccess, 0);
return timerExecutionHistoryRepository.selectCount(wrapper);
}
}
\ No newline at end of file
package pangea.hiagent.service;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.cache.annotation.CacheEvict;
import org.springframework.cache.annotation.Cacheable;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import pangea.hiagent.model.PromptTemplate;
import pangea.hiagent.repository.PromptTemplateRepository;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/**
* 提示词模板服务类
* 负责提示词模板的管理和渲染
*/
@Slf4j
@Service
public class PromptTemplateService {
private final PromptTemplateRepository promptTemplateRepository;
// 模板变量正则表达式:{{variableName}}
private static final Pattern TEMPLATE_VARIABLE_PATTERN = Pattern.compile("\\{\\{(\\w+)\\}\\}");
public PromptTemplateService(PromptTemplateRepository promptTemplateRepository) {
this.promptTemplateRepository = promptTemplateRepository;
}
/**
* 创建提示词模板
*/
@Transactional
@CacheEvict(value = {"promptTemplates", "promptTemplate"}, allEntries = true)
public PromptTemplate createTemplate(PromptTemplate template) {
log.info("创建提示词模板: {}", template.getName());
// 设置默认值
if (template.getIsSystem() == null) {
template.setIsSystem(0); // 默认自定义模板
}
promptTemplateRepository.insert(template);
return template;
}
/**
* 更新提示词模板
*/
@Transactional
@CacheEvict(value = {"promptTemplates", "promptTemplate"}, allEntries = true)
public PromptTemplate updateTemplate(PromptTemplate template) {
log.info("更新提示词模板: {}", template.getId());
// 获取现有模板
PromptTemplate existingTemplate = promptTemplateRepository.selectById(template.getId());
if (existingTemplate != null) {
// 保留原始创建信息
template.setCreatedBy(existingTemplate.getCreatedBy());
template.setCreatedAt(existingTemplate.getCreatedAt());
// 系统模板不允许修改isSystem属性
template.setIsSystem(existingTemplate.getIsSystem());
}
promptTemplateRepository.updateById(template);
return template;
}
/**
* 删除提示词模板
*/
@Transactional
@CacheEvict(value = {"promptTemplates", "promptTemplate"}, allEntries = true)
public void deleteTemplate(String id) {
log.info("删除提示词模板: {}", id);
promptTemplateRepository.deleteById(id);
}
/**
* 获取提示词模板详情
*/
@Cacheable(value = "promptTemplate", key = "#id")
public PromptTemplate getTemplateById(String id) {
if (id == null || id.isEmpty()) {
log.warn("尝试使用无效ID获取提示词模板");
return null;
}
return promptTemplateRepository.selectById(id);
}
/**
* 获取提示词模板列表
*/
@Cacheable(value = "promptTemplates")
public List<PromptTemplate> listTemplates() {
List<PromptTemplate> templates = promptTemplateRepository.selectList(null);
log.info("获取到 {} 个提示词模板", templates != null ? templates.size() : 0);
return templates != null ? templates : List.of();
}
/**
* 根据类型获取提示词模板列表
*/
@Cacheable(value = "promptTemplates", key = "#templateType")
public List<PromptTemplate> listTemplatesByType(String templateType) {
LambdaQueryWrapper<PromptTemplate> wrapper = new LambdaQueryWrapper<>();
wrapper.eq(PromptTemplate::getTemplateType, templateType);
return promptTemplateRepository.selectList(wrapper);
}
/**
* 渲染提示词模板
* 替换模板中的变量为实际值
*/
public String renderTemplate(String templateContent, Map<String, Object> params) {
if (templateContent == null || templateContent.isEmpty()) {
return "";
}
log.debug("渲染提示词模板,参数: {}", params);
String renderedContent = templateContent;
Matcher matcher = TEMPLATE_VARIABLE_PATTERN.matcher(renderedContent);
while (matcher.find()) {
String variableName = matcher.group(1);
String placeholder = matcher.group(0);
Object value = params.get(variableName);
if (value != null) {
renderedContent = renderedContent.replace(placeholder, value.toString());
} else {
// 如果参数不存在,保留原始占位符
log.warn("模板变量 {} 未提供值", variableName);
}
}
log.debug("渲染后的提示词: {}", renderedContent);
return renderedContent;
}
/**
* 渲染提示词模板(根据模板ID)
*/
public String renderTemplateById(String templateId, Map<String, Object> params) {
PromptTemplate template = getTemplateById(templateId);
if (template == null) {
throw new IllegalArgumentException("提示词模板不存在: " + templateId);
}
return renderTemplate(template.getTemplateContent(), params);
}
/**
* 验证提示词模板语法
*/
public boolean validateTemplateSyntax(String templateContent) {
if (templateContent == null || templateContent.isEmpty()) {
return true;
}
// 简单验证:检查是否有未闭合的{{}}
int openCount = 0;
for (int i = 0; i < templateContent.length() - 1; i++) {
if (templateContent.charAt(i) == '{' && templateContent.charAt(i + 1) == '{') {
openCount++;
} else if (templateContent.charAt(i) == '}' && templateContent.charAt(i + 1) == '}') {
openCount--;
if (openCount < 0) {
return false;
}
}
}
return openCount == 0;
}
}
This diff is collapsed.
package pangea.hiagent.service;
import pangea.hiagent.model.ToolConfig;
import java.util.List;
import java.util.Map;
/**
* 工具配置服务接口
* 用于处理工具参数配置的读取和保存
*/
public interface ToolConfigService {
/**
* 根据工具名称获取参数配置
* @param toolName 工具名称
* @return 参数配置键值对
*/
Map<String, String> getToolParams(String toolName);
/**
* 根据工具名称和参数名称获取参数值
* @param toolName 工具名称
* @param paramName 参数名称
* @return 参数值
*/
String getParamValue(String toolName, String paramName);
/**
* 保存参数值
* @param toolName 工具名称
* @param paramName 参数名称
* @param paramValue 参数值
*/
void saveParamValue(String toolName, String paramName, String paramValue);
/**
* 获取所有工具配置
* @return 工具配置列表
*/
List<ToolConfig> getAllToolConfigs();
/**
* 根据工具名称和参数名称获取工具配置
* @param toolName 工具名称
* @param paramName 参数名称
* @return 工具配置对象
*/
ToolConfig getToolConfig(String toolName, String paramName);
/**
* 保存工具配置
* @param toolConfig 工具配置对象
* @return 保存后的工具配置对象
*/
ToolConfig saveToolConfig(ToolConfig toolConfig);
/**
* 删除工具配置
* @param id 配置ID
*/
void deleteToolConfig(String id);
/**
* 根据工具名称获取工具配置列表
* @param toolName 工具名称
* @return 工具配置列表
*/
List<ToolConfig> getToolConfigsByToolName(String toolName);
}
\ No newline at end of file
package pangea.hiagent.service.impl;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import pangea.hiagent.model.ToolConfig;
import pangea.hiagent.repository.ToolConfigRepository;
import pangea.hiagent.service.ToolConfigService;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* 工具配置服务实现类
* 用于处理工具参数配置的读取和保存
*/
@Slf4j
@Service
public class ToolConfigServiceImpl implements ToolConfigService {
@Autowired
private ToolConfigRepository toolConfigRepository;
@Override
public Map<String, String> getToolParams(String toolName) {
log.debug("获取工具参数配置,工具名称:{}", toolName);
Map<String, String> params = new HashMap<>();
try {
List<Map<String, Object>> paramValues = toolConfigRepository.findParamValuesByToolName(toolName);
for (Map<String, Object> paramValue : paramValues) {
String paramName = (String) paramValue.get("param_name");
String value = (String) paramValue.get("param_value");
params.put(paramName, value);
}
} catch (Exception e) {
log.error("获取工具参数配置失败:{}", e.getMessage(), e);
}
return params;
}
@Override
public String getParamValue(String toolName, String paramName) {
log.debug("获取工具参数值,工具名称:{},参数名称:{}", toolName, paramName);
try {
ToolConfig toolConfig = toolConfigRepository.findByToolNameAndParamName(toolName, paramName);
if (toolConfig != null) {
return toolConfig.getParamValue();
}
} catch (Exception e) {
log.error("获取工具参数值失败:{}", e.getMessage(), e);
}
return null;
}
@Override
public void saveParamValue(String toolName, String paramName, String paramValue) {
log.debug("保存工具参数值,工具名称:{},参数名称:{},参数值:{}", toolName, paramName, paramValue);
try {
ToolConfig toolConfig = toolConfigRepository.findByToolNameAndParamName(toolName, paramName);
if (toolConfig != null) {
toolConfig.setParamValue(paramValue);
toolConfigRepository.updateById(toolConfig);
} else {
// 如果配置不存在,创建新配置
toolConfig = new ToolConfig();
toolConfig.setToolName(toolName);
toolConfig.setParamName(paramName);
toolConfig.setParamValue(paramValue);
toolConfigRepository.insert(toolConfig);
}
} catch (Exception e) {
log.error("保存工具参数值失败:{}", e.getMessage(), e);
}
}
@Override
public List<ToolConfig> getAllToolConfigs() {
log.debug("获取所有工具配置");
try {
return toolConfigRepository.findAllActive();
} catch (Exception e) {
log.error("获取所有工具配置失败:{}", e.getMessage(), e);
return List.of();
}
}
@Override
public ToolConfig getToolConfig(String toolName, String paramName) {
log.debug("获取工具配置,工具名称:{},参数名称:{}", toolName, paramName);
try {
return toolConfigRepository.findByToolNameAndParamName(toolName, paramName);
} catch (Exception e) {
log.error("获取工具配置失败:{}", e.getMessage(), e);
return null;
}
}
@Override
public ToolConfig saveToolConfig(ToolConfig toolConfig) {
log.debug("保存工具配置:{}", toolConfig);
try {
if (toolConfig.getId() != null) {
toolConfigRepository.updateById(toolConfig);
} else {
// 检查是否已存在相同的工具名称和参数名称的配置
ToolConfig existingConfig = toolConfigRepository.findByToolNameAndParamName(
toolConfig.getToolName(), toolConfig.getParamName());
if (existingConfig != null) {
toolConfig.setId(existingConfig.getId());
toolConfigRepository.updateById(toolConfig);
} else {
toolConfigRepository.insert(toolConfig);
}
}
return toolConfig;
} catch (Exception e) {
log.error("保存工具配置失败:{}", e.getMessage(), e);
return null;
}
}
@Override
public void deleteToolConfig(String id) {
log.debug("删除工具配置,ID:{}", id);
try {
toolConfigRepository.deleteById(id);
} catch (Exception e) {
log.error("删除工具配置失败:{}", e.getMessage(), e);
}
}
@Override
public List<ToolConfig> getToolConfigsByToolName(String toolName) {
log.debug("根据工具名称获取工具配置列表,工具名称:{}", toolName);
try {
return toolConfigRepository.findByToolName(toolName);
} catch (Exception e) {
log.error("根据工具名称获取工具配置列表失败:{}", e.getMessage(), e);
return List.of();
}
}
}
\ No newline at end of file
...@@ -3,6 +3,7 @@ package pangea.hiagent.tools; ...@@ -3,6 +3,7 @@ package pangea.hiagent.tools;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.tool.annotation.Tool; import org.springframework.ai.tool.annotation.Tool;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import pangea.hiagent.tools.annotation.ToolParam;
/** /**
* 图表生成工具 * 图表生成工具
...@@ -12,6 +13,36 @@ import org.springframework.stereotype.Component; ...@@ -12,6 +13,36 @@ import org.springframework.stereotype.Component;
@Component @Component
public class ChartGenerationTool { public class ChartGenerationTool {
@ToolParam(
name = "maxDataPoints",
description = "最大数据点数量限制",
defaultValue = "100",
type = "integer",
required = true,
group = "chart"
)
private Integer maxDataPoints;
@ToolParam(
name = "percentageDecimalPlaces",
description = "百分比显示的小数位数",
defaultValue = "2",
type = "integer",
required = true,
group = "chart"
)
private Integer percentageDecimalPlaces;
@ToolParam(
name = "defaultSeriesName",
description = "默认数据系列名称",
defaultValue = "数据",
type = "string",
required = true,
group = "chart"
)
private String defaultSeriesName;
/** /**
* 生成柱状图 * 生成柱状图
* @param title 图表标题 * @param title 图表标题
...@@ -45,11 +76,16 @@ public class ChartGenerationTool { ...@@ -45,11 +76,16 @@ public class ChartGenerationTool {
return "错误:X轴标签数量与数据系列数量不匹配"; return "错误:X轴标签数量与数据系列数量不匹配";
} }
if (xAxisLabels.length > maxDataPoints) {
log.warn("数据点数量超过限制,当前数量:{},限制:{}", xAxisLabels.length, maxDataPoints);
return "错误:数据点数量超过限制,当前数量:" + xAxisLabels.length + ",限制:" + maxDataPoints;
}
// 生成图表描述 // 生成图表描述
StringBuilder chartDescription = new StringBuilder(); StringBuilder chartDescription = new StringBuilder();
chartDescription.append("柱状图生成成功:\n"); chartDescription.append("柱状图生成成功:\n");
chartDescription.append("标题: ").append(title).append("\n"); chartDescription.append("标题: ").append(title).append("\n");
chartDescription.append("数据系列: ").append(seriesName != null ? seriesName : "数据").append("\n"); chartDescription.append("数据系列: ").append(seriesName != null ? seriesName : defaultSeriesName).append("\n");
chartDescription.append("数据点数量: ").append(seriesData.length).append("\n"); chartDescription.append("数据点数量: ").append(seriesData.length).append("\n");
chartDescription.append("数据详情:\n"); chartDescription.append("数据详情:\n");
...@@ -166,8 +202,9 @@ public class ChartGenerationTool { ...@@ -166,8 +202,9 @@ public class ChartGenerationTool {
for (int i = 0; i < labels.length; i++) { for (int i = 0; i < labels.length; i++) {
double percentage = total > 0 ? (values[i] / total) * 100 : 0; double percentage = total > 0 ? (values[i] / total) * 100 : 0;
String format = String.format("%%.%df", percentageDecimalPlaces);
chartDescription.append(" ").append(labels[i]).append(": ").append(values[i]) chartDescription.append(" ").append(labels[i]).append(": ").append(values[i])
.append(" (").append(String.format("%.2f", percentage)).append("%)\n"); .append(" (").append(String.format(format, percentage)).append("%)\n");
} }
log.info("饼图生成完成,包含 {} 个数据项", values.length); log.info("饼图生成完成,包含 {} 个数据项", values.length);
......
...@@ -3,6 +3,7 @@ package pangea.hiagent.tools; ...@@ -3,6 +3,7 @@ package pangea.hiagent.tools;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import org.springframework.ai.tool.annotation.Tool; import org.springframework.ai.tool.annotation.Tool;
import pangea.hiagent.tools.annotation.ToolParam;
import java.time.LocalDateTime; import java.time.LocalDateTime;
import java.time.LocalDate; import java.time.LocalDate;
...@@ -16,17 +17,36 @@ import java.time.format.DateTimeFormatter; ...@@ -16,17 +17,36 @@ import java.time.format.DateTimeFormatter;
@Component @Component
public class DateTimeTools { public class DateTimeTools {
@ToolParam(
name = "dateTimeFormat",
description = "日期时间格式",
defaultValue = "yyyy-MM-dd HH:mm:ss",
type = "string",
required = true,
group = "datetime"
)
private String dateTimeFormat;
@ToolParam(
name = "dateFormat",
description = "日期格式",
defaultValue = "yyyy-MM-dd",
type = "string",
required = true,
group = "datetime"
)
private String dateFormat;
@Tool(description = "获取当前日期和时间") @Tool(description = "获取当前日期和时间")
public String getCurrentDateTime() { public String getCurrentDateTime() {
String dateTime = LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss")); String dateTime = LocalDateTime.now().format(DateTimeFormatter.ofPattern(dateTimeFormat));
log.debug("获取当前日期时间: {}", dateTime); log.debug("获取当前日期时间: {}", dateTime);
return dateTime; return dateTime;
} }
@Tool(description = "获取当前日期") @Tool(description = "获取当前日期")
public String getCurrentDate() { public String getCurrentDate() {
String date = LocalDate.now().format(DateTimeFormatter.ofPattern("yyyy-MM-dd")); String date = LocalDate.now().format(DateTimeFormatter.ofPattern(dateFormat));
log.debug("获取当前日期: {}", date); log.debug("获取当前日期: {}", date);
return date; return date;
} }
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment