Commit b2158e16 authored by ligaowei's avatar ligaowei

修复ReAct智能体记忆功能,确保对话历史正确加载和保存

parent bee22bb0
......@@ -229,6 +229,13 @@
<version>2.9.1</version>
</dependency>
<!-- Hibernate Validator for Jakarta Bean Validation -->
<dependency>
<groupId>org.hibernate.validator</groupId>
<artifactId>hibernate-validator</artifactId>
<version>8.0.1.Final</version>
</dependency>
<!-- SpringDoc OpenAPI for Swagger -->
<dependency>
<groupId>org.springdoc</groupId>
......
......@@ -31,6 +31,7 @@ import java.util.UUID;
*/
@Slf4j
@Component
@SuppressWarnings("unchecked")
public class OAuth2AuthenticationStrategy implements AuthenticationStrategy {
private final RestTemplate restTemplate;
......
......@@ -8,6 +8,7 @@ import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Service;
import pangea.hiagent.memory.MemoryService;
import pangea.hiagent.model.Agent;
......@@ -16,7 +17,7 @@ import pangea.hiagent.rag.RagService;
import pangea.hiagent.react.ReactCallback;
import pangea.hiagent.react.ReactExecutor;
import pangea.hiagent.service.AgentService;
import pangea.hiagent.service.AgentToolManager;
import pangea.hiagent.core.AgentToolManager;
import java.util.List;
import java.util.Set;
......
......@@ -147,12 +147,15 @@ public class DefaultReactExecutor implements ReactExecutor {
// 添加历史消息到Prompt
messages.addAll(historyMessages);
// 将当前用户消息添加到内存中,以便下次对话使用
memoryService.addUserMessageToMemory(sessionId, userInput);
} catch (Exception e) {
log.warn("获取历史对话记录时发生错误: {}", e.getMessage());
}
}
// 添加当前用户消息
// 添加当前用户消息到Prompt
messages.add(new UserMessage(userInput));
return new Prompt(messages);
......@@ -258,6 +261,16 @@ public class DefaultReactExecutor implements ReactExecutor {
// 触发最终答案步骤
triggerFinalAnswerStep(fullResponse.toString());
// 将助理回复添加到ChatMemory
if (agent != null) {
try {
String sessionId = memoryService.generateSessionId(agent);
memoryService.addAssistantMessageToMemory(sessionId, fullResponse.toString());
} catch (Exception e) {
log.warn("保存助理回复到内存时发生错误: {}", e.getMessage());
}
}
// 发送完成事件,包含完整内容
sendCompletionEvent(tokenConsumer, fullResponse.toString());
}
......
package pangea.hiagent.service;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import pangea.hiagent.model.Agent;
import pangea.hiagent.model.Tool;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
/**
* Agent工具管理服务类
* 负责管理Agent可用的工具列表
*/
@Slf4j
@Service
public class AgentToolManager {
@Autowired
private ToolService toolService;
/**
* 获取Agent可用的工具列表
* @param agent Agent对象
* @return 工具列表
*/
public List<Tool> getAvailableTools(Agent agent) {
try {
log.info("获取Agent可用工具列表,Agent ID: {}, 名称: {}", agent.getId(), agent.getName());
// 获取Agent所有者的所有活跃工具
List<Tool> allTools = toolService.getUserToolsByStatus(agent.getOwner(), "active");
log.info("用户所有活跃工具数量: {}", allTools != null ? allTools.size() : 0);
if (allTools == null || allTools.isEmpty()) {
log.warn("Agent: {} 没有配置可用的工具", agent.getId());
return List.of();
}
// 如果Agent配置了特定的工具列表,则只返回配置的工具
List<String> toolNames = agent.getToolNames();
log.info("Agent配置的工具名称列表: {}", toolNames);
if (toolNames != null && !toolNames.isEmpty()) {
// 根据工具名称筛选工具
List<Tool> filteredTools = filterToolsByName(allTools, toolNames);
log.info("筛选后的工具数量: {}", filteredTools.size());
return filteredTools;
}
return allTools;
} catch (Exception e) {
log.error("获取Agent可用工具时发生错误", e);
return List.of();
}
}
/**
* 根据工具名称筛选工具
* @param allTools 所有工具
* @param toolNames 工具名称列表
* @return 筛选后的工具列表
*/
public List<Tool> filterToolsByName(List<Tool> allTools, List<String> toolNames) {
return allTools.stream()
.filter(tool -> toolNames.contains(tool.getName()))
.collect(Collectors.toList());
}
/**
* 根据工具名称集合筛选工具实例(用于ReActService)
* @param allTools 所有工具实例
* @param toolNames 工具名称集合
* @return 筛选后的工具实例列表
*/
public List<Object> filterToolsByInstances(List<Object> allTools, Set<String> toolNames) {
log.debug("开始筛选工具实例,工具名称集合: {}", toolNames);
if (toolNames == null || toolNames.isEmpty()) {
log.debug("工具名称集合为空,返回所有工具实例");
return allTools;
}
List<Object> filteredTools = allTools.stream()
.filter(tool -> {
// 获取工具类名(不含包名)
String className = tool.getClass().getSimpleName();
log.debug("检查工具类: {}", className);
// 检查类名是否匹配
boolean isMatch = toolNames.contains(className) ||
toolNames.stream().anyMatch(name ->
className.toLowerCase().contains(name.toLowerCase()));
if (isMatch) {
log.debug("工具 {} 匹配成功", className);
}
return isMatch;
})
.collect(Collectors.toList());
log.debug("筛选完成,返回 {} 个工具实例", filteredTools.size());
return filteredTools;
}
/**
* 构建工具描述文本
* @param tools 工具列表
* @return 工具描述文本
*/
public String buildToolsDescription(List<Tool> tools) {
if (tools.isEmpty()) {
return "(暂无可用工具)";
}
StringBuilder description = new StringBuilder();
for (int i = 0; i < tools.size(); i++) {
Tool tool = tools.get(i);
description.append(i + 1).append(". ");
description.append(tool.getName());
if (hasValue(tool.getDisplayName())) {
description.append(" - ").append(tool.getDisplayName());
}
if (hasValue(tool.getDescription())) {
description.append(" - ").append(tool.getDescription());
}
description.append("\n");
}
return description.toString();
}
/**
* 检查字符串是否有值
* @param value 字符串值
* @return 是否有值
*/
private boolean hasValue(String value) {
return value != null && !value.isEmpty();
}
}
\ No newline at end of file
......@@ -8,10 +8,11 @@ import pangea.hiagent.model.Agent;
import pangea.hiagent.service.AgentService;
import pangea.hiagent.service.ToolService;
import pangea.hiagent.rag.RagService;
import pangea.hiagent.tool.ReactCallback;
import pangea.hiagent.tool.DefaultReactExecutor;
import pangea.hiagent.react.ReactCallback;
import pangea.hiagent.react.DefaultReactExecutor;
import pangea.hiagent.memory.MemoryService;
import pangea.hiagent.workpanel.IWorkPanelDataCollector;
import pangea.hiagent.core.ReActService;
import java.util.List;
import java.util.Set;
......
......@@ -315,6 +315,9 @@ const sendMessage = async () => {
isStreaming: false
})
// 记录会话信息用于调试
console.log('[ChatArea] 发送消息,Agent ID:', selectedAgent.value, '消息内容:', userMessage)
await scrollToBottom()
// 添加AI消息容器(流式接收)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment