Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions internal/anysdk/casing_wiring_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package anysdk_test

import (
"testing"

. "github.com/stackql/any-sdk/internal/anysdk"
)

func TestParameterNotFoundErrorMessage(t *testing.T) {
err := &ParameterNotFoundError{
Key: "foo_bar",
AvailableWireNames: []string{"VpcId", "EnableDnsHostnames", "DryRun"},
}
want := "field 'foo_bar' not found; available: [VpcId, EnableDnsHostnames, DryRun]"
if got := err.Error(); got != want {
t.Fatalf("Error() = %q, want %q", got, want)
}
}
6 changes: 6 additions & 0 deletions internal/anysdk/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type StackQLConfig interface {
GetQueryParamPushdown() (QueryParamPushdown, bool)
GetRetryPolicy() (RetryPolicy, bool)
GetMinStackQLVersion() string
IsSnakeCaseAliasesEnabled() bool
//
isObjectSchemaImplicitlyUnioned() bool
setResource(rsc Resource)
Expand All @@ -43,6 +44,7 @@ type standardStackQLConfig struct {
QueryParamPushdown *standardQueryParamPushdown `json:"queryParamPushdown,omitempty" yaml:"queryParamPushdown,omitempty"`
Retry *standardRetryPolicy `json:"retry,omitempty" yaml:"retry,omitempty"`
MinStackQLVersion string `json:"minStackQLVersion,omitempty" yaml:"minStackQLVersion,omitempty"`
SnakeCaseAliases bool `json:"snake_case_aliases,omitempty" yaml:"snake_case_aliases,omitempty"`
}

func (qt standardStackQLConfig) JSONLookup(token string) (interface{}, error) {
Expand Down Expand Up @@ -70,6 +72,10 @@ func (cfg *standardStackQLConfig) GetMinStackQLVersion() string {
return cfg.MinStackQLVersion
}

func (cfg *standardStackQLConfig) IsSnakeCaseAliasesEnabled() bool {
return cfg.SnakeCaseAliases
}

func (cfg *standardStackQLConfig) GetQueryTranspose() (Transform, bool) {
if cfg.QueryTranspose == nil {
return nil, false
Expand Down
3 changes: 3 additions & 0 deletions internal/anysdk/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ const (
ExtensionKeyResources string = "x-stackQL-resources"
ExtensionKeyStringOnly string = "x-stackQL-stringOnly"
ExtensionKeyAlias string = "x-stackQL-alias"
// ExtensionKeyProtocol is the info-level wire protocol hint (query|ec2|rest-xml)
// consumed by the schema_driven_xml response transform.
ExtensionKeyProtocol string = "x-protocol"
)

const (
Expand Down
6 changes: 6 additions & 0 deletions internal/anysdk/expectedRequest.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ type ExpectedRequest interface {
GetBase() string
GetXMLDeclaration() string
GetXMLTransform() string
GetNativeCasing() string
//
setSchema(Schema)
setBodyMediaType(string)
Expand All @@ -33,6 +34,7 @@ type standardExpectedRequest struct {
XMLRootAnnotation string `json:"xmlRootAnnotation,omitempty" yaml:"xmlRootAnnotation,omitempty"`
OverrideSchema *LocalSchemaRef `json:"schema_override,omitempty" yaml:"schema_override,omitempty"`
Transform *standardTransform `json:"transform,omitempty" yaml:"transform,omitempty"`
NativeCasing string `json:"nativeCasing,omitempty" yaml:"nativeCasing,omitempty"`
}

func (er *standardExpectedRequest) setBodyMediaType(s string) {
Expand Down Expand Up @@ -73,6 +75,10 @@ func (er *standardExpectedRequest) GetXMLTransform() string {
return er.XMLTransform
}

func (er *standardExpectedRequest) GetNativeCasing() string {
return er.NativeCasing
}

func (er *standardExpectedRequest) GetSchema() Schema {
if er.OverrideSchema != nil && er.OverrideSchema.Value != nil {
return er.OverrideSchema.Value
Expand Down
137 changes: 129 additions & 8 deletions internal/anysdk/operation_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (

"github.com/getkin/kin-openapi/openapi3"
"github.com/getkin/kin-openapi/openapi3filter"
"github.com/stackql/any-sdk/pkg/casing"
"github.com/stackql/any-sdk/pkg/dto"
"github.com/stackql/any-sdk/pkg/fuzzymatch"
"github.com/stackql/any-sdk/pkg/media"
Expand Down Expand Up @@ -64,6 +65,8 @@ type OperationStore interface {
GetInverse() (OperationInverse, bool)
GetStackQLConfig() StackQLConfig
GetQueryParamPushdown() (QueryParamPushdown, bool)
GetRequestNativeCasing() string
GetParameterOrError(paramKey string) (Addressable, error)
GetRetryPolicy() RetryPolicy
GetParameters() map[string]Addressable
GetPathItem() *openapi3.PathItem
Expand Down Expand Up @@ -1050,10 +1053,10 @@ func (m *standardOpenAPIOperationStore) getRequestBodySchemaAttributeMatcher(pat
return nil, fmt.Errorf("could not find schema at path '%s'", path)
}
}
return getschemaAttributeMatcher(schemaOfInterest)
return getschemaAttributeMatcher(schemaOfInterest, m.GetRequestNativeCasing())
}

func getschemaAttributeMatcher(schemaOfInterest Schema) (fuzzymatch.FuzzyMatcher[string], error) {
func getschemaAttributeMatcher(schemaOfInterest Schema, nativeCasing string) (fuzzymatch.FuzzyMatcher[string], error) {
var matchers []fuzzymatch.StringFuzzyPair
for k := range schemaOfInterest.getProperties() {
if k == "" {
Expand All @@ -1065,6 +1068,17 @@ func getschemaAttributeMatcher(schemaOfInterest Schema) (fuzzymatch.FuzzyMatcher
return nil, regexpErr
}
matchers = append(matchers, fuzzymatch.NewFuzzyPair(keyRegexp, k))
// When a native wire casing is declared, also accept the snake_case form of
// the wire property name and map it back to the wire key.
if nativeCasing != "" {
if snakeKey := casing.ToSnake(k); snakeKey != k {
snakeRegexp, snakeErr := regexp.Compile(fmt.Sprintf("^%s$", regexp.QuoteMeta(snakeKey)))
if snakeErr != nil {
return nil, snakeErr
}
matchers = append(matchers, fuzzymatch.NewFuzzyPair(snakeRegexp, k))
}
}
}
return fuzzymatch.NewRegexpStringMetcher(matchers), nil
}
Expand Down Expand Up @@ -1284,10 +1298,44 @@ func (m *standardOpenAPIOperationStore) GetNonBodyParameters() map[string]Addres
return m.getNonBodyParameters()
}

func (m *standardOpenAPIOperationStore) GetRequestNativeCasing() string {
if m.Request != nil {
return m.Request.GetNativeCasing()
}
return ""
}

func (m *standardOpenAPIOperationStore) GetParameter(paramKey string) (Addressable, bool) {
params := m.GetParameters()
rv, ok := params[paramKey]
return rv, ok
if rv, ok := params[paramKey]; ok {
return rv, true
}
// Reverse-casing retry: when the method declares a native wire casing, convert
// the (snake_case) SQL key to that casing and retry. Absent native casing this
// is a no-op, preserving existing behaviour.
if nc := m.GetRequestNativeCasing(); nc != "" {
if wireKey := casing.FromSnake(paramKey, nc); wireKey != paramKey {
if rv, ok := params[wireKey]; ok {
return rv, true
}
}
}
return nil, false
}

// GetParameterOrError resolves a parameter (including reverse-casing) and, on a
// final miss, returns a *ParameterNotFoundError listing the wire-format names.
func (m *standardOpenAPIOperationStore) GetParameterOrError(paramKey string) (Addressable, error) {
if rv, ok := m.GetParameter(paramKey); ok {
return rv, nil
}
params := m.GetParameters()
wireNames := make([]string, 0, len(params))
for k := range params {
wireNames = append(wireNames, k)
}
sort.Strings(wireNames)
return nil, &ParameterNotFoundError{Key: paramKey, AvailableWireNames: wireNames}
}

func (m *standardOpenAPIOperationStore) GetName() string {
Expand Down Expand Up @@ -1819,10 +1867,21 @@ func (op *standardOpenAPIOperationStore) getOverridenResponse(httpResponse *http
overrideMediaType := expectedResponse.GetOverrrideBodyMediaType()
if responseTransformExists {
input := string(bodyBytes)
streamTransformerFactory := stream_transform.NewStreamTransformerFactory(
responseTransform.GetType(),
responseTransform.GetBody(),
)
var streamTransformerFactory stream_transform.StreamTransformerFactory
if responseTransform.GetType() == stream_transform.SchemaDrivenXMLV1 {
listProperty := strings.TrimPrefix(expectedResponse.GetObjectKey(), "$.")
streamTransformerFactory = stream_transform.NewSchemaDrivenXMLStreamTransformerFactory(
responseTransform.GetType(),
newXMLSchemaAdapter(expectedResponse.GetSchema()),
op.getXProtocol(),
listProperty,
)
} else {
streamTransformerFactory = stream_transform.NewStreamTransformerFactory(
responseTransform.GetType(),
responseTransform.GetBody(),
)
}
if !streamTransformerFactory.IsTransformable() {
return nil, fmt.Errorf("unsupported template type: %s", responseTransform.GetType())
}
Expand All @@ -1848,6 +1907,68 @@ func (op *standardOpenAPIOperationStore) getOverridenResponse(httpResponse *http
return nil, fmt.Errorf("unprocessable response body for operation = %s", op.GetName())
}

// getXProtocol reads the info-level x-protocol hint (query|ec2|rest-xml) used by
// the schema_driven_xml transform to skip the right response envelope.
func (op *standardOpenAPIOperationStore) getXProtocol() string {
if op.OpenAPIService == nil {
return ""
}
t := op.OpenAPIService.getT()
if t == nil || t.Info == nil {
return ""
}
v, err := extractExtensionValBytes(t.Info.Extensions, ExtensionKeyProtocol)
if err != nil {
return ""
}
return strings.Trim(strings.TrimSpace(string(v)), `"`)
}

// xmlSchemaAdapter adapts an internal Schema to stream_transform.SchemaTree so the
// schema_driven_xml walker can navigate it without importing internal/anysdk.
type xmlSchemaAdapter struct {
s Schema
}

func newXMLSchemaAdapter(s Schema) stream_transform.SchemaTree {
if s == nil {
return nil
}
return xmlSchemaAdapter{s: s}
}

func (a xmlSchemaAdapter) Type() string {
return a.s.GetType()
}

func (a xmlSchemaAdapter) Items() (stream_transform.SchemaTree, bool) {
items, err := a.s.GetItemsSchema()
if err != nil || items == nil {
return nil, false
}
return xmlSchemaAdapter{s: items}, true
}

func (a xmlSchemaAdapter) Property(name string) (stream_transform.SchemaTree, bool) {
p, ok := a.s.GetProperty(name)
if !ok || p == nil {
return nil, false
}
return xmlSchemaAdapter{s: p}, true
}

func (a xmlSchemaAdapter) Properties() map[string]stream_transform.SchemaTree {
props, err := a.s.GetProperties()
if err != nil {
return nil
}
out := make(map[string]stream_transform.SchemaTree, len(props))
for k, v := range props {
out[k] = xmlSchemaAdapter{s: v}
}
return out
}

func (op *standardOpenAPIOperationStore) ProcessResponse(httpResponse *http.Response) (ProcessedOperationResponse, error) {
responseSchema, mediaType, err := op.GetResponseBodySchemaAndMediaType()
if err != nil {
Expand Down
15 changes: 15 additions & 0 deletions internal/anysdk/params.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package anysdk

import (
"fmt"
"strings"

"github.com/getkin/kin-openapi/openapi3"
)

Expand All @@ -10,6 +13,18 @@ var (
_ Params = &parameters{}
)

// ParameterNotFoundError reports an unresolved clause key together with the
// wire-format names that were available, e.g.
// "field 'foo_bar' not found; available: [VpcId, EnableDnsHostnames, DryRun]".
type ParameterNotFoundError struct {
Key string
AvailableWireNames []string
}

func (e *ParameterNotFoundError) Error() string {
return fmt.Sprintf("field '%s' not found; available: [%s]", e.Key, strings.Join(e.AvailableWireNames, ", "))
}

type standardParameter struct {
openapi3.Parameter
svc OpenAPIService
Expand Down
8 changes: 8 additions & 0 deletions internal/anysdk/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ type Provider interface {
GetRequestTranslateAlgorithm() string
GetResourcesShallow(serviceKey string) (ResourceRegister, error)
GetStackQLConfig() (StackQLConfig, bool)
IsSnakeCaseAliasesEnabled() bool
JSONLookup(token string) (interface{}, error)
MarshalJSON() ([]byte, error)
UnmarshalJSON(data []byte) error
Expand Down Expand Up @@ -75,6 +76,13 @@ func (pr *standardProvider) GetMinStackQLVersion() string {
return ""
}

func (pr *standardProvider) IsSnakeCaseAliasesEnabled() bool {
if pr.StackQLConfig != nil {
return pr.StackQLConfig.IsSnakeCaseAliasesEnabled()
}
return false
}

func (pr *standardProvider) GetProviderServices() map[string]ProviderService {
providerServices := make(map[string]ProviderService, len(pr.ProviderServices))
for k, v := range pr.ProviderServices {
Expand Down
24 changes: 23 additions & 1 deletion internal/anysdk/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"github.com/antchfx/xmlquery"
"github.com/getkin/kin-openapi/openapi3"
"github.com/stackql/any-sdk/pkg/casing"
"github.com/stackql/any-sdk/pkg/jsonpath"
"github.com/stackql/any-sdk/pkg/media"
"github.com/stackql/any-sdk/pkg/openapitopath"
Expand Down Expand Up @@ -995,13 +996,21 @@ func (s *standardSchema) IsArrayRef() bool {
}

func (s *standardSchema) getPropertiesColumns() []ColumnDescriptor {
snakeAliases := s.isSnakeCaseAliasesEnabled()
var cols []ColumnDescriptor
for k, val := range s.Properties {
valSchema := val.Value
if valSchema != nil {
// The column's display/DDL name is snake-aliased when the provider opts
// in; the wire property name (k) is retained on the Schema for response
// navigation.
name := k
if snakeAliases {
name = casing.ToSnake(k)
}
col := newColumnDescriptor(
"",
k,
name,
"",
"",
nil,
Expand All @@ -1019,6 +1028,19 @@ func (s *standardSchema) getPropertiesColumns() []ColumnDescriptor {
return cols
}

// isSnakeCaseAliasesEnabled reports whether the enclosing provider opts in to
// snake_case output aliases. Defaults to false (no behaviour change).
func (s *standardSchema) isSnakeCaseAliasesEnabled() bool {
if s.svc == nil {
return false
}
prov := s.svc.getProvider()
if prov == nil {
return false
}
return prov.IsSnakeCaseAliasesEnabled()
}

func (s *standardSchema) getAllOfColumns() []ColumnDescriptor {
return s.getAllSchemaRefsColumns(s.AllOf)
}
Expand Down
Loading
Loading