From d545638d868892fac0559b996f370798ebfb43e7 Mon Sep 17 00:00:00 2001 From: Manan Tyagi Date: Tue, 16 Jun 2026 10:41:47 +0530 Subject: [PATCH 1/2] roe flag bugs fix on share folder and share-record --- .../commands/pam/discovery/__init__.py | 5 +- .../commands/pam/discovery/discover.py | 64 +++++++++---- .../commands/pam/pam_gateway_action.py | 5 +- .../keepercli/commands/pam/pam_rotation.py | 29 ++++-- .../src/keepercli/commands/shares.py | 82 +++++++++++------ .../src/keepercli/helpers/report_utils.py | 3 + .../src/keepersdk/helpers/keeper_dag/dag.py | 11 ++- .../keepersdk/helpers/keeper_dag/dag_types.py | 9 ++ .../keepersdk/helpers/keeper_dag/dag_utils.py | 10 +++ .../src/keepersdk/helpers/keeper_dag/jobs.py | 7 +- .../helpers/keeper_dag/record_link.py | 10 +-- .../helpers/keeper_dag/struct/protobuf.py | 9 +- .../keepersdk/vault/share_management_utils.py | 90 ++++++++++++++++++- .../test_share_rotate_on_expiration.py | 39 ++++++++ 14 files changed, 307 insertions(+), 66 deletions(-) create mode 100644 keepersdk-package/unit_tests/test_share_rotate_on_expiration.py diff --git a/keepercli-package/src/keepercli/commands/pam/discovery/__init__.py b/keepercli-package/src/keepercli/commands/pam/discovery/__init__.py index 61382748..f941005d 100644 --- a/keepercli-package/src/keepercli/commands/pam/discovery/__init__.py +++ b/keepercli-package/src/keepercli/commands/pam/discovery/__init__.py @@ -200,7 +200,10 @@ def from_gateway(vault: vault_online.VaultOnline, gateway: str, configuration_ui application_id = utils.base64_url_encode(found_gateway.applicationUid) application = vault.vault_data.load_record(application_id) if application is None: - logger.debug(f"cannot find application for gateway {gateway}, skipping.") + logger.warning( + f"KSM application for gateway {gateway} is not in the vault " + f"(record {application_id}); discovery may still work via the router." + ) if (utils.base64_url_encode(found_gateway.controllerUid) == gateway or found_gateway.controllerName.lower() == gateway.lower()): diff --git a/keepercli-package/src/keepercli/commands/pam/discovery/discover.py b/keepercli-package/src/keepercli/commands/pam/discovery/discover.py index f65f2775..118b4fa9 100644 --- a/keepercli-package/src/keepercli/commands/pam/discovery/discover.py +++ b/keepercli-package/src/keepercli/commands/pam/discovery/discover.py @@ -18,7 +18,7 @@ from keepersdk.helpers.pam_user_record_facade import PamUserRecordFacade from keepersdk.helpers.keeper_dag.jobs import Jobs -from keepersdk.helpers.keeper_dag.dag_types import (CredentialBase, DiscoveryDelta, DiscoveryObject, JobItem, UserAcl, DirectoryInfo, +from keepersdk.helpers.keeper_dag.dag_types import (CredentialBase, DiscoveryDelta, DiscoveryObject, JobItem, Settings, UserAcl, DirectoryInfo, BulkRecordConvert, BulkRecordAdd, BulkRecordSuccess, BulkProcessResults, NormalizedRecord, BulkRecordFail, PromptResult, PromptActionEnum, RecordField) from keepersdk.helpers.keeper_dag.dag_vertex import DAGVertex @@ -152,7 +152,7 @@ def print_job_detail(vault: vault_online.VaultOnline, job_id: str): def _find_job(configuration_record) -> Optional[Dict]: - jobs_obj = Jobs(record=configuration_record) + jobs_obj = Jobs(record=configuration_record, vault=vault) job_item = jobs_obj.get_job(job_id) if job_item is not None: return { @@ -167,7 +167,7 @@ def _find_job(configuration_record) -> Optional[Dict]: if gateway_context is not None: jobs = payload["jobs"] job = jobs.get_job(job_id) - infra = Infrastructure(record=gateway_context.configuration) + infra = Infrastructure(record=gateway_context.configuration, vault=vault) status = "RUNNING" if job.end_ts is not None and not job.error: @@ -296,7 +296,7 @@ def execute(self, context: KeeperParams, **kwargs): if len(gateway_context.gateway_name) > max_gateway_name: max_gateway_name = len(gateway_context.gateway_name) - jobs = Jobs(record=configuration_record) + jobs = Jobs(record=configuration_record, vault=vault) if show_history is True: job_list = reversed(jobs.history) else: @@ -391,7 +391,7 @@ def execute(self, context: KeeperParams, **kwargs): multi_conf_msg(gateway, err) return - jobs = Jobs(record=gateway_context.configuration) + jobs = Jobs(record=gateway_context.configuration, vault=vault) current_job_item = jobs.current_job removed_prior_job = None if current_job_item is not None: @@ -467,15 +467,20 @@ def execute(self, context: KeeperParams, **kwargs): setattr(c, key, obj[key]) credentials.append(c.model_dump()) + user_map_entries = self.make_protobuf_user_map( + context=context, + gateway_context=gateway_context + ) + if len(user_map_entries) == 0: + logger.info( + "No pamUser records are linked to this configuration; " + "discovery will run without an existing user map." + ) + action_inputs = GatewayActionDiscoverJobStartInputs( configuration_uid=gateway_context.configuration_uid, resource_uid=kwargs.get('resource_uid'), - user_map=gateway_context.encrypt( - self.make_protobuf_user_map( - context=context, - gateway_context=gateway_context - )[0] - ), + user_map=gateway_context.encrypt({"users": user_map_entries}), shared_folder_uid=gateway_context.default_shared_folder_uid, languages=[kwargs.get('language')], @@ -507,16 +512,39 @@ def execute(self, context: KeeperParams, **kwargs): logger.error(f"The router returned a failure.") return + discovery_settings = Settings( + credentials=[CredentialBase(**c) for c in credentials], + default_shared_folder_uid=gateway_context.default_shared_folder_uid, + include_azure_aadds=kwargs.get('include_azure_aadds', False), + skip_rules=kwargs.get('skip_rules', False), + skip_machines=kwargs.get('skip_machines', False), + skip_databases=kwargs.get('skip_databases', False), + skip_directories=kwargs.get('skip_directories', False), + skip_cloud_users=kwargs.get('skip_cloud_users', False), + user_map=user_map_entries or None, + ) + job_id = jobs.start( + settings=discovery_settings, + resource_uid=kwargs.get('resource_uid'), + conversation_id=conversation_id, + ) + jobs.close() + if "has been queued" in data.get("Response", ""): if removed_prior_job is None: - logger.info("The discovery job is currently running.") + logger.info(f"Discovery job {job_id} is running.") else: - logger.info(f"Active discovery job {removed_prior_job} has been removed and new discovery job is running.") + logger.info( + f"Active discovery job {removed_prior_job} has been removed; " + f"discovery job {job_id} is running." + ) logger.info(f"To check the status, use the command 'pam action discover status'.") - logger.info(f"To stop and remove the current job, use the command 'pam action discover remove -j '.") + logger.info(f"To stop and remove the current job, use the command 'pam action discover remove -j {job_id}'.") else: router_utils.print_router_response(router_response, "job_info", conversation_id, gateway_uid=gateway_context.gateway_uid) + logger.info(f"Discovery job {job_id} was recorded on the configuration.") + logger.info(f"To check the status, use the command 'pam action discover status -j {job_id}'.") @staticmethod def make_protobuf_user_map(context: KeeperParams, gateway_context: GatewayContext) -> List[dict]: @@ -580,7 +608,7 @@ def execute(self, context: KeeperParams, **kwargs): all_gateways = GatewayContext.all_gateways(vault) def _find_job(configuration_record) -> Optional[Dict]: - jobs_obj = Jobs(record=configuration_record) + jobs_obj = Jobs(record=configuration_record, vault=vault) job_item = jobs_obj.get_job(job_id) if job_item is not None: return { @@ -1775,7 +1803,7 @@ def _get_directory_info(domain: str, def remove_job(context: KeeperParams, configuration_record: vault_record.KeeperRecord, job_id: str): try: - jobs = Jobs(record=configuration_record, context=context) + jobs = Jobs(record=configuration_record, vault=context.vault) jobs.cancel(job_id) logger.info(f"No items left to process. Removing completed discovery job.") except Exception as err: @@ -1786,7 +1814,7 @@ def preview(self, job_item: JobItem, context: KeeperParams, gateway_context: Gat sync_point = job_item.sync_point infra = Infrastructure(record=gateway_context.configuration, - context=context, + vault=context.vault, logger=logger, debug_level=debug_level) infra.load(sync_point) @@ -1941,7 +1969,7 @@ def execute(self, context: KeeperParams, **kwargs): # Get the current job. # There can only be one active job. - jobs = Jobs(record=configuration_record, context=context, logger=logger, debug_level=debug_level) + jobs = Jobs(record=configuration_record, vault=vault, logger=logger, debug_level=debug_level) job_item = jobs.current_job if job_item is None: continue diff --git a/keepercli-package/src/keepercli/commands/pam/pam_gateway_action.py b/keepercli-package/src/keepercli/commands/pam/pam_gateway_action.py index 777ce710..6299f7cd 100644 --- a/keepercli-package/src/keepercli/commands/pam/pam_gateway_action.py +++ b/keepercli-package/src/keepercli/commands/pam/pam_gateway_action.py @@ -255,7 +255,10 @@ def record_rotate(self, context: KeeperParams, record_uid, slient:bool = False): config_uid = facade.controller_uid if not resource_uid: - tmp_dag = tunnel_graph.TunnelDAG(vault, encrypted_session_token, encrypted_transmission_key, record.record_uid) + tmp_dag = tunnel_graph.TunnelDAG( + vault, encrypted_session_token, encrypted_transmission_key, record.record_uid, + transmission_key=transmission_key, + ) resource_uid = tmp_dag.get_resource_uid(record_uid) if not resource_uid: is_noop = False diff --git a/keepercli-package/src/keepercli/commands/pam/pam_rotation.py b/keepercli-package/src/keepercli/commands/pam/pam_rotation.py index a47bd3cc..c303be5b 100644 --- a/keepercli-package/src/keepercli/commands/pam/pam_rotation.py +++ b/keepercli-package/src/keepercli/commands/pam/pam_rotation.py @@ -295,7 +295,8 @@ def execute(self, context: KeeperParams, **kwargs): def config_resource(_dag, target_record, target_config_uid, silent=None): if not _dag.linking_dag.has_graph: if target_config_uid: - _dag = TunnelDAG(vault, encrypted_session_token, encrypted_transmission_key, target_config_uid) + _dag = TunnelDAG(vault, encrypted_session_token, encrypted_transmission_key, target_config_uid, + transmission_key=transmission_key) _dag.edit_tunneling_config(rotation=True) else: raise base.CommandError(f'Resource "{target_record.record_uid}" is not associated ' @@ -305,7 +306,7 @@ def config_resource(_dag, target_record, target_config_uid, silent=None): resource_dag = None if not _dag.resource_belongs_to_config(target_record.record_uid): resource_dag = TunnelDAG(vault, encrypted_session_token, encrypted_transmission_key, - target_record.record_uid) + target_record.record_uid, transmission_key=transmission_key) _dag.link_resource_to_config(target_record.record_uid) admin = kwargs.get('admin') @@ -401,10 +402,12 @@ def config_iam_aad_user(_dag, target_record, target_iam_aad_config_uid): return if _dag and not _dag.linking_dag.has_graph: - _dag = TunnelDAG(vault, encrypted_session_token, encrypted_transmission_key, target_iam_aad_config_uid) + _dag = TunnelDAG(vault, encrypted_session_token, encrypted_transmission_key, target_iam_aad_config_uid, + transmission_key=transmission_key) if not _dag or not _dag.linking_dag.has_graph: _dag.edit_tunneling_config(rotation=True) - old_dag = TunnelDAG(vault, encrypted_session_token, encrypted_transmission_key, target_record.record_uid) + old_dag = TunnelDAG(vault, encrypted_session_token, encrypted_transmission_key, target_record.record_uid, + transmission_key=transmission_key) if old_dag.linking_dag.has_graph and old_dag.record.record_uid != target_iam_aad_config_uid: old_dag.remove_from_dag(target_record.record_uid) @@ -621,7 +624,8 @@ def config_user(_dag, target_record, target_resource_uid, target_config_uid=None return if isinstance(target_resource_uid, str) and len(target_resource_uid) > 0: - _dag = TunnelDAG(vault, encrypted_session_token, encrypted_transmission_key, target_resource_uid) + _dag = TunnelDAG(vault, encrypted_session_token, encrypted_transmission_key, target_resource_uid, + transmission_key=transmission_key) if not _dag or not _dag.linking_dag.has_graph: if target_config_uid and target_resource_uid: config_resource(_dag, target_record, target_config_uid, silent=silent) @@ -639,7 +643,8 @@ def config_user(_dag, target_record, target_resource_uid, target_config_uid=None current_record_rotation = context.get_record_rotation(target_record.record_uid) if not _dag or not _dag.linking_dag.has_graph: - _dag = TunnelDAG(vault, encrypted_session_token, encrypted_transmission_key, target_resource_uid) + _dag = TunnelDAG(vault, encrypted_session_token, encrypted_transmission_key, target_resource_uid, + transmission_key=transmission_key) if not _dag.linking_dag.has_graph: raise base.CommandError(f'Resource "{target_resource_uid}" is not associated ' f'with any configuration. ' @@ -824,6 +829,8 @@ def config_user(_dag, target_record, target_resource_uid, target_config_uid=None if record_name: if record_name in vault.vault_data._records: record_uids.add(record_name) + elif vault.vault_data.load_record(record_name): + record_uids.add(record_name) else: rs = folder_utils.try_resolve_path(context, record_name) if rs is not None: @@ -866,7 +873,10 @@ def add_folders(folder: vault_types.Folder): if folder_uids: regex = re.compile(fnmatch.translate(record_pattern), re.IGNORECASE).match if record_pattern else None for folder_uid in folder_uids: - folder_records = vault.vault_data.get_folder(folder_uid).records + folder = vault.vault_data.get_folder(folder_uid) + if not folder: + continue + folder_records = folder.records if not folder_records: continue if record_pattern and record_pattern in folder_records: @@ -957,7 +967,8 @@ def add_folders(folder: vault_types.Folder): r_requests = [] for _record in pam_records: - tmp_dag = TunnelDAG(vault, encrypted_session_token, encrypted_transmission_key, _record.record_uid) + tmp_dag = TunnelDAG(vault, encrypted_session_token, encrypted_transmission_key, _record.record_uid, + transmission_key=transmission_key) if _record.record_type in ['pamMachine', 'pamDatabase', 'pamDirectory', 'pamRemoteBrowser']: config_resource(tmp_dag, _record, config_uid, silent=kwargs.get('silent')) elif _record.record_type == 'pamUser': @@ -1108,7 +1119,7 @@ def is_resource_ok(resource_id, vault, configuration_uid): logger.info(f"Is Rotation Disabled: {rri.disabled}") rq = pam_pb2.PAMGenericUidsRequest() - schedules_proto = router_utils.router_get_rotation_schedules(context, rq) + schedules_proto = router_utils.router_get_rotation_schedules(vault, rq) if schedules_proto: schedules = list(schedules_proto.schedules) for s in schedules: diff --git a/keepercli-package/src/keepercli/commands/shares.py b/keepercli-package/src/keepercli/commands/shares.py index 223f337e..f1107516 100644 --- a/keepercli-package/src/keepercli/commands/shares.py +++ b/keepercli-package/src/keepercli/commands/shares.py @@ -110,6 +110,10 @@ def add_arguments_to_parser(parser: argparse.ArgumentParser): metavar='[(mi)nutes|(h)ours|(d)ays|(mo)nths|(y)ears]', help='share expiration: never or period' ) + parser.add_argument( + '-roe', '--rotate-on-expiration', dest='rotate_on_expiration', action='store_true', + help='Rotate pamUser password when share access expires (requires positive expiration)' + ) parser.add_argument( 'record', nargs='?', type=str, action='store', help='record/shared folder path/UID' ) @@ -136,22 +140,31 @@ def execute(self, context: KeeperParams, **kwargs) -> None: vault.sync_down() return - share_expiration = share_management_utils.get_share_expiration( - kwargs.get('expire_at'), kwargs.get('expire_in') - ) + rotate_on_expiration = kwargs.get('rotate_on_expiration') is True + try: + share_expiration = share_management_utils.get_share_expiration( + kwargs.get('expire_at'), kwargs.get('expire_in') + ) + share_management_utils.validate_rotate_on_expiration(share_expiration, rotate_on_expiration) + except share_management_utils.ShareValidationError as err: + raise base.CommandError(str(err)) from err - request = RecordShares.prep_request( - vault=vault, - enterprise=context.enterprise_data, - uid_or_name=uid_or_name, - emails=emails, - share_expiration=share_expiration, - action=action, - dry_run=kwargs.get('dry_run', False), - can_edit=kwargs.get('can_edit'), - can_share=kwargs.get('can_share'), - recursive=kwargs.get('recursive') - ) + try: + request = RecordShares.prep_request( + vault=vault, + enterprise=context.enterprise_data, + uid_or_name=uid_or_name, + emails=emails, + share_expiration=share_expiration, + action=action, + dry_run=kwargs.get('dry_run', False), + can_edit=kwargs.get('can_edit'), + can_share=kwargs.get('can_share'), + recursive=kwargs.get('recursive'), + rotate_on_expiration=rotate_on_expiration, + ) + except (share_management_utils.ShareValidationError, ValueError) as err: + raise base.CommandError(str(err)) from err if request: success_responses, failed_responses = RecordShares.send_requests(vault, [request]) if success_responses: @@ -261,6 +274,10 @@ def add_arguments_to_parser(parser: argparse.ArgumentParser): '--expire-in', dest='expire_in', action='store', metavar='PERIOD', help='share expiration: never or period ([(y)ears|(mo)nths|(d)ays|(h)ours(mi)nutes]' ) + parser.add_argument( + '-roe', '--rotate-on-expiration', dest='rotate_on_expiration', action='store_true', + help='Rotate pamUser passwords when share access expires (requires positive expiration)' + ) parser.add_argument( 'folder', nargs='+', type=str, action='store', help='shared folder path or UID' ) @@ -277,7 +294,19 @@ def execute(self, context: KeeperParams, **kwargs) -> None: raise ValueError('Enter name of at least one existing folder') action = kwargs.get('action') or ShareAction.GRANT.value - share_expiration = self._get_share_expiration(action, kwargs) + rotate_on_expiration = kwargs.get('rotate_on_expiration') is True + try: + share_expiration = self._get_share_expiration(action, kwargs) + share_management_utils.validate_rotate_on_expiration(share_expiration, rotate_on_expiration) + if rotate_on_expiration and action != ShareAction.GRANT.value: + raise share_management_utils.ShareValidationError( + '--rotate-on-expiration is only valid with --action grant' + ) + share_management_utils.validate_folder_shares_rotate_on_expiration( + vault, shared_folder_uids, rotate_on_expiration + ) + except share_management_utils.ShareValidationError as err: + raise base.CommandError(str(err)) from err user_data = self._parse_user_arguments(vault, kwargs) record_data = self._parse_record_arguments(vault, kwargs) @@ -287,8 +316,8 @@ def execute(self, context: KeeperParams, **kwargs) -> None: return rq_groups = self._prepare_request_groups( - vault, shared_folder_uids, user_data, record_data, - action, share_expiration, kwargs + vault, shared_folder_uids, user_data, record_data, + action, share_expiration, kwargs, rotate_on_expiration ) success_responses, failed_responses = FolderShares.send_requests(vault=vault, partitioned_requests=rq_groups) if success_responses: @@ -523,15 +552,16 @@ def _is_nothing_to_do(self, user_data: Dict, record_data: Dict) -> bool: def _prepare_request_groups(self, vault: vault_online.VaultOnline, shared_folder_uids: Set, user_data: Dict, record_data: Dict, action: str, - share_expiration, kwargs: Dict) -> List: + share_expiration, kwargs: Dict, + rotate_on_expiration: bool = False) -> List: """Prepare request groups for all shared folders.""" rq_groups = [] shared_folder_cache = {x.shared_folder_uid: x for x in vault.vault_data.shared_folders()} for sf_uid in shared_folder_uids: folder_requests = self._prepare_folder_requests( - vault, sf_uid, shared_folder_cache, user_data, - record_data, action, share_expiration, kwargs + vault, sf_uid, shared_folder_cache, user_data, + record_data, action, share_expiration, kwargs, rotate_on_expiration ) rq_groups.extend(folder_requests) @@ -540,7 +570,7 @@ def _prepare_request_groups(self, vault: vault_online.VaultOnline, shared_folder def _prepare_folder_requests(self, vault: vault_online.VaultOnline, sf_uid: str, shared_folder_cache: Dict, user_data: Dict, record_data: Dict, action: str, share_expiration, - kwargs: Dict) -> List: + kwargs: Dict, rotate_on_expiration: bool = False) -> List: """Prepare requests for a single shared folder.""" sf_users = user_data['users'].copy() sf_teams = user_data['teams'].copy() @@ -557,7 +587,8 @@ def _prepare_folder_requests(self, vault: vault_online.VaultOnline, sf_uid: str, return self._chunk_and_prepare_requests( vault, kwargs, sh_fol, sf_uid, sf_users, sf_teams, sf_records, - record_data['default_record'], user_data['default_account'], share_expiration + record_data['default_record'], user_data['default_account'], share_expiration, + rotate_on_expiration, ) def _load_or_create_shared_folder(self, vault: vault_online.VaultOnline, sf_uid: str, shared_folder_cache: Dict, @@ -613,7 +644,7 @@ def _update_from_existing_folder(self, sh_fol, auth: keeper_auth.KeeperAuth, def _chunk_and_prepare_requests(self, vault: vault_online.VaultOnline, kwargs: Dict, sh_fol, sf_uid: str, sf_users: Set, sf_teams: Set, sf_records: Set, default_record: bool, default_account: bool, - share_expiration) -> List: + share_expiration, rotate_on_expiration: bool = False) -> List: """Chunk records and users, then prepare requests.""" rec_list = list(sf_records) user_list = list(sf_users) @@ -644,7 +675,8 @@ def _chunk_and_prepare_requests(self, vault: vault_online.VaultOnline, kwargs: D vault, kwargs, sf_info, u_chunk, sf_teams, r_chunk, default_record=default_record, default_account=default_account, - share_expiration=share_expiration + share_expiration=share_expiration, + rotate_on_expiration=rotate_on_expiration, ) rq_groups[group_idx].append(request) group_idx += 1 diff --git a/keepercli-package/src/keepercli/helpers/report_utils.py b/keepercli-package/src/keepercli/helpers/report_utils.py index 0e4f6395..c1969d88 100644 --- a/keepercli-package/src/keepercli/helpers/report_utils.py +++ b/keepercli-package/src/keepercli/helpers/report_utils.py @@ -118,6 +118,9 @@ def dump_report_data(data: List[List[Any]], # sort_desc: bool - Descending Sort # right_align: Sequence[int] - Force right align + if filename: + filename = os.path.expanduser(filename) + append = kwargs.get('append') is True title = kwargs.get('title') sort_by = kwargs.get('sort_by') diff --git a/keepersdk-package/src/keepersdk/helpers/keeper_dag/dag.py b/keepersdk-package/src/keepersdk/helpers/keeper_dag/dag.py index 152a743e..e2360f5d 100644 --- a/keepersdk-package/src/keepersdk/helpers/keeper_dag/dag.py +++ b/keepersdk-package/src/keepersdk/helpers/keeper_dag/dag.py @@ -9,7 +9,7 @@ from typing import Any, List, Optional, Tuple, Union from . import dag_utils, dag_crypto -from .dag_types import EdgeType, Ref, RefType, ENDPOINT_TO_GRAPH_ID_MAP, DAGData +from .dag_types import EdgeType, Ref, RefType, ENDPOINT_TO_GRAPH_ID_MAP, DAGData, endpoint_for_graph_id from .dag_vertex import DAGVertex from .struct.protobuf import DataStruct as ProtobufDataStruct from .struct.default import DataStruct as DefaultDataStruct @@ -171,6 +171,15 @@ def __init__(self, self.conn = conn + if self.conn.use_write_protobuf and self.write_endpoint is None and self.graph_id is not None: + mapped_endpoint = endpoint_for_graph_id(self.graph_id) + if mapped_endpoint is not None: + self.write_endpoint = mapped_endpoint + if self.conn.use_read_protobuf and self.read_endpoint is None and self.graph_id is not None: + mapped_endpoint = endpoint_for_graph_id(self.graph_id) + if mapped_endpoint is not None: + self.read_endpoint = mapped_endpoint + self.read_struct_obj: Union[ProtobufDataStruct, DefaultDataStruct] = ProtobufDataStruct() \ if conn.use_read_protobuf else DefaultDataStruct() self.write_struct_obj: Union[ProtobufDataStruct, DefaultDataStruct] = ProtobufDataStruct() \ diff --git a/keepersdk-package/src/keepersdk/helpers/keeper_dag/dag_types.py b/keepersdk-package/src/keepersdk/helpers/keeper_dag/dag_types.py index bbe7f192..627811a8 100644 --- a/keepersdk-package/src/keepersdk/helpers/keeper_dag/dag_types.py +++ b/keepersdk-package/src/keepersdk/helpers/keeper_dag/dag_types.py @@ -49,6 +49,15 @@ class PamEndpoints(BaseEnum): } +def endpoint_for_graph_id(graph_id: Union[int, "PamGraphId", Enum]) -> Optional[str]: + if isinstance(graph_id, Enum): + graph_id = graph_id.value + for endpoint, gid in ENDPOINT_TO_GRAPH_ID_MAP.items(): + if gid == graph_id: + return endpoint + return None + + class SyncQuery(BaseModel): streamId: Optional[str] = None # base64 of a user's ID who is syncing. deviceId: Optional[str] = None diff --git a/keepersdk-package/src/keepersdk/helpers/keeper_dag/dag_utils.py b/keepersdk-package/src/keepersdk/helpers/keeper_dag/dag_utils.py index 72d1d6d4..c7e084a8 100644 --- a/keepersdk-package/src/keepersdk/helpers/keeper_dag/dag_utils.py +++ b/keepersdk-package/src/keepersdk/helpers/keeper_dag/dag_utils.py @@ -33,6 +33,10 @@ def get_connection(**kwargs): return kwargs.get("connection") vault = kwargs.get("vault") + if vault is None: + context = kwargs.get("context") + if context is not None: + vault = getattr(context, "vault", None) logger = kwargs.get("logger") if value_to_boolean(os.environ.get("USE_LOCAL_DAG")): from ..keeper_dag.connection.local import Connection @@ -40,6 +44,12 @@ def get_connection(**kwargs): else: use_read_protobuf = kwargs.get("use_read_protobuf") use_write_protobuf = kwargs.get("use_write_protobuf") + if use_read_protobuf is None: + env_val = os.environ.get("GS_USE_READ_PROTOBUF") + use_read_protobuf = True if env_val is None else value_to_boolean(env_val) + if use_write_protobuf is None: + env_val = os.environ.get("GS_USE_WRITE_PROTOBUF") + use_write_protobuf = True if env_val is None else value_to_boolean(env_val) if vault is not None: from ..keeper_dag.connection.commander import Connection diff --git a/keepersdk-package/src/keepersdk/helpers/keeper_dag/jobs.py b/keepersdk-package/src/keepersdk/helpers/keeper_dag/jobs.py index eb24648a..c15d65e9 100644 --- a/keepersdk-package/src/keepersdk/helpers/keeper_dag/jobs.py +++ b/keepersdk-package/src/keepersdk/helpers/keeper_dag/jobs.py @@ -32,10 +32,11 @@ def __init__(self, record: Any, logger: Optional[Any] = None, debug_level: int = log_prefix: str = "GS Jobs", save_batch_count: int = 200, agent: Optional[str] = None, **kwargs): - self.conn = get_connection(logger=logger, **kwargs) - self.record = record self._dag = None + self.conn = None + + self.conn = get_connection(logger=logger, **kwargs) if logger is None: logger = logging.getLogger() logger.propagate = False @@ -89,7 +90,7 @@ def close(self): Clean up resources held by this Jobs instance. Releases the DAG instance and connection to prevent memory leaks. """ - if self._dag is not None: + if getattr(self, "_dag", None) is not None: self._dag = None self.conn = None diff --git a/keepersdk-package/src/keepersdk/helpers/keeper_dag/record_link.py b/keepersdk-package/src/keepersdk/helpers/keeper_dag/record_link.py index 9cbb7f43..29e8c1d8 100644 --- a/keepersdk-package/src/keepersdk/helpers/keeper_dag/record_link.py +++ b/keepersdk-package/src/keepersdk/helpers/keeper_dag/record_link.py @@ -24,13 +24,12 @@ def __init__(self, use_write_protobuf: bool = True, **kwargs): + self.record = record + self._dag = None self.conn = get_connection(logger=logger, use_read_protobuf=use_read_protobuf, use_write_protobuf=use_write_protobuf, **kwargs) - - self.record = record - self._dag = None if logger is None: logger = logging.getLogger() self.logger = logger @@ -81,9 +80,10 @@ def close(self): Clean up resources held by this RecordLink instance. Releases the DAG instance and connection to prevent memory leaks. """ - if self._dag is not None: + if getattr(self, "_dag", None) is not None: self._dag = None - self.conn = None + if getattr(self, "conn", None) is not None: + self.conn = None def __enter__(self): """Context manager entry.""" diff --git a/keepersdk-package/src/keepersdk/helpers/keeper_dag/struct/protobuf.py b/keepersdk-package/src/keepersdk/helpers/keeper_dag/struct/protobuf.py index b584d3a0..28e9a2fa 100644 --- a/keepersdk-package/src/keepersdk/helpers/keeper_dag/struct/protobuf.py +++ b/keepersdk-package/src/keepersdk/helpers/keeper_dag/struct/protobuf.py @@ -109,11 +109,16 @@ def get_sync_result(results: bytes) -> SyncData: data_list: List[SyncDataItem] = [] for item in message.data: + raw_content = item.data.content + if raw_content: + encoded_content = dag_crypto.bytes_to_urlsafe_str(raw_content) + else: + encoded_content = None data_list.append( SyncDataItem( type=DataStruct.PB_TO_DATA_MAP.get(item.data.type), - content=item.data.content, - content_is_base64=False, + content=encoded_content, + content_is_base64=True, ref=Ref( type=DataStruct.PB_TO_REF_MAP.get(item.data.ref.type), value=dag_crypto.bytes_to_urlsafe_str(item.data.ref.value), diff --git a/keepersdk-package/src/keepersdk/vault/share_management_utils.py b/keepersdk-package/src/keepersdk/vault/share_management_utils.py index cf9a55e7..74d268af 100644 --- a/keepersdk-package/src/keepersdk/vault/share_management_utils.py +++ b/keepersdk-package/src/keepersdk/vault/share_management_utils.py @@ -6,7 +6,7 @@ from typing import Optional, Dict, List, Any, Generator, Iterable, Set, Tuple, Union from .. import crypto, utils -from ..proto import enterprise_pb2, folder_pb2, record_pb2 +from ..proto import enterprise_pb2, folder_pb2, pam_pb2, record_pb2, router_pb2 from . import vault_data, storage_types, vault_online, vault_record, vault_types, vault_utils, sync_down from ..enterprise import enterprise_data @@ -94,6 +94,94 @@ def get_share_expiration(expire_at: Optional[str], expire_in: Optional[str]) -> raise ShareValidationError(f'Invalid expiration format: {e}') from e +PAM_USER_RECORD_TYPE = 'pamUser' + + +def validate_rotate_on_expiration(share_expiration: int, rotate_on_expiration: bool) -> None: + """Require a positive expiration when rotate-on-expiration is requested.""" + if not rotate_on_expiration: + return + if share_expiration <= 0 or share_expiration == NEVER_EXPIRES: + raise ShareValidationError( + '--rotate-on-expiration requires a positive --expire-at or --expire-in (not "never")' + ) + + +def is_pam_user_record(vault: vault_online.VaultOnline, record_uid: str) -> bool: + info = vault.vault_data.get_record(record_uid) + return bool(info and info.record_type == PAM_USER_RECORD_TYPE) + + +def pam_user_has_rotation_configured(vault: vault_online.VaultOnline, record_uid: str) -> bool: + """True when pam/get_rotation_info reports an enabled rotation configuration.""" + try: + rq = pam_pb2.PAMGenericUidRequest() + rq.uid = utils.base64_url_decode(record_uid) + rs = vault.keeper_auth.execute_auth_rest( + rest_endpoint='pam/get_rotation_info', + request=rq, + response_type=router_pb2.RouterRotationInfo, + ) + except Exception: + return False + if not rs or rs.disabled: + return False + return bool(rs.configurationUid) + + +def get_shared_folder_record_uids(vault: vault_online.VaultOnline, shared_folder_uid: str) -> Set[str]: + """Collect record UIDs contained in a shared folder tree.""" + record_uids: Set[str] = set() + folder = vault.vault_data.get_folder(shared_folder_uid) + if not folder: + return record_uids + + def add_records(folder_obj: vault_types.Folder) -> None: + record_uids.update(folder_obj.records) + + vault_utils.traverse_folder_tree(vault.vault_data, folder, add_records) + return record_uids + + +def validate_record_shares_rotate_on_expiration( + vault: vault_online.VaultOnline, + record_uids: Iterable[str], + rotate_on_expiration: bool, +) -> None: + if not rotate_on_expiration: + return + for record_uid in record_uids: + if not is_pam_user_record(vault, record_uid): + info = vault.vault_data.get_record(record_uid) + title = info.title if info else record_uid + raise ShareValidationError( + f'--rotate-on-expiration is supported only for pamUser records ' + f'("{title}" / {record_uid})' + ) + + +def validate_folder_shares_rotate_on_expiration( + vault: vault_online.VaultOnline, + shared_folder_uids: Iterable[str], + rotate_on_expiration: bool, +) -> None: + if not rotate_on_expiration: + return + for sf_uid in shared_folder_uids: + record_uids = get_shared_folder_record_uids(vault, sf_uid) + for record_uid in record_uids: + if is_pam_user_record(vault, record_uid) and pam_user_has_rotation_configured(vault, record_uid): + return + sf_name = sf_uid + sf_info = vault.vault_data.get_shared_folder(sf_uid) + if sf_info: + sf_name = sf_info.name or sf_uid + raise ShareValidationError( + f'--rotate-on-expiration requires at least one pamUser record with rotation ' + f'configured in shared folder "{sf_name}"' + ) + + def get_share_objects(vault: vault_online.VaultOnline) -> Dict[str, Dict[str, Any]]: try: request = record_pb2.GetShareObjectsRequest() diff --git a/keepersdk-package/unit_tests/test_share_rotate_on_expiration.py b/keepersdk-package/unit_tests/test_share_rotate_on_expiration.py new file mode 100644 index 00000000..006a0f91 --- /dev/null +++ b/keepersdk-package/unit_tests/test_share_rotate_on_expiration.py @@ -0,0 +1,39 @@ +import unittest +from unittest import mock + +from keepersdk.vault import share_management_utils + + +class TestRotateOnExpirationValidation(unittest.TestCase): + + def test_validate_rotate_on_expiration_requires_positive_expiration(self): + with self.assertRaises(share_management_utils.ShareValidationError): + share_management_utils.validate_rotate_on_expiration(0, True) + with self.assertRaises(share_management_utils.ShareValidationError): + share_management_utils.validate_rotate_on_expiration(-1, True) + + def test_validate_rotate_on_expiration_allows_positive_expiration(self): + share_management_utils.validate_rotate_on_expiration(1_700_000_000, True) + + def test_set_expiration_fields_sets_rotate_flag(self): + from keepersdk.proto import record_pb2 + from keepersdk.vault.shares_management import set_expiration_fields + + ro = record_pb2.SharedRecord() + set_expiration_fields(ro, 1_700_000_000, rotate_on_expiration=True) + self.assertTrue(ro.rotateOnExpiration) + self.assertGreater(ro.expiration, 0) + + def test_validate_record_shares_requires_pam_user(self): + vault = mock.Mock() + info = mock.Mock(record_type='login', title='Not PAM') + vault.vault_data.get_record.return_value = info + + with self.assertRaises(share_management_utils.ShareValidationError): + share_management_utils.validate_record_shares_rotate_on_expiration( + vault, ['abc123'], True + ) + + +if __name__ == '__main__': + unittest.main() From 01e5dbd11be952241b0e517ba41df30cacebf9c1 Mon Sep 17 00:00:00 2001 From: Manan Tyagi Date: Tue, 16 Jun 2026 10:44:53 +0530 Subject: [PATCH 2/2] added left file --- .../src/keepersdk/vault/shares_management.py | 56 ++++++++++++------- 1 file changed, 37 insertions(+), 19 deletions(-) diff --git a/keepersdk-package/src/keepersdk/vault/shares_management.py b/keepersdk-package/src/keepersdk/vault/shares_management.py index b145464d..8f582ec6 100644 --- a/keepersdk-package/src/keepersdk/vault/shares_management.py +++ b/keepersdk-package/src/keepersdk/vault/shares_management.py @@ -52,12 +52,14 @@ class ManagePermission(Enum): FOLDER_TYPE_SHARED_FOLDER = 'shared_folder' FOLDER_TYPE_SHARED_FOLDER_FOLDER = 'shared_folder_folder' -def set_expiration_fields(obj, expiration): - """Set expiration and timerNotificationType fields on proto object if expiration is provided.""" +def set_expiration_fields(obj, expiration, rotate_on_expiration: bool = False): + """Set expiration, notification, and optional rotateOnExpiration on a share proto object.""" if isinstance(expiration, int): if expiration > 0: obj.expiration = expiration * TIMESTAMP_MILLISECONDS_FACTOR obj.timerNotificationType = record_pb2.NOTIFY_OWNER + if rotate_on_expiration: + obj.rotateOnExpiration = True elif expiration < 0: obj.expiration = -1 @@ -205,7 +207,8 @@ def _encrypt_record_key_for_user(vault, record_key, email, ro): @staticmethod def _build_shared_record(vault, email, record_uid, record_path, action, - can_edit, can_share, share_expiration, existing_shares): + can_edit, can_share, share_expiration, existing_shares, + rotate_on_expiration: bool = False): """Build a SharedRecord proto object for a user.""" ro = record_pb2.SharedRecord() ro.toUsername = email @@ -227,21 +230,22 @@ def _build_shared_record(vault, email, record_uid, record_path, action, else: ro.editable = bool(can_edit) ro.shareable = bool(can_share) - set_expiration_fields(ro, share_expiration) + set_expiration_fields(ro, share_expiration, rotate_on_expiration) else: if can_share or can_edit: if email in existing_shares: current = existing_shares[email] ro.editable = False if can_edit else current.get('editable') ro.shareable = False if can_share else current.get('shareable') - set_expiration_fields(ro, share_expiration) + set_expiration_fields(ro, share_expiration, rotate_on_expiration) return ro @staticmethod def _process_record_shares(vault, record_uids, all_users, action, can_edit, can_share, share_expiration, record_cache, - not_owned_records, is_share_admin, enterprise): + not_owned_records, is_share_admin, enterprise, + rotate_on_expiration: bool = False): """Process shares for all records and users, building the request.""" rq = record_pb2.RecordShareUpdateRequest() @@ -284,7 +288,8 @@ def _process_record_shares(vault, record_uids, all_users, action, can_edit, for email in all_users: ro = RecordShares._build_shared_record( vault, email, record_uid, record_path, action, - can_edit, can_share, share_expiration, existing_shares + can_edit, can_share, share_expiration, existing_shares, + rotate_on_expiration, ) if action in {ShareAction.GRANT.value, ShareAction.OWNER.value}: @@ -316,7 +321,8 @@ def prep_request(vault: vault_online.VaultOnline, enterprise_access: bool = False, recursive: bool = False, can_edit: bool = False, - can_share: bool = False): + can_share: bool = False, + rotate_on_expiration: bool = False): """Prepare a record share update request.""" # Build caches record_cache = {x.record_uid: x for x in vault.vault_data.records()} @@ -345,6 +351,14 @@ def prep_request(vault: vault_online.VaultOnline, if not record_uids: raise ValueError('There are no records to share selected') + + if rotate_on_expiration: + if action != ShareAction.GRANT.value: + raise ValueError('--rotate-on-expiration is only valid with --action grant') + share_management_utils.validate_rotate_on_expiration(share_expiration, rotate_on_expiration) + share_management_utils.validate_record_shares_rotate_on_expiration( + vault, record_uids, rotate_on_expiration + ) if action == ShareAction.OWNER.value and len(emails) > 1: raise ValueError('You can transfer ownership to a single account only') @@ -382,7 +396,8 @@ def prep_request(vault: vault_online.VaultOnline, # Build the request return RecordShares._process_record_shares( vault, record_uids, all_users, action, can_edit, can_share, - share_expiration, record_cache, not_owned_records, is_share_admin, enterprise + share_expiration, record_cache, not_owned_records, is_share_admin, enterprise, + rotate_on_expiration, ) @staticmethod @@ -521,7 +536,8 @@ def _process_default_account_permissions(rq, action, mr, mu, default_account): rq.defaultManageUsers = FolderShares._convert_manage_permission(mu) @staticmethod - def _process_users(vault, rq, curr_sf, users, action, mr, mu, share_expiration): + def _process_users(vault, rq, curr_sf, users, action, mr, mu, share_expiration, + rotate_on_expiration: bool = False): """Process user shares for the shared folder.""" if not users: return @@ -531,7 +547,7 @@ def _process_users(vault, rq, curr_sf, users, action, mr, mu, share_expiration): for email in users: uo = folder_pb2.SharedFolderUpdateUser() uo.username = email - set_expiration_fields(uo, share_expiration) + set_expiration_fields(uo, share_expiration, rotate_on_expiration) if email in existing_users: if action == ShareAction.GRANT.value: @@ -563,7 +579,8 @@ def _process_users(vault, rq, curr_sf, users, action, mr, mu, share_expiration): logger.warning('User %s not found', email) @staticmethod - def _process_teams(vault, rq, curr_sf, teams, action, mr, mu, share_expiration): + def _process_teams(vault, rq, curr_sf, teams, action, mr, mu, share_expiration, + rotate_on_expiration: bool = False): """Process team shares for the shared folder.""" if not teams: return @@ -573,7 +590,7 @@ def _process_teams(vault, rq, curr_sf, teams, action, mr, mu, share_expiration): for team_uid in teams: to = folder_pb2.SharedFolderUpdateTeam() to.teamUid = utils.base64_url_decode(team_uid) - set_expiration_fields(to, share_expiration) + set_expiration_fields(to, share_expiration, rotate_on_expiration) if team_uid in existing_teams: team = existing_teams[team_uid] @@ -608,7 +625,8 @@ def _process_default_record_permissions(rq, action, ce, cs, default_record): rq.defaultCanShare = FolderShares._convert_manage_permission(cs) @staticmethod - def _process_records(vault, rq, curr_sf, rec_uids, action, ce, cs, share_expiration): + def _process_records(vault, rq, curr_sf, rec_uids, action, ce, cs, share_expiration, + rotate_on_expiration: bool = False): """Process record shares for the shared folder.""" if not rec_uids: return @@ -618,7 +636,7 @@ def _process_records(vault, rq, curr_sf, rec_uids, action, ce, cs, share_expirat for record_uid in rec_uids: ro = folder_pb2.SharedFolderUpdateRecord() ro.recordUid = utils.base64_url_decode(record_uid) - set_expiration_fields(ro, share_expiration) + set_expiration_fields(ro, share_expiration, rotate_on_expiration) if record_uid in existing_records: if action == ShareAction.GRANT.value: @@ -650,7 +668,7 @@ def _process_records(vault, rq, curr_sf, rec_uids, action, ce, cs, share_expirat @staticmethod def prepare_request(vault: vault_online.VaultOnline, kwargs, curr_sf, users, teams, rec_uids, *, default_record=False, default_account=False, - share_expiration=None): + share_expiration=None, rotate_on_expiration: bool = False): """Prepare a shared folder update request.""" rq = folder_pb2.SharedFolderUpdateV3Request() FolderShares._initialize_request(rq, curr_sf) @@ -662,10 +680,10 @@ def prepare_request(vault: vault_online.VaultOnline, kwargs, curr_sf, users, tea cs = kwargs.get('can_share') FolderShares._process_default_account_permissions(rq, action, mr, mu, default_account) - FolderShares._process_users(vault, rq, curr_sf, users, action, mr, mu, share_expiration) - FolderShares._process_teams(vault, rq, curr_sf, teams, action, mr, mu, share_expiration) + FolderShares._process_users(vault, rq, curr_sf, users, action, mr, mu, share_expiration, rotate_on_expiration) + FolderShares._process_teams(vault, rq, curr_sf, teams, action, mr, mu, share_expiration, rotate_on_expiration) FolderShares._process_default_record_permissions(rq, action, ce, cs, default_record) - FolderShares._process_records(vault, rq, curr_sf, rec_uids, action, ce, cs, share_expiration) + FolderShares._process_records(vault, rq, curr_sf, rec_uids, action, ce, cs, share_expiration, rotate_on_expiration) return rq