package pangea.hiagent.security;

import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import org.springframework.web.filter.OncePerRequestFilter;
import pangea.hiagent.common.utils.JwtUtil;

import java.io.IOException;
import java.util.Collections;
import java.util.List;

/**
 * SSE流式端点授权检查过滤器
 * 在Spring Security的AuthorizationFilter之前运行，提前处理流式端点的身份验证检查
 * 避免响应被提交后才处理异常的问题
 */
@Slf4j
@Component
public class SseAuthorizationFilter extends OncePerRequestFilter {

    private static final String STREAM_ENDPOINT = "/api/v1/agent/chat-stream";
    private static final String TIMELINE_ENDPOINT = "/api/v1/agent/timeline-events";
    
    private final JwtUtil jwtUtil;
    
    public SseAuthorizationFilter(JwtUtil jwtUtil) {
        this.jwtUtil = jwtUtil;
    }

    @Override
    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
            throws ServletException, IOException {
        
        String requestUri = request.getRequestURI();
        boolean isStreamEndpoint = requestUri.contains(STREAM_ENDPOINT);
        boolean isTimelineEndpoint = requestUri.contains(TIMELINE_ENDPOINT);
        
        // 只处理SSE端点
        if (isStreamEndpoint || isTimelineEndpoint) {
            log.debug("SSE端点授权检查: {} {}", request.getMethod(), requestUri);
            
            // 尝试从请求中提取并验证JWT token
            String token = extractTokenFromRequest(request);
            
            if (StringUtils.hasText(token)) {
                log.debug("提取到JWT token，进行验证");
                try {
                    // 验证token是否有效
                    if (jwtUtil.validateToken(token)) {
                        String userId = jwtUtil.getUserIdFromToken(token);
                        if (userId != null) {
                            // 创建认证对象
                            List<SimpleGrantedAuthority> authorities = Collections.singletonList(new SimpleGrantedAuthority("ROLE_USER"));
                            UsernamePasswordAuthenticationToken authentication = 
                                    new UsernamePasswordAuthenticationToken(userId, null, authorities);
                            SecurityContextHolder.getContext().setAuthentication(authentication);
                            log.debug("SSE端点JWT验证成功，用户: {}", userId);
                            // 继续执行过滤器链
                            filterChain.doFilter(request, response);
                            return;
                        }
                    }
                } catch (Exception e) {
                    log.warn("SSE端点JWT验证失败: {}", e.getMessage());
                }
            }
            
            // token无效或不存在，拒绝连接
            log.warn("SSE端点未认证访问，拒绝连接: {} {}", request.getMethod(), requestUri);
            sendSseUnauthorizedError(response);
            return;
        }
        
        // 继续执行过滤器链（非SSE端点）
        filterChain.doFilter(request, response);
    }
    
    /**
     * 发送SSE格式的未授权错误响应
     */
    private void sendSseUnauthorizedError(HttpServletResponse response) {
        try {
            response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
            response.setContentType("text/event-stream;charset=UTF-8");
            response.setCharacterEncoding("UTF-8");
            
            // 发送SSE格式的错误事件
            response.getWriter().write("event: error\n");
            response.getWriter().write("data: {\"error\": \"未授权访问，请先登录\", \"code\": 401, \"timestamp\": " + 
                    System.currentTimeMillis() + "}\n\n");
            response.getWriter().flush();
            
            log.debug("已发送SSE未授权错误响应");
        } catch (IOException e) {
            log.error("发送SSE未授权错误响应失败", e);
        }
    }
    
    /**
     * 从请求头或参数中提取Token
     */
    private String extractTokenFromRequest(HttpServletRequest request) {
        // 首先尝试从请求头中提取Token
        String authHeader = request.getHeader("Authorization");
        if (StringUtils.hasText(authHeader) && authHeader.startsWith("Bearer ")) {
            return authHeader.substring(7);
        }
        
        // 如果请求头中没有Token，则尝试从URL参数中提取
        String tokenParam = request.getParameter("token");
        if (StringUtils.hasText(tokenParam)) {
            return tokenParam;
        }
        
        return null;
    }

    /**
     * 确定此过滤器是否应处理给定请求
     * 只处理SSE流式端点
     */
    @Override
    protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException {
        String requestUri = request.getRequestURI();
        boolean isStreamEndpoint = requestUri.contains(STREAM_ENDPOINT);
        boolean isTimelineEndpoint = requestUri.contains(TIMELINE_ENDPOINT);
        
        // 如果不是SSE端点，跳过此过滤器
        return !(isStreamEndpoint || isTimelineEndpoint);
    }
}
