package main import ( "bytes" "context" "encoding/json" "fmt" "log" "net/http" "os" "strings" "sync" "time" "github.com/gin-gonic/gin" "github.com/google/generative-ai-go/genai" "google.golang.org/api/option" ) // 配置结构 type Config struct { AnthropicKey string GoogleKey string ServiceURL string DeepseekURL string OpenAIURL string } var ( config Config configOnce sync.Once ) // 请求结构 type TokenCountRequest struct { Model string `json:"model" binding:"required"` Messages []Message `json:"messages" binding:"required"` System *string `json:"system,omitempty"` } type Message struct { Role string `json:"role" binding:"required"` Content string `json:"content" binding:"required"` } // 响应结构 type TokenCountResponse struct { InputTokens int `json:"input_tokens"` } // 错误响应结构 type ErrorResponse struct { Error string `json:"error"` } // 模型映射规则 type ModelRule struct { Keywords []string Target string } var modelRules = []ModelRule{ { Keywords: []string{"gpt"}, Target: "gpt-3.5-turbo", }, { Keywords: []string{"openai"}, Target: "gpt-3.5-turbo", }, { Keywords: []string{"deepseek"}, Target: "deepseek-v3", }, { Keywords: []string{"claude", "3", "sonnet"}, Target: "claude-3-sonnet-20240229", }, { Keywords: []string{"claude", "3", "7"}, Target: "claude-3-7-sonnet-latest", }, { Keywords: []string{"claude", "3", "5", "sonnet"}, Target: "claude-3-5-sonnet-latest", }, { Keywords: []string{"claude", "3", "5", "haiku"}, Target: "claude-3-5-haiku-latest", }, { Keywords: []string{"claude", "3", "opus"}, Target: "claude-3-opus-latest", }, { Keywords: []string{"claude", "3", "haiku"}, Target: "claude-3-haiku-20240307", }, { Keywords: []string{"gemini", "2.0"}, Target: "gemini-2.0-flash", }, { Keywords: []string{"gemini", "2.5"}, Target: "gemini-2.0-flash", // 目前使用2.0-flash作为2.5的替代 }, { Keywords: []string{"gemini", "1.5"}, Target: "gemini-1.5-flash", }, } // 智能匹配模型名称 func matchModelName(input string) string { // 转换为小写进行匹配 input = strings.ToLower(input) // 特殊规则:OpenAI GPT-4o if (strings.Contains(input, "gpt") && strings.Contains(input, "4o")) || strings.Contains(input, "o1") || strings.Contains(input, "o3") { return "gpt-4o" } // 特殊规则:OpenAI GPT-4 if (strings.Contains(input, "gpt") && strings.Contains(input, "3") && strings.Contains(input, "5")) || (strings.Contains(input, "gpt") && strings.Contains(input, "4") && !strings.Contains(input, "4o")) { return "gpt-4" } // 遍历所有规则 for _, rule := range modelRules { matches := true for _, keyword := range rule.Keywords { if !strings.Contains(input, strings.ToLower(keyword)) { matches = false break } } if matches { return rule.Target } } // 如果没有匹配到,返回原始输入 return input } // 加载配置 func loadConfig() Config { configOnce.Do(func() { config.AnthropicKey = os.Getenv("ANTHROPIC_API_KEY") if config.AnthropicKey == "" { log.Println("警告: ANTHROPIC_API_KEY 环境变量未设置,Claude模型将无法使用") } config.GoogleKey = os.Getenv("GOOGLE_API_KEY") if config.GoogleKey == "" { log.Println("警告: GOOGLE_API_KEY 环境变量未设置,Gemini模型将无法使用") } // 获取Deepseek服务URL config.DeepseekURL = os.Getenv("DEEPSEEK_URL") if config.DeepseekURL == "" { config.DeepseekURL = "http://127.0.0.1:7861" // 默认本地地址 log.Println("使用默认Deepseek服务地址:", config.DeepseekURL) } // 获取OpenAI服务URL config.OpenAIURL = os.Getenv("OPENAI_URL") if config.OpenAIURL == "" { config.OpenAIURL = "http://127.0.0.1:7862" // 默认本地地址 log.Println("使用默认OpenAI服务地址:", config.OpenAIURL) } // 获取服务URL,用于防休眠 config.ServiceURL = os.Getenv("SERVICE_URL") if config.ServiceURL == "" { log.Println("SERVICE_URL 未设置,防休眠功能将被禁用") } }) return config } // 使用Claude API计算token func countTokensWithClaude(req TokenCountRequest) (TokenCountResponse, error) { // 准备请求Anthropic API client := &http.Client{} data, err := json.Marshal(req) if err != nil { return TokenCountResponse{}, fmt.Errorf("序列化请求失败: %v", err) } // 创建请求 request, err := http.NewRequest("POST", "https://api.anthropic.com/v1/messages/count_tokens", bytes.NewBuffer(data)) if err != nil { return TokenCountResponse{}, fmt.Errorf("创建请求失败: %v", err) } // 设置请求头 request.Header.Set("x-api-key", config.AnthropicKey) request.Header.Set("anthropic-version", "2023-06-01") request.Header.Set("content-type", "application/json") // 发送请求 response, err := client.Do(request) if err != nil { return TokenCountResponse{}, fmt.Errorf("发送请求到Anthropic API失败: %v", err) } defer response.Body.Close() // 读取响应 var result TokenCountResponse if err := json.NewDecoder(response.Body).Decode(&result); err != nil { return TokenCountResponse{}, fmt.Errorf("解码响应失败: %v", err) } return result, nil } // 使用Gemini API计算token func countTokensWithGemini(req TokenCountRequest) (TokenCountResponse, error) { // 检查API密钥 if config.GoogleKey == "" { return TokenCountResponse{}, fmt.Errorf("GOOGLE_API_KEY 未设置") } // 创建Gemini客户端 ctx := context.Background() client, err := genai.NewClient(ctx, option.WithAPIKey(config.GoogleKey)) if err != nil { return TokenCountResponse{}, fmt.Errorf("创建Gemini客户端失败: %v", err) } defer client.Close() // 使用已经匹配好的模型名称 modelName := req.Model // 创建Gemini模型 model := client.GenerativeModel(modelName) // 构建提示内容 var content string if req.System != nil && *req.System != "" { content += *req.System + "\n\n" } for _, msg := range req.Messages { if msg.Role == "user" { content += "用户: " + msg.Content + "\n" } else if msg.Role == "assistant" { content += "助手: " + msg.Content + "\n" } else { content += msg.Role + ": " + msg.Content + "\n" } } // 计算token tokResp, err := model.CountTokens(ctx, genai.Text(content)) if err != nil { return TokenCountResponse{}, fmt.Errorf("计算Gemini token失败: %v", err) } return TokenCountResponse{InputTokens: int(tokResp.TotalTokens)}, nil } // 使用Deepseek API计算token func countTokensWithDeepseek(req TokenCountRequest) (TokenCountResponse, error) { // 准备请求 client := &http.Client{} data, err := json.Marshal(req) if err != nil { return TokenCountResponse{}, fmt.Errorf("序列化请求失败: %v", err) } // 创建请求 request, err := http.NewRequest("POST", config.DeepseekURL+"/count_tokens", bytes.NewBuffer(data)) if err != nil { return TokenCountResponse{}, fmt.Errorf("创建请求失败: %v", err) } // 设置请求头 request.Header.Set("Content-Type", "application/json") // 发送请求 response, err := client.Do(request) if err != nil { return TokenCountResponse{}, fmt.Errorf("发送请求到Deepseek服务失败: %v", err) } defer response.Body.Close() // 读取响应 var result TokenCountResponse if err := json.NewDecoder(response.Body).Decode(&result); err != nil { return TokenCountResponse{}, fmt.Errorf("解码响应失败: %v", err) } return result, nil } // 使用OpenAI API计算token func countTokensWithOpenAI(req TokenCountRequest) (TokenCountResponse, error) { // 准备请求 client := &http.Client{} data, err := json.Marshal(req) if err != nil { return TokenCountResponse{}, fmt.Errorf("序列化请求失败: %v", err) } // 创建请求 request, err := http.NewRequest("POST", config.OpenAIURL+"/count_tokens", bytes.NewBuffer(data)) if err != nil { return TokenCountResponse{}, fmt.Errorf("创建请求失败: %v", err) } // 设置请求头 request.Header.Set("Content-Type", "application/json") // 发送请求 response, err := client.Do(request) if err != nil { return TokenCountResponse{}, fmt.Errorf("发送请求到OpenAI服务失败: %v", err) } defer response.Body.Close() // 读取响应 var result struct { InputTokens int `json:"input_tokens"` Model string `json:"model"` Encoding string `json:"encoding"` } if err := json.NewDecoder(response.Body).Decode(&result); err != nil { return TokenCountResponse{}, fmt.Errorf("解码响应失败: %v", err) } return TokenCountResponse{InputTokens: result.InputTokens}, nil } // 计算token func countTokens(c *gin.Context) { var req TokenCountRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, ErrorResponse{Error: err.Error()}) return } // 保存原始模型名称 originalModel := req.Model // 检查是否为不支持的模型 isUnsupportedModel := true // 检查是否为支持的模型类型 modelLower := strings.ToLower(req.Model) if strings.Contains(modelLower, "gpt") || strings.Contains(modelLower, "openai") || strings.Contains(modelLower, "o1") || strings.Contains(modelLower, "o3") || strings.HasPrefix(modelLower, "claude") || strings.Contains(modelLower, "gemini") || strings.Contains(modelLower, "deepseek") { isUnsupportedModel = false } // 智能匹配模型名称 req.Model = matchModelName(req.Model) var result TokenCountResponse var err error // 优先检查是否为Deepseek模型 if strings.Contains(strings.ToLower(req.Model), "deepseek") { // 使用Deepseek API result, err = countTokensWithDeepseek(req) } else if strings.Contains(strings.ToLower(req.Model), "gpt") || strings.Contains(strings.ToLower(req.Model), "openai") { // 使用OpenAI API result, err = countTokensWithOpenAI(req) } else if strings.HasPrefix(strings.ToLower(req.Model), "claude") { // 使用Claude API if config.AnthropicKey == "" { c.JSON(http.StatusBadRequest, ErrorResponse{Error: "ANTHROPIC_API_KEY 未设置,无法使用Claude模型"}) return } result, err = countTokensWithClaude(req) } else if strings.Contains(strings.ToLower(req.Model), "gemini") { // 使用Gemini API if config.GoogleKey == "" { c.JSON(http.StatusBadRequest, ErrorResponse{Error: "GOOGLE_API_KEY 未设置,无法使用Gemini模型"}) return } result, err = countTokensWithGemini(req) } else if isUnsupportedModel { // 不支持的模型,使用GPT-4o估算 // 创建新的请求,使用GPT-4o gptReq := req gptReq.Model = "gpt-4o" // 使用OpenAI API result, err = countTokensWithOpenAI(gptReq) if err == nil { // 返回估算值,但添加警告信息 c.JSON(http.StatusOK, gin.H{ "input_tokens": result.InputTokens, "warning": fmt.Sprintf("The tokenizer for model '%s' is not supported yet. This is an estimation based on gpt-4o and may not be accurate.", originalModel), "estimated_with": "gpt-4o", }) return } } else { // 完全不支持的情况,返回错误但仍提供估算值 // 使用GPT-4o进行估算 gptReq := req gptReq.Model = "gpt-4o" estimatedResult, estimateErr := countTokensWithOpenAI(gptReq) if estimateErr == nil { c.JSON(http.StatusOK, gin.H{ "input_tokens": estimatedResult.InputTokens, "warning": fmt.Sprintf("The tokenizer for model '%s' is not supported yet. This is an estimation based on gpt-4o and may not be accurate.", originalModel), "estimated_with": "gpt-4o", }) } else { c.JSON(http.StatusBadRequest, ErrorResponse{Error: fmt.Sprintf("The tokenizer for model '%s' is not supported yet.", originalModel)}) } return } if err != nil { c.JSON(http.StatusInternalServerError, ErrorResponse{Error: err.Error()}) return } // 返回结果 c.JSON(http.StatusOK, result) } // 健康检查 func healthCheck(c *gin.Context) { c.JSON(http.StatusOK, gin.H{ "status": "healthy", "time": time.Now().Format(time.RFC3339), }) } // 防休眠任务 func startKeepAlive() { if config.ServiceURL == "" { return } healthURL := fmt.Sprintf("%s/health", config.ServiceURL) ticker := time.NewTicker(10 * time.Hour) // 立即执行一次检查 go func() { log.Printf("Starting keep-alive checks to %s", healthURL) for { resp, err := http.Get(healthURL) if err != nil { log.Printf("Keep-alive check failed: %v", err) } else { resp.Body.Close() log.Printf("Keep-alive check successful") } // 等待下一次触发 <-ticker.C } }() } func main() { // 加载配置 loadConfig() // 设置gin模式 gin.SetMode(gin.ReleaseMode) // 创建路由 r := gin.Default() // 添加中间件 r.Use(gin.Recovery()) r.Use(func(c *gin.Context) { c.Writer.Header().Set("Access-Control-Allow-Origin", "*") c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS") c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type") if c.Request.Method == "OPTIONS" { c.AbortWithStatus(204) return } c.Next() }) // 路由 r.GET("/health", healthCheck) r.POST("/count_tokens", countTokens) // 获取端口 port := os.Getenv("PORT") if port == "" { port = "7860" // Hugging Face默认端口 } // 启动防休眠任务 startKeepAlive() // 启动服务器 log.Printf("Server starting on port %s", port) if err := r.Run(":" + port); err != nil { log.Fatal(err) } }