package pangea.hiagent.memory;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Component;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;

/**
 * 基于Redis的ChatMemory实现
 * 提供持久化的对话历史存储功能
 */
@Slf4j
@Component
public class RedisChatMemory implements ChatMemory {
    
    @Autowired
    private RedisTemplate<String, String> redisTemplate;
    
    private final ObjectMapper objectMapper = new ObjectMapper();
    
    // 对话历史默认保存时间（小时）
    private static final int DEFAULT_EXPIRE_HOURS = 24;
    
    @Override
    public void add(String conversationId, List<Message> messages) {
        try {
            // 获取现有的消息列表
            List<Message> existingMessages = get(conversationId, Integer.MAX_VALUE);
            
            // 添加新消息
            existingMessages.addAll(messages);
            
            // 序列化并保存到Redis
            String serializedMessages = objectMapper.writeValueAsString(existingMessages);
            String key = generateRedisKey(conversationId);
            redisTemplate.opsForValue().set(key, serializedMessages, DEFAULT_EXPIRE_HOURS, TimeUnit.HOURS);
            
            log.debug("成功将{}条消息添加到会话{}", messages.size(), conversationId);
        } catch (JsonProcessingException e) {
            log.error("序列化消息时发生错误，会话ID: {}，消息数量: {}", conversationId, messages.size(), e);
            throw new RuntimeException("Failed to serialize messages for conversation: " + conversationId, e);
        } catch (Exception e) {
            log.error("保存消息到Redis时发生错误，会话ID: {}，消息数量: {}", conversationId, messages.size(), e);
            // 添加更多上下文信息到异常中
            throw new RuntimeException("Failed to save messages to Redis for conversation: " + conversationId + ", message count: " + messages.size(), e);
        }
    }
    
    @Override
    public List<Message> get(String conversationId, int lastN) {
        try {
            String key = generateRedisKey(conversationId);
            String serializedMessages = redisTemplate.opsForValue().get(key);
            
            if (serializedMessages == null || serializedMessages.isEmpty()) {
                return new ArrayList<>();
            }
            
            List<Message> messages = objectMapper.readValue(serializedMessages, new TypeReference<List<Message>>() {});
            
            // 返回最新的N条消息
            if (lastN < messages.size()) {
                return messages.subList(messages.size() - lastN, messages.size());
            }
            
            return messages;
        } catch (JsonProcessingException e) {
            log.error("反序列化消息时发生错误，会话ID: {}", conversationId, e);
            throw new RuntimeException("Failed to deserialize messages for conversation: " + conversationId, e);
        } catch (Exception e) {
            log.error("从Redis获取消息时发生错误，会话ID: {}", conversationId, e);
            throw new RuntimeException("Failed to get messages from Redis for conversation: " + conversationId, e);
        }
    }
    
    @Override
    public void clear(String conversationId) {
        try {
            String key = generateRedisKey(conversationId);
            redisTemplate.delete(key);
            log.debug("成功清除会话{}", conversationId);
        } catch (Exception e) {
            log.error("清除会话时发生错误，会话ID: {}", conversationId, e);
            throw new RuntimeException("Failed to clear conversation: " + conversationId, e);
        }
    }
    
    /**
     * 生成Redis键名
     * @param conversationId 会话ID
     * @return Redis键名
     */
    private String generateRedisKey(String conversationId) {
        return "chat_memory:" + conversationId;
    }
}