package pangea.hiagent.tool.impl;

import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.tool.annotation.Tool;
import org.springframework.stereotype.Component;

import java.util.Arrays;

/**
 * 统计计算工具
 * 用于执行各种统计分析计算
 */
@Slf4j
@Component
public class StatisticalCalculationTool {
    
    public StatisticalCalculationTool() {
        // 默认构造器
    }
    
    /**
     * 计算数据的基本统计信息
     * @param data 数据数组
     * @return 统计结果
     */
    @Tool(description = "计算数据的基本统计信息，包括均值、中位数、标准差等")
    public String calculateBasicStatistics(double[] data) {
        log.debug("开始计算基本统计信息，数据点数量: {}", data != null ? data.length : 0);
        
        try {
            if (data == null || data.length == 0) {
                log.warn("数据不能为空");
                return "错误：数据不能为空";
            }
            
            // 计算基本统计信息
            double sum = 0;
            double min = data[0];
            double max = data[0];
            
            for (double value : data) {
                sum += value;
                if (value < min) min = value;
                if (value > max) max = value;
            }
            
            double mean = sum / data.length;
            
            // 计算方差和标准差
            double varianceSum = 0;
            for (double value : data) {
                varianceSum += Math.pow(value - mean, 2);
            }
            double variance = varianceSum / data.length;
            double stdDev = Math.sqrt(variance);
            
            // 计算中位数
            double[] sortedData = data.clone();
            Arrays.sort(sortedData);
            double median;
            if (sortedData.length % 2 == 0) {
                median = (sortedData[sortedData.length/2 - 1] + sortedData[sortedData.length/2]) / 2.0;
            } else {
                median = sortedData[sortedData.length/2];
            }
            
            // 生成统计结果
            StringBuilder result = new StringBuilder();
            result.append("基本统计信息计算完成：\n");
            result.append("数据点数量: ").append(data.length).append("\n");
            result.append("最小值: ").append(min).append("\n");
            result.append("最大值: ").append(max).append("\n");
            result.append("均值: ").append(String.format("%.4f", mean)).append("\n");
            result.append("中位数: ").append(String.format("%.4f", median)).append("\n");
            result.append("方差: ").append(String.format("%.4f", variance)).append("\n");
            result.append("标准差: ").append(String.format("%.4f", stdDev)).append("\n");
            
            log.info("基本统计信息计算完成，数据点数量: {}", data.length);
            return result.toString();
        } catch (Exception e) {
            log.error("计算基本统计信息时发生错误: {}", e.getMessage(), e);
            return "计算基本统计信息时发生错误: " + e.getMessage();
        }
    }
    
    /**
     * 计算两个变量之间的相关系数
     * @param x 第一个变量的数据数组
     * @param y 第二个变量的数据数组
     * @return 相关系数结果
     */
    @Tool(description = "计算两个变量之间的皮尔逊相关系数")
    public String calculateCorrelation(double[] x, double[] y) {
        log.debug("开始计算相关系数，X数据点数量: {}, Y数据点数量: {}", 
                 x != null ? x.length : 0, y != null ? y.length : 0);
        
        try {
            if (x == null || x.length == 0) {
                log.warn("X数据不能为空");
                return "错误：X数据不能为空";
            }
            
            if (y == null || y.length == 0) {
                log.warn("Y数据不能为空");
                return "错误：Y数据不能为空";
            }
            
            if (x.length != y.length) {
                log.warn("X和Y数据长度必须相等");
                return "错误：X和Y数据长度必须相等";
            }
            
            // 计算均值
            double meanX = 0, meanY = 0;
            for (int i = 0; i < x.length; i++) {
                meanX += x[i];
                meanY += y[i];
            }
            meanX /= x.length;
            meanY /= y.length;
            
            // 计算协方差和标准差
            double covariance = 0;
            double stdDevX = 0, stdDevY = 0;
            
            for (int i = 0; i < x.length; i++) {
                double diffX = x[i] - meanX;
                double diffY = y[i] - meanY;
                covariance += diffX * diffY;
                stdDevX += diffX * diffX;
                stdDevY += diffY * diffY;
            }
            
            covariance /= x.length;
            stdDevX = Math.sqrt(stdDevX / x.length);
            stdDevY = Math.sqrt(stdDevY / y.length);
            
            // 计算相关系数
            double correlation = 0;
            if (stdDevX != 0 && stdDevY != 0) {
                correlation = covariance / (stdDevX * stdDevY);
            }
            
            // 生成结果
            StringBuilder result = new StringBuilder();
            result.append("相关系数计算完成：\n");
            result.append("数据点数量: ").append(x.length).append("\n");
            result.append("X变量均值: ").append(String.format("%.4f", meanX)).append("\n");
            result.append("Y变量均值: ").append(String.format("%.4f", meanY)).append("\n");
            result.append("协方差: ").append(String.format("%.4f", covariance)).append("\n");
            result.append("X变量标准差: ").append(String.format("%.4f", stdDevX)).append("\n");
            result.append("Y变量标准差: ").append(String.format("%.4f", stdDevY)).append("\n");
            result.append("皮尔逊相关系数: ").append(String.format("%.4f", correlation)).append("\n");
            
            // 解释相关系数
            String interpretation;
            if (Math.abs(correlation) >= 0.8) {
                interpretation = "强相关";
            } else if (Math.abs(correlation) >= 0.5) {
                interpretation = "中等相关";
            } else if (Math.abs(correlation) >= 0.3) {
                interpretation = "弱相关";
            } else {
                interpretation = "几乎无相关";
            }
            
            result.append("相关性解释: ").append(interpretation).append("\n");
            
            log.info("相关系数计算完成，数据点数量: {}", x.length);
            return result.toString();
        } catch (Exception e) {
            log.error("计算相关系数时发生错误: {}", e.getMessage(), e);
            return "计算相关系数时发生错误: " + e.getMessage();
        }
    }
    
    /**
     * 执行线性回归分析
     * @param x 自变量数据数组
     * @param y 因变量数据数组
     * @return 回归分析结果
     */
    @Tool(description = "执行简单的线性回归分析，计算回归系数和拟合优度")
    public String performLinearRegression(double[] x, double[] y) {
        log.debug("开始执行线性回归分析，X数据点数量: {}, Y数据点数量: {}", 
                 x != null ? x.length : 0, y != null ? y.length : 0);
        
        try {
            if (x == null || x.length == 0) {
                log.warn("X数据不能为空");
                return "错误：X数据不能为空";
            }
            
            if (y == null || y.length == 0) {
                log.warn("Y数据不能为空");
                return "错误：Y数据不能为空";
            }
            
            if (x.length != y.length) {
                log.warn("X和Y数据长度必须相等");
                return "错误：X和Y数据长度必须相等";
            }
            
            int n = x.length;
            
            // 计算均值
            double meanX = 0, meanY = 0;
            for (int i = 0; i < n; i++) {
                meanX += x[i];
                meanY += y[i];
            }
            meanX /= n;
            meanY /= n;
            
            // 计算回归系数
            double numerator = 0, denominator = 0;
            for (int i = 0; i < n; i++) {
                numerator += (x[i] - meanX) * (y[i] - meanY);
                denominator += Math.pow(x[i] - meanX, 2);
            }
            
            // 斜率和截距
            double slope = denominator != 0 ? numerator / denominator : 0;
            double intercept = meanY - slope * meanX;
            
            // 计算拟合优度(R²)
            double ssTotal = 0, ssResidual = 0;
            for (int i = 0; i < n; i++) {
                double predictedY = slope * x[i] + intercept;
                ssTotal += Math.pow(y[i] - meanY, 2);
                ssResidual += Math.pow(y[i] - predictedY, 2);
            }
            
            double rSquared = ssTotal != 0 ? 1 - (ssResidual / ssTotal) : 1;
            
            // 生成结果
            StringBuilder result = new StringBuilder();
            result.append("线性回归分析完成：\n");
            result.append("数据点数量: ").append(n).append("\n");
            result.append("回归方程: y = ").append(String.format("%.4f", slope))
                  .append(" * x + ").append(String.format("%.4f", intercept)).append("\n");
            result.append("斜率: ").append(String.format("%.4f", slope)).append("\n");
            result.append("截距: ").append(String.format("%.4f", intercept)).append("\n");
            result.append("拟合优度(R²): ").append(String.format("%.4f", rSquared)).append("\n");
            
            // 解释拟合优度
            String interpretation;
            if (rSquared >= 0.8) {
                interpretation = "模型拟合非常好";
            } else if (rSquared >= 0.6) {
                interpretation = "模型拟合良好";
            } else if (rSquared >= 0.4) {
                interpretation = "模型拟合一般";
            } else {
                interpretation = "模型拟合较差";
            }
            
            result.append("模型解释力: ").append(interpretation).append("\n");
            
            log.info("线性回归分析完成，数据点数量: {}", n);
            return result.toString();
        } catch (Exception e) {
            log.error("执行线性回归分析时发生错误: {}", e.getMessage(), e);
            return "执行线性回归分析时发生错误: " + e.getMessage();
        }
    }
}