package pangea.hiagent.common.config;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Lazy;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.util.UriComponentsBuilder;

import pangea.hiagent.workpanel.playwright.PlaywrightManager;
import pangea.hiagent.common.utils.JwtUtil;
import pangea.hiagent.websocket.DomSyncHandler;

import java.util.Map;
import lombok.extern.slf4j.Slf4j;

/**
 * WebSocket配置类
 */
@Slf4j
@Configuration
@EnableWebSocket
public class DomSyncWebSocketConfig implements WebSocketConfigurer {

    private final JwtHandshakeInterceptor jwtHandshakeInterceptor;
    
    @Autowired
    @Lazy
    private PlaywrightManager playwrightManager;
    
    public DomSyncWebSocketConfig(JwtHandshakeInterceptor jwtHandshakeInterceptor) {
        this.jwtHandshakeInterceptor = jwtHandshakeInterceptor;
    }

    // 注入DomSyncHandler，交由Spring管理生命周期
    @Bean
    public DomSyncHandler domSyncHandler() {
        DomSyncHandler handler = new DomSyncHandler();
        // 通过设置器注入PlaywrightManager
        handler.setPlaywrightManager(playwrightManager);
        return handler;
    }

    @Override
    public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
        registry.addHandler(domSyncHandler(), "/ws/dom-sync")
                // 添加握手拦截器用于JWT验证
                .addInterceptors(jwtHandshakeInterceptor)
                // 生产环境：替换为具体域名，禁止使用*
                .setAllowedOrigins("*");
    }
}

/**
 * JWT握手拦截器，用于WebSocket连接时的认证
 */
@Slf4j
@Component
class JwtHandshakeInterceptor implements HandshakeInterceptor {
    private final JwtUtil jwtUtil;
    
    public JwtHandshakeInterceptor(JwtUtil jwtUtil) {
        this.jwtUtil = jwtUtil;
    }
    
    @Override
    public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, 
                                 WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
        String token = extractTokenFromRequest(request);
        String clientInfo = "[" + (request.getRemoteAddress() != null ? request.getRemoteAddress().toString() : "unknown") + "] ";
        
        log.info(clientInfo + "WebSocket握手请求 - URI: {}, Query: {}", request.getURI(), request.getURI().getQuery());
        
        if (StringUtils.hasText(token)) {
            try {
                log.debug(clientInfo + "Token提取成功，长度: {}", token.length());
                // 验证token是否有效
                boolean isValid = jwtUtil.validateToken(token);
                if (isValid) {
                    // 获取真实的用户ID
                    String userId = jwtUtil.getUserIdFromToken(token);
                    if (userId != null) {
                        attributes.put("token", token);
                        attributes.put("userId", userId);
                        log.info(clientInfo + "WebSocket连接认证成功，用户ID: {}", userId);
                        return true;
                    } else {
                        log.error(clientInfo + "错误：无法从token中提取用户ID。Token长度: {}", token.length());
                        log.error(clientInfo + "token前50字符: {}", token.substring(0, Math.min(50, token.length())));
                        // 尝试从token的payload中直接解析userId
                        try {
                            String[] parts = token.split("\\.");
                            if (parts.length > 1) {
                                String payload = new String(java.util.Base64.getUrlDecoder().decode(parts[1]));
                                log.error(clientInfo + "token payload: {}", payload);
                            }
                        } catch (Exception payloadEx) {
                            log.error(clientInfo + "解析token payload时发生异常: {}", payloadEx.getMessage(), payloadEx);
                        }
                    }
                } else {
                    boolean isExpired = jwtUtil.isTokenExpired(token);
                    log.error(clientInfo + "JWT验证失败。Token已过期: {}", isExpired);
                    
                    // 如果Token已过期，返回401状态码和明确的错误信息
                    response.setStatusCode(org.springframework.http.HttpStatus.UNAUTHORIZED);
                    response.getHeaders().add("WWW-Authenticate", "Bearer error=\"invalid_token\", error_description=\"Token expired\"");
                    return false;
                }
            } catch (Exception e) {
                log.error(clientInfo + "JWT验证过程中发生异常: {}", e.getClass().getSimpleName(), e);
                
                // 如果验证过程出现异常，返回401状态码
                response.setStatusCode(org.springframework.http.HttpStatus.UNAUTHORIZED);
                response.getHeaders().add("WWW-Authenticate", "Bearer error=\"invalid_token\", error_description=\"Token validation failed\"");
                return false;
            }
        } else {
            log.warn(clientInfo + "WebSocket连接缺少认证token");
            log.warn(clientInfo + "请求头Authorization: {}", request.getHeaders().getFirst("Authorization"));
            String query = request.getURI().getQuery();
            log.warn(clientInfo + "查询字符串: {}", query != null ? query : "(为空)");
        }
        
        // 如果没有有效的token，拒绝连接
        log.warn(clientInfo + "拒绝WebSocket连接，返回401 UNAUTHORIZED");
        response.setStatusCode(org.springframework.http.HttpStatus.UNAUTHORIZED);
        response.getHeaders().add("WWW-Authenticate", "Bearer realm=\"WebSocket\"");
        return false;
    }
    
    @Override
    public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, 
                             WebSocketHandler wsHandler, Exception exception) {
        String clientInfo = "[" + (request.getRemoteAddress() != null ? request.getRemoteAddress().toString() : "unknown") + "] ";
        if (exception != null) {
            log.error(clientInfo + "WebSocket握手失败，异常: {}", exception.getClass().getSimpleName(), exception);
        } else {
            log.info(clientInfo + "WebSocket握手后处理完成");
        }
    }
    
    /**
     * 从请求头或参数中提取Token
     * 复用JwtAuthenticationFilter中的逻辑
     */
    private String extractTokenFromRequest(ServerHttpRequest request) {
        // 首先尝试从请求头中提取Token
        String authHeader = request.getHeaders().getFirst("Authorization");
        if (StringUtils.hasText(authHeader) && authHeader.startsWith("Bearer ")) {
            String token = authHeader.substring(7);
            log.debug("从Authorization头中提取Token，长度: {}", token.length());
            return token;
        }
        
        // 如果请求头中没有Token，则尝试从URL参数中提取
        String query = request.getURI().getQuery();
        if (query != null) {
            try {
                UriComponentsBuilder builder = UriComponentsBuilder.newInstance().query(query);
                String token = builder.build().getQueryParams().getFirst("token");
                if (StringUtils.hasText(token)) {
                    log.debug("从URL参数中提取Token，长度: {}", token.length());
                    return token;
                } else {
                    log.debug("URL中没有token参数，Query: {}", query);
                }
            } catch (Exception e) {
                log.warn("解析URL参数时出错: {}", e.getMessage());
            }
        }
        
        return null;
    }
}