package pangea.hiagent.websocket;

import lombok.extern.slf4j.Slf4j;

import java.nio.ByteBuffer;
import java.util.Objects;

/**
 * WebSocket二进制消息协议处理类
 * 
 * 协议格式：
 * ┌────────┬─────────┬─────────┬──────────────┬──────────────┐
 * │ 头字节 │ 消息ID  │ 总分片数 │ 当前分片索引 │    数据      │
 * │(1B)   │ (4B)    │ (2B)    │ (2B)         │   (可变)     │
 * └────────┴─────────┴─────────┴──────────────┴──────────────┘
 * 
 * 头字节定义：
 * bit 7-5: 消息类型 (000=data, 001=ack, 010=error)
 * bit 4-2: 编码方式 (000=raw, 001=gzip, 010=brotli)
 * bit 1-0: 保留位
 */
@Slf4j
public class BinaryMessageProtocol {
    // ========== 消息类型常量 ==========
    public static final byte TYPE_DATA = 0x00;      // 数据帧
    public static final byte TYPE_ACK = 0x01;       // 确认帧
    public static final byte TYPE_ERROR = 0x02;     // 错误帧
    
    // ========== 编码类型常量 ==========
    public static final byte ENCODING_RAW = 0x00;       // 无编码
    public static final byte ENCODING_GZIP = 0x01;      // GZIP压缩
    public static final byte ENCODING_BROTLI = 0x02;    // Brotli压缩
    
    // ========== 协议字段大小 ==========
    public static final int HEADER_SIZE = 12;           // 协议头大小（字节）
    public static final int MAX_FRAGMENT_SIZE = 65535 - HEADER_SIZE;  // 最大分片数据大小
    
    // ========== 消息ID生成器 ==========
    private static int nextMessageId = 1;
    
    /**
     * 生成唯一的消息ID
     */
    public static synchronized int generateMessageId() {
        if (nextMessageId >= Integer.MAX_VALUE) {
            nextMessageId = 1;
        }
        return nextMessageId++;
    }
    
    /**
     * 编码二进制消息头
     * 
     * @param messageType 消息类型 (TYPE_DATA/TYPE_ACK/TYPE_ERROR)
     * @param messageId 消息ID (全局唯一)
     * @param totalFragments 总分片数
     * @param currentFragment 当前分片索引 (从0开始)
     * @param encoding 编码方式 (ENCODING_RAW/ENCODING_GZIP/ENCODING_BROTLI)
     * @return 编码后的12字节头数据
     * @throws IllegalArgumentException 参数验证失败
     */
    public static byte[] encodeHeader(
            byte messageType,
            int messageId,
            int totalFragments,
            int currentFragment,
            byte encoding) {
        
        // ========== 参数验证 ==========
        validateMessageType(messageType);
        validateEncoding(encoding);
        
        if (totalFragments <= 0 || totalFragments > 65535) {
            throw new IllegalArgumentException("总分片数必须在1-65535之间，当前值: " + totalFragments);
        }
        
        if (currentFragment < 0 || currentFragment >= totalFragments) {
            throw new IllegalArgumentException(
                String.format("当前分片索引越界，应在0-%d之间，当前值: %d", 
                    totalFragments - 1, currentFragment)
            );
        }
        
        // ========== 编码头信息 ==========
        ByteBuffer buffer = ByteBuffer.allocate(HEADER_SIZE);
        
        // 第1字节：消息类型(3bit) + 编码方式(3bit) + 保留(2bit)
        byte headerByte = (byte) ((messageType & 0x07) << 5 | (encoding & 0x07) << 2);
        buffer.put(headerByte);
        
        // 第2-5字节：消息ID (4字节，大端序)
        buffer.putInt(messageId);
        
        // 第6-7字节：总分片数 (2字节，大端序)
        buffer.putShort((short) totalFragments);
        
        // 第8-9字节：当前分片索引 (2字节，大端序)
        buffer.putShort((short) currentFragment);
        
        // 第10-11字节：保留，用于扩展
        buffer.putShort((short) 0);
        
        return buffer.array();
    }
    
    /**
     * 解码二进制消息头
     * 
     * @param header 包含协议头的字节数组（至少12字节）
     * @return 解码后的消息头对象
     * @throws IllegalArgumentException header长度不足或格式错误
     */
    public static MessageHeader decodeHeader(byte[] header) {
        if (header == null || header.length < HEADER_SIZE) {
            throw new IllegalArgumentException(
                String.format("消息头长度不足，期望至少%d字节，实际%d字节", 
                    HEADER_SIZE, header == null ? 0 : header.length)
            );
        }
        
        ByteBuffer buffer = ByteBuffer.wrap(header, 0, HEADER_SIZE);
        
        // 解析第1字节
        byte headerByte = buffer.get();
        byte messageType = (byte) ((headerByte >> 5) & 0x07);
        byte encoding = (byte) ((headerByte >> 2) & 0x07);
        
        // 验证解析出的值
        validateMessageType(messageType);
        validateEncoding(encoding);
        
        // 解析ID和分片信息
        int messageId = buffer.getInt();
        int totalFragments = buffer.getShort() & 0xFFFF;
        int currentFragment = buffer.getShort() & 0xFFFF;
        
        // 验证分片信息
        if (totalFragments <= 0 || totalFragments > 65535) {
            throw new IllegalArgumentException("总分片数无效: " + totalFragments);
        }
        
        if (currentFragment >= totalFragments) {
            throw new IllegalArgumentException(
                String.format("分片索引越界: %d >= %d", currentFragment, totalFragments)
            );
        }
        
        return new MessageHeader(messageType, encoding, messageId, totalFragments, currentFragment);
    }
    
    /**
     * 从完整消息中提取数据部分（跳过12字节的协议头）
     * 
     * @param message 完整的消息字节数组
     * @return 数据部分的字节数组
     * @throws IllegalArgumentException message长度不足
     */
    public static byte[] extractData(byte[] message) {
        if (message == null || message.length < HEADER_SIZE) {
            throw new IllegalArgumentException(
                String.format("消息长度不足，期望至少%d字节", HEADER_SIZE)
            );
        }
        
        byte[] data = new byte[message.length - HEADER_SIZE];
        System.arraycopy(message, HEADER_SIZE, data, 0, data.length);
        return data;
    }
    
    /**
     * 将数据和头信息合并为完整消息
     * 
     * @param header 12字节的协议头
     * @param data 消息数据部分
     * @return 完整的消息字节数组
     */
    public static byte[] buildMessage(byte[] header, byte[] data) {
        if (header == null || header.length != HEADER_SIZE) {
            throw new IllegalArgumentException("协议头长度必须为" + HEADER_SIZE + "字节");
        }
        
        if (data == null || data.length == 0) {
            return header;
        }
        
        byte[] message = new byte[HEADER_SIZE + data.length];
        System.arraycopy(header, 0, message, 0, HEADER_SIZE);
        System.arraycopy(data, 0, message, HEADER_SIZE, data.length);
        return message;
    }
    
    /**
     * 验证消息类型是否有效
     */
    private static void validateMessageType(byte type) {
        if (type != TYPE_DATA && type != TYPE_ACK && type != TYPE_ERROR) {
            throw new IllegalArgumentException("无效的消息类型: " + type);
        }
    }
    
    /**
     * 验证编码方式是否有效
     */
    private static void validateEncoding(byte encoding) {
        if (encoding != ENCODING_RAW && encoding != ENCODING_GZIP && encoding != ENCODING_BROTLI) {
            throw new IllegalArgumentException("无效的编码方式: " + encoding);
        }
    }
    
    /**
     * 消息头信息类
     */
    public static class MessageHeader {
        public final byte messageType;      // 消息类型
        public final byte encoding;         // 编码方式
        public final int messageId;         // 消息ID
        public final int totalFragments;    // 总分片数
        public final int currentFragment;   // 当前分片索引
        
        public MessageHeader(byte type, byte enc, int id, int total, int current) {
            this.messageType = type;
            this.encoding = enc;
            this.messageId = id;
            this.totalFragments = total;
            this.currentFragment = current;
        }
        
        @Override
        public String toString() {
            return String.format(
                "MessageHeader{type=%d, encoding=%d, msgId=%d, totalFrags=%d, currentFrag=%d}",
                messageType, encoding, messageId, totalFragments, currentFragment
            );
        }
        
        @Override
        public boolean equals(Object o) {
            if (this == o) return true;
            if (o == null || getClass() != o.getClass()) return false;
            MessageHeader that = (MessageHeader) o;
            return messageType == that.messageType &&
                   encoding == that.encoding &&
                   messageId == that.messageId &&
                   totalFragments == that.totalFragments &&
                   currentFragment == that.currentFragment;
        }
        
        @Override
        public int hashCode() {
            return Objects.hash(messageType, encoding, messageId, totalFragments, currentFragment);
        }
    }
    
    /**
     * 计算传输效率
     * 原始数据大小 -> 编码后大小 -> 加入协议头后的最终大小
     */
    public static class TransmissionStats {
        public int originalSize;      // 原始数据大小
        public int encodedSize;       // 编码后大小
        public int totalSize;         // 加入协议头后的总大小
        public double compressionRatio;  // 压缩比
        
        public TransmissionStats(int original, int encoded) {
            this.originalSize = original;
            this.encodedSize = encoded;
            this.totalSize = encoded + HEADER_SIZE;
            this.compressionRatio = (double) encoded / original;
        }
        
        @Override
        public String toString() {
            return String.format(
                "原始:%dB, 编码:%dB (%.2fx), 加头:%dB, 压缩比:%.2f%%",
                originalSize, encodedSize, compressionRatio, totalSize,
                (1 - (double) totalSize / originalSize) * 100
            );
        }
    }
}
