Commit bee22bb0 authored by ligaowei's avatar ligaowei

重构ReAct执行器架构:1. 在ReactExecutor接口中添加executeWithAgent和executeStreamWithAgent方法...

重构ReAct执行器架构:1. 在ReactExecutor接口中添加executeWithAgent和executeStreamWithAgent方法 2. 修改ReActService中消除具体类型判断 3. 移动相关类到react包下
parent 46bfcc21
......@@ -39,6 +39,9 @@ public class ToolExecutionLoggerAspect {
String className = signature.getDeclaringType().getSimpleName();
String fullMethodName = className + "." + methodName;
// 使用完整的工具名称(类名)以确保开始和完成事件能够正确匹配
String toolName = className;
// 获取工具描述
String toolDescription = tool.description();
......@@ -62,7 +65,7 @@ public class ToolExecutionLoggerAspect {
// 记录工具调用开始
if (workPanelDataCollector != null) {
try {
workPanelDataCollector.recordToolCallStart(className, methodName, inputParams);
workPanelDataCollector.recordToolCallStart(toolName, methodName, inputParams);
} catch (Exception e) {
log.warn("记录工具调用开始时发生错误: {}", e.getMessage());
}
......@@ -81,7 +84,7 @@ public class ToolExecutionLoggerAspect {
// 记录工具调用完成
if (workPanelDataCollector != null) {
try {
workPanelDataCollector.recordToolCallComplete(className, result, "success", executionTime);
workPanelDataCollector.recordToolCallComplete(toolName, result, "success", executionTime);
} catch (Exception e) {
log.warn("记录工具调用完成时发生错误: {}", e.getMessage());
}
......@@ -99,7 +102,7 @@ public class ToolExecutionLoggerAspect {
// 记录工具调用错误
if (workPanelDataCollector != null) {
try {
workPanelDataCollector.recordToolCallComplete(className, e.getMessage(), "error", executionTime);
workPanelDataCollector.recordToolCallComplete(toolName, e.getMessage(), "error", executionTime);
} catch (Exception ex) {
log.warn("记录工具调用错误时发生错误: {}", ex.getMessage());
}
......
......@@ -7,7 +7,7 @@ import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import pangea.hiagent.agent.AgentChatService;
import pangea.hiagent.core.AgentChatService;
import pangea.hiagent.workpanel.SseEventManager;
import pangea.hiagent.dto.AgentRequest;
import pangea.hiagent.dto.ChatRequest;
......
......@@ -190,35 +190,46 @@ public class TimelineEventController {
sb.append("\"toolAction\":\"").append(sanitizeJsonString(event.getToolAction())).append("\",");
// 正确序列化 toolInput 为 JSON 对象
if (event.getToolInput() != null) {
// 无论toolInput是否为null,都要包含这个字段以确保前端能正确识别
try {
if (event.getToolInput() != null) {
String toolInputJson = objectMapper.writeValueAsString(event.getToolInput());
sb.append("\"toolInput\":").append(toolInputJson).append(",");
log.debug("[toolInput序列化成功] 工具={}, JSON={}", event.getToolName(), toolInputJson);
} else {
sb.append("\"toolInput\":null,");
log.debug("[toolInput为null] 工具={}, 类型={}", event.getToolName(), event.getEventType());
}
} catch (Exception e) {
// 如果序列化失败,记录警告并回退到字符串表示
log.warn("[序列化toolInput失败] 工具={}, 错误={}, 已回退为字符串表示", event.getToolName(), e.getMessage());
if (event.getToolInput() != null) {
sb.append("\"toolInput\":\"").append(sanitizeJsonString(event.getToolInput().toString())).append("\",");
}
} else {
log.debug("[toolInput为null] 工具={}, 类型={}", event.getToolName(), event.getEventType());
sb.append("\"toolInput\":null,");
}
}
// 正确序列化 toolOutput 为 JSON 对象
if (event.getToolOutput() != null) {
// 无论toolOutput是否为null,都要包含这个字段以确保前端能正确识别
try {
if (event.getToolOutput() != null) {
String toolOutputJson = objectMapper.writeValueAsString(event.getToolOutput());
sb.append("\"toolOutput\":").append(toolOutputJson).append(",");
log.debug("[toolOutput序列化成功] 工具={}, JSON={}", event.getToolName(), toolOutputJson);
} else {
sb.append("\"toolOutput\":null,");
log.debug("[toolOutput为null] 工具={}, 类型={}", event.getToolName(), event.getEventType());
}
} catch (Exception e) {
// 如果序列化失败,记录警告并回退到字符串表示
log.warn("[序列化toolOutput失败] 工具={}, 错误={}, 已回退为字符串表示", event.getToolName(), e.getMessage());
if (event.getToolOutput() != null) {
sb.append("\"toolOutput\":\"").append(sanitizeJsonString(String.valueOf(event.getToolOutput()))).append("\",");
}
} else {
log.debug("[toolOutput为null] 工具={}, 类型={}", event.getToolName(), event.getEventType());
sb.append("\"toolOutput\":null,");
}
}
if (event.getToolStatus() != null) {
sb.append("\"toolStatus\":\"").append(sanitizeJsonString(event.getToolStatus())).append("\",");
}
......
package pangea.hiagent.agent;
package pangea.hiagent.core;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.messages.AssistantMessage;
......@@ -15,7 +15,7 @@ import pangea.hiagent.dto.WorkPanelEvent;
import pangea.hiagent.model.Agent;
import pangea.hiagent.model.AgentDialogue;
import pangea.hiagent.service.AgentService;
import pangea.hiagent.agent.ReActService;
import pangea.hiagent.core.ReActService;
import pangea.hiagent.workpanel.SseEventManager;
import pangea.hiagent.memory.SmartHistorySummarizer;
import java.util.ArrayList;
......@@ -51,7 +51,7 @@ public class AgentChatService {
public String handleReActAgentRequest(Agent agent, AgentRequest request, String userId) {
log.info("使用ReAct Agent处理请求");
// 使用ReAct Agent处理请求,传递userId以支持记忆功能
String responseContent = reActService.processRequestWithUserId(agent, request.getUserMessage(), userId);
String responseContent = reActService.processRequest(agent, request.getUserMessage(), userId);
// 保存对话记录并返回结果
return responseContent;
......@@ -120,7 +120,7 @@ public class AgentChatService {
Consumer<String> tokenConsumer) {
log.info("使用ReAct Agent处理流式请求");
// 使用ReAct Agent流式处理请求,传递userId以支持记忆功能
reActService.processRequestStreamWithUserId(agent, request.getUserMessage(), tokenConsumer, userId);
reActService.processRequestStream(agent, request.getUserMessage(), tokenConsumer, userId);
}
/**
......
package pangea.hiagent.core;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.stereotype.Service;
import pangea.hiagent.model.Agent;
import pangea.hiagent.model.Tool;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
/**
* Agent工具管理服务类
* 负责管理Agent可用的工具列表
*/
@Slf4j
@Service
public class AgentToolManager {
@Autowired
private pangea.hiagent.service.ToolService toolService;
@Autowired
private ApplicationContext applicationContext;
/**
* 获取Agent可用的工具列表
* @param agent Agent对象
* @return 工具列表
*/
public List<Tool> getAvailableTools(Agent agent) {
try {
log.info("获取Agent可用工具列表,Agent ID: {}, 名称: {}", agent.getId(), agent.getName());
// 获取Agent所有者的所有活跃工具
List<Tool> allTools = toolService.getUserToolsByStatus(agent.getOwner(), "active");
log.info("用户所有活跃工具数量: {}", allTools != null ? allTools.size() : 0);
if (allTools == null || allTools.isEmpty()) {
log.warn("Agent: {} 没有配置可用的工具", agent.getId());
return List.of();
}
// 如果Agent配置了特定的工具列表,则只返回配置的工具
List<String> toolNames = agent.getToolNames();
log.info("Agent配置的工具名称列表: {}", toolNames);
if (toolNames != null && !toolNames.isEmpty()) {
// 根据工具名称筛选工具
List<Tool> filteredTools = filterToolsByName(allTools, toolNames);
log.info("筛选后的工具数量: {}", filteredTools.size());
return filteredTools;
}
return allTools;
} catch (Exception e) {
log.error("获取Agent可用工具时发生错误", e);
return List.of();
}
}
/**
* 根据工具名称筛选工具
* @param allTools 所有工具
* @param toolNames 工具名称列表
* @return 筛选后的工具列表
*/
public List<Tool> filterToolsByName(List<Tool> allTools, List<String> toolNames) {
return allTools.stream()
.filter(tool -> toolNames.contains(tool.getName()))
.collect(Collectors.toList());
}
/**
* 根据工具名称集合筛选工具实例(用于ReActService)
* @param allTools 所有工具实例
* @param toolNames 工具名称集合
* @return 筛选后的工具实例列表
*/
public List<Object> filterToolsByInstances(List<Object> allTools, Set<String> toolNames) {
log.debug("开始筛选工具实例,工具名称集合: {}", toolNames);
if (toolNames == null || toolNames.isEmpty()) {
log.debug("工具名称集合为空,返回所有工具实例");
return allTools;
}
List<Object> filteredTools = allTools.stream()
.filter(tool -> {
// 获取工具类名(不含包名)
String className = tool.getClass().getSimpleName();
log.debug("检查工具类: {}", className);
// 检查类名是否匹配
boolean isMatch = toolNames.contains(className) ||
toolNames.stream().anyMatch(name ->
className.toLowerCase().contains(name.toLowerCase()));
if (isMatch) {
log.debug("工具 {} 匹配成功", className);
}
return isMatch;
})
.collect(Collectors.toList());
log.debug("筛选完成,返回 {} 个工具实例", filteredTools.size());
return filteredTools;
}
/**
* 构建工具描述文本
* @param tools 工具列表
* @return 工具描述文本
*/
public String buildToolsDescription(List<Tool> tools) {
if (tools.isEmpty()) {
return "(暂无可用工具)";
}
StringBuilder description = new StringBuilder();
for (int i = 0; i < tools.size(); i++) {
Tool tool = tools.get(i);
description.append(i + 1).append(". ");
description.append(tool.getName());
if (hasValue(tool.getDisplayName())) {
description.append(" - ").append(tool.getDisplayName());
}
if (hasValue(tool.getDescription())) {
description.append(" - ").append(tool.getDescription());
}
description.append("\n");
}
return description.toString();
}
/**
* 检查字符串是否有值
* @param value 字符串值
* @return 是否有值
*/
private boolean hasValue(String value) {
return value != null && !value.isEmpty();
}
/**
* 根据Agent获取可用的工具实例
* @param agent Agent对象
* @return 工具实例列表
*/
public List<Object> getAvailableToolInstances(Agent agent) {
// 获取Agent可用的工具定义
List<Tool> availableTools = getAvailableTools(agent);
// 获取所有Spring管理的bean名称
String[] beanNames = applicationContext.getBeanDefinitionNames();
// 根据工具名称筛选对应的工具实例
List<Object> toolInstances = new ArrayList<>();
Set<String> availableToolNames = availableTools.stream()
.map(Tool::getName)
.collect(Collectors.toSet());
for (String beanName : beanNames) {
Object bean = applicationContext.getBean(beanName);
String simpleClassName = bean.getClass().getSimpleName();
// 检查bean的类名是否与可用工具名称匹配
if (availableToolNames.contains(simpleClassName)) {
toolInstances.add(bean);
}
}
log.info("智能体{}获取到的工具实例数量: {}", agent.getName(), toolInstances.size());
return toolInstances;
}
}
\ No newline at end of file
package pangea.hiagent.prompt;
package pangea.hiagent.core;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import pangea.hiagent.model.Agent;
import pangea.hiagent.model.Tool;
import pangea.hiagent.service.ToolService;
import pangea.hiagent.core.AgentToolManager;
import java.util.List;
import java.util.stream.Collectors;
/**
* 提示词服务类
......@@ -19,7 +18,7 @@ import java.util.stream.Collectors;
public class PromptService {
@Autowired
private ToolService toolService;
private AgentToolManager agentToolManager;
/**
* 构建系统提示词 - 根据Agent配置的工具动态生成
......@@ -34,7 +33,7 @@ public class PromptService {
try {
// 获取Agent配置的可用工具列表
List<Tool> agentTools = getAvailableTools(agent);
List<Tool> agentTools = agentToolManager.getAvailableTools(agent);
// 如果没有工具,直接返回默认提示词
if (agentTools.isEmpty()) {
......@@ -42,7 +41,7 @@ public class PromptService {
}
// 构建工具描述部分
String toolsDescription = buildToolsDescription(agentTools);
String toolsDescription = agentToolManager.buildToolsDescription(agentTools);
String toolsList = buildToolsList(agentTools);
// 构建默认系统提示词,包含动态生成的工具信息
......@@ -54,90 +53,6 @@ public class PromptService {
}
}
/**
* 获取Agent可用的工具列表
* @param agent Agent对象
* @return 工具列表
*/
private List<Tool> getAvailableTools(Agent agent) {
try {
log.info("获取Agent可用工具列表,Agent ID: {}, 名称: {}", agent.getId(), agent.getName());
// 获取Agent所有者的所有活跃工具
List<Tool> allTools = toolService.getUserToolsByStatus(agent.getOwner(), "active");
log.info("用户所有活跃工具数量: {}", allTools != null ? allTools.size() : 0);
if (allTools == null || allTools.isEmpty()) {
log.warn("Agent: {} 没有配置可用的工具", agent.getId());
return List.of();
}
// 如果Agent配置了特定的工具列表,则只返回配置的工具
List<String> toolNames = agent.getToolNames();
log.info("Agent配置的工具名称列表: {}", toolNames);
if (toolNames != null && !toolNames.isEmpty()) {
// 根据工具名称筛选工具
List<Tool> filteredTools = filterToolsByName(allTools, toolNames);
log.info("筛选后的工具数量: {}", filteredTools.size());
return filteredTools;
}
return allTools;
} catch (Exception e) {
log.error("获取Agent可用工具时发生错误", e);
return List.of();
}
}
/**
* 根据工具名称筛选工具
* @param allTools 所有工具
* @param toolNames 工具名称列表
* @return 筛选后的工具列表
*/
private List<Tool> filterToolsByName(List<Tool> allTools, List<String> toolNames) {
return allTools.stream()
.filter(tool -> toolNames.contains(tool.getName()))
.collect(Collectors.toList());
}
/**
* 构建工具描述文本
* @param tools 工具列表
* @return 工具描述文本
*/
private String buildToolsDescription(List<Tool> tools) {
if (tools.isEmpty()) {
return "(暂无可用工具)";
}
StringBuilder description = new StringBuilder();
for (int i = 0; i < tools.size(); i++) {
Tool tool = tools.get(i);
description.append(i + 1).append(". ");
description.append(tool.getName());
if (hasValue(tool.getDisplayName())) {
description.append(" - ").append(tool.getDisplayName());
}
if (hasValue(tool.getDescription())) {
description.append(" - ").append(tool.getDescription());
}
description.append("\n");
}
return description.toString();
}
/**
* 检查字符串是否有值
* @param value 字符串值
* @return 是否有值
*/
private boolean hasValue(String value) {
return value != null && !value.isEmpty();
}
/**
* 构建工具名称列表(用于Action的可选值)
* @param tools 工具列表
......@@ -150,7 +65,7 @@ public class PromptService {
return tools.stream()
.map(Tool::getName)
.collect(Collectors.joining(", "));
.collect(java.util.stream.Collectors.joining(", "));
}
/**
......
package pangea.hiagent.agent;
package pangea.hiagent.core;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.stereotype.Service;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Service;
import pangea.hiagent.memory.MemoryService;
import pangea.hiagent.model.Agent;
import pangea.hiagent.service.AgentService;
import pangea.hiagent.dto.WorkPanelEvent;
import pangea.hiagent.model.Tool;
import pangea.hiagent.rag.RagService;
import pangea.hiagent.tool.DefaultReactExecutor;
import pangea.hiagent.tool.ReactCallback;
import pangea.hiagent.tool.ReactExecutor;
import pangea.hiagent.memory.MemoryService;
import pangea.hiagent.workpanel.IWorkPanelDataCollector;
import pangea.hiagent.react.ReactCallback;
import pangea.hiagent.react.ReactExecutor;
import pangea.hiagent.service.AgentService;
import pangea.hiagent.service.AgentToolManager;
import java.util.List;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.function.Consumer;
/**
* 基于Spring AI ChatClient的ReAct Service类
......@@ -30,6 +32,19 @@ import java.util.stream.Collectors;
@Service
public class ReActService {
/**
* Token消费者接口,支持完成回调
*/
public interface TokenConsumerWithCompletion extends Consumer<String> {
/**
* 当流式处理完成时调用
* @param fullContent 完整的内容
*/
default void onComplete(String fullContent) {
// 默认实现为空
}
}
// 添加默认构造函数以支持测试
public ReActService() {
}
......@@ -39,14 +54,10 @@ public class ReActService {
@Autowired
private RagService ragService;
// 注入所有带@Component注解的工具类
@Autowired
@Lazy
private List<Object> allTools;
// 通过AgentToolManager管理工具,不再直接注入所有@Component工具类
@Autowired
private IWorkPanelDataCollector workPanelCollector;
private pangea.hiagent.workpanel.IWorkPanelDataCollector workPanelCollector;
@Autowired
private MemoryService memoryService;
......@@ -56,62 +67,18 @@ public class ReActService {
@Autowired
private ReactExecutor defaultReactExecutor;
/**
* 根据工具名称筛选工具实例
* @param toolNames 工具名称集合
* @return 筛选后的工具实例列表
*/
private List<Object> filterToolsByNames(Set<String> toolNames) {
log.debug("开始筛选工具,工具名称集合: {}", toolNames);
if (toolNames == null || toolNames.isEmpty()) {
log.debug("工具名称集合为空,返回所有工具");
return allTools;
}
List<Object> filteredTools = allTools.stream()
.filter(tool -> {
// 获取工具类名(不含包名)
String className = tool.getClass().getSimpleName();
log.debug("检查工具类: {}", className);
// 检查类名是否匹配
boolean isMatch = toolNames.contains(className) ||
toolNames.stream().anyMatch(name ->
className.toLowerCase().contains(name.toLowerCase()));
if (isMatch) {
log.debug("工具 {} 匹配成功", className);
}
return isMatch;
})
.collect(Collectors.toList());
log.debug("筛选完成,返回 {} 个工具", filteredTools.size());
return filteredTools;
}
@Autowired
private AgentToolManager agentToolManager;
/**
* 处理用户请求的主方法(同步方式)
*
* @param agent Agent对象
* @param userMessage 用户消息
* @return 处理结果
*/
public String processRequest(Agent agent, String userMessage) {
return processRequestWithUserId(agent, userMessage, null);
}
/**
* 处理用户请求的主方法(同步方式)- 支持显式传递userId
*
* @param agent Agent对象
* @param userMessage 用户消息
* @param userId 用户ID(可选)
* @return 处理结果
*/
public String processRequestWithUserId(Agent agent, String userMessage, String userId) {
public String processRequest(Agent agent, String userMessage, String userId) {
log.info("开始处理ReAct Agent请求,Agent ID: {}, 用户消息: {}", agent.getId(), userMessage);
try {
......@@ -132,14 +99,16 @@ public class ReActService {
}
// 准备执行环境
ChatClient client = prepareChatClient(agent);
List<Object> tools = prepareTools(agent);
ChatClient client = ChatClient.builder(agentService.getChatModelForAgent(agent)).build();
List<Object> tools = agentToolManager.getAvailableToolInstances(agent);
// 添加自定义回调到ReAct执行器
addReactCallbackIfNeeded();
if (defaultReactExecutor != null && defaultReactCallback != null) {
defaultReactExecutor.addReactCallback(defaultReactCallback);
}
// 使用ReAct执行器执行流程,传递Agent对象以支持记忆功能
String finalAnswer = executeReactProcess(client, userMessage, tools, agent);
String finalAnswer = defaultReactExecutor.executeWithAgent(client, userMessage, tools, agent);
// 将助理回复添加到ChatMemory
memoryService.addAssistantMessageToMemory(sessionId, finalAnswer);
......@@ -158,66 +127,19 @@ public class ReActService {
}
/**
* 准备ChatClient
* @param agent Agent对象
* @return ChatClient实例
*/
private ChatClient prepareChatClient(Agent agent) {
// 根据Agent配置获取对应的ChatModel
ChatModel chatModel = agentService.getChatModelForAgent(agent);
log.info("获取ChatModel成功: {}", chatModel.getClass().getName());
// 使用获取的ChatModel构建ChatClient
return ChatClient.builder(chatModel).build();
}
/**
* 准备工具列表
* 处理用户请求的主方法(同步方式)- 默认不传递userId
*
* @param agent Agent对象
* @return 工具列表
*/
private List<Object> prepareTools(Agent agent) {
log.info("准备工具列表,Agent ID: {}, Agent名称: {}", agent.getId(), agent.getName());
// 获取Agent配置的工具名称集合
Set<String> toolNames = agent.getToolNameSet();
log.info("Agent配置的工具名称集合: {}", toolNames);
// 根据工具名称筛选工具实例
List<Object> tools = filterToolsByNames(toolNames);
log.info("筛选后的工具数量: {}", tools.size());
// 记录工具详情
tools.forEach(tool -> log.debug("工具类: {}", tool.getClass().getSimpleName()));
return tools;
}
/**
* 添加ReAct回调(如果需要)
*/
private void addReactCallbackIfNeeded() {
if (defaultReactExecutor != null && defaultReactCallback != null) {
defaultReactExecutor.addReactCallback(defaultReactCallback);
}
}
/**
* 执行ReAct流程
* @param client ChatClient实例
* @param userMessage 用户消息
* @param tools 工具列表
* @param agent Agent对象
* @return 最终答案
* @return 处理结果
*/
private String executeReactProcess(ChatClient client, String userMessage, List<Object> tools, Agent agent) {
if (defaultReactExecutor instanceof DefaultReactExecutor) {
return ((DefaultReactExecutor) defaultReactExecutor).executeWithAgent(client, userMessage, tools, agent);
} else {
return defaultReactExecutor.execute(client, userMessage, tools);
}
public String processRequest(Agent agent, String userMessage) {
return processRequest(agent, userMessage, null);
}
// 移除了prepareChatClient、prepareTools、addReactCallbackIfNeeded和executeReactProcess方法,
// 因为它们只是简单地封装了其他方法调用,可以直接内联到processRequest方法中以提高效率
/**
* 尝试RAG增强
* @param agent Agent对象
......@@ -257,215 +179,149 @@ public class ReActService {
return response;
} catch (Exception e) {
log.error("保存响应到内存时发生错误", e);
return response; // 即使保存失败也返回响应
return response; // 即使保存失败,也要返回原始响应
}
}
/**
* 保存RAG响应到ChatMemory (已废弃,使用saveResponseToMemory替代)
* 保存RAG响应到ChatMemory
* @param agent Agent对象
* @param ragResponse RAG响应
* @return RAG响应
* @param ragResponse RAG响应内容
* @return RAG响应内容
*/
@Deprecated
private String saveRagResponseToMemory(Agent agent, String ragResponse) {
return saveResponseToMemory(agent, ragResponse);
}
/**
* 处理流式错误
*/
private void handleStreamError(AtomicBoolean isCompleted, Consumer<String> tokenConsumer, String errorMessage) {
// 确保只处理一次错误
if (!isCompleted.getAndSet(true) && tokenConsumer != null) {
try {
// 记录详细错误日志
log.error("流式处理错误: {}", errorMessage);
// 同时将错误信息记录到工作面板
if (workPanelCollector != null) {
try {
workPanelCollector.recordLog("流式处理错误: " + errorMessage, "error");
String sessionId = memoryService.generateSessionId(agent);
// 将RAG响应添加到ChatMemory
memoryService.addAssistantMessageToMemory(sessionId, "[RAG增强] " + ragResponse);
return ragResponse;
} catch (Exception e) {
log.debug("记录错误到工作面板失败: {}", e.getMessage());
log.error("保存RAG响应到内存时发生错误", e);
return ragResponse; // 即使保存失败,也要返回原始响应
}
}
// 发送错误信息给客户端
tokenConsumer.accept("[ERROR] " + errorMessage);
} catch (Exception e) {
log.error("发送错误消息时发生异常", e);
/**
* 判断异常是否为401未授权错误
* @param e 异常对象
* @return 是否为401错误
*/
private boolean isUnauthorizedError(Throwable e) {
if (e == null) {
return false;
}
// 检查异常消息中是否包含401 Unauthorized
String message = e.getMessage();
if (message != null && (message.contains("401 Unauthorized") || message.contains("Unauthorized"))) {
return true;
}
// 递归检查cause
return isUnauthorizedError(e.getCause());
}
/**
* 获取流式模型
* 初始化工作面板
*/
private StreamingChatModel getStreamingChatModel(Agent agent) {
private void initializeWorkPanel(Agent agent) {
if (workPanelCollector != null) {
try {
ChatModel chatModel = agentService.getChatModelForAgent(agent);
if (!(chatModel instanceof StreamingChatModel)) {
log.warn("模型不支持流式输出: {}", chatModel.getClass().getName());
return null;
}
return (StreamingChatModel) chatModel;
workPanelCollector.clear();
workPanelCollector.recordLog("开始处理Agent请求: " + agent.getName(), "info");
} catch (Exception e) {
log.error("获取流式模型失败: {}", e.getMessage(), e);
return null;
log.error("初始化工作面板时发生错误", e);
}
}
}
/**
* 流式处理ReAct Agent请求
*
* 优化后的实现采用更直接的流式处理方式,确保与普通Agent流式处理保持一致的行为
* 核心优化点:
* 1. 简化了Consumer包装逻辑,减少不必要的复杂性
* 2. 统一了onComplete回调机制,确保前端能正确接收到完整内容
* 3. 增强了错误处理机制,提供更清晰的错误信息
* 4. 支持对话历史记忆功能
*
* @param agent Agent对象
* @param userMessage 用户消息
* @param tokenConsumer token处理回调函数(前端的TokenConsumerWithCompletion实现)
* 为流式处理设置工作面板数据收集器
*/
public void processRequestStream(Agent agent, String userMessage, Consumer<String> tokenConsumer) {
processRequestStreamWithUserId(agent, userMessage, tokenConsumer, null);
public void setWorkPanelEventSubscriber(java.util.function.Consumer<pangea.hiagent.dto.WorkPanelEvent> consumer) {
// 订阅工作面板事件,用于实时推送
if (workPanelCollector != null && consumer != null) {
workPanelCollector.subscribe(consumer);
}
}
/**
* 流式处理ReAct Agent请求 - 支持显式传递userId
* 流式处理用户请求
*
* @param agent Agent对象
* @param userMessage 用户消息
* @param tokenConsumer token处理回调函数(前端的TokenConsumerWithCompletion实现)
* @param tokenConsumer token处理回调函数
* @param userId 用户ID(可选)
*/
public void processRequestStreamWithUserId(Agent agent, String userMessage, Consumer<String> tokenConsumer, String userId) {
AtomicBoolean isCompleted = new AtomicBoolean(false);
try {
public void processRequestStream(Agent agent, String userMessage, java.util.function.Consumer<String> tokenConsumer, String userId) {
log.info("开始流式处理ReAct Agent请求,Agent ID: {}, 用户消息: {}", agent.getId(), userMessage);
// 检查用户消息是否为空
if (userMessage == null || userMessage.trim().isEmpty()) {
String errorMsg = "用户消息不能为空";
log.error(errorMsg);
handleStreamError(isCompleted, tokenConsumer, errorMsg);
return;
}
try {
// 为每个用户-Agent组合创建唯一的会话ID
String sessionId = memoryService.generateSessionId(agent, userId);
// 添加用户消息到ChatMemory
memoryService.addUserMessageToMemory(sessionId, userMessage);
// 初始化工作面板
initializeWorkPanel(agent);
// 获取流式模型
StreamingChatModel streamingChatModel = getStreamingChatModel(agent);
if (streamingChatModel == null) {
String errorMsg = "当前模型不支持流式输出,请检查Agent配置";
log.error(errorMsg);
handleStreamError(isCompleted, tokenConsumer, errorMsg);
// 检查是否启用RAG并尝试RAG增强
String ragResponse = tryRagEnhancement(agent, userMessage);
if (ragResponse != null) {
// 触发最终答案回调
if (defaultReactCallback != null) {
defaultReactCallback.onFinalAnswer(ragResponse);
}
// 对于流式处理,我们需要将RAG响应作为token发送
if (tokenConsumer != null) {
tokenConsumer.accept(ragResponse);
// 发送完成信号
if (tokenConsumer instanceof TokenConsumerWithCompletion) {
((TokenConsumerWithCompletion) tokenConsumer).onComplete(ragResponse);
}
}
return;
}
// 准备执行环境
ChatClient client = ChatClient.builder((ChatModel) streamingChatModel).build();
List<Object> tools = prepareTools(agent);
ChatClient client = ChatClient.builder(agentService.getChatModelForAgent(agent)).build();
List<Object> tools = agentToolManager.getAvailableToolInstances(agent);
// 添加自定义回调到ReAct执行器
addReactCallbackIfNeeded();
// 直接传递tokenConsumer给ReAct执行器,简化处理逻辑
// ReAct执行器会负责处理token和onComplete回调
// 传递Agent对象以支持记忆功能
executeReactStreamProcess(client, userMessage, tools, tokenConsumer, agent);
log.debug("流式执行完成");
if (defaultReactExecutor != null && defaultReactCallback != null) {
defaultReactExecutor.addReactCallback(defaultReactCallback);
}
// 使用ReAct执行器流式执行流程,传递Agent对象以支持记忆功能
defaultReactExecutor.executeStreamWithAgent(client, userMessage, tools, tokenConsumer, agent);
} catch (Exception e) {
String errorMsg = "流式处理ReAct请求时发生错误: " + e.getMessage();
// 检查是否是401 Unauthorized错误
if (isUnauthorizedError(e)) {
log.error("LLM返回401未授权错误: {}", e.getMessage());
errorMsg = " 请配置API密钥";
} else {
log.error(errorMsg, e);
if (tokenConsumer != null) {
tokenConsumer.accept(" 请配置API密钥");
}
handleStreamError(isCompleted, tokenConsumer, errorMsg);
} else {
log.error("流式处理ReAct请求时发生错误", e);
if (tokenConsumer != null) {
tokenConsumer.accept("处理请求时发生错误: " + e.getMessage());
}
}
/**
* 执行ReAct流式流程
* @param client ChatClient实例
* @param userMessage 用户消息
* @param tools 工具列表
* @param tokenConsumer token处理回调函数
* @param agent Agent对象
*/
private void executeReactStreamProcess(ChatClient client, String userMessage, List<Object> tools, Consumer<String> tokenConsumer, Agent agent) {
if (defaultReactExecutor instanceof DefaultReactExecutor) {
((DefaultReactExecutor) defaultReactExecutor).executeStreamWithAgent(
client, userMessage, tools, tokenConsumer, agent);
} else {
defaultReactExecutor.executeStream(client, userMessage, tools, tokenConsumer);
}
}
/**
* 初始化工作面板
* 流式处理用户请求(默认不传递userId)
*
* @param agent Agent对象
* @param userMessage 用户消息
* @param tokenConsumer token处理回调函数
*/
private void initializeWorkPanel(Agent agent) {
if (workPanelCollector != null) {
try {
workPanelCollector.clear();
workPanelCollector.recordLog("开始处理Agent请求: " + agent.getName(), "info");
} catch (Exception e) {
log.error("初始化工作面板时发生错误", e);
}
}
public void processRequestStream(Agent agent, String userMessage, java.util.function.Consumer<String> tokenConsumer) {
processRequestStream(agent, userMessage, tokenConsumer, null);
}
/**
* 为流式处理设置工作面板数据收集器
*/
public void setWorkPanelEventSubscriber(Consumer<WorkPanelEvent> consumer) {
// 订阅工作面板事件,用于实时推送
if (workPanelCollector != null && consumer != null) {
workPanelCollector.subscribe(consumer);
}
}
/**
* 获取工作面板数据收集器
*/
public IWorkPanelDataCollector getWorkPanelCollector() {
public pangea.hiagent.workpanel.IWorkPanelDataCollector getWorkPanelCollector() {
return workPanelCollector;
}
/**
* 判断异常是否为401未授权错误
* @param e 异常对象
* @return 是否为401错误
*/
private boolean isUnauthorizedError(Throwable e) {
if (e == null) {
return false;
}
// 检查异常消息中是否包含401 Unauthorized
String message = e.getMessage();
if (message != null && (message.contains("401 Unauthorized") || message.contains("Unauthorized"))) {
return true;
}
// 递归检查cause
return isUnauthorizedError(e.getCause());
}
}
\ No newline at end of file
package pangea.hiagent.tool;
package pangea.hiagent.react;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
......
package pangea.hiagent.tool;
package pangea.hiagent.react;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClient;
......@@ -9,7 +9,7 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.context.annotation.Lazy;
import pangea.hiagent.workpanel.IWorkPanelDataCollector;
import pangea.hiagent.agent.AgentChatService;
import pangea.hiagent.core.AgentChatService;
import pangea.hiagent.memory.MemoryService;
import pangea.hiagent.model.Agent;
import java.util.List;
......
package pangea.hiagent.tool;
package pangea.hiagent.react;
/**
* ReAct回调接口,用于捕获ReAct执行的每一步
......
package pangea.hiagent.tool;
package pangea.hiagent.react;
import org.springframework.ai.chat.client.ChatClient;
import pangea.hiagent.model.Agent;
import java.util.List;
import java.util.function.Consumer;
......@@ -18,6 +19,18 @@ public interface ReactExecutor {
*/
String execute(ChatClient chatClient, String userInput, List<Object> tools);
/**
* 执行ReAct流程(同步方式)- 支持Agent配置
* @param chatClient ChatClient实例
* @param userInput 用户输入
* @param tools 工具列表
* @param agent Agent对象(可选)
* @return 最终答案
*/
default String executeWithAgent(ChatClient chatClient, String userInput, List<Object> tools, Agent agent) {
return execute(chatClient, userInput, tools);
}
/**
* 流式执行ReAct流程
* @param chatClient ChatClient实例
......@@ -27,6 +40,18 @@ public interface ReactExecutor {
*/
void executeStream(ChatClient chatClient, String userInput, List<Object> tools, Consumer<String> tokenConsumer);
/**
* 流式执行ReAct流程 - 支持Agent配置
* @param chatClient ChatClient实例
* @param userInput 用户输入
* @param tools 工具列表
* @param tokenConsumer token处理回调函数
* @param agent Agent对象(可选)
*/
default void executeStreamWithAgent(ChatClient chatClient, String userInput, List<Object> tools, Consumer<String> tokenConsumer, Agent agent) {
executeStream(chatClient, userInput, tools, tokenConsumer);
}
/**
* 添加ReAct回调
* @param callback ReAct回调
......
package pangea.hiagent.service;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import pangea.hiagent.model.Agent;
import pangea.hiagent.model.Tool;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
/**
* Agent工具管理服务类
* 负责管理Agent可用的工具列表
*/
@Slf4j
@Service
public class AgentToolManager {
@Autowired
private ToolService toolService;
/**
* 获取Agent可用的工具列表
* @param agent Agent对象
* @return 工具列表
*/
public List<Tool> getAvailableTools(Agent agent) {
try {
log.info("获取Agent可用工具列表,Agent ID: {}, 名称: {}", agent.getId(), agent.getName());
// 获取Agent所有者的所有活跃工具
List<Tool> allTools = toolService.getUserToolsByStatus(agent.getOwner(), "active");
log.info("用户所有活跃工具数量: {}", allTools != null ? allTools.size() : 0);
if (allTools == null || allTools.isEmpty()) {
log.warn("Agent: {} 没有配置可用的工具", agent.getId());
return List.of();
}
// 如果Agent配置了特定的工具列表,则只返回配置的工具
List<String> toolNames = agent.getToolNames();
log.info("Agent配置的工具名称列表: {}", toolNames);
if (toolNames != null && !toolNames.isEmpty()) {
// 根据工具名称筛选工具
List<Tool> filteredTools = filterToolsByName(allTools, toolNames);
log.info("筛选后的工具数量: {}", filteredTools.size());
return filteredTools;
}
return allTools;
} catch (Exception e) {
log.error("获取Agent可用工具时发生错误", e);
return List.of();
}
}
/**
* 根据工具名称筛选工具
* @param allTools 所有工具
* @param toolNames 工具名称列表
* @return 筛选后的工具列表
*/
public List<Tool> filterToolsByName(List<Tool> allTools, List<String> toolNames) {
return allTools.stream()
.filter(tool -> toolNames.contains(tool.getName()))
.collect(Collectors.toList());
}
/**
* 根据工具名称集合筛选工具实例(用于ReActService)
* @param allTools 所有工具实例
* @param toolNames 工具名称集合
* @return 筛选后的工具实例列表
*/
public List<Object> filterToolsByInstances(List<Object> allTools, Set<String> toolNames) {
log.debug("开始筛选工具实例,工具名称集合: {}", toolNames);
if (toolNames == null || toolNames.isEmpty()) {
log.debug("工具名称集合为空,返回所有工具实例");
return allTools;
}
List<Object> filteredTools = allTools.stream()
.filter(tool -> {
// 获取工具类名(不含包名)
String className = tool.getClass().getSimpleName();
log.debug("检查工具类: {}", className);
// 检查类名是否匹配
boolean isMatch = toolNames.contains(className) ||
toolNames.stream().anyMatch(name ->
className.toLowerCase().contains(name.toLowerCase()));
if (isMatch) {
log.debug("工具 {} 匹配成功", className);
}
return isMatch;
})
.collect(Collectors.toList());
log.debug("筛选完成,返回 {} 个工具实例", filteredTools.size());
return filteredTools;
}
/**
* 构建工具描述文本
* @param tools 工具列表
* @return 工具描述文本
*/
public String buildToolsDescription(List<Tool> tools) {
if (tools.isEmpty()) {
return "(暂无可用工具)";
}
StringBuilder description = new StringBuilder();
for (int i = 0; i < tools.size(); i++) {
Tool tool = tools.get(i);
description.append(i + 1).append(". ");
description.append(tool.getName());
if (hasValue(tool.getDisplayName())) {
description.append(" - ").append(tool.getDisplayName());
}
if (hasValue(tool.getDescription())) {
description.append(" - ").append(tool.getDescription());
}
description.append("\n");
}
return description.toString();
}
/**
* 检查字符串是否有值
* @param value 字符串值
* @return 是否有值
*/
private boolean hasValue(String value) {
return value != null && !value.isEmpty();
}
}
\ No newline at end of file
......@@ -26,7 +26,7 @@ public class WorkPanelService {
private AgentService agentService;
@Autowired
private pangea.hiagent.agent.ReActService reActService;
private pangea.hiagent.core.ReActService reActService;
// 用于跟踪已发送的事件ID,防止重复发送
private final Map<String, Set<String>> sentEventIds = new ConcurrentHashMap<>();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment