package main

import (
	"bytes"
	"context"
	"encoding/json"
	"fmt"
	"log"
	"net/http"
	"os"
	"strings"
	"sync"
	"time"

	"github.com/gin-gonic/gin"
	"github.com/google/generative-ai-go/genai"
	"google.golang.org/api/option"
)

// 配置结构
type Config struct {
	AnthropicKey string
	GoogleKey    string
	ServiceURL   string
	DeepseekURL  string
	OpenAIURL    string
}

var (
	config     Config
	configOnce sync.Once
)

// 请求结构
type TokenCountRequest struct {
	Model    string    `json:"model" binding:"required"`
	Messages []Message `json:"messages" binding:"required"`
	System   *string   `json:"system,omitempty"`
}

type Message struct {
	Role    string `json:"role" binding:"required"`
	Content string `json:"content" binding:"required"`
}

// 响应结构
type TokenCountResponse struct {
	InputTokens int `json:"input_tokens"`
}

// 错误响应结构
type ErrorResponse struct {
	Error string `json:"error"`
}

// 模型映射规则
type ModelRule struct {
	Keywords []string
	Target   string
}

var modelRules = []ModelRule{
	{
		Keywords: []string{"gpt"},
		Target:   "gpt-3.5-turbo",
	},
	{
		Keywords: []string{"openai"},
		Target:   "gpt-3.5-turbo",
	},
	{
		Keywords: []string{"deepseek"},
		Target:   "deepseek-v3",
	},
	{
		Keywords: []string{"claude", "3", "sonnet"},
		Target:   "claude-3-sonnet-20240229",
	},
	{
		Keywords: []string{"claude", "3", "7"},
		Target:   "claude-3-7-sonnet-latest",
	},
	{
		Keywords: []string{"claude", "3", "5", "sonnet"},
		Target:   "claude-3-5-sonnet-latest",
	},
	{
		Keywords: []string{"claude", "3", "5", "haiku"},
		Target:   "claude-3-5-haiku-latest",
	},
	{
		Keywords: []string{"claude", "3", "opus"},
		Target:   "claude-3-opus-latest",
	},
	{
		Keywords: []string{"claude", "3", "haiku"},
		Target:   "claude-3-haiku-20240307",
	},
	{
		Keywords: []string{"gemini", "2.0"},
		Target:   "gemini-2.0-flash",
	},
	{
		Keywords: []string{"gemini", "2.5"},
		Target:   "gemini-2.0-flash", // 目前使用2.0-flash作为2.5的替代
	},
	{
		Keywords: []string{"gemini", "1.5"},
		Target:   "gemini-1.5-flash",
	},
}

// 智能匹配模型名称
func matchModelName(input string) string {
	// 转换为小写进行匹配
	input = strings.ToLower(input)

	// 特殊规则:OpenAI GPT-4o
	if (strings.Contains(input, "gpt") && strings.Contains(input, "4o")) ||
		strings.Contains(input, "o1") ||
		strings.Contains(input, "o3") {
		return "gpt-4o"
	}

	// 特殊规则:OpenAI GPT-4
	if (strings.Contains(input, "gpt") && strings.Contains(input, "3") && strings.Contains(input, "5")) ||
		(strings.Contains(input, "gpt") && strings.Contains(input, "4") && !strings.Contains(input, "4o")) {
		return "gpt-4"
	}

	// 遍历所有规则
	for _, rule := range modelRules {
		matches := true
		for _, keyword := range rule.Keywords {
			if !strings.Contains(input, strings.ToLower(keyword)) {
				matches = false
				break
			}
		}
		if matches {
			return rule.Target
		}
	}

	// 如果没有匹配到,返回原始输入
	return input
}

// 加载配置
func loadConfig() Config {
	configOnce.Do(func() {
		config.AnthropicKey = os.Getenv("ANTHROPIC_API_KEY")
		if config.AnthropicKey == "" {
			log.Println("警告: ANTHROPIC_API_KEY 环境变量未设置,Claude模型将无法使用")
		}

		config.GoogleKey = os.Getenv("GOOGLE_API_KEY")
		if config.GoogleKey == "" {
			log.Println("警告: GOOGLE_API_KEY 环境变量未设置,Gemini模型将无法使用")
		}

		// 获取Deepseek服务URL
		config.DeepseekURL = os.Getenv("DEEPSEEK_URL")
		if config.DeepseekURL == "" {
			config.DeepseekURL = "http://127.0.0.1:7861" // 默认本地地址
			log.Println("使用默认Deepseek服务地址:", config.DeepseekURL)
		}

		// 获取OpenAI服务URL
		config.OpenAIURL = os.Getenv("OPENAI_URL")
		if config.OpenAIURL == "" {
			config.OpenAIURL = "http://127.0.0.1:7862" // 默认本地地址
			log.Println("使用默认OpenAI服务地址:", config.OpenAIURL)
		}

		// 获取服务URL,用于防休眠
		config.ServiceURL = os.Getenv("SERVICE_URL")
		if config.ServiceURL == "" {
			log.Println("SERVICE_URL 未设置,防休眠功能将被禁用")
		}
	})
	return config
}

// 使用Claude API计算token
func countTokensWithClaude(req TokenCountRequest) (TokenCountResponse, error) {
	// 准备请求Anthropic API
	client := &http.Client{}
	data, err := json.Marshal(req)
	if err != nil {
		return TokenCountResponse{}, fmt.Errorf("序列化请求失败: %v", err)
	}

	// 创建请求
	request, err := http.NewRequest("POST", "https://api.anthropic.com/v1/messages/count_tokens", bytes.NewBuffer(data))
	if err != nil {
		return TokenCountResponse{}, fmt.Errorf("创建请求失败: %v", err)
	}

	// 设置请求头
	request.Header.Set("x-api-key", config.AnthropicKey)
	request.Header.Set("anthropic-version", "2023-06-01")
	request.Header.Set("content-type", "application/json")

	// 发送请求
	response, err := client.Do(request)
	if err != nil {
		return TokenCountResponse{}, fmt.Errorf("发送请求到Anthropic API失败: %v", err)
	}
	defer response.Body.Close()

	// 读取响应
	var result TokenCountResponse
	if err := json.NewDecoder(response.Body).Decode(&result); err != nil {
		return TokenCountResponse{}, fmt.Errorf("解码响应失败: %v", err)
	}

	return result, nil
}

// 使用Gemini API计算token
func countTokensWithGemini(req TokenCountRequest) (TokenCountResponse, error) {
	// 检查API密钥
	if config.GoogleKey == "" {
		return TokenCountResponse{}, fmt.Errorf("GOOGLE_API_KEY 未设置")
	}

	// 创建Gemini客户端
	ctx := context.Background()
	client, err := genai.NewClient(ctx, option.WithAPIKey(config.GoogleKey))
	if err != nil {
		return TokenCountResponse{}, fmt.Errorf("创建Gemini客户端失败: %v", err)
	}
	defer client.Close()

	// 使用已经匹配好的模型名称
	modelName := req.Model

	// 创建Gemini模型
	model := client.GenerativeModel(modelName)

	// 构建提示内容
	var content string
	if req.System != nil && *req.System != "" {
		content += *req.System + "\n\n"
	}

	for _, msg := range req.Messages {
		if msg.Role == "user" {
			content += "用户: " + msg.Content + "\n"
		} else if msg.Role == "assistant" {
			content += "助手: " + msg.Content + "\n"
		} else {
			content += msg.Role + ": " + msg.Content + "\n"
		}
	}

	// 计算token
	tokResp, err := model.CountTokens(ctx, genai.Text(content))
	if err != nil {
		return TokenCountResponse{}, fmt.Errorf("计算Gemini token失败: %v", err)
	}

	return TokenCountResponse{InputTokens: int(tokResp.TotalTokens)}, nil
}

// 使用Deepseek API计算token
func countTokensWithDeepseek(req TokenCountRequest) (TokenCountResponse, error) {
	// 准备请求
	client := &http.Client{}
	data, err := json.Marshal(req)
	if err != nil {
		return TokenCountResponse{}, fmt.Errorf("序列化请求失败: %v", err)
	}

	// 创建请求
	request, err := http.NewRequest("POST", config.DeepseekURL+"/count_tokens", bytes.NewBuffer(data))
	if err != nil {
		return TokenCountResponse{}, fmt.Errorf("创建请求失败: %v", err)
	}

	// 设置请求头
	request.Header.Set("Content-Type", "application/json")

	// 发送请求
	response, err := client.Do(request)
	if err != nil {
		return TokenCountResponse{}, fmt.Errorf("发送请求到Deepseek服务失败: %v", err)
	}
	defer response.Body.Close()

	// 读取响应
	var result TokenCountResponse
	if err := json.NewDecoder(response.Body).Decode(&result); err != nil {
		return TokenCountResponse{}, fmt.Errorf("解码响应失败: %v", err)
	}

	return result, nil
}

// 使用OpenAI API计算token
func countTokensWithOpenAI(req TokenCountRequest) (TokenCountResponse, error) {
	// 准备请求
	client := &http.Client{}
	data, err := json.Marshal(req)
	if err != nil {
		return TokenCountResponse{}, fmt.Errorf("序列化请求失败: %v", err)
	}

	// 创建请求
	request, err := http.NewRequest("POST", config.OpenAIURL+"/count_tokens", bytes.NewBuffer(data))
	if err != nil {
		return TokenCountResponse{}, fmt.Errorf("创建请求失败: %v", err)
	}

	// 设置请求头
	request.Header.Set("Content-Type", "application/json")

	// 发送请求
	response, err := client.Do(request)
	if err != nil {
		return TokenCountResponse{}, fmt.Errorf("发送请求到OpenAI服务失败: %v", err)
	}
	defer response.Body.Close()

	// 读取响应
	var result struct {
		InputTokens int    `json:"input_tokens"`
		Model       string `json:"model"`
		Encoding    string `json:"encoding"`
	}
	if err := json.NewDecoder(response.Body).Decode(&result); err != nil {
		return TokenCountResponse{}, fmt.Errorf("解码响应失败: %v", err)
	}

	return TokenCountResponse{InputTokens: result.InputTokens}, nil
}

// 计算token
func countTokens(c *gin.Context) {
	var req TokenCountRequest
	if err := c.ShouldBindJSON(&req); err != nil {
		c.JSON(http.StatusBadRequest, ErrorResponse{Error: err.Error()})
		return
	}

	// 保存原始模型名称
	originalModel := req.Model

	// 检查是否为不支持的模型
	isUnsupportedModel := true

	// 检查是否为支持的模型类型
	modelLower := strings.ToLower(req.Model)
	if strings.Contains(modelLower, "gpt") || strings.Contains(modelLower, "openai") ||
		strings.Contains(modelLower, "o1") || strings.Contains(modelLower, "o3") ||
		strings.HasPrefix(modelLower, "claude") ||
		strings.Contains(modelLower, "gemini") ||
		strings.Contains(modelLower, "deepseek") {
		isUnsupportedModel = false
	}

	// 智能匹配模型名称
	req.Model = matchModelName(req.Model)

	var result TokenCountResponse
	var err error

	// 优先检查是否为Deepseek模型
	if strings.Contains(strings.ToLower(req.Model), "deepseek") {
		// 使用Deepseek API
		result, err = countTokensWithDeepseek(req)
	} else if strings.Contains(strings.ToLower(req.Model), "gpt") || strings.Contains(strings.ToLower(req.Model), "openai") {
		// 使用OpenAI API
		result, err = countTokensWithOpenAI(req)
	} else if strings.HasPrefix(strings.ToLower(req.Model), "claude") {
		// 使用Claude API
		if config.AnthropicKey == "" {
			c.JSON(http.StatusBadRequest, ErrorResponse{Error: "ANTHROPIC_API_KEY 未设置,无法使用Claude模型"})
			return
		}
		result, err = countTokensWithClaude(req)
	} else if strings.Contains(strings.ToLower(req.Model), "gemini") {
		// 使用Gemini API
		if config.GoogleKey == "" {
			c.JSON(http.StatusBadRequest, ErrorResponse{Error: "GOOGLE_API_KEY 未设置,无法使用Gemini模型"})
			return
		}
		result, err = countTokensWithGemini(req)
	} else if isUnsupportedModel {
		// 不支持的模型,使用GPT-4o估算
		// 创建新的请求,使用GPT-4o
		gptReq := req
		gptReq.Model = "gpt-4o"

		// 使用OpenAI API
		result, err = countTokensWithOpenAI(gptReq)

		if err == nil {
			// 返回估算值,但添加警告信息
			c.JSON(http.StatusOK, gin.H{
				"input_tokens":   result.InputTokens,
				"warning":        fmt.Sprintf("The tokenizer for model '%s' is not supported yet. This is an estimation based on gpt-4o and may not be accurate.", originalModel),
				"estimated_with": "gpt-4o",
			})
			return
		}
	} else {
		// 完全不支持的情况,返回错误但仍提供估算值
		// 使用GPT-4o进行估算
		gptReq := req
		gptReq.Model = "gpt-4o"

		estimatedResult, estimateErr := countTokensWithOpenAI(gptReq)
		if estimateErr == nil {
			c.JSON(http.StatusOK, gin.H{
				"input_tokens":   estimatedResult.InputTokens,
				"warning":        fmt.Sprintf("The tokenizer for model '%s' is not supported yet. This is an estimation based on gpt-4o and may not be accurate.", originalModel),
				"estimated_with": "gpt-4o",
			})
		} else {
			c.JSON(http.StatusBadRequest, ErrorResponse{Error: fmt.Sprintf("The tokenizer for model '%s' is not supported yet.", originalModel)})
		}
		return
	}

	if err != nil {
		c.JSON(http.StatusInternalServerError, ErrorResponse{Error: err.Error()})
		return
	}

	// 返回结果
	c.JSON(http.StatusOK, result)
}

// 健康检查
func healthCheck(c *gin.Context) {
	c.JSON(http.StatusOK, gin.H{
		"status": "healthy",
		"time":   time.Now().Format(time.RFC3339),
	})
}

// 防休眠任务
func startKeepAlive() {
	if config.ServiceURL == "" {
		return
	}

	healthURL := fmt.Sprintf("%s/health", config.ServiceURL)
	ticker := time.NewTicker(10 * time.Hour)

	// 立即执行一次检查
	go func() {
		log.Printf("Starting keep-alive checks to %s", healthURL)
		for {
			resp, err := http.Get(healthURL)
			if err != nil {
				log.Printf("Keep-alive check failed: %v", err)
			} else {
				resp.Body.Close()
				log.Printf("Keep-alive check successful")
			}

			// 等待下一次触发
			<-ticker.C
		}
	}()
}

func main() {
	// 加载配置
	loadConfig()

	// 设置gin模式
	gin.SetMode(gin.ReleaseMode)

	// 创建路由
	r := gin.Default()

	// 添加中间件
	r.Use(gin.Recovery())
	r.Use(func(c *gin.Context) {
		c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
		c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS")
		c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type")
		if c.Request.Method == "OPTIONS" {
			c.AbortWithStatus(204)
			return
		}
		c.Next()
	})

	// 路由
	r.GET("/health", healthCheck)
	r.POST("/count_tokens", countTokens)

	// 获取端口
	port := os.Getenv("PORT")
	if port == "" {
		port = "7860" // Hugging Face默认端口
	}

	// 启动防休眠任务
	startKeepAlive()

	// 启动服务器
	log.Printf("Server starting on port %s", port)
	if err := r.Run(":" + port); err != nil {
		log.Fatal(err)
	}
}