Spaces:
Running
Running
| // Copyright (c) 2025-2026 libaxuan | |
| // | |
| // Permission is hereby granted, free of charge, to any person obtaining a copy | |
| // of this software and associated documentation files (the "Software"), to deal | |
| // in the Software without restriction, including without limitation the rights | |
| // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
| // copies of the Software, and to permit persons to whom the Software is | |
| // furnished to do so, subject to the following conditions: | |
| // | |
| // The above copyright notice and this permission notice shall be included in all | |
| // copies or substantial portions of the Software. | |
| // | |
| // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
| // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
| // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
| // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
| // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
| // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
| // SOFTWARE. | |
| package utils | |
| import ( | |
| "bufio" | |
| "context" | |
| "crypto/rand" | |
| "cursor2api-go/middleware" | |
| "cursor2api-go/models" | |
| "encoding/hex" | |
| "encoding/json" | |
| "fmt" | |
| "io" | |
| "net/http" | |
| "os/exec" | |
| "strings" | |
| "time" | |
| "github.com/gin-gonic/gin" | |
| "github.com/sirupsen/logrus" | |
| ) | |
| // GenerateRandomString 生成指定长度的随机字符串 | |
| func GenerateRandomString(length int) string { | |
| if length <= 0 { | |
| return "" | |
| } | |
| byteLen := (length + 1) / 2 | |
| bytes := make([]byte, byteLen) | |
| if _, err := rand.Read(bytes); err != nil { | |
| fallback := fmt.Sprintf("%d", time.Now().UnixNano()) | |
| if len(fallback) >= length { | |
| return fallback[:length] | |
| } | |
| return fallback | |
| } | |
| encoded := hex.EncodeToString(bytes) | |
| if len(encoded) < length { | |
| encoded += GenerateRandomString(length - len(encoded)) | |
| } | |
| return encoded[:length] | |
| } | |
| // GenerateChatCompletionID 生成聊天完成ID | |
| func GenerateChatCompletionID() string { | |
| return "chatcmpl-" + GenerateRandomString(29) | |
| } | |
| // ParseSSELine 解析SSE数据行 | |
| func ParseSSELine(line string) string { | |
| line = strings.TrimSpace(line) | |
| if strings.HasPrefix(line, "data: ") { | |
| return strings.TrimSpace(line[6:]) // 去掉 'data: ' 前缀并去除前导空格 | |
| } | |
| return "" | |
| } | |
| // WriteSSEEvent 写入SSE事件 | |
| func WriteSSEEvent(w http.ResponseWriter, event, data string) error { | |
| if event != "" { | |
| if _, err := fmt.Fprintf(w, "event: %s\n", event); err != nil { | |
| return err | |
| } | |
| } | |
| if _, err := fmt.Fprintf(w, "data: %s\n\n", data); err != nil { | |
| return err | |
| } | |
| // 刷新缓冲区 | |
| if flusher, ok := w.(http.Flusher); ok { | |
| flusher.Flush() | |
| } | |
| return nil | |
| } | |
| // StreamChatCompletion 处理流式聊天完成 | |
| // StreamChatCompletion 处理流式聊天完成 | |
| func StreamChatCompletion(c *gin.Context, chatGenerator <-chan interface{}, modelName string) { | |
| // 设置SSE头 | |
| c.Header("Content-Type", "text/event-stream") | |
| c.Header("Cache-Control", "no-cache") | |
| c.Header("Connection", "keep-alive") | |
| c.Header("Access-Control-Allow-Origin", "*") | |
| // 生成响应ID | |
| responseID := GenerateChatCompletionID() | |
| started := false | |
| toolCallIndex := 0 | |
| writeChunk := func(delta models.StreamDelta, finishReason *string) { | |
| streamResp := models.NewChatCompletionStreamResponse(responseID, modelName, delta, finishReason) | |
| if jsonData, err := json.Marshal(streamResp); err == nil { | |
| WriteSSEEvent(c.Writer, "", string(jsonData)) | |
| } | |
| } | |
| // 处理流式数据 | |
| ctx := c.Request.Context() | |
| for { | |
| select { | |
| case <-ctx.Done(): | |
| logrus.Debug("Client disconnected during streaming") | |
| return | |
| case data, ok := <-chatGenerator: | |
| if !ok { | |
| // 通道关闭,发送完成事件 | |
| reason := "stop" | |
| if toolCallIndex > 0 { | |
| reason = "tool_calls" | |
| } | |
| writeChunk(models.StreamDelta{}, stringPtr(reason)) | |
| WriteSSEEvent(c.Writer, "", "[DONE]") | |
| return | |
| } | |
| switch v := data.(type) { | |
| case models.AssistantEvent: | |
| if !started { | |
| writeChunk(models.StreamDelta{Role: "assistant"}, nil) | |
| started = true | |
| } | |
| switch v.Kind { | |
| case models.AssistantEventText: | |
| if v.Text != "" { | |
| writeChunk(models.StreamDelta{Content: v.Text}, nil) | |
| } | |
| case models.AssistantEventToolCall: | |
| if v.ToolCall != nil { | |
| writeChunk(models.StreamDelta{ | |
| ToolCalls: []models.ToolCallDelta{ | |
| { | |
| Index: toolCallIndex, | |
| ID: v.ToolCall.ID, | |
| Type: v.ToolCall.Type, | |
| Function: &models.FunctionCallDelta{ | |
| Name: v.ToolCall.Function.Name, | |
| Arguments: v.ToolCall.Function.Arguments, | |
| }, | |
| }, | |
| }, | |
| }, nil) | |
| toolCallIndex++ | |
| } | |
| } | |
| case string: | |
| if !started { | |
| writeChunk(models.StreamDelta{Role: "assistant"}, nil) | |
| started = true | |
| } | |
| if v != "" { | |
| writeChunk(models.StreamDelta{Content: v}, nil) | |
| } | |
| case models.Usage: | |
| // 使用统计 - 通常在最后发送 | |
| continue | |
| case error: | |
| logrus.WithError(v).Error("Stream generator error") | |
| WriteSSEEvent(c.Writer, "", "[DONE]") | |
| return | |
| default: | |
| logrus.Warnf("Unknown data type in stream: %T", v) | |
| } | |
| } | |
| } | |
| } | |
| // NonStreamChatCompletion 处理非流式聊天完成 | |
| func NonStreamChatCompletion(c *gin.Context, chatGenerator <-chan interface{}, modelName string) { | |
| var fullContent strings.Builder | |
| var usage models.Usage | |
| toolCalls := make([]models.ToolCall, 0, 2) | |
| finishReason := "stop" | |
| // 收集所有数据 | |
| ctx := c.Request.Context() | |
| for { | |
| select { | |
| case <-ctx.Done(): | |
| c.JSON(http.StatusRequestTimeout, models.NewErrorResponse( | |
| "Request timeout", | |
| "timeout_error", | |
| "request_timeout", | |
| )) | |
| return | |
| case data, ok := <-chatGenerator: | |
| if !ok { | |
| // 数据收集完成,返回响应 | |
| responseID := GenerateChatCompletionID() | |
| message := models.Message{ | |
| Role: "assistant", | |
| } | |
| if fullContent.Len() > 0 || len(toolCalls) == 0 { | |
| message.Content = fullContent.String() | |
| } | |
| if len(toolCalls) > 0 { | |
| message.ToolCalls = toolCalls | |
| finishReason = "tool_calls" | |
| } | |
| response := models.NewChatCompletionResponse( | |
| responseID, | |
| modelName, | |
| message, | |
| finishReason, | |
| usage, | |
| ) | |
| c.JSON(http.StatusOK, response) | |
| return | |
| } | |
| switch v := data.(type) { | |
| case models.AssistantEvent: | |
| switch v.Kind { | |
| case models.AssistantEventText: | |
| fullContent.WriteString(v.Text) | |
| case models.AssistantEventToolCall: | |
| if v.ToolCall != nil { | |
| toolCalls = append(toolCalls, *v.ToolCall) | |
| } | |
| } | |
| case string: | |
| fullContent.WriteString(v) | |
| case models.Usage: | |
| usage = v | |
| case error: | |
| middleware.HandleError(c, v) | |
| return | |
| } | |
| } | |
| } | |
| } | |
| // ErrorWrapper 错误包装器 | |
| func ErrorWrapper(handler func(*gin.Context) error) gin.HandlerFunc { | |
| return func(c *gin.Context) { | |
| if err := handler(c); err != nil { | |
| logrus.WithError(err).Error("Handler error") | |
| if !c.Writer.Written() { | |
| c.JSON(http.StatusInternalServerError, models.NewErrorResponse( | |
| "Internal server error", | |
| "internal_error", | |
| "", | |
| )) | |
| } | |
| } | |
| } | |
| } | |
| // SafeStreamWrapper 安全流式包装器 | |
| func SafeStreamWrapper(handler func(*gin.Context, <-chan interface{}, string), c *gin.Context, chatGenerator <-chan interface{}, modelName string) { | |
| defer func() { | |
| if r := recover(); r != nil { | |
| logrus.WithField("panic", r).Error("Panic in stream handler") | |
| if !c.Writer.Written() { | |
| c.JSON(http.StatusInternalServerError, models.NewErrorResponse( | |
| "Internal server error", | |
| "panic_error", | |
| "", | |
| )) | |
| } | |
| } | |
| }() | |
| firstItem, ok := <-chatGenerator | |
| if !ok { | |
| middleware.HandleError(c, middleware.NewCursorWebError(http.StatusInternalServerError, "empty stream")) | |
| return | |
| } | |
| if err, isErr := firstItem.(error); isErr { | |
| middleware.HandleError(c, err) | |
| return | |
| } | |
| buffered := make(chan interface{}, 1) | |
| buffered <- firstItem | |
| ctx := c.Request.Context() | |
| go func() { | |
| defer close(buffered) | |
| for { | |
| select { | |
| case <-ctx.Done(): | |
| return | |
| case item, ok := <-chatGenerator: | |
| if !ok { | |
| return | |
| } | |
| select { | |
| case buffered <- item: | |
| case <-ctx.Done(): | |
| return | |
| } | |
| } | |
| } | |
| }() | |
| handler(c, buffered, modelName) | |
| } | |
| // CreateHTTPClient 创建HTTP客户端 | |
| func CreateHTTPClient(timeout time.Duration) *http.Client { | |
| return &http.Client{ | |
| Timeout: timeout, | |
| Transport: &http.Transport{ | |
| Proxy: http.ProxyFromEnvironment, | |
| MaxIdleConns: 100, | |
| MaxIdleConnsPerHost: 10, | |
| IdleConnTimeout: 90 * time.Second, | |
| }, | |
| } | |
| } | |
| // ReadSSEStream 读取SSE流 | |
| func ReadSSEStream(ctx context.Context, resp *http.Response, output chan<- interface{}) error { | |
| scanner := bufio.NewScanner(resp.Body) | |
| scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) | |
| defer resp.Body.Close() | |
| for scanner.Scan() { | |
| select { | |
| case <-ctx.Done(): | |
| return ctx.Err() | |
| default: | |
| } | |
| line := scanner.Text() | |
| data := ParseSSELine(line) | |
| if data == "" { | |
| continue | |
| } | |
| if data == "[DONE]" { | |
| return nil | |
| } | |
| // 尝试解析JSON数据 | |
| var eventData models.CursorEventData | |
| if err := json.Unmarshal([]byte(data), &eventData); err != nil { | |
| logrus.WithError(err).Debugf("Failed to parse SSE data: %s", data) | |
| continue | |
| } | |
| // 处理不同类型的事件 | |
| switch eventData.Type { | |
| case "error": | |
| if eventData.ErrorText != "" { | |
| return fmt.Errorf("cursor API error: %s", eventData.ErrorText) | |
| } | |
| case "finish": | |
| if eventData.MessageMetadata != nil && eventData.MessageMetadata.Usage != nil { | |
| usage := models.Usage{ | |
| PromptTokens: eventData.MessageMetadata.Usage.InputTokens, | |
| CompletionTokens: eventData.MessageMetadata.Usage.OutputTokens, | |
| TotalTokens: eventData.MessageMetadata.Usage.TotalTokens, | |
| } | |
| output <- usage | |
| } | |
| return nil | |
| default: | |
| if eventData.Delta != "" { | |
| output <- eventData.Delta | |
| } | |
| } | |
| } | |
| return scanner.Err() | |
| } | |
| // ValidateModel 验证模型名称 | |
| func ValidateModel(model string, validModels []string) bool { | |
| for _, validModel := range validModels { | |
| if validModel == model { | |
| return true | |
| } | |
| } | |
| return false | |
| } | |
| // SanitizeContent 清理内容 | |
| func SanitizeContent(content string) string { | |
| // 移除可能的恶意内容 | |
| content = strings.ReplaceAll(content, "\x00", "") | |
| return content | |
| } | |
| // stringPtr 返回字符串指针 | |
| func stringPtr(s string) *string { | |
| return &s | |
| } | |
| // CopyHeaders 复制HTTP头 | |
| func CopyHeaders(dst, src http.Header, skipHeaders []string) { | |
| skipMap := make(map[string]bool) | |
| for _, header := range skipHeaders { | |
| skipMap[strings.ToLower(header)] = true | |
| } | |
| for key, values := range src { | |
| if skipMap[strings.ToLower(key)] { | |
| continue | |
| } | |
| for _, value := range values { | |
| dst.Add(key, value) | |
| } | |
| } | |
| } | |
| // IsJSONContentType 检查是否为JSON内容类型 | |
| func IsJSONContentType(contentType string) bool { | |
| return strings.Contains(strings.ToLower(contentType), "application/json") | |
| } | |
| // ReadRequestBody 读取请求体 | |
| func ReadRequestBody(r *http.Request) ([]byte, error) { | |
| if r.Body == nil { | |
| return nil, nil | |
| } | |
| body, err := io.ReadAll(r.Body) | |
| if err != nil { | |
| return nil, fmt.Errorf("failed to read request body: %w", err) | |
| } | |
| return body, nil | |
| } | |
| // RunJS 执行JavaScript代码并返回标准输出内容 | |
| func RunJS(jsCode string) (string, error) { | |
| // 添加crypto模块导入并设置为全局变量 | |
| // 注意:使用stdin时,我们需要确保代码是自包含的 | |
| finalJS := `const crypto = require('crypto').webcrypto; | |
| global.crypto = crypto; | |
| globalThis.crypto = crypto; | |
| // 在Node.js环境中创建window对象 | |
| if (typeof window === 'undefined') { global.window = global; } | |
| window.crypto = crypto; | |
| this.crypto = crypto; | |
| ` + jsCode | |
| // 执行Node.js命令,使用stdin输入代码 | |
| cmd := exec.Command("node") | |
| // 设置输入 | |
| cmd.Stdin = strings.NewReader(finalJS) | |
| output, err := cmd.Output() | |
| if err != nil { | |
| if exitErr, ok := err.(*exec.ExitError); ok { | |
| return "", fmt.Errorf("node.js execution failed (exit code: %d)\nSTDOUT:\n%s\nSTDERR:\n%s", | |
| exitErr.ExitCode(), string(output), string(exitErr.Stderr)) | |
| } | |
| return "", fmt.Errorf("failed to execute node.js: %w", err) | |
| } | |
| return strings.TrimSpace(string(output)), nil | |
| } | |