package pangea.hiagent.security;

import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.stereotype.Component;
import org.springframework.web.filter.OncePerRequestFilter;
import pangea.hiagent.model.Agent;
import pangea.hiagent.web.service.AgentService;

import java.io.IOException;

/**
 * SSE流式端点授权检查过滤器
 * 在Spring Security的AuthorizationFilter之前运行，提前处理流式端点的身份验证检查
 * 避免响应被提交后才处理异常的问题
 */
@Slf4j
@Component
public class SseAuthorizationFilter extends OncePerRequestFilter {

    private static final String STREAM_ENDPOINT = "/api/v1/agent/chat-stream";
    private static final String TIMELINE_ENDPOINT = "/api/v1/agent/timeline-events";
    
    private final AgentService agentService;
    
    public SseAuthorizationFilter(AgentService agentService) {
        this.agentService = agentService;
    }
        
    /**
     * 发送SSE格式的错误响应
     */
    private void sendSseError(HttpServletResponse response, int status, String errorMessage) {
        try {
            response.setStatus(status);
            response.setContentType("text/event-stream;charset=UTF-8");
            response.setCharacterEncoding("UTF-8");
                
            // 发送SSE格式的错误事件
            response.getWriter().write("event: error\n");
            response.getWriter().write("data: {\"error\": \"" + errorMessage + "\", \"code\": " + status + ", \"timestamp\": " + 
                    System.currentTimeMillis() + "}\n\n");
            response.getWriter().flush();
                
            log.debug("已发送SSE错误响应: {} - {}", status, errorMessage);
        } catch (IOException e) {
            log.error("发送SSE错误响应失败: {}", e.getMessage());
        }
    }
        
    @Override
    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
            throws ServletException, IOException {
        
        String requestUri = request.getRequestURI();
        boolean isStreamEndpoint = requestUri.contains(STREAM_ENDPOINT);
        boolean isTimelineEndpoint = requestUri.contains(TIMELINE_ENDPOINT);
        
        // 只处理SSE端点
        if (isStreamEndpoint || isTimelineEndpoint) {
            log.debug("SSE端点授权检查: {} {}", request.getMethod(), requestUri);
            
            // 检查响应是否已经提交，避免后续错误处理异常
            if (response.isCommitted()) {
                log.warn("响应已提交，无法处理SSE端点授权检查");
                return;
            }
            
            // 从SecurityContext获取当前认证用户
            String userId = getCurrentUserId();
            
            if (userId != null) {
                log.debug("SSE端点已认证，用户: {}", userId);
                
                // 如果是chat-stream端点，需要额外验证agent权限
                if (isStreamEndpoint) {
                    String agentId = request.getParameter("agentId");
                    if (agentId != null) {
                        try {
                            Agent agent = agentService.getAgent(agentId);
                            if (agent == null) {
                                log.warn("SSE端点访问失败：Agent不存在 - AgentId: {}", agentId);
                                sendSseError(response, HttpServletResponse.SC_NOT_FOUND, "Agent不存在");
                                return;
                            }
                            
                            // 验证用户是否有权限访问该agent
                            if (!agent.getOwner().equals(userId) && !isAdminUser(userId)) {
                                log.warn("SSE端点访问失败：用户 {} 无权限访问Agent: {}", userId, agentId);
                                sendSseError(response, HttpServletResponse.SC_FORBIDDEN, "访问被拒绝，无权限访问该Agent");
                                return;
                            }
                            
                            log.debug("SSE端点Agent权限验证成功，用户: {}, Agent: {}", userId, agentId);
                        } catch (Exception e) {
                            log.error("SSE端点Agent权限验证异常: {}", e.getMessage());
                            sendSseError(response, HttpServletResponse.SC_FORBIDDEN, "访问被拒绝");
                            return;
                        }
                    } else {
                        log.warn("SSE端点请求缺少agentId参数");
                        sendSseError(response, HttpServletResponse.SC_NOT_FOUND, "Agent不存在");
                        return;
                    }
                }
                
                // 继续执行过滤器链
                filterChain.doFilter(request, response);
                return;
            } else {
                // 用户未认证，拒绝连接
                log.warn("SSE端点未认证访问，拒绝连接: {} {}", request.getMethod(), requestUri);
                sendSseError(response, HttpServletResponse.SC_UNAUTHORIZED, "未授权访问，请先登录");
                return;
            }
        }
        
        // 继续执行过滤器链（非SSE端点）
        filterChain.doFilter(request, response);
    }
    
    /**
     * 从SecurityContext获取当前认证用户ID
     */
    private String getCurrentUserId() {
        var authentication = SecurityContextHolder.getContext().getAuthentication();
        if (authentication != null && authentication.isAuthenticated() && !"anonymousUser".equals(authentication.getPrincipal())) {
            return authentication.getName();
        }
        return null;
    }
    
    /**
     * 检查是否为管理员用户
     */
    private boolean isAdminUser(String userId) {
        // 与DefaultPermissionEvaluator保持一致的管理员检查逻辑
        return "admin".equals(userId) || "user-001".equals(userId);
    }
    
    /**
     * 确定此过滤器是否应处理给定请求
     * 只处理SSE流式端点
     */
    @Override
    protected boolean shouldNotFilter(HttpServletRequest request) {
        String requestUri = request.getRequestURI();
        return !(requestUri.contains(STREAM_ENDPOINT) || requestUri.contains(TIMELINE_ENDPOINT));
    }
}
