diff --git a/cmd/tools/gen-error-codes/main.go b/cmd/tools/gen-error-codes/main.go index bd89808..abb30a1 100644 --- a/cmd/tools/gen-error-codes/main.go +++ b/cmd/tools/gen-error-codes/main.go @@ -28,6 +28,7 @@ import ( // so the registry is populated before we read it. _ "github.com/sqlrush/opendbx/internal/app/skills" // spec-2.1 D-6: SKILL.* codes _ "github.com/sqlrush/opendbx/internal/app/skills/invoke" // spec-2.3 D-6: SKILL.* invoke codes + _ "github.com/sqlrush/opendbx/internal/app/tools/dbquery" // spec-2.3a D-5: DB.QUERY_INPUT_INVALID _ "github.com/sqlrush/opendbx/internal/entrypoints" "github.com/sqlrush/opendbx/internal/platform/errcode" _ "github.com/sqlrush/opendbx/internal/platform/logger" diff --git a/internal/app/tools/dbquery/doc.go b/internal/app/tools/dbquery/doc.go new file mode 100644 index 0000000..29a6a55 --- /dev/null +++ b/internal/app/tools/dbquery/doc.go @@ -0,0 +1,20 @@ +// Copyright 2026 opendbx contributors. See LICENSE. +// +// Author: sqlrush + +// Package dbquery provides the db_query diagnose ToolExecutor (spec-2.3a): +// a read-only SQL entry the LLM can call against the active PostgreSQL +// connection. It is the tool substrate the bundled DB diagnostic skills +// (spec-2.3b) drive via `allowed-tools: db_query`. +// +// Read-only is enforced SERVER-side (the db.QueryConn driver runs every +// statement inside a READ ONLY transaction); this package never inspects +// SQL text. Failures are two-track per the spec-1.21 ToolExecutor +// contract: ctx cancel/deadline → fatal Go error; semantic failures +// (bad input, query error, read-only violation) → recoverable ToolOutput +// with IsError so the model self-corrects. +// +// Layer hygiene: this package imports db / diagnose / llm / errcode and +// stdlib only. The connection factory is injected as a closure by +// bootstrap (openFn) so dbquery imports neither config nor bootstrap. +package dbquery diff --git a/internal/app/tools/dbquery/errors.go b/internal/app/tools/dbquery/errors.go new file mode 100644 index 0000000..a9370e6 --- /dev/null +++ b/internal/app/tools/dbquery/errors.go @@ -0,0 +1,37 @@ +// Copyright 2026 opendbx contributors. See LICENSE. +// +// Author: sqlrush + +// File errors.go — db_query input-validation errcode (spec-2.3a D-5). +// Chinese message/hint to match the existing DB.* convention in +// domain/db/errors.go (spec-2.3a Q11 user decision — same DB package +// family, do not mix languages within it). +// +// This is the only NEW code owned by this package; the read/connect/ +// timeout/readonly codes all live in domain/db (registered by spec-1.18 +// and spec-2.3a D-5) and are reused via the sanitized classify path. + +package dbquery + +import "github.com/sqlrush/opendbx/internal/platform/errcode" + +//nolint:gochecknoglobals // spec-0.6 contract: errcode sentinels are package-level. +var ( + // ErrQueryInputInvalid — the db_query tool input is malformed (missing / + // non-string / blank "sql"). Feedback class: the composed Content is + // written into a recoverable ToolResult so the model self-corrects. + ErrQueryInputInvalid = errcode.Register( + "DB.QUERY_INPUT_INVALID", + "db_query 入参无效", + `用 {"sql": "<一条只读语句>"} 调用 db_query`, + ) +) + +// inputInvalidContent composes the recoverable ToolOutput.Content for a +// malformed db_query input (spec-2.3a D-5). detail is the fixed-variant +// reason (missing/non-string skill field etc.). Mirrors the spec-2.3 +// invoke template shape: "[CODE] message: detail. Hint: hint." +func inputInvalidContent(detail string) string { + return "[" + ErrQueryInputInvalid.Code() + "] " + ErrQueryInputInvalid.Message() + + ": " + detail + ". Hint: " + ErrQueryInputInvalid.Hint() + "." +} diff --git a/internal/app/tools/dbquery/errors_test.go b/internal/app/tools/dbquery/errors_test.go new file mode 100644 index 0000000..d0db8a3 --- /dev/null +++ b/internal/app/tools/dbquery/errors_test.go @@ -0,0 +1,38 @@ +// Copyright 2026 opendbx contributors. See LICENSE. +// +// Author: sqlrush + +package dbquery + +import ( + "strings" + "testing" +) + +// TestErrQueryInputInvalid_Registered — the new code carries the Rule 7 +// triple with the SKILL-less DB. prefix and Chinese text (spec-2.3a Q11). +func TestErrQueryInputInvalid_Registered(t *testing.T) { + t.Parallel() + if ErrQueryInputInvalid.Code() != "DB.QUERY_INPUT_INVALID" { + t.Errorf("code = %q", ErrQueryInputInvalid.Code()) + } + if ErrQueryInputInvalid.Message() == "" || ErrQueryInputInvalid.Hint() == "" { + t.Error("missing message/hint (Rule 7 triple)") + } +} + +// TestInputInvalidContent — template carries code, message, the detail +// variant, and the call-shape hint (spec-2.3a D-5). +func TestInputInvalidContent(t *testing.T) { + t.Parallel() + got := inputInvalidContent("缺少或非字符串 sql 字段") + for _, want := range []string{ + "DB.QUERY_INPUT_INVALID", + "缺少或非字符串 sql 字段", + `{"sql":`, + } { + if !strings.Contains(got, want) { + t.Errorf("inputInvalidContent missing %q in %q", want, got) + } + } +} diff --git a/internal/app/tools/dbquery/render.go b/internal/app/tools/dbquery/render.go new file mode 100644 index 0000000..8693371 --- /dev/null +++ b/internal/app/tools/dbquery/render.go @@ -0,0 +1,116 @@ +// Copyright 2026 opendbx contributors. See LICENSE. +// +// Author: sqlrush + +// File render.go — aligned text-table rendering of a db.QueryResult for +// the tool result content (spec-2.3a D-3 / Q5). +// +// This is WIRE text for the LLM (and /report), NOT terminal UI — column +// alignment uses rune counts, not East-Asian display width (that is the +// § 3.9 UI table concern, explicitly out of scope here, ❌-9). The total +// output is capped at 16KiB with a UTF-8-safe cut so a multibyte rune at +// the boundary is never split (spec-2.3a codex MED-1; the 64KiB→16KiB cap +// also removes the conflict with spec-2.3 R-7's body warning). + +package dbquery + +import ( + "fmt" + "strings" + "unicode/utf8" + + "github.com/sqlrush/opendbx/internal/domain/db" +) + +// maxOutputBytes caps the rendered table (spec-2.3a Q6, user decision +// 64KiB→16KiB; ≈4K tokens, within the 规则 19 context budget). +const maxOutputBytes = 16 << 10 + +// truncSuffix is appended when the rendered table exceeds maxOutputBytes. +const truncSuffix = "\n…(output truncated at 16KiB)" + +// renderTable renders the result as an aligned text table plus a footer. +// Empty column set → "(0 rows)". +func renderTable(res db.QueryResult) string { + if len(res.Columns) == 0 { + return "(0 rows)" + } + + widths := make([]int, len(res.Columns)) + for i, c := range res.Columns { + widths[i] = runeLen(c) + } + for _, row := range res.Rows { + for i, cell := range row { + if w := runeLen(cell); i < len(widths) && w > widths[i] { + widths[i] = w + } + } + } + + var b strings.Builder + writeRow(&b, res.Columns, widths) + seps := make([]string, len(widths)) + for i, w := range widths { + seps[i] = strings.Repeat("-", w) + } + writeRow(&b, seps, widths) + for _, row := range res.Rows { + writeRow(&b, row, widths) + } + b.WriteString(footer(res)) + + return capUTF8(b.String(), maxOutputBytes) +} + +// writeRow writes one left-aligned row. The trailing column is not padded +// (no dangling spaces — keeps goldens tight). Columns are separated by two +// spaces. +func writeRow(b *strings.Builder, cells []string, widths []int) { + for i, cell := range cells { + if i > 0 { + b.WriteString(" ") + } + b.WriteString(cell) + if i < len(cells)-1 { + if pad := widths[i] - runeLen(cell); pad > 0 { + b.WriteString(strings.Repeat(" ", pad)) + } + } + } + b.WriteByte('\n') +} + +// footer is the row-count line with truncation annotations (spec-2.3a Q5). +func footer(res db.QueryResult) string { + n := len(res.Rows) + var s string + if res.RowsTruncated { + s = fmt.Sprintf("(first %d rows; result truncated)", n) + } else { + s = fmt.Sprintf("(%d rows)", n) + } + if res.CellsTruncated { + s += " (some cells truncated)" + } + return s +} + +// capUTF8 truncates s to at most max bytes, never splitting a rune, and +// appends a visible note when it cuts (spec-2.3a codex MED-1). +func capUTF8(s string, max int) string { + if len(s) <= max { + return s + } + budget := max - len(truncSuffix) + if budget < 0 { + budget = 0 + } + cut := s[:budget] + for len(cut) > 0 && !utf8.ValidString(cut) { + cut = cut[:len(cut)-1] + } + return cut + truncSuffix +} + +func runeLen(s string) int { return utf8.RuneCountInString(s) } diff --git a/internal/app/tools/dbquery/render_test.go b/internal/app/tools/dbquery/render_test.go new file mode 100644 index 0000000..f1d03f6 --- /dev/null +++ b/internal/app/tools/dbquery/render_test.go @@ -0,0 +1,89 @@ +// Copyright 2026 opendbx contributors. See LICENSE. +// +// Author: sqlrush + +package dbquery + +import ( + "strings" + "testing" + + "github.com/sqlrush/opendbx/internal/domain/db" +) + +// TestRenderTable_Aligned — header + separator + rows + footer, columns +// left-aligned to max width, two-space gap, no trailing pad (spec-2.3a Q5). +func TestRenderTable_Aligned(t *testing.T) { + t.Parallel() + res := db.QueryResult{ + Columns: []string{"id", "name"}, + Rows: [][]string{{"1", "alice"}, {"20", "bo"}}, + } + got := renderTable(res) + want := "id name\n" + + "-- -----\n" + + "1 alice\n" + + "20 bo\n" + + "(2 rows)" + if got != want { + t.Errorf("renderTable =\n%q\nwant\n%q", got, want) + } +} + +// TestRenderTable_Empty — zero columns → "(0 rows)". +func TestRenderTable_Empty(t *testing.T) { + t.Parallel() + if got := renderTable(db.QueryResult{}); got != "(0 rows)" { + t.Errorf("empty = %q; want (0 rows)", got) + } +} + +// TestRenderTable_TruncationFooters — rows/cells truncation annotations. +func TestRenderTable_TruncationFooters(t *testing.T) { + t.Parallel() + res := db.QueryResult{ + Columns: []string{"c"}, + Rows: [][]string{{"x"}}, + RowsTruncated: true, + CellsTruncated: true, + } + got := renderTable(res) + if !strings.Contains(got, "(first 1 rows; result truncated) (some cells truncated)") { + t.Errorf("footer missing truncation notes:\n%s", got) + } +} + +// TestCapUTF8_RuneSafe — a 16KiB cut never splits a multibyte rune +// (spec-2.3a codex MED-1; the 1.20.2 CJK mojibake class). +func TestCapUTF8_RuneSafe(t *testing.T) { + t.Parallel() + // Build a string well over the cap of all CJK runes (3 bytes each). + big := strings.Repeat("世", maxOutputBytes) // 3*maxOutputBytes bytes + got := capUTF8(big, maxOutputBytes) + if len(got) > maxOutputBytes { + t.Errorf("capUTF8 len %d > max %d", len(got), maxOutputBytes) + } + if !strings.HasSuffix(got, truncSuffix) { + t.Error("missing truncation suffix") + } + // The body (minus suffix) must be valid UTF-8 (no split rune). + body := strings.TrimSuffix(got, truncSuffix) + if strings.ContainsRune(body, '�') { + t.Error("replacement char — a rune was split") + } + for _, r := range body { + if r != '世' { + t.Errorf("unexpected rune %q — boundary split", r) + break + } + } +} + +// TestCapUTF8_UnderCap — short output passes through unchanged. +func TestCapUTF8_UnderCap(t *testing.T) { + t.Parallel() + s := "small table\n(1 rows)" + if got := capUTF8(s, maxOutputBytes); got != s { + t.Errorf("under-cap mutated: %q", got) + } +} diff --git a/internal/app/tools/dbquery/tool.go b/internal/app/tools/dbquery/tool.go new file mode 100644 index 0000000..2bac8df --- /dev/null +++ b/internal/app/tools/dbquery/tool.go @@ -0,0 +1,173 @@ +// Copyright 2026 opendbx contributors. See LICENSE. +// +// Author: sqlrush + +// File tool.go — the db_query ToolExecutor (spec-2.3a D-3). +// +// Two-track failure (spec-1.21 contract, codex CRIT-1): ctx cancel / +// deadline is FATAL — returned as a Go error so the Loop's +// classifyToolErr produces FinishCancelled / DIAGNOSE.TOOL_TIMEOUT rather +// than hiding it behind a recoverable [DB.TIMEOUT] tool result. Every +// open/query error is checked for a ctx root FIRST (the postgres classify +// preserves it under Unwrap). Only non-ctx failures become recoverable +// IsError tool results the model can self-correct on. + +package dbquery + +import ( + "context" + "errors" + "strings" + "sync" + + "github.com/sqlrush/opendbx/internal/app/diagnose" + "github.com/sqlrush/opendbx/internal/domain/db" + "github.com/sqlrush/opendbx/internal/domain/llm" + "github.com/sqlrush/opendbx/internal/platform/errcode" +) + +// ToolName is the wire-visible tool identifier — lowercase opendbx-native +// (clock/echo family; the CC-parity capitalization rule applies only to +// CC tools, and DB query has no CC counterpart — spec-2.3a Q3). +const ToolName = "db_query" + +// errDriverNoQuery is an internal sentinel: the opened db.Conn does not +// implement db.QueryConn (a future non-query driver). Never escapes +// Execute — it is converted to a recoverable tool result. +var errDriverNoQuery = errors.New("driver does not support read-only queries") + +// Tool is the db_query ToolExecutor. The connection is opened lazily on +// first Execute (startup stays DB-I/O-free, spec-2.3a Q4) and cached. +type Tool struct { + openFn func(ctx context.Context) (db.Conn, error) + mu sync.Mutex // defense-in-depth only: Loop is serial and the + // registry is built per-session, so there is no + // concurrent Execute today (spec-2.3a arch HIGH-1). + conn db.QueryConn // cached after first successful open+assert; nil until then +} + +// Compile-time: *Tool is a diagnose-loop-callable tool. +var _ diagnose.ToolExecutor = (*Tool)(nil) + +// New builds the tool with a lazy connection factory. bootstrap wraps +// OpenConnection in openFn so this package imports neither config nor +// bootstrap (layer hygiene). +func New(openFn func(ctx context.Context) (db.Conn, error)) *Tool { + return &Tool{openFn: openFn} +} + +// Name implements diagnose.ToolExecutor. +func (t *Tool) Name() string { return ToolName } + +// Schema implements diagnose.ToolExecutor. Decision-tree description +// (CLAUDE.md § 3.3): when to use / when NOT to use, not a usage manual. +func (t *Tool) Schema() llm.ToolSchema { + return llm.ToolSchema{ + Name: ToolName, + Description: "Run ONE read-only SQL statement against the active PostgreSQL " + + "connection (SELECT / WITH / EXPLAIN / SHOW). Use to inspect live database " + + "state — pg_stat_* views, catalog queries, plans. Do NOT use for writes: " + + "INSERT/UPDATE/DELETE/DDL are rejected server-side. Results are capped at " + + "100 rows and 16KiB.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "sql": map[string]any{ + "type": "string", + "description": "One read-only SQL statement.", + }, + }, + "required": []string{"sql"}, + }, + } +} + +// Execute implements diagnose.ToolExecutor (spec-2.3a D-3 steps 1-5). +func (t *Tool) Execute(ctx context.Context, input map[string]any) (diagnose.ToolOutput, error) { + // errcode-lint:exempt -- spec-1.21 D-4 two-track: ctx errors pass through unchanged; the Loop classifies cancel-vs-timeout. + if err := ctx.Err(); err != nil { + return diagnose.ToolOutput{}, err + } + sql, ok := input["sql"].(string) + if !ok || strings.TrimSpace(sql) == "" { + return diagnose.ToolOutput{ + Content: inputInvalidContent("缺少或非字符串 sql 字段"), + IsError: true, + }, nil + } + + qc, err := t.connection(ctx) + if err != nil { + if isCtxErr(ctx, err) { + // errcode-lint:exempt -- spec-2.3a D-3: ctx fatal precedes the DB.* recoverable path (codex CRIT-1). + return diagnose.ToolOutput{}, err + } + if errors.Is(err, errDriverNoQuery) { + return diagnose.ToolOutput{ + Content: "[" + db.ErrQueryFailed.Code() + "] 当前数据库 driver 不支持查询. Hint: 确认 driver 实现 QueryConn 能力", + IsError: true, + }, nil + } + return diagnose.ToolOutput{Content: dbErrContent(err), IsError: true}, nil + } + + res, err := qc.Query(ctx, sql, db.QueryOptions{}) + if err != nil { + if isCtxErr(ctx, err) { + // errcode-lint:exempt -- spec-2.3a D-3: ctx fatal precedes the DB.* recoverable path. + return diagnose.ToolOutput{}, err + } + return diagnose.ToolOutput{Content: dbErrContent(err), IsError: true}, nil + } + return diagnose.ToolOutput{Content: renderTable(res)}, nil +} + +// connection returns the cached QueryConn, opening it lazily on first use. +// A failed open is NOT cached (retryable on the next Execute — DB recovery +// needs no restart, spec-2.3a Q4). On a successful open whose Conn lacks +// the query capability, the Conn is Closed before erroring (no pool leak, +// codex HIGH-2). +func (t *Tool) connection(ctx context.Context) (db.QueryConn, error) { + t.mu.Lock() + defer t.mu.Unlock() + if t.conn != nil { + return t.conn, nil + } + c, err := t.openFn(ctx) + if err != nil { + return nil, err // not cached → retryable + } + qc, ok := c.(db.QueryConn) + if !ok { + _ = c.Close() // prevent pool leak (codex HIGH-2) + return nil, errDriverNoQuery + } + t.conn = qc + return qc, nil +} + +// isCtxErr reports whether err (or the live ctx) is a cancel/deadline that +// must terminate as a Go error rather than a recoverable tool result. The +// postgres classify wraps ctx into DB.TIMEOUT but preserves the root under +// Unwrap, so errors.Is still matches (spec-2.3a codex CRIT-1). +func isCtxErr(ctx context.Context, err error) bool { + return ctx.Err() != nil || + errors.Is(err, context.Canceled) || + errors.Is(err, context.DeadlineExceeded) +} + +// dbErrContent renders a sanitized DB.* error to recoverable tool-result +// text. The classify path guarantees errcode.Error with no DSN leak; the +// hint carries the actionable guidance (read-only violations already say +// "写操作被服务端拒绝" in their hint, spec-2.3a D-5). +func dbErrContent(err error) string { + var ec errcode.Error + if errors.As(err, &ec) { + if h := ec.Hint(); h != "" { + return ec.Error() + ". Hint: " + h + } + return ec.Error() + } + // Unreachable: classify always returns errcode.Error. Defensive only. + return "[" + db.ErrQueryFailed.Code() + "] " + err.Error() +} diff --git a/internal/app/tools/dbquery/tool_test.go b/internal/app/tools/dbquery/tool_test.go new file mode 100644 index 0000000..ca30cb8 --- /dev/null +++ b/internal/app/tools/dbquery/tool_test.go @@ -0,0 +1,229 @@ +// Copyright 2026 opendbx contributors. See LICENSE. +// +// Author: sqlrush + +package dbquery + +import ( + "context" + "errors" + "strings" + "sync" + "testing" + + "github.com/sqlrush/opendbx/internal/domain/db" +) + +// --- fakes --- + +// fakeQueryConn implements db.QueryConn. +type fakeQueryConn struct { + result db.QueryResult + queryErr error + closed int +} + +func (f *fakeQueryConn) Ping(context.Context) error { return nil } +func (f *fakeQueryConn) HealthCheck(context.Context) (db.Health, error) { + return db.Health{}, nil +} +func (f *fakeQueryConn) Close() error { f.closed++; return nil } +func (f *fakeQueryConn) Query(_ context.Context, _ string, _ db.QueryOptions) (db.QueryResult, error) { + return f.result, f.queryErr +} + +// fakeConn implements db.Conn only (NOT QueryConn) — for the capability +// assertion-failure path. +type fakeConn struct{ closed int } + +func (f *fakeConn) Ping(context.Context) error { return nil } +func (f *fakeConn) HealthCheck(context.Context) (db.Health, error) { return db.Health{}, nil } +func (f *fakeConn) Close() error { f.closed++; return nil } + +// --- Schema / Name --- + +func TestTool_NameAndSchema(t *testing.T) { + t.Parallel() + tool := New(func(context.Context) (db.Conn, error) { return &fakeQueryConn{}, nil }) + if tool.Name() != "db_query" { + t.Errorf("Name = %q", tool.Name()) + } + sch := tool.Schema() + if sch.Name != tool.Name() { + t.Errorf("Schema.Name %q must equal Name()", sch.Name) + } + req, _ := sch.InputSchema["required"].([]string) + if len(req) != 1 || req[0] != "sql" { + t.Errorf("required = %v; want [sql]", req) + } +} + +// --- input validation --- + +func TestTool_InputInvalid(t *testing.T) { + t.Parallel() + tool := New(func(context.Context) (db.Conn, error) { return &fakeQueryConn{}, nil }) + for _, in := range []map[string]any{ + {}, // missing + {"sql": 42}, // non-string + {"sql": " "}, // blank + } { + out, err := tool.Execute(context.Background(), in) + if err != nil { + t.Fatalf("semantic failure must not return Go error: %v", err) + } + if !out.IsError || !strings.Contains(out.Content, "DB.QUERY_INPUT_INVALID") { + t.Errorf("input %v → %+v; want IsError INPUT_INVALID", in, out) + } + } +} + +// --- ctx two-track (codex CRIT-1) --- + +func TestTool_CtxFatalAtEntry(t *testing.T) { + t.Parallel() + tool := New(func(context.Context) (db.Conn, error) { return &fakeQueryConn{}, nil }) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + out, err := tool.Execute(ctx, map[string]any{"sql": "SELECT 1"}) + if !errors.Is(err, context.Canceled) { + t.Fatalf("cancelled ctx at entry → err %v; want context.Canceled", err) + } + if out.IsError || out.Content != "" { + t.Errorf("ctx fatal must not yield recoverable output: %+v", out) + } +} + +func TestTool_CtxFatalOnOpen(t *testing.T) { + t.Parallel() + // openFn returns a ctx-wrapped error (as bootstrap's OpenConnection would + // under a deadline) — must be fatal, not a recoverable DB.TIMEOUT result. + tool := New(func(context.Context) (db.Conn, error) { + return nil, context.DeadlineExceeded + }) + out, err := tool.Execute(context.Background(), map[string]any{"sql": "SELECT 1"}) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("open ctx error → %v; want DeadlineExceeded fatal", err) + } + if out.IsError { + t.Errorf("must not be recoverable: %+v", out) + } +} + +func TestTool_CtxFatalOnQuery(t *testing.T) { + t.Parallel() + qc := &fakeQueryConn{queryErr: context.DeadlineExceeded} + tool := New(func(context.Context) (db.Conn, error) { return qc, nil }) + out, err := tool.Execute(context.Background(), map[string]any{"sql": "SELECT 1"}) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("query ctx error → %v; want DeadlineExceeded fatal", err) + } + if out.IsError { + t.Errorf("must not be recoverable: %+v", out) + } +} + +// --- capability assertion failure closes the conn (codex HIGH-2) --- + +func TestTool_CapabilityFailureClosesConn(t *testing.T) { + t.Parallel() + fc := &fakeConn{} + tool := New(func(context.Context) (db.Conn, error) { return fc, nil }) + out, err := tool.Execute(context.Background(), map[string]any{"sql": "SELECT 1"}) + if err != nil { + t.Fatalf("capability failure is recoverable, not Go error: %v", err) + } + if !out.IsError || !strings.Contains(out.Content, "不支持查询") { + t.Errorf("want recoverable 'driver does not support queries': %+v", out) + } + if fc.closed != 1 { + t.Errorf("non-query Conn must be Closed exactly once, got %d", fc.closed) + } +} + +// --- lazy open: failure not cached, retryable (Q4) --- + +func TestTool_LazyOpenRetryable(t *testing.T) { + t.Parallel() + var calls int + qc := &fakeQueryConn{result: db.QueryResult{Columns: []string{"c"}, Rows: [][]string{{"1"}}}} + tool := New(func(context.Context) (db.Conn, error) { + calls++ + if calls == 1 { + return nil, db.ErrConnectFailed // first open fails + } + return qc, nil + }) + // First Execute: open fails → recoverable, not cached. + out1, err := tool.Execute(context.Background(), map[string]any{"sql": "SELECT 1"}) + if err != nil || !out1.IsError { + t.Fatalf("first call should be recoverable failure: out=%+v err=%v", out1, err) + } + // Second Execute: retried open succeeds. + out2, err := tool.Execute(context.Background(), map[string]any{"sql": "SELECT 1"}) + if err != nil || out2.IsError { + t.Fatalf("second call should succeed (retry): out=%+v err=%v", out2, err) + } + if calls != 2 { + t.Errorf("openFn calls = %d; want 2 (first not cached)", calls) + } + // Third Execute: cached conn reused, no new open. + if _, err := tool.Execute(context.Background(), map[string]any{"sql": "SELECT 2"}); err != nil { + t.Fatal(err) + } + if calls != 2 { + t.Errorf("openFn calls after cache = %d; want still 2", calls) + } +} + +// --- readonly violation surfaces the read-only hint --- + +func TestTool_ReadOnlyViolationContent(t *testing.T) { + t.Parallel() + qc := &fakeQueryConn{queryErr: db.ErrReadOnlyViolation} + tool := New(func(context.Context) (db.Conn, error) { return qc, nil }) + out, err := tool.Execute(context.Background(), map[string]any{"sql": "INSERT ..."}) + if err != nil || !out.IsError { + t.Fatalf("readonly violation is recoverable: out=%+v err=%v", out, err) + } + if !strings.Contains(out.Content, "DB.READONLY_VIOLATION") || !strings.Contains(out.Content, "只读") { + t.Errorf("content must carry readonly code + hint: %q", out.Content) + } +} + +// --- success renders a table --- + +func TestTool_SuccessRendersTable(t *testing.T) { + t.Parallel() + qc := &fakeQueryConn{result: db.QueryResult{ + Columns: []string{"id", "name"}, + Rows: [][]string{{"1", "alice"}, {"2", "bob"}}, + }} + tool := New(func(context.Context) (db.Conn, error) { return qc, nil }) + out, err := tool.Execute(context.Background(), map[string]any{"sql": "SELECT id,name FROM t"}) + if err != nil || out.IsError { + t.Fatalf("success: out=%+v err=%v", out, err) + } + for _, want := range []string{"id", "name", "alice", "(2 rows)"} { + if !strings.Contains(out.Content, want) { + t.Errorf("table missing %q:\n%s", want, out.Content) + } + } +} + +// --- concurrency: mu guards conn (defense-in-depth) --- + +func TestTool_ConcurrentExecuteRace(t *testing.T) { + t.Parallel() + qc := &fakeQueryConn{result: db.QueryResult{Columns: []string{"c"}, Rows: [][]string{{"x"}}}} + tool := New(func(context.Context) (db.Conn, error) { return qc, nil }) + var wg sync.WaitGroup + for i := 0; i < 8; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _, _ = tool.Execute(context.Background(), map[string]any{"sql": "SELECT 1"}) + }() + } + wg.Wait() +} diff --git a/internal/bootstrap/connection.go b/internal/bootstrap/connection.go index c8bdf39..e3d40bd 100644 --- a/internal/bootstrap/connection.go +++ b/internal/bootstrap/connection.go @@ -18,7 +18,10 @@ package bootstrap import ( "context" + "errors" + "github.com/sqlrush/opendbx/internal/app/diagnose" + "github.com/sqlrush/opendbx/internal/app/tools/dbquery" "github.com/sqlrush/opendbx/internal/domain/db" "github.com/sqlrush/opendbx/internal/platform/config" "github.com/sqlrush/opendbx/internal/platform/errcode" @@ -90,7 +93,11 @@ func warnIfInsecureSSL(conn config.ConnectionConfig) { } switch mode { case "disable", "allow", "prefer": - logger.L().Warn("数据库连接未强制 TLS,凭据与查询可能明文传输", + // file-only: spec-2.3a moved OpenConnection to a lazy tool-execution + // path; a normal logger.Warn would tear the TUI cell grid under + // --debug-to-stderr (codex MED-3). WarnForceFile is file-only and + // never touches stderr (mirrors the spec-2.3 InfoForceFile precedent). + logger.WarnForceFile("数据库连接未强制 TLS,凭据与查询可能明文传输", logger.Attr{Key: "alias", Value: conn.Alias}, logger.Attr{Key: "sslmode", Value: mode}, logger.Attr{Key: "hint", Value: "生产环境请设 sslmode=require 或更高"}, @@ -115,3 +122,39 @@ func OpenConnection(ctx context.Context, cfg *config.Config, cliAlias string) (d // errcode-lint:exempt -- spec-1.18 D-4: db.Open returns a sanitized db.* errcode (or nil); this is the single sanctioned secret.Expose() call site (spec-1.19 D-8). return db.Open(ctx, conn.Driver, secret.Expose()) } + +// DBQueryExecutors returns the db_query tool executor when a connection can +// be selected from config (spec-2.3a D-4). It does NOT open the connection — +// selection only — so startup stays DB-I/O-free (the tool opens lazily on +// first Execute, spec-2.3a Q4). When no connection is usable it registers +// nothing and logs a differentiated reason (distinguishing "none configured" +// from "ambiguous — set default_connection"; codex/cr MED). +func DBQueryExecutors(cfg *config.Config) []diagnose.ToolExecutor { + if _, err := ActiveConnection(cfg, ""); err != nil { + logger.InfoForceFile("db_query tool not registered", + "spec", "2.3a", "reason", connUnavailableReason(err)) + return nil + } + openFn := func(ctx context.Context) (db.Conn, error) { + return OpenConnection(ctx, cfg, "") + } + return []diagnose.ToolExecutor{dbquery.New(openFn)} +} + +// connUnavailableReason maps an ActiveConnection error to an actionable +// debug-log reason (spec-2.3a D-4 / cr MED-2: ambiguous must not look like +// "no connection"). +func connUnavailableReason(err error) string { + var ec errcode.Error + if errors.As(err, &ec) { + switch ec.Code() { + case ErrNoConnection.Code(): + return "no database connection configured" + case ErrAmbiguous.Code(): + return "multiple connections but no default_connection set; set default_connection or pass --connection-alias" + case ErrUnknownAlias.Code(): + return "configured connection alias not found" + } + } + return "connection selection failed" +} diff --git a/internal/bootstrap/connection_test.go b/internal/bootstrap/connection_test.go index 6aeab3f..15c4866 100644 --- a/internal/bootstrap/connection_test.go +++ b/internal/bootstrap/connection_test.go @@ -7,6 +7,8 @@ package bootstrap import ( "context" "errors" + "io" + "os" "strings" "testing" @@ -14,6 +16,7 @@ import ( // postgres driver is registered by the production drivers.go side-effect // import (spec-1.19 R-fix); tests rely on that, not a test-only import. "github.com/sqlrush/opendbx/internal/platform/config" + "github.com/sqlrush/opendbx/internal/platform/logger" ) // noComposeDriver implements db.Driver but NOT db.DSNComposer. @@ -139,3 +142,90 @@ func TestOpenConnection(t *testing.T) { t.Errorf("OpenConnection error leaked password: %s", err.Error()) } } + +// --- spec-2.3a D-4: db_query registration --- + +// TestDBQueryExecutors_NoConnection — zero connections → not registered. +func TestDBQueryExecutors_NoConnection(t *testing.T) { + t.Parallel() + if got := DBQueryExecutors(&config.Config{}); got != nil { + t.Errorf("no connection → %v; want nil (not registered)", got) + } +} + +// TestDBQueryExecutors_Ambiguous — multiple connections, no default → not +// registered (the differentiated-reason path; codex/cr MED-2). +func TestDBQueryExecutors_Ambiguous(t *testing.T) { + t.Parallel() + cfg := &config.Config{Connections: []config.ConnectionConfig{ + {Alias: "a", Driver: "postgres", Host: "h", Database: "d", User: "u"}, + {Alias: "b", Driver: "postgres", Host: "h", Database: "d", User: "u"}, + }} + if got := DBQueryExecutors(cfg); got != nil { + t.Errorf("ambiguous → %v; want nil (not registered)", got) + } +} + +// TestDBQueryExecutors_Registered — a selectable connection registers exactly +// one db_query executor WITHOUT opening it (startup is DB-I/O-free). +func TestDBQueryExecutors_Registered(t *testing.T) { + t.Parallel() + cfg := &config.Config{Connections: []config.ConnectionConfig{ + {Alias: "only", Driver: "postgres", Host: "h", Database: "d", User: "u"}, + }} + got := DBQueryExecutors(cfg) + if len(got) != 1 || got[0].Name() != "db_query" { + t.Fatalf("registered = %v; want one db_query executor", got) + } +} + +// TestConnUnavailableReason — each ActiveConnection error maps to a distinct +// actionable reason (codex/cr MED-2). +func TestConnUnavailableReason(t *testing.T) { + t.Parallel() + _, noneErr := ActiveConnection(&config.Config{}, "") + _, ambErr := ActiveConnection(&config.Config{Connections: []config.ConnectionConfig{ + {Alias: "a"}, {Alias: "b"}, + }}, "") + none := connUnavailableReason(noneErr) + amb := connUnavailableReason(ambErr) + if none == amb { + t.Errorf("no-connection (%q) and ambiguous (%q) reasons must differ", none, amb) + } + if !strings.Contains(amb, "default_connection") { + t.Errorf("ambiguous reason should hint default_connection: %q", amb) + } +} + +// TestWarnIfInsecureSSL_FileOnly — the warning must reach the debug file +// only, never stderr, so it cannot tear the TUI cell grid under +// --debug-to-stderr (spec-2.3a codex MED-3). NOT parallel: mutates +// os.Stderr + the logger global. +func TestWarnIfInsecureSSL_FileOnly(t *testing.T) { + logPath := t.TempDir() + "/ssl.log" + oldStderr := os.Stderr + r, w, err := os.Pipe() + if err != nil { + t.Fatalf("pipe: %v", err) + } + os.Stderr = w + defer func() { os.Stderr = oldStderr }() + + if err := logger.Init(logger.InitInput{SessionID: "ssl", LogPath: logPath, DebugToStderr: true}); err != nil { + t.Fatalf("logger.Init: %v", err) + } + warnIfInsecureSSL(config.ConnectionConfig{Alias: "a", SSLMode: "disable"}) + + if err := w.Close(); err != nil { + t.Fatalf("close pipe: %v", err) + } + stderrRaw, _ := io.ReadAll(r) + if strings.Contains(string(stderrRaw), "TLS") { + t.Errorf("insecure-SSL warning tore the TUI via stderr: %q", stderrRaw) + } + // It must still be recorded in the debug file. + fileRaw, _ := os.ReadFile(logPath) + if !strings.Contains(string(fileRaw), "TLS") { + t.Errorf("warning missing from debug file:\n%s", fileRaw) + } +} diff --git a/internal/bootstrap/tui_launcher.go b/internal/bootstrap/tui_launcher.go index 1e30d16..9ab1a64 100644 --- a/internal/bootstrap/tui_launcher.go +++ b/internal/bootstrap/tui_launcher.go @@ -140,10 +140,16 @@ func newChatModel() program.Model { // log + skip; interact always starts (user decision 4/4 — no panic). // Skipped on the provider-error path: the session cannot chat, so the // filesystem scan would be wasted I/O (post-impl cr LOW-1). - var skillExecs []diagnose.ToolExecutor + var execs []diagnose.ToolExecutor var skillPrompt string if perr == nil { + var skillExecs []diagnose.ToolExecutor skillExecs, skillPrompt = skillsForChat(DiscoverSkills(cfg)) + execs = append(execs, skillExecs...) + // spec-2.3a D-4: register db_query when a connection can be selected + // (without opening it — lazy; startup stays DB-I/O-free). No usable + // connection → not registered + a differentiated debug log. + execs = append(execs, DBQueryExecutors(cfg)...) } opts := llmapp.Options{ ModelName: cfg.LLM.ActiveModel, @@ -151,7 +157,7 @@ func newChatModel() program.Model { StripThink: cfg.LLM.StripThink, ThinkingMode: thinkingModeFromConfig(cfg.LLM.ThinkingMode), ThinkingBudget: cfg.LLM.ThinkingBudget, - Registry: diagnoseRegistryWith(skillExecs...), + Registry: diagnoseRegistryWith(execs...), SystemPrompt: skillPrompt, // spec-1.21 D-6 user-config knobs reach the runtime here. // Per-turn LLM timeout reuses LLMConfig.RequestTimeout per spec diff --git a/internal/domain/db/driver.go b/internal/domain/db/driver.go index a750379..cca6954 100644 --- a/internal/domain/db/driver.go +++ b/internal/domain/db/driver.go @@ -5,10 +5,10 @@ // File driver.go — provider-agnostic database driver interface (spec-1.18 // D-1). The minimal common set across PG / MySQL / Oracle / openGauss // (§ 3.7 multi-DB matrix): connection lifecycle + liveness + a version -// health probe. NOTHING query-shaped — query execution arrives in spec-2.1 -// as a CAPABILITY interface (QueryConn = Conn + Query...), never by adding -// methods to Conn (that would be a Go interface break for every fake and -// future driver). +// health probe. NOTHING query-shaped — query execution arrived in spec-2.3a +// (errata: the original note said spec-2.1) as a CAPABILITY interface +// (QueryConn = Conn + Query...), never by adding methods to Conn (that would +// be a Go interface break for every fake and future driver). // // Design: spec-1.18-pg-driver. @@ -34,7 +34,7 @@ type Driver interface { // Conn is a live connection (or pool) to one database instance. FROZEN to // lifecycle + health for spec-1.18; query capability is added separately in -// spec-2.1 via a composed QueryConn interface. +// spec-2.3a via the composed QueryConn interface (query.go). type Conn interface { // Ping verifies the connection is alive, honouring ctx. Ping(ctx context.Context) error diff --git a/internal/domain/db/errors.go b/internal/domain/db/errors.go index d79585e..48e66b8 100644 --- a/internal/domain/db/errors.go +++ b/internal/domain/db/errors.go @@ -67,4 +67,16 @@ var ( "未注册的数据库 driver", "检查 driver 名是否受支持 (当前: postgres); 确认对应 driver 包已 import", ) + // ErrReadOnlyViolation — a write / DDL statement was attempted inside the + // server-side READ ONLY transaction the query tool runs every statement + // in (SQLSTATE 25006 read_only_sql_transaction). This is the db_query + // behavior contract surfaced to the model: it must distinguish "SQL is + // wrong" (ErrQueryFailed) from "this tool refuses writes" so it can + // self-correct toward a read-only statement (spec-2.3a D-5; 中文 per + // the existing DB.* convention, spec-2.3a Q11). + ErrReadOnlyViolation = errcode.Register( + "DB.READONLY_VIOLATION", + "语句被拒绝: 只读会话", + "db_query 仅执行只读语句; 写操作 (INSERT/UPDATE/DDL) 被服务端拒绝 (SQLSTATE 25006)", + ) ) diff --git a/internal/domain/db/postgres/errors_classify.go b/internal/domain/db/postgres/errors_classify.go index c4f31ab..f0b85c2 100644 --- a/internal/domain/db/postgres/errors_classify.go +++ b/internal/domain/db/postgres/errors_classify.go @@ -81,6 +81,11 @@ func classify(err error) error { return sanitized(db.ErrAuthFailed.Code(), err) case "3D000": // invalid_catalog_name — target database does not exist return sanitized(db.ErrConnectFailed.Code(), err) + case "25006": // read_only_sql_transaction — a write/DDL hit the db_query + // READ ONLY tx. This is the ONLY 25-class code mapped specially + // (spec-2.3a D-2 / cr HIGH-1): 0A000 (feature_not_supported) is + // NOT a read-only signal and stays QUERY_FAILED via the default. + return sanitized(db.ErrReadOnlyViolation.Code(), err) } switch pgErr.Code[:2] { // SQLSTATE class case "28": // invalid_authorization_specification / invalid_password @@ -89,7 +94,10 @@ func classify(err error) error { return sanitized(db.ErrConnectFailed.Code(), err) case "53", "57": // insufficient_resources / operator_intervention return sanitized(db.ErrUnavailable.Code(), err) - default: // 42 syntax / 3D other / 23 integrity / ... + default: // 42 syntax / 3D other / 23 integrity / 25 (non-25006) txn-state / ... + // opendbx: 25-class other than 25006 (e.g. 25001 active_sql_transaction) + // falls through to QUERY_FAILED — a simplification; 25006 is handled by + // the override above (spec-2.3a ❌-11 / cr LOW-1). return sanitized(db.ErrQueryFailed.Code(), err) } } diff --git a/internal/domain/db/postgres/health.go b/internal/domain/db/postgres/health.go index 64f1681..69ed1f1 100644 --- a/internal/domain/db/postgres/health.go +++ b/internal/domain/db/postgres/health.go @@ -24,10 +24,12 @@ import ( ) // pgxPool is the subset of *pgxpool.Pool that pgConn uses. Kept minimal so a -// fake can stand in for unit tests. +// fake can stand in for unit tests. spec-2.3a D-2 adds BeginTx for the +// read-only query path; *pgxpool.Pool satisfies it unchanged. type pgxPool interface { Ping(ctx context.Context) error QueryRow(ctx context.Context, sql string, args ...any) pgx.Row + BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error) Close() } diff --git a/internal/domain/db/postgres/integration_test.go b/internal/domain/db/postgres/integration_test.go index d14ceee..7b2cf5a 100644 --- a/internal/domain/db/postgres/integration_test.go +++ b/internal/domain/db/postgres/integration_test.go @@ -19,6 +19,7 @@ package postgres import ( "context" + "errors" "os" "strings" "testing" @@ -100,3 +101,93 @@ func TestIntegrationConnectRefusedNoLeak(t *testing.T) { t.Errorf("error leaked credentials: %s", err.Error()) } } + +// --- spec-2.3a D-6: real-PG read-only Query parked cases --- + +func openOrSkip(t *testing.T) (db.QueryConn, func()) { + t.Helper() + dsn := dsnOrSkip(t) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + conn, err := db.Open(ctx, "postgres", dsn) + if err != nil { + cancel() + t.Fatalf("Open: %v", err) + } + qc, ok := conn.(db.QueryConn) + if !ok { + _ = conn.Close() + cancel() + t.Fatal("postgres Conn does not implement db.QueryConn") + } + return qc, func() { _ = conn.Close(); cancel() } +} + +// TestIntegrationQuerySmoke — SELECT 1 round-trips through the read-only tx. +func TestIntegrationQuerySmoke(t *testing.T) { + qc, done := openOrSkip(t) + defer done() + res, err := qc.Query(context.Background(), "SELECT 1 AS one", db.QueryOptions{}) + if err != nil { + t.Fatalf("Query: %v", err) + } + if len(res.Columns) != 1 || res.Columns[0] != "one" || len(res.Rows) != 1 || res.Rows[0][0] != "1" { + t.Errorf("unexpected result: %+v", res) + } +} + +// TestIntegrationPgStatActivity — a real pg_stat_* diagnostic query returns +// known columns (the scenario bundled skills will drive via db_query). +func TestIntegrationPgStatActivity(t *testing.T) { + qc, done := openOrSkip(t) + defer done() + res, err := qc.Query(context.Background(), + "SELECT datname, state FROM pg_stat_activity LIMIT 5", db.QueryOptions{}) + if err != nil { + t.Fatalf("Query: %v", err) + } + if len(res.Columns) != 2 || res.Columns[0] != "datname" || res.Columns[1] != "state" { + t.Errorf("columns = %v; want [datname state]", res.Columns) + } +} + +// TestIntegrationReadOnlyViolation — a write is rejected server-side as +// SQLSTATE 25006 → DB.READONLY_VIOLATION (spec-2.3a R-1 authoritative gate). +func TestIntegrationReadOnlyViolation(t *testing.T) { + qc, done := openOrSkip(t) + defer done() + _, err := qc.Query(context.Background(), + "CREATE TABLE opendbx_readonly_probe (id int)", db.QueryOptions{}) + if err == nil { + t.Fatal("CREATE TABLE in read-only tx returned nil error") + } + var ec interface{ Code() string } + if !errors.As(err, &ec) || ec.Code() != db.ErrReadOnlyViolation.Code() { + t.Errorf("write err = %v; want DB.READONLY_VIOLATION", err) + } +} + +// TestIntegrationMultiStatementRejected — extended protocol rejects a +// multi-statement string (the injection-shape second gate). +func TestIntegrationMultiStatementRejected(t *testing.T) { + qc, done := openOrSkip(t) + defer done() + _, err := qc.Query(context.Background(), + "SELECT 1; DROP TABLE IF EXISTS opendbx_x", db.QueryOptions{}) + if err == nil { + t.Fatal("multi-statement query returned nil error; extended protocol should reject") + } +} + +// TestIntegrationNumericRendering — a NUMERIC column renders as decimal text +// (the formatCell type-table contract against a real value, not a struct). +func TestIntegrationNumericRendering(t *testing.T) { + qc, done := openOrSkip(t) + defer done() + res, err := qc.Query(context.Background(), "SELECT 12345.67::numeric AS n", db.QueryOptions{}) + if err != nil { + t.Fatalf("Query: %v", err) + } + if res.Rows[0][0] != "12345.67" { + t.Errorf("numeric rendered as %q; want 12345.67", res.Rows[0][0]) + } +} diff --git a/internal/domain/db/postgres/postgres.go b/internal/domain/db/postgres/postgres.go index 9a2da28..efb1c2f 100644 --- a/internal/domain/db/postgres/postgres.go +++ b/internal/domain/db/postgres/postgres.go @@ -31,6 +31,15 @@ var _ db.Driver = Driver{} // Open's success path is unit-testable without a real PostgreSQL (spec-1.18 // R-fix; post-impl go-reviewer MED-1). // +// INVARIANT (spec-2.3a arch MED-2): the pool must NOT be configured with +// pgx.QueryExecModeSimpleProtocol. db_query's read-only safety leans on the +// extended protocol rejecting multi-statement input as a secondary gate +// (the READ ONLY tx is the authoritative gate); simple protocol would batch +// `SELECT 1; DROP ...` in one round trip. pgxpool.New uses the pgx default +// (cache_statement, an extended-protocol mode), so this holds as long as no +// DefaultQueryExecMode override is introduced — guarded by +// TestNewPoolUsesExtendedProtocol. +// //nolint:gochecknoglobals // spec-1.18 R-fix: test seam, mirrors the database/sql driver-constructor pattern. var newPool = func(ctx context.Context, dsn string) (pgxPool, error) { return pgxpool.New(ctx, dsn) diff --git a/internal/domain/db/postgres/postgres_test.go b/internal/domain/db/postgres/postgres_test.go index 8b2b612..f8e55b0 100644 --- a/internal/domain/db/postgres/postgres_test.go +++ b/internal/domain/db/postgres/postgres_test.go @@ -12,6 +12,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxpool" "github.com/sqlrush/opendbx/internal/domain/db" ) @@ -69,11 +70,19 @@ type fakePool struct { pingErr error row pgx.Row closes int + // spec-2.3a D-2: read-only query path. + tx pgx.Tx + beginErr error + lastTxOpts pgx.TxOptions } func (p *fakePool) Ping(context.Context) error { return p.pingErr } func (p *fakePool) QueryRow(context.Context, string, ...any) pgx.Row { return p.row } func (p *fakePool) Close() { p.closes++ } +func (p *fakePool) BeginTx(_ context.Context, o pgx.TxOptions) (pgx.Tx, error) { + p.lastTxOpts = o + return p.tx, p.beginErr +} func TestPgConnPing(t *testing.T) { ok := &pgConn{pool: &fakePool{}} @@ -170,3 +179,18 @@ func TestPgConnCloseIdempotent(t *testing.T) { t.Errorf("pool.Close called %d times, want 1 (idempotent)", fp.closes) } } + +// TestNewPoolUsesExtendedProtocol guards the spec-2.3a arch MED-2 invariant: +// the pool must not use simple protocol, or the multi-statement rejection +// gate (a secondary read-only defense) breaks. We parse a representative DSN +// and assert the pgx default exec mode is not SimpleProtocol. +func TestNewPoolUsesExtendedProtocol(t *testing.T) { + t.Parallel() + cfg, err := pgxpool.ParseConfig("postgres://u:p@127.0.0.1:5432/db") + if err != nil { + t.Fatalf("ParseConfig: %v", err) + } + if cfg.ConnConfig.DefaultQueryExecMode == pgx.QueryExecModeSimpleProtocol { + t.Error("pool default exec mode is SimpleProtocol — multi-statement gate broken (spec-2.3a arch MED-2)") + } +} diff --git a/internal/domain/db/postgres/query.go b/internal/domain/db/postgres/query.go new file mode 100644 index 0000000..f5856f0 --- /dev/null +++ b/internal/domain/db/postgres/query.go @@ -0,0 +1,188 @@ +// Copyright 2026 opendbx contributors. See LICENSE. +// +// Author: sqlrush + +// File query.go — pgConn's read-only Query capability (spec-2.3a D-2), +// making *pgConn satisfy db.QueryConn. +// +// Read-only is enforced SERVER-side: every statement runs inside a +// `BeginTx(AccessMode: ReadOnly)` transaction that is ALWAYS rolled back. +// Persistent-table writes / DDL raise SQLSTATE 25006 (→ DB.READONLY_ +// VIOLATION via classify). This is NOT a zero-side-effect guarantee — +// sequence advancement, session advisory locks, temp tables, and remote +// (dblink/FDW) effects escape a read-only transaction; the real boundary +// is a least-privilege connection role (spec-2.3a § 1 / R-5). +// +// The pgx.Tx / pgx.Rows surfaces are consumed through the narrow queryTx / +// queryRows interfaces so unit fakes implement a handful of methods rather +// than pgx's full 11-method Tx / 9-method Rows (spec-2.3a Q10 / cr HIGH-3). + +package postgres + +import ( + "context" + "database/sql/driver" + "encoding/hex" + "fmt" + "strconv" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + + "github.com/sqlrush/opendbx/internal/domain/db" +) + +// Compile-time: *pgConn satisfies the query capability (spec-1.18 R-8). +var _ db.QueryConn = (*pgConn)(nil) + +// queryTx is the narrow tx surface pgConn.Query uses. *pgx.Tx satisfies it. +type queryTx interface { + Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) + Rollback(ctx context.Context) error +} + +// queryRows is the narrow rows surface pgConn.Query uses. pgx.Rows satisfies it. +type queryRows interface { + FieldDescriptions() []pgconn.FieldDescription + Next() bool + Values() ([]any, error) + Err() error + Close() +} + +// Query runs ONE read-only SQL statement and returns a bounded text page. +// ctx cancel/deadline is propagated through classify, which preserves the +// ctx root under Unwrap so the tool layer's errors.Is honors the +// spec-1.21 two-track contract (spec-2.3a D-1 / codex CRIT-1). +func (c *pgConn) Query(ctx context.Context, sql string, opts db.QueryOptions) (db.QueryResult, error) { + rawTx, err := c.pool.BeginTx(ctx, pgx.TxOptions{AccessMode: pgx.ReadOnly}) + if err != nil { + // errcode-lint:exempt -- spec-1.18 D-4: classify returns a sanitized db.Err* (or nil); ctx root preserved under Unwrap. + return db.QueryResult{}, classify(err) + } + var tx queryTx = rawTx + // Always roll back — a read-only tx has nothing to commit, and rollback + // is the zero-effect close for the persistent-write/DDL set (residual + // session effects are documented, spec-2.3a R-5). Rollback after a + // committed/closed tx is a harmless no-op in pgx. + // + // WithoutCancel: if the query was cancelled/timed-out, rolling back with + // the SAME cancelled ctx makes pgx fail the ROLLBACK wire send and HARD- + // CLOSE (die) the pooled conn, discarding it. A detached ctx lets ROLLBACK + // complete so the conn returns cleanly to the pool (post-impl go MED-1). + defer func() { _ = tx.Rollback(context.WithoutCancel(ctx)) }() + + rawRows, err := tx.Query(ctx, sql) + if err != nil { + // errcode-lint:exempt -- spec-1.18 D-4: classify returns a sanitized db.Err*; 25006 → DB.READONLY_VIOLATION. + return db.QueryResult{}, classify(err) + } + var rows queryRows = rawRows + defer rows.Close() + + return scanBounded(rows, opts) +} + +// scanBounded reads up to MaxRows rows plus one sentinel to decide +// truncation, rendering each cell to text. A mid-iteration Values() error +// breaks immediately; rows.Err() then carries the cause (spec-2.3a cr +// MED-3 — no partial nil rows emitted). +func scanBounded(rows queryRows, opts db.QueryOptions) (db.QueryResult, error) { + limit := opts.EffectiveMaxRows() + maxCell := opts.EffectiveMaxCellRunes() + + fields := rows.FieldDescriptions() + cols := make([]string, len(fields)) + for i, f := range fields { + cols[i] = f.Name + } + + res := db.QueryResult{Columns: cols} + for rows.Next() { + if len(res.Rows) == limit { + // One more row exists beyond the cap → truncated. Stop fetching. + res.RowsTruncated = true + break + } + vals, err := rows.Values() + if err != nil { + break // rows.Err() carries it; do not emit a partial row + } + row := make([]string, len(vals)) + for i, v := range vals { + text, truncated := formatCell(v, maxCell) + row[i] = text + if truncated { + res.CellsTruncated = true + } + } + res.Rows = append(res.Rows, row) + } + if err := rows.Err(); err != nil { + // errcode-lint:exempt -- spec-1.18 D-4: classify returns a sanitized db.Err*. + return db.QueryResult{}, classify(err) + } + return res, nil +} + +// formatCell renders one pgx value to a stable text cell and reports +// whether it was rune-truncated (spec-2.3a D-2 pinned type table; cr/codex +// HIGH-2). byte-stability is load-bearing for dedup keys + goldens. +func formatCell(v any, maxRunes int) (text string, truncated bool) { + s := renderValue(v) + r := []rune(s) + if len(r) > maxRunes { + return string(r[:maxRunes]) + "…", true + } + return s, false +} + +// renderValue is the type table (spec-2.3a D-2 pinned; cr/codex HIGH-2). +// Explicit scalar cases keep output byte-stable and avoid fmt.Sprint traps: +// floats must NOT print in scientific notation (pgx returns bare float64 for +// float8 — e.g. pg_stat checkpoint_write_time would render "3.6e+06" without +// FormatFloat 'f', post-impl go MED-2). driver.Valuer covers the pgtype +// family (pgtype.Numeric.Value() → canonical decimal string). +func renderValue(v any) string { + switch x := v.(type) { + case nil: + return "NULL" + case string: + return x + case []byte: + return "\\x" + hex.EncodeToString(x) // bytea-style, not "[97 98]" + case time.Time: + return x.Format(time.RFC3339Nano) + case bool: + return strconv.FormatBool(x) + case int: + return strconv.FormatInt(int64(x), 10) + case int16: + return strconv.FormatInt(int64(x), 10) + case int32: + return strconv.FormatInt(int64(x), 10) + case int64: + return strconv.FormatInt(x, 10) + case uint32: // pgx OID type + return strconv.FormatUint(uint64(x), 10) + case float32: + return strconv.FormatFloat(float64(x), 'f', -1, 32) + case float64: + return strconv.FormatFloat(x, 'f', -1, 64) // "3600000" not "3.6e+06" + case driver.Valuer: + dv, err := x.Value() + if err != nil || dv == nil { + return fmt.Sprintf("%v", v) + } + if b, ok := dv.([]byte); ok { + return "\\x" + hex.EncodeToString(b) + } + if s, ok := dv.(string); ok { // pgtype.Numeric → decimal string + return s + } + return fmt.Sprintf("%v", dv) + default: + return fmt.Sprintf("%v", v) // unknown — stable best-effort + } +} diff --git a/internal/domain/db/postgres/query_test.go b/internal/domain/db/postgres/query_test.go new file mode 100644 index 0000000..f3cc12e --- /dev/null +++ b/internal/domain/db/postgres/query_test.go @@ -0,0 +1,269 @@ +// Copyright 2026 opendbx contributors. See LICENSE. +// +// Author: sqlrush + +package postgres + +import ( + "context" + "errors" + "strings" + "testing" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" + + "github.com/sqlrush/opendbx/internal/domain/db" +) + +// --- narrow fakes (embed the pgx interface; only override what Query uses) --- +// +// pgConn.Query touches only queryTx{Query, Rollback} and queryRows{ +// FieldDescriptions, Next, Values, Err, Close}. The embedded nil pgx.Tx / +// pgx.Rows satisfy the wide interfaces at compile time; any OTHER method +// would nil-panic, so adding a call in Query means overriding it here +// (post-impl go LOW-1). + +type fakeTx struct { + pgx.Tx // embedded nil interface — satisfies the type; unused methods unreached + rows pgx.Rows + queryErr error + rolledBack bool +} + +func (f *fakeTx) Query(context.Context, string, ...any) (pgx.Rows, error) { + return f.rows, f.queryErr +} +func (f *fakeTx) Rollback(context.Context) error { f.rolledBack = true; return nil } + +type fakeRows struct { + pgx.Rows // embedded nil interface + fields []pgconn.FieldDescription + data [][]any + idx int + valuesErrAt int // 1-based row index where Values() fails; 0 = never + err error + closed bool +} + +func (r *fakeRows) FieldDescriptions() []pgconn.FieldDescription { return r.fields } +func (r *fakeRows) Next() bool { + if r.idx < len(r.data) { + r.idx++ + return true + } + return false +} +func (r *fakeRows) Values() ([]any, error) { + if r.valuesErrAt != 0 && r.idx == r.valuesErrAt { + // Real pgx decode failures surface a SQLSTATE-bearing error; 22P02 + // (invalid_text_representation, class 22) → DB.QUERY_FAILED. + r.err = &pgconn.PgError{Code: "22P02", Message: "invalid input syntax"} + return nil, r.err + } + return r.data[r.idx-1], nil +} +func (r *fakeRows) Err() error { return r.err } +func (r *fakeRows) Close() { r.closed = true } + +func cols(names ...string) []pgconn.FieldDescription { + out := make([]pgconn.FieldDescription, len(names)) + for i, n := range names { + out[i] = pgconn.FieldDescription{Name: n} + } + return out +} + +// nRows builds a fakeRows with n single-column rows "r0".."r(n-1)". +func nRows(n int) *fakeRows { + data := make([][]any, n) + for i := range data { + data[i] = []any{"r" + itoa(i)} + } + return &fakeRows{fields: cols("c"), data: data} +} + +func itoa(i int) string { return strings.TrimSpace(string(rune('0' + i%10))) } // small i only + +func newQueryConn(rows pgx.Rows, queryErr, beginErr error) (*pgConn, *fakeTx, *fakePool) { + tx := &fakeTx{rows: rows, queryErr: queryErr} + fp := &fakePool{tx: tx, beginErr: beginErr} + return &pgConn{pool: fp}, tx, fp +} + +// --- tests --- + +// TestQuery_ReadOnlyAndAlwaysRollback — every Query opens a READ ONLY tx +// and rolls back even on the success path (spec-2.3a D-2 / Q2). +func TestQuery_ReadOnlyAndAlwaysRollback(t *testing.T) { + t.Parallel() + conn, tx, fp := newQueryConn(nRows(2), nil, nil) + res, err := conn.Query(context.Background(), "SELECT 1", db.QueryOptions{}) + if err != nil { + t.Fatalf("Query: %v", err) + } + if fp.lastTxOpts.AccessMode != pgx.ReadOnly { + t.Errorf("AccessMode = %v; want ReadOnly", fp.lastTxOpts.AccessMode) + } + if !tx.rolledBack { + t.Error("success path must still Rollback (zero-commit read-only tx)") + } + if len(res.Rows) != 2 || res.Columns[0] != "c" { + t.Errorf("result = %+v; want 2 rows / column c", res) + } +} + +// TestQuery_SentinelTruncation — 99/100/101 rows distinguish exactly-cap +// from over-cap via the MaxRows+1 sentinel (spec-2.3a codex HIGH-1). +func TestQuery_SentinelTruncation(t *testing.T) { + t.Parallel() + tests := []struct { + rows int + wantKept int + wantTruncated bool + }{ + {99, 99, false}, + {100, 100, false}, + {101, 100, true}, + } + for _, tc := range tests { + t.Run(itoa(tc.rows%10)+"_rows", func(t *testing.T) { + t.Parallel() + // build N rows with distinct content (itoa only handles small i, + // so use a plain loop here) + data := make([][]any, tc.rows) + for i := range data { + data[i] = []any{"x"} + } + rows := &fakeRows{fields: cols("c"), data: data} + conn, _, _ := newQueryConn(rows, nil, nil) + res, err := conn.Query(context.Background(), "SELECT 1", db.QueryOptions{}) + if err != nil { + t.Fatal(err) + } + if len(res.Rows) != tc.wantKept { + t.Errorf("kept %d rows; want %d", len(res.Rows), tc.wantKept) + } + if res.RowsTruncated != tc.wantTruncated { + t.Errorf("RowsTruncated = %v; want %v (input %d rows)", res.RowsTruncated, tc.wantTruncated, tc.rows) + } + if !rows.closed { + t.Error("rows must be Closed") + } + }) + } +} + +// TestQuery_FormatCellTypeTable — the pinned type table (spec-2.3a D-2; +// cr/codex HIGH-2): NUMERIC/bytea/time/nil/bool do not fall to fmt.Sprint +// garbage. +func TestQuery_FormatCellTypeTable(t *testing.T) { + t.Parallel() + num := pgtype.Numeric{} + if err := num.Scan("12345.67"); err != nil { + t.Fatalf("seed numeric: %v", err) + } + ts := time.Date(2026, 6, 8, 12, 30, 0, 0, time.UTC) + // float64(3.6e6): must render as plain decimal, NOT "3.6e+06" (go MED-2 — + // pgx returns bare float64 for float8 columns like checkpoint_write_time). + row := []any{nil, "hello", []byte{0x61, 0x62}, ts, true, int64(42), num, float64(3.6e6), int32(7)} + rows := &fakeRows{ + fields: cols("nullc", "str", "bytes", "tstamp", "flag", "num64", "numeric", "f8", "i32"), + data: [][]any{row}, + } + conn, _, _ := newQueryConn(rows, nil, nil) + res, err := conn.Query(context.Background(), "SELECT *", db.QueryOptions{}) + if err != nil { + t.Fatal(err) + } + got := res.Rows[0] + want := []string{"NULL", "hello", "\\x6162", "2026-06-08T12:30:00Z", "true", "42", "12345.67", "3600000", "7"} + for i := range want { + if got[i] != want[i] { + t.Errorf("cell[%d] = %q; want %q", i, got[i], want[i]) + } + } +} + +// TestQuery_CellTruncation — a cell over MaxCellRunes is rune-truncated +// with an ellipsis and flags CellsTruncated (spec-2.3a D-2). +func TestQuery_CellTruncation(t *testing.T) { + t.Parallel() + long := strings.Repeat("世", 10) // 10 runes (CJK — rune-aware, not byte) + rows := &fakeRows{fields: cols("c"), data: [][]any{{long}}} + conn, _, _ := newQueryConn(rows, nil, nil) + res, err := conn.Query(context.Background(), "SELECT 1", db.QueryOptions{MaxCellRunes: 4}) + if err != nil { + t.Fatal(err) + } + if !res.CellsTruncated { + t.Error("want CellsTruncated") + } + if res.Rows[0][0] != "世世世世…" { + t.Errorf("cell = %q; want 4 runes + ellipsis", res.Rows[0][0]) + } +} + +// TestQuery_RowValuesErrorBreaks — a mid-scan Values() error breaks the +// loop and surfaces via rows.Err()→classify, with no partial rows +// (spec-2.3a cr MED-3). +func TestQuery_RowValuesErrorBreaks(t *testing.T) { + t.Parallel() + rows := &fakeRows{fields: cols("c"), data: [][]any{{"a"}, {"b"}, {"c"}}, valuesErrAt: 2} + conn, _, _ := newQueryConn(rows, nil, nil) + _, err := conn.Query(context.Background(), "SELECT 1", db.QueryOptions{}) + if err == nil { + t.Fatal("want error from mid-scan Values failure") + } + var ec interface{ Code() string } + if !errors.As(err, &ec) || ec.Code() != db.ErrQueryFailed.Code() { + t.Errorf("err = %v; want DB.QUERY_FAILED", err) + } +} + +// TestQuery_ReadOnlyViolation — SQLSTATE 25006 maps to READONLY_VIOLATION, +// 0A000 does NOT (stays QUERY_FAILED) (spec-2.3a D-2 / cr HIGH-1). +func TestQuery_ReadOnlyViolation(t *testing.T) { + t.Parallel() + conn, _, _ := newQueryConn(nil, &pgconn.PgError{Code: "25006", Message: "cannot execute INSERT in a read-only transaction"}, nil) + _, err := conn.Query(context.Background(), "INSERT ...", db.QueryOptions{}) + var ec interface{ Code() string } + if !errors.As(err, &ec) || ec.Code() != db.ErrReadOnlyViolation.Code() { + t.Errorf("25006 err = %v; want DB.READONLY_VIOLATION", err) + } + + conn2, _, _ := newQueryConn(nil, &pgconn.PgError{Code: "0A000", Message: "feature not supported"}, nil) + _, err2 := conn2.Query(context.Background(), "SELECT ...", db.QueryOptions{}) + if !errors.As(err2, &ec) || ec.Code() != db.ErrQueryFailed.Code() { + t.Errorf("0A000 err = %v; want DB.QUERY_FAILED (not readonly)", err2) + } +} + +// TestQuery_CtxRootPreservedForTwoTrack — a ctx-deadline query error is +// classified as DB.TIMEOUT but still matches context.DeadlineExceeded via +// errors.Is, so the tool layer can honor the spec-1.21 two-track contract +// (spec-2.3a D-1 / codex CRIT-1). +func TestQuery_CtxRootPreservedForTwoTrack(t *testing.T) { + t.Parallel() + conn, _, _ := newQueryConn(nil, context.DeadlineExceeded, nil) + _, err := conn.Query(context.Background(), "SELECT 1", db.QueryOptions{}) + if !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("ctx root not preserved: %v", err) + } +} + +// TestQuery_BeginTxFailureSanitized — a BeginTx error is sanitized (no DSN +// leak) and classified (spec-2.3a R-4 sanitize regression延伸). +func TestQuery_BeginTxFailureSanitized(t *testing.T) { + t.Parallel() + conn, _, _ := newQueryConn(nil, nil, &pgconn.PgError{Code: "08006", Message: "connection failure to host=secret password=leak"}) + _, err := conn.Query(context.Background(), "SELECT 1", db.QueryOptions{}) + if err == nil { + t.Fatal("want BeginTx error") + } + if strings.Contains(err.Error(), "leak") { + t.Errorf("BeginTx error leaked credentials: %s", err.Error()) + } +} diff --git a/internal/domain/db/query.go b/internal/domain/db/query.go new file mode 100644 index 0000000..bf80d03 --- /dev/null +++ b/internal/domain/db/query.go @@ -0,0 +1,93 @@ +// Copyright 2026 opendbx contributors. See LICENSE. +// +// Author: sqlrush + +// File query.go — the QueryConn query CAPABILITY over Conn (spec-2.3a +// D-1). This realizes the spec-1.18 R-8 forward design: query arrives as +// a COMPOSED interface, never by adding a method to the frozen db.Conn +// (which would break every fake and future driver). +// +// § 3.7 multi-DB: the minimal common result shape is a text page. Each +// driver renders its own native types to strings driver-side (a single +// any-typed row set would leak every driver's type system upward). MySQL/ +// Oracle/openGauss implement QueryConn later against the SAME contract. + +package db + +import "context" + +// Default result caps (spec-2.3a Q6, R2 user decision). Package-level +// constants — config knobs are deferred (spec-2.3a ❌-3). +const ( + // DefaultMaxRows bounds rows kept before truncation. The driver fetches + // MaxRows+1 to distinguish "exactly MaxRows" from "more existed". + DefaultMaxRows = 100 + // DefaultMaxCellRunes bounds per-cell rune length before truncation. + DefaultMaxCellRunes = 200 +) + +// QueryConn is the OPTIONAL query capability over Conn. Acquire via a type +// assertion: qc, ok := conn.(db.QueryConn). The caller MUST Close() the +// underlying Conn when the assertion FAILS — a real driver may return a +// live pool from Open before it implements QueryConn (spec-2.3a D-3 / +// codex HIGH-2: not closing leaks the pool on every retry). +type QueryConn interface { + Conn + // Query executes ONE read-only SQL statement inside a server-side READ + // ONLY transaction and returns the bounded, text-rendered page. + // Enforcement is SERVER-side (persistent-table writes / DDL → SQLSTATE + // 25006); callers must NOT rely on client-side SQL inspection. The ctx + // deadline is the only v1 cancellation mechanism. Implementations MUST + // let ctx cancel/deadline remain matchable via errors.Is on the + // returned error (so the tool layer can honor the spec-1.21 two-track + // timeout/cancel contract rather than treating it as a recoverable DB + // error) — the postgres driver satisfies this because its sanitized + // classify preserves the ctx root under Unwrap. + Query(ctx context.Context, sql string, opts QueryOptions) (QueryResult, error) +} + +// QueryOptions bounds a single Query. A zero field falls back to the +// Default* constant, so QueryOptions{} is the canonical "use defaults". +type QueryOptions struct { + MaxRows int // rows kept before truncation; 0 → DefaultMaxRows + MaxCellRunes int // per-cell rune cap; 0 → DefaultMaxCellRunes +} + +// QueryResult is a fully text-rendered result page. +// +// Ownership (spec-2.3a D-1, arch HIGH-2): the driver does NOT retain or +// mutate the result after Query returns — the caller owns it. This is NOT +// a deep-copy guarantee: Rows is [][]string whose backing arrays are not +// cloned, so callers must not assume mutating one returned result is +// isolated from another. (No consumer needs that; documenting it prevents +// a false copy-isolation assumption.) +type QueryResult struct { + // Columns are the result column names in query order. + Columns []string + // Rows are text cells (driver rendered each native value to a string; + // a NULL DB value renders as the literal "NULL"). + Rows [][]string + // RowsTruncated is true iff a MaxRows+1 sentinel row existed — i.e. the + // query produced strictly more than MaxRows rows. Exactly MaxRows rows + // leaves this false. + RowsTruncated bool + // CellsTruncated is true iff at least one cell hit MaxCellRunes. + CellsTruncated bool +} + +// EffectiveMaxRows resolves the row cap (0 → default). Pure helper so the +// driver and tests agree on the fallback. +func (o QueryOptions) EffectiveMaxRows() int { + if o.MaxRows <= 0 { + return DefaultMaxRows + } + return o.MaxRows +} + +// EffectiveMaxCellRunes resolves the per-cell cap (0 → default). +func (o QueryOptions) EffectiveMaxCellRunes() int { + if o.MaxCellRunes <= 0 { + return DefaultMaxCellRunes + } + return o.MaxCellRunes +} diff --git a/internal/domain/db/query_test.go b/internal/domain/db/query_test.go new file mode 100644 index 0000000..e2ec8bf --- /dev/null +++ b/internal/domain/db/query_test.go @@ -0,0 +1,39 @@ +// Copyright 2026 opendbx contributors. See LICENSE. +// +// Author: sqlrush + +package db + +import "testing" + +// TestQueryOptions_EffectiveDefaults — zero fields fall back to the +// Default* constants; positive values pass through (spec-2.3a D-1). +func TestQueryOptions_EffectiveDefaults(t *testing.T) { + t.Parallel() + var zero QueryOptions + if got := zero.EffectiveMaxRows(); got != DefaultMaxRows { + t.Errorf("zero MaxRows = %d; want default %d", got, DefaultMaxRows) + } + if got := zero.EffectiveMaxCellRunes(); got != DefaultMaxCellRunes { + t.Errorf("zero MaxCellRunes = %d; want default %d", got, DefaultMaxCellRunes) + } + custom := QueryOptions{MaxRows: 5, MaxCellRunes: 10} + if custom.EffectiveMaxRows() != 5 || custom.EffectiveMaxCellRunes() != 10 { + t.Errorf("custom passthrough failed: %+v", custom) + } + // Negative is treated as unset (defensive). + neg := QueryOptions{MaxRows: -1, MaxCellRunes: -1} + if neg.EffectiveMaxRows() != DefaultMaxRows || neg.EffectiveMaxCellRunes() != DefaultMaxCellRunes { + t.Error("negative caps should fall back to defaults") + } +} + +// TestQueryResult_ZeroValue — a zero QueryResult is a usable empty page +// (no panic on len; both truncation flags false). +func TestQueryResult_ZeroValue(t *testing.T) { + t.Parallel() + var r QueryResult + if len(r.Columns) != 0 || len(r.Rows) != 0 || r.RowsTruncated || r.CellsTruncated { + t.Errorf("zero QueryResult not empty: %+v", r) + } +} diff --git a/internal/platform/errcode/testdata/error-codes-frozen.txt b/internal/platform/errcode/testdata/error-codes-frozen.txt index 1042a75..22689b0 100644 --- a/internal/platform/errcode/testdata/error-codes-frozen.txt +++ b/internal/platform/errcode/testdata/error-codes-frozen.txt @@ -11,6 +11,8 @@ DB.AUTH_FAILED DB.CONNECT_FAILED DB.DRIVER_UNKNOWN DB.QUERY_FAILED +DB.QUERY_INPUT_INVALID +DB.READONLY_VIOLATION DB.TIMEOUT DB.UNAVAILABLE DIAGNOSE.MAX_TURNS diff --git a/tests/integration/dbquery/dbquery_test.go b/tests/integration/dbquery/dbquery_test.go new file mode 100644 index 0000000..3d15dc5 --- /dev/null +++ b/tests/integration/dbquery/dbquery_test.go @@ -0,0 +1,259 @@ +// Copyright 2026 opendbx contributors. See LICENSE. +// +// Author: sqlrush + +// Package dbquery_test — spec-2.3a D-6 integration: the REAL dbquery.Tool +// wired into the REAL diagnose.Loop with a scripted fake provider and a +// fake QueryConn, plus interplay with the spec-2.3 SkillTool ToolFilter +// (a skill whose allowed-tools lists db_query). +package dbquery_test + +import ( + "context" + "errors" + "strings" + "testing" + + "github.com/sqlrush/opendbx/internal/app/diagnose" + "github.com/sqlrush/opendbx/internal/app/skills" + "github.com/sqlrush/opendbx/internal/app/skills/invoke" + "github.com/sqlrush/opendbx/internal/app/tools/dbquery" + "github.com/sqlrush/opendbx/internal/domain/db" + "github.com/sqlrush/opendbx/internal/domain/llm" + "github.com/sqlrush/opendbx/internal/domain/llm/fake" +) + +// fakeQueryConn implements db.QueryConn for the integration wiring. +type fakeQueryConn struct { + result db.QueryResult + queryErr error +} + +func (f *fakeQueryConn) Ping(context.Context) error { return nil } +func (f *fakeQueryConn) HealthCheck(context.Context) (db.Health, error) { return db.Health{}, nil } +func (f *fakeQueryConn) Close() error { return nil } +func (f *fakeQueryConn) Query(context.Context, string, db.QueryOptions) (db.QueryResult, error) { + return f.result, f.queryErr +} + +func newDBTool(qc db.QueryConn, openErr error) *dbquery.Tool { + return dbquery.New(func(context.Context) (db.Conn, error) { + if openErr != nil { + return nil, openErr + } + return qc, nil + }) +} + +func userReq(text string) llm.Request { + return llm.Request{ + Messages: []llm.Message{{Role: llm.RoleUser, Content: []llm.ContentBlock{{Type: llm.BlockText, Text: text}}}}, + MaxTokens: 1024, + } +} + +func toolResults(msgs []llm.Message) []llm.ToolResult { + var out []llm.ToolResult + for _, m := range msgs { + for _, c := range m.Content { + if c.Type == llm.BlockToolResult && c.ToolResult != nil { + out = append(out, *c.ToolResult) + } + } + } + return out +} + +// TestDBQuery_HappyPath — the model calls db_query, the rendered table lands +// in the transcript, and the run terminates normally. +func TestDBQuery_HappyPath(t *testing.T) { + t.Parallel() + qc := &fakeQueryConn{result: db.QueryResult{ + Columns: []string{"datname", "state"}, + Rows: [][]string{{"app", "active"}, {"app", "idle"}}, + }} + reg, err := diagnose.NewRegistry(newDBTool(qc, nil)) + if err != nil { + t.Fatal(err) + } + prov := fake.NewScriptedTurns( + fake.Turn{Finish: llm.FinishToolUse, ToolUses: []llm.ToolUse{ + {ID: "c1", Name: "db_query", Input: map[string]any{"sql": "SELECT datname, state FROM pg_stat_activity"}}}}, + fake.Turn{Text: "done", Finish: llm.FinishStop}, + ) + loop, err := diagnose.NewLoop(diagnose.Options{Provider: prov, Registry: reg, MaxTurns: 5}) + if err != nil { + t.Fatal(err) + } + res, err := loop.Run(context.Background(), userReq("why so many connections"), func(context.Context, diagnose.Event) error { return nil }) + if err != nil { + t.Fatalf("Run: %v", err) + } + trs := toolResults(res.Messages) + if len(trs) != 1 || trs[0].IsError { + t.Fatalf("tool results = %+v; want 1 success", trs) + } + for _, want := range []string{"datname", "active", "(2 rows)"} { + if !strings.Contains(trs[0].Content, want) { + t.Errorf("table missing %q:\n%s", want, trs[0].Content) + } + } +} + +// TestDBQuery_CtxCancelIsFatal — a ctx-deadline query error must terminate +// the run on the spec-1.21 two-track path (FinishCancelled), NOT surface as +// a recoverable [DB.TIMEOUT] tool result (codex CRIT-1 end-to-end). +func TestDBQuery_CtxCancelIsFatal(t *testing.T) { + t.Parallel() + qc := &fakeQueryConn{queryErr: context.Canceled} + reg, _ := diagnose.NewRegistry(newDBTool(qc, nil)) + prov := fake.NewScriptedTurns( + fake.Turn{Finish: llm.FinishToolUse, ToolUses: []llm.ToolUse{ + {ID: "c1", Name: "db_query", Input: map[string]any{"sql": "SELECT 1"}}}}, + fake.Turn{Text: "unreached", Finish: llm.FinishStop}, + ) + loop, _ := diagnose.NewLoop(diagnose.Options{Provider: prov, Registry: reg, MaxTurns: 5}) + res, err := loop.Run(context.Background(), userReq("go"), func(context.Context, diagnose.Event) error { return nil }) + if !errors.Is(err, context.Canceled) { + t.Fatalf("Run err = %v; want context.Canceled (fatal two-track)", err) + } + if res.FinishReason != llm.FinishCancelled { + t.Errorf("FinishReason = %v; want FinishCancelled", res.FinishReason) + } + // No recoverable tool result was committed for the cancelled call. + for _, tr := range toolResults(res.Messages) { + if strings.Contains(tr.Content, "DB.TIMEOUT") { + t.Error("ctx cancel leaked as a recoverable DB.TIMEOUT tool result") + } + } +} + +// TestDBQuery_RecoverableError — a non-ctx query error (read-only violation) +// is a recoverable tool result and the run continues. +func TestDBQuery_RecoverableError(t *testing.T) { + t.Parallel() + qc := &fakeQueryConn{queryErr: db.ErrReadOnlyViolation} + reg, _ := diagnose.NewRegistry(newDBTool(qc, nil)) + prov := fake.NewScriptedTurns( + fake.Turn{Finish: llm.FinishToolUse, ToolUses: []llm.ToolUse{ + {ID: "c1", Name: "db_query", Input: map[string]any{"sql": "DELETE FROM t"}}}}, + fake.Turn{Text: "ok, read-only understood", Finish: llm.FinishStop}, + ) + loop, _ := diagnose.NewLoop(diagnose.Options{Provider: prov, Registry: reg, MaxTurns: 5}) + res, err := loop.Run(context.Background(), userReq("delete stuff"), func(context.Context, diagnose.Event) error { return nil }) + if err != nil { + t.Fatalf("recoverable error must not terminate: %v", err) + } + trs := toolResults(res.Messages) + if len(trs) != 1 || !trs[0].IsError || !strings.Contains(trs[0].Content, "DB.READONLY_VIOLATION") { + t.Fatalf("want 1 recoverable readonly result: %+v", trs) + } +} + +// TestDBQuery_DedupCachedReplay — the same SQL within the dedup window +// returns a byte-identical cached result. +func TestDBQuery_DedupCachedReplay(t *testing.T) { + t.Parallel() + qc := &fakeQueryConn{result: db.QueryResult{Columns: []string{"c"}, Rows: [][]string{{"1"}}}} + reg, _ := diagnose.NewRegistry(newDBTool(qc, nil)) + prov := fake.NewScriptedTurns( + fake.Turn{Finish: llm.FinishToolUse, ToolUses: []llm.ToolUse{ + {ID: "c1", Name: "db_query", Input: map[string]any{"sql": "SELECT 1"}}}}, + fake.Turn{Finish: llm.FinishToolUse, ToolUses: []llm.ToolUse{ + {ID: "c2", Name: "db_query", Input: map[string]any{"sql": "SELECT 1"}}}}, + fake.Turn{Text: "done", Finish: llm.FinishStop}, + ) + loop, _ := diagnose.NewLoop(diagnose.Options{Provider: prov, Registry: reg, MaxTurns: 5, DedupEnabled: true}) + var c2Cached bool + emit := func(_ context.Context, e diagnose.Event) error { + if e.Kind == diagnose.EventToolResult && e.ToolResult.ToolUseID == "c2" { + c2Cached = e.Cached + } + return nil + } + res, err := loop.Run(context.Background(), userReq("go"), emit) + if err != nil { + t.Fatal(err) + } + if !c2Cached { + t.Error("second identical SELECT should be a dedup cache hit") + } + trs := toolResults(res.Messages) + if len(trs) != 2 || trs[0].Content != trs[1].Content { + t.Errorf("cached replay not byte-identical:\n%q\nvs\n%q", trs[0].Content, trs[1].Content) + } +} + +// TestDBQuery_SkillScopeAllows — a skill whose allowed-tools lists db_query +// keeps it callable inside the skill scope (spec-2.3 ToolFilter interplay). +func TestDBQuery_SkillScopeAllows(t *testing.T) { + t.Parallel() + qc := &fakeQueryConn{result: db.QueryResult{Columns: []string{"c"}, Rows: [][]string{{"x"}}}} + skill := skills.Skill{ + Schema: skills.Schema{Name: "db-doctor", Description: "DB diagnostics.", AllowedTools: "db_query"}, + Body: "Use db_query to inspect pg_stat_activity.", + } + st, err := invoke.NewSkillTool([]skills.Skill{skill}) + if err != nil { + t.Fatal(err) + } + reg, err := diagnose.NewRegistry(st, newDBTool(qc, nil)) + if err != nil { + t.Fatal(err) + } + prov := fake.NewScriptedTurns( + // Enter the skill scope (allowed-tools: db_query). + fake.Turn{Finish: llm.FinishToolUse, ToolUses: []llm.ToolUse{ + {ID: "c1", Name: "Skill", Input: map[string]any{"skill": "db-doctor"}}}}, + // db_query is in scope → executes. + fake.Turn{Finish: llm.FinishToolUse, ToolUses: []llm.ToolUse{ + {ID: "c2", Name: "db_query", Input: map[string]any{"sql": "SELECT 1"}}}}, + fake.Turn{Text: "done", Finish: llm.FinishStop}, + ) + loop, _ := diagnose.NewLoop(diagnose.Options{Provider: prov, Registry: reg, MaxTurns: 5}) + res, err := loop.Run(context.Background(), userReq("diagnose"), func(context.Context, diagnose.Event) error { return nil }) + if err != nil { + t.Fatal(err) + } + trs := toolResults(res.Messages) + // c1 = skill body, c2 = db_query table (not SCOPE_TOOL_DENIED). + if len(trs) != 2 || trs[1].IsError || !strings.Contains(trs[1].Content, "(1 rows)") { + t.Errorf("db_query should run inside the db-doctor scope: %+v", trs) + } +} + +// TestDBQuery_SkillScopeDenied — a skill whose allowed-tools does NOT list +// db_query keeps it OUT of scope: an attempt returns SCOPE_TOOL_DENIED, not +// an execution (spec-2.3a D-6 / spec-2.3 ToolFilter regression). +func TestDBQuery_SkillScopeDenied(t *testing.T) { + t.Parallel() + qc := &fakeQueryConn{result: db.QueryResult{Columns: []string{"c"}, Rows: [][]string{{"x"}}}} + skill := skills.Skill{ + Schema: skills.Schema{Name: "clock-only", Description: "Time only.", AllowedTools: "clock"}, + Body: "Use clock.", + } + st, err := invoke.NewSkillTool([]skills.Skill{skill}) + if err != nil { + t.Fatal(err) + } + reg, err := diagnose.NewRegistry(st, newDBTool(qc, nil), diagnose.ClockTool{}) + if err != nil { + t.Fatal(err) + } + prov := fake.NewScriptedTurns( + fake.Turn{Finish: llm.FinishToolUse, ToolUses: []llm.ToolUse{ + {ID: "c1", Name: "Skill", Input: map[string]any{"skill": "clock-only"}}}}, + fake.Turn{Finish: llm.FinishToolUse, ToolUses: []llm.ToolUse{ + {ID: "c2", Name: "db_query", Input: map[string]any{"sql": "SELECT 1"}}}}, + fake.Turn{Text: "denied, ok", Finish: llm.FinishStop}, + ) + loop, _ := diagnose.NewLoop(diagnose.Options{Provider: prov, Registry: reg, MaxTurns: 5}) + res, err := loop.Run(context.Background(), userReq("go"), func(context.Context, diagnose.Event) error { return nil }) + if err != nil { + t.Fatal(err) + } + trs := toolResults(res.Messages) + if len(trs) != 2 || !trs[1].IsError || !strings.Contains(trs[1].Content, "SKILL.SCOPE_TOOL_DENIED") { + t.Errorf("db_query outside scope must be denied: %+v", trs) + } +}