Commit 4eb44b8e authored by youxiaoji's avatar youxiaoji

+ [会话支持]

parent 31c51157
......@@ -126,28 +126,30 @@ public class DefaultReactExecutor implements ReactExecutor {
messages.add(new SystemMessage(systemPrompt));
if (agent != null) {
try {
// 如果没有提供用户ID,则尝试获取当前用户ID
if (userId == null) {
userId = UserUtils.getCurrentUserIdStatic();
}
String sessionId = memoryService.generateSessionId(agent, userId);
if (!newChat) {
if (agent != null) {
try {
// 如果没有提供用户ID,则尝试获取当前用户ID
if (userId == null) {
userId = UserUtils.getCurrentUserIdStatic();
}
String sessionId = memoryService.generateSessionId(agent, userId);
int historyLength = agent.getHistoryLength() != null ? agent.getHistoryLength() : 10;
int historyLength = agent.getHistoryLength() != null ? agent.getHistoryLength() : 10;
List<org.springframework.ai.chat.messages.Message> historyMessages =
memoryService.getHistoryMessages(sessionId, historyLength);
if (!newChat) {
List<org.springframework.ai.chat.messages.Message> historyMessages =
memoryService.getHistoryMessages(sessionId, historyLength);
messages.addAll(historyMessages);
}
memoryService.addUserMessageToMemory(sessionId, userInput);
} catch (Exception e) {
log.warn("获取历史对话记录时发生错误: {}", e.getMessage());
memoryService.addUserMessageToMemory(sessionId, userInput);
} catch (Exception e) {
log.warn("获取历史对话记录时发生错误: {}", e.getMessage());
}
}
}
messages.add(new UserMessage(userInput));
for (Message message : messages) {
log.info("message is {}", message);
......@@ -179,13 +181,13 @@ public class DefaultReactExecutor implements ReactExecutor {
log.info("agentTools {}", agentTools);
if (agent.getId().compareToIgnoreCase("agent-8") == 0) {
if (!chatService.chatExists(tmpUserId, agent.getId())) {
log.info("new chat for {} {} ", userId, agent.getId());
prompt = buildPromptWithHistory(defaultSystemPrompt, userInput, agent, tmpUserId, true);
}
// if (!chatService.chatExists(tmpUserId, agent.getId())) {
// log.info("new chat for {} {} ", userId, agent.getId());
prompt = buildPromptWithHistory(agent.getSystemPrompt(), userInput, agent, tmpUserId, true);
// }
chatClient.prompt(prompt)
.tools(agentTools.toArray())
.toolContext(Map.of("emitterId", emitterId, "userId", sseTokenEmitter.getUserId(),"agentId",agent.getId()))
.toolContext(Map.of("emitterId", emitterId, "userId", sseTokenEmitter.getUserId(), "agentId", agent.getId()))
.stream()
.chatResponse()
.subscribe(
......
......@@ -2,6 +2,7 @@ package pangea.hiagent.agent.service;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
......@@ -80,7 +81,7 @@ public class AgentChatService {
* @param response HTTP响应
* @return SSE emitter
*/
public SseEmitter handleChatStream(String agentId, ChatRequest chatRequest, HttpServletResponse response) {
public PangeaEmitter handleChatStream(String agentId, ChatRequest chatRequest, HttpServletResponse response) {
log.info("开始处理流式对话请求,AgentId: {}, 用户消息: {}", agentId, chatRequest.getMessage());
// 尝试获取当前用户ID,优先从SecurityContext获取,其次从请求中解析JWT
......@@ -104,7 +105,7 @@ public class AgentChatService {
emitter.complete();
}
}
return emitter;
return createPangeaEmitter(emitter,chatRequest);
}
// 验证Agent是否存在
......@@ -122,7 +123,7 @@ public class AgentChatService {
emitter.complete();
}
}
return emitter;
return createPangeaEmitter(emitter,chatRequest);
}
// 创建 SSE emitter
......@@ -130,10 +131,22 @@ public class AgentChatService {
String emitterId = UUID.randomUUID().toString();
log.info("emitterId: {}", emitterId);
userSseService.registerEmitter(emitterId, emitter);
// 异步处理对话,避免阻塞HTTP连接
processChatStreamAsync(emitter, agent, chatRequest, userId,emitterId);
return emitter;
return createPangeaEmitter(emitter,chatRequest);
}
private PangeaEmitter createPangeaEmitter(SseEmitter sseEmitter,ChatRequest chatRequest){
PangeaEmitter pangeaEmitter = null;
if(StringUtils.isEmpty(chatRequest.getChatId())){
String chatId = UUID.randomUUID().toString();
chatRequest.setChatId(chatId);
pangeaEmitter = new PangeaEmitter(sseEmitter,chatId);
}else{
pangeaEmitter = new PangeaEmitter(sseEmitter,chatRequest.getChatId());
}
return pangeaEmitter;
}
/**
......@@ -202,6 +215,7 @@ public class AgentChatService {
// 创建新的SseTokenEmitter实例
SseTokenEmitter tokenEmitter = new SseTokenEmitter(userSseService, emitter, agent, request, userId, this::handleCompletion);
tokenEmitter.setEmitterId(emitterId);
tokenEmitter.setChatId(chatRequest.getChatId());
// 处理流式请求前再次检查连接状态
if (!userSseService.isEmitterCompleted(emitter)) {
processor.processStreamRequest(request, agent, userId, tokenEmitter);
......
......@@ -26,6 +26,7 @@ public class SseTokenEmitter implements TokenConsumerWithCompletion {
private final String userId;
private final CompletionCallback completionCallback;
private String emitterId;
private String chatId;
/**
* 构造函数
......@@ -170,4 +171,11 @@ public class SseTokenEmitter implements TokenConsumerWithCompletion {
public String getUserId() {
return userId;
}
public void setChatId(String chatId) {
this.chatId = chatId;
}
public String getChatId() {
return chatId;
}
}
\ No newline at end of file
......@@ -5,6 +5,7 @@ import org.springframework.web.bind.annotation.*;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import pangea.hiagent.agent.service.AgentChatService;
import pangea.hiagent.agent.service.PangeaEmitter;
import pangea.hiagent.web.dto.ChatRequest;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.validation.Valid;
......@@ -34,7 +35,7 @@ public class AgentChatController {
* @return SSE emitter
*/
@PostMapping("/chat-stream")
public SseEmitter chatStream(
public PangeaEmitter chatStream(
@RequestParam @NotBlank(message = "Agent ID不能为空") String agentId,
@RequestBody @Valid ChatRequest chatRequest,
HttpServletResponse response) {
......
......@@ -17,6 +17,7 @@ import java.util.List;
@AllArgsConstructor
@JsonInclude(JsonInclude.Include.NON_NULL)
public class AgentRequest {
private String chatId;
private String agentId;
private String systemPrompt;
private String userMessage;
......
......@@ -18,6 +18,9 @@ import jakarta.validation.constraints.NotBlank;
@AllArgsConstructor
@JsonInclude(JsonInclude.Include.NON_NULL)
public class ChatRequest {
private String chatId;
@NotBlank(message = "用户消息不能为空")
private String message;
......@@ -31,6 +34,7 @@ public class ChatRequest {
*/
public AgentRequest toAgentRequest(String agentId, pangea.hiagent.model.Agent agent, pangea.hiagent.tool.AgentToolManager agentToolManager) {
return AgentRequest.builder()
.chatId(this.chatId)
.agentId(agentId)
.systemPrompt(agent.getSystemPrompt())
.userMessage(this.message)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment