Commit 4eb44b8e authored by youxiaoji's avatar youxiaoji

+ [会话支持]

parent 31c51157
...@@ -126,6 +126,7 @@ public class DefaultReactExecutor implements ReactExecutor { ...@@ -126,6 +126,7 @@ public class DefaultReactExecutor implements ReactExecutor {
messages.add(new SystemMessage(systemPrompt)); messages.add(new SystemMessage(systemPrompt));
if (!newChat) {
if (agent != null) { if (agent != null) {
try { try {
// 如果没有提供用户ID,则尝试获取当前用户ID // 如果没有提供用户ID,则尝试获取当前用户ID
...@@ -138,15 +139,16 @@ public class DefaultReactExecutor implements ReactExecutor { ...@@ -138,15 +139,16 @@ public class DefaultReactExecutor implements ReactExecutor {
List<org.springframework.ai.chat.messages.Message> historyMessages = List<org.springframework.ai.chat.messages.Message> historyMessages =
memoryService.getHistoryMessages(sessionId, historyLength); memoryService.getHistoryMessages(sessionId, historyLength);
if (!newChat) {
messages.addAll(historyMessages); messages.addAll(historyMessages);
}
memoryService.addUserMessageToMemory(sessionId, userInput); memoryService.addUserMessageToMemory(sessionId, userInput);
} catch (Exception e) { } catch (Exception e) {
log.warn("获取历史对话记录时发生错误: {}", e.getMessage()); log.warn("获取历史对话记录时发生错误: {}", e.getMessage());
} }
} }
}
messages.add(new UserMessage(userInput)); messages.add(new UserMessage(userInput));
for (Message message : messages) { for (Message message : messages) {
...@@ -179,13 +181,13 @@ public class DefaultReactExecutor implements ReactExecutor { ...@@ -179,13 +181,13 @@ public class DefaultReactExecutor implements ReactExecutor {
log.info("agentTools {}", agentTools); log.info("agentTools {}", agentTools);
if (agent.getId().compareToIgnoreCase("agent-8") == 0) { if (agent.getId().compareToIgnoreCase("agent-8") == 0) {
if (!chatService.chatExists(tmpUserId, agent.getId())) { // if (!chatService.chatExists(tmpUserId, agent.getId())) {
log.info("new chat for {} {} ", userId, agent.getId()); // log.info("new chat for {} {} ", userId, agent.getId());
prompt = buildPromptWithHistory(defaultSystemPrompt, userInput, agent, tmpUserId, true); prompt = buildPromptWithHistory(agent.getSystemPrompt(), userInput, agent, tmpUserId, true);
} // }
chatClient.prompt(prompt) chatClient.prompt(prompt)
.tools(agentTools.toArray()) .tools(agentTools.toArray())
.toolContext(Map.of("emitterId", emitterId, "userId", sseTokenEmitter.getUserId(),"agentId",agent.getId())) .toolContext(Map.of("emitterId", emitterId, "userId", sseTokenEmitter.getUserId(), "agentId", agent.getId()))
.stream() .stream()
.chatResponse() .chatResponse()
.subscribe( .subscribe(
......
...@@ -2,6 +2,7 @@ package pangea.hiagent.agent.service; ...@@ -2,6 +2,7 @@ package pangea.hiagent.agent.service;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.scheduling.annotation.Async; import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
...@@ -80,7 +81,7 @@ public class AgentChatService { ...@@ -80,7 +81,7 @@ public class AgentChatService {
* @param response HTTP响应 * @param response HTTP响应
* @return SSE emitter * @return SSE emitter
*/ */
public SseEmitter handleChatStream(String agentId, ChatRequest chatRequest, HttpServletResponse response) { public PangeaEmitter handleChatStream(String agentId, ChatRequest chatRequest, HttpServletResponse response) {
log.info("开始处理流式对话请求,AgentId: {}, 用户消息: {}", agentId, chatRequest.getMessage()); log.info("开始处理流式对话请求,AgentId: {}, 用户消息: {}", agentId, chatRequest.getMessage());
// 尝试获取当前用户ID,优先从SecurityContext获取,其次从请求中解析JWT // 尝试获取当前用户ID,优先从SecurityContext获取,其次从请求中解析JWT
...@@ -104,7 +105,7 @@ public class AgentChatService { ...@@ -104,7 +105,7 @@ public class AgentChatService {
emitter.complete(); emitter.complete();
} }
} }
return emitter; return createPangeaEmitter(emitter,chatRequest);
} }
// 验证Agent是否存在 // 验证Agent是否存在
...@@ -122,7 +123,7 @@ public class AgentChatService { ...@@ -122,7 +123,7 @@ public class AgentChatService {
emitter.complete(); emitter.complete();
} }
} }
return emitter; return createPangeaEmitter(emitter,chatRequest);
} }
// 创建 SSE emitter // 创建 SSE emitter
...@@ -130,10 +131,22 @@ public class AgentChatService { ...@@ -130,10 +131,22 @@ public class AgentChatService {
String emitterId = UUID.randomUUID().toString(); String emitterId = UUID.randomUUID().toString();
log.info("emitterId: {}", emitterId); log.info("emitterId: {}", emitterId);
userSseService.registerEmitter(emitterId, emitter); userSseService.registerEmitter(emitterId, emitter);
// 异步处理对话,避免阻塞HTTP连接 // 异步处理对话,避免阻塞HTTP连接
processChatStreamAsync(emitter, agent, chatRequest, userId,emitterId); processChatStreamAsync(emitter, agent, chatRequest, userId,emitterId);
return emitter; return createPangeaEmitter(emitter,chatRequest);
}
private PangeaEmitter createPangeaEmitter(SseEmitter sseEmitter,ChatRequest chatRequest){
PangeaEmitter pangeaEmitter = null;
if(StringUtils.isEmpty(chatRequest.getChatId())){
String chatId = UUID.randomUUID().toString();
chatRequest.setChatId(chatId);
pangeaEmitter = new PangeaEmitter(sseEmitter,chatId);
}else{
pangeaEmitter = new PangeaEmitter(sseEmitter,chatRequest.getChatId());
}
return pangeaEmitter;
} }
/** /**
...@@ -202,6 +215,7 @@ public class AgentChatService { ...@@ -202,6 +215,7 @@ public class AgentChatService {
// 创建新的SseTokenEmitter实例 // 创建新的SseTokenEmitter实例
SseTokenEmitter tokenEmitter = new SseTokenEmitter(userSseService, emitter, agent, request, userId, this::handleCompletion); SseTokenEmitter tokenEmitter = new SseTokenEmitter(userSseService, emitter, agent, request, userId, this::handleCompletion);
tokenEmitter.setEmitterId(emitterId); tokenEmitter.setEmitterId(emitterId);
tokenEmitter.setChatId(chatRequest.getChatId());
// 处理流式请求前再次检查连接状态 // 处理流式请求前再次检查连接状态
if (!userSseService.isEmitterCompleted(emitter)) { if (!userSseService.isEmitterCompleted(emitter)) {
processor.processStreamRequest(request, agent, userId, tokenEmitter); processor.processStreamRequest(request, agent, userId, tokenEmitter);
......
...@@ -26,6 +26,7 @@ public class SseTokenEmitter implements TokenConsumerWithCompletion { ...@@ -26,6 +26,7 @@ public class SseTokenEmitter implements TokenConsumerWithCompletion {
private final String userId; private final String userId;
private final CompletionCallback completionCallback; private final CompletionCallback completionCallback;
private String emitterId; private String emitterId;
private String chatId;
/** /**
* 构造函数 * 构造函数
...@@ -170,4 +171,11 @@ public class SseTokenEmitter implements TokenConsumerWithCompletion { ...@@ -170,4 +171,11 @@ public class SseTokenEmitter implements TokenConsumerWithCompletion {
public String getUserId() { public String getUserId() {
return userId; return userId;
} }
public void setChatId(String chatId) {
this.chatId = chatId;
}
public String getChatId() {
return chatId;
}
} }
\ No newline at end of file
...@@ -5,6 +5,7 @@ import org.springframework.web.bind.annotation.*; ...@@ -5,6 +5,7 @@ import org.springframework.web.bind.annotation.*;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import pangea.hiagent.agent.service.AgentChatService; import pangea.hiagent.agent.service.AgentChatService;
import pangea.hiagent.agent.service.PangeaEmitter;
import pangea.hiagent.web.dto.ChatRequest; import pangea.hiagent.web.dto.ChatRequest;
import jakarta.servlet.http.HttpServletResponse; import jakarta.servlet.http.HttpServletResponse;
import jakarta.validation.Valid; import jakarta.validation.Valid;
...@@ -34,7 +35,7 @@ public class AgentChatController { ...@@ -34,7 +35,7 @@ public class AgentChatController {
* @return SSE emitter * @return SSE emitter
*/ */
@PostMapping("/chat-stream") @PostMapping("/chat-stream")
public SseEmitter chatStream( public PangeaEmitter chatStream(
@RequestParam @NotBlank(message = "Agent ID不能为空") String agentId, @RequestParam @NotBlank(message = "Agent ID不能为空") String agentId,
@RequestBody @Valid ChatRequest chatRequest, @RequestBody @Valid ChatRequest chatRequest,
HttpServletResponse response) { HttpServletResponse response) {
......
...@@ -17,6 +17,7 @@ import java.util.List; ...@@ -17,6 +17,7 @@ import java.util.List;
@AllArgsConstructor @AllArgsConstructor
@JsonInclude(JsonInclude.Include.NON_NULL) @JsonInclude(JsonInclude.Include.NON_NULL)
public class AgentRequest { public class AgentRequest {
private String chatId;
private String agentId; private String agentId;
private String systemPrompt; private String systemPrompt;
private String userMessage; private String userMessage;
......
...@@ -18,6 +18,9 @@ import jakarta.validation.constraints.NotBlank; ...@@ -18,6 +18,9 @@ import jakarta.validation.constraints.NotBlank;
@AllArgsConstructor @AllArgsConstructor
@JsonInclude(JsonInclude.Include.NON_NULL) @JsonInclude(JsonInclude.Include.NON_NULL)
public class ChatRequest { public class ChatRequest {
private String chatId;
@NotBlank(message = "用户消息不能为空") @NotBlank(message = "用户消息不能为空")
private String message; private String message;
...@@ -31,6 +34,7 @@ public class ChatRequest { ...@@ -31,6 +34,7 @@ public class ChatRequest {
*/ */
public AgentRequest toAgentRequest(String agentId, pangea.hiagent.model.Agent agent, pangea.hiagent.tool.AgentToolManager agentToolManager) { public AgentRequest toAgentRequest(String agentId, pangea.hiagent.model.Agent agent, pangea.hiagent.tool.AgentToolManager agentToolManager) {
return AgentRequest.builder() return AgentRequest.builder()
.chatId(this.chatId)
.agentId(agentId) .agentId(agentId)
.systemPrompt(agent.getSystemPrompt()) .systemPrompt(agent.getSystemPrompt())
.userMessage(this.message) .userMessage(this.message)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment