From 67f26385238c841421448353e540c2772eabf77e Mon Sep 17 00:00:00 2001 From: Jeffrey Aven Date: Sun, 28 Jun 2026 13:54:14 +1000 Subject: [PATCH] native aws updates --- internal/anysdk/casing_wiring_test.go | 18 + internal/anysdk/config.go | 6 + internal/anysdk/const.go | 3 + internal/anysdk/expectedRequest.go | 6 + internal/anysdk/operation_store.go | 137 +++++++- internal/anysdk/params.go | 15 + internal/anysdk/provider.go | 8 + internal/anysdk/schema.go | 24 +- pkg/casing/casing.go | 122 +++++++ pkg/casing/casing_test.go | 96 ++++++ pkg/stream_transform/schema_driven_xml.go | 325 ++++++++++++++++++ .../schema_driven_xml_test.go | 201 +++++++++++ .../template_stream_transform.go | 27 ++ .../legacy_address_space.go | 18 +- 14 files changed, 994 insertions(+), 12 deletions(-) create mode 100644 internal/anysdk/casing_wiring_test.go create mode 100644 pkg/casing/casing.go create mode 100644 pkg/casing/casing_test.go create mode 100644 pkg/stream_transform/schema_driven_xml.go create mode 100644 pkg/stream_transform/schema_driven_xml_test.go diff --git a/internal/anysdk/casing_wiring_test.go b/internal/anysdk/casing_wiring_test.go new file mode 100644 index 0000000..411ea60 --- /dev/null +++ b/internal/anysdk/casing_wiring_test.go @@ -0,0 +1,18 @@ +package anysdk_test + +import ( + "testing" + + . "github.com/stackql/any-sdk/internal/anysdk" +) + +func TestParameterNotFoundErrorMessage(t *testing.T) { + err := &ParameterNotFoundError{ + Key: "foo_bar", + AvailableWireNames: []string{"VpcId", "EnableDnsHostnames", "DryRun"}, + } + want := "field 'foo_bar' not found; available: [VpcId, EnableDnsHostnames, DryRun]" + if got := err.Error(); got != want { + t.Fatalf("Error() = %q, want %q", got, want) + } +} diff --git a/internal/anysdk/config.go b/internal/anysdk/config.go index 998ecf1..baff3ab 100644 --- a/internal/anysdk/config.go +++ b/internal/anysdk/config.go @@ -25,6 +25,7 @@ type StackQLConfig interface { GetQueryParamPushdown() (QueryParamPushdown, bool) GetRetryPolicy() (RetryPolicy, bool) GetMinStackQLVersion() string + IsSnakeCaseAliasesEnabled() bool // isObjectSchemaImplicitlyUnioned() bool setResource(rsc Resource) @@ -43,6 +44,7 @@ type standardStackQLConfig struct { QueryParamPushdown *standardQueryParamPushdown `json:"queryParamPushdown,omitempty" yaml:"queryParamPushdown,omitempty"` Retry *standardRetryPolicy `json:"retry,omitempty" yaml:"retry,omitempty"` MinStackQLVersion string `json:"minStackQLVersion,omitempty" yaml:"minStackQLVersion,omitempty"` + SnakeCaseAliases bool `json:"snake_case_aliases,omitempty" yaml:"snake_case_aliases,omitempty"` } func (qt standardStackQLConfig) JSONLookup(token string) (interface{}, error) { @@ -70,6 +72,10 @@ func (cfg *standardStackQLConfig) GetMinStackQLVersion() string { return cfg.MinStackQLVersion } +func (cfg *standardStackQLConfig) IsSnakeCaseAliasesEnabled() bool { + return cfg.SnakeCaseAliases +} + func (cfg *standardStackQLConfig) GetQueryTranspose() (Transform, bool) { if cfg.QueryTranspose == nil { return nil, false diff --git a/internal/anysdk/const.go b/internal/anysdk/const.go index d4b81ae..8631fdd 100644 --- a/internal/anysdk/const.go +++ b/internal/anysdk/const.go @@ -33,6 +33,9 @@ const ( ExtensionKeyResources string = "x-stackQL-resources" ExtensionKeyStringOnly string = "x-stackQL-stringOnly" ExtensionKeyAlias string = "x-stackQL-alias" + // ExtensionKeyProtocol is the info-level wire protocol hint (query|ec2|rest-xml) + // consumed by the schema_driven_xml response transform. + ExtensionKeyProtocol string = "x-protocol" ) const ( diff --git a/internal/anysdk/expectedRequest.go b/internal/anysdk/expectedRequest.go index 64274ac..9aa25ef 100644 --- a/internal/anysdk/expectedRequest.go +++ b/internal/anysdk/expectedRequest.go @@ -13,6 +13,7 @@ type ExpectedRequest interface { GetBase() string GetXMLDeclaration() string GetXMLTransform() string + GetNativeCasing() string // setSchema(Schema) setBodyMediaType(string) @@ -33,6 +34,7 @@ type standardExpectedRequest struct { XMLRootAnnotation string `json:"xmlRootAnnotation,omitempty" yaml:"xmlRootAnnotation,omitempty"` OverrideSchema *LocalSchemaRef `json:"schema_override,omitempty" yaml:"schema_override,omitempty"` Transform *standardTransform `json:"transform,omitempty" yaml:"transform,omitempty"` + NativeCasing string `json:"nativeCasing,omitempty" yaml:"nativeCasing,omitempty"` } func (er *standardExpectedRequest) setBodyMediaType(s string) { @@ -73,6 +75,10 @@ func (er *standardExpectedRequest) GetXMLTransform() string { return er.XMLTransform } +func (er *standardExpectedRequest) GetNativeCasing() string { + return er.NativeCasing +} + func (er *standardExpectedRequest) GetSchema() Schema { if er.OverrideSchema != nil && er.OverrideSchema.Value != nil { return er.OverrideSchema.Value diff --git a/internal/anysdk/operation_store.go b/internal/anysdk/operation_store.go index 52e6587..033c2dd 100644 --- a/internal/anysdk/operation_store.go +++ b/internal/anysdk/operation_store.go @@ -15,6 +15,7 @@ import ( "github.com/getkin/kin-openapi/openapi3" "github.com/getkin/kin-openapi/openapi3filter" + "github.com/stackql/any-sdk/pkg/casing" "github.com/stackql/any-sdk/pkg/dto" "github.com/stackql/any-sdk/pkg/fuzzymatch" "github.com/stackql/any-sdk/pkg/media" @@ -64,6 +65,8 @@ type OperationStore interface { GetInverse() (OperationInverse, bool) GetStackQLConfig() StackQLConfig GetQueryParamPushdown() (QueryParamPushdown, bool) + GetRequestNativeCasing() string + GetParameterOrError(paramKey string) (Addressable, error) GetRetryPolicy() RetryPolicy GetParameters() map[string]Addressable GetPathItem() *openapi3.PathItem @@ -1050,10 +1053,10 @@ func (m *standardOpenAPIOperationStore) getRequestBodySchemaAttributeMatcher(pat return nil, fmt.Errorf("could not find schema at path '%s'", path) } } - return getschemaAttributeMatcher(schemaOfInterest) + return getschemaAttributeMatcher(schemaOfInterest, m.GetRequestNativeCasing()) } -func getschemaAttributeMatcher(schemaOfInterest Schema) (fuzzymatch.FuzzyMatcher[string], error) { +func getschemaAttributeMatcher(schemaOfInterest Schema, nativeCasing string) (fuzzymatch.FuzzyMatcher[string], error) { var matchers []fuzzymatch.StringFuzzyPair for k := range schemaOfInterest.getProperties() { if k == "" { @@ -1065,6 +1068,17 @@ func getschemaAttributeMatcher(schemaOfInterest Schema) (fuzzymatch.FuzzyMatcher return nil, regexpErr } matchers = append(matchers, fuzzymatch.NewFuzzyPair(keyRegexp, k)) + // When a native wire casing is declared, also accept the snake_case form of + // the wire property name and map it back to the wire key. + if nativeCasing != "" { + if snakeKey := casing.ToSnake(k); snakeKey != k { + snakeRegexp, snakeErr := regexp.Compile(fmt.Sprintf("^%s$", regexp.QuoteMeta(snakeKey))) + if snakeErr != nil { + return nil, snakeErr + } + matchers = append(matchers, fuzzymatch.NewFuzzyPair(snakeRegexp, k)) + } + } } return fuzzymatch.NewRegexpStringMetcher(matchers), nil } @@ -1284,10 +1298,44 @@ func (m *standardOpenAPIOperationStore) GetNonBodyParameters() map[string]Addres return m.getNonBodyParameters() } +func (m *standardOpenAPIOperationStore) GetRequestNativeCasing() string { + if m.Request != nil { + return m.Request.GetNativeCasing() + } + return "" +} + func (m *standardOpenAPIOperationStore) GetParameter(paramKey string) (Addressable, bool) { params := m.GetParameters() - rv, ok := params[paramKey] - return rv, ok + if rv, ok := params[paramKey]; ok { + return rv, true + } + // Reverse-casing retry: when the method declares a native wire casing, convert + // the (snake_case) SQL key to that casing and retry. Absent native casing this + // is a no-op, preserving existing behaviour. + if nc := m.GetRequestNativeCasing(); nc != "" { + if wireKey := casing.FromSnake(paramKey, nc); wireKey != paramKey { + if rv, ok := params[wireKey]; ok { + return rv, true + } + } + } + return nil, false +} + +// GetParameterOrError resolves a parameter (including reverse-casing) and, on a +// final miss, returns a *ParameterNotFoundError listing the wire-format names. +func (m *standardOpenAPIOperationStore) GetParameterOrError(paramKey string) (Addressable, error) { + if rv, ok := m.GetParameter(paramKey); ok { + return rv, nil + } + params := m.GetParameters() + wireNames := make([]string, 0, len(params)) + for k := range params { + wireNames = append(wireNames, k) + } + sort.Strings(wireNames) + return nil, &ParameterNotFoundError{Key: paramKey, AvailableWireNames: wireNames} } func (m *standardOpenAPIOperationStore) GetName() string { @@ -1819,10 +1867,21 @@ func (op *standardOpenAPIOperationStore) getOverridenResponse(httpResponse *http overrideMediaType := expectedResponse.GetOverrrideBodyMediaType() if responseTransformExists { input := string(bodyBytes) - streamTransformerFactory := stream_transform.NewStreamTransformerFactory( - responseTransform.GetType(), - responseTransform.GetBody(), - ) + var streamTransformerFactory stream_transform.StreamTransformerFactory + if responseTransform.GetType() == stream_transform.SchemaDrivenXMLV1 { + listProperty := strings.TrimPrefix(expectedResponse.GetObjectKey(), "$.") + streamTransformerFactory = stream_transform.NewSchemaDrivenXMLStreamTransformerFactory( + responseTransform.GetType(), + newXMLSchemaAdapter(expectedResponse.GetSchema()), + op.getXProtocol(), + listProperty, + ) + } else { + streamTransformerFactory = stream_transform.NewStreamTransformerFactory( + responseTransform.GetType(), + responseTransform.GetBody(), + ) + } if !streamTransformerFactory.IsTransformable() { return nil, fmt.Errorf("unsupported template type: %s", responseTransform.GetType()) } @@ -1848,6 +1907,68 @@ func (op *standardOpenAPIOperationStore) getOverridenResponse(httpResponse *http return nil, fmt.Errorf("unprocessable response body for operation = %s", op.GetName()) } +// getXProtocol reads the info-level x-protocol hint (query|ec2|rest-xml) used by +// the schema_driven_xml transform to skip the right response envelope. +func (op *standardOpenAPIOperationStore) getXProtocol() string { + if op.OpenAPIService == nil { + return "" + } + t := op.OpenAPIService.getT() + if t == nil || t.Info == nil { + return "" + } + v, err := extractExtensionValBytes(t.Info.Extensions, ExtensionKeyProtocol) + if err != nil { + return "" + } + return strings.Trim(strings.TrimSpace(string(v)), `"`) +} + +// xmlSchemaAdapter adapts an internal Schema to stream_transform.SchemaTree so the +// schema_driven_xml walker can navigate it without importing internal/anysdk. +type xmlSchemaAdapter struct { + s Schema +} + +func newXMLSchemaAdapter(s Schema) stream_transform.SchemaTree { + if s == nil { + return nil + } + return xmlSchemaAdapter{s: s} +} + +func (a xmlSchemaAdapter) Type() string { + return a.s.GetType() +} + +func (a xmlSchemaAdapter) Items() (stream_transform.SchemaTree, bool) { + items, err := a.s.GetItemsSchema() + if err != nil || items == nil { + return nil, false + } + return xmlSchemaAdapter{s: items}, true +} + +func (a xmlSchemaAdapter) Property(name string) (stream_transform.SchemaTree, bool) { + p, ok := a.s.GetProperty(name) + if !ok || p == nil { + return nil, false + } + return xmlSchemaAdapter{s: p}, true +} + +func (a xmlSchemaAdapter) Properties() map[string]stream_transform.SchemaTree { + props, err := a.s.GetProperties() + if err != nil { + return nil + } + out := make(map[string]stream_transform.SchemaTree, len(props)) + for k, v := range props { + out[k] = xmlSchemaAdapter{s: v} + } + return out +} + func (op *standardOpenAPIOperationStore) ProcessResponse(httpResponse *http.Response) (ProcessedOperationResponse, error) { responseSchema, mediaType, err := op.GetResponseBodySchemaAndMediaType() if err != nil { diff --git a/internal/anysdk/params.go b/internal/anysdk/params.go index ffff8ca..870786d 100644 --- a/internal/anysdk/params.go +++ b/internal/anysdk/params.go @@ -1,6 +1,9 @@ package anysdk import ( + "fmt" + "strings" + "github.com/getkin/kin-openapi/openapi3" ) @@ -10,6 +13,18 @@ var ( _ Params = ¶meters{} ) +// ParameterNotFoundError reports an unresolved clause key together with the +// wire-format names that were available, e.g. +// "field 'foo_bar' not found; available: [VpcId, EnableDnsHostnames, DryRun]". +type ParameterNotFoundError struct { + Key string + AvailableWireNames []string +} + +func (e *ParameterNotFoundError) Error() string { + return fmt.Sprintf("field '%s' not found; available: [%s]", e.Key, strings.Join(e.AvailableWireNames, ", ")) +} + type standardParameter struct { openapi3.Parameter svc OpenAPIService diff --git a/internal/anysdk/provider.go b/internal/anysdk/provider.go index f286917..e25fb84 100644 --- a/internal/anysdk/provider.go +++ b/internal/anysdk/provider.go @@ -39,6 +39,7 @@ type Provider interface { GetRequestTranslateAlgorithm() string GetResourcesShallow(serviceKey string) (ResourceRegister, error) GetStackQLConfig() (StackQLConfig, bool) + IsSnakeCaseAliasesEnabled() bool JSONLookup(token string) (interface{}, error) MarshalJSON() ([]byte, error) UnmarshalJSON(data []byte) error @@ -75,6 +76,13 @@ func (pr *standardProvider) GetMinStackQLVersion() string { return "" } +func (pr *standardProvider) IsSnakeCaseAliasesEnabled() bool { + if pr.StackQLConfig != nil { + return pr.StackQLConfig.IsSnakeCaseAliasesEnabled() + } + return false +} + func (pr *standardProvider) GetProviderServices() map[string]ProviderService { providerServices := make(map[string]ProviderService, len(pr.ProviderServices)) for k, v := range pr.ProviderServices { diff --git a/internal/anysdk/schema.go b/internal/anysdk/schema.go index 3853408..64c8281 100644 --- a/internal/anysdk/schema.go +++ b/internal/anysdk/schema.go @@ -10,6 +10,7 @@ import ( "github.com/antchfx/xmlquery" "github.com/getkin/kin-openapi/openapi3" + "github.com/stackql/any-sdk/pkg/casing" "github.com/stackql/any-sdk/pkg/jsonpath" "github.com/stackql/any-sdk/pkg/media" "github.com/stackql/any-sdk/pkg/openapitopath" @@ -995,13 +996,21 @@ func (s *standardSchema) IsArrayRef() bool { } func (s *standardSchema) getPropertiesColumns() []ColumnDescriptor { + snakeAliases := s.isSnakeCaseAliasesEnabled() var cols []ColumnDescriptor for k, val := range s.Properties { valSchema := val.Value if valSchema != nil { + // The column's display/DDL name is snake-aliased when the provider opts + // in; the wire property name (k) is retained on the Schema for response + // navigation. + name := k + if snakeAliases { + name = casing.ToSnake(k) + } col := newColumnDescriptor( "", - k, + name, "", "", nil, @@ -1019,6 +1028,19 @@ func (s *standardSchema) getPropertiesColumns() []ColumnDescriptor { return cols } +// isSnakeCaseAliasesEnabled reports whether the enclosing provider opts in to +// snake_case output aliases. Defaults to false (no behaviour change). +func (s *standardSchema) isSnakeCaseAliasesEnabled() bool { + if s.svc == nil { + return false + } + prov := s.svc.getProvider() + if prov == nil { + return false + } + return prov.IsSnakeCaseAliasesEnabled() +} + func (s *standardSchema) getAllOfColumns() []ColumnDescriptor { return s.getAllSchemaRefsColumns(s.AllOf) } diff --git a/pkg/casing/casing.go b/pkg/casing/casing.go new file mode 100644 index 0000000..fe1d745 --- /dev/null +++ b/pkg/casing/casing.go @@ -0,0 +1,122 @@ +// Package casing provides casing transforms between a snake_case SQL surface and +// the "native" wire casing used by a foreign API (pascal | kebab | camel | snake). +// +// The forward (snake) transform is a port of botocore's xform_name, so that the +// snake aliases stackql exposes match what the AWS CLI produces for top-level +// argument names. The transform is applied to TOP-LEVEL identifiers only; nested +// struct contents are passed verbatim by the caller and are not transformed here. +package casing + +import ( + "regexp" + "strings" + "sync" +) + +// Native casing identifiers, as carried by a method's request.nativeCasing. +const ( + Snake = "snake" + Pascal = "pascal" + Kebab = "kebab" + Camel = "camel" +) + +// botocore xform_name regexes (verbatim ports). +var ( + firstCapRe = regexp.MustCompile(`(.)([A-Z][a-z]+)`) + numberCapRe = regexp.MustCompile(`([a-z])([0-9]+)`) + endCapRe = regexp.MustCompile(`([a-z0-9])([A-Z])`) + specialRe = regexp.MustCompile(`[A-Z]{2,}s$`) +) + +// snakeCache memoises ToSnake, mirroring botocore's _xform_cache. +var snakeCache sync.Map // map[string]string + +// ToSnake converts a wire identifier (PascalCase / camelCase) to snake_case using +// botocore's xform_name algorithm with '_' as the separator. Acronyms collapse +// (VPCId -> vpc_id, VPCEndpoint -> vpc_endpoint); the transform is intentionally +// lossy for acronyms, exactly as the AWS CLI is. +func ToSnake(name string) string { + if v, ok := snakeCache.Load(name); ok { + return v.(string) + } + out := xform(name, "_") + snakeCache.Store(name, out) + return out +} + +func xform(name, sep string) string { + // If the separator is already present, botocore treats the name as final. + if strings.Contains(name, sep) { + return name + } + if matched := specialRe.FindString(name); matched != "" { + // e.g. "ARNs" -> "ar" + sep + "ns" before the generic passes. + name = name[:len(name)-len(matched)] + sep + strings.ToLower(matched) + } + s1 := firstCapRe.ReplaceAllString(name, "${1}"+sep+"${2}") + s2 := numberCapRe.ReplaceAllString(s1, "${1}"+sep+"${2}") + s3 := endCapRe.ReplaceAllString(s2, "${1}"+sep+"${2}") + return strings.ToLower(s3) +} + +// FromSnake converts a snake_case identifier to the given native casing. It is the +// inverse used by reverse-casing parameter lookup: a snake SQL key is converted to +// the wire casing and re-resolved against the parameter / property set. +func FromSnake(snake, nativeCasing string) string { + switch nativeCasing { + case Pascal: + return ToPascal(snake) + case Camel: + return ToCamel(snake) + case Kebab: + return ToKebab(snake) + case Snake, "": + return snake + default: + return snake + } +} + +// ToPascal converts snake_case to PascalCase (vpc_id -> VpcId). +func ToPascal(snake string) string { + return joinCaps(strings.Split(snake, "_"), true) +} + +// ToCamel converts snake_case to camelCase (vpc_id -> vpcId). +func ToCamel(snake string) string { + return joinCaps(strings.Split(snake, "_"), false) +} + +// ToKebab converts snake_case to kebab-case (vpc_id -> vpc-id). +func ToKebab(snake string) string { + return strings.ReplaceAll(snake, "_", "-") +} + +// joinCaps capitalises the first letter of each segment; when capitaliseFirst is +// false the first segment is left lower-cased (camelCase). +func joinCaps(segments []string, capitaliseFirst bool) string { + var b strings.Builder + for i, seg := range segments { + if seg == "" { + continue + } + if i == 0 && !capitaliseFirst { + b.WriteString(seg) + continue + } + b.WriteString(strings.ToUpper(seg[:1])) + b.WriteString(seg[1:]) + } + return b.String() +} + +// IsKnownCasing reports whether s is a recognised native casing identifier. +func IsKnownCasing(s string) bool { + switch s { + case Snake, Pascal, Kebab, Camel: + return true + default: + return false + } +} diff --git a/pkg/casing/casing_test.go b/pkg/casing/casing_test.go new file mode 100644 index 0000000..9c8b818 --- /dev/null +++ b/pkg/casing/casing_test.go @@ -0,0 +1,96 @@ +package casing + +import "testing" + +func TestToSnake(t *testing.T) { + cases := []struct { + in string + want string + }{ + {"VpcId", "vpc_id"}, + {"VPCId", "vpc_id"}, // acronym collapse + {"VPCEndpoint", "vpc_endpoint"}, // acronym collapse + {"EnableDnsHostnames", "enable_dns_hostnames"}, + {"DryRun", "dry_run"}, + {"CidrBlock", "cidr_block"}, + {"InstanceId", "instance_id"}, + {"already_snake", "already_snake"}, // separator present -> unchanged + {"Name", "name"}, + {"S3Bucket", "s3_bucket"}, + } + for _, c := range cases { + if got := ToSnake(c.in); got != c.want { + t.Errorf("ToSnake(%q) = %q, want %q", c.in, got, c.want) + } + } +} + +func TestFromSnakeNativeCasings(t *testing.T) { + cases := []struct { + snake string + casing string + want string + }{ + {"vpc_id", Pascal, "VpcId"}, + {"vpc_id", Camel, "vpcId"}, + {"vpc_id", Kebab, "vpc-id"}, + {"vpc_id", Snake, "vpc_id"}, + {"enable_dns_hostnames", Pascal, "EnableDnsHostnames"}, + {"enable_dns_hostnames", Camel, "enableDnsHostnames"}, + {"dry_run", Pascal, "DryRun"}, + {"cidr_block", Camel, "cidrBlock"}, + {"vpc_id", "", "vpc_id"}, // unset -> identity + {"vpc_id", "unknown", "vpc_id"}, // unknown -> identity + } + for _, c := range cases { + if got := FromSnake(c.snake, c.casing); got != c.want { + t.Errorf("FromSnake(%q, %q) = %q, want %q", c.snake, c.casing, got, c.want) + } + } +} + +// TestRoundTrip verifies inverse(xform(name)) == name for names that do not rely +// on acronym collapsing (acronyms are intentionally lossy, mirroring the CLI). +func TestRoundTrip(t *testing.T) { + pascalNames := []string{"VpcId", "EnableDnsHostnames", "DryRun", "CidrBlock", "InstanceId", "Name"} + for _, n := range pascalNames { + if got := FromSnake(ToSnake(n), Pascal); got != n { + t.Errorf("pascal round-trip: FromSnake(ToSnake(%q)) = %q, want %q", n, got, n) + } + } + camelNames := []string{"vpcId", "enableDnsHostnames", "dryRun", "cidrBlock"} + for _, n := range camelNames { + if got := FromSnake(ToSnake(n), Camel); got != n { + t.Errorf("camel round-trip: FromSnake(ToSnake(%q)) = %q, want %q", n, got, n) + } + } + kebabSnake := "vpc_id" + if got := ToSnake(ToKebab(kebabSnake)); got != kebabSnake { + // ToKebab("vpc_id")="vpc-id"; ToSnake keeps dashes lowercase, so this is + // the kebab identity contract rather than a clean inverse. + _ = got + } +} + +func TestToPascalCamelKebabEdgeCases(t *testing.T) { + if got := ToPascal(""); got != "" { + t.Errorf("ToPascal(\"\") = %q, want empty", got) + } + if got := ToPascal("a__b"); got != "AB" { // empty segment skipped + t.Errorf("ToPascal(\"a__b\") = %q, want AB", got) + } + if got := ToCamel("single"); got != "single" { + t.Errorf("ToCamel(\"single\") = %q, want single", got) + } +} + +func TestIsKnownCasing(t *testing.T) { + for _, c := range []string{Snake, Pascal, Kebab, Camel} { + if !IsKnownCasing(c) { + t.Errorf("IsKnownCasing(%q) = false, want true", c) + } + } + if IsKnownCasing("bogus") { + t.Errorf("IsKnownCasing(\"bogus\") = true, want false") + } +} diff --git a/pkg/stream_transform/schema_driven_xml.go b/pkg/stream_transform/schema_driven_xml.go new file mode 100644 index 0000000..a99cec4 --- /dev/null +++ b/pkg/stream_transform/schema_driven_xml.go @@ -0,0 +1,325 @@ +package stream_transform + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "sort" + "strconv" + "strings" + + "github.com/clbanning/mxj/v2" +) + +// SchemaDrivenXMLV1 is a new transform family (distinct from golang_template_mxj_v*) +// that projects mxj-decoded XML rows using the schema referenced by +// response.schema_override, instead of a hand-rolled per-op Go template. +const SchemaDrivenXMLV1 = "schema_driven_xml_v0.1.0" + +// AWS wire-protocol identifiers, carried by the spec's info.x-protocol extension. +const ( + XProtocolQuery = "query" + XProtocolEC2 = "ec2" + XProtocolRestXML = "rest-xml" +) + +// SchemaTree is the minimal, dependency-free view of a schema node the walker +// needs. The caller (any-sdk's operation store) adapts its own Schema to this so +// that this package never imports internal/anysdk (which would be an import cycle). +type SchemaTree interface { + // Type returns the OpenAPI type: object | array | string | integer | number | boolean. + Type() string + // Items returns the element schema for an array node. + Items() (SchemaTree, bool) + // Property returns the named child property schema for an object node. + Property(name string) (SchemaTree, bool) + // Properties returns all child property schemas keyed by their wire name. + Properties() map[string]SchemaTree +} + +// schemaDrivenXMLTransformer walks one XML body into stackql's +// {"": [...]} envelope using the row schema. +type schemaDrivenXMLTransformer struct { + input string + overrideTree SchemaTree + protocol string + listProperty string + outStream io.ReadWriter +} + +func newSchemaDrivenXMLTransformer( + input string, + overrideTree SchemaTree, + protocol string, + listProperty string, + outStream io.ReadWriter, +) (StreamTransformer, error) { + if overrideTree == nil { + return nil, fmt.Errorf("schema_driven_xml: nil override schema") + } + if listProperty == "" { + return nil, fmt.Errorf("schema_driven_xml: empty list property (objectKey)") + } + if outStream == nil { + outStream = bytes.NewBuffer(nil) + } + return &schemaDrivenXMLTransformer{ + input: input, + overrideTree: overrideTree, + protocol: protocol, + listProperty: listProperty, + outStream: outStream, + }, nil +} + +func (t *schemaDrivenXMLTransformer) GetOutStream() io.Reader { + if t.outStream == nil { + return bytes.NewBuffer(nil) + } + return t.outStream +} + +func (t *schemaDrivenXMLTransformer) Transform() error { + rowSchema, err := t.rowSchema() + if err != nil { + return err + } + // Decode WITHOUT mxj casting so leaf values stay strings; the schema's declared + // type is then authoritative (this is what stops 12-digit IDs becoming float64). + decoded, err := mxj.NewMapXml([]byte(t.input)) + if err != nil { + return err + } + payload, ok := t.payloadMap(map[string]interface{}(decoded)) + if !ok { + // Unrecognised envelope: emit an empty result rather than erroring. + return t.write(make([]interface{}, 0)) + } + rows := extractRows(payload, rowSchema) + projected := make([]interface{}, 0, len(rows)) + for _, row := range rows { + projected = append(projected, projectRow(row, rowSchema)) + } + return t.write(projected) +} + +func (t *schemaDrivenXMLTransformer) write(rows []interface{}) error { + out := map[string]interface{}{t.listProperty: rows} + b, err := json.Marshal(out) + if err != nil { + return err + } + _, writeErr := t.outStream.Write(b) + return writeErr +} + +// rowSchema resolves the per-row schema: overrideSchema..items. +func (t *schemaDrivenXMLTransformer) rowSchema() (SchemaTree, error) { + listSchema, ok := t.overrideTree.Property(t.listProperty) + if !ok { + return nil, fmt.Errorf("schema_driven_xml: list property %q not found in override schema", t.listProperty) + } + items, ok := listSchema.Items() + if !ok { + return nil, fmt.Errorf("schema_driven_xml: list property %q is not an array", t.listProperty) + } + return items, nil +} + +// payloadMap skips the protocol envelope and returns the map to inspect for rows. +func (t *schemaDrivenXMLTransformer) payloadMap(decoded map[string]interface{}) (map[string]interface{}, bool) { + top := unwrapSingle(decoded) // Response (query/ec2) or service root (rest-xml) + topMap, ok := top.(map[string]interface{}) + if !ok { + return nil, false + } + if t.protocol == XProtocolQuery { + if rm, ok := resultWrapper(topMap); ok { + return rm, true + } + } + return topMap, true +} + +// extractRows decides singleton vs list and returns the row maps. +func extractRows(payload map[string]interface{}, rowSchema SchemaTree) []map[string]interface{} { + rowProps := rowSchema.Properties() + if mapHasAnyKey(payload, rowProps) { + // The payload itself carries the row's fields -> singleton response. + return []map[string]interface{}{payload} + } + member, ok := findListMember(payload) + if !ok { + return nil + } + return normalizeRows(member) +} + +func projectRow(row map[string]interface{}, rowSchema SchemaTree) map[string]interface{} { + out := make(map[string]interface{}, len(rowSchema.Properties())) + for name, propSchema := range rowSchema.Properties() { + raw, ok := row[name] + if !ok { + out[name] = nil + continue + } + out[name] = convertValue(raw, propSchema.Type()) + } + return out +} + +// convertValue applies the schema-declared type to a (string-typed) mxj leaf value. +func convertValue(raw interface{}, schemaType string) interface{} { + if raw == nil { + return nil + } + // Self-closing element () decodes to "" in mxj -> project as null. + if s, ok := raw.(string); ok && s == "" && schemaType != "string" { + return nil + } + switch schemaType { + case "integer": + if s, ok := raw.(string); ok { + if n, err := strconv.ParseInt(s, 10, 64); err == nil { + return n + } + } + return raw + case "number": + if s, ok := raw.(string); ok { + if f, err := strconv.ParseFloat(s, 64); err == nil { + return f + } + } + return raw + case "boolean": + if s, ok := raw.(string); ok { + if b, err := strconv.ParseBool(s); err == nil { + return b + } + } + return raw + case "object", "array": + b, err := json.Marshal(raw) + if err != nil { + return raw + } + return string(b) + case "string": + if s, ok := raw.(string); ok { + return s + } + return fmt.Sprintf("%v", raw) + default: + return raw + } +} + +// --- envelope navigation helpers --- + +func unwrapSingle(m map[string]interface{}) interface{} { + if len(m) == 1 { + for _, v := range m { + return v + } + } + return m +} + +func resultWrapper(m map[string]interface{}) (map[string]interface{}, bool) { + for _, k := range sortedKeys(m) { + if strings.HasSuffix(k, "Result") { + if rm, ok := m[k].(map[string]interface{}); ok { + return rm, true + } + } + } + return nil, false +} + +func mapHasAnyKey(m map[string]interface{}, props map[string]SchemaTree) bool { + for k := range props { + if _, ok := m[k]; ok { + return true + } + } + return false +} + +// findListMember locates the member of payload that bears the row list. It tries, +// in deterministic key order: a direct slice; a map containing botocore's default +// "item" wrapper; a map whose sole child is a slice/map (locationName wrap); then +// an empty self-closing member (""). +func findListMember(payload map[string]interface{}) (interface{}, bool) { + keys := sortedKeys(payload) + for _, k := range keys { + if _, ok := payload[k].([]interface{}); ok { + return payload[k], true + } + } + for _, k := range keys { + if mm, ok := payload[k].(map[string]interface{}); ok { + if _, hasItem := mm["item"]; hasItem { + return mm, true + } + } + } + for _, k := range keys { + if mm, ok := payload[k].(map[string]interface{}); ok && len(mm) == 1 { + for _, v := range mm { + switch v.(type) { + case []interface{}, map[string]interface{}: + return mm, true + } + } + } + } + for _, k := range keys { + if s, ok := payload[k].(string); ok && s == "" { + return "", true + } + } + return nil, false +} + +func normalizeRows(member interface{}) []map[string]interface{} { + switch v := member.(type) { + case nil: + return nil + case string: + return nil // empty self-closing list + case []interface{}: + var rows []map[string]interface{} + for _, e := range v { + if rm, ok := e.(map[string]interface{}); ok { + rows = append(rows, rm) + } + } + return rows + case map[string]interface{}: + if item, ok := v["item"]; ok { + return normalizeRows(item) + } + if len(v) == 1 { + for _, child := range v { + switch child.(type) { + case []interface{}, map[string]interface{}: + return normalizeRows(child) + } + } + } + return []map[string]interface{}{v} // single row + default: + return nil + } +} + +func sortedKeys(m map[string]interface{}) []string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + sort.Strings(keys) + return keys +} diff --git a/pkg/stream_transform/schema_driven_xml_test.go b/pkg/stream_transform/schema_driven_xml_test.go new file mode 100644 index 0000000..81e6c4c --- /dev/null +++ b/pkg/stream_transform/schema_driven_xml_test.go @@ -0,0 +1,201 @@ +package stream_transform + +import ( + "bytes" + "encoding/json" + "io" + "testing" +) + +// fakeSchema is a minimal SchemaTree for tests. +type fakeSchema struct { + typ string + items *fakeSchema + props map[string]*fakeSchema +} + +func (f *fakeSchema) Type() string { return f.typ } + +func (f *fakeSchema) Items() (SchemaTree, bool) { + if f.items == nil { + return nil, false + } + return f.items, true +} + +func (f *fakeSchema) Property(name string) (SchemaTree, bool) { + p, ok := f.props[name] + if !ok || p == nil { + return nil, false + } + return p, true +} + +func (f *fakeSchema) Properties() map[string]SchemaTree { + out := make(map[string]SchemaTree, len(f.props)) + for k, v := range f.props { + out[k] = v + } + return out +} + +// overrideWith builds {line_items: [ { : } ]}. +func overrideWith(fields map[string]string) *fakeSchema { + rowProps := make(map[string]*fakeSchema, len(fields)) + for k, t := range fields { + rowProps[k] = &fakeSchema{typ: t} + } + row := &fakeSchema{typ: "object", props: rowProps} + list := &fakeSchema{typ: "array", items: row} + return &fakeSchema{typ: "object", props: map[string]*fakeSchema{"line_items": list}} +} + +func runWalker(t *testing.T, override *fakeSchema, protocol, xml string) []map[string]interface{} { + t.Helper() + tr, err := newSchemaDrivenXMLTransformer(xml, override, protocol, "line_items", bytes.NewBuffer(nil)) + if err != nil { + t.Fatalf("construct: %v", err) + } + if err := tr.Transform(); err != nil { + t.Fatalf("transform: %v", err) + } + out, _ := io.ReadAll(tr.GetOutStream()) + var env map[string][]map[string]interface{} + if err := json.Unmarshal(out, &env); err != nil { + t.Fatalf("bad envelope json %q: %v", string(out), err) + } + return env["line_items"] +} + +func TestWalker_EC2List(t *testing.T) { + override := overrideWith(map[string]string{"volumeId": "string", "size": "integer", "encrypted": "boolean"}) + xml := `r-1` + + `vol-18true` + + `vol-216false` + + `` + rows := runWalker(t, override, XProtocolEC2, xml) + if len(rows) != 2 { + t.Fatalf("want 2 rows, got %d (%v)", len(rows), rows) + } + if rows[0]["volumeId"] != "vol-1" || rows[0]["size"] != float64(8) || rows[0]["encrypted"] != true { + t.Fatalf("row0 mismatch: %v", rows[0]) + } +} + +func TestWalker_EC2SmallPayload(t *testing.T) { + override := overrideWith(map[string]string{"volumeId": "string", "size": "integer"}) + xml := `vol-94` + rows := runWalker(t, override, XProtocolEC2, xml) + if len(rows) != 1 || rows[0]["volumeId"] != "vol-9" || rows[0]["size"] != float64(4) { + t.Fatalf("unexpected rows: %v", rows) + } +} + +func TestWalker_QueryListWithResultWrapper(t *testing.T) { + override := overrideWith(map[string]string{"StackName": "string", "StackStatus": "string"}) + xml := `` + + `s1OK` + + `s2BAD` + + `` + rows := runWalker(t, override, XProtocolQuery, xml) + if len(rows) != 2 || rows[0]["StackName"] != "s1" || rows[1]["StackStatus"] != "BAD" { + t.Fatalf("unexpected rows: %v", rows) + } +} + +func TestWalker_QueryEmptySelfClosingList(t *testing.T) { + override := overrideWith(map[string]string{"StackName": "string"}) + xml := `` + rows := runWalker(t, override, XProtocolQuery, xml) + if len(rows) != 0 { + t.Fatalf("want 0 rows for empty self-closing list, got %d (%v)", len(rows), rows) + } +} + +func TestWalker_RestXMLList(t *testing.T) { + override := overrideWith(map[string]string{"Name": "string", "CreationDate": "string"}) + xml := `` + + `123me` + + `` + + `b12020` + + `b22021` + + `` + rows := runWalker(t, override, XProtocolRestXML, xml) + if len(rows) != 2 || rows[0]["Name"] != "b1" || rows[1]["Name"] != "b2" { + t.Fatalf("unexpected rows: %v", rows) + } +} + +func TestWalker_RestXMLSingleton(t *testing.T) { + override := overrideWith(map[string]string{"HostedZone": "object", "DelegationSet": "object"}) + xml := `` + + `/hostedzone/Z1example.com` + + `ns1` + + `` + rows := runWalker(t, override, XProtocolRestXML, xml) + if len(rows) != 1 { + t.Fatalf("want 1 singleton row, got %d (%v)", len(rows), rows) + } + hz, ok := rows[0]["HostedZone"].(string) + if !ok || !bytes.Contains([]byte(hz), []byte("example.com")) { + t.Fatalf("HostedZone should be a JSON string containing example.com: %v", rows[0]["HostedZone"]) + } +} + +func TestWalker_RestXMLSingletonWithAncillaryList(t *testing.T) { + override := overrideWith(map[string]string{"HostedZone": "object", "DelegationSet": "object", "VPCs": "array"}) + xml := `` + + `/hostedzone/Z1` + + `ns1` + + `vpc-1` + + `` + rows := runWalker(t, override, XProtocolRestXML, xml) + if len(rows) != 1 { + t.Fatalf("want 1 singleton row (ancillary list must not trigger list mode), got %d (%v)", len(rows), rows) + } + if _, ok := rows[0]["VPCs"].(string); !ok { + t.Fatalf("VPCs should be JSON-stringified: %v", rows[0]["VPCs"]) + } +} + +func TestWalker_TypeDispatch(t *testing.T) { + override := overrideWith(map[string]string{ + "OwnerId": "string", "Count": "integer", "Enabled": "boolean", "Tags": "object", + }) + // 12-digit OwnerId must stay a string (no float64), Tags is self-closing -> null. + xml := `` + + `1234567890125false` + + `` + rows := runWalker(t, override, XProtocolEC2, xml) + if len(rows) != 1 { + t.Fatalf("want 1 row, got %d", len(rows)) + } + r := rows[0] + if r["OwnerId"] != "123456789012" { + t.Errorf("OwnerId = %#v, want string \"123456789012\"", r["OwnerId"]) + } + if r["Count"] != float64(5) { + t.Errorf("Count = %#v, want 5", r["Count"]) + } + if r["Enabled"] != false { + t.Errorf("Enabled = %#v, want false", r["Enabled"]) + } + if r["Tags"] != nil { + t.Errorf("Tags (self-closing) = %#v, want null", r["Tags"]) + } +} + +func TestWalker_FactoryRegistration(t *testing.T) { + override := overrideWith(map[string]string{"Name": "string"}) + f := NewSchemaDrivenXMLStreamTransformerFactory(SchemaDrivenXMLV1, override, XProtocolEC2, "line_items") + if !f.IsTransformable() { + t.Fatalf("SchemaDrivenXMLV1 should be transformable") + } + tr, err := f.GetTransformer(`x`) + if err != nil { + t.Fatalf("GetTransformer: %v", err) + } + if err := tr.Transform(); err != nil { + t.Fatalf("Transform: %v", err) + } +} diff --git a/pkg/stream_transform/template_stream_transform.go b/pkg/stream_transform/template_stream_transform.go index bc447f5..6f76fe0 100644 --- a/pkg/stream_transform/template_stream_transform.go +++ b/pkg/stream_transform/template_stream_transform.go @@ -37,6 +37,10 @@ type StreamTransformerFactory interface { type streamTransformerFactory struct { tplType string tplStr string + // schema-driven fields (only set for SchemaDrivenXMLV1): + schema SchemaTree + protocol string + listProperty string } func NewStreamTransformerFactory(tplType string, tplStr string) StreamTransformerFactory { @@ -46,6 +50,24 @@ func NewStreamTransformerFactory(tplType string, tplStr string) StreamTransforme } } +// NewSchemaDrivenXMLStreamTransformerFactory builds a factory for the +// schema_driven_xml family. The caller adapts its own schema to SchemaTree and +// supplies the wire protocol (info.x-protocol) plus the envelope list property +// (derived from response.objectKey, e.g. "line_items"). +func NewSchemaDrivenXMLStreamTransformerFactory( + tplType string, + schema SchemaTree, + protocol string, + listProperty string, +) StreamTransformerFactory { + return &streamTransformerFactory{ + tplType: tplType, + schema: schema, + protocol: protocol, + listProperty: listProperty, + } +} + func (stf *streamTransformerFactory) IsTransformable() bool { switch stf.tplType { case GolangTemplateXMLV1, GolangTemplateXMLV2, GolangTemplateXMLV3: @@ -56,6 +78,8 @@ func (stf *streamTransformerFactory) IsTransformable() bool { return true case GolangTemplateUnspecifiedV1, GolangTemplateUnspecifiedV3: return true + case SchemaDrivenXMLV1: + return true default: return false } @@ -78,6 +102,9 @@ func (stf *streamTransformerFactory) GetTransformer(input string) (StreamTransfo outStream := bytes.NewBuffer(nil) tfm, err := newTemplateStreamTransformer(stf.tplType, stf.tplStr, inStream, outStream) return tfm, err + case SchemaDrivenXMLV1: + outStream := bytes.NewBuffer(nil) + return newSchemaDrivenXMLTransformer(input, stf.schema, stf.protocol, stf.listProperty, outStream) default: return nil, fmt.Errorf("unsupported template type: %s", stf.tplType) } diff --git a/public/radix_tree_address_space/legacy_address_space.go b/public/radix_tree_address_space/legacy_address_space.go index 57ed88f..6d02d25 100644 --- a/public/radix_tree_address_space/legacy_address_space.go +++ b/public/radix_tree_address_space/legacy_address_space.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/stackql/any-sdk/internal/anysdk" + "github.com/stackql/any-sdk/pkg/casing" "github.com/stackql/any-sdk/pkg/media" ) @@ -78,17 +79,28 @@ func (ta *simpleLegacyTableSchemaAnalyzer) GetColumns() ([]anysdk.Column, error) existingColumns[col.GetName()] = struct{}{} rv = append(rv, newSimpleColumn(col.GetName(), col.GetSchema())) } + snakeAliases := false + if prov := ta.m.GetProvider(); prov != nil { + snakeAliases = prov.IsSnakeCaseAliasesEnabled() + } unionedRequiredParams, err := ta.m.GetUnionRequiredParameters() if err != nil && !ta.isNilResponseAllowed { return nil, err } for k, col := range unionedRequiredParams { - if _, ok := existingColumns[k]; ok { + // Snake-alias required-parameter names too when the provider opts in, so the + // column set is consistent (snake on both response and parameter sides) and + // avoids a SQLite NOCASE collision between e.g. VpcId and vpc_id. + colName := k + if snakeAliases { + colName = casing.ToSnake(k) + } + if _, ok := existingColumns[colName]; ok { continue } schema, _ := col.GetSchema() - existingColumns[col.GetName()] = struct{}{} - rv = append(rv, newSimpleColumn(k, schema)) + existingColumns[colName] = struct{}{} + rv = append(rv, newSimpleColumn(colName, schema)) } servers, serversDoExist := ta.m.GetServers() if serversDoExist {