diff --git a/README.md b/README.md index 7b15ffd..31d3cf6 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ [SurveyController](https://github.com/SurveyController/SurveyController) 的核心 HTTP 提交 API 服务。 -负责解析问卷、创建提交任务、查询任务、停止任务、读取任务日志和解析二维码。 +负责解析问卷、创建提交任务、查询任务、停止任务、读取任务日志、导入导出配置、导出报告和解析二维码。 > [!CAUTION] > @@ -69,7 +69,7 @@ model = "deepseek-chat" api_key = "" ``` -服务固定监听 `127.0.0.1`,配置文件只改端口。 +服务固定监听 `127.0.0.1`,配置文件只改端口。AI 密钥和模型配置属于服务端私有配置,不从任务运行配置 JSON 中传入。 ## 接口列表 @@ -80,12 +80,33 @@ api_key = "" | `GET` | `/api/tasks` | 查询任务列表。按创建时间倒序返回。 | | `GET` | `/api/tasks/{id}` | 查询单个任务详情。 | | `GET` | `/api/tasks/{id}/logs` | 分页读取指定任务日志。支持 `after` 游标和 `limit` 条数参数。 | +| `GET` | `/api/tasks/{id}/config` | 导出指定任务的运行配置 JSON。 | +| `GET` | `/api/tasks/{id}/report` | 导出指定任务报告。默认 JSON,`?format=csv` 导出日志表。 | | `POST` | `/api/surveys/parse` | 解析问卷链接,返回问卷标题、平台和题目结构。不会提交答案。 | | `POST` | `/api/configs` | 生成默认运行配置。传入问卷链接时会先解析问卷,再补全题目配置;不传链接时返回空模板。 | -| `POST` | `/api/tasks` | 创建提交任务。任务异步运行,创建成功只表示已进入任务队列。 | +| `POST` | `/api/configs/import` | 导入并标准化 Python/Go 兼容运行配置 JSON。支持直接配置对象或 `{ "config": ... }` 包络。 | +| `POST` | `/api/configs/export` | 导出标准化运行配置 JSON 文件。支持直接配置对象或 `{ "config": ... }` 包络。 | +| `POST` | `/api/tasks` | 创建提交任务。任务异步运行,创建成功只表示已进入任务队列。支持直接配置对象或 `{ "config": ... }` 包络。 | | `POST` | `/api/tasks/{id}/stop` | 停止指定任务。任务不存在时返回错误。 | +| `POST` | `/api/ai/test` | 测试传入的 AI 连接参数是否可用,成功时返回模型回复预览。 | | `POST` | `/api/qrcode/decode` | 从二维码图片中解析问卷链接。 | +## 配置兼容 + +`POST /api/tasks`、`POST /api/configs/import` 和 `POST /api/configs/export` 按兼容模式读取运行配置,既支持直接传运行配置对象,也支持 `{ "config": ... }` 包络。 + +Go 会兼容 Python codec 的宽松输入形态:数字字段可接受字符串数字,布尔字段可接受 `true/false`、`1/0`、`yes/no`,`answer_duration` 可接受旧版单值或单元素数组并转换为上下浮动范围,`answer_datetime_window` 会按 `YYYY-MM-DD HH:MM:SS` 归一化。随机 UA 支持 Python 当前 preset 键 `wechat_android`、`mobile_android`、`pc_web`,同时兼容旧版 Go 键 `wechat`、`mobile`、`pc`。 + +SurveyCore 不包含 SurveyController 原有的账号、额度、设备身份或私有服务接入逻辑。来自旧配置的此类字段会被视为遗留字段并忽略;二次开发者可在 SurveyCore 外层自行实现鉴权和业务服务。 + +Go 生成或导出的配置默认带 `config_schema_version=6`。其他请求包络(例如 `/api/surveys/parse`、`/api/configs`)保持严格 JSON 校验,避免调用方把错误参数静默传入。 + +## AI 服务 + +填空题可通过题目配置中的 `ai_enabled` 或 `multi_text_blank_ai_flags` 启用服务端 AI。实际 AI 密钥、Base URL 和模型默认值由 `configs/surveycore.toml` 的 `[ai]` 分区提供,并在任务执行时注入执行配置。 + +当前支持 OpenAI 兼容 Chat Completions、自定义 Base URL、Responses API 和自动协议 fallback。SurveyCore 不内置免费 AI 私有服务链路。 + ## 错误响应 API 错误统一返回稳定错误码、用户消息和调试详情: @@ -109,10 +130,25 @@ API 错误统一返回稳定错误码、用户消息和调试详情: | `validation_error` | 业务参数未通过校验。 | | `not_found` | 任务或资源不存在。 | | `upstream_error` | 问卷平台解析、配置生成等上游调用失败。 | +| `ai_config_error` | AI 配置不完整或不支持。 | +| `ai_connection_failed` | AI 连接测试或调用失败。 | | `internal_error` | 服务内部错误。 | ## 任务状态 +任务详情响应会包含稳定进度和失败字段: + +| 字段 | 含义 | +|---|---| +| `progress.current` | 当前已成功提交数。 | +| `progress.target` | 目标提交数。 | +| `progress.success` | 成功提交数。 | +| `progress.fail` | 失败提交数。 | +| `progress.percent` | 完成比例,范围 `0` 到 `1`。 | +| `error_code` | 标准化任务错误码,例如 `fill_failed`、`submission_verification_required`、`survey_provider_unavailable`、`user_stopped`。 | +| `failure_reason` | 失败原因,优先使用运行时终止原因,其次使用错误消息或停止消息。 | +| `terminal_stop_category` | 终止类别,例如 `fail_threshold`、`reverse_fill_exhausted`、`submission_verification`、`target_reached`。 | + | 状态 | 含义 | |---|---| | `pending` | 已创建,等待运行。 | @@ -122,6 +158,12 @@ API 错误统一返回稳定错误码、用户消息和调试详情: | `stopped` | 已停止。 | | `interrupted` | 服务重启导致中断。 | +## 能力边界 + +SurveyCore 是本地 HTTP/API 化的核心执行内核。Python 项目继续负责桌面 GUI、安装更新和用户交互;Go 项目负责解析、配置、任务执行、状态、日志和可被桌面端调用的稳定 API。 + +SurveyCore 不包含 PySide GUI,也不引入 Playwright、Selenium 或浏览器兼容提交层。 + ## 许可证 Mozilla Public License Version 2.0 diff --git a/docs/sdk/api.md b/docs/sdk/api.md index 8aa7a82..5fd0572 100644 --- a/docs/sdk/api.md +++ b/docs/sdk/api.md @@ -118,7 +118,9 @@ curl -X POST http://localhost:19178/api/surveys/parse \ ```json { - "error": "url 不能为空" + "error": "url 不能为空", + "code": "validation_error", + "message": "url 不能为空" } ``` @@ -236,6 +238,18 @@ Accept: application/json 完整字段见 [数据结构](./schemas#runtimeconfig)。 +也支持 Python 桌面端常用包络: + +```json +{ + "config": { + "url": "https://www.wjx.cn/vm/example.aspx", + "target": 10, + "threads": 2 + } +} +``` + ### 请求示例 ```bash @@ -441,6 +455,121 @@ curl "http://localhost:19178/api/tasks/9f4c6b2b1a2d4e9f/logs?after=12&limit=100" `event` 字段来自 Go 结构体,目前没有 JSON 标签,所以返回字段是大写驼峰。 +## 导出任务配置 + +```http +GET /api/tasks/{id}/config +``` + +返回指定任务保存的运行配置 JSON,并设置附件下载响应头。 + +### 请求示例 + +```bash +curl -OJ http://localhost:19178/api/tasks/9f4c6b2b1a2d4e9f/config +``` + +## 导出任务报告 + +```http +GET /api/tasks/{id}/report +``` + +默认导出 JSON 报告。传 `?format=csv` 时导出日志 CSV。 + +### 请求示例 + +```bash +curl -OJ http://localhost:19178/api/tasks/9f4c6b2b1a2d4e9f/report +curl -OJ "http://localhost:19178/api/tasks/9f4c6b2b1a2d4e9f/report?format=csv" +``` + +### JSON 返回字段 + +| 字段 | 说明 | +|---|---| +| `task_id` | 任务 ID。 | +| `status` | 任务状态。 | +| `config` | 任务配置摘要。 | +| `progress` | 当前进度摘要。 | +| `error_code` | 标准化错误码。 | +| `failure_reason` | 失败原因。 | +| `terminal_stop_category` | 运行时终止类别。 | +| `logs` | 完整任务日志。 | + +## 导入兼容配置 + +```http +POST /api/configs/import +``` + +导入并标准化运行配置。支持直接传 `RuntimeConfig`,也支持 `{ "config": ... }` 包络。旧配置中的私有业务字段会被忽略。 + +### 请求示例 + +```bash +curl -X POST http://localhost:19178/api/configs/import \ + -H "Content-Type: application/json" \ + -d "{\"config\":{\"url\":\"https://www.wjx.cn/vm/example.aspx\",\"target\":\"10\"}}" +``` + +## 导出兼容配置 + +```http +POST /api/configs/export +``` + +读取兼容配置后返回标准化 JSON 文件。 + +### 请求示例 + +```bash +curl -OJ -X POST http://localhost:19178/api/configs/export \ + -H "Content-Type: application/json" \ + -d "{\"url\":\"https://www.wjx.cn/vm/example.aspx\",\"target\":10}" +``` + +## 测试 AI 连接 + +```http +POST /api/ai/test +``` + +用于测试一组 AI 连接参数是否可用。SurveyCore 不内置免费 AI 私有服务链路;生产任务的 AI 默认值建议放在服务端 `configs/surveycore.toml`。 + +### 请求体 + +```json +{ + "ai_provider": "custom", + "ai_api_key": "sk-...", + "ai_base_url": "https://api.example.com/v1", + "ai_api_protocol": "responses", + "ai_model": "example-model", + "question": "这是一个测试问题" +} +``` + +| 字段 | 类型 | 必填 | 说明 | +|---|---|---|---| +| `ai_provider` | `string` | 否 | `custom` 表示自定义 OpenAI 兼容服务。 | +| `ai_api_key` | `string` | 是 | AI API Key。 | +| `ai_base_url` | `string` | 否 | Base URL 或完整 endpoint。为空时使用默认 DeepSeek Base URL。 | +| `ai_api_protocol` | `string` | 否 | `auto`、`chat_completions` 或 `responses`。 | +| `ai_model` | `string` | 否 | 模型名。为空时使用默认模型。 | +| `ai_system_prompt` | `string` | 否 | 系统提示词。 | +| `question` | `string` | 否 | 测试问题。 | + +### 返回示例 + +```json +{ + "ok": true, + "message": "AI 连接测试成功", + "preview": "连接成功" +} +``` + ## 解析二维码 ```http diff --git a/docs/sdk/errors.md b/docs/sdk/errors.md index 88aeda8..29b6a6e 100644 --- a/docs/sdk/errors.md +++ b/docs/sdk/errors.md @@ -8,11 +8,14 @@ outline: deep ```json { - "error": "错误原因" + "error": "任务配置无效", + "code": "validation_error", + "message": "任务配置无效", + "detail": "url 不能为空" } ``` -当前版本没有独立的业务错误码字段。错误处理以 HTTP 状态码为准。 +客户端应优先看 HTTP 状态码和 `code`,`detail` 只用于调试展示。 ## 状态码 @@ -37,12 +40,20 @@ outline: deep | `POST /api/surveys/parse` | `502` | 问卷平台访问失败、链接不支持、问卷关闭、需要登录。 | | `POST /api/configs` | `400` | JSON 无效、字段名错误。 | | `POST /api/configs` | `502` | 传入 `url` 后解析问卷失败。 | +| `POST /api/configs/import` | `400` | JSON 无效或配置包络无效。 | +| `POST /api/configs/export` | `400` | JSON 无效或配置包络无效。 | | `POST /api/tasks` | `400` | JSON 无效、字段名错误、任务保存失败。 | | `GET /api/tasks` | `200` | 成功返回任务列表。 | | `GET /api/tasks/{id}` | `404` | 任务 ID 不存在。 | | `POST /api/tasks/{id}/stop` | `404` | 任务 ID 不存在。 | | `GET /api/tasks/{id}/logs` | `400` | `after` 或 `limit` 参数不合法。 | | `GET /api/tasks/{id}/logs` | `404` | 任务 ID 不存在。 | +| `GET /api/tasks/{id}/config` | `404` | 任务 ID 不存在。 | +| `GET /api/tasks/{id}/report` | `400` | `format` 参数不合法。 | +| `GET /api/tasks/{id}/report` | `404` | 任务 ID 不存在。 | +| `GET /api/tasks/{id}/report` | `500` | 日志读取或报告生成失败。 | +| `POST /api/ai/test` | `400` | JSON 无效或 AI 配置不完整。 | +| `POST /api/ai/test` | `502` | AI 上游连接失败。 | | `POST /api/qrcode/decode` | `400` | 表单无效、缺少 `image` 文件、二维码无法解析出问卷链接。 | | `POST /api/qrcode/decode` | `500` | 临时文件、文件读取或服务内部处理失败。 | @@ -65,7 +76,10 @@ outline: deep ```json { - "error": "JSON 请求体无效: json: unknown field \"taskTarget\"" + "error": "JSON 请求体无效", + "code": "invalid_json", + "message": "JSON 请求体无效", + "detail": "json: unknown field \"taskTarget\"" } ``` @@ -73,7 +87,9 @@ URL 为空: ```json { - "error": "url 不能为空" + "error": "url 不能为空", + "code": "validation_error", + "message": "url 不能为空" } ``` @@ -81,7 +97,10 @@ URL 为空: ```json { - "error": "日志游标必须是非负整数" + "error": "日志查询参数无效", + "code": "invalid_query", + "message": "日志查询参数无效", + "detail": "日志游标必须是非负整数" } ``` @@ -89,7 +108,10 @@ URL 为空: ```json { - "error": "日志条数必须是 1 到 1000 之间的整数" + "error": "日志查询参数无效", + "code": "invalid_query", + "message": "日志查询参数无效", + "detail": "日志条数必须是 1 到 1000 之间的整数" } ``` @@ -97,7 +119,10 @@ URL 为空: ```json { - "error": "缺少 image 文件" + "error": "缺少 image 文件", + "code": "validation_error", + "message": "缺少 image 文件", + "detail": "http: no such file" } ``` @@ -111,7 +136,9 @@ URL 为空: ```json { - "error": "任务不存在" + "error": "任务不存在", + "code": "not_found", + "message": "任务不存在" } ``` @@ -125,7 +152,10 @@ URL 为空: ```json { - "error": "无法获取题目: 问卷已停止,无法作答" + "error": "问卷解析失败", + "code": "upstream_error", + "message": "问卷解析失败", + "detail": "无法获取题目: 问卷已停止,无法作答" } ``` diff --git a/docs/sdk/examples.md b/docs/sdk/examples.md index b5f781e..db4f1af 100644 --- a/docs/sdk/examples.md +++ b/docs/sdk/examples.md @@ -42,6 +42,15 @@ async function createTaskFromURL(url) { async function getTask(taskID) { return requestJSON(`/api/tasks/${taskID}`); } + +async function exportTaskReport(taskID) { + const response = await fetch(`${baseURL}/api/tasks/${taskID}/report`); + if (!response.ok) { + const data = await response.json(); + throw new Error(data.message || data.error || `HTTP ${response.status}`); + } + return response.json(); +} ``` ## Python @@ -75,6 +84,9 @@ task = request_json("POST", "/api/tasks", json=config) task_id = task["task_id"] print(task_id) + +report = request_json("GET", f"/api/tasks/{task_id}/report") +print(report["status"]) ``` ## 上传二维码 diff --git a/docs/sdk/index.md b/docs/sdk/index.md index bb83460..c0d091b 100644 --- a/docs/sdk/index.md +++ b/docs/sdk/index.md @@ -68,7 +68,10 @@ Accept: application/json 4. 调 `POST /api/tasks` 创建任务。 5. 调 `GET /api/tasks/{id}` 查询任务状态。 6. 调 `GET /api/tasks/{id}/logs` 分页读取日志。 -7. 需要停止时调 `POST /api/tasks/{id}/stop`。 +7. 需要归档结果时调 `GET /api/tasks/{id}/config` 或 `GET /api/tasks/{id}/report`。 +8. 需要停止时调 `POST /api/tasks/{id}/stop`。 + +SurveyCore 不包含 SurveyController 原有的账号、额度、设备身份或私有服务接入逻辑。二次开发者应在 SurveyCore 外层实现自己的鉴权和业务服务。 ## 接口清单 @@ -78,9 +81,14 @@ Accept: application/json | `GET` | `/api/version` | 读取服务版本号。 | | `POST` | `/api/surveys/parse` | 解析问卷链接。 | | `POST` | `/api/configs` | 生成默认运行配置。 | +| `POST` | `/api/configs/import` | 导入并标准化兼容运行配置。 | +| `POST` | `/api/configs/export` | 导出标准化运行配置文件。 | | `POST` | `/api/tasks` | 创建提交任务。 | | `GET` | `/api/tasks` | 查询任务列表。 | | `GET` | `/api/tasks/{id}` | 查询单个任务。 | | `POST` | `/api/tasks/{id}/stop` | 停止任务。 | | `GET` | `/api/tasks/{id}/logs` | 分页读取任务日志。 | +| `GET` | `/api/tasks/{id}/config` | 导出任务配置。 | +| `GET` | `/api/tasks/{id}/report` | 导出任务报告。 | +| `POST` | `/api/ai/test` | 测试 AI 连接参数。 | | `POST` | `/api/qrcode/decode` | 解析二维码图片里的问卷链接。 | diff --git a/internal/api/ai.go b/internal/api/ai.go new file mode 100644 index 0000000..39296a1 --- /dev/null +++ b/internal/api/ai.go @@ -0,0 +1,57 @@ +package api + +import ( + "errors" + "net/http" + "strings" + + "github.com/SurveyController/SurveyCore/internal/questions" +) + +type aiTestRequest struct { + AIProvider string `json:"ai_provider,omitempty"` + AIAPIKey string `json:"ai_api_key,omitempty"` + AIBaseURL string `json:"ai_base_url,omitempty"` + AIAPIProtocol string `json:"ai_api_protocol,omitempty"` + AIModel string `json:"ai_model,omitempty"` + AISystemPrompt string `json:"ai_system_prompt,omitempty"` + Question string `json:"question,omitempty"` +} + +func (s *Server) handleTestAI(w http.ResponseWriter, r *http.Request) { + var req aiTestRequest + if err := decodeStrictJSON(r, &req); err != nil { + writeError(w, http.StatusBadRequest, "invalid_json", "JSON 请求体无效", err) + return + } + + question := strings.TrimSpace(req.Question) + if question == "" { + question = "这是一个测试问题,请回复'连接成功'" + } + client := questions.NewAIClient(questions.AIConfig{ + Provider: req.AIProvider, + APIKey: req.AIAPIKey, + BaseURL: req.AIBaseURL, + Protocol: req.AIAPIProtocol, + Model: req.AIModel, + SystemPrompt: req.AISystemPrompt, + }) + preview, err := client.GenerateAnswer(question, "fill_blank", 1) + if err != nil { + status := http.StatusBadGateway + code := "ai_connection_failed" + var aiErr *questions.AIError + if errors.As(err, &aiErr) && aiErr.Kind == questions.AIErrorConfig { + status = http.StatusBadRequest + code = "ai_config_error" + } + writeError(w, status, code, "AI 连接测试失败", err) + return + } + writeJSON(w, http.StatusOK, map[string]any{ + "ok": true, + "message": "AI 连接测试成功", + "preview": preview, + }) +} diff --git a/internal/api/export.go b/internal/api/export.go new file mode 100644 index 0000000..7d1f8ff --- /dev/null +++ b/internal/api/export.go @@ -0,0 +1,264 @@ +package api + +import ( + "bytes" + "encoding/csv" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "time" + + "github.com/SurveyController/SurveyCore/internal/config" + "github.com/SurveyController/SurveyCore/internal/models" + "github.com/SurveyController/SurveyCore/internal/tasks" +) + +const reportLogPageSize = 1000 + +type configImportEnvelope struct { + Config *models.RuntimeConfig `json:"config"` +} + +type taskReport struct { + TaskID string `json:"task_id"` + Status string `json:"status"` + CreatedAt time.Time `json:"created_at"` + StartedAt *time.Time `json:"started_at,omitempty"` + FinishedAt *time.Time `json:"finished_at,omitempty"` + DurationMS int64 `json:"duration_ms,omitempty"` + ErrorCode string `json:"error_code,omitempty"` + FailureReason string `json:"failure_reason,omitempty"` + TerminalStopCategory string `json:"terminal_stop_category,omitempty"` + Error string `json:"error,omitempty"` + StopMessage string `json:"stop_message,omitempty"` + Config taskReportConfig `json:"config"` + Progress *tasks.TaskProgress `json:"progress,omitempty"` + ThreadProgress []map[string]any `json:"thread_progress,omitempty"` + Logs []tasks.TaskLog `json:"logs"` +} + +type taskReportConfig struct { + URL string `json:"url,omitempty"` + SurveyTitle string `json:"survey_title,omitempty"` + SurveyProvider string `json:"survey_provider,omitempty"` + Target int `json:"target,omitempty"` + Threads int `json:"threads,omitempty"` + RandomUAEnabled bool `json:"random_ua_enabled,omitempty"` + ReverseFillEnabled bool `json:"reverse_fill_enabled,omitempty"` +} + +func (s *Server) handleImportConfig(w http.ResponseWriter, r *http.Request) { + cfg, err := readCompatibleRuntimeConfig(r) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_json", "配置 JSON 请求体无效", err) + return + } + config.MergeDefaults(cfg) + writeJSON(w, http.StatusOK, cfg) +} + +func (s *Server) handleExportConfig(w http.ResponseWriter, r *http.Request) { + cfg, err := readCompatibleRuntimeConfig(r) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_json", "配置 JSON 请求体无效", err) + return + } + config.MergeDefaults(cfg) + writeRuntimeConfigDownload(w, cfg, "surveycore-config.json") +} + +func (s *Server) handleExportTaskConfig(w http.ResponseWriter, r *http.Request) { + task, ok := s.manager.Get(r.PathValue("id")) + if !ok { + writeError(w, http.StatusNotFound, "not_found", "任务不存在", nil) + return + } + writeRuntimeConfigDownload(w, task.Config, "surveycore-task-"+task.ID+"-config.json") +} + +func (s *Server) handleExportTaskReport(w http.ResponseWriter, r *http.Request) { + task, ok := s.manager.Get(r.PathValue("id")) + if !ok { + writeError(w, http.StatusNotFound, "not_found", "任务不存在", nil) + return + } + logs, err := s.loadAllTaskLogs(task.ID) + if err != nil { + writeError(w, http.StatusInternalServerError, "internal_error", "任务日志读取失败", err) + return + } + + format := strings.ToLower(strings.TrimSpace(r.URL.Query().Get("format"))) + if format == "" || format == "json" { + writeJSONDownload(w, http.StatusOK, buildTaskReport(task, logs), "surveycore-task-"+task.ID+"-report.json") + return + } + if format == "csv" { + writeCSVReport(w, task, logs) + return + } + writeError(w, http.StatusBadRequest, "invalid_query", "报告格式无效", errors.New("format 仅支持 json 或 csv")) +} + +func readCompatibleRuntimeConfig(r *http.Request) (*models.RuntimeConfig, error) { + defer r.Body.Close() + data, err := io.ReadAll(io.LimitReader(r.Body, 16<<20)) + if err != nil { + return nil, err + } + data = bytes.TrimSpace(data) + if len(data) == 0 { + return nil, errors.New("请求体为空") + } + var envelope configImportEnvelope + if err := json.Unmarshal(data, &envelope); err == nil && envelope.Config != nil { + return envelope.Config, nil + } + var cfg models.RuntimeConfig + if err := json.Unmarshal(data, &cfg); err != nil { + return nil, err + } + return &cfg, nil +} + +func writeRuntimeConfigDownload(w http.ResponseWriter, cfg *models.RuntimeConfig, filename string) { + data, err := models.SerializeRuntimeConfig(cfg) + if err != nil { + writeError(w, http.StatusInternalServerError, "internal_error", "配置序列化失败", err) + return + } + writeBytesDownload(w, http.StatusOK, "application/json; charset=utf-8", filename, data) +} + +func writeJSONDownload(w http.ResponseWriter, status int, value any, filename string) { + data, err := json.MarshalIndent(value, "", " ") + if err != nil { + writeError(w, http.StatusInternalServerError, "internal_error", "报告序列化失败", err) + return + } + data = append(data, '\n') + writeBytesDownload(w, status, "application/json; charset=utf-8", filename, data) +} + +func writeBytesDownload(w http.ResponseWriter, status int, contentType, filename string, data []byte) { + w.Header().Set("Content-Type", contentType) + w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, filename)) + w.WriteHeader(status) + _, _ = w.Write(data) +} + +func (s *Server) loadAllTaskLogs(taskID string) ([]tasks.TaskLog, error) { + var ( + all []tasks.TaskLog + after int64 + ) + for { + page, err := s.manager.Logs(taskID, after, reportLogPageSize) + if err != nil { + return nil, err + } + all = append(all, page.Logs...) + if !page.HasMore || page.NextCursor == 0 { + return all, nil + } + after = page.NextCursor + } +} + +func buildTaskReport(task *tasks.TaskRecord, logs []tasks.TaskLog) taskReport { + report := taskReport{ + TaskID: task.ID, + Status: task.Status, + CreatedAt: task.CreatedAt, + StartedAt: task.StartedAt, + FinishedAt: task.FinishedAt, + ErrorCode: task.ErrorCode, + FailureReason: task.FailureReason, + TerminalStopCategory: task.TerminalStopCategory, + Error: task.Error, + StopMessage: task.StopMessage, + Config: summarizeReportConfig(task.Config), + Progress: task.Progress, + Logs: logs, + } + if task.StartedAt != nil { + end := time.Now() + if task.FinishedAt != nil { + end = *task.FinishedAt + } + report.DurationMS = end.Sub(*task.StartedAt).Milliseconds() + } + if task.State != nil { + report.ThreadProgress = task.State.SnapshotThreadProgress() + } + return report +} + +func summarizeReportConfig(cfg *models.RuntimeConfig) taskReportConfig { + if cfg == nil { + return taskReportConfig{} + } + return taskReportConfig{ + URL: cfg.URL, + SurveyTitle: cfg.SurveyTitle, + SurveyProvider: cfg.SurveyProvider, + Target: cfg.Target, + Threads: cfg.Threads, + RandomUAEnabled: cfg.RandomUAEnabled, + ReverseFillEnabled: cfg.ReverseFillEnabled, + } +} + +func writeCSVReport(w http.ResponseWriter, task *tasks.TaskRecord, logs []tasks.TaskLog) { + var buf bytes.Buffer + writer := csv.NewWriter(&buf) + _ = writer.Write([]string{"task_id", "status", "log_id", "timestamp", "level", "message", "worker", "current", "total", "success", "fail"}) + for _, entry := range logs { + worker, current, total, success, fail := eventFields(entry) + _ = writer.Write([]string{ + task.ID, + task.Status, + strconv.FormatInt(entry.ID, 10), + entry.Timestamp.Format(time.RFC3339Nano), + entry.Level, + entry.Message, + worker, + current, + total, + success, + fail, + }) + } + writer.Flush() + if err := writer.Error(); err != nil { + writeError(w, http.StatusInternalServerError, "internal_error", "CSV 报告生成失败", err) + return + } + writeBytesDownload(w, http.StatusOK, "text/csv; charset=utf-8", "surveycore-task-"+task.ID+"-report.csv", buf.Bytes()) +} + +func eventFields(entry tasks.TaskLog) (worker, current, total, success, fail string) { + if entry.Event == nil { + return fieldString(entry.Fields, "worker"), fieldString(entry.Fields, "current"), fieldString(entry.Fields, "total"), "", "" + } + return entry.Event.ThreadName, + strconv.Itoa(entry.Event.Current), + strconv.Itoa(entry.Event.Total), + strconv.FormatBool(entry.Event.Success), + strconv.FormatBool(entry.Event.Fail) +} + +func fieldString(fields map[string]any, key string) string { + if fields == nil { + return "" + } + value, ok := fields[key] + if !ok || value == nil { + return "" + } + return fmt.Sprint(value) +} diff --git a/internal/api/server.go b/internal/api/server.go index 8be0d0e..9e04581 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -48,8 +48,13 @@ func (s *Server) Handler() http.Handler { mux.HandleFunc("GET /api/tasks/{id}", s.handleGetTask) mux.HandleFunc("POST /api/tasks/{id}/stop", s.handleStopTask) mux.HandleFunc("GET /api/tasks/{id}/logs", s.handleTaskLogs) + mux.HandleFunc("GET /api/tasks/{id}/config", s.handleExportTaskConfig) + mux.HandleFunc("GET /api/tasks/{id}/report", s.handleExportTaskReport) mux.HandleFunc("POST /api/surveys/parse", s.handleParseSurvey) mux.HandleFunc("POST /api/configs", s.handleCreateConfig) + mux.HandleFunc("POST /api/configs/import", s.handleImportConfig) + mux.HandleFunc("POST /api/configs/export", s.handleExportConfig) + mux.HandleFunc("POST /api/ai/test", s.handleTestAI) mux.HandleFunc("POST /api/qrcode/decode", s.handleDecodeQR) return loggingMiddleware(mux) } @@ -63,12 +68,12 @@ func (s *Server) handleVersion(w http.ResponseWriter, r *http.Request) { } func (s *Server) handleCreateTask(w http.ResponseWriter, r *http.Request) { - var cfg models.RuntimeConfig - if err := decodeStrictJSON(r, &cfg); err != nil { + cfg, err := readCompatibleRuntimeConfig(r) + if err != nil { writeError(w, http.StatusBadRequest, "invalid_json", "配置 JSON 请求体无效", err) return } - task, err := s.manager.Create(context.Background(), &cfg) + task, err := s.manager.Create(context.Background(), cfg) if err != nil { writeError(w, http.StatusBadRequest, "validation_error", "任务配置无效", err) return diff --git a/internal/api/server_test.go b/internal/api/server_test.go index 53e80c6..8ef1cfa 100644 --- a/internal/api/server_test.go +++ b/internal/api/server_test.go @@ -41,24 +41,73 @@ func TestCreateTaskReturnsTaskID(t *testing.T) { } } -func TestCreateTaskRejectsUnknownLegacyFields(t *testing.T) { +func TestCreateTaskDropsPrivateLegacyFields(t *testing.T) { server := newTestServer(t) reqBody := `{ "url":"https://www.wjx.cn/vm/test.aspx", "target":1, - "proxy_source":"default" + "proxy_source":"default", + "random_ip_user_id":88, + "random_ip_device_id":"device-88" }` req := httptest.NewRequest(http.MethodPost, "/api/tasks", strings.NewReader(reqBody)) rec := httptest.NewRecorder() server.Handler().ServeHTTP(rec, req) - if rec.Code != http.StatusBadRequest { - t.Fatalf("status = %d body=%s, want 400", rec.Code, rec.Body.String()) + if rec.Code != http.StatusAccepted { + t.Fatalf("status = %d body=%s, want 202", rec.Code, rec.Body.String()) } - apiErr := decodeAPIError(t, rec) - if apiErr.Code != "invalid_json" || apiErr.Detail == "" { - t.Fatalf("error = %#v, want invalid_json with detail", apiErr) + var created map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &created); err != nil { + t.Fatal(err) + } + task, ok := server.manager.Get(created["task_id"].(string)) + if !ok || task.Config == nil { + t.Fatalf("created task not found") + } + for _, field := range []string{"proxy_source", "random_ip_user_id", "random_ip_device_id"} { + if _, exists := task.Config.ExtraFields[field]; exists { + t.Fatalf("legacy private field %q was preserved in extras: %#v", field, task.Config.ExtraFields) + } + } +} + +func TestCreateTaskAcceptsPythonConfigEnvelope(t *testing.T) { + server := newTestServer(t) + reqBody := `{ + "config":{ + "url":"https://www.wjx.cn/vm/test.aspx", + "target":1, + "_ai_config_present":true, + "python_only_future_field":"task-envelope" + } + }` + req := httptest.NewRequest(http.MethodPost, "/api/tasks", strings.NewReader(reqBody)) + rec := httptest.NewRecorder() + + server.Handler().ServeHTTP(rec, req) + + if rec.Code != http.StatusAccepted { + t.Fatalf("status = %d body=%s", rec.Code, rec.Body.String()) + } + var created map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &created); err != nil { + t.Fatal(err) + } + taskID, _ := created["task_id"].(string) + if taskID == "" { + t.Fatalf("response = %#v, want task_id", created) + } + task, ok := server.manager.Get(taskID) + if !ok { + t.Fatalf("created task %q not found", taskID) + } + if task.Config == nil || task.Config.URL != "https://www.wjx.cn/vm/test.aspx" { + t.Fatalf("task config = %#v, want envelope config", task.Config) + } + if string(task.Config.ExtraFields["python_only_future_field"]) != `"task-envelope"` { + t.Fatalf("extra fields = %#v, want envelope extras preserved", task.Config.ExtraFields) } } @@ -193,6 +242,38 @@ func TestDecodeQRMissingImageReturns400(t *testing.T) { } } +func TestAITestEndpointReturnsPreview(t *testing.T) { + aiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/responses" { + http.NotFound(w, r) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"output_text":"连接成功"}`)) + })) + defer aiServer.Close() + + server := newTestServer(t) + reqBody := `{ + "ai_provider":"custom", + "ai_api_key":"test-key", + "ai_base_url":"` + aiServer.URL + `/v1", + "ai_api_protocol":"responses", + "ai_model":"test-model" + }` + req := httptest.NewRequest(http.MethodPost, "/api/ai/test", strings.NewReader(reqBody)) + rec := httptest.NewRecorder() + + server.Handler().ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d body=%s, want 200", rec.Code, rec.Body.String()) + } + if !strings.Contains(rec.Body.String(), "连接成功") { + t.Fatalf("response body = %s, want AI preview", rec.Body.String()) + } +} + func TestTaskLogsReturnsCursorPage(t *testing.T) { server := newTestServer(t) task, err := server.manager.Create(t.Context(), nilRuntimeConfig()) @@ -232,6 +313,160 @@ func TestTaskLogsRejectsInvalidCursor(t *testing.T) { } } +func TestImportConfigAcceptsPythonCompatibleJSON(t *testing.T) { + server := newTestServer(t) + reqBody := `{ + "config":{ + "url":"https://www.wjx.cn/vm/test.aspx", + "target":2, + "_ai_config_present":true, + "python_only_future_field":"ignored" + } + }` + req := httptest.NewRequest(http.MethodPost, "/api/configs/import", strings.NewReader(reqBody)) + rec := httptest.NewRecorder() + + server.Handler().ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d body=%s, want 200", rec.Code, rec.Body.String()) + } + var cfg models.RuntimeConfig + if err := json.Unmarshal(rec.Body.Bytes(), &cfg); err != nil { + t.Fatal(err) + } + if cfg.Target != 2 || cfg.Threads != 1 { + t.Fatalf("config = %#v, want imported target with defaults", cfg) + } + var raw map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &raw); err != nil { + t.Fatal(err) + } + if raw["_ai_config_present"] != true || raw["python_only_future_field"] != "ignored" { + t.Fatalf("raw config = %#v, want python-only fields preserved", raw) + } +} + +func TestExportConfigReturnsDownload(t *testing.T) { + server := newTestServer(t) + req := httptest.NewRequest(http.MethodPost, "/api/configs/export", strings.NewReader(`{"url":"https://www.wjx.cn/vm/test.aspx","target":3,"config_schema_version":6}`)) + rec := httptest.NewRecorder() + + server.Handler().ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d body=%s, want 200", rec.Code, rec.Body.String()) + } + if got := rec.Header().Get("Content-Disposition"); !strings.Contains(got, "surveycore-config.json") { + t.Fatalf("Content-Disposition = %q, want config filename", got) + } + var cfg models.RuntimeConfig + if err := json.Unmarshal(rec.Body.Bytes(), &cfg); err != nil { + t.Fatal(err) + } + if cfg.Target != 3 || cfg.Threads != 1 { + t.Fatalf("config = %#v, want exported config with defaults", cfg) + } + var raw map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &raw); err != nil { + t.Fatal(err) + } + if raw["config_schema_version"] != float64(6) { + t.Fatalf("raw config = %#v, want preserved schema version", raw) + } +} + +func TestExportTaskConfigReturnsPersistedConfig(t *testing.T) { + server := newTestServer(t) + importedCfg, err := models.DeserializeRuntimeConfig([]byte(`{ + "url":"https://www.wjx.cn/vm/test.aspx", + "survey_title":"问卷", + "target":1, + "threads":1, + "python_only_future_field":"task-keep" + }`)) + if err != nil { + t.Fatal(err) + } + task, err := server.manager.Create(t.Context(), &models.RuntimeConfig{ + URL: importedCfg.URL, + SurveyTitle: importedCfg.SurveyTitle, + Target: importedCfg.Target, + Threads: importedCfg.Threads, + ExtraFields: importedCfg.ExtraFields, + }) + if err != nil { + t.Fatal(err) + } + req := httptest.NewRequest(http.MethodGet, "/api/tasks/"+task.ID+"/config", nil) + rec := httptest.NewRecorder() + + server.Handler().ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d body=%s, want 200", rec.Code, rec.Body.String()) + } + var exportedCfg models.RuntimeConfig + if err := json.Unmarshal(rec.Body.Bytes(), &exportedCfg); err != nil { + t.Fatal(err) + } + if exportedCfg.SurveyTitle != "问卷" || exportedCfg.Target != 1 { + t.Fatalf("config = %#v, want task config", exportedCfg) + } + var raw map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &raw); err != nil { + t.Fatal(err) + } + if raw["python_only_future_field"] != "task-keep" { + t.Fatalf("raw task config = %#v, want persisted extra field", raw) + } +} + +func TestExportTaskReportReturnsProgressAndLogs(t *testing.T) { + server := newTestServer(t) + task, err := server.manager.Create(t.Context(), nilRuntimeConfig()) + if err != nil { + t.Fatal(err) + } + req := httptest.NewRequest(http.MethodGet, "/api/tasks/"+task.ID+"/report", nil) + rec := httptest.NewRecorder() + + server.Handler().ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d body=%s, want 200", rec.Code, rec.Body.String()) + } + if got := rec.Header().Get("Content-Disposition"); !strings.Contains(got, "report.json") { + t.Fatalf("Content-Disposition = %q, want report filename", got) + } + var report taskReport + if err := json.Unmarshal(rec.Body.Bytes(), &report); err != nil { + t.Fatal(err) + } + if report.TaskID != task.ID || report.Progress == nil || report.Progress.Target != 1 || len(report.Logs) == 0 { + t.Fatalf("report = %#v, want task progress and logs", report) + } +} + +func TestExportTaskReportCSV(t *testing.T) { + server := newTestServer(t) + task, err := server.manager.Create(t.Context(), nilRuntimeConfig()) + if err != nil { + t.Fatal(err) + } + req := httptest.NewRequest(http.MethodGet, "/api/tasks/"+task.ID+"/report?format=csv", nil) + rec := httptest.NewRecorder() + + server.Handler().ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d body=%s, want 200", rec.Code, rec.Body.String()) + } + if !strings.Contains(rec.Body.String(), "task_id,status,log_id") { + t.Fatalf("csv = %s, want header", rec.Body.String()) + } +} + func decodeAPIError(t *testing.T, rec *httptest.ResponseRecorder) errorResponse { t.Helper() var apiErr errorResponse diff --git a/internal/engine/engine.go b/internal/engine/engine.go index 4a7e22b..8ac8fab 100644 --- a/internal/engine/engine.go +++ b/internal/engine/engine.go @@ -17,13 +17,13 @@ import ( // StatusEvent represents a status update from the engine. type StatusEvent struct { - ThreadName string - StatusText string - Success bool - Fail bool - Current int - Total int - Timestamp time.Time + ThreadName string `json:"thread_name"` + StatusText string `json:"status_text"` + Success bool `json:"success"` + Fail bool `json:"fail"` + Current int `json:"current"` + Total int `json:"total"` + Timestamp time.Time `json:"timestamp"` } // StatusHandler is called when a status event occurs. @@ -283,9 +283,12 @@ func sampleIntervalDelay(bounds [2]int) time.Duration { } var userAgentProfiles = map[string]string{ - "wechat": "Mozilla/5.0 (Linux; Android 14; Pixel 8 Pro) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Mobile Safari/537.36 MicroMessenger/8.0.44", - "mobile": "Mozilla/5.0 (Linux; Android 14; Pixel 8 Pro) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Mobile Safari/537.36", - "pc": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", + "pc_web": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/121.0.0.0 Safari/537.36", + "mobile_android": "Mozilla/5.0 (Linux; Android 16; Pixel 8) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/121.0.0.0 Mobile Safari/537.36", + "wechat_android": "Mozilla/5.0 (Linux; Android 16; Pixel 8 Build/BP22.250124.009; wv) AppleWebKit/537.36 (KHTML, like Gecko) Version/4.0 Chrome/121.0.0.0 Mobile Safari/537.36 MicroMessenger/8.0.43.2460(0x28002B3B) Process/appbrand0 WeChat/arm64 Weixin NetType/WIFI Language/zh_CN ABI/arm64", + "wechat": "Mozilla/5.0 (Linux; Android 16; Pixel 8 Build/BP22.250124.009; wv) AppleWebKit/537.36 (KHTML, like Gecko) Version/4.0 Chrome/121.0.0.0 Mobile Safari/537.36 MicroMessenger/8.0.43.2460(0x28002B3B) Process/appbrand0 WeChat/arm64 Weixin NetType/WIFI Language/zh_CN ABI/arm64", + "mobile": "Mozilla/5.0 (Linux; Android 16; Pixel 8) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/121.0.0.0 Mobile Safari/537.36", + "pc": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/121.0.0.0 Safari/537.36", } func sampleUserAgent(cfg *execution.ExecutionConfig) string { @@ -294,7 +297,7 @@ func sampleUserAgent(cfg *execution.ExecutionConfig) string { } keys := cfg.RandomUserAgentKeys if len(keys) == 0 { - keys = []string{"wechat", "mobile", "pc"} + keys = []string{"wechat_android", "mobile_android", "pc_web"} } total := 0 @@ -302,7 +305,7 @@ func sampleUserAgent(cfg *execution.ExecutionConfig) string { if _, ok := userAgentProfiles[key]; !ok { continue } - weight := cfg.UserAgentRatios[key] + weight := cfg.UserAgentRatios[userAgentRatioKey(key)] if weight <= 0 { weight = 1 } @@ -318,7 +321,7 @@ func sampleUserAgent(cfg *execution.ExecutionConfig) string { if !ok { continue } - weight := cfg.UserAgentRatios[key] + weight := cfg.UserAgentRatios[userAgentRatioKey(key)] if weight <= 0 { weight = 1 } @@ -330,6 +333,19 @@ func sampleUserAgent(cfg *execution.ExecutionConfig) string { return "" } +func userAgentRatioKey(key string) string { + switch key { + case "wechat_android": + return "wechat" + case "mobile_android": + return "mobile" + case "pc_web": + return "pc" + default: + return key + } +} + // ParseSurvey parses a survey URL using the appropriate provider. func (e *Engine) ParseSurvey(ctx context.Context, surveyURL string) (*models.SurveyDefinition, error) { adapter, err := e.registry.GetByURL(surveyURL) diff --git a/internal/engine/engine_test.go b/internal/engine/engine_test.go index 1afaa41..89c9386 100644 --- a/internal/engine/engine_test.go +++ b/internal/engine/engine_test.go @@ -33,14 +33,26 @@ func TestSampleUserAgentHonorsDisabledAndRatios(t *testing.T) { cfg := &execution.ExecutionConfig{ RandomUserAgentEnabled: true, - RandomUserAgentKeys: []string{"pc"}, + RandomUserAgentKeys: []string{"pc_web"}, UserAgentRatios: map[string]int{"pc": 100}, } ua := sampleUserAgent(cfg) if ua == "" { t.Fatal("random UA should be selected when enabled") } - if ua != userAgentProfiles["pc"] { + if ua != userAgentProfiles["pc_web"] { t.Fatalf("UA = %q, want pc profile", ua) } } + +func TestSampleUserAgentKeepsLegacyKeys(t *testing.T) { + cfg := &execution.ExecutionConfig{ + RandomUserAgentEnabled: true, + RandomUserAgentKeys: []string{"mobile"}, + UserAgentRatios: map[string]int{"mobile": 100}, + } + ua := sampleUserAgent(cfg) + if ua != userAgentProfiles["mobile"] { + t.Fatalf("UA = %q, want legacy mobile profile", ua) + } +} diff --git a/internal/models/config.go b/internal/models/config.go index 2a05859..cc025f3 100644 --- a/internal/models/config.go +++ b/internal/models/config.go @@ -2,37 +2,48 @@ package models import ( "encoding/json" + "fmt" + "reflect" + "strconv" + "strings" + "time" "github.com/SurveyController/SurveyCore/internal/domain" ) -// Default UA keys for random user agent -var DefaultRandomUAKeys = []string{"wechat", "mobile", "pc"} +// Default UA keys for random user agent, matching Python's USER_AGENT_PRESETS. +var DefaultRandomUAKeys = []string{"wechat_android", "mobile_android", "pc_web"} + +const ( + CurrentConfigSchemaVersion = 6 + maxAnswerDurationSeconds = 30 * 60 +) // RuntimeConfig is the top-level user-facing configuration object. type RuntimeConfig struct { - URL string `json:"url"` - SurveyTitle string `json:"survey_title,omitempty"` - SurveyProvider string `json:"survey_provider,omitempty"` - Target int `json:"target,omitempty"` - Threads int `json:"threads,omitempty"` - SubmitInterval [2]int `json:"submit_interval,omitempty"` - AnswerDuration [2]int `json:"answer_duration,omitempty"` - AnswerDatetimeWindow [2]string `json:"answer_datetime_window,omitempty"` - RandomUAEnabled bool `json:"random_ua_enabled,omitempty"` - RandomUAKeys []string `json:"random_ua_keys,omitempty"` - RandomUARatios map[string]int `json:"random_ua_ratios,omitempty"` - ReliabilityModeEnabled bool `json:"reliability_mode_enabled,omitempty"` - PsychoTargetAlpha float64 `json:"psycho_target_alpha,omitempty"` - ReverseFillEnabled bool `json:"reverse_fill_enabled,omitempty"` - ReverseFillSourcePath string `json:"reverse_fill_source_path,omitempty"` - ReverseFillFormat string `json:"reverse_fill_format,omitempty"` - ReverseFillStartRow int `json:"reverse_fill_start_row,omitempty"` - ReverseFillThreads int `json:"reverse_fill_threads,omitempty"` - AnswerRules []map[string]any `json:"answer_rules,omitempty"` - DimensionGroups []string `json:"dimension_groups,omitempty"` - QuestionEntries []QuestionEntry `json:"question_entries,omitempty"` - QuestionsInfo []SurveyQuestionMeta `json:"questions_info,omitempty"` + URL string `json:"url"` + SurveyTitle string `json:"survey_title,omitempty"` + SurveyProvider string `json:"survey_provider,omitempty"` + Target int `json:"target,omitempty"` + Threads int `json:"threads,omitempty"` + SubmitInterval [2]int `json:"submit_interval,omitempty"` + AnswerDuration [2]int `json:"answer_duration,omitempty"` + AnswerDatetimeWindow [2]string `json:"answer_datetime_window,omitempty"` + RandomUAEnabled bool `json:"random_ua_enabled,omitempty"` + RandomUAKeys []string `json:"random_ua_keys,omitempty"` + RandomUARatios map[string]int `json:"random_ua_ratios,omitempty"` + ReliabilityModeEnabled bool `json:"reliability_mode_enabled,omitempty"` + PsychoTargetAlpha float64 `json:"psycho_target_alpha,omitempty"` + ReverseFillEnabled bool `json:"reverse_fill_enabled,omitempty"` + ReverseFillSourcePath string `json:"reverse_fill_source_path,omitempty"` + ReverseFillFormat string `json:"reverse_fill_format,omitempty"` + ReverseFillStartRow int `json:"reverse_fill_start_row,omitempty"` + ReverseFillThreads int `json:"reverse_fill_threads,omitempty"` + AnswerRules []map[string]any `json:"answer_rules,omitempty"` + DimensionGroups []string `json:"dimension_groups,omitempty"` + QuestionEntries []QuestionEntry `json:"question_entries,omitempty"` + QuestionsInfo []SurveyQuestionMeta `json:"questions_info,omitempty"` + ExtraFields map[string]json.RawMessage `json:"-"` } // NewDefaultRuntimeConfig returns a RuntimeConfig with sensible defaults. @@ -66,6 +77,411 @@ func DeserializeRuntimeConfig(data []byte) (*RuntimeConfig, error) { return &cfg, nil } +// UnmarshalJSON keeps Python-only or future config fields for lossless round trips. +func (cfg *RuntimeConfig) UnmarshalJSON(data []byte) error { + var raw map[string]json.RawMessage + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + normalized := normalizeRuntimeConfigJSON(raw) + + type runtimeConfigAlias RuntimeConfig + var alias runtimeConfigAlias + if err := json.Unmarshal(configMustJSON(normalized), &alias); err != nil { + return err + } + *cfg = RuntimeConfig(alias) + + for key := range runtimeConfigJSONKeys() { + delete(raw, key) + } + for key := range legacyPrivateRuntimeConfigJSONKeys() { + delete(raw, key) + } + if len(raw) == 0 { + cfg.ExtraFields = nil + return nil + } + cfg.ExtraFields = raw + return nil +} + +// MarshalJSON writes preserved Python-only fields back alongside Go-supported fields. +func (cfg RuntimeConfig) MarshalJSON() ([]byte, error) { + type runtimeConfigAlias RuntimeConfig + data, err := json.Marshal(runtimeConfigAlias(cfg)) + if err != nil { + return nil, err + } + var raw map[string]json.RawMessage + if err := json.Unmarshal(data, &raw); err != nil { + return nil, err + } + known := runtimeConfigJSONKeys() + for key, value := range cfg.ExtraFields { + if len(value) == 0 { + continue + } + if _, ok := known[key]; ok { + continue + } + raw[key] = value + } + if _, ok := raw["config_schema_version"]; !ok { + raw["config_schema_version"] = configMustJSON(CurrentConfigSchemaVersion) + } + if _, ok := raw["_ai_config_present"]; !ok { + raw["_ai_config_present"] = configMustJSON(cfg.hasAIConfig()) + } + return json.Marshal(raw) +} + +func (cfg RuntimeConfig) hasAIConfig() bool { + return false +} + +func normalizeRuntimeConfigJSON(raw map[string]json.RawMessage) map[string]json.RawMessage { + normalized := make(map[string]json.RawMessage, len(raw)) + for key, value := range raw { + normalized[key] = value + } + for _, key := range []string{"target", "threads"} { + if value, ok := raw[key]; ok { + normalized[key] = rawConfigInt(value, 0) + } + } + if value, ok := raw["reverse_fill_start_row"]; ok { + normalized["reverse_fill_start_row"] = rawMinInt(value, 1, 1) + } + if value, ok := raw["reverse_fill_threads"]; ok { + normalized["reverse_fill_threads"] = rawMinInt(value, 1, 1) + } + for _, key := range []string{"random_ua_enabled", "reliability_mode_enabled", "reverse_fill_enabled"} { + if value, ok := raw[key]; ok { + normalized[key] = rawConfigBool(value, false) + } + } + if value, ok := raw["psycho_target_alpha"]; ok { + normalized["psycho_target_alpha"] = rawConfigFloat(value, 0) + } + for _, key := range []string{"url", "survey_title", "survey_provider", "reverse_fill_source_path", "reverse_fill_format"} { + if value, ok := raw[key]; ok { + normalized[key] = rawString(value) + } + } + if value, ok := raw["submit_interval"]; ok { + normalized["submit_interval"] = rawIntPair(value, [2]int{}) + } + if value, ok := raw["answer_duration"]; ok { + normalized["answer_duration"] = rawAnswerDuration(value) + } + if value, ok := raw["answer_datetime_window"]; ok { + normalized["answer_datetime_window"] = rawAnswerDatetimeWindow(value) + } + if value, ok := raw["random_ua_ratios"]; ok { + normalized["random_ua_ratios"] = rawRandomUARatios(value) + } + if value, ok := raw["random_ua_keys"]; ok { + normalized["random_ua_keys"] = rawRandomUAKeys(value) + } + if value, ok := raw["reverse_fill_format"]; ok { + normalized["reverse_fill_format"] = rawReverseFillFormat(value) + } + return normalized +} + +func rawConfigInt(raw json.RawMessage, fallback int) json.RawMessage { + return configMustJSON(configToInt(configAnyFromRaw(raw), fallback)) +} + +func rawConfigFloat(raw json.RawMessage, fallback float64) json.RawMessage { + return configMustJSON(configToFloat(configAnyFromRaw(raw), fallback)) +} + +func rawConfigBool(raw json.RawMessage, fallback bool) json.RawMessage { + return configMustJSON(configToBool(configAnyFromRaw(raw), fallback)) +} + +func rawMinInt(raw json.RawMessage, fallback, minValue int) json.RawMessage { + value := configToInt(configAnyFromRaw(raw), fallback) + if value < minValue { + value = minValue + } + return configMustJSON(value) +} + +func rawString(raw json.RawMessage) json.RawMessage { + value := configAnyFromRaw(raw) + if value == nil { + return configMustJSON("") + } + return configMustJSON(strings.TrimSpace(fmt.Sprint(value))) +} + +func rawNullableString(raw json.RawMessage) json.RawMessage { + value := configAnyFromRaw(raw) + if value == nil { + return configMustJSON(nil) + } + return rawString(raw) +} + +func rawIntPair(raw json.RawMessage, fallback [2]int) json.RawMessage { + value := configAnyFromRaw(raw) + items, ok := value.([]any) + if !ok || len(items) < 2 { + return configMustJSON(fallback) + } + first := configToInt(items[0], fallback[0]) + second := configToInt(items[1], fallback[1]) + return configMustJSON([2]int{first, second}) +} + +func rawAnswerDuration(raw json.RawMessage) json.RawMessage { + value := configAnyFromRaw(raw) + defaultRange := [2]int{60, 120} + if value == nil { + return configMustJSON(defaultRange) + } + if items, ok := value.([]any); ok { + if len(items) >= 2 { + low := configClampInt(configToInt(items[0], 0), 0, maxAnswerDurationSeconds) + high := configClampInt(configToInt(items[1], low), low, maxAnswerDurationSeconds) + if low == 0 && high == 0 { + return configMustJSON(defaultRange) + } + if low == high { + return configMustJSON(configLegacyAnswerDurationRange(low)) + } + return configMustJSON([2]int{low, high}) + } + if len(items) == 1 { + return configMustJSON(configLegacyAnswerDurationRange(configToInt(items[0], 0))) + } + return configMustJSON(defaultRange) + } + return configMustJSON(configLegacyAnswerDurationRange(configToInt(value, 0))) +} + +func rawAnswerDatetimeWindow(raw json.RawMessage) json.RawMessage { + value := configAnyFromRaw(raw) + items, ok := value.([]any) + if !ok { + return configMustJSON([2]string{}) + } + var result [2]string + if len(items) >= 1 { + result[0] = configNormalizeAnswerDatetimeString(items[0]) + } + if len(items) >= 2 { + result[1] = configNormalizeAnswerDatetimeString(items[1]) + } + return configMustJSON(result) +} + +func rawRandomUAKeys(raw json.RawMessage) json.RawMessage { + value := configAnyFromRaw(raw) + items, ok := value.([]any) + if !ok { + return configMustJSON([]string{}) + } + result := make([]string, 0, len(items)) + for _, item := range items { + key := strings.TrimSpace(fmt.Sprint(item)) + if isValidRandomUAKey(key) { + result = append(result, key) + } + } + return configMustJSON(result) +} + +func rawRandomUARatios(raw json.RawMessage) json.RawMessage { + var source map[string]any + if err := json.Unmarshal(raw, &source); err != nil { + return configMustJSON(defaultRandomUARatios()) + } + total := 0 + for _, value := range source { + total += configToInt(value, 0) + } + if total != 100 { + return configMustJSON(defaultRandomUARatios()) + } + result := map[string]int{ + "wechat": configToInt(source["wechat"], 33), + "mobile": configToInt(source["mobile"], 33), + "pc": configToInt(source["pc"], 34), + } + return configMustJSON(result) +} + +func rawReverseFillFormat(raw json.RawMessage) json.RawMessage { + value := strings.ToLower(strings.TrimSpace(fmt.Sprint(configAnyFromRaw(raw)))) + switch value { + case domain.ReverseFillFormatAuto, domain.ReverseFillFormatWJXSequence, domain.ReverseFillFormatWJXScore, domain.ReverseFillFormatWJXText: + return configMustJSON(value) + default: + return configMustJSON(domain.ReverseFillFormatAuto) + } +} + +func defaultRandomUARatios() map[string]int { + return map[string]int{"wechat": 33, "mobile": 33, "pc": 34} +} + +func isValidRandomUAKey(key string) bool { + switch key { + case "wechat_android", "mobile_android", "pc_web", "wechat", "mobile", "pc": + return true + default: + return false + } +} + +func configAnyFromRaw(raw json.RawMessage) any { + var value any + if err := json.Unmarshal(raw, &value); err != nil { + return nil + } + return value +} + +func configToInt(value any, fallback int) int { + switch typed := value.(type) { + case float64: + return int(typed) + case string: + parsed, err := strconv.Atoi(strings.TrimSpace(typed)) + if err == nil { + return parsed + } + case json.Number: + parsed, err := typed.Int64() + if err == nil { + return int(parsed) + } + } + return fallback +} + +func configToFloat(value any, fallback float64) float64 { + switch typed := value.(type) { + case float64: + return typed + case string: + parsed, err := strconv.ParseFloat(strings.TrimSpace(typed), 64) + if err == nil { + return parsed + } + case json.Number: + parsed, err := typed.Float64() + if err == nil { + return parsed + } + } + return fallback +} + +func configToBool(value any, fallback bool) bool { + switch typed := value.(type) { + case bool: + return typed + case float64: + return typed != 0 + case string: + text := strings.ToLower(strings.TrimSpace(typed)) + switch text { + case "1", "true", "yes", "on": + return true + case "0", "false", "no", "off", "": + return false + } + } + return fallback +} + +func configLegacyAnswerDurationRange(value int) [2]int { + normalized := configClampInt(value, 0, maxAnswerDurationSeconds) + if normalized <= 0 { + return [2]int{60, 120} + } + low := int(float64(normalized)*0.9 + 0.5) + high := int(float64(normalized)*1.1 + 0.5) + return [2]int{configClampInt(low, 0, maxAnswerDurationSeconds), configClampInt(high, low, maxAnswerDurationSeconds)} +} + +func configClampInt(value, low, high int) int { + if value < low { + return low + } + if value > high { + return high + } + return value +} + +func configNormalizeAnswerDatetimeString(value any) string { + text := strings.TrimSpace(fmt.Sprint(value)) + if text == "" { + return "" + } + parsed, err := time.ParseInLocation("2006-01-02 15:04:05", text, time.Local) + if err != nil { + return "" + } + return parsed.Format("2006-01-02 15:04:05") +} + +func configMustJSON(value any) json.RawMessage { + data, err := json.Marshal(value) + if err != nil { + return json.RawMessage("null") + } + return data +} + +func runtimeConfigJSONKeys() map[string]struct{} { + result := make(map[string]struct{}) + t := reflect.TypeOf(RuntimeConfig{}) + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + tag := field.Tag.Get("json") + name := strings.Split(tag, ",")[0] + if name == "-" { + continue + } + if name == "" { + name = field.Name + } + result[name] = struct{}{} + } + return result +} + +func legacyPrivateRuntimeConfigJSONKeys() map[string]struct{} { + return map[string]struct{}{ + "random_ip_enabled": {}, + "proxy_source": {}, + "custom_proxy_api": {}, + "proxy_area_code": {}, + "random_ip_user_id": {}, + "random_ip_device_id": {}, + "ip_extract_endpoint": {}, + "random_ip_lease_minute": {}, + "fail_stop_enabled": {}, + "pause_on_aliyun_captcha": {}, + "ai_mode": {}, + "ai_provider": {}, + "ai_api_key": {}, + "ai_base_url": {}, + "ai_api_protocol": {}, + "ai_model": {}, + "ai_system_prompt": {}, + "ai_free_endpoint": {}, + "random_ip_session_path": {}, + } +} + // Provider constants const ( ProviderWJX = domain.ProviderWJX diff --git a/internal/models/config_test.go b/internal/models/config_test.go new file mode 100644 index 0000000..fa702dd --- /dev/null +++ b/internal/models/config_test.go @@ -0,0 +1,207 @@ +package models + +import ( + "encoding/json" + "testing" +) + +func TestRuntimeConfigPreservesPythonExtraFields(t *testing.T) { + cfg, err := DeserializeRuntimeConfig([]byte(`{ + "url":"https://www.wjx.cn/vm/test.aspx", + "target":3, + "_ai_config_present":true, + "config_schema_version":6, + "python_future":{"nested":[1,"two",true]} + }`)) + if err != nil { + t.Fatal(err) + } + if cfg.URL != "https://www.wjx.cn/vm/test.aspx" || cfg.Target != 3 { + t.Fatalf("known fields = %#v, want decoded runtime config", cfg) + } + if len(cfg.ExtraFields) != 3 { + t.Fatalf("extra fields = %#v, want python-only fields preserved", cfg.ExtraFields) + } + + data, err := SerializeRuntimeConfig(cfg) + if err != nil { + t.Fatal(err) + } + var roundTrip map[string]any + if err := json.Unmarshal(data, &roundTrip); err != nil { + t.Fatal(err) + } + if roundTrip["_ai_config_present"] != true || roundTrip["config_schema_version"] != float64(6) { + t.Fatalf("round trip = %#v, want preserved python metadata", roundTrip) + } + future, ok := roundTrip["python_future"].(map[string]any) + if !ok || len(future["nested"].([]any)) != 3 { + t.Fatalf("python_future = %#v, want preserved nested object", roundTrip["python_future"]) + } +} + +func TestRuntimeConfigSerializationAddsPythonSchemaMetadata(t *testing.T) { + cfg := NewDefaultRuntimeConfig() + cfg.URL = "https://www.wjx.cn/vm/test.aspx" + + data, err := SerializeRuntimeConfig(&cfg) + if err != nil { + t.Fatal(err) + } + var payload map[string]any + if err := json.Unmarshal(data, &payload); err != nil { + t.Fatal(err) + } + if payload["config_schema_version"] != float64(CurrentConfigSchemaVersion) { + t.Fatalf("payload = %#v, want Python schema version", payload) + } + if payload["_ai_config_present"] != false { + t.Fatalf("payload = %#v, want no request-side AI config marker", payload) + } +} + +func TestDefaultRuntimeConfigUsesPythonRandomUAKeys(t *testing.T) { + cfg := NewDefaultRuntimeConfig() + if len(cfg.RandomUAKeys) != 3 || + cfg.RandomUAKeys[0] != "wechat_android" || + cfg.RandomUAKeys[1] != "mobile_android" || + cfg.RandomUAKeys[2] != "pc_web" { + t.Fatalf("random UA keys = %#v, want Python preset defaults", cfg.RandomUAKeys) + } +} + +func TestRuntimeConfigSerializationPreservesPythonSchemaMetadata(t *testing.T) { + cfg, err := DeserializeRuntimeConfig([]byte(`{ + "url":"https://www.wjx.cn/vm/test.aspx", + "_ai_config_present":false, + "config_schema_version":5 + }`)) + if err != nil { + t.Fatal(err) + } + data, err := SerializeRuntimeConfig(cfg) + if err != nil { + t.Fatal(err) + } + var payload map[string]any + if err := json.Unmarshal(data, &payload); err != nil { + t.Fatal(err) + } + if payload["config_schema_version"] != float64(5) || payload["_ai_config_present"] != false { + t.Fatalf("payload = %#v, want imported metadata preserved", payload) + } +} + +func TestRuntimeConfigCloneKeepsExtraFields(t *testing.T) { + original, err := DeserializeRuntimeConfig([]byte(`{ + "url":"https://www.wjx.cn/vm/test.aspx", + "python_only_future_field":"keep-me" + }`)) + if err != nil { + t.Fatal(err) + } + data, err := SerializeRuntimeConfig(original) + if err != nil { + t.Fatal(err) + } + cloned, err := DeserializeRuntimeConfig(data) + if err != nil { + t.Fatal(err) + } + if string(cloned.ExtraFields["python_only_future_field"]) != `"keep-me"` { + t.Fatalf("extra fields = %#v, want clone to keep unknown field", cloned.ExtraFields) + } +} + +func TestRuntimeConfigAcceptsPythonLooseScalarFields(t *testing.T) { + cfg, err := DeserializeRuntimeConfig([]byte(`{ + "url":123, + "target":"5", + "threads":"2", + "psycho_target_alpha":"0.91", + "submit_interval":["7","9"], + "answer_duration":100, + "answer_datetime_window":["2026-02-10 09:00:00","bad"], + "random_ua_keys":["pc_web","bad","wechat_android"], + "random_ua_ratios":{"wechat":"40","mobile":30,"pc":"30"} + }`)) + if err != nil { + t.Fatal(err) + } + if cfg.URL != "123" || cfg.Target != 5 || cfg.Threads != 2 { + t.Fatalf("basic fields = url %q target %d threads %d, want string/int coercion", cfg.URL, cfg.Target, cfg.Threads) + } + if cfg.PsychoTargetAlpha != 0.91 { + t.Fatalf("alpha = %v, want parsed float", cfg.PsychoTargetAlpha) + } + if cfg.SubmitInterval != [2]int{7, 9} { + t.Fatalf("submit interval = %#v, want parsed pair", cfg.SubmitInterval) + } + if cfg.AnswerDuration != [2]int{90, 110} { + t.Fatalf("answer duration = %#v, want legacy scalar range", cfg.AnswerDuration) + } + if cfg.AnswerDatetimeWindow != [2]string{"2026-02-10 09:00:00", ""} { + t.Fatalf("answer datetime window = %#v, want normalized valid side only", cfg.AnswerDatetimeWindow) + } + if cfg.RandomUARatios["wechat"] != 40 || cfg.RandomUARatios["mobile"] != 30 || cfg.RandomUARatios["pc"] != 30 { + t.Fatalf("ua ratios = %#v, want parsed int map", cfg.RandomUARatios) + } + if len(cfg.RandomUAKeys) != 2 || cfg.RandomUAKeys[0] != "pc_web" || cfg.RandomUAKeys[1] != "wechat_android" { + t.Fatalf("ua keys = %#v, want Python preset keys filtered", cfg.RandomUAKeys) + } +} + +func TestRuntimeConfigAcceptsPythonLooseAnswerDurationLists(t *testing.T) { + tests := []struct { + name string + raw string + want [2]int + }{ + {name: "single item", raw: `[120]`, want: [2]int{108, 132}}, + {name: "equal pair", raw: `[100,100]`, want: [2]int{90, 110}}, + {name: "zero pair uses default", raw: `[0,0]`, want: [2]int{60, 120}}, + {name: "ordered pair", raw: `[3,5]`, want: [2]int{3, 5}}, + {name: "empty list uses default", raw: `[]`, want: [2]int{60, 120}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg, err := DeserializeRuntimeConfig([]byte(`{"answer_duration":` + tt.raw + `}`)) + if err != nil { + t.Fatal(err) + } + if cfg.AnswerDuration != tt.want { + t.Fatalf("answer duration = %#v, want %#v", cfg.AnswerDuration, tt.want) + } + }) + } +} + +func TestRuntimeConfigNormalizesPythonCodecBoundariesAndDropsPrivateFields(t *testing.T) { + cfg, err := DeserializeRuntimeConfig([]byte(`{ + "proxy_source":"bad", + "ai_mode":"unsupported", + "random_ip_user_id":88, + "random_ip_device_id":"device-88", + "reverse_fill_format":"spreadsheet", + "reverse_fill_start_row":0, + "reverse_fill_threads":"0", + "random_ua_ratios":{"wechat":20,"mobile":20,"pc":20} + }`)) + if err != nil { + t.Fatal(err) + } + for _, key := range []string{"proxy_source", "ai_mode", "random_ip_user_id", "random_ip_device_id"} { + if _, ok := cfg.ExtraFields[key]; ok { + t.Fatalf("private legacy field %q preserved in extras: %#v", key, cfg.ExtraFields) + } + } + if cfg.ReverseFillFormat != ReverseFillFormatAuto { + t.Fatalf("reverse fill format = %q, want auto", cfg.ReverseFillFormat) + } + if cfg.ReverseFillStartRow != 1 || cfg.ReverseFillThreads != 1 { + t.Fatalf("reverse fill start/threads = %d/%d, want 1/1", cfg.ReverseFillStartRow, cfg.ReverseFillThreads) + } + if cfg.RandomUARatios["wechat"] != 33 || cfg.RandomUARatios["mobile"] != 33 || cfg.RandomUARatios["pc"] != 34 { + t.Fatalf("ua ratios = %#v, want Python defaults for invalid total", cfg.RandomUARatios) + } +} diff --git a/internal/network/proxy/official_source.go b/internal/network/proxy/official_source.go deleted file mode 100644 index d651aa7..0000000 --- a/internal/network/proxy/official_source.go +++ /dev/null @@ -1,245 +0,0 @@ -package proxy - -import ( - "bytes" - "encoding/json" - "fmt" - "io" - "net/http" - "strings" - "time" - - "github.com/SurveyController/SurveyCore/internal/models" -) - -func fetchFromOfficial(source string, count int, opts officialOptions) ([]models.ProxyLease, error) { - count = normalizeProxyCount(count) - opts = normalizeOfficialOptions(opts) - if opts.Endpoint == "" { - return nil, fmt.Errorf("官方随机 IP 提取接口未配置") - } - if opts.UserID <= 0 { - return nil, fmt.Errorf("官方随机 IP 用户 ID 未配置") - } - if opts.DeviceID == "" { - return nil, fmt.Errorf("官方随机 IP 设备 ID 未配置") - } - - upstream := "default" - if source == "benefit" { - upstream = "idiot" - opts.Minute = 1 - } - body := map[string]any{ - "user_id": opts.UserID, - "minute": opts.Minute, - "pool": opts.Pool, - "upstream": upstream, - } - if count > 1 { - body["num"] = count - } - if opts.AreaCode != "" { - body["area"] = opts.AreaCode - } - bodyBytes, err := json.Marshal(body) - if err != nil { - return nil, err - } - - req, err := http.NewRequest(http.MethodPost, opts.Endpoint, bytes.NewReader(bodyBytes)) - if err != nil { - return nil, err - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("X-Device-ID", opts.DeviceID) - req.Header.Set("User-Agent", defaultUserAgent) - req.Header.Set("Accept", "application/json, text/plain, */*") - - client := &http.Client{Timeout: 10 * time.Second} - resp, err := client.Do(req) - if err != nil { - return nil, fmt.Errorf("官方随机 IP 请求失败: %w", err) - } - defer resp.Body.Close() - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("读取官方随机 IP 响应失败: %w", err) - } - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("官方随机 IP HTTP %d: %s", resp.StatusCode, officialErrorDetail(respBody)) - } - - leases, err := parseOfficialProxyPayload(respBody, source) - if err != nil { - return nil, err - } - if len(leases) > count { - leases = leases[:count] - } - return leases, nil -} - -const defaultUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/121.0.0.0 Safari/537.36" - -func parseOfficialProxyPayload(body []byte, fallbackSource string) ([]models.ProxyLease, error) { - var data map[string]any - if err := json.Unmarshal(body, &data); err != nil { - return nil, fmt.Errorf("解析官方随机 IP 响应失败: %w", err) - } - source := officialSourceFromProvider(firstString(data, "provider"), fallbackSource) - if rawItems, ok := data["items"].([]any); ok { - leases := make([]models.ProxyLease, 0, len(rawItems)) - for _, raw := range rawItems { - item, ok := raw.(map[string]any) - if !ok { - continue - } - if lease, ok := officialLeaseFromMap(item, source); ok { - leases = append(leases, lease) - } - } - if len(leases) == 0 { - return nil, fmt.Errorf("官方随机 IP 批量响应中无有效代理") - } - return leases, nil - } - if lease, ok := officialLeaseFromMap(data, source); ok { - return []models.ProxyLease{lease}, nil - } - return nil, fmt.Errorf("官方随机 IP 响应缺少 host/port/account/password") -} - -func officialLeaseFromMap(data map[string]any, source string) (models.ProxyLease, bool) { - host := firstString(data, "host", "ip", "IP") - port := firstString(data, "port", "Port", "PORT") - account := firstString(data, "account", "username", "user") - password := firstString(data, "password", "pwd", "pass") - if host == "" || port == "" || account == "" || password == "" { - return models.ProxyLease{}, false - } - expireAt := firstString(data, "expire_at", "expireAt", "expire") - return models.ProxyLease{ - Address: fmt.Sprintf("%s:%s@%s:%s", account, password, host, port), - ExpireAt: expireAt, - ExpireTS: parseExpireAtToUnix(expireAt), - Poolable: expireAt != "", - Source: source, - }, true -} - -func officialSourceFromProvider(provider, fallback string) string { - switch strings.ToLower(strings.TrimSpace(provider)) { - case "idiot", "benefit": - return "benefit" - case "default": - return "default" - default: - if strings.TrimSpace(fallback) != "" { - return strings.TrimSpace(fallback) - } - return "default" - } -} - -func officialErrorDetail(body []byte) string { - var data map[string]any - if err := json.Unmarshal(body, &data); err == nil { - for _, key := range []string{"detail", "message", "error"} { - if value := firstString(data, key); value != "" { - return value - } - } - } - return truncate(string(body), 200) -} - -func parseExpireAtToUnix(value string) float64 { - text := strings.TrimSpace(value) - if text == "" { - return 0 - } - layouts := []string{ - time.RFC3339Nano, - time.RFC3339, - "2006-01-02T15:04:05", - "2006-01-02 15:04:05", - "2006-01-02 15:04", - } - for _, layout := range layouts { - parsed, err := time.Parse(layout, text) - if err == nil { - return float64(parsed.Unix()) - } - } - return 0 -} - -func normalizeOfficialOptions(opts officialOptions) officialOptions { - opts.Endpoint = strings.TrimSpace(opts.Endpoint) - if opts.Endpoint == "" { - opts.Endpoint = defaultOfficialEndpoint() - } - opts.AreaCode = normalizeAreaCode(opts.AreaCode) - if !isAllowedMinute(opts.Minute) { - opts.Minute = 1 - } - if opts.Pool == "" { - opts.Pool = resolveOfficialPool(opts.AreaCode) - } - return opts -} - -func normalizeProxyCount(count int) int { - if count <= 0 { - return 1 - } - if count > 80 { - return 80 - } - return count -} - -func defaultOfficialEndpoint() string { - return "https://api-wjx.hungrym0.top/api/ip/extract" -} - -func isAllowedMinute(minute int) bool { - switch minute { - case 1, 3, 5, 10, 15, 30: - return true - default: - return false - } -} - -func normalizeAreaCode(areaCode string) string { - areaCode = strings.TrimSpace(areaCode) - if len(areaCode) != 6 { - return "" - } - for _, ch := range areaCode { - if ch < '0' || ch > '9' { - return "" - } - } - return areaCode -} - -func resolveOfficialPool(areaCode string) string { - if areaCode == "" { - return "ordinary" - } - if strings.HasSuffix(areaCode, "0000") && ordinaryPoolProvinceCodes[areaCode] { - return "ordinary" - } - return "quality" -} - -var ordinaryPoolProvinceCodes = map[string]bool{ - "110000": true, "120000": true, "130000": true, "140000": true, "150000": true, - "210000": true, "220000": true, "230000": true, "320000": true, "330000": true, - "340000": true, "350000": true, "360000": true, "370000": true, "410000": true, - "420000": true, "430000": true, "440000": true, "460000": true, "500000": true, - "510000": true, "610000": true, "620000": true, "640000": true, -} diff --git a/internal/network/proxy/pool.go b/internal/network/proxy/pool.go index 3bb75c0..6f89526 100644 --- a/internal/network/proxy/pool.go +++ b/internal/network/proxy/pool.go @@ -11,57 +11,16 @@ import ( // Pool manages a pool of proxy leases. type Pool struct { - mu sync.RWMutex - leases []models.ProxyLease - cooldown map[string]time.Time // address -> cooldown until - source string - apiURL string - officialOptions officialOptions -} - -type officialOptions struct { - Endpoint string - UserID int - DeviceID string - AreaCode string - Minute int - Pool string + mu sync.RWMutex + leases []models.ProxyLease + cooldown map[string]time.Time // address -> cooldown until + source string + apiURL string } // Option configures a proxy pool. type Option func(*Pool) -// WithOfficialEndpoint overrides the official random-IP endpoint. -func WithOfficialEndpoint(endpoint string) Option { - return func(p *Pool) { - p.officialOptions.Endpoint = strings.TrimSpace(endpoint) - } -} - -// WithOfficialCredentials sets official random-IP session credentials. -func WithOfficialCredentials(userID int, deviceID string) Option { - return func(p *Pool) { - p.officialOptions.UserID = userID - p.officialOptions.DeviceID = strings.TrimSpace(deviceID) - } -} - -// WithOfficialAreaCode sets the official random-IP area code. -func WithOfficialAreaCode(areaCode string) Option { - return func(p *Pool) { - p.officialOptions.AreaCode = normalizeAreaCode(areaCode) - } -} - -// WithOfficialMinute sets the official random-IP lease minute. -func WithOfficialMinute(minute int) Option { - return func(p *Pool) { - if isAllowedMinute(minute) { - p.officialOptions.Minute = minute - } - } -} - // NewPool creates a new proxy pool. func NewPool(source, apiURL string, opts ...Option) *Pool { p := &Pool{ @@ -69,22 +28,12 @@ func NewPool(source, apiURL string, opts ...Option) *Pool { cooldown: make(map[string]time.Time), source: source, apiURL: apiURL, - officialOptions: officialOptions{ - Endpoint: defaultOfficialEndpoint(), - Minute: 1, - }, } for _, opt := range opts { if opt != nil { opt(p) } } - if p.officialOptions.Pool == "" { - p.officialOptions.Pool = resolveOfficialPool(p.officialOptions.AreaCode) - } - if p.officialOptions.Minute <= 0 { - p.officialOptions.Minute = 1 - } return p } @@ -130,12 +79,9 @@ func (p *Pool) FetchBatch(count int) ([]models.ProxyLease, error) { p.mu.RLock() source := p.source apiURL := p.apiURL - officialOpts := p.officialOptions p.mu.RUnlock() switch source { - case "default", "benefit": - return fetchFromOfficial(source, count, officialOpts) case "custom": if apiURL == "" { return nil, fmt.Errorf("自定义代理 API URL 未配置") diff --git a/internal/network/proxy/pool_test.go b/internal/network/proxy/pool_test.go deleted file mode 100644 index ab59902..0000000 --- a/internal/network/proxy/pool_test.go +++ /dev/null @@ -1,171 +0,0 @@ -package proxy - -import ( - "encoding/json" - "net/http" - "net/http/httptest" - "strings" - "testing" -) - -func TestParseProxyFromNestedJSON(t *testing.T) { - payload := `{ - "data": { - "items": [ - {"ip": "1.1.1.1", "port": 8000, "account": "user", "password": "pass"}, - {"proxy": "http://2.2.2.2:9000"}, - {"nested": {"list": ["3.3.3.3:7000"]}} - ] - } - }` - - leases, err := parseProxyFromJSON(payload) - if err != nil { - t.Fatalf("parseProxyFromJSON failed: %v", err) - } - if len(leases) != 3 { - t.Fatalf("leases length = %d, want 3: %#v", len(leases), leases) - } - if leases[0].Address != "user:pass@1.1.1.1:8000" { - t.Fatalf("first proxy = %q", leases[0].Address) - } -} - -func TestFetchFromCustomChecksHTTPStatus(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - http.Error(w, "bad key", http.StatusForbidden) - })) - defer server.Close() - - _, err := fetchFromCustom(server.URL, 1) - if err == nil { - t.Fatal("expected error for HTTP 403") - } -} - -func TestFetchFromCustomParsesCredentials(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`{"items":["user:pass@4.4.4.4:8080"]}`)) - })) - defer server.Close() - - leases, err := fetchFromCustom(server.URL, 1) - if err != nil { - t.Fatalf("fetchFromCustom failed: %v", err) - } - if len(leases) != 1 || leases[0].Address != "user:pass@4.4.4.4:8080" { - t.Fatalf("leases = %#v", leases) - } - if got := ExtractProxyAddress(leases[0].Address); got != "http://user:pass@4.4.4.4:8080" { - t.Fatalf("normalized proxy = %q", got) - } -} - -func TestFetchFromOfficialPostsSessionAndParsesBatch(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - t.Fatalf("method = %s, want POST", r.Method) - } - if got := r.Header.Get("X-Device-ID"); got != "device-1" { - t.Fatalf("X-Device-ID = %q, want device-1", got) - } - var body map[string]any - if err := json.NewDecoder(r.Body).Decode(&body); err != nil { - t.Fatalf("decode request body: %v", err) - } - if body["user_id"].(float64) != 33 || body["minute"].(float64) != 3 || body["num"].(float64) != 2 { - t.Fatalf("request body numeric fields = %#v", body) - } - if body["pool"] != "quality" || body["upstream"] != "default" || body["area"] != "110100" { - t.Fatalf("request body string fields = %#v", body) - } - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{ - "provider":"default", - "items":[ - {"host":"1.1.1.1","port":8000,"account":"user","password":"pass","expire_at":"2030-01-02T03:04:05Z"}, - {"host":"2.2.2.2","port":9000,"account":"user2","password":"pass2","expire_at":"2030-01-02T03:05:05Z"} - ] - }`)) - })) - defer server.Close() - - leases, err := fetchFromOfficial("default", 2, officialOptions{ - Endpoint: server.URL, - UserID: 33, - DeviceID: "device-1", - AreaCode: "110100", - Minute: 3, - Pool: "quality", - }) - if err != nil { - t.Fatalf("fetchFromOfficial failed: %v", err) - } - if len(leases) != 2 { - t.Fatalf("leases length = %d, want 2", len(leases)) - } - if leases[0].Address != "user:pass@1.1.1.1:8000" || leases[0].Source != "default" || !leases[0].Poolable { - t.Fatalf("first lease = %#v", leases[0]) - } - if leases[0].ExpireTS <= 0 { - t.Fatalf("first lease ExpireTS = %v, want parsed timestamp", leases[0].ExpireTS) - } -} - -func TestFetchFromOfficialBenefitUsesIdiotUpstreamAndSinglePayload(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var body map[string]any - if err := json.NewDecoder(r.Body).Decode(&body); err != nil { - t.Fatalf("decode request body: %v", err) - } - if body["upstream"] != "idiot" || body["minute"].(float64) != 1 { - t.Fatalf("benefit request body = %#v", body) - } - if _, hasNum := body["num"]; hasNum { - t.Fatalf("single request should omit num: %#v", body) - } - _, _ = w.Write([]byte(`{"provider":"idiot","host":"3.3.3.3","port":7000,"account":"u","password":"p"}`)) - })) - defer server.Close() - - leases, err := fetchFromOfficial("benefit", 1, officialOptions{ - Endpoint: server.URL, - UserID: 44, - DeviceID: "device-2", - Minute: 5, - }) - if err != nil { - t.Fatalf("fetchFromOfficial failed: %v", err) - } - if len(leases) != 1 || leases[0].Address != "u:p@3.3.3.3:7000" || leases[0].Source != "benefit" { - t.Fatalf("leases = %#v", leases) - } - if leases[0].Poolable { - t.Fatalf("lease without expire_at should not be poolable: %#v", leases[0]) - } -} - -func TestFetchFromOfficialRequiresCredentials(t *testing.T) { - _, err := fetchFromOfficial("default", 1, officialOptions{Endpoint: "http://example.invalid"}) - if err == nil { - t.Fatal("expected missing credentials error") - } -} - -func TestFetchFromOfficialReportsHTTPDetail(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusTooManyRequests) - _, _ = w.Write([]byte(`{"detail":"token_rate_limited"}`)) - })) - defer server.Close() - - _, err := fetchFromOfficial("default", 1, officialOptions{ - Endpoint: server.URL, - UserID: 55, - DeviceID: "device-3", - }) - if err == nil || !strings.Contains(err.Error(), "token_rate_limited") { - t.Fatalf("error = %v, want token_rate_limited", err) - } -} diff --git a/internal/questions/ai.go b/internal/questions/ai.go index b3bc697..d138957 100644 --- a/internal/questions/ai.go +++ b/internal/questions/ai.go @@ -14,19 +14,29 @@ const ( aiRequestTimeout = 12 * time.Second aiMaxAttempts = 4 aiRetryBackoff = 400 * time.Millisecond + + aiChatCompletionsSuffix = "/chat/completions" + aiResponsesSuffix = "/responses" + aiLegacyCompletions = "/completions" ) +// AIConfig holds server-side AI generation configuration. type AIConfig struct { - APIKey string - BaseURL string - Model string + Provider string + APIKey string + BaseURL string + Protocol string + Model string + SystemPrompt string } +// AIClient generates text answers using a configured AI provider. type AIClient struct { config AIConfig client *http.Client } +// AIError classifies an AI generation failure for callers and tests. type AIError struct { Kind string Err error @@ -54,6 +64,7 @@ const ( AIErrorNetwork = "network" ) +// NewAIClient creates a new AI client. func NewAIClient(config AIConfig) *AIClient { return &AIClient{ config: config, @@ -61,30 +72,38 @@ func NewAIClient(config AIConfig) *AIClient { } } +// GenerateAnswer generates a text answer for a question. func (a *AIClient) GenerateAnswer(questionTitle, questionType string, blankCount int) (string, error) { + return a.generateAPI(questionTitle, questionType, blankCount) +} + +func (a *AIClient) TestConnection() (string, error) { + return a.GenerateAnswer("这是一个测试问题,请回复'连接成功'", "fill_blank", 1) +} + +func (a *AIClient) generateAPI(questionTitle, questionType string, blankCount int) (string, error) { if strings.TrimSpace(a.config.APIKey) == "" { return "", classifyAIError(AIErrorConfig, fmt.Errorf("API key 未配置")) } - if strings.TrimSpace(a.config.BaseURL) == "" { - return "", classifyAIError(AIErrorConfig, fmt.Errorf("base_url 未配置")) - } - if strings.TrimSpace(a.config.Model) == "" { - return "", classifyAIError(AIErrorConfig, fmt.Errorf("model 未配置")) + + prompt := a.buildPrompt(questionTitle, questionType, blankCount) + systemPrompt := strings.TrimSpace(a.config.SystemPrompt) + if systemPrompt == "" { + systemPrompt = "你是一个问卷答题助手,请根据题目生成合理的答案。只输出答案内容,不要解释。" } - reqBody := map[string]any{ - "model": strings.TrimSpace(a.config.Model), - "messages": []map[string]string{ - {"role": "system", "content": "你是一个问卷答题助手,请根据题目生成合理的答案。只输出答案内容,不要解释。"}, - {"role": "user", "content": buildAIPrompt(questionTitle, questionType, blankCount)}, - }, - "temperature": 0.7, - "max_tokens": 200, + endpoint, err := a.resolveEndpoint() + if err != nil { + return "", err } var lastErr error for attempt := 1; attempt <= aiMaxAttempts; attempt++ { - answer, err := a.doGenerateAPI(reqBody) + answer, err := a.doGenerateAPI(endpoint.Protocol, endpoint.URL, endpoint.Model, prompt, systemPrompt) + if err != nil && endpoint.Protocol == "chat_completions" && endpoint.AutoFallbackToResponses && isEndpointMismatchAIError(err) { + fallbackURL := strings.TrimRight(endpoint.BaseURL, "/") + aiResponsesSuffix + answer, err = a.doGenerateAPI("responses", fallbackURL, endpoint.Model, prompt, systemPrompt) + } if err == nil { return answer, nil } @@ -97,9 +116,111 @@ func (a *AIClient) GenerateAnswer(questionTitle, questionType string, blankCount return "", lastErr } -func (a *AIClient) doGenerateAPI(reqBody map[string]any) (string, error) { +type aiEndpoint struct { + Protocol string + URL string + BaseURL string + Model string + AutoFallbackToResponses bool +} + +func (a *AIClient) resolveEndpoint() (aiEndpoint, error) { + provider := strings.ToLower(strings.TrimSpace(a.config.Provider)) + model := strings.TrimSpace(a.config.Model) + baseURL := normalizeAIEndpointURL(a.config.BaseURL) + protocol := normalizeAIProtocol(a.config.Protocol) + + if provider == "custom" { + if baseURL == "" { + return aiEndpoint{}, classifyAIError(AIErrorConfig, fmt.Errorf("自定义模式需要配置 Base URL")) + } + if model == "" { + return aiEndpoint{}, classifyAIError(AIErrorConfig, fmt.Errorf("自定义模式需要配置模型名称")) + } + return resolveCustomAIEndpoint(baseURL, protocol, model) + } + + if baseURL == "" { + baseURL = "https://api.deepseek.com/v1" + } + if model == "" { + model = "deepseek-chat" + } + if endpoint, err := resolveExplicitAIEndpoint(baseURL, protocol, model); err == nil { + return endpoint, nil + } else if strings.HasSuffix(strings.ToLower(strings.TrimRight(baseURL, "/")), aiLegacyCompletions) { + return aiEndpoint{}, err + } + return aiEndpoint{ + Protocol: "chat_completions", + URL: strings.TrimRight(baseURL, "/") + aiChatCompletionsSuffix, + BaseURL: baseURL, + Model: model, + }, nil +} + +func resolveCustomAIEndpoint(baseURL string, protocol string, model string) (aiEndpoint, error) { + if endpoint, err := resolveExplicitAIEndpoint(baseURL, protocol, model); err == nil { + return endpoint, nil + } else if strings.HasSuffix(strings.ToLower(strings.TrimRight(baseURL, "/")), aiLegacyCompletions) { + return aiEndpoint{}, err + } + + if protocol == "responses" { + return aiEndpoint{ + Protocol: "responses", + URL: strings.TrimRight(baseURL, "/") + aiResponsesSuffix, + BaseURL: baseURL, + Model: model, + }, nil + } + return aiEndpoint{ + Protocol: "chat_completions", + URL: strings.TrimRight(baseURL, "/") + aiChatCompletionsSuffix, + BaseURL: baseURL, + Model: model, + AutoFallbackToResponses: protocol == "auto", + }, nil +} + +func resolveExplicitAIEndpoint(baseURL string, protocol string, model string) (aiEndpoint, error) { + path := strings.ToLower(strings.TrimRight(baseURL, "/")) + switch { + case strings.HasSuffix(path, aiChatCompletionsSuffix): + return aiEndpoint{Protocol: "chat_completions", URL: baseURL, BaseURL: trimAISuffix(baseURL, aiChatCompletionsSuffix), Model: model}, nil + case strings.HasSuffix(path, aiResponsesSuffix): + return aiEndpoint{Protocol: "responses", URL: baseURL, BaseURL: trimAISuffix(baseURL, aiResponsesSuffix), Model: model}, nil + case strings.HasSuffix(path, aiLegacyCompletions): + return aiEndpoint{}, classifyAIError(AIErrorConfig, fmt.Errorf("暂不支持旧版 /completions 协议,请改用 /chat/completions 或 /responses")) + } + return aiEndpoint{}, classifyAIError(AIErrorConfig, fmt.Errorf("未配置完整 AI endpoint")) +} + +func normalizeAIEndpointURL(raw string) string { + return strings.TrimRight(strings.TrimSpace(raw), "/") +} + +func normalizeAIProtocol(raw string) string { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "chat_completions", "responses": + return strings.ToLower(strings.TrimSpace(raw)) + default: + return "auto" + } +} + +func trimAISuffix(rawURL string, suffix string) string { + trimmed := strings.TrimRight(rawURL, "/") + if len(trimmed) < len(suffix) { + return trimmed + } + return trimmed[:len(trimmed)-len(suffix)] +} + +func (a *AIClient) doGenerateAPI(protocol string, url string, model string, prompt string, systemPrompt string) (string, error) { + reqBody := buildAIRequestBody(protocol, model, prompt, systemPrompt) bodyBytes, _ := json.Marshal(reqBody) - req, err := http.NewRequest("POST", strings.TrimRight(a.config.BaseURL, "/")+"/chat/completions", bytes.NewReader(bodyBytes)) + req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(bodyBytes)) if err != nil { return "", classifyAIError(AIErrorConfig, err) } @@ -123,22 +244,103 @@ func (a *AIClient) doGenerateAPI(reqBody map[string]any) (string, error) { } return "", classifyAIError(kind, fmt.Errorf("HTTP %d: %s", resp.StatusCode, truncateString(string(respBody), 200))) } - var result map[string]any if err := json.Unmarshal(respBody, &result); err != nil { return "", classifyAIError(AIErrorResponse, err) } + + answer, err := extractAIResponseText(protocol, result) + if err != nil { + return "", classifyAIError(AIErrorResponse, err) + } + return answer, nil +} + +func buildAIRequestBody(protocol string, model string, prompt string, systemPrompt string) map[string]any { + if protocol == "responses" { + return map[string]any{ + "model": model, + "instructions": systemPrompt, + "input": prompt, + "temperature": 0.7, + "max_output_tokens": 200, + } + } + return map[string]any{ + "model": model, + "messages": []map[string]string{ + {"role": "system", "content": systemPrompt}, + {"role": "user", "content": prompt}, + }, + "temperature": 0.7, + "max_tokens": 200, + } +} + +func extractAIResponseText(protocol string, result map[string]any) (string, error) { + if protocol == "responses" { + if text, ok := result["output_text"].(string); ok && strings.TrimSpace(text) != "" { + return strings.TrimSpace(text), nil + } + if output, ok := result["output"].([]any); ok { + for _, item := range output { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + if text := joinAITextParts(itemMap["content"]); text != "" { + return text, nil + } + } + } + return "", fmt.Errorf("Responses API 返回内容为空") + } + if choices, ok := result["choices"].([]any); ok && len(choices) > 0 { if choice, ok := choices[0].(map[string]any); ok { if message, ok := choice["message"].(map[string]any); ok { - if content, ok := message["content"].(string); ok { - return strings.TrimSpace(content), nil + if text := joinAITextParts(message["content"]); text != "" { + return text, nil } } } } + return "", fmt.Errorf("响应格式错误") +} - return "", classifyAIError(AIErrorResponse, fmt.Errorf("响应格式错误")) +func joinAITextParts(content any) string { + switch value := content.(type) { + case string: + return strings.TrimSpace(value) + case []any: + parts := make([]string, 0, len(value)) + for _, item := range value { + switch part := item.(type) { + case string: + if text := strings.TrimSpace(part); text != "" { + parts = append(parts, text) + } + case map[string]any: + itemType := strings.ToLower(strings.TrimSpace(fmt.Sprint(part["type"]))) + text := strings.TrimSpace(fmt.Sprint(firstNonEmpty(part["text"], part["content"]))) + if text != "" && (itemType == "text" || itemType == "output_text" || itemType == "input_text") { + parts = append(parts, text) + } + } + } + return strings.TrimSpace(strings.Join(parts, "\n")) + default: + return "" + } +} + +func firstNonEmpty(values ...any) any { + for _, value := range values { + if strings.TrimSpace(fmt.Sprint(value)) != "" && fmt.Sprint(value) != "" { + return value + } + } + return "" } func classifyAIError(kind string, err error) error { @@ -161,6 +363,19 @@ func isRetryableAIError(err error) bool { } } +func isEndpointMismatchAIError(err error) bool { + if err == nil { + return false + } + text := strings.ToLower(err.Error()) + for _, marker := range []string{"404", "405", "410", "not found", "no route", "no handler", "unsupported path", "invalid url", "method not allowed"} { + if strings.Contains(text, marker) { + return true + } + } + return false +} + func isTimeoutError(err error) bool { if err == nil { return false @@ -176,7 +391,7 @@ func truncateString(s string, n int) string { return s[:n] + "..." } -func buildAIPrompt(questionTitle, questionType string, blankCount int) string { +func (a *AIClient) buildPrompt(questionTitle, questionType string, blankCount int) string { cleaned := cleanQuestionTitle(questionTitle) if blankCount > 1 { return fmt.Sprintf("题目:%s\n题型:%s\n这是一个包含 %d 个空格的填空题,请为每个空格生成一个答案,用 | 分隔。", cleaned, questionType, blankCount) diff --git a/internal/questions/runtime_test.go b/internal/questions/runtime_test.go index 91d9261..246f3dc 100644 --- a/internal/questions/runtime_test.go +++ b/internal/questions/runtime_test.go @@ -119,6 +119,62 @@ func TestAIClientRetriesTransientHTTPFailure(t *testing.T) { } } +func TestAIClientCallsResponsesAPI(t *testing.T) { + var gotPath string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"output_text":"连接成功"}`)) + })) + defer server.Close() + + client := NewAIClient(AIConfig{ + Provider: "custom", + APIKey: "test-key", + BaseURL: server.URL + "/v1", + Protocol: "responses", + Model: "test-model", + }) + + got, err := client.GenerateAnswer("测试", "fill_blank", 1) + if err != nil { + t.Fatalf("GenerateAnswer returned error: %v", err) + } + if got != "连接成功" || gotPath != "/v1/responses" { + t.Fatalf("answer=%q path=%q, want responses answer and endpoint", got, gotPath) + } +} + +func TestAIClientAutoFallsBackToResponsesOnEndpointMismatch(t *testing.T) { + paths := []string{} + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + paths = append(paths, r.URL.Path) + if r.URL.Path == "/v1/chat/completions" { + http.NotFound(w, r) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"output":[{"content":[{"type":"output_text","text":"fallback ok"}]}]}`)) + })) + defer server.Close() + + client := NewAIClient(AIConfig{ + Provider: "custom", + APIKey: "test-key", + BaseURL: server.URL + "/v1", + Protocol: "auto", + Model: "test-model", + }) + + got, err := client.GenerateAnswer("测试", "fill_blank", 1) + if err != nil { + t.Fatalf("GenerateAnswer returned error: %v", err) + } + if got != "fallback ok" || strings.Join(paths, ",") != "/v1/chat/completions,/v1/responses" { + t.Fatalf("answer=%q paths=%v, want chat then responses fallback", got, paths) + } +} + func TestAIClientClassifiesConfigErrorWithoutRetry(t *testing.T) { client := NewAIClient(AIConfig{}) diff --git a/internal/tasks/store.go b/internal/tasks/store.go index 339be73..b4c12c1 100644 --- a/internal/tasks/store.go +++ b/internal/tasks/store.go @@ -128,6 +128,7 @@ func (s *Store) SaveTask(task *TaskRecord) error { if task == nil || task.ID == "" { return errors.New("任务为空") } + syncTaskDerivedFields(task) data, err := json.Marshal(task) if err != nil { return err diff --git a/internal/tasks/tasks.go b/internal/tasks/tasks.go index 3c70a0a..4b0b320 100644 --- a/internal/tasks/tasks.go +++ b/internal/tasks/tasks.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "sort" + "strings" "sync" "time" @@ -52,6 +53,9 @@ func (m *TaskManager) Load() []error { task.Status = TaskInterrupted task.FinishedAt = &now task.Error = "服务重启,任务已中断" + task.ErrorCode = "task_interrupted" + task.FailureReason = task.Error + task.TerminalStopCategory = "task_interrupted" errs = appendSaveErr(errs, m.store.SaveTask(task)) } m.tasks[task.ID] = task @@ -139,6 +143,9 @@ func (m *TaskManager) Stop(id string) (*TaskRecord, error) { } task.Status = TaskStopped task.StopMessage = "用户请求停止" + task.ErrorCode = "user_stopped" + task.FailureReason = task.StopMessage + task.TerminalStopCategory = "user_stopped" if task.State != nil { task.State.SignalStop() } @@ -222,16 +229,27 @@ func (m *TaskManager) run(ctx context.Context, id string) { current.State = state current.FinishedAt = &finished delete(m.runtimes, id) + category, reason, message := terminalSnapshot(state) + current.TerminalStopCategory = category if current.Status == TaskStopped || ctx.Err() != nil { current.Status = TaskStopped if current.StopMessage == "" { current.StopMessage = "任务已停止" } + if current.TerminalStopCategory == "" { + current.TerminalStopCategory = "user_stopped" + } + current.ErrorCode = "user_stopped" + current.FailureReason = firstNonEmpty(reason, message, current.StopMessage) } else if err != nil { current.Status = TaskFailed current.Error = err.Error() + current.ErrorCode = classifyTaskErrorCode(category, reason, err) + current.FailureReason = firstNonEmpty(reason, message, err.Error()) } else { current.Status = TaskSucceeded + current.ErrorCode = "" + current.FailureReason = "" } snapshot := cloneTask(current) m.mu.Unlock() @@ -350,6 +368,7 @@ func (m *TaskManager) appendLog(id string, entry TaskLog) { } func (m *TaskManager) saveTask(task *TaskRecord) { + syncTaskDerivedFields(task) if err := m.store.SaveTask(task); err != nil { logging.ErrorFields("保存任务状态失败", logging.F("task_id", task.ID), logging.F("error", err)) } @@ -411,20 +430,144 @@ func cloneTask(task *TaskRecord) *TaskRecord { copy := *task copy.Config = cloneRuntimeConfig(task.Config) copy.State = task.State.Snapshot() + syncTaskDerivedFields(©) return © } +func syncTaskDerivedFields(task *TaskRecord) { + if task == nil { + return + } + task.Progress = buildTaskProgress(task) + if task.TerminalStopCategory == "" { + task.TerminalStopCategory = deriveTerminalStopCategory(task) + } + if task.FailureReason == "" { + task.FailureReason = deriveFailureReason(task) + } + if task.ErrorCode == "" { + task.ErrorCode = deriveErrorCode(task) + } +} + +func buildTaskProgress(task *TaskRecord) *TaskProgress { + target := 0 + if task.Config != nil { + target = task.Config.Target + } + success := 0 + fail := 0 + if task.State != nil { + success = task.State.GetCurNum() + fail = task.State.GetCurFail() + } + current := success + percent := 0.0 + if target > 0 { + percent = float64(current) / float64(target) + if percent > 1 { + percent = 1 + } + } + return &TaskProgress{ + Current: current, + Target: target, + Success: success, + Fail: fail, + Percent: percent, + } +} + +func deriveErrorCode(task *TaskRecord) string { + switch task.Status { + case TaskFailed: + return classifyTaskErrorCode(task.TerminalStopCategory, task.FailureReason, errors.New(task.Error)) + case TaskStopped: + return "user_stopped" + case TaskInterrupted: + return "task_interrupted" + default: + return "" + } +} + +func deriveTerminalStopCategory(task *TaskRecord) string { + if task.State == nil { + return "" + } + category, _, _ := task.State.GetTerminalStopSnapshot() + return category +} + +func deriveFailureReason(task *TaskRecord) string { + if task.State != nil { + _, reason, message := task.State.GetTerminalStopSnapshot() + if reason != "" { + return reason + } + if message != "" { + return message + } + } + return firstNonEmpty(task.Error, task.StopMessage) +} + +func terminalSnapshot(state *runstate.ExecutionState) (category, reason, message string) { + if state == nil { + return "", "", "" + } + return state.GetTerminalStopSnapshot() +} + +func classifyTaskErrorCode(category, reason string, err error) string { + for _, value := range []string{reason, category} { + switch value { + case "proxy_unavailable", + "page_load_failed", + "fill_failed", + "submission_verification_required", + "survey_provider_unavailable", + "user_stopped", + "reverse_fill_exhausted", + "ai_unstable": + return value + case "fail_threshold", + "proxy_unavailable_threshold", + "submission_verification_threshold": + return value + } + } + if err == nil { + return "execution_error" + } + message := err.Error() + switch { + case strings.Contains(message, "必须提供问卷链接"): + return "validation_error" + case strings.Contains(message, "解析问卷失败"): + return "survey_parse_failed" + case strings.Contains(message, "准备执行配置失败"): + return "config_error" + default: + return "execution_error" + } +} + +func firstNonEmpty(values ...string) string { + for _, value := range values { + if value != "" { + return value + } + } + return "" +} + func DefaultTaskManager() (*TaskManager, error) { return DefaultTaskManagerWithStore("data/surveycore.db") } func DefaultTaskManagerWithStore(dbPath string) (*TaskManager, error) { - store := NewStore(dbPath) - if err := store.Init(); err != nil { - return nil, err - } - manager := NewTaskManager(store, providers.Default()) - return manager, nil + return DefaultTaskManagerWithStoreAndExecutionDefaults(dbPath, nil) } func DefaultTaskManagerWithStoreAndExecutionDefaults(dbPath string, executionDefaults func(*execution.ExecutionConfig)) (*TaskManager, error) { diff --git a/internal/tasks/tasks_test.go b/internal/tasks/tasks_test.go index 8acc250..a7b2620 100644 --- a/internal/tasks/tasks_test.go +++ b/internal/tasks/tasks_test.go @@ -6,6 +6,7 @@ import ( "testing" "time" + "github.com/SurveyController/SurveyCore/internal/execution" runstate "github.com/SurveyController/SurveyCore/internal/runtime" "github.com/SurveyController/SurveyCore/internal/models" @@ -72,6 +73,21 @@ func TestTaskManagerLoadMarksRunningInterruptedAndSkipsBadRecord(t *testing.T) { } } +func TestTaskManagerAppliesExecutionDefaults(t *testing.T) { + manager := NewTaskManagerWithExecutionDefaults(nil, nil, func(cfg *execution.ExecutionConfig) { + cfg.AIBaseURL = "https://ai.example.test/v1" + cfg.AIModel = "test-model" + cfg.AIAPIKey = "test-key" + }) + cfg := &execution.ExecutionConfig{} + + manager.applyExecutionDefaults(cfg) + + if cfg.AIBaseURL != "https://ai.example.test/v1" || cfg.AIModel != "test-model" || cfg.AIAPIKey != "test-key" { + t.Fatalf("execution defaults = %#v, want AI defaults", cfg) + } +} + func TestStoreLoadLogsUsesCursorPagination(t *testing.T) { store := NewStore(filepath.Join(t.TempDir(), "surveycore.db")) if err := store.Init(); err != nil { @@ -130,3 +146,73 @@ func TestCloneTaskSnapshotsExecutionState(t *testing.T) { t.Fatalf("cloned status = %q, want snapshot value", got) } } + +func TestCloneTaskAddsProgressAndFailureFields(t *testing.T) { + state := runstate.NewExecutionState() + state.IncrementSuccess() + state.MarkTerminalStop("fill_failed", "fill_failed", "填写失败") + + task := &TaskRecord{ + ID: "task-1", + Status: TaskFailed, + Config: &models.RuntimeConfig{Target: 4}, + State: state, + Error: "执行失败", + } + cloned := cloneTask(task) + + if cloned.Progress == nil || cloned.Progress.Target != 4 || cloned.Progress.Success != 1 || cloned.Progress.Percent != 0.25 { + t.Fatalf("progress = %#v, want stable summary", cloned.Progress) + } + if cloned.ErrorCode != "fill_failed" || cloned.FailureReason != "fill_failed" || cloned.TerminalStopCategory != "fill_failed" { + t.Fatalf("error fields = %q/%q/%q, want standardized fill failure", cloned.ErrorCode, cloned.FailureReason, cloned.TerminalStopCategory) + } +} + +func TestTaskErrorCodeUsesTerminalCategoryAndErrorFallback(t *testing.T) { + tests := []struct { + name string + task *TaskRecord + wantCode string + wantReason string + }{ + { + name: "submission verification", + task: terminalTask("submission_verification", "submission_verification_required", "触发智能验证"), + wantCode: "submission_verification_required", + wantReason: "submission_verification_required", + }, + { + name: "reverse fill exhausted category", + task: terminalTask("reverse_fill_exhausted", "", "反填样本已耗尽"), + wantCode: "reverse_fill_exhausted", + wantReason: "反填样本已耗尽", + }, + { + name: "parse failure fallback", + task: &TaskRecord{ID: "task-1", Status: TaskFailed, Error: "解析问卷失败: upstream timeout"}, + wantCode: "survey_parse_failed", + wantReason: "解析问卷失败: upstream timeout", + }, + { + name: "config failure fallback", + task: &TaskRecord{ID: "task-1", Status: TaskFailed, Error: "准备执行配置失败: answer window invalid"}, + wantCode: "config_error", + wantReason: "准备执行配置失败: answer window invalid", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cloned := cloneTask(tt.task) + if cloned.ErrorCode != tt.wantCode || cloned.FailureReason != tt.wantReason { + t.Fatalf("error fields = %q/%q, want %q/%q", cloned.ErrorCode, cloned.FailureReason, tt.wantCode, tt.wantReason) + } + }) + } +} + +func terminalTask(category, reason, message string) *TaskRecord { + state := runstate.NewExecutionState() + state.MarkTerminalStop(category, reason, message) + return &TaskRecord{ID: "task-1", Status: TaskFailed, State: state, Error: message} +} diff --git a/internal/tasks/types.go b/internal/tasks/types.go index 1cfa71f..51497a6 100644 --- a/internal/tasks/types.go +++ b/internal/tasks/types.go @@ -21,15 +21,28 @@ const ( // TaskRecord is the persisted task state. type TaskRecord struct { - ID string `json:"id"` - Status string `json:"status"` - Config *models.RuntimeConfig `json:"config"` - State *runstate.ExecutionState `json:"state,omitempty"` - CreatedAt time.Time `json:"created_at"` - StartedAt *time.Time `json:"started_at,omitempty"` - FinishedAt *time.Time `json:"finished_at,omitempty"` - Error string `json:"error,omitempty"` - StopMessage string `json:"stop_message,omitempty"` + ID string `json:"id"` + Status string `json:"status"` + Config *models.RuntimeConfig `json:"config"` + State *runstate.ExecutionState `json:"state,omitempty"` + Progress *TaskProgress `json:"progress,omitempty"` + CreatedAt time.Time `json:"created_at"` + StartedAt *time.Time `json:"started_at,omitempty"` + FinishedAt *time.Time `json:"finished_at,omitempty"` + Error string `json:"error,omitempty"` + ErrorCode string `json:"error_code,omitempty"` + FailureReason string `json:"failure_reason,omitempty"` + TerminalStopCategory string `json:"terminal_stop_category,omitempty"` + StopMessage string `json:"stop_message,omitempty"` +} + +// TaskProgress is a stable summary of task progress for API consumers. +type TaskProgress struct { + Current int `json:"current"` + Target int `json:"target"` + Success int `json:"success"` + Fail int `json:"fail"` + Percent float64 `json:"percent"` } // TaskLog is one persisted task log entry.