diff --git a/cmd/serve.go b/cmd/serve.go index 80119fcbb..0bb9540d1 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -488,7 +488,8 @@ func buildAPIDependencies( metaschemaRepository := postgres.NewMetaSchemaRepository(logger, dbc) metaschemaService := metaschema.NewService(metaschemaRepository) - userPATService := userpat.NewService(logger, userPATRepo, cfg.App.PAT, organizationService, roleService, policyService, projectService, auditRecordRepository) + userPATService := userpat.NewService(logger, userPATRepo, cfg.App.PAT, organizationService, roleService, membershipService, projectService, auditRecordRepository) + membershipService.SetUserPATService(userPATService) patAlertService := userpat.NewAlertService(userPATRepo, userService, organizationService, mailDialer, dbc, cfg.App.PAT.Alert, logger, auditRecordRepository) auditRecordService := auditrecord.NewService(auditRecordRepository, userService, serviceUserService, sessionService, userPATService) diff --git a/core/membership/errors.go b/core/membership/errors.go index ce835ae6c..7b952b719 100644 --- a/core/membership/errors.go +++ b/core/membership/errors.go @@ -15,4 +15,5 @@ var ( ErrInvalidResourceType = errors.New("unsupported resource type") ErrInvalidGroupRole = errors.New("role is not valid for group scope") ErrLastGroupOwnerRole = errors.New("cannot change role: this is the last owner of the group") + ErrPrincipalExpired = errors.New("principal has expired") ) diff --git a/core/membership/mocks/user_pat_service.go b/core/membership/mocks/user_pat_service.go new file mode 100644 index 000000000..0987c7e57 --- /dev/null +++ b/core/membership/mocks/user_pat_service.go @@ -0,0 +1,95 @@ +// Code generated by mockery v2.53.5. DO NOT EDIT. + +package mocks + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + + models "github.com/raystack/frontier/core/userpat/models" +) + +// UserPATService is an autogenerated mock type for the UserPATService type +type UserPATService struct { + mock.Mock +} + +type UserPATService_Expecter struct { + mock *mock.Mock +} + +func (_m *UserPATService) EXPECT() *UserPATService_Expecter { + return &UserPATService_Expecter{mock: &_m.Mock} +} + +// GetByID provides a mock function with given fields: ctx, id +func (_m *UserPATService) GetByID(ctx context.Context, id string) (models.PAT, error) { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for GetByID") + } + + var r0 models.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (models.PAT, error)); ok { + return rf(ctx, id) + } + if rf, ok := ret.Get(0).(func(context.Context, string) models.PAT); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Get(0).(models.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// UserPATService_GetByID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetByID' +type UserPATService_GetByID_Call struct { + *mock.Call +} + +// GetByID is a helper method to define mock.On call +// - ctx context.Context +// - id string +func (_e *UserPATService_Expecter) GetByID(ctx interface{}, id interface{}) *UserPATService_GetByID_Call { + return &UserPATService_GetByID_Call{Call: _e.mock.On("GetByID", ctx, id)} +} + +func (_c *UserPATService_GetByID_Call) Run(run func(ctx context.Context, id string)) *UserPATService_GetByID_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *UserPATService_GetByID_Call) Return(_a0 models.PAT, _a1 error) *UserPATService_GetByID_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *UserPATService_GetByID_Call) RunAndReturn(run func(context.Context, string) (models.PAT, error)) *UserPATService_GetByID_Call { + _c.Call.Return(run) + return _c +} + +// NewUserPATService creates a new instance of UserPATService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewUserPATService(t interface { + mock.TestingT + Cleanup(func()) +}) *UserPATService { + mock := &UserPATService{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/core/membership/service.go b/core/membership/service.go index 87e608984..5076c6997 100644 --- a/core/membership/service.go +++ b/core/membership/service.go @@ -4,11 +4,10 @@ import ( "context" "errors" "fmt" + "log/slog" "slices" "time" - "log/slog" - "github.com/raystack/frontier/core/audit" "github.com/raystack/frontier/core/auditrecord" "github.com/raystack/frontier/core/authenticate" @@ -20,6 +19,7 @@ import ( "github.com/raystack/frontier/core/role" "github.com/raystack/frontier/core/serviceuser" "github.com/raystack/frontier/core/user" + patmodels "github.com/raystack/frontier/core/userpat/models" "github.com/raystack/frontier/internal/bootstrap/schema" pkgAuditRecord "github.com/raystack/frontier/pkg/auditrecord" "github.com/raystack/frontier/pkg/utils" @@ -64,6 +64,10 @@ type ServiceuserService interface { Get(ctx context.Context, id string) (serviceuser.ServiceUser, error) } +type UserPATService interface { + GetByID(ctx context.Context, id string) (patmodels.PAT, error) +} + type AuditRecordRepository interface { Create(ctx context.Context, auditRecord auditrecord.AuditRecord) (auditrecord.AuditRecord, error) } @@ -78,6 +82,7 @@ type Service struct { projectService ProjectService groupService GroupService serviceuserService ServiceuserService + userPATService UserPATService auditRecordRepository AuditRecordRepository } @@ -107,7 +112,13 @@ func NewService( } } -// AddOrganizationMember adds a principal (user or service user) to an organization +// SetUserPATService sets the PAT dependency after construction to break the +// circular init order between userpat and membership services. +func (s *Service) SetUserPATService(ups UserPATService) { + s.userPATService = ups +} + +// AddOrganizationMember adds a principal (user, service user, or PAT) to an organization // with an explicit role, bypassing the invitation flow. // Returns ErrAlreadyMember if the principal already has a policy on this org. func (s *Service) AddOrganizationMember(ctx context.Context, orgID, principalID, principalType, roleID string) error { @@ -136,6 +147,7 @@ func (s *Service) AddOrganizationMember(ctx context.Context, orgID, principalID, if err != nil { return fmt.Errorf("list existing policies: %w", err) } + existing = excludePATAllProjects(existing, schema.OrganizationNamespace) if len(existing) > 0 { return ErrAlreadyMember } @@ -145,6 +157,12 @@ func (s *Service) AddOrganizationMember(ctx context.Context, orgID, principalID, return err } + // PATs don't get relations. + if principalType == schema.PATPrincipal { + s.auditOrgMemberAdded(ctx, org, principal, roleID) + return nil + } + relationName := orgRoleToRelation(fetchedRole) if err := s.createRelation(ctx, orgID, schema.OrganizationNamespace, principalID, principalType, relationName); err != nil { // best-effort cleanup to avoid orphaned policy @@ -182,7 +200,7 @@ func (s *Service) AddOrganizationMember(ctx context.Context, orgID, principalID, } // SetOrganizationMemberRole changes an existing member's role in an organization. -// Supports user and service user principals. +// Supports user, service user, and PAT principals. // Skips the write if the member already has exactly the requested role. func (s *Service) SetOrganizationMemberRole(ctx context.Context, orgID, principalID, principalType, roleID string) error { org, err := s.orgService.Get(ctx, orgID) @@ -211,7 +229,9 @@ func (s *Service) SetOrganizationMemberRole(ctx context.Context, orgID, principa if err != nil { return fmt.Errorf("list existing policies: %w", err) } - if len(existing) == 0 { + // drop the PAT's all-projects policy — only the org role should be replaced here. + existing = excludePATAllProjects(existing, schema.OrganizationNamespace) + if len(existing) == 0 && principalType != schema.PATPrincipal { return ErrNotMember } @@ -220,15 +240,25 @@ func (s *Service) SetOrganizationMemberRole(ctx context.Context, orgID, principa return nil } - ownerRoleID, err := s.validateMinOwnerConstraint(ctx, orgID, resolvedRoleID, existing) - if err != nil { - return err + // only human users can be the last owner — skip for service users and PATs. + var ownerRoleID string + if principalType == schema.UserPrincipal { + ownerRoleID, err = s.validateMinOwnerConstraint(ctx, orgID, resolvedRoleID, existing) + if err != nil { + return err + } } if err := s.replacePolicy(ctx, orgID, schema.OrganizationNamespace, principalID, principalType, resolvedRoleID, existing, ownerRoleID); err != nil { return err } + // PATs don't get relations. + if principalType == schema.PATPrincipal { + s.auditOrgMemberRoleChanged(ctx, org, principal, resolvedRoleID) + return nil + } + newRelation := orgRoleToRelation(fetchedRole) oldRelations := []string{schema.OwnerRelationName, schema.MemberRelationName} err = s.replaceRelation(ctx, orgID, schema.OrganizationNamespace, principalID, principalType, oldRelations, newRelation) @@ -248,6 +278,104 @@ func (s *Service) SetOrganizationMemberRole(ctx context.Context, orgID, principa return nil } +// SetPATAllProjectsRole grants a PAT a project-scoped role across all projects +// in the org via the pat_granted relation. Idempotent — replaces any existing +// all-projects role for this PAT on this org. +func (s *Service) SetPATAllProjectsRole(ctx context.Context, orgID, patID, roleID string) error { + org, err := s.orgService.Get(ctx, orgID) + if err != nil { + return err + } + + principal, err := s.validatePrincipal(ctx, orgID, patID, schema.PATPrincipal) + if err != nil { + return err + } + + fetchedRole, err := s.validateProjectRole(ctx, roleID, orgID) + if err != nil { + return err + } + resolvedRoleID := fetchedRole.ID + + allPolicies, err := s.policyService.List(ctx, policy.Filter{ + OrgID: orgID, + PrincipalID: patID, + PrincipalType: schema.PATPrincipal, + }) + if err != nil { + return fmt.Errorf("list existing policies: %w", err) + } + + var existing []policy.Policy + for _, p := range allPolicies { + if p.GrantRelation == schema.PATGrantRelationName { + existing = append(existing, p) + } + } + + if len(existing) == 1 && existing[0].RoleID == resolvedRoleID { + return nil + } + + for _, p := range existing { + if err := s.policyService.Delete(ctx, p.ID); err != nil { + return fmt.Errorf("delete policy %s: %w", p.ID, err) + } + } + + if _, err := s.policyService.Create(ctx, policy.Policy{ + RoleID: resolvedRoleID, + ResourceID: orgID, + ResourceType: schema.OrganizationNamespace, + PrincipalID: patID, + PrincipalType: schema.PATPrincipal, + GrantRelation: schema.PATGrantRelationName, + }); err != nil { + s.log.ErrorContext(ctx, "membership state inconsistent: old pat_granted policies deleted but new policy creation failed, needs manual fix", + "org_id", orgID, + "pat_id", patID, + "role_id", resolvedRoleID, + "error", err, + ) + return fmt.Errorf("create policy: %w", err) + } + + s.auditOrgMemberRoleChanged(ctx, org, principal, resolvedRoleID) + return nil +} + +// ListPoliciesByPrincipal returns every policy held by the principal. +func (s *Service) ListPoliciesByPrincipal(ctx context.Context, principalID, principalType string) ([]policy.Policy, error) { + return s.policyService.List(ctx, policy.Filter{ + PrincipalID: principalID, + PrincipalType: principalType, + }) +} + +// RemoveAllPATPolicies deletes every policy held by a PAT. +func (s *Service) RemoveAllPATPolicies(ctx context.Context, patID string) error { + _, err := s.removePoliciesByFilter(ctx, policy.Filter{ + PrincipalID: patID, + PrincipalType: schema.PATPrincipal, + }) + return err +} + +// removePoliciesByFilter lists policies matching the filter and deletes them. +func (s *Service) removePoliciesByFilter(ctx context.Context, filter policy.Filter) (int, error) { + policies, err := s.policyService.List(ctx, filter) + if err != nil { + return 0, fmt.Errorf("list policies: %w", err) + } + for _, p := range policies { + if err := s.policyService.Delete(ctx, p.ID); err != nil { + return 0, fmt.Errorf("delete policy %s: %w", p.ID, err) + } + } + return len(policies), nil +} + // RemoveOrganizationMember removes a principal from an organization and cascades // the removal through all org projects and groups, cleaning up both policies and // relations. Returns ErrNotMember if the principal has no policies on this org. @@ -275,9 +403,13 @@ func (s *Service) RemoveOrganizationMember(ctx context.Context, orgID, principal return ErrNotMember } - ownerRoleID, err := s.validateMinOwnerConstraint(ctx, orgID, "", orgPolicies) - if err != nil { - return err + // only humans can be the last owner — skip for service users and PATs. + var ownerRoleID string + if principalType == schema.UserPrincipal { + ownerRoleID, err = s.validateMinOwnerConstraint(ctx, orgID, "", orgPolicies) + if err != nil { + return err + } } if err := s.cascadeRemovePrincipal(ctx, org, principalID, principalType, ownerRoleID); err != nil { @@ -586,6 +718,25 @@ func (s *Service) validatePrincipal(ctx context.Context, orgID, principalID, pri Type: schema.ServiceUserPrincipal, Name: su.Title, }, nil + case schema.PATPrincipal: + if s.userPATService == nil { + return principalInfo{}, ErrInvalidPrincipal + } + pat, err := s.userPATService.GetByID(ctx, principalID) + if err != nil { + return principalInfo{}, err + } + if pat.OrgID != orgID { + return principalInfo{}, ErrPrincipalNotInOrg + } + if !pat.ExpiresAt.After(time.Now()) { + return principalInfo{}, ErrPrincipalExpired + } + return principalInfo{ + ID: pat.ID, + Type: schema.PATPrincipal, + Name: pat.Title, + }, nil default: return principalInfo{}, ErrInvalidPrincipal } @@ -765,7 +916,7 @@ func (s *Service) SetProjectMemberRole(ctx context.Context, projectID, principal // RemoveProjectMember removes a principal from a project by deleting all their project-level policies. func (s *Service) RemoveProjectMember(ctx context.Context, projectID, principalID, principalType string) error { switch principalType { - case schema.UserPrincipal, schema.ServiceUserPrincipal, schema.GroupPrincipal: + case schema.UserPrincipal, schema.ServiceUserPrincipal, schema.GroupPrincipal, schema.PATPrincipal: default: return ErrInvalidPrincipalType } @@ -820,6 +971,22 @@ func policyFilterForResource(resourceID, resourceType, principalID, principalTyp return f } +// excludePATAllProjects hides a PAT's all-projects grant from org member +// listings — that policy lives on the org but grants project access, not +// org membership. +func excludePATAllProjects(policies []policy.Policy, resourceType string) []policy.Policy { + if resourceType != schema.OrganizationNamespace { + return policies + } + filtered := policies[:0] + for _, p := range policies { + if p.GrantRelation != schema.PATGrantRelationName { + filtered = append(filtered, p) + } + } + return filtered +} + // validateProjectRole checks that the role is valid for project scope: // - a platform-wide role scoped to projects, or // - a custom role created for the project's parent organization. @@ -885,6 +1052,20 @@ func (s *Service) validateOrgMembership(ctx context.Context, orgID, principalID, if grp.OrganizationID != orgID { return ErrNotOrgMember } + case schema.PATPrincipal: + if s.userPATService == nil { + return ErrInvalidPrincipal + } + pat, err := s.userPATService.GetByID(ctx, principalID) + if err != nil { + return err + } + if pat.OrgID != orgID { + return ErrNotOrgMember + } + if !pat.ExpiresAt.After(time.Now()) { + return ErrPrincipalExpired + } default: return ErrInvalidPrincipalType } @@ -957,6 +1138,7 @@ func (s *Service) ListPrincipalsByResource(ctx context.Context, resourceID, reso if err != nil { return nil, fmt.Errorf("list policies: %w", err) } + policies = excludePATAllProjects(policies, resourceType) // deduplicate by (principalID, principalType) preserving order memberIndex := make(map[string]int, len(policies)) @@ -981,6 +1163,7 @@ func (s *Service) ListPrincipalsByResource(ctx context.Context, resourceID, reso if err != nil { return nil, fmt.Errorf("list policies for role enrichment: %w", err) } + allPolicies = excludePATAllProjects(allPolicies, resourceType) principalRoleIDs := make(map[string][]string, len(members)) roleSeen := make(map[string]map[string]struct{}, len(members)) diff --git a/core/membership/service_test.go b/core/membership/service_test.go index d5fbfd734..7ee7576bf 100644 --- a/core/membership/service_test.go +++ b/core/membership/service_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "testing" + "time" "io" "log/slog" @@ -334,6 +335,98 @@ func TestService_AddOrganizationMember_ServiceUser(t *testing.T) { }) } +func TestService_AddOrganizationMember_PAT(t *testing.T) { + ctx := context.Background() + orgID := uuid.New().String() + patID := uuid.New().String() + viewerRoleID := uuid.New().String() + + enabledOrg := organization.Organization{ID: orgID, Title: "Test Org"} + activePAT := pat.PAT{ID: patID, OrgID: orgID, Title: "test-pat", ExpiresAt: time.Now().Add(time.Hour)} + + t.Run("should add PAT without writing org member/owner relation", func(t *testing.T) { + mockPolicySvc := mocks.NewPolicyService(t) + mockRelSvc := mocks.NewRelationService(t) + mockRoleSvc := mocks.NewRoleService(t) + mockOrgSvc := mocks.NewOrgService(t) + mockPATSvc := mocks.NewUserPATService(t) + mockAuditRepo := mocks.NewAuditRecordRepository(t) + + mockOrgSvc.EXPECT().Get(ctx, orgID).Return(enabledOrg, nil) + mockPATSvc.EXPECT().GetByID(ctx, patID).Return(activePAT, nil) + mockRoleSvc.EXPECT().Get(ctx, viewerRoleID).Return(role.Role{ID: viewerRoleID, Scopes: []string{schema.OrganizationNamespace}}, nil) + mockPolicySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, PrincipalID: patID, PrincipalType: schema.PATPrincipal}).Return([]policy.Policy{}, nil) + mockPolicySvc.EXPECT().Create(ctx, mock.Anything).Return(policy.Policy{}, nil) + mockAuditRepo.EXPECT().Create(ctx, mock.Anything).Return(auditrecord.AuditRecord{}, nil) + + svc := membership.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), mockPolicySvc, mockRelSvc, mockRoleSvc, mockOrgSvc, mocks.NewUserService(t), mocks.NewProjectService(t), mocks.NewGroupService(t), mocks.NewServiceuserService(t), mockAuditRepo) + svc.SetUserPATService(mockPATSvc) + err := svc.AddOrganizationMember(ctx, orgID, patID, schema.PATPrincipal, viewerRoleID) + assert.NoError(t, err) + }) + + t.Run("should reject PAT from different org", func(t *testing.T) { + mockOrgSvc := mocks.NewOrgService(t) + mockPATSvc := mocks.NewUserPATService(t) + + mockOrgSvc.EXPECT().Get(ctx, orgID).Return(enabledOrg, nil) + mockPATSvc.EXPECT().GetByID(ctx, patID).Return(pat.PAT{ID: patID, OrgID: "other-org", ExpiresAt: time.Now().Add(time.Hour)}, nil) + + svc := membership.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), mocks.NewPolicyService(t), mocks.NewRelationService(t), mocks.NewRoleService(t), mockOrgSvc, mocks.NewUserService(t), mocks.NewProjectService(t), mocks.NewGroupService(t), mocks.NewServiceuserService(t), mocks.NewAuditRecordRepository(t)) + svc.SetUserPATService(mockPATSvc) + err := svc.AddOrganizationMember(ctx, orgID, patID, schema.PATPrincipal, viewerRoleID) + assert.ErrorIs(t, err, membership.ErrPrincipalNotInOrg) + }) + + t.Run("should reject expired PAT", func(t *testing.T) { + mockOrgSvc := mocks.NewOrgService(t) + mockPATSvc := mocks.NewUserPATService(t) + + mockOrgSvc.EXPECT().Get(ctx, orgID).Return(enabledOrg, nil) + mockPATSvc.EXPECT().GetByID(ctx, patID).Return(pat.PAT{ID: patID, OrgID: orgID, ExpiresAt: time.Now().Add(-time.Hour)}, nil) + + svc := membership.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), mocks.NewPolicyService(t), mocks.NewRelationService(t), mocks.NewRoleService(t), mockOrgSvc, mocks.NewUserService(t), mocks.NewProjectService(t), mocks.NewGroupService(t), mocks.NewServiceuserService(t), mocks.NewAuditRecordRepository(t)) + svc.SetUserPATService(mockPATSvc) + err := svc.AddOrganizationMember(ctx, orgID, patID, schema.PATPrincipal, viewerRoleID) + assert.ErrorIs(t, err, membership.ErrPrincipalExpired) + }) + + t.Run("should reject PAT principal when userPATService is not wired", func(t *testing.T) { + mockOrgSvc := mocks.NewOrgService(t) + + mockOrgSvc.EXPECT().Get(ctx, orgID).Return(enabledOrg, nil) + + svc := membership.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), mocks.NewPolicyService(t), mocks.NewRelationService(t), mocks.NewRoleService(t), mockOrgSvc, mocks.NewUserService(t), mocks.NewProjectService(t), mocks.NewGroupService(t), mocks.NewServiceuserService(t), mocks.NewAuditRecordRepository(t)) + err := svc.AddOrganizationMember(ctx, orgID, patID, schema.PATPrincipal, viewerRoleID) + assert.ErrorIs(t, err, membership.ErrInvalidPrincipal) + }) + + t.Run("should allow adding an org role to a PAT that has only a pat_granted policy", func(t *testing.T) { + // PAT holds only an all-projects (pat_granted) policy. AddOrganizationMember + // should not treat that as existing org membership. + mockPolicySvc := mocks.NewPolicyService(t) + mockRelSvc := mocks.NewRelationService(t) + mockRoleSvc := mocks.NewRoleService(t) + mockOrgSvc := mocks.NewOrgService(t) + mockPATSvc := mocks.NewUserPATService(t) + mockAuditRepo := mocks.NewAuditRecordRepository(t) + + mockOrgSvc.EXPECT().Get(ctx, orgID).Return(enabledOrg, nil) + mockPATSvc.EXPECT().GetByID(ctx, patID).Return(activePAT, nil) + mockRoleSvc.EXPECT().Get(ctx, viewerRoleID).Return(role.Role{ID: viewerRoleID, Scopes: []string{schema.OrganizationNamespace}}, nil) + mockPolicySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, PrincipalID: patID, PrincipalType: schema.PATPrincipal}).Return([]policy.Policy{ + {ID: "pat-granted-pol", RoleID: uuid.New().String(), GrantRelation: schema.PATGrantRelationName}, + }, nil) + mockPolicySvc.EXPECT().Create(ctx, mock.Anything).Return(policy.Policy{}, nil) + mockAuditRepo.EXPECT().Create(ctx, mock.Anything).Return(auditrecord.AuditRecord{}, nil) + + svc := membership.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), mockPolicySvc, mockRelSvc, mockRoleSvc, mockOrgSvc, mocks.NewUserService(t), mocks.NewProjectService(t), mocks.NewGroupService(t), mocks.NewServiceuserService(t), mockAuditRepo) + svc.SetUserPATService(mockPATSvc) + err := svc.AddOrganizationMember(ctx, orgID, patID, schema.PATPrincipal, viewerRoleID) + assert.NoError(t, err) + }) +} + func TestService_SetOrganizationMemberRole(t *testing.T) { ctx := context.Background() orgID := uuid.New().String() @@ -563,9 +656,7 @@ func TestService_SetOrganizationMemberRole_ServiceUser(t *testing.T) { mockSuSvc.EXPECT().Get(ctx, suID).Return(serviceuser.ServiceUser{ID: suID, OrgID: orgID, Title: "test-su", State: string(serviceuser.Enabled)}, nil) mockRoleSvc.EXPECT().Get(ctx, viewerRoleID).Return(role.Role{ID: viewerRoleID, Scopes: []string{schema.OrganizationNamespace}}, nil) mockPolicySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, PrincipalID: suID, PrincipalType: schema.ServiceUserPrincipal}).Return([]policy.Policy{{ID: "p1", RoleID: ownerRoleID}}, nil) - mockRoleSvc.EXPECT().Get(ctx, schema.RoleOrganizationOwner).Return(role.Role{ID: ownerRoleID, Name: schema.RoleOrganizationOwner}, nil) - mockPolicySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, RoleID: ownerRoleID}).Return([]policy.Policy{{ID: "p1"}, {ID: "p2"}}, nil) - mockPolicySvc.EXPECT().DeleteWithMinRoleGuard(ctx, "p1", ownerRoleID).Return(nil) + mockPolicySvc.EXPECT().Delete(ctx, "p1").Return(nil) mockPolicySvc.EXPECT().Create(ctx, mock.Anything).Return(policy.Policy{}, nil) mockRelSvc.EXPECT().Delete(ctx, mock.Anything).Return(relation.ErrNotExist).Times(2) mockRelSvc.EXPECT().Create(ctx, mock.Anything).Return(relation.Relation{}, nil) @@ -601,6 +692,84 @@ func TestService_SetOrganizationMemberRole_ServiceUser(t *testing.T) { }) } +func TestService_SetOrganizationMemberRole_PAT(t *testing.T) { + ctx := context.Background() + orgID := uuid.New().String() + patID := uuid.New().String() + viewerRoleID := uuid.New().String() + oldRoleID := uuid.New().String() + + enabledOrg := organization.Organization{ID: orgID, Title: "Test Org"} + activePAT := pat.PAT{ID: patID, OrgID: orgID, Title: "test-pat", ExpiresAt: time.Now().Add(time.Hour)} + + t.Run("should replace PAT role without writing org member/owner relation", func(t *testing.T) { + mockPolicySvc := mocks.NewPolicyService(t) + mockRelSvc := mocks.NewRelationService(t) + mockRoleSvc := mocks.NewRoleService(t) + mockOrgSvc := mocks.NewOrgService(t) + mockPATSvc := mocks.NewUserPATService(t) + mockAuditRepo := mocks.NewAuditRecordRepository(t) + + mockOrgSvc.EXPECT().Get(ctx, orgID).Return(enabledOrg, nil) + mockPATSvc.EXPECT().GetByID(ctx, patID).Return(activePAT, nil) + mockRoleSvc.EXPECT().Get(ctx, viewerRoleID).Return(role.Role{ID: viewerRoleID, Scopes: []string{schema.OrganizationNamespace}}, nil) + mockPolicySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, PrincipalID: patID, PrincipalType: schema.PATPrincipal}).Return([]policy.Policy{{ID: "p1", RoleID: oldRoleID}}, nil) + mockPolicySvc.EXPECT().Delete(ctx, "p1").Return(nil) + mockPolicySvc.EXPECT().Create(ctx, mock.Anything).Return(policy.Policy{}, nil) + mockAuditRepo.EXPECT().Create(ctx, mock.Anything).Return(auditrecord.AuditRecord{}, nil) + + svc := membership.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), mockPolicySvc, mockRelSvc, mockRoleSvc, mockOrgSvc, mocks.NewUserService(t), mocks.NewProjectService(t), mocks.NewGroupService(t), mocks.NewServiceuserService(t), mockAuditRepo) + svc.SetUserPATService(mockPATSvc) + err := svc.SetOrganizationMemberRole(ctx, orgID, patID, schema.PATPrincipal, viewerRoleID) + assert.NoError(t, err) + }) + + t.Run("should reject expired PAT", func(t *testing.T) { + mockOrgSvc := mocks.NewOrgService(t) + mockPATSvc := mocks.NewUserPATService(t) + + mockOrgSvc.EXPECT().Get(ctx, orgID).Return(enabledOrg, nil) + mockPATSvc.EXPECT().GetByID(ctx, patID).Return(pat.PAT{ID: patID, OrgID: orgID, ExpiresAt: time.Now().Add(-time.Hour)}, nil) + + svc := membership.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), mocks.NewPolicyService(t), mocks.NewRelationService(t), mocks.NewRoleService(t), mockOrgSvc, mocks.NewUserService(t), mocks.NewProjectService(t), mocks.NewGroupService(t), mocks.NewServiceuserService(t), mocks.NewAuditRecordRepository(t)) + svc.SetUserPATService(mockPATSvc) + err := svc.SetOrganizationMemberRole(ctx, orgID, patID, schema.PATPrincipal, viewerRoleID) + assert.ErrorIs(t, err, membership.ErrPrincipalExpired) + }) + + t.Run("should leave the pat_granted policy untouched when only the granted role changes", func(t *testing.T) { + // A PAT can hold both a granted org policy and a pat_granted all-projects + // policy on the same org. SetOrganizationMemberRole should only replace + // the granted one — the pat_granted policy is project-cascade scope and + // must not be wiped as collateral. + mockPolicySvc := mocks.NewPolicyService(t) + mockRelSvc := mocks.NewRelationService(t) + mockRoleSvc := mocks.NewRoleService(t) + mockOrgSvc := mocks.NewOrgService(t) + mockPATSvc := mocks.NewUserPATService(t) + mockAuditRepo := mocks.NewAuditRecordRepository(t) + + mockOrgSvc.EXPECT().Get(ctx, orgID).Return(enabledOrg, nil) + mockPATSvc.EXPECT().GetByID(ctx, patID).Return(activePAT, nil) + mockRoleSvc.EXPECT().Get(ctx, viewerRoleID).Return(role.Role{ID: viewerRoleID, Scopes: []string{schema.OrganizationNamespace}}, nil) + mockPolicySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, PrincipalID: patID, PrincipalType: schema.PATPrincipal}).Return([]policy.Policy{ + {ID: "granted-pol", RoleID: oldRoleID, GrantRelation: schema.RoleGrantRelationName}, + {ID: "pat-granted-pol", RoleID: uuid.New().String(), GrantRelation: schema.PATGrantRelationName}, + }, nil) + // Only the granted policy is deleted; pat-granted-pol stays. + mockPolicySvc.EXPECT().Delete(ctx, "granted-pol").Return(nil) + mockPolicySvc.EXPECT().Create(ctx, mock.MatchedBy(func(p policy.Policy) bool { + return p.RoleID == viewerRoleID && p.PrincipalID == patID + })).Return(policy.Policy{}, nil) + mockAuditRepo.EXPECT().Create(ctx, mock.Anything).Return(auditrecord.AuditRecord{}, nil) + + svc := membership.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), mockPolicySvc, mockRelSvc, mockRoleSvc, mockOrgSvc, mocks.NewUserService(t), mocks.NewProjectService(t), mocks.NewGroupService(t), mocks.NewServiceuserService(t), mockAuditRepo) + svc.SetUserPATService(mockPATSvc) + err := svc.SetOrganizationMemberRole(ctx, orgID, patID, schema.PATPrincipal, viewerRoleID) + assert.NoError(t, err) + }) +} + func TestService_RemoveOrganizationMember(t *testing.T) { ctx := context.Background() orgID := uuid.New().String() @@ -1018,6 +1187,17 @@ func TestService_RemoveProjectMember(t *testing.T) { principalID: suID, principalType: schema.ServiceUserPrincipal, }, + { + name: "should succeed removing a PAT", + setup: func(policySvc *mocks.PolicyService, prjSvc *mocks.ProjectService, auditRepo *mocks.AuditRecordRepository) { + prjSvc.EXPECT().Get(ctx, projectID).Return(prj, nil) + policySvc.EXPECT().List(ctx, policy.Filter{ProjectID: projectID, PrincipalID: userID, PrincipalType: schema.PATPrincipal}).Return([]policy.Policy{{ID: "p1"}}, nil) + policySvc.EXPECT().Delete(ctx, "p1").Return(nil) + auditRepo.EXPECT().Create(ctx, mock.Anything).Return(auditrecord.AuditRecord{}, nil) + }, + principalID: userID, + principalType: schema.PATPrincipal, + }, } for _, tt := range tests { @@ -1042,6 +1222,246 @@ func TestService_RemoveProjectMember(t *testing.T) { } } +func TestService_SetPATAllProjectsRole(t *testing.T) { + ctx := context.Background() + orgID := uuid.New().String() + patID := uuid.New().String() + projectRoleID := uuid.New().String() + oldRoleID := uuid.New().String() + + enabledOrg := organization.Organization{ID: orgID, Title: "Test Org"} + activePAT := pat.PAT{ID: patID, OrgID: orgID, Title: "test-pat", ExpiresAt: time.Now().Add(time.Hour)} + projectRole := role.Role{ID: projectRoleID, Scopes: []string{schema.ProjectNamespace}} + + t.Run("should write pat_granted policy on the org", func(t *testing.T) { + mockPolicySvc := mocks.NewPolicyService(t) + mockRoleSvc := mocks.NewRoleService(t) + mockOrgSvc := mocks.NewOrgService(t) + mockPATSvc := mocks.NewUserPATService(t) + mockAuditRepo := mocks.NewAuditRecordRepository(t) + + mockOrgSvc.EXPECT().Get(ctx, orgID).Return(enabledOrg, nil) + mockPATSvc.EXPECT().GetByID(ctx, patID).Return(activePAT, nil) + mockRoleSvc.EXPECT().Get(ctx, projectRoleID).Return(projectRole, nil) + mockPolicySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, PrincipalID: patID, PrincipalType: schema.PATPrincipal}).Return([]policy.Policy{}, nil) + mockPolicySvc.EXPECT().Create(ctx, mock.MatchedBy(func(p policy.Policy) bool { + return p.RoleID == projectRoleID && + p.ResourceID == orgID && + p.ResourceType == schema.OrganizationNamespace && + p.PrincipalID == patID && + p.PrincipalType == schema.PATPrincipal && + p.GrantRelation == schema.PATGrantRelationName + })).Return(policy.Policy{}, nil) + mockAuditRepo.EXPECT().Create(ctx, mock.Anything).Return(auditrecord.AuditRecord{}, nil) + + svc := membership.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), mockPolicySvc, mocks.NewRelationService(t), mockRoleSvc, mockOrgSvc, mocks.NewUserService(t), mocks.NewProjectService(t), mocks.NewGroupService(t), mocks.NewServiceuserService(t), mockAuditRepo) + svc.SetUserPATService(mockPATSvc) + err := svc.SetPATAllProjectsRole(ctx, orgID, patID, projectRoleID) + assert.NoError(t, err) + }) + + t.Run("should be a no-op when the same pat_granted role is already set", func(t *testing.T) { + mockPolicySvc := mocks.NewPolicyService(t) + mockRoleSvc := mocks.NewRoleService(t) + mockOrgSvc := mocks.NewOrgService(t) + mockPATSvc := mocks.NewUserPATService(t) + + mockOrgSvc.EXPECT().Get(ctx, orgID).Return(enabledOrg, nil) + mockPATSvc.EXPECT().GetByID(ctx, patID).Return(activePAT, nil) + mockRoleSvc.EXPECT().Get(ctx, projectRoleID).Return(projectRole, nil) + mockPolicySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, PrincipalID: patID, PrincipalType: schema.PATPrincipal}).Return([]policy.Policy{ + {ID: "p1", RoleID: projectRoleID, GrantRelation: schema.PATGrantRelationName}, + }, nil) + + svc := membership.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), mockPolicySvc, mocks.NewRelationService(t), mockRoleSvc, mockOrgSvc, mocks.NewUserService(t), mocks.NewProjectService(t), mocks.NewGroupService(t), mocks.NewServiceuserService(t), mocks.NewAuditRecordRepository(t)) + svc.SetUserPATService(mockPATSvc) + err := svc.SetPATAllProjectsRole(ctx, orgID, patID, projectRoleID) + assert.NoError(t, err) + }) + + t.Run("should replace existing pat_granted policy with new role", func(t *testing.T) { + mockPolicySvc := mocks.NewPolicyService(t) + mockRoleSvc := mocks.NewRoleService(t) + mockOrgSvc := mocks.NewOrgService(t) + mockPATSvc := mocks.NewUserPATService(t) + mockAuditRepo := mocks.NewAuditRecordRepository(t) + + mockOrgSvc.EXPECT().Get(ctx, orgID).Return(enabledOrg, nil) + mockPATSvc.EXPECT().GetByID(ctx, patID).Return(activePAT, nil) + mockRoleSvc.EXPECT().Get(ctx, projectRoleID).Return(projectRole, nil) + mockPolicySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, PrincipalID: patID, PrincipalType: schema.PATPrincipal}).Return([]policy.Policy{ + {ID: "p1", RoleID: oldRoleID, GrantRelation: schema.PATGrantRelationName}, + }, nil) + mockPolicySvc.EXPECT().Delete(ctx, "p1").Return(nil) + mockPolicySvc.EXPECT().Create(ctx, mock.Anything).Return(policy.Policy{}, nil) + mockAuditRepo.EXPECT().Create(ctx, mock.Anything).Return(auditrecord.AuditRecord{}, nil) + + svc := membership.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), mockPolicySvc, mocks.NewRelationService(t), mockRoleSvc, mockOrgSvc, mocks.NewUserService(t), mocks.NewProjectService(t), mocks.NewGroupService(t), mocks.NewServiceuserService(t), mockAuditRepo) + svc.SetUserPATService(mockPATSvc) + err := svc.SetPATAllProjectsRole(ctx, orgID, patID, projectRoleID) + assert.NoError(t, err) + }) + + t.Run("should ignore an existing granted policy and only replace pat_granted", func(t *testing.T) { + mockPolicySvc := mocks.NewPolicyService(t) + mockRoleSvc := mocks.NewRoleService(t) + mockOrgSvc := mocks.NewOrgService(t) + mockPATSvc := mocks.NewUserPATService(t) + mockAuditRepo := mocks.NewAuditRecordRepository(t) + + mockOrgSvc.EXPECT().Get(ctx, orgID).Return(enabledOrg, nil) + mockPATSvc.EXPECT().GetByID(ctx, patID).Return(activePAT, nil) + mockRoleSvc.EXPECT().Get(ctx, projectRoleID).Return(projectRole, nil) + mockPolicySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, PrincipalID: patID, PrincipalType: schema.PATPrincipal}).Return([]policy.Policy{ + {ID: "granted-pol", RoleID: oldRoleID, GrantRelation: schema.RoleGrantRelationName}, + }, nil) + // no Delete on the granted policy; only Create for the new pat_granted + mockPolicySvc.EXPECT().Create(ctx, mock.Anything).Return(policy.Policy{}, nil) + mockAuditRepo.EXPECT().Create(ctx, mock.Anything).Return(auditrecord.AuditRecord{}, nil) + + svc := membership.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), mockPolicySvc, mocks.NewRelationService(t), mockRoleSvc, mockOrgSvc, mocks.NewUserService(t), mocks.NewProjectService(t), mocks.NewGroupService(t), mocks.NewServiceuserService(t), mockAuditRepo) + svc.SetUserPATService(mockPATSvc) + err := svc.SetPATAllProjectsRole(ctx, orgID, patID, projectRoleID) + assert.NoError(t, err) + }) + + t.Run("should replace only the pat_granted policy when both granted and pat_granted exist", func(t *testing.T) { + // Granted policy's role matches the requested role — the function must + // not treat this as a no-op (the role-match check is for pat_granted + // only) and must not delete the granted policy. + mockPolicySvc := mocks.NewPolicyService(t) + mockRoleSvc := mocks.NewRoleService(t) + mockOrgSvc := mocks.NewOrgService(t) + mockPATSvc := mocks.NewUserPATService(t) + mockAuditRepo := mocks.NewAuditRecordRepository(t) + + mockOrgSvc.EXPECT().Get(ctx, orgID).Return(enabledOrg, nil) + mockPATSvc.EXPECT().GetByID(ctx, patID).Return(activePAT, nil) + mockRoleSvc.EXPECT().Get(ctx, projectRoleID).Return(projectRole, nil) + mockPolicySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, PrincipalID: patID, PrincipalType: schema.PATPrincipal}).Return([]policy.Policy{ + {ID: "granted-pol", RoleID: projectRoleID, GrantRelation: schema.RoleGrantRelationName}, + {ID: "pat-granted-pol", RoleID: oldRoleID, GrantRelation: schema.PATGrantRelationName}, + }, nil) + mockPolicySvc.EXPECT().Delete(ctx, "pat-granted-pol").Return(nil) + mockPolicySvc.EXPECT().Create(ctx, mock.MatchedBy(func(p policy.Policy) bool { + return p.GrantRelation == schema.PATGrantRelationName && p.RoleID == projectRoleID + })).Return(policy.Policy{}, nil) + mockAuditRepo.EXPECT().Create(ctx, mock.Anything).Return(auditrecord.AuditRecord{}, nil) + + svc := membership.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), mockPolicySvc, mocks.NewRelationService(t), mockRoleSvc, mockOrgSvc, mocks.NewUserService(t), mocks.NewProjectService(t), mocks.NewGroupService(t), mocks.NewServiceuserService(t), mockAuditRepo) + svc.SetUserPATService(mockPATSvc) + err := svc.SetPATAllProjectsRole(ctx, orgID, patID, projectRoleID) + assert.NoError(t, err) + }) + + t.Run("should reject role that is not project-scoped", func(t *testing.T) { + mockRoleSvc := mocks.NewRoleService(t) + mockOrgSvc := mocks.NewOrgService(t) + mockPATSvc := mocks.NewUserPATService(t) + + mockOrgSvc.EXPECT().Get(ctx, orgID).Return(enabledOrg, nil) + mockPATSvc.EXPECT().GetByID(ctx, patID).Return(activePAT, nil) + mockRoleSvc.EXPECT().Get(ctx, projectRoleID).Return(role.Role{ID: projectRoleID, Scopes: []string{schema.OrganizationNamespace}}, nil) + + svc := membership.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), mocks.NewPolicyService(t), mocks.NewRelationService(t), mockRoleSvc, mockOrgSvc, mocks.NewUserService(t), mocks.NewProjectService(t), mocks.NewGroupService(t), mocks.NewServiceuserService(t), mocks.NewAuditRecordRepository(t)) + svc.SetUserPATService(mockPATSvc) + err := svc.SetPATAllProjectsRole(ctx, orgID, patID, projectRoleID) + assert.ErrorIs(t, err, membership.ErrInvalidProjectRole) + }) + + t.Run("should reject expired PAT", func(t *testing.T) { + mockOrgSvc := mocks.NewOrgService(t) + mockPATSvc := mocks.NewUserPATService(t) + + mockOrgSvc.EXPECT().Get(ctx, orgID).Return(enabledOrg, nil) + mockPATSvc.EXPECT().GetByID(ctx, patID).Return(pat.PAT{ID: patID, OrgID: orgID, ExpiresAt: time.Now().Add(-time.Hour)}, nil) + + svc := membership.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), mocks.NewPolicyService(t), mocks.NewRelationService(t), mocks.NewRoleService(t), mockOrgSvc, mocks.NewUserService(t), mocks.NewProjectService(t), mocks.NewGroupService(t), mocks.NewServiceuserService(t), mocks.NewAuditRecordRepository(t)) + svc.SetUserPATService(mockPATSvc) + err := svc.SetPATAllProjectsRole(ctx, orgID, patID, projectRoleID) + assert.ErrorIs(t, err, membership.ErrPrincipalExpired) + }) +} + +func TestService_ListPoliciesByPrincipal(t *testing.T) { + ctx := context.Background() + principalID := uuid.New().String() + + t.Run("returns every policy held by the principal", func(t *testing.T) { + mockPolicySvc := mocks.NewPolicyService(t) + mockPolicySvc.EXPECT().List(ctx, policy.Filter{PrincipalID: principalID, PrincipalType: schema.PATPrincipal}). + Return([]policy.Policy{{ID: "p1"}, {ID: "p2"}}, nil) + + svc := membership.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), mockPolicySvc, mocks.NewRelationService(t), mocks.NewRoleService(t), mocks.NewOrgService(t), mocks.NewUserService(t), mocks.NewProjectService(t), mocks.NewGroupService(t), mocks.NewServiceuserService(t), mocks.NewAuditRecordRepository(t)) + got, err := svc.ListPoliciesByPrincipal(ctx, principalID, schema.PATPrincipal) + assert.NoError(t, err) + assert.Len(t, got, 2) + }) + + t.Run("surfaces policy list errors", func(t *testing.T) { + mockPolicySvc := mocks.NewPolicyService(t) + mockPolicySvc.EXPECT().List(ctx, policy.Filter{PrincipalID: principalID, PrincipalType: schema.PATPrincipal}). + Return(nil, errors.New("db down")) + + svc := membership.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), mockPolicySvc, mocks.NewRelationService(t), mocks.NewRoleService(t), mocks.NewOrgService(t), mocks.NewUserService(t), mocks.NewProjectService(t), mocks.NewGroupService(t), mocks.NewServiceuserService(t), mocks.NewAuditRecordRepository(t)) + _, err := svc.ListPoliciesByPrincipal(ctx, principalID, schema.PATPrincipal) + assert.Error(t, err) + }) +} + +func TestService_RemoveAllPATPolicies(t *testing.T) { + ctx := context.Background() + patID := uuid.New().String() + + t.Run("should delete every policy held by the PAT", func(t *testing.T) { + mockPolicySvc := mocks.NewPolicyService(t) + + mockPolicySvc.EXPECT().List(ctx, policy.Filter{PrincipalID: patID, PrincipalType: schema.PATPrincipal}). + Return([]policy.Policy{{ID: "p1"}, {ID: "p2"}, {ID: "p3"}}, nil) + mockPolicySvc.EXPECT().Delete(ctx, "p1").Return(nil) + mockPolicySvc.EXPECT().Delete(ctx, "p2").Return(nil) + mockPolicySvc.EXPECT().Delete(ctx, "p3").Return(nil) + + svc := membership.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), mockPolicySvc, mocks.NewRelationService(t), mocks.NewRoleService(t), mocks.NewOrgService(t), mocks.NewUserService(t), mocks.NewProjectService(t), mocks.NewGroupService(t), mocks.NewServiceuserService(t), mocks.NewAuditRecordRepository(t)) + err := svc.RemoveAllPATPolicies(ctx, patID) + assert.NoError(t, err) + }) + + t.Run("should be a no-op when the PAT has no policies", func(t *testing.T) { + mockPolicySvc := mocks.NewPolicyService(t) + + mockPolicySvc.EXPECT().List(ctx, policy.Filter{PrincipalID: patID, PrincipalType: schema.PATPrincipal}). + Return([]policy.Policy{}, nil) + + svc := membership.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), mockPolicySvc, mocks.NewRelationService(t), mocks.NewRoleService(t), mocks.NewOrgService(t), mocks.NewUserService(t), mocks.NewProjectService(t), mocks.NewGroupService(t), mocks.NewServiceuserService(t), mocks.NewAuditRecordRepository(t)) + err := svc.RemoveAllPATPolicies(ctx, patID) + assert.NoError(t, err) + }) + + t.Run("should surface policy list errors", func(t *testing.T) { + mockPolicySvc := mocks.NewPolicyService(t) + + mockPolicySvc.EXPECT().List(ctx, policy.Filter{PrincipalID: patID, PrincipalType: schema.PATPrincipal}). + Return(nil, errors.New("db down")) + + svc := membership.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), mockPolicySvc, mocks.NewRelationService(t), mocks.NewRoleService(t), mocks.NewOrgService(t), mocks.NewUserService(t), mocks.NewProjectService(t), mocks.NewGroupService(t), mocks.NewServiceuserService(t), mocks.NewAuditRecordRepository(t)) + err := svc.RemoveAllPATPolicies(ctx, patID) + assert.Error(t, err) + }) + + t.Run("should surface policy delete errors", func(t *testing.T) { + mockPolicySvc := mocks.NewPolicyService(t) + + mockPolicySvc.EXPECT().List(ctx, policy.Filter{PrincipalID: patID, PrincipalType: schema.PATPrincipal}). + Return([]policy.Policy{{ID: "p1"}}, nil) + mockPolicySvc.EXPECT().Delete(ctx, "p1").Return(errors.New("spicedb unavailable")) + + svc := membership.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), mockPolicySvc, mocks.NewRelationService(t), mocks.NewRoleService(t), mocks.NewOrgService(t), mocks.NewUserService(t), mocks.NewProjectService(t), mocks.NewGroupService(t), mocks.NewServiceuserService(t), mocks.NewAuditRecordRepository(t)) + err := svc.RemoveAllPATPolicies(ctx, patID) + assert.Error(t, err) + }) +} + func TestService_ListPrincipalsByResource(t *testing.T) { ctx := context.Background() orgID := uuid.New().String() diff --git a/core/userpat/errors/errors.go b/core/userpat/errors/errors.go index b403dd7c4..b4e4389ff 100644 --- a/core/userpat/errors/errors.go +++ b/core/userpat/errors/errors.go @@ -17,4 +17,5 @@ var ( ErrScopeMismatch = errors.New("role does not support the specified scope") ErrRoleNotFound = errors.New("one or more requested roles do not exist") ErrProjectForbidden = errors.New("user does not have access to one or more specified projects") + ErrDuplicateScope = errors.New("only one role per resource type is allowed") ) diff --git a/core/userpat/mocks/membership_service.go b/core/userpat/mocks/membership_service.go new file mode 100644 index 000000000..6ea366929 --- /dev/null +++ b/core/userpat/mocks/membership_service.go @@ -0,0 +1,293 @@ +// Code generated by mockery v2.53.5. DO NOT EDIT. + +package mocks + +import ( + context "context" + + policy "github.com/raystack/frontier/core/policy" + mock "github.com/stretchr/testify/mock" +) + +// MembershipService is an autogenerated mock type for the MembershipService type +type MembershipService struct { + mock.Mock +} + +type MembershipService_Expecter struct { + mock *mock.Mock +} + +func (_m *MembershipService) EXPECT() *MembershipService_Expecter { + return &MembershipService_Expecter{mock: &_m.Mock} +} + +// ListPoliciesByPrincipal provides a mock function with given fields: ctx, principalID, principalType +func (_m *MembershipService) ListPoliciesByPrincipal(ctx context.Context, principalID string, principalType string) ([]policy.Policy, error) { + ret := _m.Called(ctx, principalID, principalType) + + if len(ret) == 0 { + panic("no return value specified for ListPoliciesByPrincipal") + } + + var r0 []policy.Policy + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) ([]policy.Policy, error)); ok { + return rf(ctx, principalID, principalType) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string) []policy.Policy); ok { + r0 = rf(ctx, principalID, principalType) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]policy.Policy) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = rf(ctx, principalID, principalType) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MembershipService_ListPoliciesByPrincipal_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListPoliciesByPrincipal' +type MembershipService_ListPoliciesByPrincipal_Call struct { + *mock.Call +} + +// ListPoliciesByPrincipal is a helper method to define mock.On call +// - ctx context.Context +// - principalID string +// - principalType string +func (_e *MembershipService_Expecter) ListPoliciesByPrincipal(ctx interface{}, principalID interface{}, principalType interface{}) *MembershipService_ListPoliciesByPrincipal_Call { + return &MembershipService_ListPoliciesByPrincipal_Call{Call: _e.mock.On("ListPoliciesByPrincipal", ctx, principalID, principalType)} +} + +func (_c *MembershipService_ListPoliciesByPrincipal_Call) Run(run func(ctx context.Context, principalID string, principalType string)) *MembershipService_ListPoliciesByPrincipal_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string)) + }) + return _c +} + +func (_c *MembershipService_ListPoliciesByPrincipal_Call) Return(_a0 []policy.Policy, _a1 error) *MembershipService_ListPoliciesByPrincipal_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MembershipService_ListPoliciesByPrincipal_Call) RunAndReturn(run func(context.Context, string, string) ([]policy.Policy, error)) *MembershipService_ListPoliciesByPrincipal_Call { + _c.Call.Return(run) + return _c +} + +// RemoveAllPATPolicies provides a mock function with given fields: ctx, patID +func (_m *MembershipService) RemoveAllPATPolicies(ctx context.Context, patID string) error { + ret := _m.Called(ctx, patID) + + if len(ret) == 0 { + panic("no return value specified for RemoveAllPATPolicies") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, patID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MembershipService_RemoveAllPATPolicies_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveAllPATPolicies' +type MembershipService_RemoveAllPATPolicies_Call struct { + *mock.Call +} + +// RemoveAllPATPolicies is a helper method to define mock.On call +// - ctx context.Context +// - patID string +func (_e *MembershipService_Expecter) RemoveAllPATPolicies(ctx interface{}, patID interface{}) *MembershipService_RemoveAllPATPolicies_Call { + return &MembershipService_RemoveAllPATPolicies_Call{Call: _e.mock.On("RemoveAllPATPolicies", ctx, patID)} +} + +func (_c *MembershipService_RemoveAllPATPolicies_Call) Run(run func(ctx context.Context, patID string)) *MembershipService_RemoveAllPATPolicies_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *MembershipService_RemoveAllPATPolicies_Call) Return(_a0 error) *MembershipService_RemoveAllPATPolicies_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MembershipService_RemoveAllPATPolicies_Call) RunAndReturn(run func(context.Context, string) error) *MembershipService_RemoveAllPATPolicies_Call { + _c.Call.Return(run) + return _c +} + +// SetOrganizationMemberRole provides a mock function with given fields: ctx, orgID, principalID, principalType, roleID +func (_m *MembershipService) SetOrganizationMemberRole(ctx context.Context, orgID string, principalID string, principalType string, roleID string) error { + ret := _m.Called(ctx, orgID, principalID, principalType, roleID) + + if len(ret) == 0 { + panic("no return value specified for SetOrganizationMemberRole") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string) error); ok { + r0 = rf(ctx, orgID, principalID, principalType, roleID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MembershipService_SetOrganizationMemberRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetOrganizationMemberRole' +type MembershipService_SetOrganizationMemberRole_Call struct { + *mock.Call +} + +// SetOrganizationMemberRole is a helper method to define mock.On call +// - ctx context.Context +// - orgID string +// - principalID string +// - principalType string +// - roleID string +func (_e *MembershipService_Expecter) SetOrganizationMemberRole(ctx interface{}, orgID interface{}, principalID interface{}, principalType interface{}, roleID interface{}) *MembershipService_SetOrganizationMemberRole_Call { + return &MembershipService_SetOrganizationMemberRole_Call{Call: _e.mock.On("SetOrganizationMemberRole", ctx, orgID, principalID, principalType, roleID)} +} + +func (_c *MembershipService_SetOrganizationMemberRole_Call) Run(run func(ctx context.Context, orgID string, principalID string, principalType string, roleID string)) *MembershipService_SetOrganizationMemberRole_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(string)) + }) + return _c +} + +func (_c *MembershipService_SetOrganizationMemberRole_Call) Return(_a0 error) *MembershipService_SetOrganizationMemberRole_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MembershipService_SetOrganizationMemberRole_Call) RunAndReturn(run func(context.Context, string, string, string, string) error) *MembershipService_SetOrganizationMemberRole_Call { + _c.Call.Return(run) + return _c +} + +// SetPATAllProjectsRole provides a mock function with given fields: ctx, orgID, patID, roleID +func (_m *MembershipService) SetPATAllProjectsRole(ctx context.Context, orgID string, patID string, roleID string) error { + ret := _m.Called(ctx, orgID, patID, roleID) + + if len(ret) == 0 { + panic("no return value specified for SetPATAllProjectsRole") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) error); ok { + r0 = rf(ctx, orgID, patID, roleID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MembershipService_SetPATAllProjectsRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetPATAllProjectsRole' +type MembershipService_SetPATAllProjectsRole_Call struct { + *mock.Call +} + +// SetPATAllProjectsRole is a helper method to define mock.On call +// - ctx context.Context +// - orgID string +// - patID string +// - roleID string +func (_e *MembershipService_Expecter) SetPATAllProjectsRole(ctx interface{}, orgID interface{}, patID interface{}, roleID interface{}) *MembershipService_SetPATAllProjectsRole_Call { + return &MembershipService_SetPATAllProjectsRole_Call{Call: _e.mock.On("SetPATAllProjectsRole", ctx, orgID, patID, roleID)} +} + +func (_c *MembershipService_SetPATAllProjectsRole_Call) Run(run func(ctx context.Context, orgID string, patID string, roleID string)) *MembershipService_SetPATAllProjectsRole_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string)) + }) + return _c +} + +func (_c *MembershipService_SetPATAllProjectsRole_Call) Return(_a0 error) *MembershipService_SetPATAllProjectsRole_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MembershipService_SetPATAllProjectsRole_Call) RunAndReturn(run func(context.Context, string, string, string) error) *MembershipService_SetPATAllProjectsRole_Call { + _c.Call.Return(run) + return _c +} + +// SetProjectMemberRole provides a mock function with given fields: ctx, projectID, principalID, principalType, roleID +func (_m *MembershipService) SetProjectMemberRole(ctx context.Context, projectID string, principalID string, principalType string, roleID string) error { + ret := _m.Called(ctx, projectID, principalID, principalType, roleID) + + if len(ret) == 0 { + panic("no return value specified for SetProjectMemberRole") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string) error); ok { + r0 = rf(ctx, projectID, principalID, principalType, roleID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MembershipService_SetProjectMemberRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetProjectMemberRole' +type MembershipService_SetProjectMemberRole_Call struct { + *mock.Call +} + +// SetProjectMemberRole is a helper method to define mock.On call +// - ctx context.Context +// - projectID string +// - principalID string +// - principalType string +// - roleID string +func (_e *MembershipService_Expecter) SetProjectMemberRole(ctx interface{}, projectID interface{}, principalID interface{}, principalType interface{}, roleID interface{}) *MembershipService_SetProjectMemberRole_Call { + return &MembershipService_SetProjectMemberRole_Call{Call: _e.mock.On("SetProjectMemberRole", ctx, projectID, principalID, principalType, roleID)} +} + +func (_c *MembershipService_SetProjectMemberRole_Call) Run(run func(ctx context.Context, projectID string, principalID string, principalType string, roleID string)) *MembershipService_SetProjectMemberRole_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(string)) + }) + return _c +} + +func (_c *MembershipService_SetProjectMemberRole_Call) Return(_a0 error) *MembershipService_SetProjectMemberRole_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MembershipService_SetProjectMemberRole_Call) RunAndReturn(run func(context.Context, string, string, string, string) error) *MembershipService_SetProjectMemberRole_Call { + _c.Call.Return(run) + return _c +} + +// NewMembershipService creates a new instance of MembershipService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMembershipService(t interface { + mock.TestingT + Cleanup(func()) +}) *MembershipService { + mock := &MembershipService{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/core/userpat/mocks/policy_service.go b/core/userpat/mocks/policy_service.go index 56dbe4f9b..c859bb71a 100644 --- a/core/userpat/mocks/policy_service.go +++ b/core/userpat/mocks/policy_service.go @@ -22,110 +22,6 @@ func (_m *PolicyService) EXPECT() *PolicyService_Expecter { return &PolicyService_Expecter{mock: &_m.Mock} } -// Create provides a mock function with given fields: ctx, pol -func (_m *PolicyService) Create(ctx context.Context, pol policy.Policy) (policy.Policy, error) { - ret := _m.Called(ctx, pol) - - if len(ret) == 0 { - panic("no return value specified for Create") - } - - var r0 policy.Policy - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, policy.Policy) (policy.Policy, error)); ok { - return rf(ctx, pol) - } - if rf, ok := ret.Get(0).(func(context.Context, policy.Policy) policy.Policy); ok { - r0 = rf(ctx, pol) - } else { - r0 = ret.Get(0).(policy.Policy) - } - - if rf, ok := ret.Get(1).(func(context.Context, policy.Policy) error); ok { - r1 = rf(ctx, pol) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// PolicyService_Create_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Create' -type PolicyService_Create_Call struct { - *mock.Call -} - -// Create is a helper method to define mock.On call -// - ctx context.Context -// - pol policy.Policy -func (_e *PolicyService_Expecter) Create(ctx interface{}, pol interface{}) *PolicyService_Create_Call { - return &PolicyService_Create_Call{Call: _e.mock.On("Create", ctx, pol)} -} - -func (_c *PolicyService_Create_Call) Run(run func(ctx context.Context, pol policy.Policy)) *PolicyService_Create_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(policy.Policy)) - }) - return _c -} - -func (_c *PolicyService_Create_Call) Return(_a0 policy.Policy, _a1 error) *PolicyService_Create_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *PolicyService_Create_Call) RunAndReturn(run func(context.Context, policy.Policy) (policy.Policy, error)) *PolicyService_Create_Call { - _c.Call.Return(run) - return _c -} - -// Delete provides a mock function with given fields: ctx, id -func (_m *PolicyService) Delete(ctx context.Context, id string) error { - ret := _m.Called(ctx, id) - - if len(ret) == 0 { - panic("no return value specified for Delete") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { - r0 = rf(ctx, id) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// PolicyService_Delete_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Delete' -type PolicyService_Delete_Call struct { - *mock.Call -} - -// Delete is a helper method to define mock.On call -// - ctx context.Context -// - id string -func (_e *PolicyService_Expecter) Delete(ctx interface{}, id interface{}) *PolicyService_Delete_Call { - return &PolicyService_Delete_Call{Call: _e.mock.On("Delete", ctx, id)} -} - -func (_c *PolicyService_Delete_Call) Run(run func(ctx context.Context, id string)) *PolicyService_Delete_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(string)) - }) - return _c -} - -func (_c *PolicyService_Delete_Call) Return(_a0 error) *PolicyService_Delete_Call { - _c.Call.Return(_a0) - return _c -} - -func (_c *PolicyService_Delete_Call) RunAndReturn(run func(context.Context, string) error) *PolicyService_Delete_Call { - _c.Call.Return(run) - return _c -} - // List provides a mock function with given fields: ctx, flt func (_m *PolicyService) List(ctx context.Context, flt policy.Filter) ([]policy.Policy, error) { ret := _m.Called(ctx, flt) diff --git a/core/userpat/mocks/project_service.go b/core/userpat/mocks/project_service.go index 3dd430db6..80b1b68c6 100644 --- a/core/userpat/mocks/project_service.go +++ b/core/userpat/mocks/project_service.go @@ -5,9 +5,8 @@ package mocks import ( context "context" - mock "github.com/stretchr/testify/mock" - project "github.com/raystack/frontier/core/project" + mock "github.com/stretchr/testify/mock" ) // ProjectService is an autogenerated mock type for the ProjectService type @@ -94,4 +93,4 @@ func NewProjectService(t interface { t.Cleanup(func() { mock.AssertExpectations(t) }) return mock -} \ No newline at end of file +} diff --git a/core/userpat/service.go b/core/userpat/service.go index 98472e6b1..3fc2d8a87 100644 --- a/core/userpat/service.go +++ b/core/userpat/service.go @@ -43,10 +43,12 @@ type RoleService interface { List(ctx context.Context, f role.Filter) ([]role.Role, error) } -type PolicyService interface { - Create(ctx context.Context, pol policy.Policy) (policy.Policy, error) - List(ctx context.Context, flt policy.Filter) ([]policy.Policy, error) - Delete(ctx context.Context, id string) error +type MembershipService interface { + SetOrganizationMemberRole(ctx context.Context, orgID, principalID, principalType, roleID string) error + SetPATAllProjectsRole(ctx context.Context, orgID, patID, roleID string) error + SetProjectMemberRole(ctx context.Context, projectID, principalID, principalType, roleID string) error + RemoveAllPATPolicies(ctx context.Context, patID string) error + ListPoliciesByPrincipal(ctx context.Context, principalID, principalType string) ([]policy.Policy, error) } type ProjectService interface { @@ -63,21 +65,22 @@ type Service struct { logger *slog.Logger orgService OrganizationService roleService RoleService - policyService PolicyService + membershipService MembershipService projectService ProjectService auditRecordRepository AuditRecordRepository deniedPerms map[string]struct{} } func NewService(logger *slog.Logger, repo Repository, config Config, orgService OrganizationService, - roleService RoleService, policyService PolicyService, projectService ProjectService, auditRecordRepository AuditRecordRepository) *Service { + roleService RoleService, membershipService MembershipService, + projectService ProjectService, auditRecordRepository AuditRecordRepository) *Service { return &Service{ repo: repo, config: config, logger: logger, orgService: orgService, roleService: roleService, - policyService: policyService, + membershipService: membershipService, projectService: projectService, auditRecordRepository: auditRecordRepository, deniedPerms: config.DeniedPermissionsSet(), @@ -156,7 +159,7 @@ func (s *Service) Delete(ctx context.Context, userID, id string) error { return fmt.Errorf("soft deleting PAT: %w", err) } - if err := s.deletePolicies(ctx, id); err != nil { + if err := s.membershipService.RemoveAllPATPolicies(ctx, id); err != nil { return fmt.Errorf("deleting policies: %w", err) } @@ -232,6 +235,9 @@ func (s *Service) Update(ctx context.Context, toUpdate patmodels.PAT) (patmodels if err != nil { return patmodels.PAT{}, err } + if !existing.ExpiresAt.After(time.Now()) { + return patmodels.PAT{}, paterrors.ErrExpired + } if err := s.validateScopes(ctx, toUpdate.Scopes); err != nil { return patmodels.PAT{}, err @@ -291,7 +297,7 @@ func (s *Service) captureOldScope(ctx context.Context, pat *patmodels.PAT) (stri // replacePolicies deletes existing policies and creates new ones from scopes. // Re-checks PAT existence after delete to guard against concurrent soft-delete. func (s *Service) replacePolicies(ctx context.Context, patID, orgID string, scopes []patmodels.PATScope) error { - if err := s.deletePolicies(ctx, patID); err != nil { + if err := s.membershipService.RemoveAllPATPolicies(ctx, patID); err != nil { return fmt.Errorf("deleting old policies: %w", err) } @@ -317,24 +323,6 @@ func (s *Service) auditUpdate(ctx context.Context, updated patmodels.PAT, toUpda } } -// deletePolicies removes all SpiceDB policies associated with a PAT. -// Each policy.Delete call removes SpiceDB relations first, then hard-deletes the Postgres policy row. -func (s *Service) deletePolicies(ctx context.Context, patID string) error { - policies, err := s.policyService.List(ctx, policy.Filter{ - PrincipalID: patID, - PrincipalType: schema.PATPrincipal, - }) - if err != nil { - return fmt.Errorf("listing policies for PAT %s: %w", patID, err) - } - for _, pol := range policies { - if err := s.policyService.Delete(ctx, pol.ID); err != nil { - return fmt.Errorf("deleting policy %s: %w", pol.ID, err) - } - } - return nil -} - // Create generates a new PAT and returns it with the plaintext value. // The plaintext value is only available at creation time. func (s *Service) Create(ctx context.Context, req CreateRequest) (patmodels.PAT, string, error) { @@ -475,6 +463,7 @@ func (s *Service) validateScopes(ctx context.Context, scopes []patmodels.PATScop roleMap[r.ID] = r } + seen := make(map[string]bool, len(scopes)) for _, sc := range scopes { if !slices.Contains(supportedPATResourceTypes, sc.ResourceType) { return fmt.Errorf("resource type %s: %w", sc.ResourceType, paterrors.ErrUnsupportedScope) @@ -483,6 +472,10 @@ func (s *Service) validateScopes(ctx context.Context, scopes []patmodels.PATScop if !slices.Contains(r.Scopes, sc.ResourceType) { return fmt.Errorf("role %s does not support resource type %s: %w", sc.RoleID, sc.ResourceType, paterrors.ErrScopeMismatch) } + if seen[sc.ResourceType] { + return fmt.Errorf("resource type %s: %w", sc.ResourceType, paterrors.ErrDuplicateScope) + } + seen[sc.ResourceType] = true } return nil } @@ -526,17 +519,25 @@ func (s *Service) validateProjectAccess(ctx context.Context, userID, orgID strin return nil } -// createPolicies creates SpiceDB policies from pre-validated scopes. +// createPolicies writes the PAT's scopes via the membership package. func (s *Service) createPolicies(ctx context.Context, patID, orgID string, scopes []patmodels.PATScope) error { for _, sc := range scopes { switch sc.ResourceType { case schema.OrganizationNamespace: - if err := s.createOrgScopedPolicy(ctx, patID, orgID, sc.RoleID); err != nil { - return err + if err := s.membershipService.SetOrganizationMemberRole(ctx, orgID, patID, schema.PATPrincipal, sc.RoleID); err != nil { + return fmt.Errorf("set org role: %w", err) } case schema.ProjectNamespace: - if err := s.createProjectScopedPolicies(ctx, patID, orgID, sc.RoleID, sc.ResourceIDs); err != nil { - return err + if len(sc.ResourceIDs) == 0 { + if err := s.membershipService.SetPATAllProjectsRole(ctx, orgID, patID, sc.RoleID); err != nil { + return fmt.Errorf("set all-projects role: %w", err) + } + continue + } + for _, pid := range sc.ResourceIDs { + if err := s.membershipService.SetProjectMemberRole(ctx, pid, patID, schema.PATPrincipal, sc.RoleID); err != nil { + return fmt.Errorf("set project role on %s: %w", pid, err) + } } default: return fmt.Errorf("unsupported resource type %s: %w", sc.ResourceType, paterrors.ErrUnsupportedScope) @@ -606,51 +607,10 @@ func (s *Service) validateRolePermissions(roles []role.Role) error { return nil } -// createPATPolicy creates a single SpiceDB policy for a PAT. -func (s *Service) createPATPolicy(ctx context.Context, patID, roleID, resourceID, resourceType, grantRelation string) error { - if _, err := s.policyService.Create(ctx, policy.Policy{ - RoleID: roleID, - ResourceID: resourceID, - ResourceType: resourceType, - PrincipalID: patID, - PrincipalType: schema.PATPrincipal, - GrantRelation: grantRelation, - }); err != nil { - s.logger.Error("failed to create PAT policy", - "pat_id", patID, "role_id", roleID, "resource_id", resourceID, - "resource_type", resourceType, "grant_relation", grantRelation, "error", err) - return err - } - return nil -} - -// createOrgScopedPolicy creates a policy on the org with the default "granted" relation. -func (s *Service) createOrgScopedPolicy(ctx context.Context, patID, orgID, roleID string) error { - return s.createPATPolicy(ctx, patID, roleID, orgID, schema.OrganizationNamespace, schema.RoleGrantRelationName) -} - -// createProjectScopedPolicies creates policies for a project-scoped role. -// If resourceIDs is empty, it creates a single policy on the org with "pat_granted" relation -// (cascades to all projects). Otherwise, it creates one policy per project with default "granted". -func (s *Service) createProjectScopedPolicies(ctx context.Context, patID, orgID, roleID string, resourceIDs []string) error { - if len(resourceIDs) == 0 { - return s.createPATPolicy(ctx, patID, roleID, orgID, schema.OrganizationNamespace, schema.PATGrantRelationName) - } - for _, resourceID := range resourceIDs { - if err := s.createPATPolicy(ctx, patID, roleID, resourceID, schema.ProjectNamespace, schema.RoleGrantRelationName); err != nil { - return err - } - } - return nil -} - // enrichWithScope derives scopes from the PAT's policies. // Groups policies by role ID + resource type to reconstruct PATScope entries. func (s *Service) enrichWithScope(ctx context.Context, pat *patmodels.PAT) error { - policies, err := s.policyService.List(ctx, policy.Filter{ - PrincipalID: pat.ID, - PrincipalType: schema.PATPrincipal, - }) + policies, err := s.membershipService.ListPoliciesByPrincipal(ctx, pat.ID, schema.PATPrincipal) if err != nil { return fmt.Errorf("listing policies for PAT %s: %w", pat.ID, err) } diff --git a/core/userpat/service_test.go b/core/userpat/service_test.go index 498f37859..be832e8cf 100644 --- a/core/userpat/service_test.go +++ b/core/userpat/service_test.go @@ -34,7 +34,7 @@ var defaultConfig = userpat.Config{ MaxLifetime: "8760h", } -func newSuccessMocks(t *testing.T) (*mocks.OrganizationService, *mocks.RoleService, *mocks.PolicyService, *mocks.ProjectService, *mocks.AuditRecordRepository) { +func newSuccessMocks(t *testing.T) (*mocks.OrganizationService, *mocks.RoleService, *mocks.MembershipService, *mocks.ProjectService, *mocks.AuditRecordRepository) { t.Helper() orgSvc := mocks.NewOrganizationService(t) orgSvc.On("GetRaw", mock.Anything, mock.Anything). @@ -52,16 +52,22 @@ func newSuccessMocks(t *testing.T) (*mocks.OrganizationService, *mocks.RoleServi Name: "test-role", Scopes: []string{schema.OrganizationNamespace}, }, nil).Maybe() - policySvc := mocks.NewPolicyService(t) - policySvc.On("Create", mock.Anything, mock.Anything). - Return(policy.Policy{}, nil).Maybe() - policySvc.On("List", mock.Anything, mock.Anything). + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.On("SetOrganizationMemberRole", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(nil).Maybe() + membershipSvc.On("SetPATAllProjectsRole", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(nil).Maybe() + membershipSvc.On("SetProjectMemberRole", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(nil).Maybe() + membershipSvc.On("RemoveAllPATPolicies", mock.Anything, mock.Anything). + Return(nil).Maybe() + membershipSvc.On("ListPoliciesByPrincipal", mock.Anything, mock.Anything, mock.Anything). Return([]policy.Policy{}, nil).Maybe() projSvc := mocks.NewProjectService(t) auditRepo := mocks.NewAuditRecordRepository(t) auditRepo.On("Create", mock.Anything, mock.Anything). Return(auditmodels.AuditRecord{}, nil).Maybe() - return orgSvc, roleSvc, policySvc, projSvc, auditRepo + return orgSvc, roleSvc, membershipSvc, projSvc, auditRepo } func TestService_Create(t *testing.T) { @@ -253,8 +259,8 @@ func TestService_Create(t *testing.T) { ExpiresAt: futureExpiry, CreatedAt: time.Date(2026, 2, 10, 0, 0, 0, 0, time.UTC), }, nil) - orgSvc, roleSvc, policySvc, _, auditRepo := newSuccessMocks(t) - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, policySvc, nil, auditRepo) + orgSvc, roleSvc, membershipSvc, _, auditRepo := newSuccessMocks(t) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, membershipSvc, nil, auditRepo) }, validateFunc: func(t *testing.T, got models.PAT, tokenValue string) { t.Helper() @@ -285,8 +291,8 @@ func TestService_Create(t *testing.T) { Return(int64(0), nil) repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). Return(models.PAT{ID: "pat-1", OrgID: "org-1"}, nil) - orgSvc, roleSvc, policySvc, _, auditRepo := newSuccessMocks(t) - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, policySvc, nil, auditRepo) + orgSvc, roleSvc, membershipSvc, _, auditRepo := newSuccessMocks(t) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, membershipSvc, nil, auditRepo) }, validateFunc: func(t *testing.T, got models.PAT, tokenValue string) { t.Helper() @@ -322,8 +328,8 @@ func TestService_Create(t *testing.T) { Return(int64(0), nil) repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). Return(models.PAT{ID: "pat-1", OrgID: "org-1"}, nil) - orgSvc, roleSvc, policySvc, _, auditRepo := newSuccessMocks(t) - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, policySvc, nil, auditRepo) + orgSvc, roleSvc, membershipSvc, _, auditRepo := newSuccessMocks(t) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, membershipSvc, nil, auditRepo) }, validateFunc: func(t *testing.T, got models.PAT, tokenValue string) { t.Helper() @@ -358,13 +364,13 @@ func TestService_Create(t *testing.T) { Return(int64(0), nil) repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). Return(models.PAT{ID: "pat-1", OrgID: "org-1"}, nil) - orgSvc, roleSvc, policySvc, _, auditRepo := newSuccessMocks(t) + orgSvc, roleSvc, membershipSvc, _, auditRepo := newSuccessMocks(t) return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, userpat.Config{ Enabled: true, Prefix: "custom", MaxPerUserPerOrg: 50, MaxLifetime: "8760h", - }, orgSvc, roleSvc, policySvc, nil, auditRepo) + }, orgSvc, roleSvc, membershipSvc, nil, auditRepo) }, validateFunc: func(t *testing.T, got models.PAT, tokenValue string) { t.Helper() @@ -389,8 +395,8 @@ func TestService_Create(t *testing.T) { Return(int64(49), nil) repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). Return(models.PAT{ID: "pat-1", OrgID: "org-1"}, nil) - orgSvc, roleSvc, policySvc, _, auditRepo := newSuccessMocks(t) - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, policySvc, nil, auditRepo) + orgSvc, roleSvc, membershipSvc, _, auditRepo := newSuccessMocks(t) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, membershipSvc, nil, auditRepo) }, }, { @@ -409,8 +415,8 @@ func TestService_Create(t *testing.T) { Return(int64(0), nil) repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). Return(models.PAT{ID: "pat-1", OrgID: "org-1"}, nil) - orgSvc, roleSvc, policySvc, _, auditRepo := newSuccessMocks(t) - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, policySvc, nil, auditRepo) + orgSvc, roleSvc, membershipSvc, _, auditRepo := newSuccessMocks(t) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, membershipSvc, nil, auditRepo) }, }, } @@ -444,8 +450,8 @@ func TestService_Create_UniquePATs(t *testing.T) { repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). Return(models.PAT{ID: "pat-1", OrgID: "org-1"}, nil).Times(2) - orgSvc, roleSvc, policySvc, _, auditRepo := newSuccessMocks(t) - svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, policySvc, nil, auditRepo) + orgSvc, roleSvc, membershipSvc, _, auditRepo := newSuccessMocks(t) + svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, membershipSvc, nil, auditRepo) req := userpat.CreateRequest{ UserID: "user-1", @@ -479,8 +485,8 @@ func TestService_Create_HashVerification(t *testing.T) { }). Return(models.PAT{ID: "pat-1", OrgID: "org-1"}, nil) - orgSvc, roleSvc, policySvc, _, auditRepo := newSuccessMocks(t) - svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, policySvc, nil, auditRepo) + orgSvc, roleSvc, membershipSvc, _, auditRepo := newSuccessMocks(t) + svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, membershipSvc, nil, auditRepo) _, tokenValue, err := svc.Create(context.Background(), userpat.CreateRequest{ UserID: "user-1", @@ -532,18 +538,11 @@ func TestService_CreatePolicies_OrgScopedRole(t *testing.T) { roleSvc.EXPECT().List(mock.Anything, role.Filter{IDs: []string{"org-role-1"}}).Return([]role.Role{orgRole}, nil) roleSvc.On("Get", mock.Anything, "org-role-1").Return(orgRole, nil).Maybe() - policySvc := mocks.NewPolicyService(t) - policySvc.EXPECT().Create(mock.Anything, policy.Policy{ - RoleID: "org-role-1", - ResourceID: "org-1", - ResourceType: schema.OrganizationNamespace, - PrincipalID: "pat-1", - PrincipalType: schema.PATPrincipal, - GrantRelation: schema.RoleGrantRelationName, - }).Return(policy.Policy{ID: "pol-1"}, nil) - policySvc.On("List", mock.Anything, mock.Anything).Return([]policy.Policy{}, nil).Maybe() - - svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, policySvc, nil, auditRepo) + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.EXPECT().SetOrganizationMemberRole(mock.Anything, "org-1", "pat-1", schema.PATPrincipal, "org-role-1").Return(nil) + membershipSvc.On("ListPoliciesByPrincipal", mock.Anything, mock.Anything, mock.Anything).Return([]policy.Policy{}, nil).Maybe() + + svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, membershipSvc, nil, auditRepo) _, _, err := svc.Create(context.Background(), userpat.CreateRequest{ UserID: "user-1", OrgID: "org-1", @@ -579,18 +578,11 @@ func TestService_CreatePolicies_ProjectScopedAllProjects(t *testing.T) { roleSvc.EXPECT().List(mock.Anything, role.Filter{IDs: []string{"proj-role-1"}}).Return([]role.Role{projRole}, nil) roleSvc.On("Get", mock.Anything, "proj-role-1").Return(projRole, nil).Maybe() - policySvc := mocks.NewPolicyService(t) - policySvc.EXPECT().Create(mock.Anything, policy.Policy{ - RoleID: "proj-role-1", - ResourceID: "org-1", - ResourceType: schema.OrganizationNamespace, - PrincipalID: "pat-1", - PrincipalType: schema.PATPrincipal, - GrantRelation: schema.PATGrantRelationName, - }).Return(policy.Policy{ID: "pol-1"}, nil) - policySvc.On("List", mock.Anything, mock.Anything).Return([]policy.Policy{}, nil).Maybe() - - svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, policySvc, nil, auditRepo) + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.EXPECT().SetPATAllProjectsRole(mock.Anything, "org-1", "pat-1", "proj-role-1").Return(nil) + membershipSvc.On("ListPoliciesByPrincipal", mock.Anything, mock.Anything, mock.Anything).Return([]policy.Policy{}, nil).Maybe() + + svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, membershipSvc, nil, auditRepo) _, _, err := svc.Create(context.Background(), userpat.CreateRequest{ UserID: "user-1", OrgID: "org-1", @@ -626,24 +618,10 @@ func TestService_CreatePolicies_ProjectScopedSpecificProjects(t *testing.T) { roleSvc.EXPECT().List(mock.Anything, role.Filter{IDs: []string{"proj-role-1"}}).Return([]role.Role{projRole}, nil) roleSvc.On("Get", mock.Anything, "proj-role-1").Return(projRole, nil).Maybe() - policySvc := mocks.NewPolicyService(t) - policySvc.EXPECT().Create(mock.Anything, policy.Policy{ - RoleID: "proj-role-1", - ResourceID: "proj-a", - ResourceType: schema.ProjectNamespace, - PrincipalID: "pat-1", - PrincipalType: schema.PATPrincipal, - GrantRelation: schema.RoleGrantRelationName, - }).Return(policy.Policy{ID: "pol-1"}, nil) - policySvc.EXPECT().Create(mock.Anything, policy.Policy{ - RoleID: "proj-role-1", - ResourceID: "proj-b", - ResourceType: schema.ProjectNamespace, - PrincipalID: "pat-1", - PrincipalType: schema.PATPrincipal, - GrantRelation: schema.RoleGrantRelationName, - }).Return(policy.Policy{ID: "pol-2"}, nil) - policySvc.On("List", mock.Anything, mock.Anything).Return([]policy.Policy{}, nil).Maybe() + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.EXPECT().SetProjectMemberRole(mock.Anything, "proj-a", "pat-1", schema.PATPrincipal, "proj-role-1").Return(nil) + membershipSvc.EXPECT().SetProjectMemberRole(mock.Anything, "proj-b", "pat-1", schema.PATPrincipal, "proj-role-1").Return(nil) + membershipSvc.On("ListPoliciesByPrincipal", mock.Anything, mock.Anything, mock.Anything).Return([]policy.Policy{}, nil).Maybe() projSvc := mocks.NewProjectService(t) projSvc.On("List", mock.Anything, mock.MatchedBy(func(f project.Filter) bool { @@ -652,7 +630,7 @@ func TestService_CreatePolicies_ProjectScopedSpecificProjects(t *testing.T) { {ID: "proj-a"}, {ID: "proj-b"}, }, nil).Maybe() - svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, policySvc, projSvc, auditRepo) + svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, membershipSvc, projSvc, auditRepo) _, _, err := svc.Create(context.Background(), userpat.CreateRequest{ UserID: "user-1", OrgID: "org-1", @@ -681,12 +659,12 @@ func TestService_CreatePolicies_DeniedPermission(t *testing.T) { Scopes: []string{schema.OrganizationNamespace}, }}, nil) - policySvc := mocks.NewPolicyService(t) + membershipSvc := mocks.NewMembershipService(t) cfg := defaultConfig cfg.DeniedPermissions = []string{"app_organization_administer"} - svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, cfg, orgSvc, roleSvc, policySvc, nil, auditRepo) + svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, cfg, orgSvc, roleSvc, membershipSvc, nil, auditRepo) _, _, err := svc.Create(context.Background(), userpat.CreateRequest{ UserID: "user-1", OrgID: "org-1", @@ -714,9 +692,9 @@ func TestService_CreatePolicies_RoleFetchError(t *testing.T) { roleSvc.EXPECT().List(mock.Anything, role.Filter{IDs: []string{"bad-role"}}). Return(nil, errors.New("role not found")) - policySvc := mocks.NewPolicyService(t) + membershipSvc := mocks.NewMembershipService(t) - svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, policySvc, nil, auditRepo) + svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, membershipSvc, nil, auditRepo) _, _, err := svc.Create(context.Background(), userpat.CreateRequest{ UserID: "user-1", OrgID: "org-1", @@ -748,9 +726,9 @@ func TestService_CreatePolicies_UnsupportedScope(t *testing.T) { Scopes: []string{schema.GroupNamespace}, }}, nil) - policySvc := mocks.NewPolicyService(t) + membershipSvc := mocks.NewMembershipService(t) - svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, policySvc, nil, auditRepo) + svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, membershipSvc, nil, auditRepo) _, _, err := svc.Create(context.Background(), userpat.CreateRequest{ UserID: "user-1", OrgID: "org-1", @@ -782,9 +760,9 @@ func TestService_CreatePolicies_MissingRoleID(t *testing.T) { Scopes: []string{schema.OrganizationNamespace}, }}, nil) - policySvc := mocks.NewPolicyService(t) + membershipSvc := mocks.NewMembershipService(t) - svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, policySvc, nil, auditRepo) + svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, membershipSvc, nil, auditRepo) _, _, err := svc.Create(context.Background(), userpat.CreateRequest{ UserID: "user-1", OrgID: "org-1", @@ -812,8 +790,8 @@ func TestService_CreatePolicies_NoRoles(t *testing.T) { repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). Return(models.PAT{ID: "pat-1", OrgID: "org-1", CreatedAt: time.Now()}, nil) - orgSvc, roleSvc, policySvc, _, auditRepo := newSuccessMocks(t) - svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, policySvc, nil, auditRepo) + orgSvc, roleSvc, membershipSvc, _, auditRepo := newSuccessMocks(t) + svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, membershipSvc, nil, auditRepo) _, _, err := svc.Create(context.Background(), userpat.CreateRequest{ UserID: "user-1", @@ -919,11 +897,31 @@ func TestService_CreatePolicies_ScopeMatrix(t *testing.T) { {RoleID: "org-viewer-id", ResourceID: "org-1", ResourceType: schema.OrganizationNamespace, Grant: "granted"}, }, }, + { + // Both the all-projects pat_granted policy and the org granted policy + // land on (org-1, PAT) — SetOrganizationMemberRole must skip the + // pat_granted row when replacing existing org policies, otherwise the + // project-all-projects access is silently dropped when scopes arrive + // in this order. + name: "ex5: project all-projects first, then org — order does not drop pat_granted", + scopes: []models.PATScope{ + {RoleID: "proj-owner-id", ResourceType: schema.ProjectNamespace}, + {RoleID: "org-mgr-id", ResourceType: schema.OrganizationNamespace}, + }, + roles: []role.Role{ + {ID: "proj-owner-id", Name: "app_project_owner", Permissions: []string{"app_project_get", "app_project_update", "app_project_delete"}, Scopes: []string{schema.ProjectNamespace}}, + {ID: "org-mgr-id", Name: "app_organization_manager", Permissions: []string{"app_organization_get", "app_organization_update"}, Scopes: []string{schema.OrganizationNamespace}}, + }, + want: []wantPolicy{ + {RoleID: "org-mgr-id", ResourceID: "org-1", ResourceType: schema.OrganizationNamespace, Grant: "granted"}, + {RoleID: "proj-owner-id", ResourceID: "org-1", ResourceType: schema.OrganizationNamespace, Grant: "pat_granted"}, + }, + }, - // ── Multiple roles of same scope ───────────────────────────────── + // ── Duplicate scopes rejected (1 role per resource type) ───────── { - name: "multiple org roles create separate org policies", + name: "two org-scoped roles rejected", scopes: []models.PATScope{ {RoleID: "org-viewer-id", ResourceType: schema.OrganizationNamespace}, {RoleID: "org-billing-id", ResourceType: schema.OrganizationNamespace}, @@ -932,13 +930,12 @@ func TestService_CreatePolicies_ScopeMatrix(t *testing.T) { {ID: "org-viewer-id", Name: "app_organization_viewer", Permissions: []string{"app_organization_get"}, Scopes: []string{schema.OrganizationNamespace}}, {ID: "org-billing-id", Name: "app_organization_billing_viewer", Permissions: []string{"app_organization_billingview"}, Scopes: []string{schema.OrganizationNamespace}}, }, - want: []wantPolicy{ - {RoleID: "org-viewer-id", ResourceID: "org-1", ResourceType: schema.OrganizationNamespace, Grant: "granted"}, - {RoleID: "org-billing-id", ResourceID: "org-1", ResourceType: schema.OrganizationNamespace, Grant: "granted"}, - }, + want: nil, + wantErr: true, + wantErrIs: paterrors.ErrDuplicateScope, }, { - name: "multiple project roles, all projects → separate pat_granted policies", + name: "two project-scoped roles (all projects) rejected", scopes: []models.PATScope{ {RoleID: "proj-viewer-id", ResourceType: schema.ProjectNamespace}, {RoleID: "proj-editor-id", ResourceType: schema.ProjectNamespace}, @@ -947,13 +944,12 @@ func TestService_CreatePolicies_ScopeMatrix(t *testing.T) { {ID: "proj-viewer-id", Name: "app_project_viewer", Permissions: []string{"app_project_get"}, Scopes: []string{schema.ProjectNamespace}}, {ID: "proj-editor-id", Name: "app_project_editor", Permissions: []string{"app_project_get", "app_project_update"}, Scopes: []string{schema.ProjectNamespace}}, }, - want: []wantPolicy{ - {RoleID: "proj-viewer-id", ResourceID: "org-1", ResourceType: schema.OrganizationNamespace, Grant: "pat_granted"}, - {RoleID: "proj-editor-id", ResourceID: "org-1", ResourceType: schema.OrganizationNamespace, Grant: "pat_granted"}, - }, + want: nil, + wantErr: true, + wantErrIs: paterrors.ErrDuplicateScope, }, { - name: "multiple project roles, specific projects → policy per role per project", + name: "two project-scoped roles (specific projects) rejected", scopes: []models.PATScope{ {RoleID: "proj-viewer-id", ResourceType: schema.ProjectNamespace, ResourceIDs: []string{"proj-1", "proj-2"}}, {RoleID: "proj-editor-id", ResourceType: schema.ProjectNamespace, ResourceIDs: []string{"proj-1", "proj-2"}}, @@ -962,12 +958,9 @@ func TestService_CreatePolicies_ScopeMatrix(t *testing.T) { {ID: "proj-viewer-id", Name: "app_project_viewer", Permissions: []string{"app_project_get"}, Scopes: []string{schema.ProjectNamespace}}, {ID: "proj-editor-id", Name: "app_project_editor", Permissions: []string{"app_project_get", "app_project_update"}, Scopes: []string{schema.ProjectNamespace}}, }, - want: []wantPolicy{ - {RoleID: "proj-viewer-id", ResourceID: "proj-1", ResourceType: schema.ProjectNamespace, Grant: "granted"}, - {RoleID: "proj-viewer-id", ResourceID: "proj-2", ResourceType: schema.ProjectNamespace, Grant: "granted"}, - {RoleID: "proj-editor-id", ResourceID: "proj-1", ResourceType: schema.ProjectNamespace, Grant: "granted"}, - {RoleID: "proj-editor-id", ResourceID: "proj-2", ResourceType: schema.ProjectNamespace, Grant: "granted"}, - }, + want: nil, + wantErr: true, + wantErrIs: paterrors.ErrDuplicateScope, }, // ── Scope isolation ────────────────────────────────────────────── @@ -1161,15 +1154,45 @@ func TestService_CreatePolicies_ScopeMatrix(t *testing.T) { } } - // --- policyService: capture all Create calls + // --- membershipService: capture every write the service makes, translate + // each membership call into the equivalent policy.Policy shape so the + // existing wantPolicy assertions keep working. var captured []policy.Policy - policySvc := mocks.NewPolicyService(t) - policySvc.On("Create", mock.Anything, mock.AnythingOfType("policy.Policy")). + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.On("SetOrganizationMemberRole", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Run(func(args mock.Arguments) { + captured = append(captured, policy.Policy{ + RoleID: args.String(4), + ResourceID: args.String(1), + ResourceType: schema.OrganizationNamespace, + PrincipalID: args.String(2), + PrincipalType: args.String(3), + GrantRelation: schema.RoleGrantRelationName, + }) + }).Return(nil).Maybe() + membershipSvc.On("SetPATAllProjectsRole", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Run(func(args mock.Arguments) { + captured = append(captured, policy.Policy{ + RoleID: args.String(3), + ResourceID: args.String(1), + ResourceType: schema.OrganizationNamespace, + PrincipalID: args.String(2), + PrincipalType: schema.PATPrincipal, + GrantRelation: schema.PATGrantRelationName, + }) + }).Return(nil).Maybe() + membershipSvc.On("SetProjectMemberRole", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). Run(func(args mock.Arguments) { - captured = append(captured, args.Get(1).(policy.Policy)) - }). - Return(policy.Policy{ID: "pol-gen"}, nil).Maybe() - policySvc.On("List", mock.Anything, mock.Anything).Return([]policy.Policy{}, nil).Maybe() + captured = append(captured, policy.Policy{ + RoleID: args.String(4), + ResourceID: args.String(1), + ResourceType: schema.ProjectNamespace, + PrincipalID: args.String(2), + PrincipalType: args.String(3), + GrantRelation: schema.RoleGrantRelationName, + }) + }).Return(nil).Maybe() + membershipSvc.On("ListPoliciesByPrincipal", mock.Anything, mock.Anything, mock.Anything).Return([]policy.Policy{}, nil).Maybe() projSvc := mocks.NewProjectService(t) projSvc.On("List", mock.Anything, mock.MatchedBy(func(f project.Filter) bool { @@ -1178,7 +1201,7 @@ func TestService_CreatePolicies_ScopeMatrix(t *testing.T) { {ID: "proj-1"}, {ID: "proj-2"}, {ID: "proj-3"}, {ID: "proj-a"}, {ID: "proj-b"}, }, nil).Maybe() - svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, cfg, orgSvc, roleSvc, policySvc, projSvc, auditRepo) + svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, cfg, orgSvc, roleSvc, membershipSvc, projSvc, auditRepo) _, _, err := svc.Create(context.Background(), userpat.CreateRequest{ UserID: "user-1", OrgID: "org-1", @@ -1267,32 +1290,29 @@ func TestService_CreatePolicies_PolicyCreateFailure(t *testing.T) { orgSvc := mocks.NewOrganizationService(t) auditRepo := mocks.NewAuditRecordRepository(t) - orgViewerRole := role.Role{ID: "org-viewer-id", Name: "app_organization_viewer", Permissions: []string{"app_organization_get"}, Scopes: []string{schema.OrganizationNamespace}} - orgBillingRole := role.Role{ID: "org-billing-id", Name: "app_organization_billing", Permissions: []string{"app_organization_billingview"}, Scopes: []string{schema.OrganizationNamespace}} + projViewerRole := role.Role{ID: "proj-viewer-id", Name: "app_project_viewer", Permissions: []string{"app_project_get"}, Scopes: []string{schema.ProjectNamespace}} roleSvc := mocks.NewRoleService(t) - roleSvc.EXPECT().List(mock.Anything, role.Filter{IDs: []string{"org-viewer-id", "org-billing-id"}}). - Return([]role.Role{orgViewerRole, orgBillingRole}, nil) - roleSvc.On("Get", mock.Anything, "org-viewer-id").Return(orgViewerRole, nil).Maybe() - roleSvc.On("Get", mock.Anything, "org-billing-id").Return(orgBillingRole, nil).Maybe() - - // first policy Create succeeds, second fails - policySvc := mocks.NewPolicyService(t) - policySvc.On("Create", mock.Anything, mock.MatchedBy(func(p policy.Policy) bool { - return p.RoleID == "org-viewer-id" - })).Return(policy.Policy{ID: "pol-1"}, nil) - policySvc.On("Create", mock.Anything, mock.MatchedBy(func(p policy.Policy) bool { - return p.RoleID == "org-billing-id" - })).Return(policy.Policy{}, errors.New("spicedb unavailable")) - - svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, policySvc, nil, auditRepo) + roleSvc.EXPECT().List(mock.Anything, role.Filter{IDs: []string{"proj-viewer-id"}}). + Return([]role.Role{projViewerRole}, nil) + roleSvc.On("Get", mock.Anything, "proj-viewer-id").Return(projViewerRole, nil).Maybe() + + // one scope with two project IDs invokes SetProjectMemberRole twice; first succeeds, second fails + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.EXPECT().SetProjectMemberRole(mock.Anything, "proj-1", "pat-1", schema.PATPrincipal, "proj-viewer-id").Return(nil) + membershipSvc.EXPECT().SetProjectMemberRole(mock.Anything, "proj-2", "pat-1", schema.PATPrincipal, "proj-viewer-id").Return(errors.New("spicedb unavailable")) + + projSvc := mocks.NewProjectService(t) + projSvc.On("List", mock.Anything, mock.Anything). + Return([]project.Project{{ID: "proj-1"}, {ID: "proj-2"}}, nil).Maybe() + + svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, membershipSvc, projSvc, auditRepo) _, _, err := svc.Create(context.Background(), userpat.CreateRequest{ UserID: "user-1", OrgID: "org-1", Title: "fail-token", Scopes: []models.PATScope{ - {RoleID: "org-viewer-id", ResourceType: schema.OrganizationNamespace}, - {RoleID: "org-billing-id", ResourceType: schema.OrganizationNamespace}, + {RoleID: "proj-viewer-id", ResourceType: schema.ProjectNamespace, ResourceIDs: []string{"proj-1", "proj-2"}}, }, ExpiresAt: time.Now().Add(24 * time.Hour), }) @@ -1636,10 +1656,10 @@ func TestService_Get(t *testing.T) { patID: "pat-1", setup: func() *userpat.Service { repo := mocks.NewRepository(t) - orgSvc, _, policySvc, _, auditRepo := newSuccessMocks(t) + orgSvc, _, membershipSvc, _, auditRepo := newSuccessMocks(t) return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, userpat.Config{ Enabled: false, - }, orgSvc, nil, policySvc, nil, auditRepo) + }, orgSvc, nil, membershipSvc, nil, auditRepo) }, wantErr: true, wantErrIs: paterrors.ErrDisabled, @@ -1652,8 +1672,8 @@ func TestService_Get(t *testing.T) { repo := mocks.NewRepository(t) repo.EXPECT().GetByID(mock.Anything, "pat-1"). Return(models.PAT{}, paterrors.ErrNotFound) - orgSvc, _, policySvc, _, auditRepo := newSuccessMocks(t) - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, policySvc, nil, auditRepo) + orgSvc, _, membershipSvc, _, auditRepo := newSuccessMocks(t) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, membershipSvc, nil, auditRepo) }, wantErr: true, wantErrIs: paterrors.ErrNotFound, @@ -1666,8 +1686,8 @@ func TestService_Get(t *testing.T) { repo := mocks.NewRepository(t) repo.EXPECT().GetByID(mock.Anything, "pat-1"). Return(testPAT, nil) - orgSvc, _, policySvc, _, auditRepo := newSuccessMocks(t) - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, policySvc, nil, auditRepo) + orgSvc, _, membershipSvc, _, auditRepo := newSuccessMocks(t) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, membershipSvc, nil, auditRepo) }, wantErr: true, wantErrIs: paterrors.ErrNotFound, @@ -1680,12 +1700,12 @@ func TestService_Get(t *testing.T) { repo := mocks.NewRepository(t) repo.EXPECT().GetByID(mock.Anything, "pat-1"). Return(testPAT, nil) - orgSvc, _, policySvc, _, auditRepo := newSuccessMocks(t) - policySvc.On("List", mock.Anything, mock.Anything). + orgSvc, _, membershipSvc, _, auditRepo := newSuccessMocks(t) + membershipSvc.On("ListPoliciesByPrincipal", mock.Anything, mock.Anything, mock.Anything). Return([]policy.Policy{ {RoleID: "role-1", ResourceType: "app/organization", ResourceID: "org-1"}, }, nil).Maybe() - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, policySvc, nil, auditRepo) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, membershipSvc, nil, auditRepo) }, wantErr: false, }, @@ -1698,11 +1718,11 @@ func TestService_Get(t *testing.T) { repo.EXPECT().GetByID(mock.Anything, "pat-1"). Return(testPAT, nil) orgSvc := mocks.NewOrganizationService(t) - policySvc := mocks.NewPolicyService(t) - policySvc.On("List", mock.Anything, mock.Anything). + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.On("ListPoliciesByPrincipal", mock.Anything, mock.Anything, mock.Anything). Return(nil, errors.New("spicedb down")) auditRepo := mocks.NewAuditRecordRepository(t) - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, policySvc, nil, auditRepo) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, membershipSvc, nil, auditRepo) }, wantErr: true, }, @@ -1822,41 +1842,16 @@ func TestService_Delete(t *testing.T) { repo.EXPECT().Delete(mock.Anything, "pat-1"). Return(nil) orgSvc := mocks.NewOrganizationService(t) - policySvc := mocks.NewPolicyService(t) - policySvc.EXPECT().List(mock.Anything, policy.Filter{ - PrincipalID: "pat-1", - PrincipalType: schema.PATPrincipal, - }).Return(nil, errors.New("spicedb down")) + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.EXPECT().RemoveAllPATPolicies(mock.Anything, "pat-1"). + Return(errors.New("spicedb down")) auditRepo := mocks.NewAuditRecordRepository(t) - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, policySvc, nil, auditRepo) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, membershipSvc, nil, auditRepo) }, wantErr: true, }, { - name: "should return error when policy delete fails after soft-delete", - userID: "user-1", - patID: "pat-1", - setup: func() *userpat.Service { - repo := mocks.NewRepository(t) - repo.EXPECT().GetByID(mock.Anything, "pat-1"). - Return(testPAT, nil) - repo.EXPECT().Delete(mock.Anything, "pat-1"). - Return(nil) - orgSvc := mocks.NewOrganizationService(t) - policySvc := mocks.NewPolicyService(t) - policySvc.EXPECT().List(mock.Anything, policy.Filter{ - PrincipalID: "pat-1", - PrincipalType: schema.PATPrincipal, - }).Return([]policy.Policy{{ID: "pol-1"}}, nil) - policySvc.EXPECT().Delete(mock.Anything, "pol-1"). - Return(errors.New("spicedb unavailable")) - auditRepo := mocks.NewAuditRecordRepository(t) - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, policySvc, nil, auditRepo) - }, - wantErr: true, - }, - { - name: "should delete successfully with policies", + name: "should delete successfully", userID: "user-1", patID: "pat-1", setup: func() *userpat.Service { @@ -1868,45 +1863,12 @@ func TestService_Delete(t *testing.T) { orgSvc := mocks.NewOrganizationService(t) orgSvc.On("GetRaw", mock.Anything, mock.Anything). Return(organization.Organization{ID: "org-1", Title: "Test Org"}, nil).Maybe() - policySvc := mocks.NewPolicyService(t) - policySvc.EXPECT().List(mock.Anything, policy.Filter{ - PrincipalID: "pat-1", - PrincipalType: schema.PATPrincipal, - }).Return([]policy.Policy{ - {ID: "pol-1"}, - {ID: "pol-2"}, - }, nil) - policySvc.EXPECT().Delete(mock.Anything, "pol-1").Return(nil) - policySvc.EXPECT().Delete(mock.Anything, "pol-2").Return(nil) + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.EXPECT().RemoveAllPATPolicies(mock.Anything, "pat-1").Return(nil) auditRepo := mocks.NewAuditRecordRepository(t) auditRepo.On("Create", mock.Anything, mock.Anything). Return(auditmodels.AuditRecord{}, nil).Maybe() - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, policySvc, nil, auditRepo) - }, - wantErr: false, - }, - { - name: "should delete successfully with no policies", - userID: "user-1", - patID: "pat-1", - setup: func() *userpat.Service { - repo := mocks.NewRepository(t) - repo.EXPECT().GetByID(mock.Anything, "pat-1"). - Return(testPAT, nil) - repo.EXPECT().Delete(mock.Anything, "pat-1"). - Return(nil) - orgSvc := mocks.NewOrganizationService(t) - orgSvc.On("GetRaw", mock.Anything, mock.Anything). - Return(organization.Organization{ID: "org-1", Title: "Test Org"}, nil).Maybe() - policySvc := mocks.NewPolicyService(t) - policySvc.EXPECT().List(mock.Anything, policy.Filter{ - PrincipalID: "pat-1", - PrincipalType: schema.PATPrincipal, - }).Return([]policy.Policy{}, nil) - auditRepo := mocks.NewAuditRecordRepository(t) - auditRepo.On("Create", mock.Anything, mock.Anything). - Return(auditmodels.AuditRecord{}, nil).Maybe() - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, policySvc, nil, auditRepo) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, membershipSvc, nil, auditRepo) }, wantErr: false, }, @@ -1923,15 +1885,12 @@ func TestService_Delete(t *testing.T) { orgSvc := mocks.NewOrganizationService(t) orgSvc.On("GetRaw", mock.Anything, mock.Anything). Return(organization.Organization{ID: "org-1", Title: "Test Org"}, nil).Maybe() - policySvc := mocks.NewPolicyService(t) - policySvc.EXPECT().List(mock.Anything, policy.Filter{ - PrincipalID: "pat-1", - PrincipalType: schema.PATPrincipal, - }).Return([]policy.Policy{}, nil) + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.EXPECT().RemoveAllPATPolicies(mock.Anything, "pat-1").Return(nil) auditRepo := mocks.NewAuditRecordRepository(t) auditRepo.On("Create", mock.Anything, mock.Anything). Return(auditmodels.AuditRecord{}, errors.New("audit db down")) - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, policySvc, nil, auditRepo) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, membershipSvc, nil, auditRepo) }, wantErr: false, }, @@ -2039,6 +1998,20 @@ func TestService_Update(t *testing.T) { wantErr: true, wantErrIs: paterrors.ErrNotFound, }, + { + name: "should return ErrExpired when PAT has already expired", + input: defaultInput, + setup: func() *userpat.Service { + expiredPAT := testPAT + expiredPAT.ExpiresAt = time.Now().Add(-time.Hour) + repo := mocks.NewRepository(t) + repo.EXPECT().GetByID(mock.Anything, "pat-1"). + Return(expiredPAT, nil) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, nil, nil, nil, nil, nil) + }, + wantErr: true, + wantErrIs: paterrors.ErrExpired, + }, { name: "should return error when role validation fails", input: defaultInput, @@ -2064,14 +2037,11 @@ func TestService_Update(t *testing.T) { roleSvc := mocks.NewRoleService(t) roleSvc.EXPECT().List(mock.Anything, mock.Anything). Return([]role.Role{validRole}, nil) - policySvc := mocks.NewPolicyService(t) - policySvc.EXPECT().List(mock.Anything, policy.Filter{ - PrincipalID: "pat-1", - PrincipalType: schema.PATPrincipal, - }).Return([]policy.Policy{}, nil) + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.EXPECT().ListPoliciesByPrincipal(mock.Anything, "pat-1", schema.PATPrincipal).Return([]policy.Policy{}, nil) repo.EXPECT().Update(mock.Anything, mock.Anything). Return(models.PAT{}, errors.New("db error")) - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, nil, roleSvc, policySvc, nil, nil) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, nil, roleSvc, membershipSvc, nil, nil) }, wantErr: true, }, @@ -2085,14 +2055,11 @@ func TestService_Update(t *testing.T) { roleSvc := mocks.NewRoleService(t) roleSvc.EXPECT().List(mock.Anything, mock.Anything). Return([]role.Role{validRole}, nil) - policySvc := mocks.NewPolicyService(t) - policySvc.EXPECT().List(mock.Anything, policy.Filter{ - PrincipalID: "pat-1", - PrincipalType: schema.PATPrincipal, - }).Return([]policy.Policy{}, nil) + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.EXPECT().ListPoliciesByPrincipal(mock.Anything, "pat-1", schema.PATPrincipal).Return([]policy.Policy{}, nil) repo.EXPECT().Update(mock.Anything, mock.Anything). Return(models.PAT{}, paterrors.ErrConflict) - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, nil, roleSvc, policySvc, nil, nil) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, nil, roleSvc, membershipSvc, nil, nil) }, wantErr: true, wantErrIs: paterrors.ErrConflict, @@ -2107,20 +2074,12 @@ func TestService_Update(t *testing.T) { roleSvc := mocks.NewRoleService(t) roleSvc.EXPECT().List(mock.Anything, mock.Anything). Return([]role.Role{validRole}, nil) - policySvc := mocks.NewPolicyService(t) - // captureOldScope call - policySvc.EXPECT().List(mock.Anything, policy.Filter{ - PrincipalID: "pat-1", - PrincipalType: schema.PATPrincipal, - }).Return([]policy.Policy{}, nil).Once() + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.EXPECT().ListPoliciesByPrincipal(mock.Anything, "pat-1", schema.PATPrincipal).Return([]policy.Policy{}, nil) repo.EXPECT().Update(mock.Anything, mock.Anything). Return(updatedPAT, nil) - // deletePolicies call - policySvc.EXPECT().List(mock.Anything, policy.Filter{ - PrincipalID: "pat-1", - PrincipalType: schema.PATPrincipal, - }).Return(nil, errors.New("spicedb down")).Once() - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, nil, roleSvc, policySvc, nil, nil) + membershipSvc.EXPECT().RemoveAllPATPolicies(mock.Anything, "pat-1").Return(errors.New("spicedb down")) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, nil, roleSvc, membershipSvc, nil, nil) }, wantErr: true, }, @@ -2129,23 +2088,20 @@ func TestService_Update(t *testing.T) { input: defaultInput, setup: func() *userpat.Service { repo := mocks.NewRepository(t) - // getOwnedPAT repo.EXPECT().GetByID(mock.Anything, "pat-1"). Return(testPAT, nil).Once() roleSvc := mocks.NewRoleService(t) roleSvc.EXPECT().List(mock.Anything, mock.Anything). Return([]role.Role{validRole}, nil) - policySvc := mocks.NewPolicyService(t) - policySvc.EXPECT().List(mock.Anything, policy.Filter{ - PrincipalID: "pat-1", - PrincipalType: schema.PATPrincipal, - }).Return([]policy.Policy{}, nil) + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.EXPECT().ListPoliciesByPrincipal(mock.Anything, "pat-1", schema.PATPrincipal).Return([]policy.Policy{}, nil) repo.EXPECT().Update(mock.Anything, mock.Anything). Return(updatedPAT, nil) + membershipSvc.EXPECT().RemoveAllPATPolicies(mock.Anything, "pat-1").Return(nil) // TOCTOU re-check returns not found (concurrent delete) repo.EXPECT().GetByID(mock.Anything, "pat-1"). Return(models.PAT{}, paterrors.ErrNotFound).Once() - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, nil, roleSvc, policySvc, nil, nil) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, nil, roleSvc, membershipSvc, nil, nil) }, wantErr: true, }, @@ -2164,14 +2120,13 @@ func TestService_Update(t *testing.T) { orgSvc := mocks.NewOrganizationService(t) orgSvc.On("GetRaw", mock.Anything, mock.Anything). Return(organization.Organization{ID: "org-1", Title: "Test Org"}, nil).Maybe() - policySvc := mocks.NewPolicyService(t) + membershipSvc := mocks.NewMembershipService(t) // captureOldScope + enrichWithScope (after update) - policySvc.On("List", mock.Anything, policy.Filter{ - PrincipalID: "pat-1", - PrincipalType: schema.PATPrincipal, - }).Return([]policy.Policy{}, nil) - policySvc.On("Create", mock.Anything, mock.Anything). - Return(policy.Policy{}, nil) + membershipSvc.On("ListPoliciesByPrincipal", mock.Anything, "pat-1", schema.PATPrincipal).Return([]policy.Policy{}, nil) + membershipSvc.On("RemoveAllPATPolicies", mock.Anything, "pat-1").Return(nil) + membershipSvc.On("SetOrganizationMemberRole", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() + membershipSvc.On("SetPATAllProjectsRole", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() + membershipSvc.On("SetProjectMemberRole", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() repo.EXPECT().Update(mock.Anything, mock.Anything). Return(updatedPAT, nil) // TOCTOU re-check @@ -2180,7 +2135,7 @@ func TestService_Update(t *testing.T) { auditRepo := mocks.NewAuditRecordRepository(t) auditRepo.On("Create", mock.Anything, mock.Anything). Return(auditmodels.AuditRecord{}, nil).Maybe() - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, policySvc, nil, auditRepo) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, membershipSvc, nil, auditRepo) }, wantErr: false, }, @@ -2198,13 +2153,12 @@ func TestService_Update(t *testing.T) { orgSvc := mocks.NewOrganizationService(t) orgSvc.On("GetRaw", mock.Anything, mock.Anything). Return(organization.Organization{ID: "org-1", Title: "Test Org"}, nil).Maybe() - policySvc := mocks.NewPolicyService(t) - policySvc.On("List", mock.Anything, policy.Filter{ - PrincipalID: "pat-1", - PrincipalType: schema.PATPrincipal, - }).Return([]policy.Policy{}, nil) - policySvc.On("Create", mock.Anything, mock.Anything). - Return(policy.Policy{}, nil) + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.On("ListPoliciesByPrincipal", mock.Anything, "pat-1", schema.PATPrincipal).Return([]policy.Policy{}, nil) + membershipSvc.On("RemoveAllPATPolicies", mock.Anything, "pat-1").Return(nil) + membershipSvc.On("SetOrganizationMemberRole", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() + membershipSvc.On("SetPATAllProjectsRole", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() + membershipSvc.On("SetProjectMemberRole", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() repo.EXPECT().Update(mock.Anything, mock.Anything). Return(updatedPAT, nil) repo.EXPECT().GetByID(mock.Anything, "pat-1"). @@ -2212,7 +2166,7 @@ func TestService_Update(t *testing.T) { auditRepo := mocks.NewAuditRecordRepository(t) auditRepo.On("Create", mock.Anything, mock.Anything). Return(auditmodels.AuditRecord{}, errors.New("audit db down")) - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, policySvc, nil, auditRepo) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, membershipSvc, nil, auditRepo) }, wantErr: false, }, @@ -2369,13 +2323,13 @@ func TestService_Regenerate(t *testing.T) { orgSvc := mocks.NewOrganizationService(t) orgSvc.On("GetRaw", mock.Anything, mock.Anything). Return(organization.Organization{ID: "org-1", Title: "Test Org"}, nil).Maybe() - policySvc := mocks.NewPolicyService(t) - policySvc.On("List", mock.Anything, mock.Anything). + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.On("ListPoliciesByPrincipal", mock.Anything, mock.Anything, mock.Anything). Return([]policy.Policy{}, nil).Maybe() auditRepo := mocks.NewAuditRecordRepository(t) auditRepo.On("Create", mock.Anything, mock.Anything). Return(auditmodels.AuditRecord{}, nil).Maybe() - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, policySvc, nil, auditRepo) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, membershipSvc, nil, auditRepo) }, wantErr: false, }, @@ -2410,13 +2364,13 @@ func TestService_Regenerate(t *testing.T) { orgSvc := mocks.NewOrganizationService(t) orgSvc.On("GetRaw", mock.Anything, mock.Anything). Return(organization.Organization{ID: "org-1", Title: "Test Org"}, nil).Maybe() - policySvc := mocks.NewPolicyService(t) - policySvc.On("List", mock.Anything, mock.Anything). + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.On("ListPoliciesByPrincipal", mock.Anything, mock.Anything, mock.Anything). Return([]policy.Policy{}, nil).Maybe() auditRepo := mocks.NewAuditRecordRepository(t) auditRepo.On("Create", mock.Anything, mock.Anything). Return(auditmodels.AuditRecord{}, nil).Maybe() - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, policySvc, nil, auditRepo) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, membershipSvc, nil, auditRepo) }, wantErr: false, }, @@ -2434,13 +2388,13 @@ func TestService_Regenerate(t *testing.T) { orgSvc := mocks.NewOrganizationService(t) orgSvc.On("GetRaw", mock.Anything, mock.Anything). Return(organization.Organization{ID: "org-1", Title: "Test Org"}, nil).Maybe() - policySvc := mocks.NewPolicyService(t) - policySvc.On("List", mock.Anything, mock.Anything). + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.On("ListPoliciesByPrincipal", mock.Anything, mock.Anything, mock.Anything). Return([]policy.Policy{}, nil).Maybe() auditRepo := mocks.NewAuditRecordRepository(t) auditRepo.On("Create", mock.Anything, mock.Anything). Return(auditmodels.AuditRecord{}, errors.New("audit db down")) - return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, policySvc, nil, auditRepo) + return userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, nil, membershipSvc, nil, auditRepo) }, wantErr: false, }, @@ -2599,9 +2553,11 @@ func TestService_ValidateProjectAccess(t *testing.T) { roleSvc.EXPECT().List(mock.Anything, mock.Anything).Return([]role.Role{ {ID: "role-1", Name: "proj_viewer", Scopes: []string{schema.ProjectNamespace}, Permissions: []string{"app_project_get"}}, }, nil) - policySvc := mocks.NewPolicyService(t) - policySvc.On("Create", mock.Anything, mock.Anything).Return(policy.Policy{}, nil).Maybe() - policySvc.On("List", mock.Anything, mock.Anything).Return([]policy.Policy{}, nil).Maybe() + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.On("SetOrganizationMemberRole", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() + membershipSvc.On("SetProjectMemberRole", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() + membershipSvc.On("SetPATAllProjectsRole", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() + membershipSvc.On("ListPoliciesByPrincipal", mock.Anything, mock.Anything, mock.Anything).Return([]policy.Policy{}, nil).Maybe() projSvc := mocks.NewProjectService(t) projSvc.On("List", mock.Anything, mock.MatchedBy(func(f project.Filter) bool { return f.OrgID == "org-1" && f.Principal != nil && f.Principal.ID == "user-1" && f.Principal.Type == schema.UserPrincipal @@ -2611,7 +2567,7 @@ func TestService_ValidateProjectAccess(t *testing.T) { auditRepo := mocks.NewAuditRecordRepository(t) auditRepo.On("Create", mock.Anything, mock.Anything).Return(auditmodels.AuditRecord{}, nil).Maybe() - svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, policySvc, projSvc, auditRepo) + svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, membershipSvc, projSvc, auditRepo) _, _, err := svc.Create(context.Background(), userpat.CreateRequest{ UserID: "user-1", OrgID: "org-1", @@ -2638,14 +2594,16 @@ func TestService_ValidateProjectAccess(t *testing.T) { roleSvc.EXPECT().List(mock.Anything, mock.Anything).Return([]role.Role{ {ID: "role-1", Name: "proj_viewer", Scopes: []string{schema.ProjectNamespace}, Permissions: []string{"app_project_get"}}, }, nil) - policySvc := mocks.NewPolicyService(t) - policySvc.On("Create", mock.Anything, mock.Anything).Return(policy.Policy{}, nil).Maybe() - policySvc.On("List", mock.Anything, mock.Anything).Return([]policy.Policy{}, nil).Maybe() + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.On("SetOrganizationMemberRole", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() + membershipSvc.On("SetProjectMemberRole", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() + membershipSvc.On("SetPATAllProjectsRole", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() + membershipSvc.On("ListPoliciesByPrincipal", mock.Anything, mock.Anything, mock.Anything).Return([]policy.Policy{}, nil).Maybe() auditRepo := mocks.NewAuditRecordRepository(t) auditRepo.On("Create", mock.Anything, mock.Anything).Return(auditmodels.AuditRecord{}, nil).Maybe() // No projectService mock needed — all-projects scope has empty ResourceIDs, skips validation - svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, policySvc, nil, auditRepo) + svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, orgSvc, roleSvc, membershipSvc, nil, auditRepo) _, _, err := svc.Create(context.Background(), userpat.CreateRequest{ UserID: "user-1", OrgID: "org-1", @@ -2692,13 +2650,10 @@ func TestService_List(t *testing.T) { Return(models.PATList{ PATs: []models.PAT{{ID: "pat-1", UserID: "user-1", OrgID: "org-1"}}, }, nil) - policySvc := mocks.NewPolicyService(t) - policySvc.EXPECT().List(mock.Anything, policy.Filter{ - PrincipalID: "pat-1", - PrincipalType: schema.PATPrincipal, - }).Return(nil, errors.New("policy service down")) + membershipSvc := mocks.NewMembershipService(t) + membershipSvc.EXPECT().ListPoliciesByPrincipal(mock.Anything, "pat-1", schema.PATPrincipal).Return(nil, errors.New("policy service down")) auditRepo := mocks.NewAuditRecordRepository(t) - svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, nil, nil, policySvc, nil, auditRepo) + svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, nil, nil, membershipSvc, nil, auditRepo) _, err := svc.List(context.Background(), "user-1", "org-1", nil) if err == nil || !strings.Contains(err.Error(), "enriching PAT scope") { @@ -2715,21 +2670,15 @@ func TestService_List(t *testing.T) { {ID: "pat-2", UserID: "user-1", OrgID: "org-1", Title: "token-2"}, }, }, nil) - policySvc := mocks.NewPolicyService(t) + membershipSvc := mocks.NewMembershipService(t) // enrichWithScope for pat-1 - policySvc.EXPECT().List(mock.Anything, policy.Filter{ - PrincipalID: "pat-1", - PrincipalType: schema.PATPrincipal, - }).Return([]policy.Policy{ + membershipSvc.EXPECT().ListPoliciesByPrincipal(mock.Anything, "pat-1", schema.PATPrincipal).Return([]policy.Policy{ {ID: "pol-1", RoleID: "role-1", ResourceID: "org-1", ResourceType: schema.OrganizationNamespace, GrantRelation: "granted"}, }, nil) // enrichWithScope for pat-2 - policySvc.EXPECT().List(mock.Anything, policy.Filter{ - PrincipalID: "pat-2", - PrincipalType: schema.PATPrincipal, - }).Return([]policy.Policy{}, nil) + membershipSvc.EXPECT().ListPoliciesByPrincipal(mock.Anything, "pat-2", schema.PATPrincipal).Return([]policy.Policy{}, nil) auditRepo := mocks.NewAuditRecordRepository(t) - svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, nil, nil, policySvc, nil, auditRepo) + svc := userpat.NewService(slog.New(slog.NewTextHandler(io.Discard, nil)), repo, defaultConfig, nil, nil, membershipSvc, nil, auditRepo) result, err := svc.List(context.Background(), "user-1", "org-1", nil) if err != nil { diff --git a/internal/api/v1beta1connect/user_pat.go b/internal/api/v1beta1connect/user_pat.go index 723940713..ed244364f 100644 --- a/internal/api/v1beta1connect/user_pat.go +++ b/internal/api/v1beta1connect/user_pat.go @@ -35,7 +35,8 @@ func (h *ConnectHandler) getLoggedInPrincipalWithUser(ctx context.Context) (*aut // mapPATError maps PAT service errors to Connect RPC error codes. func mapPATError(err error) *connect.Error { switch { - case errors.Is(err, paterrors.ErrDisabled): + case errors.Is(err, paterrors.ErrDisabled), + errors.Is(err, paterrors.ErrExpired): return connect.NewError(connect.CodeFailedPrecondition, err) case errors.Is(err, paterrors.ErrNotFound): return connect.NewError(connect.CodeNotFound, err) @@ -47,6 +48,7 @@ func mapPATError(err error) *connect.Error { errors.Is(err, paterrors.ErrDeniedRole), errors.Is(err, paterrors.ErrUnsupportedScope), errors.Is(err, paterrors.ErrScopeMismatch), + errors.Is(err, paterrors.ErrDuplicateScope), errors.Is(err, paterrors.ErrProjectForbidden), errors.Is(err, paterrors.ErrExpiryInPast), errors.Is(err, paterrors.ErrExpiryExceeded): diff --git a/test/e2e/regression/pat_test.go b/test/e2e/regression/pat_test.go index cec4be759..e372702ab 100644 --- a/test/e2e/regression/pat_test.go +++ b/test/e2e/regression/pat_test.go @@ -748,6 +748,34 @@ func (s *PATRegressionTestSuite) TestPATCRUD_CreateErrors() { s.Assert().Error(err) s.Assert().Equal(connect.CodeInvalidArgument, connect.CodeOf(err)) }) + + s.Run("two org-scoped roles in one request", func() { + _, err := s.testBench.Client.CreateCurrentUserPAT(ctxAdmin, connect.NewRequest(&frontierv1beta1.CreateCurrentUserPATRequest{ + Title: "two-org-scopes", + OrgId: orgID, + Scopes: []*frontierv1beta1.PATScope{ + {RoleId: s.roleID(schema.RoleOrganizationViewer), ResourceType: schema.OrganizationNamespace}, + {RoleId: s.roleID(schema.RoleOrganizationManager), ResourceType: schema.OrganizationNamespace}, + }, + ExpiresAt: timestamppb.New(time.Now().Add(24 * time.Hour)), + })) + s.Assert().Error(err) + s.Assert().Equal(connect.CodeInvalidArgument, connect.CodeOf(err)) + }) + + s.Run("two project-scoped roles in one request", func() { + _, err := s.testBench.Client.CreateCurrentUserPAT(ctxAdmin, connect.NewRequest(&frontierv1beta1.CreateCurrentUserPATRequest{ + Title: "two-project-scopes", + OrgId: orgID, + Scopes: []*frontierv1beta1.PATScope{ + {RoleId: s.roleID(schema.RoleProjectViewer), ResourceType: schema.ProjectNamespace}, + {RoleId: s.roleID(schema.RoleProjectOwner), ResourceType: schema.ProjectNamespace}, + }, + ExpiresAt: timestamppb.New(time.Now().Add(24 * time.Hour)), + })) + s.Assert().Error(err) + s.Assert().Equal(connect.CodeInvalidArgument, connect.CodeOf(err)) + }) } func TestEndToEndPATRegressionTestSuite(t *testing.T) {