package pangea.hiagent.config;

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.util.UriComponentsBuilder;

import pangea.hiagent.utils.JwtUtil;
import pangea.hiagent.websocket.DomSyncHandler;

import java.util.Map;

/**
 * WebSocket配置类
 */
@Configuration
@EnableWebSocket
public class DomSyncWebSocketConfig implements WebSocketConfigurer {

    private final JwtHandshakeInterceptor jwtHandshakeInterceptor;
    
    public DomSyncWebSocketConfig(JwtHandshakeInterceptor jwtHandshakeInterceptor) {
        this.jwtHandshakeInterceptor = jwtHandshakeInterceptor;
    }

    // 注入DomSyncHandler，交由Spring管理生命周期
    @Bean
    public DomSyncHandler domSyncHandler() {
        return new DomSyncHandler();
    }

    @Override
    public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
        registry.addHandler(domSyncHandler(), "/ws/dom-sync")
                // 添加握手拦截器用于JWT验证
                .addInterceptors(jwtHandshakeInterceptor)
                // 生产环境：替换为具体域名，禁止使用*
                .setAllowedOrigins("*");
    }
}

/**
 * JWT握手拦截器，用于WebSocket连接时的认证
 */
@Component
class JwtHandshakeInterceptor implements HandshakeInterceptor {
    private final JwtUtil jwtUtil;
    
    public JwtHandshakeInterceptor(JwtUtil jwtUtil) {
        this.jwtUtil = jwtUtil;
    }
    
    @Override
    public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, 
                                 WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
        String token = extractTokenFromRequest(request);
        
        if (StringUtils.hasText(token)) {
            try {
                // 验证token是否有效
                boolean isValid = jwtUtil.validateToken(token);
                if (isValid) {
                    // 获取真实的用户ID
                    String userId = jwtUtil.getUserIdFromToken(token);
                    if (userId != null) {
                        attributes.put("token", token);
                        attributes.put("userId", userId);
                        System.out.println("WebSocket连接认证成功，用户ID: " + userId);
                        return true;
                    } else {
                        System.err.println("无法从token中提取用户ID");
                    }
                } else {
                    System.err.println("JWT验证失败，token可能已过期或无效");
                }
            } catch (Exception e) {
                System.err.println("JWT验证过程中发生错误: " + e.getMessage());
                e.printStackTrace();
            }
        }
        
        // 如果没有有效的token，拒绝连接
        System.err.println("WebSocket连接缺少有效的认证token");
        response.setStatusCode(org.springframework.http.HttpStatus.UNAUTHORIZED);
        return false;
    }
    
    @Override
    public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, 
                             WebSocketHandler wsHandler, Exception exception) {
        // 握手后处理，这里不需要特殊处理
    }
    
    /**
     * 从请求头或参数中提取Token
     * 复用JwtAuthenticationFilter中的逻辑
     */
    private String extractTokenFromRequest(ServerHttpRequest request) {
        // 首先尝试从请求头中提取Token
        String authHeader = request.getHeaders().getFirst("Authorization");
        if (StringUtils.hasText(authHeader) && authHeader.startsWith("Bearer ")) {
            String token = authHeader.substring(7);
            return token;
        }
        
        // 如果请求头中没有Token，则尝试从URL参数中提取
        String query = request.getURI().getQuery();
        if (query != null) {
            UriComponentsBuilder builder = UriComponentsBuilder.newInstance().query(query);
            String token = builder.build().getQueryParams().getFirst("token");
            if (StringUtils.hasText(token)) {
                return token;
            }
        }
        
        return null;
    }
}