package pangea.hiagent.config;

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.util.UriComponentsBuilder;

import pangea.hiagent.websocket.DomSyncHandler;

import java.util.Map;

/**
 * WebSocket配置类
 */
@Configuration
@EnableWebSocket
public class DomSyncWebSocketConfig implements WebSocketConfigurer {

    // 注入DomSyncHandler，交由Spring管理生命周期
    @Bean
    public DomSyncHandler domSyncHandler() {
        return new DomSyncHandler();
    }

    @Override
    public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
        registry.addHandler(domSyncHandler(), "/ws/dom-sync")
                // 添加握手拦截器用于JWT验证
                .addInterceptors(new JwtHandshakeInterceptor())
                // 生产环境：替换为具体域名，禁止使用*
                .setAllowedOrigins("*");
    }
    
    /**
     * JWT握手拦截器，用于WebSocket连接时的认证
     */
    public static class JwtHandshakeInterceptor implements HandshakeInterceptor {
        @Override
        public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, 
                                     WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
            // 首先尝试从请求头中获取JWT Token
            String token = request.getHeaders().getFirst("Authorization");
            
            // 如果请求头中没有，则尝试从查询参数中获取
            if (token == null) {
                String query = request.getURI().getQuery();
                if (query != null) {
                    UriComponentsBuilder builder = UriComponentsBuilder.newInstance().query(query);
                    token = builder.build().getQueryParams().getFirst("token");
                }
            }
            
            if (token != null && token.startsWith("Bearer ")) {
                token = token.substring(7); // 移除"Bearer "前缀
            }
            
            if (token != null && !token.isEmpty()) {
                try {
                    // 简单检查token是否包含典型的JWT部分
                    String[] parts = token.split("\\.");
                    if (parts.length == 3) {
                        // 基本格式正确，接受连接
                        attributes.put("token", token);
                        // 使用token的一部分作为用户标识（实际应用中应该解析JWT获取用户ID）
                        attributes.put("userId", "user_" + token.substring(0, Math.min(8, token.length())));
                        System.out.println("WebSocket连接认证成功，Token: " + token);
                        return true;
                    }
                } catch (Exception e) {
                    System.err.println("JWT验证过程中发生错误: " + e.getMessage());
                    e.printStackTrace();
                }
            }
            
            // 如果没有有效的token，拒绝连接
            System.err.println("WebSocket连接缺少有效的认证token");
            response.setStatusCode(org.springframework.http.HttpStatus.UNAUTHORIZED);
            return false;
        }
        
        @Override
        public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, 
                                 WebSocketHandler wsHandler, Exception exception) {
            // 握手后处理，这里不需要特殊处理
        }
    }
}