Commit 198ca2f1 authored by ligaowei's avatar ligaowei

fix(react): 修正流式处理时关键词解析和日志记录

- 在DefaultReactExecutor中添加tokenTextSegmenter.finishInput()以完善输入处理
- 注释更新以说明流式处理过程中关键词实时解析调整
- 在TokenTextSegmenter中增加对输入字符的日志打印,便于调试和监控
- 优化分段标识匹配逻辑前的输入状态记录提升可观察性
parent 3534a7fe
...@@ -193,6 +193,9 @@ backend/logs/ ...@@ -193,6 +193,9 @@ backend/logs/
backend/storage/ backend/storage/
backend/uploads/ backend/uploads/
backend/hiagentdb.mv.db backend/hiagentdb.mv.db
# H2 database files
backend/src/main/resources/hiagent_dev_db.*
./data/hiagent_dev_db.*
# Frontend files # Frontend files
frontend/node_modules/ frontend/node_modules/
......
# SSE 心跳保活机制改进方案
## 问题描述
之前对话返回信息过长时,会因为流式响应超时(60秒无消息)而显示"[错误] 流式输出超时,请重试",导致SSE连接被关闭。
## 解决方案
### 前端改进 (ChatArea.vue)
#### 1. 改进超时检测机制
- **原来**: 简单的60秒全局超时,无任何数据到达就关闭
- **现在**: 使用心跳保活机制,定期检查是否收到心跳
```typescript
// 关键参数
const HEARTBEAT_TIMEOUT = 60000; // 60秒无心跳则为超时
const HEARTBEAT_CHECK_INTERVAL = 5000; // 每5秒检查一次
let lastHeartbeatTime = Date.now(); // 记录最后一次心跳时间
```
#### 2. 新增心跳事件处理
`processSSELine` 函数中新增 heartbeat case:
```typescript
case "heartbeat":
// 收到心跳事件,重置超时计时器
resetStreamTimeout();
// 心跳事件本身不处理,只用于保活连接
console.debug("[心跳] 收到心跳事件,连接保活");
return false;
```
#### 3. 改进的超时判断逻辑
```typescript
const resetStreamTimeout = () => {
clearStreamTimeout();
lastHeartbeatTime = Date.now(); // 更新最后心跳时间
streamTimeoutTimer = setTimeout(() => {
if (!isStreamComplete) {
// 检查是否在指定时间内收到过心跳或数据
const timeSinceLastHeartbeat = Date.now() - lastHeartbeatTime;
if (timeSinceLastHeartbeat >= HEARTBEAT_TIMEOUT) {
// 真正的超时,关闭连接
isStreamComplete = true;
reader.cancel();
// ... 显示超时错误
} else {
// 还没超时,继续检查
resetStreamTimeout();
}
}
}, HEARTBEAT_CHECK_INTERVAL);
};
```
**工作原理**
1. 每当收到token、心跳或其他数据时,重置超时计时器并更新`lastHeartbeatTime`
2. 每5秒检查一次是否超时
3. 只有当最后一次心跳/数据距现在超过60秒时,才真正认为超时并关闭连接
4. 否则继续检查,保持连接活跃
---
### 后端改进 (UserSseService.java)
#### 1. 调整心跳发送频率
- **原来**: 每30秒发送一次心跳
- **现在**: 每20秒发送一次心跳
```java
}, 20, 20, TimeUnit.SECONDS); // 每20秒发送一次心跳,确保前端60秒超时前至少收到2次心跳
```
**原因**: 确保在前端60秒超时前,至少能收到2次心跳,增加可靠性
#### 2. 增强心跳日志
```java
long heartbeatTimestamp = System.currentTimeMillis();
emitter.send(SseEmitter.event().name("heartbeat").data(heartbeatTimestamp));
log.debug("[心跳] 成功发送心跳事件,时间戳: {}", heartbeatTimestamp);
```
#### 3. 心跳机制的完整生命周期
- **启动**: 创建连接时调用 `startHeartbeat()`
- **运行**: 每20秒检查一次连接有效性,如果有效则发送心跳
- **停止**: 在连接完成/超时/错误时自动取消心跳任务
```java
// 注册回调,在连接完成时取消心跳任务
emitter.onCompletion(() -> {
if (heartbeatTask != null && !heartbeatTask.isCancelled()) {
heartbeatTask.cancel(true);
log.debug("SSE连接完成,心跳任务已取消");
}
});
// 类似的处理: onTimeout(), onError()
```
---
## 工作流程
### 正常情况(消息持续到达)
```
时间轴: 0s ─── 10s ─── 20s ─── 30s ─── 40s ─── 50s ─── 60s
│ │ │
token token token
│ │ │
重置超时 重置超时 重置超时
(60s) (60s) (60s)
```
连接保持活跃,不会超时。
### 有心跳但消息间隔长(解决长时间处理问题)
```
时间轴: 0s ─── 10s ─── 20s ─── 30s ─── 40s ─── 50s ─── 60s ─── 70s ─── 80s
token 心跳 心跳 心跳 token
│ │ │ │ │
重置超时 重置超时 重置超时 重置超时 重置超时
(60s) (60s) (60s) (60s) (60s)
```
心跳每20秒发送一次,保持连接活跃,即使消息处理需要很长时间。
### 真正超时的情况(心跳也断开)
```
时间轴: 0s ─── 20s ─── 40s ─── 60s ─── 70s(超时)
token 心跳 心跳 [无更多心跳]
│ │ │
重置超时 重置超时 重置超时
(60s) (60s) (60s)
超过60秒无响应,关闭连接
```
当网络真的中断或服务器崩溃时,经过60秒无任何响应,客户端才会超时并提示用户。
---
## 关键时间参数
| 参数 | 值 | 说明 |
|------|-----|------|
| 心跳间隔(后端)| 20秒 | 后端定期向客户端发送心跳 |
| 前端超时时间 | 60秒 | 前端在60秒内无心跳/数据则超时 |
| 检查间隔(前端)| 5秒 | 前端每5秒检查一次是否超时 |
| SSE连接超时(后端)| 120秒 | Spring框架层面的连接超时 |
**设计原理**: 心跳间隔 (20s) < 前端超时时间 (60s) / 2,保证前端超时前至少收到2次心跳。
---
## 对话结束和错误处理
### 对话正常结束
1. 后端发送 `complete` 事件
2. 前端收到 `complete` 事件,调用 `clearStreamTimeout()`
3. 流式处理完成,关闭所有计时器和监听
### 发生错误时
1. 后端发送 `error` 事件
2. 前端收到 `error` 事件,调用 `clearStreamTimeout()`
3. 关闭连接和心跳检查,显示错误信息
### 心跳中断且超时
1. 前端在60秒内未收到任何心跳/数据
2. 前端认定连接超时,取消读取并显示错误
3. 用户可以点击重试按钮重新发送消息
---
## 调试
### 后端日志
```
[心跳] 成功发送心跳事件,时间戳: 1640000000000
```
### 前端日志
```
[心跳] 收到心跳事件,连接保活
[SSE完成事件] {type: "complete", ...}
```
### 超时测试
1. 故意让后端处理延迟超过60秒的请求
2. 观察是否能收到心跳事件
3. 连接应该保持活跃,不会因为消息间隔长而断开
4. 直到对话完成或心跳真的中断,才会关闭连接
---
## 总结
这个改进方案通过引入心跳保活机制,解决了以下问题:
✅ 长时间处理的对话不会因为超时而意外断开
✅ 心跳中断才会真正关闭连接(而不是任意时间无消息就关闭)
✅ 流式响应自然结束或错误发生时,及时清理资源
✅ 系统更加稳定可靠,特别是对于复杂AI任务处理
2025-12-25 11:27:45.456378+08:00 jdbc[3]: exception
org.h2.jdbc.JdbcSQLSyntaxErrorException: Table "TOOL_CONFIGS" not found (this database is empty); SQL statement:
SELECT * FROM tool_configs WHERE tool_name = ? AND param_name = ? AND deleted = 0 LIMIT 1 [42104-224]
2025-12-25 11:27:45.630100+08:00 jdbc[3]: exception
org.h2.jdbc.JdbcSQLSyntaxErrorException: Table "TOOL_CONFIGS" not found (this database is empty); SQL statement:
SELECT * FROM tool_configs WHERE tool_name = ? AND param_name = ? AND deleted = 0 LIMIT 1 [42104-224]
2025-12-25 11:27:45.657786+08:00 jdbc[3]: exception
org.h2.jdbc.JdbcSQLSyntaxErrorException: Table "TOOL_CONFIGS" not found (this database is empty); SQL statement:
SELECT * FROM tool_configs WHERE tool_name = ? AND param_name = ? AND deleted = 0 LIMIT 1 [42104-224]
2025-12-25 11:30:31.913327+08:00 jdbc[3]: exception
org.h2.jdbc.JdbcSQLSyntaxErrorException: Table "TOOL_CONFIGS" not found (this database is empty); SQL statement:
SELECT * FROM tool_configs WHERE tool_name = ? AND param_name = ? AND deleted = 0 LIMIT 1 [42104-224]
2025-12-25 11:30:32.084087+08:00 jdbc[3]: exception
org.h2.jdbc.JdbcSQLSyntaxErrorException: Table "TOOL_CONFIGS" not found (this database is empty); SQL statement:
SELECT * FROM tool_configs WHERE tool_name = ? AND param_name = ? AND deleted = 0 LIMIT 1 [42104-224]
2025-12-25 11:30:32.117664+08:00 jdbc[3]: exception
org.h2.jdbc.JdbcSQLSyntaxErrorException: Table "TOOL_CONFIGS" not found (this database is empty); SQL statement:
SELECT * FROM tool_configs WHERE tool_name = ? AND param_name = ? AND deleted = 0 LIMIT 1 [42104-224]
...@@ -87,12 +87,10 @@ public class ReActAgentProcessor extends BaseAgentProcessor { ...@@ -87,12 +87,10 @@ public class ReActAgentProcessor extends BaseAgentProcessor {
defaultReactExecutor.addReactCallback(defaultReactCallback); defaultReactExecutor.addReactCallback(defaultReactCallback);
} }
// 使用ReAct执行器执行流程,传递Agent对象以支持记忆功能 // 使用ReAct执行器执行流程,传递Agent对象和用户ID以支持记忆功能
String finalAnswer = defaultReactExecutor.execute(client, userMessage, tools, agent); String finalAnswer = defaultReactExecutor.execute(client, userMessage, tools, agent, userId);
// 将助理回复添加到ChatMemory // 助手回复已经由执行器保存到内存中,不需要重复保存
String sessionId = generateSessionId(agent, userId);
addAssistantMessageToMemory(sessionId, finalAnswer);
return finalAnswer; return finalAnswer;
} catch (Exception e) { } catch (Exception e) {
...@@ -138,8 +136,8 @@ public class ReActAgentProcessor extends BaseAgentProcessor { ...@@ -138,8 +136,8 @@ public class ReActAgentProcessor extends BaseAgentProcessor {
return; return;
} }
// 使用ReAct执行器流式执行流程,传递Agent对象以支持记忆功能 // 使用ReAct执行器流式执行流程,传递Agent对象以支持记忆功能和用户ID以确保上下文传播
defaultReactExecutor.executeStream(client, userMessage, tools, tokenConsumer, agent); defaultReactExecutor.executeStream(client, userMessage, tools, tokenConsumer, agent, userId);
} catch (Exception e) { } catch (Exception e) {
agentErrorHandler.handleStreamError(e, tokenConsumer, "流式处理ReAct请求时发生错误"); agentErrorHandler.handleStreamError(e, tokenConsumer, "流式处理ReAct请求时发生错误");
agentErrorHandler.ensureCompletionCallback(tokenConsumer, "处理请求时发生错误: " + e.getMessage()); agentErrorHandler.ensureCompletionCallback(tokenConsumer, "处理请求时发生错误: " + e.getMessage());
......
...@@ -13,6 +13,7 @@ import pangea.hiagent.memory.MemoryService; ...@@ -13,6 +13,7 @@ import pangea.hiagent.memory.MemoryService;
import pangea.hiagent.model.Agent; import pangea.hiagent.model.Agent;
import pangea.hiagent.tool.AgentToolManager; import pangea.hiagent.tool.AgentToolManager;
import pangea.hiagent.tool.impl.DateTimeTools; import pangea.hiagent.tool.impl.DateTimeTools;
import pangea.hiagent.common.utils.UserUtils;
import java.util.List; import java.util.List;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
...@@ -121,6 +122,13 @@ public class DefaultReactExecutor implements ReactExecutor { ...@@ -121,6 +122,13 @@ public class DefaultReactExecutor implements ReactExecutor {
@Override @Override
public String execute(ChatClient chatClient, String userInput, List<Object> tools, Agent agent) { public String execute(ChatClient chatClient, String userInput, List<Object> tools, Agent agent) {
// 调用带用户ID的方法,首先尝试获取当前用户ID
String userId = UserUtils.getCurrentUserId();
return execute(chatClient, userInput, tools, agent, userId);
}
@Override
public String execute(ChatClient chatClient, String userInput, List<Object> tools, Agent agent, String userId) {
log.info("开始执行ReAct流程,用户输入: {}", userInput); log.info("开始执行ReAct流程,用户输入: {}", userInput);
stepCounter.set(0); stepCounter.set(0);
...@@ -128,9 +136,9 @@ public class DefaultReactExecutor implements ReactExecutor { ...@@ -128,9 +136,9 @@ public class DefaultReactExecutor implements ReactExecutor {
List<Object> agentTools = getAgentTools(agent); List<Object> agentTools = getAgentTools(agent);
try { try {
triggerThinkStep("开始处理用户请求: " + userInput); // triggerThinkStep("开始处理用户请求: " + userInput);
Prompt prompt = buildPromptWithHistory(DEFAULT_SYSTEM_PROMPT, userInput, agent); Prompt prompt = buildPromptWithHistory(DEFAULT_SYSTEM_PROMPT, userInput, agent, userId);
ChatResponse response = chatClient.prompt(prompt) ChatResponse response = chatClient.prompt(prompt)
.tools(agentTools.toArray()) .tools(agentTools.toArray())
...@@ -139,11 +147,14 @@ public class DefaultReactExecutor implements ReactExecutor { ...@@ -139,11 +147,14 @@ public class DefaultReactExecutor implements ReactExecutor {
String responseText = response.getResult().getOutput().getText(); String responseText = response.getResult().getOutput().getText();
triggerObservationStep(responseText); // triggerObservationStep(responseText);
log.info("最终答案: {}", responseText); log.info("最终答案: {}", responseText);
triggerFinalAnswerStep(responseText); // triggerFinalAnswerStep(responseText);
// 保存助手回复到内存,使用提供的用户ID
saveAssistantResponseToMemory(agent, responseText, userId);
return responseText; return responseText;
} catch (Exception e) { } catch (Exception e) {
...@@ -171,13 +182,30 @@ public class DefaultReactExecutor implements ReactExecutor { ...@@ -171,13 +182,30 @@ public class DefaultReactExecutor implements ReactExecutor {
* @return 构建好的提示词对象 * @return 构建好的提示词对象
*/ */
private Prompt buildPromptWithHistory(String systemPrompt, String userInput, Agent agent) { private Prompt buildPromptWithHistory(String systemPrompt, String userInput, Agent agent) {
return buildPromptWithHistory(systemPrompt, userInput, agent, null);
}
/**
* 构建带有历史记录的提示词
*
* @param systemPrompt 系统提示词
* @param userInput 用户输入
* @param agent 智能体对象
* @param userId 用户ID(可选,如果为null则自动获取)
* @return 构建好的提示词对象
*/
private Prompt buildPromptWithHistory(String systemPrompt, String userInput, Agent agent, String userId) {
List<org.springframework.ai.chat.messages.Message> messages = new ArrayList<>(); List<org.springframework.ai.chat.messages.Message> messages = new ArrayList<>();
messages.add(new SystemMessage(systemPrompt)); messages.add(new SystemMessage(systemPrompt));
if (agent != null) { if (agent != null) {
try { try {
String sessionId = memoryService.generateSessionId(agent); // 如果没有提供用户ID,则尝试获取当前用户ID
if (userId == null) {
userId = UserUtils.getCurrentUserId();
}
String sessionId = memoryService.generateSessionId(agent, userId);
int historyLength = agent.getHistoryLength() != null ? agent.getHistoryLength() : 10; int historyLength = agent.getHistoryLength() != null ? agent.getHistoryLength() : 10;
...@@ -199,6 +227,13 @@ public class DefaultReactExecutor implements ReactExecutor { ...@@ -199,6 +227,13 @@ public class DefaultReactExecutor implements ReactExecutor {
@Override @Override
public void executeStream(ChatClient chatClient, String userInput, List<Object> tools, Consumer<String> tokenConsumer, Agent agent) { public void executeStream(ChatClient chatClient, String userInput, List<Object> tools, Consumer<String> tokenConsumer, Agent agent) {
// 调用带用户ID的方法,但首先尝试获取当前用户ID
String userId = UserUtils.getCurrentUserId();
executeStream(chatClient, userInput, tools, tokenConsumer, agent, userId);
}
@Override
public void executeStream(ChatClient chatClient, String userInput, List<Object> tools, Consumer<String> tokenConsumer, Agent agent, String userId) {
log.info("使用stream()方法处理ReAct流程,支持真正的流式输出"); log.info("使用stream()方法处理ReAct流程,支持真正的流式输出");
stepCounter.set(0); stepCounter.set(0);
...@@ -208,9 +243,9 @@ public class DefaultReactExecutor implements ReactExecutor { ...@@ -208,9 +243,9 @@ public class DefaultReactExecutor implements ReactExecutor {
StringBuilder fullResponse = new StringBuilder(); StringBuilder fullResponse = new StringBuilder();
try { try {
triggerThinkStep("开始处理用户请求: " + userInput); // triggerThinkStep("开始处理用户请求: " + userInput);
Prompt prompt = buildPromptWithHistory(DEFAULT_SYSTEM_PROMPT, userInput, agent); Prompt prompt = buildPromptWithHistory(DEFAULT_SYSTEM_PROMPT, userInput, agent, userId);
chatClient.prompt(prompt) chatClient.prompt(prompt)
.tools(agentTools.toArray()) .tools(agentTools.toArray())
...@@ -219,7 +254,7 @@ public class DefaultReactExecutor implements ReactExecutor { ...@@ -219,7 +254,7 @@ public class DefaultReactExecutor implements ReactExecutor {
.subscribe( .subscribe(
chatResponse -> handleTokenResponse(chatResponse, tokenConsumer, fullResponse), chatResponse -> handleTokenResponse(chatResponse, tokenConsumer, fullResponse),
throwable -> handleStreamError(throwable, tokenConsumer), throwable -> handleStreamError(throwable, tokenConsumer),
() -> handleStreamCompletion(tokenConsumer, fullResponse, agent) () -> handleStreamCompletion(tokenConsumer, fullResponse, agent, userId)
); );
} catch (Exception e) { } catch (Exception e) {
...@@ -248,7 +283,8 @@ public class DefaultReactExecutor implements ReactExecutor { ...@@ -248,7 +283,8 @@ public class DefaultReactExecutor implements ReactExecutor {
tokenConsumer.accept(token); tokenConsumer.accept(token);
} }
tokenTextSegmenter.inputChar(token); // tokenTextSegmenter.inputChar(token);
// tokenTextSegmenter.finishInput();
// 改进:在流式处理过程中实时解析关键词 // 改进:在流式处理过程中实时解析关键词
// processTokenForStepsWithFullResponse(token, fullResponse.toString()); // processTokenForStepsWithFullResponse(token, fullResponse.toString());
...@@ -266,16 +302,30 @@ public class DefaultReactExecutor implements ReactExecutor { ...@@ -266,16 +302,30 @@ public class DefaultReactExecutor implements ReactExecutor {
* @param agent 智能体对象 * @param agent 智能体对象
*/ */
private void handleStreamCompletion(Consumer<String> tokenConsumer, StringBuilder fullResponse, Agent agent) { private void handleStreamCompletion(Consumer<String> tokenConsumer, StringBuilder fullResponse, Agent agent) {
// 调用带用户ID的版本,使用当前线程的用户ID
String userId = UserUtils.getCurrentUserId();
handleStreamCompletion(tokenConsumer, fullResponse, agent, userId);
}
/**
* 处理流式响应完成事件
*
* @param tokenConsumer token消费者
* @param fullResponse 完整响应内容
* @param agent 智能体对象
* @param userId 用户ID
*/
private void handleStreamCompletion(Consumer<String> tokenConsumer, StringBuilder fullResponse, Agent agent, String userId) {
try { try {
log.info("流式处理完成"); log.info("流式处理完成");
// 检查是否已经处理了Final Answer,如果没有,则将整个响应作为最终答案 // 检查是否已经处理了Final Answer,如果没有,则将整个响应作为最终答案
String responseStr = fullResponse.toString(); String responseStr = fullResponse.toString();
if (!hasFinalAnswerBeenTriggered(responseStr)) { if (!hasFinalAnswerBeenTriggered(responseStr)) {
triggerFinalAnswerStep(responseStr); // triggerFinalAnswerStep(responseStr);
} }
saveAssistantResponseToMemory(agent, responseStr); saveAssistantResponseToMemory(agent, responseStr, userId);
sendCompletionEvent(tokenConsumer, responseStr); sendCompletionEvent(tokenConsumer, responseStr);
} catch (Exception e) { } catch (Exception e) {
...@@ -305,11 +355,12 @@ public class DefaultReactExecutor implements ReactExecutor { ...@@ -305,11 +355,12 @@ public class DefaultReactExecutor implements ReactExecutor {
* *
* @param agent 智能体对象 * @param agent 智能体对象
* @param response 助手的回复内容 * @param response 助手的回复内容
* @param userId 用户ID
*/ */
private void saveAssistantResponseToMemory(Agent agent, String response) { private void saveAssistantResponseToMemory(Agent agent, String response, String userId) {
if (agent != null) { if (agent != null) {
try { try {
String sessionId = memoryService.generateSessionId(agent); String sessionId = memoryService.generateSessionId(agent, userId);
memoryService.addAssistantMessageToMemory(sessionId, response); memoryService.addAssistantMessageToMemory(sessionId, response);
} catch (Exception e) { } catch (Exception e) {
log.warn("保存助理回复到内存时发生错误: {}", e.getMessage()); log.warn("保存助理回复到内存时发生错误: {}", e.getMessage());
......
...@@ -15,16 +15,40 @@ public interface ReactExecutor { ...@@ -15,16 +15,40 @@ public interface ReactExecutor {
* @param chatClient ChatClient实例 * @param chatClient ChatClient实例
* @param userInput 用户输入 * @param userInput 用户输入
* @param tools 工具列表 * @param tools 工具列表
* @param agent Agent对象
* @return 最终答案 * @return 最终答案
*/ */
String execute(ChatClient chatClient, String userInput, List<Object> tools, Agent agent); String execute(ChatClient chatClient, String userInput, List<Object> tools, Agent agent);
/**
* 执行ReAct流程(同步方式)
* @param chatClient ChatClient实例
* @param userInput 用户输入
* @param tools 工具列表
* @param agent Agent对象
* @param userId 用户ID
* @return 最终答案
*/
String execute(ChatClient chatClient, String userInput, List<Object> tools, Agent agent, String userId);
/** /**
* 流式执行ReAct流程 * 流式执行ReAct流程
* @param chatClient ChatClient实例 * @param chatClient ChatClient实例
* @param userInput 用户输入 * @param userInput 用户输入
* @param tools 工具列表 * @param tools 工具列表
* @param tokenConsumer token处理回调函数 * @param tokenConsumer token处理回调函数
* @param agent Agent对象
* @param userId 用户ID
*/
void executeStream(ChatClient chatClient, String userInput, List<Object> tools, Consumer<String> tokenConsumer, Agent agent, String userId);
/**
* 流式执行ReAct流程(旧方法,保持向后兼容)
* @param chatClient ChatClient实例
* @param userInput 用户输入
* @param tools 工具列表
* @param tokenConsumer token处理回调函数
* @param agent Agent对象
*/ */
void executeStream(ChatClient chatClient, String userInput, List<Object> tools, Consumer<String> tokenConsumer, Agent agent); void executeStream(ChatClient chatClient, String userInput, List<Object> tools, Consumer<String> tokenConsumer, Agent agent);
......
...@@ -23,15 +23,9 @@ public class TokenTextSegmenter { ...@@ -23,15 +23,9 @@ public class TokenTextSegmenter {
// 当前缓存的输入字符 // 当前缓存的输入字符
private StringBuilder currentInputBuffer; private StringBuilder currentInputBuffer;
// 已匹配到的分段标识
private String matchedMarker;
// 分段内容起始索引
private int segmentContentStartIndex;
public TokenTextSegmenter() { public TokenTextSegmenter() {
currentInputBuffer = new StringBuilder(); currentInputBuffer = new StringBuilder();
matchedMarker = null;
segmentContentStartIndex = 0;
} }
/** /**
...@@ -40,38 +34,36 @@ public class TokenTextSegmenter { ...@@ -40,38 +34,36 @@ public class TokenTextSegmenter {
* @param inputChar 单个输入字符 * @param inputChar 单个输入字符
* @return 当分割出有效文本段时返回该段内容,否则返回null * @return 当分割出有效文本段时返回该段内容,否则返回null
*/ */
public void inputChar(String inputChar) { public synchronized void inputChar(String inputChar) {
// 输入验证
if (inputChar == null) {
return;
}
// 将字符加入缓存 // 将字符加入缓存
currentInputBuffer.append(inputChar); currentInputBuffer.append(inputChar);
String currentBufferStr = currentInputBuffer.toString(); String currentBufferStr = currentInputBuffer.toString();
// 1. 未匹配到标识时,检测是否出现分段标识 log.info("【输入字符】: {}", currentBufferStr);
if (matchedMarker == null) {
for (String marker : SEGMENT_MARKERS) { // 检查当前缓冲区中是否出现任何SEGMENT_MARKERS字段
if (currentBufferStr.endsWith(marker)) { for (String marker : SEGMENT_MARKERS) {
// 匹配到标识,记录信息 int markerIndex = currentBufferStr.indexOf(marker);
matchedMarker = marker; if (markerIndex != -1) {
segmentContentStartIndex = currentBufferStr.length(); // 找到SEGMENT_MARKERS字段,截取该字段之前的文本进行输出
// 输出标识本身(可选,根据需求决定是否包含标识) String contentBeforeMarker = currentBufferStr.substring(0, markerIndex);
log.info("【识别到分段标识】: {}", matchedMarker);
} // 输出截取的内容
} outputSegment(marker, contentBeforeMarker);
}
// 2. 已匹配到标识,检测是否出现下一个标识(或文本结束) // 重置缓冲区,保留标识符及之后的内容
else { currentInputBuffer = new StringBuilder(currentBufferStr.substring(markerIndex));
for (String marker : SEGMENT_MARKERS) {
if (!marker.equals(matchedMarker) && currentBufferStr.contains(marker)) { log.info("【识别到分段标识】: {}", marker);
// 找到下一个标识,截取当前分段内容 break; // 找到第一个标识后就处理并退出,避免重复处理
int nextMarkerStartIndex = currentBufferStr.indexOf(marker);
String segmentContent = currentBufferStr.substring(segmentContentStartIndex, nextMarkerStartIndex)
.trim();
// 重置状态,准备处理下一个分段
resetSegmentState(nextMarkerStartIndex);
// 输出当前分段内容
outputSegment(matchedMarker, segmentContent);
}
} }
} }
// 如果没有找到SEGMENT_MARKERS字段,则不输出,等待更多输入
} }
/** /**
...@@ -79,11 +71,14 @@ public class TokenTextSegmenter { ...@@ -79,11 +71,14 @@ public class TokenTextSegmenter {
* *
* @return 最后一个分段的内容,无则返回null * @return 最后一个分段的内容,无则返回null
*/ */
public void finishInput() { public synchronized void finishInput() {
if (matchedMarker != null && segmentContentStartIndex < currentInputBuffer.length()) { // 如果缓冲区还有内容,输出全部剩余内容
String lastSegmentContent = currentInputBuffer.substring(segmentContentStartIndex).trim(); if (currentInputBuffer.length() > 0) {
resetSegmentState(currentInputBuffer.length()); String remainingContent = currentInputBuffer.toString();
outputSegment(matchedMarker, lastSegmentContent); // 输出剩余的全部内容,使用一个通用标记或保持原格式
outputSegment("Final_Content:", remainingContent);
// 清空缓冲区
currentInputBuffer.setLength(0);
} }
} }
...@@ -96,8 +91,6 @@ public class TokenTextSegmenter { ...@@ -96,8 +91,6 @@ public class TokenTextSegmenter {
// 保留未处理的部分,用于下一个分段 // 保留未处理的部分,用于下一个分段
String remainingStr = currentInputBuffer.substring(newStartIndex); String remainingStr = currentInputBuffer.substring(newStartIndex);
currentInputBuffer = new StringBuilder(remainingStr); currentInputBuffer = new StringBuilder(remainingStr);
matchedMarker = null;
segmentContentStartIndex = 0;
} }
/** /**
...@@ -109,6 +102,9 @@ public class TokenTextSegmenter { ...@@ -109,6 +102,9 @@ public class TokenTextSegmenter {
*/ */
private void outputSegment(String marker, String content) { private void outputSegment(String marker, String content) {
log.info("【分段内容】{}: {}", marker, content); log.info("【分段内容】{}: {}", marker, content);
workPanelCollector.addEvent(null); // 根据实际需求处理事件,这里可能需要创建适当的事件对象而不是传入null
// workPanelCollector.addEvent(null); // 临时注释掉可能引发问题的调用
// 或者创建一个适当的事件对象
// 例如:workPanelCollector.addEvent(new WorkPanelEvent(marker, content));
} }
} }
...@@ -14,6 +14,7 @@ import pangea.hiagent.model.Agent; ...@@ -14,6 +14,7 @@ import pangea.hiagent.model.Agent;
import pangea.hiagent.tool.AgentToolManager; import pangea.hiagent.tool.AgentToolManager;
import pangea.hiagent.web.dto.AgentRequest; import pangea.hiagent.web.dto.AgentRequest;
import pangea.hiagent.workpanel.event.EventService; import pangea.hiagent.workpanel.event.EventService;
import pangea.hiagent.common.utils.AsyncUserContextDecorator;
import jakarta.servlet.http.HttpServletResponse; import jakarta.servlet.http.HttpServletResponse;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.LinkedBlockingQueue;
...@@ -34,6 +35,7 @@ public class AgentChatService { ...@@ -34,6 +35,7 @@ public class AgentChatService {
private final StreamRequestService streamRequestService; private final StreamRequestService streamRequestService;
private final AgentToolManager agentToolManager; private final AgentToolManager agentToolManager;
private final UserSseService workPanelSseService; private final UserSseService workPanelSseService;
private final pangea.hiagent.web.service.AgentService agentService;
public AgentChatService( public AgentChatService(
EventService eventService, EventService eventService,
...@@ -42,13 +44,15 @@ public class AgentChatService { ...@@ -42,13 +44,15 @@ public class AgentChatService {
AgentProcessorFactory agentProcessorFactory, AgentProcessorFactory agentProcessorFactory,
StreamRequestService streamRequestService, StreamRequestService streamRequestService,
AgentToolManager agentToolManager, AgentToolManager agentToolManager,
UserSseService workPanelSseService) { UserSseService workPanelSseService,
pangea.hiagent.web.service.AgentService agentService) {
this.chatErrorHandler = chatErrorHandler; this.chatErrorHandler = chatErrorHandler;
this.agentValidationService = agentValidationService; this.agentValidationService = agentValidationService;
this.agentProcessorFactory = agentProcessorFactory; this.agentProcessorFactory = agentProcessorFactory;
this.streamRequestService = streamRequestService; this.streamRequestService = streamRequestService;
this.agentToolManager = agentToolManager; this.agentToolManager = agentToolManager;
this.workPanelSseService = workPanelSseService; this.workPanelSseService = workPanelSseService;
this.agentService = agentService;
} }
// 专用线程池配置 - 使用静态变量确保线程池在整个应用中是单例的 // 专用线程池配置 - 使用静态变量确保线程池在整个应用中是单例的
...@@ -126,7 +130,8 @@ public class AgentChatService { ...@@ -126,7 +130,8 @@ public class AgentChatService {
final String finalUserId = userId; final String finalUserId = userId;
// 异步处理对话,避免阻塞HTTP连接 // 异步处理对话,避免阻塞HTTP连接
executorService.execute(() -> { // 使用用户上下文装饰器来确保在异步线程中也能获取到用户信息
executorService.execute(AsyncUserContextDecorator.wrapWithContext(() -> {
try { try {
processChatRequest(emitter, agentId, chatRequest, finalUserId); processChatRequest(emitter, agentId, chatRequest, finalUserId);
} catch (Exception e) { } catch (Exception e) {
...@@ -138,13 +143,14 @@ public class AgentChatService { ...@@ -138,13 +143,14 @@ public class AgentChatService {
log.warn("响应已提交,无法发送处理请求错误信息"); log.warn("响应已提交,无法发送处理请求错误信息");
} }
} }
}); }));
return emitter; return emitter;
} }
/** /**
* 处理聊天请求的核心逻辑 * 处理聊天请求的核心逻辑
* 注意:权限验证已在主线程中完成,此正仅执行业务逻辑不进行权限检查
* *
* @param emitter SSE发射器 * @param emitter SSE发射器
* @param agentId Agent ID * @param agentId Agent ID
...@@ -153,16 +159,21 @@ public class AgentChatService { ...@@ -153,16 +159,21 @@ public class AgentChatService {
*/ */
private void processChatRequest(SseEmitter emitter, String agentId, ChatRequest chatRequest, String userId) { private void processChatRequest(SseEmitter emitter, String agentId, ChatRequest chatRequest, String userId) {
try { try {
// 获取Agent信息并进行权限检查 // 直接从 agentService 获取Agent,不需验证权限(权限检查已在主线程中完成)
Agent agent = agentValidationService.validateAgentAndPermission(agentId, userId, emitter); // 使用 agentService.getAgent() 要比 validateAgentAndPermission 安全,因为前者不会在异步线程中访问SecurityContext
Agent agent = agentService.getAgent(agentId);
if (agent == null) { if (agent == null) {
return; // 权限验证失败,直接返回 log.error("Agent不存在: {}", agentId);
chatErrorHandler.handleChatError(emitter, "Agent不存在");
return;
} }
// 获取处理器并启动心跳保活机制 // 获取处理器
AgentProcessor processor = agentProcessorFactory.getProcessor(agent); AgentProcessor processor = agentProcessorFactory.getProcessor(agent);
if (processor == null) { if (processor == null) {
return; // 获取处理器失败,直接返回 log.error("无法获取Agent处理器,Agent: {}", agentId);
chatErrorHandler.handleChatError(emitter, "无法获取Agent处理器");
return;
} }
// 启动心跳机制 // 启动心跳机制
...@@ -174,6 +185,7 @@ public class AgentChatService { ...@@ -174,6 +185,7 @@ public class AgentChatService {
// 处理流式请求 // 处理流式请求
streamRequestService.handleStreamRequest(emitter, processor, request, agent, userId); streamRequestService.handleStreamRequest(emitter, processor, request, agent, userId);
} catch (Exception e) { } catch (Exception e) {
log.error("处理聊天请求时发生异常", e);
chatErrorHandler.handleChatError(emitter, "处理请求时发生错误", e, null); chatErrorHandler.handleChatError(emitter, "处理请求时发生错误", e, null);
} }
} }
......
...@@ -87,6 +87,7 @@ public class CompletionHandlerService { ...@@ -87,6 +87,7 @@ public class CompletionHandlerService {
log.info("{} Agent处理完成,总字符数: {}", processor.getProcessorType(), fullContent != null ? fullContent.length() : 0); log.info("{} Agent处理完成,总字符数: {}", processor.getProcessorType(), fullContent != null ? fullContent.length() : 0);
// 发送完成事件 // 发送完成事件
Exception completionException = null;
try { try {
// 发送完整内容作为最后一个token // 发送完整内容作为最后一个token
// if (fullContent != null && !fullContent.isEmpty()) { // if (fullContent != null && !fullContent.isEmpty()) {
...@@ -95,16 +96,28 @@ public class CompletionHandlerService { ...@@ -95,16 +96,28 @@ public class CompletionHandlerService {
// 发送完成信号 // 发送完成信号
emitter.send("[DONE]"); emitter.send("[DONE]");
} catch (Exception e) { } catch (Exception e) {
errorHandlerService.handleCompletionError(emitter, e); log.error("发送完成信号失败", e);
completionException = e;
} }
// 保存对话记录 // 保存对话记录
try { try {
saveDialogue(agent, request, userId, fullContent); saveDialogue(agent, request, userId, fullContent);
log.info("对话记录保存成功");
} catch (Exception e) { } catch (Exception e) {
errorHandlerService.handleSaveDialogueError(emitter, e, isCompleted); log.error("保存对话记录失败", e);
} finally { // 记录异常但不中断流程,继续关闭emitter
completionException = e;
}
// 最后才关闭emitter,确保所有操作都完成后再提交响应
try {
unifiedSseService.completeEmitter(emitter, isCompleted); unifiedSseService.completeEmitter(emitter, isCompleted);
log.debug("SSE Emitter已关闭");
} catch (Exception e) {
log.error("关闭Emitter时发生错误", e);
} }
LogUtils.exitMethod("handleCompletion", "处理完成"); LogUtils.exitMethod("handleCompletion", "处理完成");
} }
......
...@@ -253,7 +253,7 @@ public class UserSseService { ...@@ -253,7 +253,7 @@ public class UserSseService {
isCompleted.set(true); isCompleted.set(true);
} }
} }
}, 30, 30, TimeUnit.SECONDS); // 每30秒发送一次心跳 }, 20, 20, TimeUnit.SECONDS); // 每20秒发送一次心跳,确保前端60秒超时前至少收到2次心跳
// 注册回调,在连接完成时取消心跳任务 // 注册回调,在连接完成时取消心跳任务
emitter.onCompletion(() -> { emitter.onCompletion(() -> {
...@@ -287,7 +287,7 @@ public class UserSseService { ...@@ -287,7 +287,7 @@ public class UserSseService {
*/ */
public void registerCallbacks(SseEmitter emitter) { public void registerCallbacks(SseEmitter emitter) {
emitter.onCompletion(() -> { emitter.onCompletion(() -> {
log.debug("SSE连接完成"); log.debug("【注册回调函数】SSE连接完成");
removeEmitter(emitter); removeEmitter(emitter);
}); });
emitter.onError((Throwable t) -> { emitter.onError((Throwable t) -> {
...@@ -314,7 +314,7 @@ public class UserSseService { ...@@ -314,7 +314,7 @@ public class UserSseService {
*/ */
public void registerCallbacks(SseEmitter emitter, String userId) { public void registerCallbacks(SseEmitter emitter, String userId) {
emitter.onCompletion(() -> { emitter.onCompletion(() -> {
log.debug("SSE连接完成"); log.debug("【注册Emitter回调函数】SSE连接完成");
// 通知用户连接管理器连接已完成 // 通知用户连接管理器连接已完成
handleConnectionCompletion(emitter); handleConnectionCompletion(emitter);
}); });
...@@ -424,7 +424,9 @@ public class UserSseService { ...@@ -424,7 +424,9 @@ public class UserSseService {
try { try {
// 发送心跳事件 // 发送心跳事件
emitter.send(SseEmitter.event().name("heartbeat").data(System.currentTimeMillis())); long heartbeatTimestamp = System.currentTimeMillis();
emitter.send(SseEmitter.event().name("heartbeat").data(heartbeatTimestamp));
log.debug("[心跳] 成功发送心跳事件,时间戳: {}", heartbeatTimestamp);
} catch (IllegalStateException e) { } catch (IllegalStateException e) {
// 处理 emitter 已关闭的情况 // 处理 emitter 已关闭的情况
log.debug("无法发送心跳事件,emitter已关闭: {}", e.getMessage()); log.debug("无法发送心跳事件,emitter已关闭: {}", e.getMessage());
......
...@@ -5,6 +5,8 @@ import lombok.extern.slf4j.Slf4j; ...@@ -5,6 +5,8 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.reflection.MetaObject; import org.apache.ibatis.reflection.MetaObject;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import pangea.hiagent.common.utils.UserUtils; import pangea.hiagent.common.utils.UserUtils;
import pangea.hiagent.common.utils.UserContextPropagationUtil;
import pangea.hiagent.common.utils.AsyncUserContextDecorator;
import java.time.LocalDateTime; import java.time.LocalDateTime;
...@@ -46,7 +48,7 @@ public class MetaObjectHandlerConfig implements MetaObjectHandler { ...@@ -46,7 +48,7 @@ public class MetaObjectHandlerConfig implements MetaObjectHandler {
if (metaObject.hasSetter("createdBy")) { if (metaObject.hasSetter("createdBy")) {
Object createdBy = getFieldValByName("createdBy", metaObject); Object createdBy = getFieldValByName("createdBy", metaObject);
if (createdBy == null) { if (createdBy == null) {
String userId = UserUtils.getCurrentUserId(); String userId = getCurrentUserIdWithContext();
if (userId != null) { if (userId != null) {
this.strictInsertFill(metaObject, "createdBy", String.class, userId); this.strictInsertFill(metaObject, "createdBy", String.class, userId);
log.debug("自动填充createdBy字段: {}", userId); log.debug("自动填充createdBy字段: {}", userId);
...@@ -60,7 +62,7 @@ public class MetaObjectHandlerConfig implements MetaObjectHandler { ...@@ -60,7 +62,7 @@ public class MetaObjectHandlerConfig implements MetaObjectHandler {
if (metaObject.hasSetter("updatedBy")) { if (metaObject.hasSetter("updatedBy")) {
Object updatedBy = getFieldValByName("updatedBy", metaObject); Object updatedBy = getFieldValByName("updatedBy", metaObject);
if (updatedBy == null) { if (updatedBy == null) {
String userId = UserUtils.getCurrentUserId(); String userId = getCurrentUserIdWithContext();
if (userId != null) { if (userId != null) {
this.strictInsertFill(metaObject, "updatedBy", String.class, userId); this.strictInsertFill(metaObject, "updatedBy", String.class, userId);
log.debug("自动填充updatedBy字段: {}", userId); log.debug("自动填充updatedBy字段: {}", userId);
...@@ -91,7 +93,7 @@ public class MetaObjectHandlerConfig implements MetaObjectHandler { ...@@ -91,7 +93,7 @@ public class MetaObjectHandlerConfig implements MetaObjectHandler {
Object updatedBy = getFieldValByName("updatedBy", metaObject); Object updatedBy = getFieldValByName("updatedBy", metaObject);
// 如果updatedBy为空或者需要强制更新,则填充当前用户ID // 如果updatedBy为空或者需要强制更新,则填充当前用户ID
if (updatedBy == null) { if (updatedBy == null) {
String userId = UserUtils.getCurrentUserId(); String userId = getCurrentUserIdWithContext();
if (userId != null) { if (userId != null) {
this.strictUpdateFill(metaObject, "updatedBy", String.class, userId); this.strictUpdateFill(metaObject, "updatedBy", String.class, userId);
log.debug("自动填充updatedBy字段: {}", userId); log.debug("自动填充updatedBy字段: {}", userId);
...@@ -101,4 +103,39 @@ public class MetaObjectHandlerConfig implements MetaObjectHandler { ...@@ -101,4 +103,39 @@ public class MetaObjectHandlerConfig implements MetaObjectHandler {
} }
} }
} }
/**
* 获取当前用户ID,支持异步线程上下文
* 该方法支持以下场景:
* 1. 同步请求:从SecurityContext获取用户ID
* 2. 异步任务:从AsyncUserContextDecorator传播的上下文获取用户ID
* 3. 故障转移:尝试直接解析Token获取用户ID
*
* @return 用户ID,如果无法获取则返回null
*/
private String getCurrentUserIdWithContext() {
try {
// 方式1:首先尝试从SecurityContext获取(支持同步请求和AsyncUserContextDecorator传播)
String userId = UserUtils.getCurrentUserId();
if (userId != null) {
log.debug("通过SecurityContext成功获取用户ID: {}", userId);
return userId;
}
log.debug("无法从SecurityContext获取用户ID,可能是异步线程且未使用AsyncUserContextDecorator包装");
// 方式2:尝试直接从请求中解析Token(故障转移)
String asyncUserId = UserUtils.getCurrentUserIdInAsync();
if (asyncUserId != null) {
log.debug("通过直接解析Token成功获取用户ID: {}", asyncUserId);
return asyncUserId;
}
log.warn("无法通过任何方式获取当前用户ID,createdBy/updatedBy字段将不被填充");
return null;
} catch (Exception e) {
log.error("获取用户ID时发生异常", e);
return null;
}
}
} }
\ No newline at end of file
...@@ -21,6 +21,7 @@ import pangea.hiagent.web.service.AgentService; ...@@ -21,6 +21,7 @@ import pangea.hiagent.web.service.AgentService;
import pangea.hiagent.web.service.TimerService; import pangea.hiagent.web.service.TimerService;
import pangea.hiagent.security.DefaultPermissionEvaluator; import pangea.hiagent.security.DefaultPermissionEvaluator;
import pangea.hiagent.security.JwtAuthenticationFilter; import pangea.hiagent.security.JwtAuthenticationFilter;
import pangea.hiagent.security.SseAuthorizationFilter;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
...@@ -33,11 +34,13 @@ import java.util.Collections; ...@@ -33,11 +34,13 @@ import java.util.Collections;
public class SecurityConfig { public class SecurityConfig {
private final JwtAuthenticationFilter jwtAuthenticationFilter; private final JwtAuthenticationFilter jwtAuthenticationFilter;
private final SseAuthorizationFilter sseAuthorizationFilter;
private final AgentService agentService; private final AgentService agentService;
private final TimerService timerService; private final TimerService timerService;
public SecurityConfig(JwtAuthenticationFilter jwtAuthenticationFilter, AgentService agentService, TimerService timerService) { public SecurityConfig(JwtAuthenticationFilter jwtAuthenticationFilter, SseAuthorizationFilter sseAuthorizationFilter, AgentService agentService, TimerService timerService) {
this.jwtAuthenticationFilter = jwtAuthenticationFilter; this.jwtAuthenticationFilter = jwtAuthenticationFilter;
this.sseAuthorizationFilter = sseAuthorizationFilter;
this.agentService = agentService; this.agentService = agentService;
this.timerService = timerService; this.timerService = timerService;
} }
...@@ -203,6 +206,8 @@ public class SecurityConfig { ...@@ -203,6 +206,8 @@ public class SecurityConfig {
} }
}) })
) )
// 添加SSE授权检查过滤器,在所有其他过滤器之前运行,提前拒绝未认证的SSE请求
.addFilterBefore(sseAuthorizationFilter, UsernamePasswordAuthenticationFilter.class)
// 添加JWT认证过滤器 // 添加JWT认证过滤器
.addFilterBefore(jwtAuthenticationFilter, UsernamePasswordAuthenticationFilter.class) .addFilterBefore(jwtAuthenticationFilter, UsernamePasswordAuthenticationFilter.class)
// 配置X-Frame-Options头部,允许同源iframe嵌入 // 配置X-Frame-Options头部,允许同源iframe嵌入
......
...@@ -18,19 +18,30 @@ import jakarta.servlet.http.HttpServletRequest; ...@@ -18,19 +18,30 @@ import jakarta.servlet.http.HttpServletRequest;
@Slf4j @Slf4j
@Component @Component
public class UserUtils { public class UserUtils {
// 注入JwtUtil bean // 注入JwtUtil bean
private static JwtUtil jwtUtil; private static JwtUtil jwtUtil;
public UserUtils(JwtUtil jwtUtil) { public UserUtils(JwtUtil jwtUtil) {
UserUtils.jwtUtil = jwtUtil; UserUtils.jwtUtil = jwtUtil;
} }
public static String getCurrentUserId() {
String username = getCurrentUserIdInSync();
if (username==null || username.isEmpty()) {
username = getCurrentUserIdInAsync();
}
return username;
}
/** /**
* 获取当前认证用户ID * 获取当前认证用户ID
*
* @return 用户ID,如果未认证则返回null * @return 用户ID,如果未认证则返回null
*/ */
public static String getCurrentUserId() { public static String getCurrentUserIdInSync() {
try { try {
// 首先尝试从SecurityContext获取 // 首先尝试从SecurityContext获取
Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
...@@ -52,14 +63,14 @@ public class UserUtils { ...@@ -52,14 +63,14 @@ public class UserUtils {
} }
} }
} }
// 如果SecurityContext中没有认证信息,尝试从请求中解析JWT令牌 // 如果SecurityContext中没有认证信息,尝试从请求中解析JWT令牌
String userId = getUserIdFromRequest(); String userId = getUserIdFromRequest();
if (userId != null) { if (userId != null) {
log.debug("从请求中解析到用户ID: {}", userId); log.debug("从请求中解析到用户ID: {}", userId);
return userId; return userId;
} }
log.debug("未能获取到有效的用户ID"); log.debug("未能获取到有效的用户ID");
return null; return null;
} catch (Exception e) { } catch (Exception e) {
...@@ -67,23 +78,24 @@ public class UserUtils { ...@@ -67,23 +78,24 @@ public class UserUtils {
return null; return null;
} }
} }
/** /**
* 在异步线程环境中获取当前认证用户ID * 在异步线程环境中获取当前认证用户ID
* 该方法专为异步线程环境设计,通过JWT令牌解析获取用户ID * 该方法专为异步线程环境设计,通过JWT令牌解析获取用户ID
*
* @return 用户ID,如果未认证则返回null * @return 用户ID,如果未认证则返回null
*/ */
public static String getCurrentUserIdInAsync() { public static String getCurrentUserIdInAsync() {
try { try {
log.debug("在异步线程中尝试获取用户ID"); log.debug("在异步线程中尝试获取用户ID");
// 直接从请求中解析JWT令牌获取用户ID // 直接从请求中解析JWT令牌获取用户ID
String userId = getUserIdFromRequest(); String userId = getUserIdFromRequest();
if (userId != null) { if (userId != null) {
log.debug("在异步线程中成功获取用户ID: {}", userId); log.debug("在异步线程中成功获取用户ID: {}", userId);
return userId; return userId;
} }
log.debug("在异步线程中未能获取到有效的用户ID"); log.debug("在异步线程中未能获取到有效的用户ID");
return null; return null;
} catch (Exception e) { } catch (Exception e) {
...@@ -91,9 +103,10 @@ public class UserUtils { ...@@ -91,9 +103,10 @@ public class UserUtils {
return null; return null;
} }
} }
/** /**
* 从当前请求中提取JWT令牌并解析用户ID * 从当前请求中提取JWT令牌并解析用户ID
*
* @return 用户ID,如果无法解析则返回null * @return 用户ID,如果无法解析则返回null
*/ */
private static String getUserIdFromRequest() { private static String getUserIdFromRequest() {
...@@ -101,15 +114,15 @@ public class UserUtils { ...@@ -101,15 +114,15 @@ public class UserUtils {
RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes(); RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();
if (requestAttributes instanceof ServletRequestAttributes) { if (requestAttributes instanceof ServletRequestAttributes) {
HttpServletRequest request = ((ServletRequestAttributes) requestAttributes).getRequest(); HttpServletRequest request = ((ServletRequestAttributes) requestAttributes).getRequest();
// 从请求头或参数中提取Token // 从请求头或参数中提取Token
String token = extractTokenFromRequest(request); String token = extractTokenFromRequest(request);
if (StringUtils.hasText(token) && jwtUtil != null) { if (StringUtils.hasText(token) && jwtUtil != null) {
// 验证token是否有效 // 验证token是否有效
boolean isValid = jwtUtil.validateToken(token); boolean isValid = jwtUtil.validateToken(token);
log.debug("JWT验证结果: {}", isValid); log.debug("JWT验证结果: {}", isValid);
if (isValid) { if (isValid) {
String userId = jwtUtil.getUserIdFromToken(token); String userId = jwtUtil.getUserIdFromToken(token);
log.debug("从JWT令牌中提取用户ID: {}", userId); log.debug("从JWT令牌中提取用户ID: {}", userId);
...@@ -130,10 +143,10 @@ public class UserUtils { ...@@ -130,10 +143,10 @@ public class UserUtils {
} catch (Exception e) { } catch (Exception e) {
log.error("从请求中解析用户ID时发生异常", e); log.error("从请求中解析用户ID时发生异常", e);
} }
return null; return null;
} }
/** /**
* 从请求头或参数中提取Token * 从请求头或参数中提取Token
*/ */
...@@ -146,7 +159,7 @@ public class UserUtils { ...@@ -146,7 +159,7 @@ public class UserUtils {
log.debug("从Authorization头中提取到token"); log.debug("从Authorization头中提取到token");
return token; return token;
} }
// 如果请求头中没有Token,则尝试从URL参数中提取 // 如果请求头中没有Token,则尝试从URL参数中提取
String tokenParam = request.getParameter("token"); String tokenParam = request.getParameter("token");
log.debug("从URL参数中提取token参数: {}", tokenParam); log.debug("从URL参数中提取token参数: {}", tokenParam);
...@@ -154,21 +167,23 @@ public class UserUtils { ...@@ -154,21 +167,23 @@ public class UserUtils {
log.debug("从URL参数中提取到token"); log.debug("从URL参数中提取到token");
return tokenParam; return tokenParam;
} }
log.debug("未找到有效的token"); log.debug("未找到有效的token");
return null; return null;
} }
/** /**
* 检查当前用户是否已认证 * 检查当前用户是否已认证
*
* @return true表示已认证,false表示未认证 * @return true表示已认证,false表示未认证
*/ */
public static boolean isAuthenticated() { public static boolean isAuthenticated() {
return getCurrentUserId() != null; return getCurrentUserId() != null;
} }
/** /**
* 检查用户是否是管理员 * 检查用户是否是管理员
*
* @param userId 用户ID * @param userId 用户ID
* @return true表示是管理员,false表示不是管理员 * @return true表示是管理员,false表示不是管理员
*/ */
......
package pangea.hiagent.security;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import org.springframework.web.filter.OncePerRequestFilter;
import pangea.hiagent.common.utils.JwtUtil;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
/**
* SSE流式端点授权检查过滤器
* 在Spring Security的AuthorizationFilter之前运行,提前处理流式端点的身份验证检查
* 避免响应被提交后才处理异常的问题
*/
@Slf4j
@Component
public class SseAuthorizationFilter extends OncePerRequestFilter {
private static final String STREAM_ENDPOINT = "/api/v1/agent/chat-stream";
private static final String TIMELINE_ENDPOINT = "/api/v1/agent/timeline-events";
private final JwtUtil jwtUtil;
public SseAuthorizationFilter(JwtUtil jwtUtil) {
this.jwtUtil = jwtUtil;
}
@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException {
String requestUri = request.getRequestURI();
boolean isStreamEndpoint = requestUri.contains(STREAM_ENDPOINT);
boolean isTimelineEndpoint = requestUri.contains(TIMELINE_ENDPOINT);
// 只处理SSE端点
if (isStreamEndpoint || isTimelineEndpoint) {
log.debug("SSE端点授权检查: {} {}", request.getMethod(), requestUri);
// 尝试从请求中提取并验证JWT token
String token = extractTokenFromRequest(request);
if (StringUtils.hasText(token)) {
log.debug("提取到JWT token,进行验证");
try {
// 验证token是否有效
if (jwtUtil.validateToken(token)) {
String userId = jwtUtil.getUserIdFromToken(token);
if (userId != null) {
// 创建认证对象
List<SimpleGrantedAuthority> authorities = Collections.singletonList(new SimpleGrantedAuthority("ROLE_USER"));
UsernamePasswordAuthenticationToken authentication =
new UsernamePasswordAuthenticationToken(userId, null, authorities);
SecurityContextHolder.getContext().setAuthentication(authentication);
log.debug("SSE端点JWT验证成功,用户: {}", userId);
// 继续执行过滤器链
filterChain.doFilter(request, response);
return;
}
}
} catch (Exception e) {
log.warn("SSE端点JWT验证失败: {}", e.getMessage());
}
}
// token无效或不存在,拒绝连接
log.warn("SSE端点未认证访问,拒绝连接: {} {}", request.getMethod(), requestUri);
sendSseUnauthorizedError(response);
return;
}
// 继续执行过滤器链(非SSE端点)
filterChain.doFilter(request, response);
}
/**
* 发送SSE格式的未授权错误响应
*/
private void sendSseUnauthorizedError(HttpServletResponse response) {
try {
response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
response.setContentType("text/event-stream;charset=UTF-8");
response.setCharacterEncoding("UTF-8");
// 发送SSE格式的错误事件
response.getWriter().write("event: error\n");
response.getWriter().write("data: {\"error\": \"未授权访问,请先登录\", \"code\": 401, \"timestamp\": " +
System.currentTimeMillis() + "}\n\n");
response.getWriter().flush();
log.debug("已发送SSE未授权错误响应");
} catch (IOException e) {
log.error("发送SSE未授权错误响应失败", e);
}
}
/**
* 从请求头或参数中提取Token
*/
private String extractTokenFromRequest(HttpServletRequest request) {
// 首先尝试从请求头中提取Token
String authHeader = request.getHeader("Authorization");
if (StringUtils.hasText(authHeader) && authHeader.startsWith("Bearer ")) {
return authHeader.substring(7);
}
// 如果请求头中没有Token,则尝试从URL参数中提取
String tokenParam = request.getParameter("token");
if (StringUtils.hasText(tokenParam)) {
return tokenParam;
}
return null;
}
/**
* 确定此过滤器是否应处理给定请求
* 只处理SSE流式端点
*/
@Override
protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException {
String requestUri = request.getRequestURI();
boolean isStreamEndpoint = requestUri.contains(STREAM_ENDPOINT);
boolean isTimelineEndpoint = requestUri.contains(TIMELINE_ENDPOINT);
// 如果不是SSE端点,跳过此过滤器
return !(isStreamEndpoint || isTimelineEndpoint);
}
}
package pangea.hiagent.tool;
import java.lang.annotation.*;
/**
* 工具参数注解
* 用于标记工具类中的配置参数
*/
@Target(ElementType.FIELD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface ToolParam {
/**
* 参数名称
*/
String name() default "";
/**
* 参数描述
*/
String description() default "";
/**
* 参数默认值
*/
String defaultValue() default "";
/**
* 参数类型
*/
String type() default "string";
/**
* 是否必填
*/
boolean required() default false;
/**
* 参数分组
*/
String group() default "default";
}
\ No newline at end of file
package pangea.hiagent.tool;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.beans.factory.config.ConfigurableBeanFactory;
import org.springframework.context.annotation.Scope;
import org.springframework.stereotype.Component;
import pangea.hiagent.web.service.ToolConfigService;
import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.List;
/**
* 工具参数处理器
* 用于处理工具类中的@ToolParam注解,将数据库中的参数值注入到工具类字段
*/
@Slf4j
@Component
@Scope(ConfigurableBeanFactory.SCOPE_SINGLETON)
public class ToolParamProcessor implements BeanPostProcessor {
private final ToolConfigService toolConfigService;
// 构造函数注入
public ToolParamProcessor(ToolConfigService toolConfigService) {
this.toolConfigService = toolConfigService;
}
@Override
public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
// 检查Bean是否为工具类(位于tools包下,且带有@Component注解)
Class<?> beanClass = bean.getClass();
String packageName = beanClass.getPackage().getName();
if (packageName.contains("pangea.hiagent.tools") && beanClass.isAnnotationPresent(Component.class)) {
log.debug("处理工具类参数,Bean名称:{}", beanName);
injectParams(bean);
}
return bean;
}
/**
* 注入参数值到工具类字段
* @param bean 工具类实例
*/
private void injectParams(Object bean) {
Class<?> beanClass = bean.getClass();
String toolName = beanClass.getSimpleName();
// 获取所有字段,包括父类字段
List<Field> fields = getAllFields(beanClass);
for (Field field : fields) {
if (field.isAnnotationPresent(ToolParam.class)) {
ToolParam annotation = field.getAnnotation(ToolParam.class);
String paramName = annotation.name().isEmpty() ? field.getName() : annotation.name();
// 从数据库获取参数值,如果不存在则使用默认值
String paramValue = toolConfigService.getParamValue(toolName, paramName);
if (paramValue == null) {
paramValue = annotation.defaultValue();
log.debug("参数值不存在,使用默认值,工具名称:{},参数名称:{},默认值:{}",
toolName, paramName, paramValue);
}
// 设置字段值
field.setAccessible(true);
try {
// 根据字段类型转换参数值
injectFieldValue(bean, field, paramValue);
log.debug("参数值注入成功,工具名称:{},参数名称:{},字段类型:{},值:{}",
toolName, paramName, field.getType().getName(), paramValue);
} catch (Exception e) {
log.error("参数值注入失败,工具名称:{},参数名称:{},字段类型:{},值:{}",
toolName, paramName, field.getType().getName(), paramValue, e);
}
}
}
}
/**
* 递归获取所有字段,包括父类字段
* @param clazz 类对象
* @return 字段列表
*/
private List<Field> getAllFields(Class<?> clazz) {
List<Field> fields = Arrays.asList(clazz.getDeclaredFields());
Class<?> superClass = clazz.getSuperclass();
if (superClass != null && !superClass.equals(Object.class)) {
fields.addAll(getAllFields(superClass));
}
return fields;
}
/**
* 根据字段类型注入参数值
* @param bean 工具类实例
* @param field 字段对象
* @param paramValue 参数值字符串
* @throws IllegalAccessException 访问权限异常
*/
private void injectFieldValue(Object bean, Field field, String paramValue) throws IllegalAccessException {
Class<?> fieldType = field.getType();
if (fieldType == String.class) {
field.set(bean, paramValue);
} else if (fieldType == int.class || fieldType == Integer.class) {
field.set(bean, Integer.parseInt(paramValue));
} else if (fieldType == long.class || fieldType == Long.class) {
field.set(bean, Long.parseLong(paramValue));
} else if (fieldType == boolean.class || fieldType == Boolean.class) {
field.set(bean, Boolean.parseBoolean(paramValue));
} else if (fieldType == double.class || fieldType == Double.class) {
field.set(bean, Double.parseDouble(paramValue));
} else if (fieldType == float.class || fieldType == Float.class) {
field.set(bean, Float.parseFloat(paramValue));
} else if (fieldType == short.class || fieldType == Short.class) {
field.set(bean, Short.parseShort(paramValue));
} else if (fieldType == byte.class || fieldType == Byte.class) {
field.set(bean, Byte.parseByte(paramValue));
} else if (fieldType == char.class || fieldType == Character.class) {
field.set(bean, paramValue.charAt(0));
} else {
// 对于其他类型,直接设置为null
field.set(bean, null);
log.warn("不支持的字段类型,工具名称:{},参数名称:{},字段类型:{}",
bean.getClass().getSimpleName(), field.getName(), fieldType.getName());
}
}
}
\ No newline at end of file
...@@ -3,7 +3,7 @@ package pangea.hiagent.tool.impl; ...@@ -3,7 +3,7 @@ package pangea.hiagent.tool.impl;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.tool.annotation.Tool; import org.springframework.ai.tool.annotation.Tool;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import pangea.hiagent.tool.ToolParam;
/** /**
* 图表生成工具 * 图表生成工具
...@@ -13,35 +13,11 @@ import pangea.hiagent.tool.ToolParam; ...@@ -13,35 +13,11 @@ import pangea.hiagent.tool.ToolParam;
@Component @Component
public class ChartGenerationTool { public class ChartGenerationTool {
@ToolParam( private Integer maxDataPoints = 100;
name = "maxDataPoints",
description = "最大数据点数量限制",
defaultValue = "100",
type = "integer",
required = true,
group = "chart"
)
private Integer maxDataPoints;
@ToolParam( private Integer percentageDecimalPlaces = 2;
name = "percentageDecimalPlaces",
description = "百分比显示的小数位数",
defaultValue = "2",
type = "integer",
required = true,
group = "chart"
)
private Integer percentageDecimalPlaces;
@ToolParam( private String defaultSeriesName = "数据";
name = "defaultSeriesName",
description = "默认数据系列名称",
defaultValue = "数据",
type = "string",
required = true,
group = "chart"
)
private String defaultSeriesName;
/** /**
* 生成柱状图 * 生成柱状图
......
...@@ -3,7 +3,7 @@ package pangea.hiagent.tool.impl; ...@@ -3,7 +3,7 @@ package pangea.hiagent.tool.impl;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import org.springframework.ai.tool.annotation.Tool; import org.springframework.ai.tool.annotation.Tool;
import pangea.hiagent.tool.ToolParam;
import java.time.LocalDateTime; import java.time.LocalDateTime;
import java.time.LocalDate; import java.time.LocalDate;
...@@ -18,34 +18,10 @@ import java.time.format.DateTimeFormatter; ...@@ -18,34 +18,10 @@ import java.time.format.DateTimeFormatter;
@Component @Component
public class DateTimeTools { public class DateTimeTools {
@ToolParam(
name = "dateTimeFormat",
description = "日期时间格式",
defaultValue = "yyyy-MM-dd HH:mm:ss",
type = "string",
required = true,
group = "datetime"
)
private String dateTimeFormat = "yyyy-MM-dd HH:mm:ss"; private String dateTimeFormat = "yyyy-MM-dd HH:mm:ss";
@ToolParam(
name = "dateFormat",
description = "日期格式",
defaultValue = "yyyy-MM-dd",
type = "string",
required = true,
group = "datetime"
)
private String dateFormat = "yyyy-MM-dd"; private String dateFormat = "yyyy-MM-dd";
@ToolParam(
name = "timeFormat",
description = "时间格式",
defaultValue = "HH:mm:ss",
type = "string",
required = true,
group = "datetime"
)
private String timeFormat = "HH:mm:ss"; private String timeFormat = "HH:mm:ss";
@Tool(description = "获取当前日期和时间,返回格式为 'yyyy-MM-dd HH:mm:ss'") @Tool(description = "获取当前日期和时间,返回格式为 'yyyy-MM-dd HH:mm:ss'")
......
...@@ -10,7 +10,7 @@ import jakarta.mail.search.ReceivedDateTerm; ...@@ -10,7 +10,7 @@ import jakarta.mail.search.ReceivedDateTerm;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.tool.annotation.Tool; import org.springframework.ai.tool.annotation.Tool;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import pangea.hiagent.tool.ToolParam;
import java.io.File; import java.io.File;
import java.util.*; import java.util.*;
...@@ -23,45 +23,13 @@ import java.util.*; ...@@ -23,45 +23,13 @@ import java.util.*;
@Component @Component
public class EmailTools { public class EmailTools {
@ToolParam( private Integer defaultPop3Port = 995;
name = "defaultPop3Port",
description = "默认POP3服务器端口",
defaultValue = "995",
type = "integer",
required = true,
group = "email"
)
private Integer defaultPop3Port;
@ToolParam( private String defaultAttachmentPath = "attachments";
name = "defaultAttachmentPath",
description = "默认附件保存路径",
defaultValue = "attachments",
type = "string",
required = true,
group = "email"
)
private String defaultAttachmentPath;
@ToolParam( private Boolean pop3SslEnable = true;
name = "pop3SslEnable",
description = "是否启用POP3 SSL",
defaultValue = "true",
type = "boolean",
required = true,
group = "email"
)
private Boolean pop3SslEnable;
@ToolParam( private String pop3SocketFactoryClass = "javax.net.ssl.SSLSocketFactory";
name = "pop3SocketFactoryClass",
description = "POP3 SSL套接字工厂类",
defaultValue = "javax.net.ssl.SSLSocketFactory",
type = "string",
required = true,
group = "email"
)
private String pop3SocketFactoryClass;
// 邮件请求参数类 // 邮件请求参数类
@JsonClassDescription("邮件操作请求参数") @JsonClassDescription("邮件操作请求参数")
......
...@@ -3,7 +3,7 @@ package pangea.hiagent.tool.impl; ...@@ -3,7 +3,7 @@ package pangea.hiagent.tool.impl;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import org.springframework.ai.tool.annotation.Tool; import org.springframework.ai.tool.annotation.Tool;
import pangea.hiagent.tool.ToolParam;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.nio.charset.Charset; import java.nio.charset.Charset;
...@@ -24,37 +24,13 @@ import java.util.UUID; ...@@ -24,37 +24,13 @@ import java.util.UUID;
public class FileProcessingTools { public class FileProcessingTools {
// 支持的文本文件扩展名 // 支持的文本文件扩展名
@ToolParam( private String textFileExtensions = ".txt,.md,.java,.html,.htm,.css,.js,.json,.xml,.yaml,.yml,.properties,.sql,.py,.cpp,.c,.h,.cs,.php,.rb,.go,.rs,.swift,.kt,.scala,.sh,.bat,.cmd,.ps1,.log,.csv,.ts,.jsx,.tsx,.vue,.scss,.sass,.less";
name = "textFileExtensions",
description = "支持的文本文件扩展名,逗号分隔",
defaultValue = ".txt,.md,.java,.html,.htm,.css,.js,.json,.xml,.yaml,.yml,.properties,.sql,.py,.cpp,.c,.h,.cs,.php,.rb,.go,.rs,.swift,.kt,.scala,.sh,.bat,.cmd,.ps1,.log,.csv,.ts,.jsx,.tsx,.vue,.scss,.sass,.less",
type = "string",
required = true,
group = "file"
)
private String textFileExtensions;
// 支持的图片文件扩展名 // 支持的图片文件扩展名
@ToolParam( private String imageFileExtensions = ".jpg,.jpeg,.png,.gif,.bmp,.svg,.webp,.ico";
name = "imageFileExtensions",
description = "支持的图片文件扩展名,逗号分隔",
defaultValue = ".jpg,.jpeg,.png,.gif,.bmp,.svg,.webp,.ico",
type = "string",
required = true,
group = "file"
)
private String imageFileExtensions;
// 默认文件存储目录 // 默认文件存储目录
@ToolParam( private String defaultStorageDir = "storage";
name = "defaultStorageDir",
description = "默认文件存储目录",
defaultValue = "storage",
type = "string",
required = true,
group = "file"
)
private String defaultStorageDir;
// 转换为列表的辅助方法 // 转换为列表的辅助方法
private List<String> getTextFileExtensions() { private List<String> getTextFileExtensions() {
......
# 文件处理工具使用说明
## 功能概述
FileProcessingTools 是一个功能丰富的文件处理工具类,专门设计用于处理各种文本格式文件。该工具支持读取、写入、追加内容到文件,并提供文件信息查询功能。
支持的文件格式包括但不限于:
- 文本文件:`.txt`
- 标记语言文件:`.md`
- 编程语言文件:`.java`, `.html`, `.htm`, `.css`, `.js`, `.json`, `.xml`, `.yaml`, `.yml`, `.py`, `.cpp`, `.c`, `.h`, `.cs`, `.php`, `.rb`, `.go`, `.rs`, `.swift`, `.kt`, `.scala`
- 脚本文件:`.sh`, `.bat`, `.cmd`, `.ps1`
- 其他文本格式:`.properties`, `.sql`, `.log`, `.csv`, `.ts`, `.jsx`, `.tsx`, `.vue`, `.scss`, `.sass`, `.less`
## 功能列表
### 1. readFile(String filePath)
读取文本文件内容
**参数:**
- `filePath`: 文件路径(支持相对路径)
**返回值:**
- 成功时返回文件内容
- 失败时返回错误信息
**示例:**
```java
@Autowired
private FileProcessingTools fileTools;
String content = fileTools.readFile("/path/to/file.txt");
// 或使用相对路径
String content = fileTools.readFile("relative/path/to/file.txt");
```
### 2. readFileWithEncoding(String filePath, String encoding)
读取文本文件内容,支持指定字符编码
**参数:**
- `filePath`: 文件路径(支持相对路径)
- `encoding`: 字符编码(如 "UTF-8", "GBK" 等)
**返回值:**
- 成功时返回文件内容
- 失败时返回错误信息
**示例:**
```java
String content = fileTools.readFileWithEncoding("/path/to/file.txt", "UTF-8");
```
### 3. writeFile(String filePath, String content)
写入内容到文本文件
**参数:**
- `filePath`: 文件路径(支持相对路径,如果为空或null则自动生成随机文件名)
- `content`: 要写入的内容
**返回值:**
- 成功时返回"文件写入成功,文件路径: [完整文件路径]"
- 失败时返回错误信息
**示例:**
```java
// 指定文件名
String result = fileTools.writeFile("/path/to/file.txt", "Hello, World!");
// 使用相对路径
String result = fileTools.writeFile("relative/path/to/file.txt", "Hello, World!");
// 自动生成随机文件名
String result = fileTools.writeFile("", "Hello, World!");
```
### 4. writeFileWithEncoding(String filePath, String content, String encoding, boolean append)
写入内容到文本文件,支持指定字符编码和追加模式
**参数:**
- `filePath`: 文件路径(支持相对路径,如果为空或null则自动生成随机文件名)
- `content`: 要写入的内容
- `encoding`: 字符编码
- `append`: 是否追加到文件末尾(true为追加,false为覆盖)
**返回值:**
- 成功时返回"文件写入成功,文件路径: [完整文件路径]"
- 失败时返回错误信息
**示例:**
```java
// 覆盖写入
String result = fileTools.writeFileWithEncoding("/path/to/file.txt", "New content", "UTF-8", false);
// 追加写入
String result = fileTools.writeFileWithEncoding("/path/to/file.txt", "Additional content", "UTF-8", true);
// 自动生成随机文件名并写入
String result = fileTools.writeFileWithEncoding("", "Content with random filename", "UTF-8", false);
```
### 5. appendToFile(String filePath, String content)
追加内容到文本文件末尾
**参数:**
- `filePath`: 文件路径(支持相对路径,如果为空或null则自动生成随机文件名)
- `content`: 要追加的内容
**返回值:**
- 成功时返回"文件写入成功,文件路径: [完整文件路径]"
- 失败时返回错误信息
**示例:**
```java
String result = fileTools.appendToFile("/path/to/file.txt", "Appended content");
// 或使用相对路径
String result = fileTools.appendToFile("relative/path/to/file.txt", "Appended content");
// 或自动生成随机文件名
String result = fileTools.appendToFile("", "Appended content with random filename");
```
### 6. getFileSize(String filePath)
获取文件大小
**参数:**
- `filePath`: 文件路径(支持相对路径)
**返回值:**
- 成功时返回文件大小信息
- 失败时返回错误信息
**示例:**
```java
String sizeInfo = fileTools.getFileSize("/path/to/file.txt");
// 或使用相对路径
String sizeInfo = fileTools.getFileSize("relative/path/to/file.txt");
```
### 7. fileExists(String filePath)
检查文件是否存在
**参数:**
- `filePath`: 文件路径(支持相对路径)
**返回值:**
- 文件存在返回true
- 文件不存在返回false
**示例:**
```java
boolean exists = fileTools.fileExists("/path/to/file.txt");
// 或使用相对路径
boolean exists = fileTools.fileExists("relative/path/to/file.txt");
```
### 8. getFileInfo(String filePath)
获取文件详细信息
**参数:**
- `filePath`: 文件路径(支持相对路径)
**返回值:**
- 成功时返回文件详细信息(包括路径、大小、是否为文本文件、最后修改时间)
- 失败时返回错误信息
**示例:**
```java
String fileInfo = fileTools.getFileInfo("/path/to/file.txt");
// 或使用相对路径
String fileInfo = fileTools.getFileInfo("relative/path/to/file.txt");
```
### 9. generateRandomFileName(String extension)
生成随机文件名并返回完整路径
**参数:**
- `extension`: 文件扩展名(如 ".txt", "md" 等,如果不带点会自动添加)
**返回值:**
- 成功时返回完整文件路径
- 失败时返回错误信息
**示例:**
```java
String randomFilePath = fileTools.generateRandomFileName(".txt");
// 或不带点的扩展名
String randomFilePath = fileTools.generateRandomFileName("md");
```
## 使用注意事项
1. **字符编码**:默认使用UTF-8编码,可根据需要指定其他编码格式
2. **文件类型限制**:只能处理预定义的文本文件类型,非文本文件会被拒绝处理
3. **目录自动创建**:写入文件时会自动创建不存在的目录
4. **错误处理**:所有操作都有完善的错误处理和日志记录
5. **文件大小**:适合处理中小型文本文件,大文件处理可能影响性能
6. **路径支持**:支持相对路径,默认相对于当前工作目录
7. **随机文件名**:当filePath为空或null时,会自动生成随机文件名并存储在"storage"目录下
8. **扩展名推断**:当使用随机文件名时,会根据内容自动推断合适的文件扩展名
## 错误处理
工具类提供了完善的错误处理机制:
- 文件不存在时返回明确的错误信息
- 文件路径为空时自动生成随机文件名而不是报错
- IO异常时记录详细日志并返回友好的错误信息
- 编码错误时使用默认UTF-8编码并记录警告日志
## 性能优化
1. **内存使用**:使用NIO.2 API进行文件读写,提高效率
2. **字符编码**:自动检测和处理字符编码,确保内容正确性
3. **日志记录**:详细的日志记录便于问题排查和性能监控
4. **路径处理**:智能处理相对路径和绝对路径
5. **文件名生成**:使用UUID生成唯一的随机文件名,避免冲突
## 示例用法
```java
@Autowired
private FileProcessingTools fileTools;
// 读取文件
String content = fileTools.readFile("data/input.txt");
// 写入文件(自动生成随机文件名)
String writeResult = fileTools.writeFile("", "Hello, World!");
System.out.println(writeResult); // 输出文件路径
// 追加内容到文件
fileTools.appendToFile("logs/app.log", "New log entry\n");
// 获取文件信息
String fileInfo = fileTools.getFileInfo("config/settings.json");
```
\ No newline at end of file
...@@ -262,7 +262,21 @@ public class PlaywrightWebTools { ...@@ -262,7 +262,21 @@ public class PlaywrightWebTools {
return executeWithPage(url, page -> { return executeWithPage(url, page -> {
// 获取所有a标签的href属性 // 获取所有a标签的href属性
Object result = page.locator("a").evaluateAll("elements => elements.map(el => el.href)"); Object result = page.locator("a").evaluateAll("elements => elements.map(el => el.href)");
List<String> links = (List<String>) result; // 安全地进行类型转换
List<?> rawList;
if (result instanceof List) {
rawList = (List<?>) result;
} else {
log.warn("预期返回List类型,但实际返回: {}", result != null ? result.getClass().getName() : "null");
return "获取链接失败:返回类型错误";
}
// 安全地转换为List<String>
List<String> links = rawList.stream()
.map(item -> item != null ? item.toString() : "")
.filter(str -> !str.isEmpty())
.toList();
return links.isEmpty() ? "未找到任何链接" : String.join(", ", links); return links.isEmpty() ? "未找到任何链接" : String.join(", ", links);
}); });
} }
......
...@@ -5,8 +5,11 @@ import org.springframework.web.bind.annotation.*; ...@@ -5,8 +5,11 @@ import org.springframework.web.bind.annotation.*;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import pangea.hiagent.agent.service.AgentChatService; import pangea.hiagent.agent.service.AgentChatService;
import pangea.hiagent.agent.service.AgentValidationService;
import pangea.hiagent.common.utils.UserUtils; import pangea.hiagent.common.utils.UserUtils;
import pangea.hiagent.model.Agent;
import pangea.hiagent.web.dto.ChatRequest; import pangea.hiagent.web.dto.ChatRequest;
import pangea.hiagent.web.service.AgentService;
import jakarta.servlet.http.HttpServletResponse; import jakarta.servlet.http.HttpServletResponse;
import jakarta.validation.Valid; import jakarta.validation.Valid;
import jakarta.validation.constraints.NotBlank; import jakarta.validation.constraints.NotBlank;
...@@ -21,9 +24,11 @@ import jakarta.validation.constraints.NotBlank; ...@@ -21,9 +24,11 @@ import jakarta.validation.constraints.NotBlank;
public class AgentChatController { public class AgentChatController {
private final AgentChatService agentChatService; private final AgentChatService agentChatService;
private final AgentService agentService;
public AgentChatController(AgentChatService agentChatService) { public AgentChatController(AgentChatService agentChatService, AgentService agentService) {
this.agentChatService = agentChatService; this.agentChatService = agentChatService;
this.agentService = agentService;
} }
/** /**
...@@ -41,13 +46,27 @@ public class AgentChatController { ...@@ -41,13 +46,27 @@ public class AgentChatController {
HttpServletResponse response) { HttpServletResponse response) {
log.info("接收到流式对话请求,AgentId: {}", agentId); log.info("接收到流式对话请求,AgentId: {}", agentId);
// 检查用户权限 // 在主线程中完成权限检查,避免在异步线程中触发Spring Security异常
String userId = UserUtils.getCurrentUserId(); String userId = UserUtils.getCurrentUserId();
if (userId == null) { if (userId == null) {
log.warn("用户未认证,无法执行Agent对话"); log.warn("用户未认证,无法执行Agent对话");
throw new org.springframework.security.access.AccessDeniedException("用户未认证"); throw new org.springframework.security.access.AccessDeniedException("用户未认证");
} }
// 验证Agent存在性和权限
Agent agent = agentService.getAgent(agentId);
if (agent == null) {
log.warn("Agent不存在: {}", agentId);
throw new IllegalArgumentException("Agent不存在");
}
// 检查权限
if (!agent.getOwner().equals(userId) && !UserUtils.isAdminUser(userId)) {
log.warn("用户 {} 无权限访问Agent: {}", userId, agentId);
throw new org.springframework.security.access.AccessDeniedException("无权限访问该Agent");
}
// 权限验证通过,调用异步处理
return agentChatService.handleChatStream(agentId, chatRequest, response); return agentChatService.handleChatStream(agentId, chatRequest, response);
} }
} }
\ No newline at end of file
...@@ -11,7 +11,7 @@ import org.springframework.web.bind.annotation.RestController; ...@@ -11,7 +11,7 @@ import org.springframework.web.bind.annotation.RestController;
import pangea.hiagent.document.KnowledgeBaseInitializationService; import pangea.hiagent.document.KnowledgeBaseInitializationService;
import pangea.hiagent.web.dto.ApiResponse; import pangea.hiagent.web.dto.ApiResponse;
import pangea.hiagent.tool.ToolBeanNameInitializer;
/** /**
* 系统管理控制器 * 系统管理控制器
...@@ -23,31 +23,12 @@ import pangea.hiagent.tool.ToolBeanNameInitializer; ...@@ -23,31 +23,12 @@ import pangea.hiagent.tool.ToolBeanNameInitializer;
@Tag(name = "系统管理", description = "系统管理相关API") @Tag(name = "系统管理", description = "系统管理相关API")
public class SystemAdminController { public class SystemAdminController {
@Autowired
private ToolBeanNameInitializer toolBeanNameInitializer;
@Autowired @Autowired
private KnowledgeBaseInitializationService knowledgeBaseInitializationService; private KnowledgeBaseInitializationService knowledgeBaseInitializationService;
/**
* 手动触发工具Bean名称初始化
*
* @return 操作结果
*/
@PostMapping("/initialize-tool-beans")
@Operation(summary = "初始化工具Bean", description = "手动触发工具Bean名称初始化任务")
public ResponseEntity<ApiResponse<Void>> initializeToolBeans() {
try {
log.info("收到手动触发工具Bean初始化请求");
toolBeanNameInitializer.initializeToolBeanNamesManually();
log.info("工具Bean初始化完成");
return ResponseEntity.ok(ApiResponse.success(null, "工具Bean初始化完成"));
} catch (Exception e) {
log.error("工具Bean初始化失败", e);
return ResponseEntity.internalServerError()
.body(ApiResponse.error(500, "工具Bean初始化失败: " + e.getMessage()));
}
}
/** /**
* 手动触发知识库初始化 * 手动触发知识库初始化
......
package pangea.hiagent.web.controller; // package pangea.hiagent.web.controller;
import lombok.extern.slf4j.Slf4j; // import lombok.extern.slf4j.Slf4j;
import org.springframework.web.bind.annotation.*; // import org.springframework.web.bind.annotation.*;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; // import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import pangea.hiagent.agent.sse.UserSseService; // import pangea.hiagent.agent.sse.UserSseService;
import pangea.hiagent.common.utils.UserUtils; // import pangea.hiagent.common.utils.UserUtils;
import pangea.hiagent.workpanel.event.EventService; // import pangea.hiagent.workpanel.event.EventService;
/** // /**
* 时间轴事件控制器 // * 时间轴事件控制器
* 提供ReAct过程的实时事件推送功能 // * 提供ReAct过程的实时事件推送功能
*/ // */
@Slf4j // @Slf4j
@RestController // @RestController
@RequestMapping("/api/v1/agent") // @RequestMapping("/api/v1/agent")
public class TimelineEventController { // public class TimelineEventController {
private final UserSseService workPanelSseService; // private final UserSseService workPanelSseService;
public TimelineEventController(UserSseService workPanelSseService, EventService eventService) { // public TimelineEventController(UserSseService workPanelSseService, EventService eventService) {
this.workPanelSseService = workPanelSseService; // this.workPanelSseService = workPanelSseService;
} // }
/** // /**
* 订阅时间轴事件 // * 订阅时间轴事件
* 支持 SSE (Server-Sent Events) 格式的实时事件推送 // * 支持 SSE (Server-Sent Events) 格式的实时事件推送
* // *
* @return SSE emitter // * @return SSE emitter
*/ // */
@GetMapping("/timeline-events") // @GetMapping("/timeline-events")
public SseEmitter subscribeTimelineEvents() { // public SseEmitter subscribeTimelineEvents() {
log.info("开始处理时间轴事件订阅请求"); // log.info("开始处理时间轴事件订阅请求");
// 获取当前认证用户ID // // 获取当前认证用户ID
String userId = UserUtils.getCurrentUserId(); // String userId = UserUtils.getCurrentUserId();
if (userId == null) { // if (userId == null) {
log.warn("用户未认证,无法创建时间轴事件订阅"); // log.warn("用户未认证,无法创建时间轴事件订阅");
throw new org.springframework.security.access.AccessDeniedException("用户未认证"); // throw new org.springframework.security.access.AccessDeniedException("用户未认证");
} // }
log.info("开始为用户 {} 创建SSE连接", userId); // log.info("开始为用户 {} 创建SSE连接", userId);
// 创建并注册SSE连接 // // 创建并注册SSE连接
return workPanelSseService.createAndRegisterConnection(userId); // return workPanelSseService.createAndRegisterConnection(userId);
} // }
} // }
\ No newline at end of file \ No newline at end of file
package pangea.hiagent.web.service; package pangea.hiagent.web.service;
import org.springframework.cache.annotation.CacheEvict;
import org.springframework.cache.annotation.Cacheable;
import pangea.hiagent.model.ToolConfig; import pangea.hiagent.model.ToolConfig;
import java.util.List; import java.util.List;
...@@ -12,59 +15,69 @@ import java.util.Map; ...@@ -12,59 +15,69 @@ import java.util.Map;
public interface ToolConfigService { public interface ToolConfigService {
/** /**
* 根据工具名称获取参数配置 * 根据工具名称获取参数配置(带缓存)
* @param toolName 工具名称 * @param toolName 工具名称
* @return 参数配置键值对 * @return 参数配置键值对
*/ */
@Cacheable(value = "toolConfigByToolName", key = "#toolName")
Map<String, String> getToolParams(String toolName); Map<String, String> getToolParams(String toolName);
/** /**
* 根据工具名称和参数名称获取参数值 * 根据工具名称和参数名称获取参数值(带缓存)
* @param toolName 工具名称 * @param toolName 工具名称
* @param paramName 参数名称 * @param paramName 参数名称
* @return 参数值 * @return 参数值
*/ */
@Cacheable(value = "toolConfig", key = "#toolName + '_' + #paramName")
String getParamValue(String toolName, String paramName); String getParamValue(String toolName, String paramName);
/** /**
* 保存参数值 * 保存参数值(自动清除缓存)
* @param toolName 工具名称 * @param toolName 工具名称
* @param paramName 参数名称 * @param paramName 参数名称
* @param paramValue 参数值 * @param paramValue 参数值
*/ */
@CacheEvict(value = "toolConfig", key = "#toolName + '_' + #paramName")
void saveParamValue(String toolName, String paramName, String paramValue); void saveParamValue(String toolName, String paramName, String paramValue);
/** /**
* 获取所有工具配置 * 获取所有工具配置(带缓存)
* @return 工具配置列表 * @return 工具配置列表
*/ */
@Cacheable(value = "allToolConfigs", key = "'all'")
List<ToolConfig> getAllToolConfigs(); List<ToolConfig> getAllToolConfigs();
/** /**
* 根据工具名称和参数名称获取工具配置 * 根据工具名称和参数名称获取工具配置(带缓存)
* @param toolName 工具名称 * @param toolName 工具名称
* @param paramName 参数名称 * @param paramName 参数名称
* @return 工具配置对象 * @return 工具配置对象
*/ */
@Cacheable(value = "toolConfig", key = "#toolName + '_' + #paramName")
ToolConfig getToolConfig(String toolName, String paramName); ToolConfig getToolConfig(String toolName, String paramName);
/** /**
* 保存工具配置 * 保存工具配置(自动清除相关缓存)
* @param toolConfig 工具配置对象 * @param toolConfig 工具配置对象
* @return 保存后的工具配置对象 * @return 保存后的工具配置对象
*/ */
@CacheEvict(value = {"toolConfig", "toolConfigByToolName", "toolConfigsByToolName"},
key = "#toolConfig.toolName + '_' + #toolConfig.paramName")
ToolConfig saveToolConfig(ToolConfig toolConfig); ToolConfig saveToolConfig(ToolConfig toolConfig);
/** /**
* 删除工具配置 * 删除工具配置(自动清除相关缓存)
* @param id 配置ID * @param id 配置ID
*/ */
@CacheEvict(value = {"toolConfig", "toolConfigByToolName", "toolConfigsByToolName"},
allEntries = true) // 删除配置时清除所有缓存,因为不知道具体工具名
void deleteToolConfig(String id); void deleteToolConfig(String id);
/** /**
* 根据工具名称获取工具配置列表 * 根据工具名称获取工具配置列表(带缓存)
* @param toolName 工具名称 * @param toolName 工具名称
* @return 工具配置列表 * @return 工具配置列表
*/ */
@Cacheable(value = "toolConfigsByToolName", key = "#toolName")
List<ToolConfig> getToolConfigsByToolName(String toolName); List<ToolConfig> getToolConfigsByToolName(String toolName);
} }
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -107,6 +107,6 @@ REM 设置更多调试参数 ...@@ -107,6 +107,6 @@ REM 设置更多调试参数
set JAVA_OPTS=-Dfile.encoding=UTF-8 -Dspring.profiles.active=dev -Dlogging.level.root=DEBUG -Dlogging.level.pangea.hiagent=TRACE -Dlogging.level.org.springframework.web=DEBUG -Dlogging.level.org.springframework.security=DEBUG -Dlogging.level.org.springframework.web.socket=DEBUG -Dlogging.level.org.projectlombok=DEBUG set JAVA_OPTS=-Dfile.encoding=UTF-8 -Dspring.profiles.active=dev -Dlogging.level.root=DEBUG -Dlogging.level.pangea.hiagent=TRACE -Dlogging.level.org.springframework.web=DEBUG -Dlogging.level.org.springframework.security=DEBUG -Dlogging.level.org.springframework.web.socket=DEBUG -Dlogging.level.org.projectlombok=DEBUG
echo [INFO] 启动Spring Boot应用... echo [INFO] 启动Spring Boot应用...
call mvn spring-boot:run -Dspring-boot.run.arguments="--spring.jpa.hibernate.ddl-auto=create-drop --logging.level.root=DEBUG --logging.level.pangea.hiagent=TRACE --logging.level.org.springframework.web=DEBUG --logging.level.org.springframework.security=DEBUG --logging.level.org.springframework.web.socket=DEBUG --logging.level.org.projectlombok=DEBUG" call mvn spring-boot:run -Dspring-boot.run.arguments="--spring-boot.run.profiles=dev"
pause pause
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment