diff --git a/keepercommander/__init__.py b/keepercommander/__init__.py index 845f8bb94..c54d16492 100644 --- a/keepercommander/__init__.py +++ b/keepercommander/__init__.py @@ -10,4 +10,4 @@ # Contact: commander@keepersecurity.com # -__version__ = '18.0.7' +__version__ = '18.0.8' diff --git a/keepercommander/auth/console_ui.py b/keepercommander/auth/console_ui.py index 0a60b077b..eca6bf3b4 100644 --- a/keepercommander/auth/console_ui.py +++ b/keepercommander/auth/console_ui.py @@ -18,6 +18,31 @@ def _stderr(msg=''): print(msg, file=sys.stderr) +_HEADLESS_AUTH_MSG_SHOWN = False + + +def _is_interactive(): + try: + return bool(sys.stdin) and sys.stdin.isatty() + except Exception: + return False + + +def _fail_headless_auth(step): + """In headless/service mode, persistent login often needs a follow-up prompt + (password, SSO, 2FA, device approval) that cannot be answered. Log once and + cancel so the caller exits cleanly instead of looping or spamming getpass.""" + global _HEADLESS_AUTH_MSG_SHOWN + if not _HEADLESS_AUTH_MSG_SHOWN: + _HEADLESS_AUTH_MSG_SHOWN = True + logging.error( + 'Persistent login is not working in this non-interactive environment ' + '(possibly due to an IP/location change). ' + 'Re-run Commander/Docker setup from this network, then restart the service.' + ) + step.cancel() + + class ConsoleLoginUi(login_steps.LoginUi): def __init__(self): self._show_device_approval_help = True @@ -28,6 +53,9 @@ def __init__(self): self._failed_password_attempt = 0 def on_device_approval(self, step): + if not _is_interactive(): + _fail_headless_auth(step) + return if self._show_device_approval_help: _stderr(f"\n{Fore.YELLOW}Device Approval Required{Fore.RESET}\n") _stderr(f"{Fore.CYAN}Select an approval method:{Fore.RESET}") @@ -123,6 +151,9 @@ def two_factor_channel_to_desc(channel): # type: (login_steps.TwoFactorChannel return 'Backup Codes' def on_two_factor(self, step): + if not _is_interactive(): + _fail_headless_auth(step) + return channels = step.get_channels() if self._show_two_factor_help: @@ -273,6 +304,9 @@ def on_two_factor(self, step): logging.warning(f'{Fore.YELLOW}Invalid 2FA code. Please try again.{Fore.RESET}') def on_password(self, step): + if not _is_interactive(): + _fail_headless_auth(step) + return if self._show_password_help: _stderr(f'{Fore.CYAN}Enter master password for {Fore.WHITE}{step.username}{Fore.RESET}') @@ -293,6 +327,9 @@ def on_password(self, step): step.cancel() def on_sso_redirect(self, step): + if not _is_interactive(): + _fail_headless_auth(step) + return try: wb = webbrowser.get() wrappers = set('xdg-open|gvfs-open|gnome-open|x-www-browser|www-browser'.split('|')) @@ -360,6 +397,9 @@ def on_sso_redirect(self, step): break def on_sso_data_key(self, step): + if not _is_interactive(): + _fail_headless_auth(step) + return if self._show_sso_data_key_help: _stderr(f'\n{Fore.YELLOW}Device Approval Required for SSO{Fore.RESET}\n') _stderr(f'{Fore.CYAN}Select an approval method:{Fore.RESET}') diff --git a/keepercommander/commands/credential_provision.py b/keepercommander/commands/credential_provision.py index d19711948..b1a9008bd 100644 --- a/keepercommander/commands/credential_provision.py +++ b/keepercommander/commands/credential_provision.py @@ -1773,7 +1773,7 @@ def _create_dag_link( pam_config_record = vault.KeeperRecord.load(params, pam_config_uid) # Create RecordLink instance - record_link = RecordLink(record=pam_config_record, params=params, fail_on_corrupt=False, use_per_graph_endpoints=True) + record_link = RecordLink(record=pam_config_record, params=params, fail_on_corrupt=False) # Create belongs_to relationship: PAM User belongs_to PAM Configuration record_link.belongs_to( diff --git a/keepercommander/commands/discover/job_remove.py b/keepercommander/commands/discover/job_remove.py index 9a15a1755..fed280496 100644 --- a/keepercommander/commands/discover/job_remove.py +++ b/keepercommander/commands/discover/job_remove.py @@ -39,7 +39,7 @@ def execute(self, params, **kwargs): all_gateways = GatewayContext.all_gateways(params) def _find_job(configuration_record) -> Optional[Dict]: - jobs_obj = Jobs(record=configuration_record, params=params, use_per_graph_endpoints=True) + jobs_obj = Jobs(record=configuration_record, params=params) job_item = jobs_obj.get_job(job_id) if job_item is not None: return { diff --git a/keepercommander/commands/discover/job_start.py b/keepercommander/commands/discover/job_start.py index 4c80016e8..1093ec402 100644 --- a/keepercommander/commands/discover/job_start.py +++ b/keepercommander/commands/discover/job_start.py @@ -112,7 +112,7 @@ def execute(self, params, **kwargs): multi_conf_msg(gateway, err) return - jobs = Jobs(record=gateway_context.configuration, params=params, use_per_graph_endpoints=True) + jobs = Jobs(record=gateway_context.configuration, params=params) current_job_item = jobs.current_job removed_prior_job = None if current_job_item is not None: diff --git a/keepercommander/commands/discover/job_status.py b/keepercommander/commands/discover/job_status.py index f4e8c046d..25450489a 100644 --- a/keepercommander/commands/discover/job_status.py +++ b/keepercommander/commands/discover/job_status.py @@ -6,7 +6,7 @@ from ...display import bcolors from ...discovery_common.jobs import Jobs from ...discovery_common.infrastructure import Infrastructure -from ...keeper_dag.types import PamEndpoints +from ...keeper_dag.types import PamGraphId from ...discovery_common.types import DiscoveryDelta, DiscoveryObject from ...keeper_dag.dag import DAG from typing import Optional, Dict, List, TYPE_CHECKING @@ -160,7 +160,7 @@ def print_job_detail(params: KeeperParams, job_id: str): def _find_job(configuration_record) -> Optional[Dict]: - jobs_obj = Jobs(record=configuration_record, params=params, use_per_graph_endpoints=True) + jobs_obj = Jobs(record=configuration_record, params=params) job_item = jobs_obj.get_job(job_id) if job_item is not None: return { @@ -175,7 +175,7 @@ def _find_job(configuration_record) -> Optional[Dict]: if gateway_context is not None: jobs = payload["jobs"] job = jobs.get_job(job_id) # type: JobItem - infra = Infrastructure(record=gateway_context.configuration, params=params, use_per_graph_endpoints=True) + infra = Infrastructure(record=gateway_context.configuration, params=params) color = bcolors.OKBLUE status = "RUNNING" @@ -257,8 +257,7 @@ def _find_job(configuration_record) -> Optional[Dict]: print("Fall back to raw graph.") print("") dag = DAG(conn=infra.conn, record=infra.record, - read_endpoint=PamEndpoints.INFRASTRUCTURE, - write_endpoint=PamEndpoints.INFRASTRUCTURE) + graph_id=PamGraphId.INFRASTRUCTURE) print(dag.to_dot_raw(sync_point=job.sync_point, rank_dir="RL")) else: @@ -325,7 +324,7 @@ def execute(self, params, **kwargs): if len(gateway_context.gateway_name) > max_gateway_name: max_gateway_name = len(gateway_context.gateway_name) - jobs = Jobs(record=configuration_record, params=params, use_per_graph_endpoints=True) + jobs = Jobs(record=configuration_record, params=params) if show_history is True: job_list = reversed(jobs.history) else: diff --git a/keepercommander/commands/discover/result_process.py b/keepercommander/commands/discover/result_process.py index 54dabd931..1c61d234b 100644 --- a/keepercommander/commands/discover/result_process.py +++ b/keepercommander/commands/discover/result_process.py @@ -1334,7 +1334,7 @@ def _get_directory_info(domain: str, def remove_job(params: KeeperParams, configuration_record: KeeperRecord, job_id: str): try: - jobs = Jobs(record=configuration_record, params=params, use_per_graph_endpoints=True) + jobs = Jobs(record=configuration_record, params=params) jobs.cancel(job_id) print(f"{bcolors.OKGREEN}No items left to process. Removing completed discovery job.{bcolors.ENDC}") except Exception as err: @@ -1352,8 +1352,7 @@ def preview(self, job_item: JobItem, params: KeeperParams, gateway_context: Gate infra = Infrastructure(record=gateway_context.configuration, params=params, logger=logging, - debug_level=debug_level, - use_per_graph_endpoints=True) + debug_level=debug_level) infra.load(sync_point) configuration = None @@ -1512,7 +1511,7 @@ def execute(self, params: KeeperParams, **kwargs): # Get the current job. # There can only be one active job. # This will give us the sync point for the delta - jobs = Jobs(record=configuration_record, params=params, logger=logging, debug_level=debug_level, use_per_graph_endpoints=True) + jobs = Jobs(record=configuration_record, params=params, logger=logging, debug_level=debug_level) job_item = jobs.current_job if job_item is None: continue diff --git a/keepercommander/commands/discover/rule_add.py b/keepercommander/commands/discover/rule_add.py index a472c3dd8..2bf1d7bb9 100644 --- a/keepercommander/commands/discover/rule_add.py +++ b/keepercommander/commands/discover/rule_add.py @@ -134,7 +134,7 @@ def execute(self, params, **kwargs): return # If the rule passes its validation, then add control DAG - rules = Rules(record=gateway_context.configuration, params=params, use_per_graph_endpoints=True) + rules = Rules(record=gateway_context.configuration, params=params) new_rule = ActionRuleItem( name=kwargs.get("name"), action=kwargs.get("rule_action"), diff --git a/keepercommander/commands/discover/rule_list.py b/keepercommander/commands/discover/rule_list.py index 6f2f4145d..9819fc361 100644 --- a/keepercommander/commands/discover/rule_list.py +++ b/keepercommander/commands/discover/rule_list.py @@ -101,7 +101,7 @@ def execute(self, params, **kwargs): multi_conf_msg(gateway, err) return - rules = Rules(record=gateway_context.configuration, params=params, use_per_graph_endpoints=True) + rules = Rules(record=gateway_context.configuration, params=params) rule_list = rules.rule_list(rule_type=RuleTypeEnum.ACTION, search=kwargs.get("search")) # type: List[RuleItem] if len(rule_list) == 0: diff --git a/keepercommander/commands/discover/rule_remove.py b/keepercommander/commands/discover/rule_remove.py index ad081d3e4..b093dfa57 100644 --- a/keepercommander/commands/discover/rule_remove.py +++ b/keepercommander/commands/discover/rule_remove.py @@ -46,7 +46,7 @@ def execute(self, params, **kwargs): return try: - rules = Rules(record=gateway_context.configuration, params=params, use_per_graph_endpoints=True) + rules = Rules(record=gateway_context.configuration, params=params) if remove_all: rules.remove_all(RuleTypeEnum.ACTION) print(f"{bcolors.OKGREEN}All rules removed.{bcolors.ENDC}") diff --git a/keepercommander/commands/discover/rule_update.py b/keepercommander/commands/discover/rule_update.py index 0f8aeef5d..b3e063bf5 100644 --- a/keepercommander/commands/discover/rule_update.py +++ b/keepercommander/commands/discover/rule_update.py @@ -64,7 +64,7 @@ def execute(self, params, **kwargs): try: rule_id = kwargs.get("rule_id") - rules = Rules(record=gateway_context.configuration, params=params, use_per_graph_endpoints=True) + rules = Rules(record=gateway_context.configuration, params=params) rule_item = rules.get_rule_item(rule_type=RuleTypeEnum.ACTION, rule_id=rule_id) if rule_item is None: raise ValueError("Rule Id does not exist.") diff --git a/keepercommander/commands/discoveryrotation.py b/keepercommander/commands/discoveryrotation.py index 7671ef65d..d8e5b3b57 100644 --- a/keepercommander/commands/discoveryrotation.py +++ b/keepercommander/commands/discoveryrotation.py @@ -3372,7 +3372,7 @@ def record_rotate(self, params, record_uid, slient: bool = False): # Check the graph for the noop setting. record_link = RecordLink(record=pam_config, params=params, - fail_on_corrupt=False, use_per_graph_endpoints=True) + fail_on_corrupt=False) acl = record_link.get_acl(record_uid, pam_config.record_uid) if acl is not None and acl.rotation_settings is not None: is_noop = acl.rotation_settings.noop diff --git a/keepercommander/commands/helpers/record.py b/keepercommander/commands/helpers/record.py index 7769f5f7e..a2f97976f 100644 --- a/keepercommander/commands/helpers/record.py +++ b/keepercommander/commands/helpers/record.py @@ -1,9 +1,26 @@ +import re from typing import Set, Optional from ... import api +from ...error import CommandError from ...params import KeeperParams from ...subfolder import try_resolve_path +# Block shell chaining markers in `get` lookup tokens. +_GET_LOOKUP_CONTROL_CHARS_RE = re.compile(r'[\r\n\x00]') +_GET_LOOKUP_SHELL_METACHAR_RE = re.compile(r'[;|]') +_GET_LOOKUP_CHAIN_RE = re.compile(r'&&') + + +def raise_if_unsafe_get_lookup_token(token, command='get'): + # type: (str, str) -> None + if not token: + raise CommandError(command, 'Invalid record identifier: forbidden characters') + if (_GET_LOOKUP_CONTROL_CHARS_RE.search(token) + or _GET_LOOKUP_SHELL_METACHAR_RE.search(token) + or _GET_LOOKUP_CHAIN_RE.search(token)): + raise CommandError(command, 'Invalid record identifier: forbidden characters') + # Get record UID(s) given one of its identifiers: name (if current folder contains the record), path, or UID def get_record_uids(params, name): # type: (KeeperParams, str) -> Set[Optional[str]] diff --git a/keepercommander/commands/ksm.py b/keepercommander/commands/ksm.py index b4ec8ff17..e55463f19 100644 --- a/keepercommander/commands/ksm.py +++ b/keepercommander/commands/ksm.py @@ -10,6 +10,7 @@ # import argparse +import base64 import datetime import hmac import json @@ -1594,6 +1595,8 @@ def init_ksm_config(params, one_time_token, config_init, include_config_dict=Fal if 'KEY_OWNER_PUBLIC_KEY' in ConfigKeys.__members__ and ksm_conf_storage.config.get(ConfigKeys.KEY_OWNER_PUBLIC_KEY): config_dict[ConfigKeys.KEY_OWNER_PUBLIC_KEY.value] = ksm_conf_storage.config.get(ConfigKeys.KEY_OWNER_PUBLIC_KEY) + KSMCommand.validate_ksm_config_dict(config_dict) + converted_config = KSMCommand.convert_config_dict(config_dict, config_init) if include_config_dict: @@ -1604,13 +1607,49 @@ def init_ksm_config(params, one_time_token, config_init, include_config_dict=Fal else: return converted_config + @staticmethod + def validate_ksm_config_dict(config_dict): + """Verify a freshly generated KSM device config is intact. + + The config is handed out as an opaque base64 blob (gateway install, + k8s secret) and a corrupted clientId/privateKey only surfaces much + later as an unusable device, so fail loudly at the source instead. + + Note: if this validation passes but the consumer still receives a + malformed token, the base64 was most likely mangled by the console - + lines overwritten/lost during print (wrapped rows, redraws) or a bad + copy/paste. For comparison capture it losslessly with a redirect: + pam project import ... > out.json + """ + required_keys = ('hostname', 'clientId', 'privateKey', 'serverPublicKeyId', 'appKey') + for key in required_keys: + value = config_dict.get(key) + if not value or not isinstance(value, str): + raise Exception(f'Generated KSM config is invalid: "{key}" is missing or empty. ' + 'Please remove the client device and try again.') + for key in ('clientId', 'privateKey', 'appKey'): + try: + decoded = base64.b64decode(config_dict[key], validate=True) + except Exception: + raise Exception(f'Generated KSM config is invalid: "{key}" is not valid base64. ' + 'Please remove the client device and try again.') + if key == 'clientId' and len(decoded) != 64: # HMAC-SHA512 digest + raise Exception(f'Generated KSM config is invalid: "clientId" decodes to ' + f'{len(decoded)} bytes, expected 64. ' + 'Please remove the client device and try again.') + @staticmethod def convert_config_dict(config_dict, conversion_type='json'): config = json.dumps(config_dict) if conversion_type in ['b64', 'k8s']: - config = json_to_base64(config) + encoded = json_to_base64(config) + # the encoded blob must round-trip to the exact JSON it was built + # from - catches any corruption before the config is handed out + if base64.b64decode(encoded).decode('utf-8') != config: + raise Exception('KSM config base64 encoding failed the integrity check') + config = encoded if conversion_type == 'k8s': config = "\n" \ diff --git a/keepercommander/commands/pam_debug/acl.py b/keepercommander/commands/pam_debug/acl.py index cdf1918eb..9a7bca7d2 100644 --- a/keepercommander/commands/pam_debug/acl.py +++ b/keepercommander/commands/pam_debug/acl.py @@ -57,7 +57,7 @@ def execute(self, params: KeeperParams, **kwargs): record_link = RecordLink(record=gateway_context.configuration, params=params, logger=logging, - debug_level=debug_level, use_per_graph_endpoints=True) + debug_level=debug_level) user_record = vault.KeeperRecord.load(params, user_uid) # type: Optional[TypedRecord] if user_record is None: diff --git a/keepercommander/commands/pam_debug/gateway.py b/keepercommander/commands/pam_debug/gateway.py index 2f53a620a..3642668ce 100644 --- a/keepercommander/commands/pam_debug/gateway.py +++ b/keepercommander/commands/pam_debug/gateway.py @@ -49,11 +49,11 @@ def execute(self, params: KeeperParams, **kwargs): multi_conf_msg(gateway, err) return - infra = Infrastructure(record=gateway_context.configuration, params=params, fail_on_corrupt=False, use_per_graph_endpoints=True) + infra = Infrastructure(record=gateway_context.configuration, params=params, fail_on_corrupt=False) infra.load() - record_link = RecordLink(record=gateway_context.configuration, params=params, fail_on_corrupt=False, use_per_graph_endpoints=True) - user_service = UserService(record=gateway_context.configuration, params=params, fail_on_corrupt=False, use_per_graph_endpoints=True) + record_link = RecordLink(record=gateway_context.configuration, params=params, fail_on_corrupt=False) + user_service = UserService(record=gateway_context.configuration, params=params, fail_on_corrupt=False) if gateway_context is None: print(f" {self._f('Cannot get gateway information. Gateway may not be up.')}") diff --git a/keepercommander/commands/pam_debug/graph.py b/keepercommander/commands/pam_debug/graph.py index 8daeb4838..6c985a4cb 100644 --- a/keepercommander/commands/pam_debug/graph.py +++ b/keepercommander/commands/pam_debug/graph.py @@ -78,7 +78,7 @@ def _do_text_list_infra(self, params: KeeperParams, gateway_context: GatewayCont indent: int = 0): infra = Infrastructure(record=gateway_context.configuration, params=params, logger=logging, - debug_level=debug_level, use_per_graph_endpoints=True) + debug_level=debug_level) infra.load(sync_point=0) try: @@ -164,7 +164,7 @@ def _do_text_list_rl(self, params: KeeperParams, gateway_context: GatewayContext record_link = RecordLink(record=gateway_context.configuration, params=params, logger=logging, - debug_level=debug_level, use_per_graph_endpoints=True) + debug_level=debug_level) configuration = record_link.dag.get_root record = vault.KeeperRecord.load(params, configuration.uid) # type: Optional[TypedRecord] @@ -316,7 +316,7 @@ def _do_text_list_service(self, params: KeeperParams, gateway_context: GatewayCo indent: int = 0): user_service = UserService(record=gateway_context.configuration, params=params, logger=logging, - debug_level=debug_level, use_per_graph_endpoints=True) + debug_level=debug_level) configuration = user_service.dag.get_root def _handle(current_vertex: DAGVertex, parent_vertex: Optional[DAGVertex] = None, indent: int = 0): @@ -364,7 +364,7 @@ def _do_text_list_jobs(self, params: KeeperParams, gateway_context: GatewayConte indent: int = 0): infra = Infrastructure(record=gateway_context.configuration, params=params, logger=logging, - debug_level=debug_level, fail_on_corrupt=False, use_per_graph_endpoints=True) + debug_level=debug_level, fail_on_corrupt=False) infra.load(sync_point=0) pad = "" @@ -461,7 +461,7 @@ def _do_render_infra(self, params: KeeperParams, gateway_context: GatewayContext debug_level: int = 0): infra = Infrastructure(record=gateway_context.configuration, params=params, logger=logging, - debug_level=debug_level, use_per_graph_endpoints=True) + debug_level=debug_level) infra.load(sync_point=0) print("") @@ -487,7 +487,7 @@ def _do_render_rl(self, params: KeeperParams, gateway_context: GatewayContext, f rl = RecordLink(record=gateway_context.configuration, params=params, logger=logging, - debug_level=debug_level, use_per_graph_endpoints=True) + debug_level=debug_level) print("") dot_instance = rl.to_dot( @@ -510,7 +510,7 @@ def _do_render_service(self, params: KeeperParams, gateway_context: GatewayConte graph_format: str, debug_level: int = 0): service = UserService(record=gateway_context.configuration, params=params, logger=logging, - debug_level=debug_level, use_per_graph_endpoints=True) + debug_level=debug_level) print("") dot_instance = service.to_dot( @@ -532,7 +532,7 @@ def _do_render_service(self, params: KeeperParams, gateway_context: GatewayConte def _do_render_jobs(self, params: KeeperParams, gateway_context: GatewayContext, filepath: str, graph_format: str, debug_level: int = 0): - jobs = Jobs(record=gateway_context.configuration, params=params, logger=logging, debug_level=debug_level, use_per_graph_endpoints=True) + jobs = Jobs(record=gateway_context.configuration, params=params, logger=logging, debug_level=debug_level) print("") dot_instance = jobs.dag.to_dot() diff --git a/keepercommander/commands/pam_debug/info.py b/keepercommander/commands/pam_debug/info.py index c8f690518..09d9bcfa6 100644 --- a/keepercommander/commands/pam_debug/info.py +++ b/keepercommander/commands/pam_debug/info.py @@ -65,7 +65,7 @@ def execute(self, params: KeeperParams, **kwargs): for configuration_record in configuration_records: - record_link = RecordLink(record=configuration_record, params=params, use_per_graph_endpoints=True) + record_link = RecordLink(record=configuration_record, params=params) record_vertex = record_link.dag.get_vertex(record.record_uid) if record_vertex is not None and record_vertex.active is True: controller_uid = configuration_record.record_uid @@ -95,10 +95,10 @@ def execute(self, params: KeeperParams, **kwargs): print(f"{bcolors.FAIL}Could not find the gateway for configuration record.{controller_uid}{bcolors.ENDC}") return - infra = Infrastructure(record=configuration_record, params=params, use_per_graph_endpoints=True) + infra = Infrastructure(record=configuration_record, params=params) infra.load() - record_link = RecordLink(record=configuration_record, params=params, use_per_graph_endpoints=True) - user_service = UserService(record=configuration_record, params=params, use_per_graph_endpoints=True) + record_link = RecordLink(record=configuration_record, params=params) + user_service = UserService(record=configuration_record, params=params) print("") print(self._h("Record Information")) diff --git a/keepercommander/commands/pam_debug/link.py b/keepercommander/commands/pam_debug/link.py index f2489d2f0..766d27d16 100644 --- a/keepercommander/commands/pam_debug/link.py +++ b/keepercommander/commands/pam_debug/link.py @@ -53,7 +53,7 @@ def execute(self, params: KeeperParams, **kwargs): record_link = RecordLink(record=gateway_context.configuration, params=params, logger=logging, - debug_level=debug_level, use_per_graph_endpoints=True) + debug_level=debug_level) resource_record = vault.KeeperRecord.load(params, resource_uid) # type: Optional[TypedRecord] if resource_record is None: diff --git a/keepercommander/commands/pam_debug/rotation_setting.py b/keepercommander/commands/pam_debug/rotation_setting.py index fd2d6ed19..8e475b487 100644 --- a/keepercommander/commands/pam_debug/rotation_setting.py +++ b/keepercommander/commands/pam_debug/rotation_setting.py @@ -164,7 +164,7 @@ def execute(self, params: KeeperParams, **kwargs): f"It's a {resource_record.record_type}.{bcolors.ENDC}") return - record_link = RecordLink(record=configuration_record, params=params, use_per_graph_endpoints=True) + record_link = RecordLink(record=configuration_record, params=params) parent_uid = resource_record_uid or configuration_record_uid parent_vertex = record_link.get_record_link(parent_uid) diff --git a/keepercommander/commands/pam_debug/vertex.py b/keepercommander/commands/pam_debug/vertex.py index 35bd26acd..3d37e61dc 100644 --- a/keepercommander/commands/pam_debug/vertex.py +++ b/keepercommander/commands/pam_debug/vertex.py @@ -53,7 +53,7 @@ def execute(self, params: KeeperParams, **kwargs): return infra = Infrastructure(record=gateway_context.configuration, params=params, fail_on_corrupt=False, - debug_level=debug_level, use_per_graph_endpoints=True) + debug_level=debug_level) infra.load() vertex_uid = kwargs.get("vertex_uid") diff --git a/keepercommander/commands/pam_import/README.md b/keepercommander/commands/pam_import/README.md index 233f8f37a..706225643 100644 --- a/keepercommander/commands/pam_import/README.md +++ b/keepercommander/commands/pam_import/README.md @@ -600,6 +600,7 @@ Workflow controls how privileged access to a resource is gated: how many approva "disable_paste": true, "color_scheme": "gray-black", "font_size": "18", + "scrollback": 5000, "public_host_key": "", "command": "/bin/bash", "sftp": { @@ -715,6 +716,7 @@ Workflow controls how privileged access to a resource is gated: how many approva "disable_paste": true, "color_scheme": "gray-black", "font_size": "18", + "scrollback": 5000, "username_regex": "regex: username", "password_regex": "regex: password", "login_success_regex": "regex: login success", @@ -768,6 +770,7 @@ Workflow controls how privileged access to a resource is gated: how many approva "recording_include_keys": true, "color_scheme": "gray-black", "font_size": "18", + "scrollback": 5000, "namespace": "namespace", "pod_name": "pod name", "container": "container name", @@ -820,7 +823,7 @@ Workflow controls how privileged access to a resource is gated: how many approva }, "connection" : { "protocol": "mysql", - "_comment": "protocol types: ", + "_comment": "DB protocol types: ", "port": "2222", "allow_supply_user": true, "administrative_credentials": "admin1", @@ -828,6 +831,9 @@ Workflow controls how privileged access to a resource is gated: how many approva "recording_include_keys": true, "disable_copy": true, "disable_paste": true, + "color_scheme": "gray-black", + "font_size": "18", + "scrollback": 5000, "disable_csv_import": true, "disable_csv_export": true, "default_database": "db1" @@ -899,6 +905,7 @@ Workflow controls how privileged access to a resource is gated: how many approva "disable_paste": true, "color_scheme": "gray-black", "font_size": "18", + "scrollback": 5000, "public_host_key": "", "command": "/bin/bash", "sftp": { @@ -952,7 +959,17 @@ Workflow controls how privileged access to a resource is gated: how many approva "allowed_url_patterns": "*.com\n*.org", "allowed_resource_url_patterns": "*.org\n*.gov", "autofill_targets": "autofil_target1\nautofil_target2", - "ignore_server_cert": true + "ignore_server_cert": true, + "session_persistence": "none", + "_comment_session_persistence": "none | user | resource", + "allow_file_uploads": true, + "allow_file_downloads": true, + "disable_audio": false, + "audio_channels": 2, + "_comment_audio_channels": "1 (mono) or 2 (stereo)", + "audio_bps": 16, + "_comment_audio_bps": "8 or 16", + "audio_sample_rate": 44100 } } }, diff --git a/keepercommander/commands/pam_import/base.py b/keepercommander/commands/pam_import/base.py index 16cbf2214..ef2b6dabc 100644 --- a/keepercommander/commands/pam_import/base.py +++ b/keepercommander/commands/pam_import/base.py @@ -1901,6 +1901,13 @@ class ConnectionProtocol(Enum): SQLSERVER = "sql-server" POSTGRESQL = "postgresql" MYSQL = "mysql" + MARIADB = "mariadb" + ORACLE = "oracle" + MONGODB = "mongodb" + REDIS = "redis" + ELASTICSEARCH = "elasticsearch" + CLICKHOUSE = "clickhouse" + DYNAMODB = "dynamodb" HTTP = "http" class RDPSecurity(Enum): @@ -2094,9 +2101,11 @@ def sftp_enabled(connection_settings: Union[PamConnectionSettings, ConnectionSet class TerminalDisplayConnectionSettings: fontSizes: List[int] = [8,9,10,11,12,14,18,24,30,36,48,60,72,96] - def __init__(self, colorScheme: Optional[str] = None, fontSize: Optional[int] = None): + def __init__(self, colorScheme: Optional[str] = None, fontSize: Optional[int] = None, + scrollback: Optional[int] = None): self.colorScheme = colorScheme self.fontSize = fontSize + self.scrollback = scrollback @classmethod def load(cls, data: Union[str, dict]): @@ -2116,6 +2125,18 @@ def load(cls, data: Union[str, dict]): if closest_number != font_size: logging.error(f"Terminal Display Connection Settings - adjusted invalid font_size from: {obj.fontSize} to: {closest_number}") obj.fontSize = closest_number + + # scrollback (Maximum Scrollback Size): must be a positive integer. + val = data.get("scrollback", None) + if val is not None: + try: + parsed = int(val) + if parsed <= 0: + logging.warning(f"Terminal Display Connection Settings: scrollback must be a positive integer, got: {parsed}") + else: + obj.scrollback = parsed + except (TypeError, ValueError): + logging.warning(f"Terminal Display Connection Settings: invalid scrollback value: {val!r}") return obj class BaseConnectionSettings: @@ -2303,7 +2324,14 @@ def __init__( allowedResourceUrlPatterns: Optional[str] = None, httpCredentials: Optional[List[str]] = None, # autofill_credentials: login|pamUser autofillConfiguration: Optional[str] = None, - ignoreInitialSslCert: Optional[bool] = None + ignoreInitialSslCert: Optional[bool] = None, + sessionPersistence: Optional[str] = None, + allowFileUploads: Optional[bool] = None, + allowFileDownloads: Optional[bool] = None, + disableAudio: Optional[bool] = None, + audioChannels: Optional[int] = None, + audioBps: Optional[int] = None, + audioSampleRate: Optional[int] = None ): BaseConnectionSettings.__init__(self, port, allowSupplyUser, userRecords, recordingIncludeKeys) ClipboardConnectionSettings.__init__(self, disableCopy, disablePaste) @@ -2314,6 +2342,13 @@ def __init__( self.autofillConfiguration = autofillConfiguration self.ignoreInitialSslCert = ignoreInitialSslCert self.httpCredentialsUid = None # resolved from httpCredentials + self.sessionPersistence = sessionPersistence + self.allowFileUploads = allowFileUploads + self.allowFileDownloads = allowFileDownloads + self.disableAudio = disableAudio + self.audioChannels = audioChannels + self.audioBps = audioBps + self.audioSampleRate = audioSampleRate @classmethod def load(cls, data: Union[str, dict]): @@ -2343,6 +2378,40 @@ def load(cls, data: Union[str, dict]): obj.autofillConfiguration = multiline_to_str(parse_multiline(data, "autofill_targets", "Error parsing autofill_targets")) obj.ignoreInitialSslCert = utils.value_to_boolean(data.get("ignore_server_cert", None)) + val = data.get("session_persistence", None) + if isinstance(val, str) and val.lower() in ("none", "user", "resource"): + obj.sessionPersistence = val.lower() + obj.allowFileUploads = utils.value_to_boolean(data.get("allow_file_uploads", None)) + obj.allowFileDownloads = utils.value_to_boolean(data.get("allow_file_downloads", None)) + obj.disableAudio = utils.value_to_boolean(data.get("disable_audio", None)) + val = data.get("audio_channels", None) + if val is not None: + try: + parsed = int(val) + if parsed not in (1, 2): + logging.warning(f"ConnectionSettingsHTTP: audio_channels must be 1 or 2, got: {parsed}") + else: + obj.audioChannels = parsed + except (TypeError, ValueError): logging.warning(f"ConnectionSettingsHTTP: invalid audio_channels value: {val!r}") + val = data.get("audio_bps", None) + if val is not None: + try: + parsed = int(val) + if parsed not in (8, 16): + logging.warning(f"ConnectionSettingsHTTP: audio_bps must be 8 or 16, got: {parsed}") + else: + obj.audioBps = parsed + except (TypeError, ValueError): logging.warning(f"ConnectionSettingsHTTP: invalid audio_bps value: {val!r}") + val = data.get("audio_sample_rate", None) + if val is not None: + try: + parsed = int(val) + if parsed < 0: + logging.warning(f"ConnectionSettingsHTTP: audio_sample_rate must be non-negative, got: {parsed}") + else: + obj.audioSampleRate = parsed + except (TypeError, ValueError): logging.warning(f"ConnectionSettingsHTTP: invalid audio_sample_rate value: {val!r}") + return obj def to_record_dict(self): @@ -2377,6 +2446,21 @@ def to_record_dict(self): if self.ignoreInitialSslCert is not None and isinstance(self.ignoreInitialSslCert, bool): kvp["ignoreInitialSslCert"] = self.ignoreInitialSslCert + if self.sessionPersistence and isinstance(self.sessionPersistence, str) and self.sessionPersistence in ("none", "user", "resource"): + kvp["sessionPersistence"] = self.sessionPersistence + if self.allowFileUploads is not None and isinstance(self.allowFileUploads, bool): + kvp["allowFileUploads"] = self.allowFileUploads + if self.allowFileDownloads is not None and isinstance(self.allowFileDownloads, bool): + kvp["allowFileDownloads"] = self.allowFileDownloads + if self.disableAudio is not None and isinstance(self.disableAudio, bool): + kvp["disableAudio"] = self.disableAudio + if self.audioChannels is not None and type(self.audioChannels) is int: + kvp["audioChannels"] = self.audioChannels + if self.audioBps is not None and type(self.audioBps) is int: + kvp["audioBps"] = self.audioBps + if self.audioSampleRate is not None and type(self.audioSampleRate) is int: + kvp["audioSampleRate"] = self.audioSampleRate + return kvp def to_record_json(self): @@ -2484,6 +2568,7 @@ def __init__( # pylint: disable=R0917 disablePaste: Optional[bool] = None, colorScheme: Optional[str] = None, fontSize: Optional[int] = None, + scrollback: Optional[int] = None, usernameRegex: Optional[str] = None, passwordRegex: Optional[str] = None, loginSuccessRegex: Optional[str] = None, @@ -2491,7 +2576,7 @@ def __init__( # pylint: disable=R0917 ): BaseConnectionSettings.__init__(self, port, allowSupplyUser, userRecords, recordingIncludeKeys) ClipboardConnectionSettings.__init__(self, disableCopy, disablePaste) - TerminalDisplayConnectionSettings.__init__(self, colorScheme, fontSize) + TerminalDisplayConnectionSettings.__init__(self, colorScheme, fontSize, scrollback) self.usernameRegex = usernameRegex self.passwordRegex = passwordRegex self.loginSuccessRegex = loginSuccessRegex @@ -2522,6 +2607,7 @@ def load(cls, data: Union[str, dict]): if tcs: obj.colorScheme = tcs.colorScheme obj.fontSize = tcs.fontSize + obj.scrollback = tcs.scrollback val = data.get("username_regex", None) if isinstance(val, str): obj.usernameRegex = val @@ -2558,6 +2644,8 @@ def to_record_dict(self): kvp["colorScheme"] = self.colorScheme.strip() if self.fontSize and type(self.fontSize) is int and self.fontSize > 4: kvp["fontSize"] = str(self.fontSize) + if self.scrollback is not None and type(self.scrollback) is int and self.scrollback > 0: + kvp["scrollback"] = self.scrollback if self.usernameRegex and isinstance(self.usernameRegex, str) and self.usernameRegex.strip(): kvp["usernameRegex"] = self.usernameRegex.strip() if self.passwordRegex and isinstance(self.passwordRegex, str) and self.passwordRegex.strip(): @@ -2586,13 +2674,14 @@ def __init__( # pylint: disable=R0917 disablePaste: Optional[bool] = None, colorScheme: Optional[str] = None, fontSize: Optional[int] = None, + scrollback: Optional[int] = None, hostKey: Optional[str] = None, command: Optional[str] = None, sftp: Optional[SFTPRootDirectorySettings] = None ): BaseConnectionSettings.__init__(self, port, allowSupplyUser, userRecords, recordingIncludeKeys) ClipboardConnectionSettings.__init__(self, disableCopy, disablePaste) - TerminalDisplayConnectionSettings.__init__(self, colorScheme, fontSize) + TerminalDisplayConnectionSettings.__init__(self, colorScheme, fontSize, scrollback) self.hostKey = hostKey self.command = command self.sftp = sftp if isinstance(sftp, SFTPRootDirectorySettings) else None @@ -2622,6 +2711,7 @@ def load(cls, data: Union[str, dict]): if tcs: obj.colorScheme = tcs.colorScheme obj.fontSize = tcs.fontSize + obj.scrollback = tcs.scrollback val = data.get("public_host_key", None) if isinstance(val, str): obj.hostKey = val @@ -2656,6 +2746,8 @@ def to_record_dict(self): kvp["colorScheme"] = self.colorScheme.strip() if self.fontSize and type(self.fontSize) is int and self.fontSize > 4: kvp["fontSize"] = str(self.fontSize) + if self.scrollback is not None and type(self.scrollback) is int and self.scrollback > 0: + kvp["scrollback"] = self.scrollback if self.hostKey and isinstance(self.hostKey, str) and self.hostKey.strip(): kvp["hostKey"] = self.hostKey.strip() if self.command and isinstance(self.command, str) and self.command.strip(): @@ -2683,6 +2775,7 @@ def __init__( # pylint: disable=R0917 recordingIncludeKeys: Optional[bool] = None, colorScheme: Optional[str] = None, fontSize: Optional[int] = None, + scrollback: Optional[int] = None, ignoreCert: Optional[bool] = None, caCert: Optional[str] = None, namespace: Optional[str] = None, @@ -2692,7 +2785,7 @@ def __init__( # pylint: disable=R0917 clientKey: Optional[str] = None ): BaseConnectionSettings.__init__(self, port, allowSupplyUser, userRecords, recordingIncludeKeys) - TerminalDisplayConnectionSettings.__init__(self, colorScheme, fontSize) + TerminalDisplayConnectionSettings.__init__(self, colorScheme, fontSize, scrollback) self.ignoreCert = ignoreCert self.caCert = caCert self.namespace = namespace @@ -2721,6 +2814,7 @@ def load(cls, data: Union[str, dict]): if tcs: obj.colorScheme = tcs.colorScheme obj.fontSize = tcs.fontSize + obj.scrollback = tcs.scrollback val = data.get("namespace", None) if isinstance(val, str): obj.namespace = val @@ -2754,6 +2848,8 @@ def to_record_dict(self): kvp["colorScheme"] = self.colorScheme.strip() if self.fontSize and type(self.fontSize) is int and self.fontSize > 4: kvp["fontSize"] = str(self.fontSize) + if self.scrollback is not None and type(self.scrollback) is int and self.scrollback > 0: + kvp["scrollback"] = self.scrollback if self.namespace and isinstance(self.namespace, str) and self.namespace.strip(): kvp["namespace"] = self.namespace.strip() if self.pod and isinstance(self.pod, str) and self.pod.strip(): @@ -2857,9 +2953,10 @@ def to_record_json(self): rec_json = json.dumps(dict) return rec_json -class ConnectionSettingsSqlServer(BaseDatabaseConnectionSettings): - protocol = ConnectionProtocol.SQLSERVER - def __init__( # pylint: disable=W0246 +class CliCapableDatabaseConnectionSettings(BaseDatabaseConnectionSettings, TerminalDisplayConnectionSettings): + """mysql/postgresql/sql-server: CLI-capable DB protocols with terminal display (mirrors WV).""" + + def __init__( # pylint: disable=R0917 self, port: Optional[str] = None, # Override port from host allowSupplyUser: Optional[bool] = None, @@ -2869,17 +2966,21 @@ def __init__( # pylint: disable=W0246 disablePaste: Optional[bool] = None, database: Optional[str] = None, disableCsvExport: Optional[bool] = None, - disableCsvImport: Optional[bool] = None + disableCsvImport: Optional[bool] = None, + colorScheme: Optional[str] = None, + fontSize: Optional[int] = None, + scrollback: Optional[int] = None, ): - super().__init__(port, allowSupplyUser, userRecords, recordingIncludeKeys, - disableCopy, disablePaste, database, - disableCsvExport, disableCsvImport) + BaseDatabaseConnectionSettings.__init__( + self, port, allowSupplyUser, userRecords, recordingIncludeKeys, + disableCopy, disablePaste, database, disableCsvExport, disableCsvImport) + TerminalDisplayConnectionSettings.__init__(self, colorScheme, fontSize, scrollback) @classmethod def load(cls, data: Union[str, dict]): obj = cls() try: data = json.loads(data) if isinstance(data, str) else data - except: logging.error(f"SQLServer Connection Settings failed to load from: {str(data)[:80]}") + except: logging.error(f"CLI-capable Database Connection Settings failed to load from: {str(data)[:80]}") if not isinstance(data, dict): return obj bdcs = BaseDatabaseConnectionSettings.load(data) @@ -2896,104 +2997,100 @@ def load(cls, data: Union[str, dict]): obj.launch_credentials = getattr(bdcs, "launch_credentials", None) obj.launchRecordUid = getattr(bdcs, "launchRecordUid", None) + tcs = TerminalDisplayConnectionSettings.load(data) + if tcs: + obj.colorScheme = tcs.colorScheme + obj.fontSize = tcs.fontSize + obj.scrollback = tcs.scrollback + return obj - def to_record_dict(self): - dict = super().to_record_dict() - dict["protocol"] = ConnectionProtocol.SQLSERVER.value # pylint: disable=E1101 - return dict + def _apply_terminal_display_to_record_dict(self, kvp: Dict[str, Any]): + if self.colorScheme and isinstance(self.colorScheme, str) and self.colorScheme.strip(): + kvp["colorScheme"] = self.colorScheme.strip() + if self.fontSize and type(self.fontSize) is int and self.fontSize > 4: + kvp["fontSize"] = str(self.fontSize) + if self.scrollback is not None and type(self.scrollback) is int and self.scrollback > 0: + kvp["scrollback"] = self.scrollback -class ConnectionSettingsPostgreSQL(BaseDatabaseConnectionSettings): - protocol = ConnectionProtocol.POSTGRESQL - def __init__( # pylint: disable=W0246,R0917 - self, - port: Optional[str] = None, # Override port from host - allowSupplyUser: Optional[bool] = None, - userRecords: Optional[List[str]] = None, - recordingIncludeKeys: Optional[bool] = None, - disableCopy: Optional[bool] = None, - disablePaste: Optional[bool] = None, - database: Optional[str] = None, - disableCsvExport: Optional[bool] = None, - disableCsvImport: Optional[bool] = None - ): - super().__init__(port, allowSupplyUser, userRecords, recordingIncludeKeys, - disableCopy, disablePaste, database, - disableCsvExport, disableCsvImport) + def to_record_dict(self): + kvp = super().to_record_dict() + self._apply_terminal_display_to_record_dict(kvp) + return kvp - @classmethod - def load(cls, data: Union[str, dict]): - obj = cls() - try: data = json.loads(data) if isinstance(data, str) else data - except: logging.error(f"PostgreSQL Connection Settings failed to load from: {str(data)[:80]}") - if not isinstance(data, dict): return obj +class ConnectionSettingsSqlServer(CliCapableDatabaseConnectionSettings): + protocol = ConnectionProtocol.SQLSERVER + def to_record_dict(self): + d = super().to_record_dict() + d["protocol"] = ConnectionProtocol.SQLSERVER.value # pylint: disable=E1101 + return d - bdcs = BaseDatabaseConnectionSettings.load(data) - if bdcs: - obj.port = bdcs.port - obj.allowSupplyUser = bdcs.allowSupplyUser - obj.userRecords = bdcs.userRecords - obj.recordingIncludeKeys = bdcs.recordingIncludeKeys - obj.disableCopy = bdcs.disableCopy - obj.disablePaste = bdcs.disablePaste - obj.database = bdcs.database - obj.disableCsvExport = bdcs.disableCsvExport - obj.disableCsvImport = bdcs.disableCsvImport - obj.launch_credentials = getattr(bdcs, "launch_credentials", None) - obj.launchRecordUid = getattr(bdcs, "launchRecordUid", None) +class ConnectionSettingsPostgreSQL(CliCapableDatabaseConnectionSettings): + protocol = ConnectionProtocol.POSTGRESQL + def to_record_dict(self): + d = super().to_record_dict() + d["protocol"] = ConnectionProtocol.POSTGRESQL.value # pylint: disable=E1101 + return d - return obj +class ConnectionSettingsMySQL(CliCapableDatabaseConnectionSettings): + protocol = ConnectionProtocol.MYSQL + def to_record_dict(self): + d = super().to_record_dict() + d["protocol"] = ConnectionProtocol.MYSQL.value # pylint: disable=E1101 + return d + +# The remaining database protocols share BaseDatabaseConnectionSettings verbatim — they +# differ only by their wire protocol value. __init__ and load() are inherited unchanged +# (BaseDatabaseConnectionSettings.load uses cls(), so it builds the right subclass); +# only to_record_dict needs to stamp the protocol. +class ConnectionSettingsMariaDB(BaseDatabaseConnectionSettings): + protocol = ConnectionProtocol.MARIADB + def to_record_dict(self): + d = super().to_record_dict() + d["protocol"] = ConnectionProtocol.MARIADB.value # pylint: disable=E1101 + return d +class ConnectionSettingsOracle(BaseDatabaseConnectionSettings): + protocol = ConnectionProtocol.ORACLE def to_record_dict(self): - dict = super().to_record_dict() - dict["protocol"] = ConnectionProtocol.POSTGRESQL.value # pylint: disable=E1101 - return dict + d = super().to_record_dict() + d["protocol"] = ConnectionProtocol.ORACLE.value # pylint: disable=E1101 + return d -class ConnectionSettingsMySQL(BaseDatabaseConnectionSettings): - protocol = ConnectionProtocol.MYSQL - def __init__( # pylint: disable=W0246,R0917 - self, - port: Optional[str] = None, # Override port from host - allowSupplyUser: Optional[bool] = None, - userRecords: Optional[List[str]] = None, - recordingIncludeKeys: Optional[bool] = None, - disableCopy: Optional[bool] = None, - disablePaste: Optional[bool] = None, - database: Optional[str] = None, - disableCsvExport: Optional[bool] = None, - disableCsvImport: Optional[bool] = None - ): - super().__init__(port, allowSupplyUser, userRecords, recordingIncludeKeys, - disableCopy, disablePaste, database, - disableCsvExport, disableCsvImport) +class ConnectionSettingsMongoDB(BaseDatabaseConnectionSettings): + protocol = ConnectionProtocol.MONGODB + def to_record_dict(self): + d = super().to_record_dict() + d["protocol"] = ConnectionProtocol.MONGODB.value # pylint: disable=E1101 + return d - @classmethod - def load(cls, data: Union[str, dict]): - obj = cls() - try: data = json.loads(data) if isinstance(data, str) else data - except: logging.error(f"MySQL Connection Settings failed to load from: {str(data)[:80]}") - if not isinstance(data, dict): return obj +class ConnectionSettingsRedis(BaseDatabaseConnectionSettings): + protocol = ConnectionProtocol.REDIS + def to_record_dict(self): + d = super().to_record_dict() + d["protocol"] = ConnectionProtocol.REDIS.value # pylint: disable=E1101 + return d - bdcs = BaseDatabaseConnectionSettings.load(data) - if bdcs: - obj.port = bdcs.port - obj.allowSupplyUser = bdcs.allowSupplyUser - obj.userRecords = bdcs.userRecords - obj.recordingIncludeKeys = bdcs.recordingIncludeKeys - obj.disableCopy = bdcs.disableCopy - obj.disablePaste = bdcs.disablePaste - obj.database = bdcs.database - obj.disableCsvExport = bdcs.disableCsvExport - obj.disableCsvImport = bdcs.disableCsvImport - obj.launch_credentials = getattr(bdcs, "launch_credentials", None) - obj.launchRecordUid = getattr(bdcs, "launchRecordUid", None) +class ConnectionSettingsElasticsearch(BaseDatabaseConnectionSettings): + protocol = ConnectionProtocol.ELASTICSEARCH + def to_record_dict(self): + d = super().to_record_dict() + d["protocol"] = ConnectionProtocol.ELASTICSEARCH.value # pylint: disable=E1101 + return d - return obj +class ConnectionSettingsClickHouse(BaseDatabaseConnectionSettings): + protocol = ConnectionProtocol.CLICKHOUSE + def to_record_dict(self): + d = super().to_record_dict() + d["protocol"] = ConnectionProtocol.CLICKHOUSE.value # pylint: disable=E1101 + return d +class ConnectionSettingsDynamoDB(BaseDatabaseConnectionSettings): + protocol = ConnectionProtocol.DYNAMODB def to_record_dict(self): - dict = super().to_record_dict() - dict["protocol"] = ConnectionProtocol.MYSQL.value # pylint: disable=E1101 - return dict + d = super().to_record_dict() + d["protocol"] = ConnectionProtocol.DYNAMODB.value # pylint: disable=E1101 + return d PamConnectionSettings = Optional[ Union[ @@ -3004,7 +3101,14 @@ def to_record_dict(self): ConnectionSettingsKubernetes, ConnectionSettingsSqlServer, ConnectionSettingsPostgreSQL, - ConnectionSettingsMySQL + ConnectionSettingsMySQL, + ConnectionSettingsMariaDB, + ConnectionSettingsOracle, + ConnectionSettingsMongoDB, + ConnectionSettingsRedis, + ConnectionSettingsElasticsearch, + ConnectionSettingsClickHouse, + ConnectionSettingsDynamoDB ] ] @@ -3111,7 +3215,14 @@ def __init__( ConnectionSettingsKubernetes, ConnectionSettingsSqlServer, ConnectionSettingsPostgreSQL, - ConnectionSettingsMySQL + ConnectionSettingsMySQL, + ConnectionSettingsMariaDB, + ConnectionSettingsOracle, + ConnectionSettingsMongoDB, + ConnectionSettingsRedis, + ConnectionSettingsElasticsearch, + ConnectionSettingsClickHouse, + ConnectionSettingsDynamoDB ] @classmethod @@ -3194,16 +3305,30 @@ def is_blank_instance(obj, skiplist: Optional[List[str]] = None): return False return True +_CLI_CAPABLE_DB_MEMBERS = ( + ConnectionProtocol.MYSQL, + ConnectionProtocol.POSTGRESQL, + ConnectionProtocol.SQLSERVER, +) + def is_database_protocol(protocol): """ - Returns True if the protocol is one of the database protocols: MYSQL, POSTGRESQL, or SQLSERVER. + Returns True if the protocol is one of the database protocols: MYSQL, POSTGRESQL, + SQLSERVER, MARIADB, ORACLE, MONGODB, REDIS, ELASTICSEARCH, CLICKHOUSE, or DYNAMODB. - Accepts ConnectionProtocol or the string wire values (e.g. 'mysql', 'postgresql', 'sql-server'). + Accepts ConnectionProtocol or the string wire values (e.g. 'mysql', 'oracle', 'sql-server'). """ db_members = ( ConnectionProtocol.MYSQL, ConnectionProtocol.POSTGRESQL, ConnectionProtocol.SQLSERVER, + ConnectionProtocol.MARIADB, + ConnectionProtocol.ORACLE, + ConnectionProtocol.MONGODB, + ConnectionProtocol.REDIS, + ConnectionProtocol.ELASTICSEARCH, + ConnectionProtocol.CLICKHOUSE, + ConnectionProtocol.DYNAMODB, ) db_values = {m.value for m in db_members} if isinstance(protocol, ConnectionProtocol): @@ -3212,6 +3337,25 @@ def is_database_protocol(protocol): return str(protocol).strip().lower() in db_values return False +def is_cli_capable_db_protocol(protocol): + """ + Returns True for mysql/postgresql/sql-server — DB protocols with CLI/TTY sessions + (mirrors WV: terminal display is shown when !isKeeperDbOnlyProtocol). + """ + cli_values = {m.value for m in _CLI_CAPABLE_DB_MEMBERS} + if isinstance(protocol, ConnectionProtocol): + return protocol in _CLI_CAPABLE_DB_MEMBERS + if isinstance(protocol, str): + return str(protocol).strip().lower() in cli_values + return False + +def is_keeper_db_only_protocol(protocol): + """ + Returns True for keeperDb-only DB protocols (mariadb, oracle, mongodb, redis, + elasticsearch, clickhouse, dynamodb). Mirrors WV isKeeperDbOnlyProtocol. + """ + return is_database_protocol(protocol) and not is_cli_capable_db_protocol(protocol) + def get_sftp_attribute(obj, name: str) -> str: # Get one of pam_settings.connection.sftp.{sftpResource,sftpResourceUid,sftpUser,sftpUserUid} value: str = "" diff --git a/keepercommander/commands/pam_import/keeper_ai_settings.py b/keepercommander/commands/pam_import/keeper_ai_settings.py index 1b84cf7c3..7d4056c12 100644 --- a/keepercommander/commands/pam_import/keeper_ai_settings.py +++ b/keepercommander/commands/pam_import/keeper_ai_settings.py @@ -16,7 +16,7 @@ from ...keeper_dag import DAG, EdgeType from ...keeper_dag.exceptions import DAGPathException from ...keeper_dag.connection.commander import Connection -from ...keeper_dag.types import PamEndpoints +from ...keeper_dag.types import PamGraphId from ...vault import PasswordRecord from ... import vault from ...display import bcolors @@ -84,17 +84,15 @@ def list_resource_data_edges( encrypted_transmission_key=encrypted_transmission_key, encrypted_session_token=encrypted_session_token, transmission_key=transmission_key, - use_read_protobuf=True, - use_write_protobuf=True + use_read_protobuf=False, + use_write_protobuf=False ) # Load the DAG linking_dag = DAG( conn=conn, record=dag_record, - graph_id=0, - read_endpoint=PamEndpoints.PAM, - write_endpoint=PamEndpoints.PAM + graph_id=PamGraphId.PAM.value ) try: linking_dag.load() @@ -103,7 +101,7 @@ def list_resource_data_edges( return [] # Get the resource vertex - resource_vertex = linking_dag.get_vertex(resource_uid) + resource_vertex = linking_dag.get_vertex_by_uid(resource_uid) if not resource_vertex: logging.warning(f"Resource vertex {resource_uid} not found in DAG") return [] @@ -189,17 +187,15 @@ def get_resource_settings( encrypted_transmission_key=encrypted_transmission_key, encrypted_session_token=encrypted_session_token, transmission_key=transmission_key, - use_read_protobuf=True, - use_write_protobuf=True + use_read_protobuf=False, + use_write_protobuf=False ) # Load the DAG linking_dag = DAG( conn=conn, record=dag_record, - graph_id=0, - read_endpoint=PamEndpoints.PAM, - write_endpoint=PamEndpoints.PAM + graph_id=PamGraphId.PAM.value ) try: linking_dag.load() @@ -211,7 +207,7 @@ def get_resource_settings( return None # Get the resource vertex - resource_vertex = linking_dag.get_vertex(resource_uid) + resource_vertex = linking_dag.get_vertex_by_uid(resource_uid) if not resource_vertex: logging.warning(f"Resource vertex {resource_uid} not found in DAG") return None @@ -414,10 +410,11 @@ def set_resource_keeper_ai_settings( validates caller access then writes the `ai_settings` DAG DATA edge on the resource server-side. - Fallback (env var `KEEPER_DAG_LB_FALLBACK=1`, default ON): on + Fallback (env var `KEEPER_DAG_LB_FALLBACK`, default OFF / strict mode): on `RRC_NOT_ALLOWED*` from krouter, fall back to the legacy direct DAG-write path (`_set_resource_keeper_ai_settings_legacy`). Gateway then - enforces at runtime. Set the env var to `0` for strict mode (denials propagate). + enforces at runtime. Default (unset/`0`) propagates denials; set to `1` to opt + into fallback. Args: params: KeeperParams instance @@ -436,6 +433,17 @@ def set_resource_keeper_ai_settings( encrypted_content = encrypt_aes(json.dumps(settings).encode(), record_key) + # krouter's configure_resource only writes a settings edge when it loads the + # resource's existing edges (loopEdges), which it does only for requests that + # carry meta/jit/connection (UserRest.kt). A keeperAiSettings-only request + # leaves loopEdges null and the ai_settings write is silently dropped. The Web + # Vault avoids this by always sending meta alongside the AI settings, so mirror + # that: include the resource's current meta in the same request. + meta_bytes = None + current_meta = get_resource_settings(params, resource_uid, 'meta', resolved_config_uid) + if isinstance(current_meta, dict): + meta_bytes = json.dumps(current_meta).encode() + # Primary: Layer-B configure_resource (permission-checked). from ..pam.router_helper import router_configure_resource, get_router_url host = get_router_url(params) @@ -446,6 +454,8 @@ def set_resource_keeper_ai_settings( networkUid=url_safe_str_to_bytes(resolved_config_uid), keeperAiSettings=encrypted_content, ) + if meta_bytes is not None: + rq.meta = meta_bytes try: router_configure_resource(params, rq) logging.debug(f"Saved KeeperAI settings via configure_resource for {resource_uid}") @@ -527,20 +537,19 @@ def _set_resource_keeper_ai_settings_legacy( encrypted_transmission_key=encrypted_transmission_key, encrypted_session_token=encrypted_session_token, transmission_key=transmission_key, - use_read_protobuf=True, - use_write_protobuf=True, + use_read_protobuf=False, + use_write_protobuf=False, ) linking_dag = DAG( conn=conn, record=dag_record, - read_endpoint=PamEndpoints.PAM, - write_endpoint=PamEndpoints.PAM, + graph_id=PamGraphId.PAM.value, decrypt=True, ) linking_dag.load() - resource_vertex = linking_dag.get_vertex(resource_uid) + resource_vertex = linking_dag.get_vertex_by_uid(resource_uid) if not resource_vertex: logging.warning(f"Resource vertex {resource_uid} not found in DAG") return False @@ -649,20 +658,19 @@ def _set_resource_jit_settings_legacy( encrypted_transmission_key=encrypted_transmission_key, encrypted_session_token=encrypted_session_token, transmission_key=transmission_key, - use_read_protobuf=True, - use_write_protobuf=True, + use_read_protobuf=False, + use_write_protobuf=False, ) linking_dag = DAG( conn=conn, record=dag_record, - read_endpoint=PamEndpoints.PAM, - write_endpoint=PamEndpoints.PAM, + graph_id=PamGraphId.PAM.value, decrypt=True, ) linking_dag.load() - resource_vertex = linking_dag.get_vertex(resource_uid) + resource_vertex = linking_dag.get_vertex_by_uid(resource_uid) if not resource_vertex: logging.warning(f"Resource vertex {resource_uid} not found in DAG") return False @@ -725,19 +733,17 @@ def refresh_meta_to_latest( encrypted_transmission_key=encrypted_transmission_key, encrypted_session_token=encrypted_session_token, transmission_key=transmission_key, - use_read_protobuf=True, - use_write_protobuf=True + use_read_protobuf=False, + use_write_protobuf=False ) linking_dag = DAG( conn=conn, record=dag_record, - graph_id=0, - read_endpoint=PamEndpoints.PAM, - write_endpoint=PamEndpoints.PAM, + graph_id=PamGraphId.PAM.value, decrypt=True ) linking_dag.load() - resource_vertex = linking_dag.get_vertex(resource_uid) + resource_vertex = linking_dag.get_vertex_by_uid(resource_uid) if not resource_vertex: return False meta_edges = [e for e in (resource_vertex.edges or []) @@ -795,20 +801,18 @@ def refresh_link_to_config_to_latest( encrypted_transmission_key=encrypted_transmission_key, encrypted_session_token=encrypted_session_token, transmission_key=transmission_key, - use_read_protobuf=True, - use_write_protobuf=True + use_read_protobuf=False, + use_write_protobuf=False ) linking_dag = DAG( conn=conn, record=dag_record, - graph_id=0, - read_endpoint=PamEndpoints.PAM, - write_endpoint=PamEndpoints.PAM, + graph_id=PamGraphId.PAM.value, decrypt=True ) linking_dag.load() - resource_vertex = linking_dag.get_vertex(resource_uid) - config_vertex = linking_dag.get_vertex(config_uid) + resource_vertex = linking_dag.get_vertex_by_uid(resource_uid) + config_vertex = linking_dag.get_vertex_by_uid(config_uid) if not resource_vertex or not config_vertex: return False # Re-add LINK (path empty, content {}) so it becomes latest, above KEY added by JIT/AI @@ -933,15 +937,13 @@ def inspect_resource_in_graph( encrypted_transmission_key=encrypted_transmission_key, encrypted_session_token=encrypted_session_token, transmission_key=transmission_key, - use_read_protobuf=True, - use_write_protobuf=True + use_read_protobuf=False, + use_write_protobuf=False ) linking_dag = DAG( conn=conn, record=dag_record, - graph_id=0, - read_endpoint=PamEndpoints.PAM, - write_endpoint=PamEndpoints.PAM, + graph_id=PamGraphId.PAM.value, decrypt=not show_raw_content ) linking_dag.load() @@ -1035,19 +1037,17 @@ def get_resource_domain_dir_uid( encrypted_transmission_key=encrypted_transmission_key, encrypted_session_token=encrypted_session_token, transmission_key=transmission_key, - use_read_protobuf=True, - use_write_protobuf=True + use_read_protobuf=False, + use_write_protobuf=False ) linking_dag = DAG( conn=conn, record=dag_record, - graph_id=0, - read_endpoint=PamEndpoints.PAM, - write_endpoint=PamEndpoints.PAM + graph_id=PamGraphId.PAM.value ) linking_dag.load() - resource_vertex = linking_dag.get_vertex(resource_uid) + resource_vertex = linking_dag.get_vertex_by_uid(resource_uid) if not resource_vertex: return None @@ -1131,19 +1131,18 @@ def _set_resource_domain_dir_legacy( encrypted_transmission_key=encrypted_transmission_key, encrypted_session_token=encrypted_session_token, transmission_key=transmission_key, - use_read_protobuf=True, - use_write_protobuf=True, + use_read_protobuf=False, + use_write_protobuf=False, ) linking_dag = DAG( conn=conn, record=dag_record, - read_endpoint=PamEndpoints.PAM, - write_endpoint=PamEndpoints.PAM, + graph_id=PamGraphId.PAM.value, decrypt=True, ) linking_dag.load() - resource_vertex = linking_dag.get_vertex(resource_uid) + resource_vertex = linking_dag.get_vertex_by_uid(resource_uid) if not resource_vertex: logging.warning(f"Resource vertex {resource_uid} not found in DAG") return False @@ -1156,12 +1155,12 @@ def _set_resource_domain_dir_legacy( old_dir_uid = edge.head_uid break if old_dir_uid and old_dir_uid != dir_uid: - old_dir_vertex = linking_dag.get_vertex(old_dir_uid) + old_dir_vertex = linking_dag.get_vertex_by_uid(old_dir_uid) if old_dir_vertex: resource_vertex.disconnect_from(old_dir_vertex) logging.debug(f"Disconnected old domain LINK edge to {old_dir_uid}") - dir_vertex = linking_dag.get_vertex(dir_uid) + dir_vertex = linking_dag.get_vertex_by_uid(dir_uid) if not dir_vertex: logging.warning(f"Directory vertex {dir_uid} not found in DAG") return False diff --git a/keepercommander/commands/pam_service/add.py b/keepercommander/commands/pam_service/add.py index c03441e3b..e471a5c4f 100644 --- a/keepercommander/commands/pam_service/add.py +++ b/keepercommander/commands/pam_service/add.py @@ -60,9 +60,9 @@ def execute(self, params: KeeperParams, **kwargs): return user_service = UserService(record=gateway_context.configuration, params=params, fail_on_corrupt=False, - agent=f"Cmdr/{__version__}", use_per_graph_endpoints=True) + agent=f"Cmdr/{__version__}") record_link = RecordLink(record=gateway_context.configuration, params=params, fail_on_corrupt=False, - agent=f"Cmdr/{__version__}", use_per_graph_endpoints=True) + agent=f"Cmdr/{__version__}") ############### diff --git a/keepercommander/commands/pam_service/list.py b/keepercommander/commands/pam_service/list.py index b19b34283..995be7d1d 100644 --- a/keepercommander/commands/pam_service/list.py +++ b/keepercommander/commands/pam_service/list.py @@ -42,7 +42,7 @@ def execute(self, params: KeeperParams, **kwargs): return user_service = UserService(record=gateway_context.configuration, params=params, fail_on_corrupt=False, - agent=f"Cmdr/{__version__}", use_per_graph_endpoints=True) + agent=f"Cmdr/{__version__}") service_map = {} for resource_vertex in user_service.dag.get_root.has_vertices(edge_type=EdgeType.LINK): diff --git a/keepercommander/commands/pam_service/remove.py b/keepercommander/commands/pam_service/remove.py index 58f545592..e4b68d25f 100644 --- a/keepercommander/commands/pam_service/remove.py +++ b/keepercommander/commands/pam_service/remove.py @@ -57,7 +57,7 @@ def execute(self, params: KeeperParams, **kwargs): return user_service = UserService(record=gateway_context.configuration, params=params, fail_on_corrupt=False, - agent=f"Cmdr/{__version__}", use_per_graph_endpoints=True) + agent=f"Cmdr/{__version__}") machine_record = vault.KeeperRecord.load(params, machine_uid) # type: Optional[TypedRecord] if machine_record is None: diff --git a/keepercommander/commands/record.py b/keepercommander/commands/record.py index 0c8d152de..7977d0ab5 100644 --- a/keepercommander/commands/record.py +++ b/keepercommander/commands/record.py @@ -20,6 +20,7 @@ import re from functools import reduce from typing import Dict, Any, List, Optional, Iterable, Tuple, Set +from .helpers.record import raise_if_unsafe_get_lookup_token from colorama import Fore, Back, Style @@ -294,6 +295,8 @@ def execute(self, params, **kwargs): if not uid: raise CommandError('get', 'UID parameter is required') + raise_if_unsafe_get_lookup_token(uid) + fmt = kwargs.get('format') or 'detail' # First try to interpret as UID @@ -632,6 +635,41 @@ def _format_expiration(expiration_value): 'value': r.notes, } fields.append(field) + # custom fields (v2 type in c['type']; v3 encode type as "type:label" in c['name']) + # Keep _MASKED_FIELD_TYPES in sync with _MASKED_TYPES in record.py. + # 'note' = Secured Note — sensitive, masked by design. + # 'passkey' omitted: early-exit handler renders only non-sensitive sub-fields. + _MASKED_FIELD_TYPES = frozenset({ + 'secret', 'pinCode', 'note', 'json', 'oneTimeCode', + 'paymentCard', 'bankAccount', 'keyPair', 'securityQuestion', + }) + unmask = kwargs.get('unmask') is True + for cf in (r.custom_fields or []): + cf_value = cf.get('value') + if not cf_value and cf_value != 0: + continue + cf_name = str(cf.get('name') or cf.get('type') or '') + is_sensitive = (cf.get('type') in _MASKED_FIELD_TYPES or + cf_name.split(':')[0] in _MASKED_FIELD_TYPES) + is_sq = (cf.get('type') == 'securityQuestion' or + cf_name == 'securityQuestion' or + cf_name.startswith('securityQuestion:')) + if is_sq: + val = cf_value + entry = val[0] if (isinstance(val, list) and val) else val + if isinstance(entry, dict): + display_val = { + 'question': entry.get('question') or '', + 'answer': (entry.get('answer') or '') if unmask else '********', + } + else: + display_val = cf_value if unmask else '********' + else: + display_val = cf_value if (unmask or not is_sensitive) else '********' + fields.append({ + 'name': cf_name, + 'value': display_val, + }) print(json.dumps(fields, indent=2)) else: diff --git a/keepercommander/commands/record_edit.py b/keepercommander/commands/record_edit.py index aceb2c301..9781bf431 100644 --- a/keepercommander/commands/record_edit.py +++ b/keepercommander/commands/record_edit.py @@ -42,6 +42,10 @@ record_add_parser.add_argument('-f', '--force', dest='force', action='store_true', help='ignore warnings') record_add_parser.add_argument('-t', '--title', dest='title', action='store', help='record title') record_add_parser.add_argument('-rt', '--record-type', dest='record_type', action='store', help='record type') +record_add_parser.add_argument('--labels', dest='labels', action='store', choices=['on', 'off'], + help='label fields in standard record-type definition. "on" (default) keeps legacy ' + 'labels; "off" omits them. "off" affects only RT-definition fields without their ' + 'own label; RT-definition custom labels and explicitly provided labels are preserved.') record_add_parser.add_argument('-n', '--notes', dest='notes', action='store', help='record notes') record_add_parser.add_argument('--folder', dest='folder', action='store', help='folder name or UID to store record') @@ -865,11 +869,15 @@ def execute(self, params, **kwargs): raise CommandError('record-add', f'Record type \"{record_type}\" cannot be found.') record = vault.TypedRecord() record.type_name = record_type + omit_labels = (kwargs.get('labels') or 'on').lower() == 'off' for rf in rt_fields: ref = rf.get('$ref') if not ref: continue - label = rf.get('label') or ref + # Use the label from the record-type definition when present (both modes). + # When the definition has none: legacy ("on") falls back to the field type; + # "off" leaves it empty so the redundant type-name label is omitted (matches Vault UI). + label = rf.get('label') or ('' if omit_labels else ref) required = rf.get('required', False) default_value = None if ref == 'appFiller': @@ -1027,7 +1035,7 @@ def _sync_password_to_pam(self, params: KeeperParams, record: vault.TypedRecord, if plugin_name == 'azureadpwd': # Import Azure AD plugin - from ...plugins.azureadpwd import azureadpwd + from ..plugins.azureadpwd import azureadpwd # Call the rotate function with PAM config record success = azureadpwd.rotate(pam_record, password) @@ -1039,7 +1047,7 @@ def _sync_password_to_pam(self, params: KeeperParams, record: vault.TypedRecord, elif plugin_name == 'awspswd': # Import AWS plugin and common rotator - from ...plugins.awspswd import aws_passwd + from ..plugins.awspswd import aws_passwd # Extract AWS credentials from PAM config aws_access_key = None @@ -1070,7 +1078,6 @@ def _sync_password_to_pam(self, params: KeeperParams, record: vault.TypedRecord, # Set AWS credentials in environment or use profile if aws_access_key and aws_secret_key: - import os original_access_key = os.environ.get('AWS_ACCESS_KEY_ID') original_secret_key = os.environ.get('AWS_SECRET_ACCESS_KEY') @@ -1137,7 +1144,6 @@ def _send_onboarding_email( # Load email configuration from .email_commands import find_email_config_record, load_email_config_from_record from ..email_service import EmailSender, build_onboarding_email - from .helpers.timeout import parse_timeout config_uid = find_email_config_record(params, email_config_name) if not config_uid: @@ -1162,7 +1168,7 @@ def _send_onboarding_email( else: # minutes minutes = expire_seconds // 60 expiration_text = f"{minutes} minute{'s' if minutes > 1 else ''}" - except: + except Exception: expiration_text = expiration # fallback to original if parsing fails html_body = build_onboarding_email( diff --git a/keepercommander/commands/tunnel/port_forward/TunnelGraph.py b/keepercommander/commands/tunnel/port_forward/TunnelGraph.py index b5a149760..7ac683d1c 100644 --- a/keepercommander/commands/tunnel/port_forward/TunnelGraph.py +++ b/keepercommander/commands/tunnel/port_forward/TunnelGraph.py @@ -3,7 +3,7 @@ from .tunnel_helpers import generate_random_bytes, get_config_uid from ....keeper_dag import DAG, EdgeType from ....keeper_dag.connection.commander import Connection -from ....keeper_dag.types import RefType, PamEndpoints +from ....keeper_dag.types import RefType, PamGraphId from ....keeper_dag.vertex import DAGVertex from ....display import bcolors from ....vault import PasswordRecord @@ -113,11 +113,11 @@ def __init__(self, params, encrypted_session_token, encrypted_transmission_key, encrypted_transmission_key=self.encrypted_transmission_key, encrypted_session_token=self.encrypted_session_token, transmission_key=self.transmission_key, - use_read_protobuf=True, - use_write_protobuf=True + use_read_protobuf=False, + use_write_protobuf=False ) self.linking_dag = DAG(conn=self.conn, record=self.record, - read_endpoint=PamEndpoints.PAM, write_endpoint=PamEndpoints.PAM) + graph_id=PamGraphId.PAM.value) try: self.linking_dag.load() except Exception as e: @@ -127,15 +127,15 @@ def __init__(self, params, encrypted_session_token, encrypted_transmission_key, def resource_belongs_to_config(self, resource_uid): if not self.linking_dag.has_graph: return False - resource_vertex = self.linking_dag.get_vertex(resource_uid) - config_vertex = self.linking_dag.get_vertex(self.record.record_uid) + resource_vertex = self.linking_dag.get_vertex_by_uid(resource_uid) + config_vertex = self.linking_dag.get_vertex_by_uid(self.record.record_uid) return resource_vertex and config_vertex.has(resource_vertex, EdgeType.LINK) def user_belongs_to_config(self, user_uid): if not self.linking_dag.has_graph: return False - user_vertex = self.linking_dag.get_vertex(user_uid) - config_vertex = self.linking_dag.get_vertex(self.record.record_uid) + user_vertex = self.linking_dag.get_vertex_by_uid(user_uid) + config_vertex = self.linking_dag.get_vertex_by_uid(self.record.record_uid) res_content = False if user_vertex and config_vertex and config_vertex.has(user_vertex, EdgeType.ACL): acl_edge = user_vertex.get_edge(config_vertex, EdgeType.ACL) diff --git a/keepercommander/commands/tunnel_and_connections.py b/keepercommander/commands/tunnel_and_connections.py index d30fb6f73..0314248b6 100644 --- a/keepercommander/commands/tunnel_and_connections.py +++ b/keepercommander/commands/tunnel_and_connections.py @@ -2623,7 +2623,19 @@ def _latency_sampler(): class PAMConnectionEditCommand(Command): choices = ['on', 'off', 'default'] - protocols = ['', 'http', 'kubernetes', 'mysql', 'postgresql', 'rdp', 'sql-server', 'ssh', 'telnet', 'vnc'] + # Database connection protocols (pamDatabase). Kept as its own list so it is the single + # source of truth for DB-protocol checks (e.g. --scrollback) and so a future DB-vs-non-DB + # split or per-record-type gate is a one-line change. No per-type gating is applied today: + # any supported PAM resource may use any protocol (see validate_pam_connection in pam_import). + db_protocols = ['clickhouse', 'dynamodb', 'elasticsearch', 'mariadb', 'mongodb', + 'mysql', 'oracle', 'postgresql', 'redis', 'sql-server'] + # CLI-capable DB protocols (mysql/postgresql/sql-server): terminal display incl. scrollback. + # Mirrors WV isKeeperDbOnlyProtocol — keeperDb-only DBs have no TTY session settings. + cli_capable_db_protocols = ['mysql', 'postgresql', 'sql-server'] + # Non-database protocols (terminal/remote for pamMachine/pamDirectory, http for RBI). + non_db_protocols = ['http', 'kubernetes', 'rdp', 'ssh', 'telnet', 'vnc'] + # Protocols offered by --protocol ('' clears the protocol). + protocols = [''] + sorted(non_db_protocols + db_protocols) parser = argparse.ArgumentParser(prog='pam connection edit') parser.add_argument('record', type=str, action='store', help='The record UID or path of the PAM ' 'resource record with network information to use for connections') @@ -2653,8 +2665,8 @@ class PAMConnectionEditCommand(Command): help='Toggle Key Events settings') parser.add_argument('--scrollback', '-sb', required=False, dest='scrollback', action='store', help='Maximum Scrollback Size (terminal history). Integer to set, ' - 'empty string to remove. Supported only for pamDatabase (any DB protocol) and ' - 'pamMachine/pamDirectory (ssh/telnet/kubernetes).') + 'empty string to remove. Supported for pamDatabase (mysql/postgresql/sql-server) ' + 'and pamMachine/pamDirectory (ssh/telnet/kubernetes).') parser.add_argument('--rotate-on-termination', required=False, dest='rotate_on_termination', choices=['on', 'off'], help='Rotate launch credentials when the PAM session ends (DAG resource meta)') @@ -2708,8 +2720,7 @@ def execute(self, params, **kwargs): scrollback_clear = False scrollback_value = None # parsed int, or None to skip apply if scrollback_arg is not None: - db_scrollback_protocols = {'mysql', 'postgresql', 'sql-server', 'mariadb', 'oracle', - 'mongodb', 'redis', 'elasticsearch', 'clickhouse', 'dynamodb'} + db_scrollback_protocols = set(PAMConnectionEditCommand.cli_capable_db_protocols) terminal_scrollback_protocols = {'ssh', 'telnet', 'kubernetes'} if record_type == 'pamDatabase': allowed_protocols = db_scrollback_protocols @@ -3327,6 +3338,11 @@ class PAMRbiEditCommand(Command): parser.add_argument('--audio-sample-rate', '-sr', dest='audio_sample_rate', type=int, help='Audio sample rate in Hz (e.g., 44100, 48000)') + # Session Persistence + parser.add_argument('--session-persistence', '-sp', dest='session_persistence', + choices=['none', 'user', 'resource', 'default'], + help='RBI session persistence (none/user/resource; default = unset)') + # Utility parser.add_argument('--silent', '-s', required=False, dest='silent', action='store_true', help='Silent mode - don\'t print PAM User, PAM Config etc.') @@ -3357,6 +3373,7 @@ def execute(self, params, **kwargs): audio_channels = kwargs.get('audio_channels') # int or None audio_bit_depth = kwargs.get('audio_bit_depth') # int or None audio_sample_rate = kwargs.get('audio_sample_rate') # int or None + session_persistence = kwargs.get('session_persistence') # none/user/resource/default/None if not record_name: raise CommandError('pam rbi edit', 'Record parameter is required.') @@ -3375,7 +3392,8 @@ def execute(self, params, **kwargs): disable_audio is not None, audio_channels is not None, audio_bit_depth is not None, - audio_sample_rate is not None + audio_sample_rate is not None, + session_persistence is not None ]) if not (autofill or key_events or config_name or rbi or recording or has_new_settings): @@ -3542,6 +3560,31 @@ def update_connection_int(field_name, value): else: logging.debug(f'{field_name} is already set to {value} on record={record_uid}') + # Helper for enum string fields (e.g. sessionPersistence): set a literal value, + # or remove the key on 'default' so the gateway/vault applies its own default. + # Coerces 'connection' to a dict locally (idempotent) so this is safe even if + # 'connection' is missing/null/"" — no dependency on the earlier coercion. Removal + # keys on presence, so a present-but-null value is cleared too. + def update_connection_choice(field_name, value): + nonlocal dirty + rbs_fld = record.get_typed_field('pamRemoteBrowserSettings') + if rbs_fld and rbs_fld.value and isinstance(rbs_fld.value[0], dict): + _coerce_settings_subdicts(rbs_fld.value[0], 'connection') + connection = rbs_fld.value[0]['connection'] + if value == 'default': + if field_name in connection: + connection.pop(field_name, None) + dirty = True + logging.debug(f'Removed {field_name} (set to default) on record={record_uid}') + else: + logging.debug(f'{field_name} is already unset on record={record_uid}') + elif connection.get(field_name) != value: + connection[field_name] = value + dirty = True + logging.debug(f'Set {field_name}={value} on record={record_uid}') + else: + logging.debug(f'{field_name} is already set to {value} on record={record_uid}') + # Browser Settings - allowUrlManipulation (on/off/default) if allow_url_navigation: update_connection_toggle('allowUrlManipulation', allow_url_navigation) @@ -3594,6 +3637,10 @@ def update_connection_int(field_name, value): if audio_sample_rate is not None: update_connection_int('audioSampleRate', audio_sample_rate) + # Session Persistence - sessionPersistence (none/user/resource; default removes) + if session_persistence: + update_connection_choice('sessionPersistence', session_persistence) + if dirty: record_management.update_record(params, record) api.sync_down(params) diff --git a/keepercommander/discovery_common/__version__.py b/keepercommander/discovery_common/__version__.py index 87bb711d7..7025f5029 100644 --- a/keepercommander/discovery_common/__version__.py +++ b/keepercommander/discovery_common/__version__.py @@ -1 +1 @@ -__version__ = '1.1.14' +__version__ = '1.1.15' diff --git a/keepercommander/discovery_common/infrastructure.py b/keepercommander/discovery_common/infrastructure.py index 744c7951b..ee2675955 100644 --- a/keepercommander/discovery_common/infrastructure.py +++ b/keepercommander/discovery_common/infrastructure.py @@ -394,7 +394,9 @@ def to_dot(self, graph_format: str = "svg", show_hex_uid: bool = False, head_uids.append(edge.head_uid) def _render_edge(e): - + # _render_edge is invoked immediately within the loop below, so capturing the + # loop variables v/content is safe. + # pylint: disable=cell-var-from-loop edge_color = "grey" style = "solid" @@ -439,7 +441,7 @@ def _render_edge(e): tooltip=edge_tip) for head_uid in head_uids: - version, edge = v.get_highest_edge_version(head_uid) + _, edge = v.get_highest_edge_version(head_uid) _render_edge(edge) data_edge = v.get_data() diff --git a/keepercommander/discovery_common/process.py b/keepercommander/discovery_common/process.py index 836243c41..a99ee2b06 100644 --- a/keepercommander/discovery_common/process.py +++ b/keepercommander/discovery_common/process.py @@ -377,7 +377,7 @@ def _directory_exists(self, domain: str, directory_info_func: Callable, context: for provider_vertex in provider_vertices: content = DiscoveryObject.get_discovery_object(provider_vertex) found = False - for domain in domains: + for domain in domains: # pylint: disable=redefined-argument-from-local for provider_domain in content.item.info.get("domains", []): if domain.lower() in provider_domain.lower(): found = True @@ -453,7 +453,7 @@ def _find_directory_user(self, found_vertex = None if find_user is not None: - user, domain = split_user_and_domain(find_user) + user, _ = split_user_and_domain(find_user) if user_content.item.user.lower() == user.lower(): found_vertex = user_vertex elif user_content.item.user.lower() == find_user.lower(): @@ -1185,7 +1185,7 @@ def _process_admin_user(self, # We need to populate the id and uid of the content, now that we have data in the content. self.populate_admin_content_ids(admin_content, resource_vertex) - ad_user, ad_domain = split_user_and_domain(admin_content.item.user) + _, ad_domain = split_user_and_domain(admin_content.item.user) if ad_domain is not None and admin_content.item.source == LOCAL_USER: self.logger.debug("The admin is an directory user, but the source is set to a local user") diff --git a/keepercommander/discovery_common/rm_types.py b/keepercommander/discovery_common/rm_types.py index a647ca933..1aa513d88 100644 --- a/keepercommander/discovery_common/rm_types.py +++ b/keepercommander/discovery_common/rm_types.py @@ -465,7 +465,7 @@ class RmOracleUserAddMeta(RmMetaBase): class RmOracleRoleAddMeta(RmMetaBase): - not_identified: bool = False, + not_identified: bool = False identified_by_password: Optional[str] = None identified_using: Optional[str] = None identified_externally: bool = False diff --git a/keepercommander/discovery_common/types.py b/keepercommander/discovery_common/types.py index 5a79a9115..6942e87eb 100644 --- a/keepercommander/discovery_common/types.py +++ b/keepercommander/discovery_common/types.py @@ -736,7 +736,7 @@ def has_dn(self, user) -> bool: return False - + class PromptResult(BaseModel): # "add" and "ignore" are the only action diff --git a/keepercommander/discovery_common/verify.py b/keepercommander/discovery_common/verify.py index 5e90ab5a7..dd6e6a3fd 100644 --- a/keepercommander/discovery_common/verify.py +++ b/keepercommander/discovery_common/verify.py @@ -403,7 +403,7 @@ def _check(vertex: DAGVertex, indent: int = 0): # Get all the child vertices, allow self ref, so we can delete it if not already deleted. for next_vertex in vertex.has_vertices(allow_self_ref=True): if next_vertex.uid == vertex.uid: - version, edge = next_vertex.get_highest_edge_version(vertex.uid) + _, edge = next_vertex.get_highest_edge_version(vertex.uid) if edge.edge_type == EdgeType.DELETION: continue else: diff --git a/keepercommander/display.py b/keepercommander/display.py index 8e23f0e15..7958c0214 100644 --- a/keepercommander/display.py +++ b/keepercommander/display.py @@ -254,7 +254,17 @@ def print_record(params, record_uid): class Spinner: - """Animated spinner for long-running operations.""" + """Animated spinner for long-running operations. + + WARNING: every frame starts with '\\r' and overwrites the current console + row with the frame, message and padding spaces. A spinner that is still + (or again) ticking while other code prints can therefore erase chunks of + large multi-row output - especially long single-line blobs like base64 + KSM config tokens (`pam project import`/`pam gateway new` access_token), + silently corrupting what the user copies. Callers MUST guarantee stop() + via try/finally, and any change here must keep frames out of stopped + spinners and out of redirected/captured output. + """ # Claude-style spinner frames FRAMES = ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏'] @@ -272,26 +282,49 @@ def _animate(self): message = self.message or '' visible_len = len(message) + 2 # frame + space + message pad = max(0, self._last_visible_len - visible_len) - sys.stdout.write(f'\r{Fore.CYAN}{frame}{Fore.RESET} {message}' + (' ' * pad)) - sys.stdout.flush() + # Re-check right before the write: a stale tick firing after + # stop() has returned (join timed out on a blocked console write) + # would '\r'-overwrite output printed in the meantime - erasing a + # row of large output such as a printed KSM config token. + # Skipping costs only one cosmetic frame. + if not self.running: + break + # Frames go to stderr (codebase convention for '\r' progress, see + # sox/aram/record_totp): on stdout they land inside redirected or + # captured command output, e.g. corrupting the base64 config in + # `pam project import ... > out.json`. + sys.stderr.write(f'\r{Fore.CYAN}{frame}{Fore.RESET} {message}' + (' ' * pad)) + sys.stderr.flush() self._last_visible_len = visible_len + pad idx += 1 time.sleep(0.08) - # Clear the line when done - clear_len = max(self._last_visible_len, len(self.message or '') + 2) - sys.stdout.write('\r' + ' ' * clear_len + '\r') - sys.stdout.flush() - self._last_visible_len = 0 def start(self): + # Spin only on a real console; in a redirected/captured stream the + # frames cannot animate and would pile up as '\r' noise in the data. + try: + if not sys.stderr.isatty(): + return + except Exception: + return self.running = True self.thread = threading.Thread(target=self._animate, daemon=True) self.thread.start() def stop(self): self.running = False - if self.thread: - self.thread.join(timeout=0.5) + if not self.thread: + return + self.thread.join(timeout=0.5) + self.thread = None + # Clear the spinner line from the calling thread after the animator + # has exited; clearing from the animator raced with output printed + # right after stop() and could blank part of it (erased lines in + # large output, e.g. ksm config tokens) with the padding spaces. + clear_len = max(self._last_visible_len, len(self.message or '') + 2) + sys.stderr.write('\r' + ' ' * clear_len + '\r') + sys.stderr.flush() + self._last_visible_len = 0 def post_login_summary(record_count=0, breachwatch_count=0, show_tips=True): diff --git a/keepercommander/importer/lastpass/fetcher.py b/keepercommander/importer/lastpass/fetcher.py index 30b1049e4..dc95ae42d 100644 --- a/keepercommander/importer/lastpass/fetcher.py +++ b/keepercommander/importer/lastpass/fetcher.py @@ -1,4 +1,5 @@ # coding: utf-8 +import codecs import hashlib import json import logging @@ -127,6 +128,7 @@ def request_login(username, password, key_iteration_count, multifactor_password= 'username': username, 'hash': make_hash(username, password, key_iteration_count), 'iterations': key_iteration_count, + 'includeprivatekeyenc': 1, } if multifactor_password: @@ -176,7 +178,14 @@ def create_session(parsed_response, key_iteration_count): if parsed_response.tag == 'ok': session_id = parsed_response.attrib.get('sessionid') if isinstance(session_id, str): - return Session(session_id, key_iteration_count) + session = Session(session_id, key_iteration_count) + try: + privatekeyenc = parsed_response.attrib.get('privatekeyenc') + if isinstance(privatekeyenc, str): + session.privatekeyenc = privatekeyenc.encode('utf-8') + except Exception: + pass + return session def login_error(parsed_response): diff --git a/keepercommander/importer/lastpass/parser.py b/keepercommander/importer/lastpass/parser.py index f7d0772ef..d75639ba5 100644 --- a/keepercommander/importer/lastpass/parser.py +++ b/keepercommander/importer/lastpass/parser.py @@ -129,12 +129,15 @@ def parse_ACCT(chunk, encryption_key, shared_folder): return account -def parse_PRIK(chunk, encryption_key): +def parse_PRIK(payload, encryption_key): """Parse PRIK chunk which contains private RSA key""" - decrypted = decode_aes256('cbc', - encryption_key[:16], - decode_hex(chunk.payload), - encryption_key) + if payload[0] == ord('!'): + decrypted = decode_aes256_base64_auto(payload, encryption_key) + else: + decrypted = decode_aes256('cbc', + encryption_key[:16], + decode_hex(payload), + encryption_key) hex_key = re.match(br'^LastPassPrivateKey<(?P.*)>LastPassPrivateKey$', decrypted).group('hex_key') return decode_hex(hex_key) diff --git a/keepercommander/importer/lastpass/session.py b/keepercommander/importer/lastpass/session.py index f68a0f487..dd564d265 100644 --- a/keepercommander/importer/lastpass/session.py +++ b/keepercommander/importer/lastpass/session.py @@ -3,6 +3,7 @@ class Session(object): def __init__(self, id, key_iteration_count): self.id = id self.key_iteration_count = key_iteration_count + self.privatekeyenc = None def __eq__(self, other): return self.id == other.id and self.key_iteration_count == other.key_iteration_count diff --git a/keepercommander/importer/lastpass/vault.py b/keepercommander/importer/lastpass/vault.py index 16dee4cc7..ac283c096 100644 --- a/keepercommander/importer/lastpass/vault.py +++ b/keepercommander/importer/lastpass/vault.py @@ -7,7 +7,7 @@ from tempfile import mkdtemp from typing import Optional -from . import fetcher +from . import fetcher, session from . import parser from .exceptions import InvalidResponseError from .shared_folder import LastpassSharedFolder @@ -46,7 +46,7 @@ def __init__(self, blob, encryption_key, session, tmpdir=None, shared_folder_det self.errors = set() self.shared_folders = [] self.attachments = [] - self.accounts = self.parse_accounts(chunks, encryption_key) + self.accounts = self.parse_accounts(chunks, encryption_key, session.privatekeyenc) self.tmpdir = None self.proxies = kwargs.get('proxies') self.certificate_check = kwargs.get('certificate_check') @@ -83,11 +83,16 @@ def __init__(self, blob, encryption_key, session, tmpdir=None, shared_folder_det def is_complete(self, chunks): return len(chunks) > 0 and chunks[-1].id == b'ENDM' and chunks[-1].payload == b'OK' - def parse_accounts(self, chunks, encryption_key): + def parse_accounts(self, chunks, encryption_key, privatekeyenc = None): accounts = [] key = encryption_key rsa_private_key = None # type: Optional[bytes] + if isinstance(privatekeyenc, bytes): + try: + rsa_private_key = parser.parse_PRIK(privatekeyenc, encryption_key) + except Exception: + pass shared_folder = None last_account = None for i in chunks: @@ -100,7 +105,8 @@ def parse_accounts(self, chunks, encryption_key): if last_account: accounts.append(last_account) elif i.id == b'PRIK': - rsa_private_key = parser.parse_PRIK(i, encryption_key) + if not rsa_private_key: + rsa_private_key = parser.parse_PRIK(i.payload, encryption_key) elif i.id == b'SHAR': # After SHAR chunk all the following accounts are encrypted with a new key share = parser.parse_SHAR(i, encryption_key, rsa_private_key) diff --git a/keepercommander/keeper_dag/__version__.py b/keepercommander/keeper_dag/__version__.py index 874042f3b..ed7133b36 100644 --- a/keepercommander/keeper_dag/__version__.py +++ b/keepercommander/keeper_dag/__version__.py @@ -1 +1 @@ -__version__ = '1.1.10' # pragma: no cover +__version__ = '1.1.11' # pragma: no cover diff --git a/keepercommander/keeper_dag/connection/__init__.py b/keepercommander/keeper_dag/connection/__init__.py index 762acde53..24ef3ba66 100644 --- a/keepercommander/keeper_dag/connection/__init__.py +++ b/keepercommander/keeper_dag/connection/__init__.py @@ -31,6 +31,8 @@ class ConnectionBase: ADD_DATA = "/add_data" SYNC = "/sync" + MULTI_SYNC = "/multi_sync" + GET_LEAFS = "/get_leafs" TIMEOUT = 30 @@ -59,7 +61,7 @@ def __init__(self, if self.log_transactions_dir is None: self.log_transactions_dir = "." - if self.log_transactions is True: + if self.log_transactions: self.logger.info("keeper-dag transaction logging is ENABLED; " f"write directory at {self.log_transactions_dir}") @@ -99,8 +101,11 @@ def get_encrypted_payload_data(encrypted_payload_data: bytes) -> bytes: @staticmethod def get_router_host(server_hostname: str): - if server_hostname and '://' in server_hostname: # accept URL-formatted inputs + # Defensive: accept URL-formatted inputs (e.g. "https://keepersecurity.com") + # and extract the bare hostname before the GovCloud subdomain check. + if server_hostname and '://' in server_hostname: server_hostname = server_hostname.split('://', 1)[1].split('/', 1)[0] + # Only PROD GovCloud strips the subdomain (workaround for prod infrastructure). # DEV/QA GOV (govcloud.dev.keepersecurity.us, govcloud.qa.keepersecurity.us) keep govcloud. if server_hostname == 'govcloud.keepersecurity.us': @@ -222,10 +227,10 @@ def sync(self, sync_query: Union[SyncQuery, gs_pb2.GraphSyncQuery], graph_id: Optional[int] = None, endpoint: Optional[str] = None, - agent: Optional[str] = None) -> bytes: + agent: Optional[str] = None) -> Optional[bytes]: if agent is None: - f"keeper-dag/{__version__}" + agent = f"keeper-dag/{__version__}" endpoint = self._endpoint(ConnectionBase.SYNC, endpoint) self.logger.debug(f"endpoint {endpoint}") @@ -238,13 +243,13 @@ def sync(self, headers=headers, payload=sync_query) - if self.use_read_protobuf: + if payload is not None and self.use_read_protobuf: try: self.logger.debug(f"decrypt payload with transmission key {kotlin_bytes(self.transmission_key)}") payload = self.get_encrypted_payload_data(payload) payload = decrypt_aes(payload, self.transmission_key) except Exception as err: - self.logger.error(f"Could not decrypt protobuf graph sync response: {type(err)}, {err}") + self.logger.error(f"Could not decrypt protobuf graph sync response: {err}") self.write_transaction_log( graph_id=graph_id, @@ -290,7 +295,7 @@ def add_data(self, agent: Optional[str] = None): if agent is None: - f"keeper-dag/{__version__}" + agent = f"keeper-dag/{__version__}" endpoint = self._endpoint(ConnectionBase.ADD_DATA, endpoint) self.logger.debug(f"endpoint {endpoint}") @@ -331,3 +336,135 @@ def add_data(self, error=str(err) ) raise DAGException(f"Could not create a new DAG structure: {err}") + + def multi_sync(self, + multi_query: Union[BaseModel, gs_pb2.GraphSyncMultiQuery], + graph_id: Optional[int] = None, + endpoint: Optional[str] = None, + agent: Optional[str] = None) -> bytes: + """POST a GraphSyncMultiQuery to /multi_sync. + + Used by per-graph reads: after `get_leafs` discovers the stream refs + rooted at the graph's origin, `multi_sync` fetches sync data for all + those streams in one round-trip. Mirrors `sync()` in transport shape + (encrypt/headers, decrypt-on-read, transaction log, error handling). + """ + if agent is None: + agent = f"keeper-dag/{__version__}" + + endpoint = self._endpoint(ConnectionBase.MULTI_SYNC, endpoint) + self.logger.debug(f"endpoint {endpoint}") + + try: + multi_query, headers = self.payload_and_headers(multi_query) + payload = self.rest_call_to_router(http_method="POST", + endpoint=endpoint, + agent=agent, + headers=headers, + payload=multi_query) + + if self.use_read_protobuf: + try: + self.logger.debug(f"decrypt payload with transmission key {kotlin_bytes(self.transmission_key)}") + payload = self.get_encrypted_payload_data(payload) + payload = decrypt_aes(payload, self.transmission_key) + except Exception as err: + self.logger.error(f"Could not decrypt protobuf graph multi-sync response: {type(err)}, {err}") + + self.write_transaction_log( + graph_id=graph_id, + request=multi_query, + response=payload, + agent=agent, + endpoint=endpoint, + error=None + ) + + return payload + + except DAGConnectionException as err: + self.write_transaction_log( + graph_id=graph_id, + request=multi_query, + response=None, + agent=agent, + endpoint=endpoint, + error=str(err) + ) + raise err + except Exception as err: + self.write_transaction_log( + graph_id=graph_id, + request=multi_query, + response=None, + agent=agent, + endpoint=endpoint, + error=str(err) + ) + raise DAGException(f"Could not load the DAG structure (multi_sync): {err}") + + def get_leafs(self, + leafs_query: Union[BaseModel, gs_pb2.GraphSyncLeafsQuery], + graph_id: Optional[int] = None, + endpoint: Optional[str] = None, + agent: Optional[str] = None) -> bytes: + """POST a GraphSyncLeafsQuery to /get_leafs. + + Returns the serialized GraphSyncRefsResult — the list of stream refs + rooted at the queried vertices. Used as the discovery step before a + `multi_sync` call (per the per-graph read pattern that Web Vault + already uses). + """ + if agent is None: + agent = f"keeper-dag/{__version__}" + + endpoint = self._endpoint(ConnectionBase.GET_LEAFS, endpoint) + self.logger.debug(f"endpoint {endpoint}") + + try: + leafs_query, headers = self.payload_and_headers(leafs_query) + payload = self.rest_call_to_router(http_method="POST", + endpoint=endpoint, + agent=agent, + headers=headers, + payload=leafs_query) + + if self.use_read_protobuf: + try: + self.logger.debug(f"decrypt payload with transmission key {kotlin_bytes(self.transmission_key)}") + payload = self.get_encrypted_payload_data(payload) + payload = decrypt_aes(payload, self.transmission_key) + except Exception as err: + self.logger.error(f"Could not decrypt protobuf get_leafs response: {type(err)}, {err}") + + self.write_transaction_log( + graph_id=graph_id, + request=leafs_query, + response=payload, + agent=agent, + endpoint=endpoint, + error=None + ) + + return payload + + except DAGConnectionException as err: + self.write_transaction_log( + graph_id=graph_id, + request=leafs_query, + response=None, + agent=agent, + endpoint=endpoint, + error=str(err) + ) + raise err + except Exception as err: + self.write_transaction_log( + graph_id=graph_id, + request=leafs_query, + response=None, + agent=agent, + endpoint=endpoint, + error=str(err) + ) + raise DAGException(f"Could not get leafs: {err}") diff --git a/keepercommander/keeper_dag/connection/ksm.py b/keepercommander/keeper_dag/connection/ksm.py index 58fa4fac0..b4ad0279f 100644 --- a/keepercommander/keeper_dag/connection/ksm.py +++ b/keepercommander/keeper_dag/connection/ksm.py @@ -57,7 +57,7 @@ def __init__(self, self.use_read_protobuf = False if self.use_write_protobuf: self.logger.info("KSM cannot use protobuf for writing to the graph, using JSON.") - self.use_read_protobuf = False + self.use_write_protobuf = False if InMemoryKeyValueStorage.is_base64(config): config = utils.base64_to_string(config) @@ -165,6 +165,7 @@ def authenticate(self, attempt = 0 while True: + err_msg = "no error message" try: attempt += 1 response = requests.get(url, diff --git a/keepercommander/keeper_dag/connection/local.py b/keepercommander/keeper_dag/connection/local.py index 6a0dac42b..0567860d5 100644 --- a/keepercommander/keeper_dag/connection/local.py +++ b/keepercommander/keeper_dag/connection/local.py @@ -582,6 +582,41 @@ def sync(self, hasMore=has_more ).model_dump_json().encode() + def multi_sync(self, + multi_query: Union[gs_pb2.GraphSyncMultiQuery, Any], + graph_id: Optional[int] = None, + endpoint: Optional[str] = None, + agent: Optional[str] = None) -> bytes: + """Local mirror of the network per-graph ``multi_sync``. + + The local SQLite store has no per-graph URL routing, so each sub-query + in the ``GraphSyncMultiQuery`` is run through this connection's own + ``sync()`` — identical stream / sync-point / graph-id semantics as the + single-stream read and the save path — and the per-stream results are + assembled into the same multi-stream envelope the network endpoint + returns: ``GraphSyncMultiResult`` (protobuf) or ``{"results": [...]}`` + (JSON). + """ + is_protobuf = isinstance(multi_query, gs_pb2.GraphSyncMultiQuery) + queries = list(multi_query.queries) + + if is_protobuf: + multi = gs_pb2.GraphSyncMultiResult() + for sub_query in queries: + single = self.sync(sub_query, graph_id=graph_id, + endpoint=endpoint, agent=agent) + result = gs_pb2.GraphSyncResult() + result.ParseFromString(single) + multi.results.add().CopyFrom(result) + return multi.SerializeToString() + + results = [] + for sub_query in queries: + single = self.sync(sub_query, graph_id=graph_id, + endpoint=endpoint, agent=agent) + results.append(json.loads(single)) + return json.dumps({"results": results}).encode() + def debug_dump(self) -> str: ret = "" diff --git a/keepercommander/keeper_dag/dag.py b/keepercommander/keeper_dag/dag.py index 85a2d1296..a9066528e 100644 --- a/keepercommander/keeper_dag/dag.py +++ b/keepercommander/keeper_dag/dag.py @@ -15,7 +15,7 @@ import importlib import traceback import sys -from typing import Optional, Union, List, Any, Tuple, TYPE_CHECKING +from typing import Optional, Union, List, Any, Tuple, Dict, TYPE_CHECKING if TYPE_CHECKING: from .connection import ConnectionBase @@ -225,7 +225,7 @@ def close(self): try: # Safely get the root vertex without creating a new one if hasattr(self, '_vertices') and hasattr(self, 'uid') and hasattr(self, '_uid_lookup'): - if len(self._vertices) > 0 and self.uid in self._uid_lookup: + if len(self._vertices) > 0 and self.uid is not None and self.uid in self._uid_lookup: idx = self._uid_lookup[self.uid] if idx < len(self._vertices): root = self._vertices[idx] @@ -298,7 +298,7 @@ def debug_stacktrace(self): trc = 'Traceback (most recent call last):\n' msg = trc + ''.join(traceback.format_list(stack)) if exc is not None: - msg += ' ' + traceback.format_exc().lstrip(trc) + msg += ' ' + traceback.format_exc().removeprefix(trc) self.debug(msg) def __str__(self): @@ -310,6 +310,8 @@ def __str__(self): for v in self.all_vertices: ret += f" * {v.uid}, Keys: {v.keychain}, Active: {v.active}\n" for e in v.edges: + if e is None: + continue if e.edge_type == EdgeType.DATA: ret += " + has a DATA edge" if e.content is not None: @@ -504,11 +506,28 @@ def get_vertices_by_path_value(self, path: str, inc_deleted: bool = False) -> Li for vertex in vertices: for edge in vertex.edges: - if edge.path == path: + if edge is not None and edge.path == path: results.append(vertex) return results def _sync(self, sync_point: int = 0) -> Tuple[List[DAGData], int]: + """Dispatch to legacy single-stream sync or per-graph multi-stream sync. + + When `read_endpoint` is set, the server uses the per-graph URL pattern + (`/api/user/graph-sync//...`). That model splits the graph across + multiple streams, so a single-stream `sync` returns only a fragment. + Web Vault uses `get_leafs` -> `multi_sync` to read the full graph; + this client follows the same pattern. + + When only `graph_id` is set (legacy single-endpoint transport), the + single-stream sync remains correct. + """ + if self.read_endpoint is not None: + return self._sync_per_graph(sync_point) + return self._sync_legacy(sync_point) + + def _sync_legacy(self, sync_point: int = 0) -> Tuple[List[DAGData], int]: + """Single-stream sync against the legacy `/sync` endpoint.""" # The web service will send 500 items, if there is more the 'has_more' flag is set to True. has_more = True @@ -543,6 +562,61 @@ def _sync(self, sync_point: int = 0) -> Tuple[List[DAGData], int]: return all_data, sync_point + def _sync_per_graph(self, sync_point: int = 0) -> Tuple[List[DAGData], int]: + """Multi-stream read against the per-graph endpoints. + + The graph's data lives in a single stream keyed by the graph's origin + (e.g. the PAM Configuration record's UID for TunnelDAG). We multi_sync + that stream directly — no `get_leafs` discovery step needed for this + caller pattern. (`Connection.get_leafs` remains available for callers + that start from leaf vertices and need to discover stream roots.) + + Returns aggregated (data, max_sync_point) just like `_sync_legacy`. + """ + + origin_bytes = urlsafe_str_to_bytes(self.uid) + + # Stream keyed by the graph's origin (e.g. config_uid for PAM linking). + per_stream_sync_point: Dict[bytes, int] = {origin_bytes: sync_point} + all_data: List[DAGData] = [] + max_sync_point = sync_point + + while per_stream_sync_point: + stream_ids = list(per_stream_sync_point.keys()) + multi_query = self.read_struct_obj.multi_sync_query( + stream_ids=stream_ids, + origin=origin_bytes, + sync_point=sync_point, + ) + # Per-stream syncPoint adjustment so each stream advances + # independently across pagination rounds (proto variant only; + # JSON variant builds via SyncQuery which already carries syncPoint). + try: + for inner, sid in zip(multi_query.queries, stream_ids): + inner.syncPoint = per_stream_sync_point[sid] + except Exception: # pragma: no cover - JSON variant has no .queries + pass + + multi_response = self.conn.multi_sync( + multi_query=multi_query, + graph_id=self.graph_id, + endpoint=self.read_endpoint, + agent=self.agent, + ) + multi_results = self.read_struct_obj.get_multi_sync_result(multi_response) + + next_per_stream: Dict[bytes, int] = {} + for result in multi_results: + all_data += result.data + if result.syncPoint and result.syncPoint > max_sync_point: + max_sync_point = result.syncPoint + if result.hasMore and result.streamId is not None: + next_per_stream[bytes(result.streamId)] = result.syncPoint + + per_stream_sync_point = next_per_stream + + return all_data, max_sync_point + def _load(self, sync_point: int = 0): """ @@ -619,11 +693,12 @@ def _load(self, sync_point: int = 0): name=data.parentRef.name, vertex_type=RefType.GENERAL ) + # Get the head vertex, which will exist now. - head = self.get_vertex(head_uid) + head = self.get_vertex_by_uid(head_uid) if head is None or head == "": head = tail - head = self.get_vertex_by_uid(head_uid) + self.debug(f" * tail {tail_uid} belongs to {head_uid}, " f"edge type {edge_type}", level=3) diff --git a/keepercommander/keeper_dag/struct/__init__.py b/keepercommander/keeper_dag/struct/__init__.py index 53aa771da..ac87baa7b 100644 --- a/keepercommander/keeper_dag/struct/__init__.py +++ b/keepercommander/keeper_dag/struct/__init__.py @@ -54,3 +54,34 @@ def payload(origin_ref: Union[Ref, gs_pb2.GraphSyncRef], graph_id: Optional[int] = None) -> Union[DataPayload, gs_pb2.GraphSyncAddDataRequest]: pass + + # --- Per-graph multi-stream read transport --------------------------- + # Used by DAG._sync_per_graph when read_endpoint is set. Two-step pattern: + # 1. leafs_query(...) -> get_leafs_result(...) discovers stream refs. + # 2. multi_sync_query(...) -> get_multi_sync_result(...) fetches data. + + def leafs_query(self, + vertices: List[str]) -> Union[BaseModel, gs_pb2.GraphSyncLeafsQuery]: + """Build a GraphSyncLeafsQuery from a list of vertex UIDs (URL-safe str).""" + pass + + @staticmethod + def get_leafs_result(results: bytes) -> List[Ref]: + """Parse GraphSyncRefsResult bytes into a list of Ref objects. + Each Ref's `value` is the stream UID rooted under the queried vertex. + """ + pass + + def multi_sync_query(self, + stream_ids: List[bytes], + origin: bytes, + sync_point: int = 0) -> Union[BaseModel, gs_pb2.GraphSyncMultiQuery]: + """Build a GraphSyncMultiQuery wrapping one GraphSyncQuery per stream.""" + pass + + @staticmethod + def get_multi_sync_result(results: bytes): # -> List[SyncData] + """Parse GraphSyncMultiResult bytes into a list of SyncData, one per + inner GraphSyncResult (each carrying its own streamId/syncPoint/hasMore). + """ + pass diff --git a/keepercommander/keeper_dag/struct/default.py b/keepercommander/keeper_dag/struct/default.py index a58558bff..dd23a6256 100644 --- a/keepercommander/keeper_dag/struct/default.py +++ b/keepercommander/keeper_dag/struct/default.py @@ -1,8 +1,10 @@ from __future__ import annotations +import json from . import DataStructBase from ..types import SyncQuery, Ref, RefType, DAGData, DataPayload, EdgeType, SyncData -from ..crypto import generate_random_bytes, generate_uid_str, bytes_to_str +from ..crypto import generate_random_bytes, generate_uid_str, bytes_to_str, bytes_to_urlsafe_str import base64 +from pydantic import BaseModel from typing import Optional, List @@ -79,3 +81,62 @@ def payload(origin_ref: Ref, dataList=data_list, graphId=graph_id ) + + # --- Per-graph multi-stream read transport --------------------------- + + class _LeafsQuery(BaseModel): + vertices: List[str] + + class _MultiSyncQuery(BaseModel): + queries: List[SyncQuery] + + def leafs_query(self, vertices: List[str]) -> 'DataStruct._LeafsQuery': + return DataStruct._LeafsQuery(vertices=list(vertices)) + + @staticmethod + def get_leafs_result(results: bytes) -> List[Ref]: + try: + obj = json.loads(results) + except Exception as err: + raise Exception(f"Could not parse the leafs JSON result: {err}") + refs_list = obj.get("refs", []) if isinstance(obj, dict) else obj + out: List[Ref] = [] + for r in refs_list: + # Server may return either {type, value, name} or just a value str. + if isinstance(r, dict): + value = r.get("value") + if isinstance(value, bytes): + value = bytes_to_urlsafe_str(value) + out.append(Ref( + type=RefType(r["type"]) if r.get("type") is not None else RefType.GENERAL, + value=value, + name=r.get("name") or None, + )) + return out + + def multi_sync_query(self, + stream_ids: List[bytes], + origin: bytes, + sync_point: int = 0) -> 'DataStruct._MultiSyncQuery': + queries = [ + SyncQuery( + streamId=bytes_to_urlsafe_str(sid), + deviceId=bytes_to_urlsafe_str(origin), + syncPoint=sync_point, + graphId=None, + ) + for sid in stream_ids + ] + return DataStruct._MultiSyncQuery(queries=queries) + + @staticmethod + def get_multi_sync_result(results: bytes) -> List[SyncData]: + try: + obj = json.loads(results) + except Exception as err: + raise Exception(f"Could not parse the multi_sync JSON result: {err}") + items = obj.get("results", []) if isinstance(obj, dict) else obj + out: List[SyncData] = [] + for item in items: + out.append(SyncData.model_validate(item)) + return out diff --git a/keepercommander/keeper_dag/struct/protobuf.py b/keepercommander/keeper_dag/struct/protobuf.py index fcd123d5d..7a449f918 100644 --- a/keepercommander/keeper_dag/struct/protobuf.py +++ b/keepercommander/keeper_dag/struct/protobuf.py @@ -58,23 +58,16 @@ def sync_query(self, ) @staticmethod - def get_sync_result(results: bytes) -> SyncData: - - try: - result = gs_pb2.GraphSyncResult() - result.ParseFromString(results) - except Exception as err: - raise Exception(f"Could not parse the GraphSyncResult message: {err}") - - message = gs_pb2.GraphSyncResult() - message.ParseFromString(results) - + def _sync_data_from_result(message: gs_pb2.GraphSyncResult) -> SyncData: + """Convert a single GraphSyncResult protobuf into a SyncData pydantic + model. Extracted so both single-`sync` and multi_sync code paths share + identical per-result decoding. + """ data_list: List[SyncDataItem] = [] for item in message.data: data_list.append( SyncDataItem( type=DataStruct.PB_TO_DATA_MAP.get(item.data.type), - # content=bytes_to_str(item.data.content), content=item.data.content, content_is_base64=False, ref=Ref( @@ -92,9 +85,21 @@ def get_sync_result(results: bytes) -> SyncData: return SyncData( syncPoint=message.syncPoint, data=data_list, - hasMore=message.hasMore + hasMore=message.hasMore, + streamId=bytes(message.streamId) if message.streamId else None, ) + @staticmethod + def get_sync_result(results: bytes) -> SyncData: + + try: + message = gs_pb2.GraphSyncResult() + message.ParseFromString(results) + except Exception as err: + raise Exception(f"Could not parse the GraphSyncResult message: {err}") + + return DataStruct._sync_data_from_result(message) + @staticmethod def origin_ref(origin_ref_value: bytes, name: str) -> gs_pb2.GraphSyncRef: @@ -149,3 +154,49 @@ def payload(origin_ref: gs_pb2.GraphSyncRef, return gs_pb2.GraphSyncAddDataRequest( origin=origin_ref, data=data_list) + + # --- Per-graph multi-stream read transport --------------------------- + + def leafs_query(self, vertices: List[str]) -> gs_pb2.GraphSyncLeafsQuery: + return gs_pb2.GraphSyncLeafsQuery( + vertices=[urlsafe_str_to_bytes(v) for v in vertices] + ) + + @staticmethod + def get_leafs_result(results: bytes) -> List[Ref]: + msg = gs_pb2.GraphSyncRefsResult() + try: + msg.ParseFromString(results) + except Exception as err: + raise Exception(f"Could not parse the GraphSyncRefsResult message: {err}") + return [ + Ref( + type=DataStruct.PB_TO_REF_MAP.get(r.type), + value=bytes_to_urlsafe_str(r.value), + name=r.name or None, + ) + for r in msg.refs + ] + + def multi_sync_query(self, + stream_ids: List[bytes], + origin: bytes, + sync_point: int = 0) -> gs_pb2.GraphSyncMultiQuery: + return gs_pb2.GraphSyncMultiQuery(queries=[ + gs_pb2.GraphSyncQuery( + streamId=sid, + origin=origin, + syncPoint=sync_point, + maxCount=0, # let krouter default (currently 500) + ) + for sid in stream_ids + ]) + + @staticmethod + def get_multi_sync_result(results: bytes) -> List[SyncData]: + msg = gs_pb2.GraphSyncMultiResult() + try: + msg.ParseFromString(results) + except Exception as err: + raise Exception(f"Could not parse the GraphSyncMultiResult message: {err}") + return [DataStruct._sync_data_from_result(r) for r in msg.results] diff --git a/keepercommander/keeper_dag/types.py b/keepercommander/keeper_dag/types.py index 9ab242c88..4b614fbfe 100644 --- a/keepercommander/keeper_dag/types.py +++ b/keepercommander/keeper_dag/types.py @@ -124,6 +124,17 @@ class PamEndpoints(BaseEnum): PamGraphId.SERVICE_LINKS.value: PamEndpoints.SERVICE_LINKS, } +# Inverse map for callers that have a graph_id int and need the PamEndpoints enum +# to address the new /api/user/graph-sync// routes. +GRAPH_ID_TO_ENDPOINT = { + PamGraphId.PAM.value: PamEndpoints.PAM, + PamGraphId.DISCOVERY_RULES.value: PamEndpoints.DISCOVERY_RULES, + PamGraphId.DISCOVERY_JOBS.value: PamEndpoints.DISCOVERY_JOBS, + PamGraphId.INFRASTRUCTURE.value: PamEndpoints.INFRASTRUCTURE, + PamGraphId.SERVICE_LINKS.value: PamEndpoints.SERVICE_LINKS, +} + + class SyncQuery(BaseModel): streamId: Optional[str] = None # base64 of a user's ID who is syncing. deviceId: Optional[str] = None @@ -134,7 +145,10 @@ class SyncQuery(BaseModel): class SyncDataItem(BaseModel): ref: Ref parentRef: Optional[Ref] = None - content: Optional[str] = None + # Either a base64-encoded string (JSON wire format) or raw bytes + # (protobuf wire format). `content_is_base64` distinguishes them so the + # consumer can decode appropriately. + content: Optional[Union[str, bytes]] = None content_is_base64: bool = True type: Optional[str] = None path: Optional[str] = None @@ -145,6 +159,9 @@ class SyncData(BaseModel): syncPoint: int data: List[SyncDataItem] hasMore: bool + # Per-graph multi_sync: identifies which stream this result came from. + # None for single-stream `sync` results (backward compatible). + streamId: Optional[bytes] = None class Ref(BaseModel): diff --git a/keepercommander/keeper_dag/utils.py b/keepercommander/keeper_dag/utils.py index 43ad5a76e..51e33c921 100644 --- a/keepercommander/keeper_dag/utils.py +++ b/keepercommander/keeper_dag/utils.py @@ -55,5 +55,5 @@ def set_file_permissions(file_path): # type: (str) -> None check=False, capture_output=True) subprocess.run(["icacls", file_path, "/grant", f"{username}:M"], check=True, capture_output=True) logging.debug(f'Set secure permissions (owner Modify only) for Windows file: {file_path}') - except Exception: - logging.warning(f'Failed to set file permissions for {file_path}') + except (OSError, subprocess.SubprocessError) as err: + logging.warning(f'Failed to set file permissions for {file_path}: {err}') diff --git a/keepercommander/record.py b/keepercommander/record.py index 897180003..189e3ccf0 100644 --- a/keepercommander/record.py +++ b/keepercommander/record.py @@ -275,6 +275,15 @@ def display(self, unmask=False): # Strip type prefixes from field names (e.g., "text:Sign-In Address" -> "Sign-In Address") field_type_prefixes = ('text:', 'multiline:', 'url:', 'phone:', 'email:', 'secret:', 'date:', 'name:', 'host:', 'address:') display_name = field_name + # v2 records carry the real type in c['type']; v3 encode it as "type:label" in the name. + # Keep this list in sync with _MASKED_FIELD_TYPES in commands/record.py. + # 'note' = Secured Note — sensitive, masked by design. + # 'passkey' omitted: early-exit handler above renders only non-sensitive sub-fields. + _MASKED_TYPES = frozenset({ + 'secret', 'pinCode', 'note', 'json', 'oneTimeCode', + 'paymentCard', 'bankAccount', 'keyPair', 'securityQuestion', + }) + is_secret_field = c.get('type') in _MASKED_TYPES for prefix in field_type_prefixes: if field_name.lower().startswith(prefix): display_name = field_name[len(prefix):] @@ -293,8 +302,26 @@ def display(self, unmask=False): 'address:': 'Address', } display_name = type_friendly_names.get(prefix, prefix.rstrip(':').title()) + if prefix.rstrip(':') in _MASKED_TYPES: + is_secret_field = True break - print('{0:>20s}: {1:20s}: {1:20s}: {1: None """Sync full or partial data down to the client""" - params.sync_data = False - token = params.sync_down_token - - # Use spinner animation for full sync (only in interactive mode, not batch/automation) + # Use spinner animation for full sync (only in interactive mode, not batch/automation). + # WARNING: stop() MUST be guaranteed via finally. A leaked spinner thread keeps + # '\r'-overwriting the current console row for the rest of the session, erasing + # lines of any large output printed later - notably base64 KSM config tokens + # (`pam project import` / `pam gateway new` access_token), which then reach the + # user silently corrupted. spinner = None - if not token and not params.batch_mode: + if not params.sync_down_token and not params.batch_mode: spinner = Spinner('Syncing...') spinner.start() + try: + _sync_down_impl(params, record_types) + finally: + if spinner: + spinner.stop() + + +def _sync_down_impl(params, record_types=False): # type: (KeeperParams, bool) -> None + params.sync_data = False + token = params.sync_down_token for record in params.record_cache.values(): if 'shares' in record: @@ -1041,10 +1053,6 @@ def convert_user_folder_shared_folder(ufsf): type_id += rt.scope * 1000000 params.record_type_cache[type_id] = rt.content - # Stop spinner if running - if spinner: - spinner.stop() - if full_sync: convert_keys.change_key_types(params) diff --git a/keepercommander/vault_extensions.py b/keepercommander/vault_extensions.py index c13411f46..a1bd49034 100644 --- a/keepercommander/vault_extensions.py +++ b/keepercommander/vault_extensions.py @@ -447,9 +447,10 @@ def extract_typed_field(field): # type: (vault.TypedField) -> dict field_values.append(value) result = { 'type': field_type, - 'label': field.label or '', 'value': field_values } + if field.label: + result['label'] = field.label if field.required is True: result['required'] = True return result diff --git a/tests/test_tunnel_close_leak.py b/tests/test_tunnel_close_leak.py index 877677523..72696e6bb 100644 --- a/tests/test_tunnel_close_leak.py +++ b/tests/test_tunnel_close_leak.py @@ -19,6 +19,20 @@ import time import logging +# Linux-only e2e repro: shells out to `sshpass`/`ssh` with `/dev/null` paths and +# needs an SSH container on 127.0.0.1:2222. Bail before importing the native +# keeper_pam_connections module so it neither aborts pytest collection nor errors +# when run on Windows. +if sys.platform == "win32": + _skip_msg = ("tunnel close-leak e2e is Linux-only (needs sshpass + SSH " + "container on :2222); skipped on Windows") + if "pytest" in sys.modules: + import pytest + pytest.skip(_skip_msg, allow_module_level=True) + else: + print(f"SKIP: {_skip_msg}") + sys.exit(0) + import keeper_pam_connections from keepercommander.commands.tunnel.port_forward.tunnel_helpers import ( TunnelSignalHandler, diff --git a/unit-tests/pam/test_dag_layer_b_configure_resource.py b/unit-tests/pam/test_dag_layer_b_configure_resource.py index 13f36247e..c29e1b2d2 100644 --- a/unit-tests/pam/test_dag_layer_b_configure_resource.py +++ b/unit-tests/pam/test_dag_layer_b_configure_resource.py @@ -102,15 +102,25 @@ def test_configure_resource_hits_correct_url(): def test_configure_resource_sends_protobuf_body(): - """Body is the encrypted PAMResourceConfig protobuf, not JSON.""" + """Body is the AES-GCM-encrypted PAMResourceConfig protobuf — not plaintext proto, not JSON.""" from keepercommander.commands.pam.router_helper import router_configure_resource + from keepercommander import crypto, utils rq = pam_pb2.PAMResourceConfig(recordUid=RESOURCE_UID, networkUid=NETWORK_UID, adminUid=ADMIN_UID) + transmission_key = utils.generate_aes_key() with patch(REQUESTS_TARGET, return_value=_ok_router_response()) as mock_req: - router_configure_resource(_mock_params(), rq) + router_configure_resource(_mock_params(), rq, transmission_key=transmission_key) _, body = _capture_call(mock_req) assert isinstance(body, (bytes, bytearray)), f'body must be bytes, got {type(body)}' - # Body is encrypted; check that it's NOT a JSON-encoded payload (sanity). - assert not body.startswith(b'{'), 'body should be encrypted protobuf, not JSON' + # Must not be the raw (unencrypted) serialisation. + # (A first-byte heuristic like `not body.startswith(b'{')` is flaky: ~1/256 + # of ciphertexts legitimately start with 0x7b.) + assert body != rq.SerializeToString() + decrypted = crypto.decrypt_aes_v2(body, transmission_key) + parsed = pam_pb2.PAMResourceConfig() + parsed.ParseFromString(decrypted) + assert parsed.recordUid == RESOURCE_UID + assert parsed.networkUid == NETWORK_UID + assert parsed.adminUid == ADMIN_UID # --------------------------------------------------------------------------- # diff --git a/unit-tests/pam/test_dag_layer_b_migration.py b/unit-tests/pam/test_dag_layer_b_migration.py index c332841c0..e77c4ee9e 100644 --- a/unit-tests/pam/test_dag_layer_b_migration.py +++ b/unit-tests/pam/test_dag_layer_b_migration.py @@ -16,6 +16,7 @@ calling configure_resource. """ import json +import json import os import sys from unittest.mock import MagicMock, patch @@ -95,6 +96,37 @@ def _capture(params, rq): # Critical: must NOT be set on jitSettings field assert rq.jitSettings == b'' + def test_happy_path_bundles_current_meta_so_krouter_persists_ai_edge(self): + """Regression: krouter's configure_resource only writes a settings edge + when it loads loopEdges, which it does only for requests carrying + meta/jit/connection (UserRest.kt:497). A keeperAiSettings-only request + leaves loopEdges null and the ai_settings write is silently dropped. The + Web Vault always sends meta alongside AI settings; Commander must mirror + that by bundling the resource's current meta in the same request.""" + captured = {} + + def _capture(params, rq): + captured['rq'] = rq + return None + + meta_dict = {'version': 1, 'allowedSettings': {'aiEnabled': True}, 'rotateOnTermination': False} + with _patch_inputs(), \ + patch.object(ai_mod, 'encrypt_aes', return_value=b'CIPHER_BYTES'), \ + patch.object(ai_mod, 'get_resource_settings', return_value=meta_dict) as meta_mock, \ + patch('keepercommander.commands.pam.router_helper.router_configure_resource', side_effect=_capture): + ok = ai_mod.set_resource_keeper_ai_settings( + _mock_params(), RESOURCE_UID_STR, {'level': 'critical'}, config_uid=CONFIG_UID_STR + ) + assert ok is True + rq = captured['rq'] + assert rq.keeperAiSettings == b'CIPHER_BYTES' + # The fix: meta must be present so krouter fetches loopEdges and persists + # the ai_settings edge. Without it the write is a silent no-op. + assert rq.meta == json.dumps(meta_dict).encode() + # meta is read from the resource's current 'meta' DATA edge. + meta_mock.assert_called_once() + assert meta_mock.call_args.args[2] == 'meta' + def test_permission_denied_with_fallback_enabled_calls_legacy(self): legacy_called = {'count': 0} diff --git a/unit-tests/pam/test_pam_connection_edit_scrollback.py b/unit-tests/pam/test_pam_connection_edit_scrollback.py index 7f758652d..2ebd61573 100644 --- a/unit-tests/pam/test_pam_connection_edit_scrollback.py +++ b/unit-tests/pam/test_pam_connection_edit_scrollback.py @@ -47,6 +47,50 @@ def test_help_includes_scrollback(self): self.assertIn('-sb', help_text) +@unittest.skipIf(skip_tests, skip_reason) +class TestPamConnectionEditProtocolChoices(unittest.TestCase): + """--protocol choices: the full DB protocol set is now accepted, and the choices + list is composed from the db_protocols / non_db_protocols source-of-truth lists.""" + + NEW_DB_PROTOCOLS = ['mariadb', 'oracle', 'mongodb', 'redis', + 'elasticsearch', 'clickhouse', 'dynamodb'] + + def setUp(self): + self.parser = PAMConnectionEditCommand.parser + + def test_new_db_protocols_accepted(self): + for proto in self.NEW_DB_PROTOCOLS: + with self.subTest(protocol=proto): + args = self.parser.parse_args(['rec', '--protocol', proto]) + self.assertEqual(args.protocol, proto) + + def test_mariadb_and_oracle_accepted(self): + # The two protocols this change was specifically about. + self.assertEqual(self.parser.parse_args(['rec', '-p', 'mariadb']).protocol, 'mariadb') + self.assertEqual(self.parser.parse_args(['rec', '-p', 'oracle']).protocol, 'oracle') + + def test_existing_protocols_still_accepted(self): + for proto in ['', 'http', 'kubernetes', 'mysql', 'postgresql', 'rdp', 'sql-server', 'ssh', 'telnet', 'vnc']: + with self.subTest(protocol=proto): + self.assertEqual(self.parser.parse_args(['rec', '--protocol', proto]).protocol, proto) + + def test_invalid_protocol_rejected(self): + with self.assertRaises(SystemExit): + self.parser.parse_args(['rec', '--protocol', 'bogus']) + + def test_choices_composed_from_source_lists(self): + # protocols is the single source of truth: '' + sorted(non_db + db), no duplicates. + expected = [''] + sorted(PAMConnectionEditCommand.non_db_protocols + + PAMConnectionEditCommand.db_protocols) + self.assertEqual(PAMConnectionEditCommand.protocols, expected) + self.assertEqual(len(PAMConnectionEditCommand.protocols), + len(set(PAMConnectionEditCommand.protocols))) + + def test_all_db_protocols_present_in_choices(self): + for proto in PAMConnectionEditCommand.db_protocols: + self.assertIn(proto, PAMConnectionEditCommand.protocols) + + @unittest.skipIf(skip_tests, skip_reason) class TestPamConnectionEditScrollbackValidation(unittest.TestCase): """Validation runs before DAG / token operations, so we can drive execute() @@ -122,6 +166,18 @@ def test_pam_directory_http_rejected(self): self._execute(rec, scrollback='100') self.assertIn('not supported for protocol "http"', str(ctx.exception)) + def test_pam_database_mariadb_rejected(self): + rec = self._mock_record('pamDatabase', 'mariadb') + with self.assertRaises(CommandError) as ctx: + self._execute(rec, scrollback='100') + self.assertIn('not supported for protocol "mariadb"', str(ctx.exception)) + + def test_pam_database_mongodb_rejected(self): + rec = self._mock_record('pamDatabase', 'mongodb') + with self.assertRaises(CommandError) as ctx: + self._execute(rec, scrollback='100') + self.assertIn('not supported for protocol "mongodb"', str(ctx.exception)) + def test_non_numeric_rejected(self): rec = self._mock_record('pamMachine', 'ssh') with self.assertRaises(CommandError) as ctx: @@ -183,8 +239,9 @@ class TestPamConnectionEditScrollbackAllowedCombinations(unittest.TestCase): a scrollback-related error. We don't run the full execute path (which would require mocking the entire DAG layer), only verify validation passes.""" - DB_PROTOCOLS = ['mysql', 'postgresql', 'sql-server', 'mariadb', 'oracle', - 'mongodb', 'redis', 'elasticsearch', 'clickhouse', 'dynamodb'] + DB_PROTOCOLS = ['mysql', 'postgresql', 'sql-server'] + KEEPER_DB_ONLY_PROTOCOLS = ['mariadb', 'oracle', 'mongodb', 'redis', + 'elasticsearch', 'clickhouse', 'dynamodb'] TERMINAL_PROTOCOLS = ['ssh', 'telnet', 'kubernetes'] def _assert_validation_passes(self, record_type, protocol): @@ -211,11 +268,30 @@ def _assert_validation_passes(self, record_type, protocol): except Exception: pass # downstream DAG/token failures are not what we're testing - def test_pam_database_all_db_protocols(self): + def test_pam_database_cli_capable_db_protocols(self): for proto in self.DB_PROTOCOLS: with self.subTest(protocol=proto): self._assert_validation_passes('pamDatabase', proto) + def test_pam_database_keeper_db_only_rejected(self): + for proto in self.KEEPER_DB_ONLY_PROTOCOLS: + rec = mock.MagicMock(spec=vault.TypedRecord) + rec.record_uid = 'rec-uid' + rec.record_type = 'pamDatabase' + rec.version = 3 + ps_field = mock.MagicMock() + ps_field.value = [{'connection': {'protocol': proto}}] + rec.get_typed_field.side_effect = lambda name: ps_field if name == 'pamSettings' else None + cmd = PAMConnectionEditCommand() + params = mock.MagicMock() + with mock.patch( + 'keepercommander.commands.tunnel_and_connections.RecordMixin.resolve_single_record', + return_value=rec, + ): + with self.assertRaises(CommandError) as ctx: + cmd.execute(params, record='rec', scrollback='100') + self.assertIn(f'not supported for protocol "{proto}"', str(ctx.exception)) + def test_pam_machine_terminal_protocols(self): for proto in self.TERMINAL_PROTOCOLS: with self.subTest(protocol=proto): diff --git a/unit-tests/pam/test_pam_import_db_protocols.py b/unit-tests/pam/test_pam_import_db_protocols.py new file mode 100644 index 000000000..7a8da96d6 --- /dev/null +++ b/unit-tests/pam/test_pam_import_db_protocols.py @@ -0,0 +1,92 @@ +""" +Unit tests for the database connection protocols recognized by `pam project import` / +`pam project extend` (step 2): mariadb, oracle, mongodb, redis, elasticsearch, +clickhouse, dynamodb — added alongside the pre-existing mysql/postgresql/sql-server. + +Verifies the import model round-trips each protocol: protocol string -> connection class +-> record dict, that pamDatabase validation passes, and that is_database_protocol agrees. +""" + +import unittest + +skip_tests = False +skip_reason = "" +try: + from keepercommander.commands.pam_import.base import ( + PamSettingsFieldData, + validate_pam_connection, + is_database_protocol, + ConnectionProtocol, + ConnectionSettingsMariaDB, + ConnectionSettingsOracle, + ConnectionSettingsMongoDB, + ConnectionSettingsRedis, + ConnectionSettingsElasticsearch, + ConnectionSettingsClickHouse, + ConnectionSettingsDynamoDB, + ) +except ImportError as e: + skip_tests = True + skip_reason = f"Cannot import pam_import.base: {e}" + + +@unittest.skipIf(skip_tests, skip_reason) +class TestPamImportNewDbProtocols(unittest.TestCase): + # (wire value, expected connection class) + NEW_PROTOCOLS = [ + ('mariadb', 'ConnectionSettingsMariaDB'), + ('oracle', 'ConnectionSettingsOracle'), + ('mongodb', 'ConnectionSettingsMongoDB'), + ('redis', 'ConnectionSettingsRedis'), + ('elasticsearch', 'ConnectionSettingsElasticsearch'), + ('clickhouse', 'ConnectionSettingsClickHouse'), + ('dynamodb', 'ConnectionSettingsDynamoDB'), + ] + + def test_enum_values_present(self): + self.assertEqual(ConnectionProtocol.MARIADB.value, 'mariadb') + self.assertEqual(ConnectionProtocol.ORACLE.value, 'oracle') + self.assertEqual(ConnectionProtocol.MONGODB.value, 'mongodb') + self.assertEqual(ConnectionProtocol.REDIS.value, 'redis') + self.assertEqual(ConnectionProtocol.ELASTICSEARCH.value, 'elasticsearch') + self.assertEqual(ConnectionProtocol.CLICKHOUSE.value, 'clickhouse') + self.assertEqual(ConnectionProtocol.DYNAMODB.value, 'dynamodb') + + def test_registered_in_connection_classes(self): + registered = set(PamSettingsFieldData.pam_connection_classes) + for cls in (ConnectionSettingsMariaDB, ConnectionSettingsOracle, ConnectionSettingsMongoDB, + ConnectionSettingsRedis, ConnectionSettingsElasticsearch, + ConnectionSettingsClickHouse, ConnectionSettingsDynamoDB): + self.assertIn(cls, registered) + + def test_get_connection_class_resolves(self): + for proto, cls_name in self.NEW_PROTOCOLS: + with self.subTest(protocol=proto): + obj = PamSettingsFieldData.get_connection_class({'protocol': proto}) + self.assertIsNotNone(obj) + self.assertEqual(type(obj).__name__, cls_name) + + def test_round_trip_protocol_and_database(self): + for proto, _ in self.NEW_PROTOCOLS: + with self.subTest(protocol=proto): + obj = PamSettingsFieldData.get_connection_class( + {'protocol': proto, 'default_database': 'db1', 'port': '1234'}) + rd = obj.to_record_dict() + self.assertEqual(rd.get('protocol'), proto) + self.assertEqual(rd.get('database'), 'db1') + self.assertEqual(rd.get('port'), '1234') + + def test_validate_pam_connection_passes_for_pam_database(self): + for proto, _ in self.NEW_PROTOCOLS: + with self.subTest(protocol=proto): + obj = PamSettingsFieldData.get_connection_class({'protocol': proto}) + self.assertFalse(validate_pam_connection(obj, 'pamDatabase')) + + def test_is_database_protocol(self): + for proto, _ in self.NEW_PROTOCOLS: + with self.subTest(protocol=proto): + self.assertTrue(is_database_protocol(proto)) + + +if __name__ == '__main__': + unittest.main() diff --git a/unit-tests/pam/test_pam_import_scrollback.py b/unit-tests/pam/test_pam_import_scrollback.py new file mode 100644 index 000000000..0a03a383e --- /dev/null +++ b/unit-tests/pam/test_pam_import_scrollback.py @@ -0,0 +1,90 @@ +""" +Unit tests for `scrollback` (Maximum Scrollback Size) in the `pam project import` / +`pam project extend` connection settings. + +scrollback lives in TerminalDisplayConnectionSettings (mirrors the Web Vault) and is +threaded through terminal protocols (SSH, Telnet, Kubernetes) and CLI-capable DB protocols +(mysql, postgresql, sql-server). It is validated as a positive integer (like audio_bps / +audio_sample_rate): zero, negative, and non-numeric values are rejected with a warning. +KeeperDb-only DB protocols have no terminal display, so they never carry scrollback. +""" + +import logging +import unittest + +skip_tests = False +skip_reason = "" +try: + from keepercommander.commands.pam_import.base import PamSettingsFieldData +except ImportError as e: + skip_tests = True + skip_reason = f"Cannot import pam_import.base: {e}" + + +def _record_scrollback(protocol, scrollback): + data = {'protocol': protocol} + if scrollback is not None: + data['scrollback'] = scrollback + obj = PamSettingsFieldData.get_connection_class(data) + return obj.to_record_dict().get('scrollback') if obj else 'NO_CLASS' + + +@unittest.skipIf(skip_tests, skip_reason) +class TestPamImportScrollback(unittest.TestCase): + TERMINAL_PROTOCOLS = ['ssh', 'telnet', 'kubernetes'] + CLI_CAPABLE_DB_PROTOCOLS = ['mysql', 'postgresql', 'sql-server'] + KEEPER_DB_ONLY_PROTOCOLS = [ + 'mariadb', 'oracle', 'mongodb', 'redis', 'elasticsearch', 'clickhouse', 'dynamodb', + ] + SCROLLBACK_PROTOCOLS = TERMINAL_PROTOCOLS + CLI_CAPABLE_DB_PROTOCOLS + + def setUp(self): + # Silence the expected validation warnings for invalid inputs. + logging.disable(logging.CRITICAL) + + def tearDown(self): + logging.disable(logging.NOTSET) + + def test_valid_int_round_trips(self): + for proto in self.SCROLLBACK_PROTOCOLS: + with self.subTest(protocol=proto): + self.assertEqual(_record_scrollback(proto, 5000), 5000) + + def test_valid_string_int_parsed(self): + for proto in self.SCROLLBACK_PROTOCOLS: + with self.subTest(protocol=proto): + self.assertEqual(_record_scrollback(proto, '4096'), 4096) + + def test_zero_rejected(self): + for proto in self.SCROLLBACK_PROTOCOLS: + with self.subTest(protocol=proto): + self.assertIsNone(_record_scrollback(proto, 0)) + + def test_negative_rejected(self): + for proto in self.SCROLLBACK_PROTOCOLS: + with self.subTest(protocol=proto): + self.assertIsNone(_record_scrollback(proto, -10)) + + def test_non_numeric_rejected(self): + for proto in self.SCROLLBACK_PROTOCOLS: + with self.subTest(protocol=proto): + self.assertIsNone(_record_scrollback(proto, 'abc')) + + def test_float_string_rejected(self): + for proto in self.SCROLLBACK_PROTOCOLS: + with self.subTest(protocol=proto): + self.assertIsNone(_record_scrollback(proto, '12.5')) + + def test_not_provided_absent(self): + for proto in self.SCROLLBACK_PROTOCOLS: + with self.subTest(protocol=proto): + self.assertIsNone(_record_scrollback(proto, None)) + + def test_keeper_db_only_protocols_have_no_scrollback(self): + for proto in self.KEEPER_DB_ONLY_PROTOCOLS: + with self.subTest(protocol=proto): + self.assertIsNone(_record_scrollback(proto, 5000)) + + +if __name__ == '__main__': + unittest.main() diff --git a/unit-tests/pam/test_pam_rbi_edit.py b/unit-tests/pam/test_pam_rbi_edit.py index ce52be84d..d3a451b20 100644 --- a/unit-tests/pam/test_pam_rbi_edit.py +++ b/unit-tests/pam/test_pam_rbi_edit.py @@ -199,6 +199,34 @@ def test_combine_new_args_with_existing(self): self.assertEqual(args.allow_url_navigation, 'on') self.assertEqual(args.allowed_urls, ['*.example.com']) + def test_session_persistence_none(self): + args = self.parser.parse_args(['--record', 'test-record', '--session-persistence', 'none']) + self.assertEqual(args.session_persistence, 'none') + + def test_session_persistence_user(self): + args = self.parser.parse_args(['--record', 'test-record', '--session-persistence', 'user']) + self.assertEqual(args.session_persistence, 'user') + + def test_session_persistence_resource(self): + args = self.parser.parse_args(['--record', 'test-record', '--session-persistence', 'resource']) + self.assertEqual(args.session_persistence, 'resource') + + def test_session_persistence_default(self): + args = self.parser.parse_args(['--record', 'test-record', '--session-persistence', 'default']) + self.assertEqual(args.session_persistence, 'default') + + def test_session_persistence_short_flag(self): + args = self.parser.parse_args(['--record', 'test-record', '-sp', 'user']) + self.assertEqual(args.session_persistence, 'user') + + def test_session_persistence_invalid(self): + with self.assertRaises(SystemExit): + self.parser.parse_args(['--record', 'test-record', '--session-persistence', 'invalid']) + + def test_session_persistence_not_provided(self): + args = self.parser.parse_args(['--record', 'test-record', '--key-events', 'on']) + self.assertIsNone(args.session_persistence) + @unittest.skipIf(skip_tests, skip_reason) class TestPamRbiEditExecute(unittest.TestCase): @@ -328,6 +356,42 @@ def test_autofill_targets_joins_with_newlines(self, mock_sync, mock_update, mock self.command.execute(self.mock_params, record='test-record', autofill_targets=['#username', '#password']) self.assertEqual(self.pam_settings['connection'].get('autofillConfiguration'), '#username\n#password') + @mock.patch('keepercommander.commands.tunnel_and_connections.RecordMixin.resolve_single_record') + @mock.patch('keepercommander.commands.tunnel_and_connections.record_management.update_record') + @mock.patch('keepercommander.commands.tunnel_and_connections.api.sync_down') + def test_session_persistence_sets_value(self, mock_sync, mock_update, mock_resolve): + mock_resolve.return_value = self.mock_record + self.command.execute(self.mock_params, record='test-record', session_persistence='user') + self.assertEqual(self.pam_settings['connection'].get('sessionPersistence'), 'user') + + @mock.patch('keepercommander.commands.tunnel_and_connections.RecordMixin.resolve_single_record') + @mock.patch('keepercommander.commands.tunnel_and_connections.record_management.update_record') + @mock.patch('keepercommander.commands.tunnel_and_connections.api.sync_down') + def test_session_persistence_none_sets_literal(self, mock_sync, mock_update, mock_resolve): + # 'none' is a real enum value (no persistence), not a removal sentinel + mock_resolve.return_value = self.mock_record + self.command.execute(self.mock_params, record='test-record', session_persistence='none') + self.assertEqual(self.pam_settings['connection'].get('sessionPersistence'), 'none') + + @mock.patch('keepercommander.commands.tunnel_and_connections.RecordMixin.resolve_single_record') + @mock.patch('keepercommander.commands.tunnel_and_connections.record_management.update_record') + @mock.patch('keepercommander.commands.tunnel_and_connections.api.sync_down') + def test_session_persistence_default_removes_field(self, mock_sync, mock_update, mock_resolve): + mock_resolve.return_value = self.mock_record + self.pam_settings['connection']['sessionPersistence'] = 'user' + self.command.execute(self.mock_params, record='test-record', session_persistence='default') + self.assertNotIn('sessionPersistence', self.pam_settings['connection']) + + @mock.patch('keepercommander.commands.tunnel_and_connections.RecordMixin.resolve_single_record') + @mock.patch('keepercommander.commands.tunnel_and_connections.record_management.update_record') + @mock.patch('keepercommander.commands.tunnel_and_connections.api.sync_down') + def test_session_persistence_default_removes_present_but_null(self, mock_sync, mock_update, mock_resolve): + # A present-but-null value must still be removed (membership check, not None check) + mock_resolve.return_value = self.mock_record + self.pam_settings['connection']['sessionPersistence'] = None + self.command.execute(self.mock_params, record='test-record', session_persistence='default') + self.assertNotIn('sessionPersistence', self.pam_settings['connection']) + @unittest.skipIf(skip_tests, skip_reason) class TestPamRbiEditClipboardInversion(unittest.TestCase): diff --git a/unit-tests/test_command_record.py b/unit-tests/test_command_record.py index fc61833c9..e326e2abe 100644 --- a/unit-tests/test_command_record.py +++ b/unit-tests/test_command_record.py @@ -8,7 +8,7 @@ from data_vault import get_synced_params, VaultEnvironment from helper import KeeperApiHelper -from keepercommander import api, utils, crypto, attachment, vault +from keepercommander import api, utils, crypto, attachment, vault, vault_extensions from keepercommander.commands import record, record_edit from keepercommander.error import CommandError @@ -97,6 +97,78 @@ def artf(p, r, f): self.assertIsNotNone(field) self.assertEqual(field.get_default_value(str), 'BBB') + def _run_add(self, params, **kwargs): + """Run record-add with the API mocked; return the TypedRecord that would be saved.""" + cmd = record_edit.RecordAddCommand() + captured = {} + with mock.patch('keepercommander.api.sync_down'), \ + mock.patch('keepercommander.record_management.add_record_to_folder') as ar: + def artf(p, r, f): + captured['record'] = r + r.record_uid = utils.generate_uid() + ar.side_effect = artf + cmd.execute(params, **kwargs) + return captured.get('record') + + # RT schema with a label-less field (synthesized fallback) and one with a real definition label. + _RT_SCHEMA = [ + {"$ref": "login"}, # no label in RT definition + {"$ref": "password"}, # no label in RT definition + {"$ref": "script", "label": "rotationScripts"}, # real RT-definition label + ] + + def test_add_command_labels_default_is_legacy(self): + # No --labels (and explicit --labels=on): fields with no label in the RT definition fall + # back to the field type as the label; real definition labels are kept. + params = get_synced_params() + for labels in (None, 'on'): + kwargs = dict(force=True, title='L', record_type='login', + fields=['login=user@company.com', 'password=secret']) + if labels is not None: + kwargs['labels'] = labels + with mock.patch.object(record_edit.RecordAddCommand, 'get_record_type_fields', + return_value=list(self._RT_SCHEMA)): + record = self._run_add(params, **kwargs) + self.assertIsInstance(record, vault.TypedRecord) + self.assertEqual(record.get_typed_field('login').label, 'login') # synthesized + self.assertEqual(record.get_typed_field('password').label, 'password') # synthesized + self.assertEqual(record.get_typed_field('script').label, 'rotationScripts') # real, kept + data = vault_extensions.extract_typed_record_data(record) + login_data = next(x for x in data['fields'] if x['type'] == 'login') + self.assertEqual(login_data.get('label'), 'login') + + def test_add_command_labels_off_matches_vault(self): + # --labels=off: omit the synthesized type-name labels (login, password) but KEEP real + # RT-definition labels (script->rotationScripts), matching the Vault UI; an explicitly + # provided cmdline label is always preserved. + params = get_synced_params() + with mock.patch.object(record_edit.RecordAddCommand, 'get_record_type_fields', + return_value=list(self._RT_SCHEMA)): + record = self._run_add(params, force=True, title='L', record_type='login', labels='off', + fields=['login=user@company.com', 'text.MyLabel=val']) + self.assertIsInstance(record, vault.TypedRecord) + self.assertFalse(record.get_typed_field('login').label) # synthesized -> dropped + self.assertFalse(record.get_typed_field('password').label) # synthesized -> dropped + self.assertEqual(record.get_typed_field('script').label, 'rotationScripts') # real -> kept + self.assertEqual(record.get_typed_field('text', 'MyLabel').label, 'MyLabel') # explicit -> kept + + data = vault_extensions.extract_typed_record_data(record) + login_d = next(x for x in data['fields'] if x['type'] == 'login') + script_d = next(x for x in data['fields'] if x['type'] == 'script') + self.assertNotIn('label', login_d) # synthesized label omitted + self.assertEqual(script_d.get('label'), 'rotationScripts') # real label serialized + custom_d = next(x for x in data['custom'] if x.get('label') == 'MyLabel') + self.assertEqual(custom_d['label'], 'MyLabel') + + def test_extract_typed_field_omits_empty_label(self): + # Serializer omits the label key when falsy; keeps it when present. + self.assertNotIn('label', vault_extensions.extract_typed_field( + vault.TypedField.new_field('login', 'admin', ''))) + self.assertNotIn('label', vault_extensions.extract_typed_field( + vault.TypedField.new_field('login', 'admin', None))) + kept = vault_extensions.extract_typed_field(vault.TypedField.new_field('text', 'v', 'MyLabel')) + self.assertEqual(kept.get('label'), 'MyLabel') + def test_remove_command_from_root(self): params = get_synced_params() cmd = record.RecordRemoveCommand() @@ -248,6 +320,14 @@ def test_get_invalid_uid(self): with self.assertRaises(CommandError): cmd.execute(params, uid='invalid') + def test_get_rejects_shell_metacharacters_in_lookup_token(self): + params = get_synced_params() + cmd = record.RecordGetUidCommand() + + with self.assertRaises(CommandError) as context: + cmd.execute(params, uid='x;cd $HOME && id > pwned_keeper_rce.txt;#"unclosed') + self.assertIn('forbidden characters', context.exception.message) + def test_append_notes_command(self): params = get_synced_params() cmd = record_edit.RecordAppendNotesCommand() @@ -343,3 +423,219 @@ def communicate_rest_success(params, request, endpoint, **kwargs): raise Exception() + +class TestGetCommandMasking(TestCase): + """Sensitive fields are masked in detail/fields output; --unmask reveals them.""" + + def setUp(self): + mock.patch('keepercommander.api.communicate').start() + mock.patch('keepercommander.api.communicate_rest').start() + + def tearDown(self): + mock.patch.stopall() + + def _printed(self, mock_print): + return ' '.join(str(a) for call in mock_print.call_args_list for a in call[0]) + + # ── Record.display() (v2 / v3-via-Record.load) ───────────────────────── + + def _v2_record(self, custom_fields): + from keepercommander.record import Record + r = Record(utils.generate_uid()) + r.title = 'Test' + r.custom_fields = list(custom_fields) + return r + + def test_detail_masks_secret_type(self): + r = self._v2_record([{'type': 'secret', 'name': 'Token', 'value': 'top-secret'}]) + with mock.patch('builtins.print') as p: + r.display(unmask=False) + out = self._printed(p) + self.assertNotIn('top-secret', out) + self.assertIn('********', out) + + def test_detail_unmask_reveals_secret(self): + r = self._v2_record([{'type': 'secret', 'name': 'Token', 'value': 'top-secret'}]) + with mock.patch('builtins.print') as p: + r.display(unmask=True) + self.assertIn('top-secret', self._printed(p)) + + def test_detail_masks_pincode_type(self): + r = self._v2_record([{'type': 'pinCode', 'name': 'PIN', 'value': '1234'}]) + with mock.patch('builtins.print') as p: + r.display(unmask=False) + out = self._printed(p) + self.assertNotIn('1234', out) + self.assertIn('********', out) + + def test_detail_masks_v3_secret_prefix(self): + # v3 with label: Record.load() encodes type in name as "secret:Label" + r = self._v2_record([{'type': 'text', 'name': 'secret:Token', 'value': 'top-secret'}]) + with mock.patch('builtins.print') as p: + r.display(unmask=False) + out = self._printed(p) + self.assertNotIn('top-secret', out) + self.assertIn('********', out) + + def test_detail_masks_v3_secret_no_label(self): + # v3 without label: Record.load() stores just the type as the name, no colon + r = self._v2_record([{'type': 'text', 'name': 'secret', 'value': 'top-secret'}]) + with mock.patch('builtins.print') as p: + r.display(unmask=False) + out = self._printed(p) + self.assertNotIn('top-secret', out) + self.assertIn('********', out) + + def test_detail_masks_v3_pincode_no_label(self): + r = self._v2_record([{'type': 'text', 'name': 'pinCode', 'value': '9999'}]) + with mock.patch('builtins.print') as p: + r.display(unmask=False) + out = self._printed(p) + self.assertNotIn('9999', out) + self.assertIn('********', out) + + def test_detail_does_not_mask_text_field(self): + r = self._v2_record([{'type': 'text', 'name': 'Note', 'value': 'public info'}]) + with mock.patch('builtins.print') as p: + r.display(unmask=False) + self.assertIn('public info', self._printed(p)) + + def test_detail_masks_security_question_answer_only(self): + # v3 custom field: Record.load() stores type='text', name='securityQuestion', value=dict + r = self._v2_record([{'type': 'text', 'name': 'securityQuestion', + 'value': {'question': 'MyQuestion', 'answer': 'MyAnswer'}}]) + with mock.patch('builtins.print') as p: + r.display(unmask=False) + out = self._printed(p) + self.assertNotIn('MyAnswer', out) + self.assertIn('MyQuestion', out) + self.assertIn('********', out) + + def test_detail_unmask_reveals_security_question_answer(self): + r = self._v2_record([{'type': 'text', 'name': 'securityQuestion', + 'value': {'question': 'MyQuestion', 'answer': 'MyAnswer'}}]) + with mock.patch('builtins.print') as p: + r.display(unmask=True) + out = self._printed(p) + self.assertIn('MyQuestion', out) + self.assertIn('MyAnswer', out) + self.assertNotIn('********', out) + + # ── RecordV3.display() ────────────────────────────────────────────────── + + def _v3_cache_entry(self, fields=None, custom=None): + data = json.dumps({ + 'type': 'login', 'title': 'Test', + 'fields': fields or [], + 'custom': custom or [], + }).encode() + return {'record_uid': utils.generate_uid(), 'data_unencrypted': data} + + def test_v3_detail_masks_json_field(self): + from keepercommander.recordv3 import RecordV3 + rec = self._v3_cache_entry( + fields=[{'type': 'json', 'label': 'Config', 'value': ['{"k":"v"}']}]) + with mock.patch('builtins.print') as p: + RecordV3.display(rec, unmask=False, params=None) + out = self._printed(p) + self.assertNotIn('"k"', out) + self.assertIn('********', out) + + def test_v3_detail_masks_security_question_answer(self): + from keepercommander.recordv3 import RecordV3 + rec = self._v3_cache_entry(fields=[{ + 'type': 'securityQuestion', 'label': 'SQ', + 'value': [{'question': 'Mothers maiden name', 'answer': 'Smith'}], + }]) + with mock.patch('builtins.print') as p: + RecordV3.display(rec, unmask=False, params=None) + out = self._printed(p) + self.assertNotIn('Smith', out) + self.assertIn('********', out) + self.assertIn('Mothers maiden name', out) + + def test_v3_detail_unmask_reveals_security_answer(self): + from keepercommander.recordv3 import RecordV3 + rec = self._v3_cache_entry(fields=[{ + 'type': 'securityQuestion', 'label': 'SQ', + 'value': [{'question': 'Mothers maiden name', 'answer': 'Smith'}], + }]) + with mock.patch('builtins.print') as p: + RecordV3.display(rec, unmask=True, params=None) + self.assertIn('Smith', self._printed(p)) + + # ── fields format ────────────────────────────────────────────────────── + + def _run_fields(self, custom_fields, unmask=False): + from keepercommander.record import Record as LegacyRecord + params = get_synced_params() + r = LegacyRecord(utils.generate_uid()) + r.title = 'Test' + r.custom_fields = list(custom_fields) + params.record_cache[r.record_uid] = {'version': 2, 'shared': False} + captured = [] + cmd = record.RecordGetUidCommand() + with mock.patch('builtins.print', side_effect=captured.append), \ + mock.patch('keepercommander.api.get_record', return_value=r), \ + mock.patch('keepercommander.api.get_record_shares'), \ + mock.patch('keepercommander.api.get_share_admins_for_record', return_value=[]): + cmd.execute(params, uid=r.record_uid, format='fields', unmask=unmask) + return json.loads(captured[-1]) + + def test_fields_includes_secret_custom_field_masked(self): + fields = self._run_fields([{'type': 'secret', 'name': 'Token', 'value': 'top-secret'}]) + f = next((x for x in fields if x['name'] == 'Token'), None) + self.assertIsNotNone(f) + self.assertEqual(f['value'], '********') + + def test_fields_masks_v3_secret_no_label(self): + # v3 custom secret without label: type='text', name='secret' + fields = self._run_fields([{'type': 'text', 'name': 'secret', 'value': 'top-secret'}]) + f = next((x for x in fields if x['name'] == 'secret'), None) + self.assertIsNotNone(f) + self.assertEqual(f['value'], '********') + + def test_fields_masks_v3_pincode_no_label(self): + fields = self._run_fields([{'type': 'text', 'name': 'pinCode', 'value': '9999'}]) + f = next((x for x in fields if x['name'] == 'pinCode'), None) + self.assertIsNotNone(f) + self.assertEqual(f['value'], '********') + + def test_fields_unmask_reveals_secret(self): + fields = self._run_fields( + [{'type': 'secret', 'name': 'Token', 'value': 'top-secret'}], unmask=True) + f = next((x for x in fields if x['name'] == 'Token'), None) + self.assertIsNotNone(f) + self.assertEqual(f['value'], 'top-secret') + + def test_fields_excludes_empty_custom_fields(self): + fields = self._run_fields([ + {'type': 'text', 'name': 'Empty', 'value': ''}, + {'type': 'text', 'name': 'Present', 'value': 'hello'}, + ]) + names = [x['name'] for x in fields] + self.assertNotIn('Empty', names) + self.assertIn('Present', names) + + def test_fields_security_question_masks_answer_only(self): + fields = self._run_fields([{ + 'type': 'securityQuestion', 'name': 'securityQuestion', + 'value': [{'question': 'MyQuestion', 'answer': 'MyAnswer'}], + }]) + f = next((x for x in fields if x['name'] == 'securityQuestion'), None) + self.assertIsNotNone(f) + self.assertIsInstance(f['value'], dict) + self.assertEqual(f['value']['question'], 'MyQuestion') + self.assertEqual(f['value']['answer'], '********') + + def test_fields_security_question_unmask_reveals_answer(self): + fields = self._run_fields([{ + 'type': 'securityQuestion', 'name': 'securityQuestion', + 'value': [{'question': 'MyQuestion', 'answer': 'MyAnswer'}], + }], unmask=True) + f = next((x for x in fields if x['name'] == 'securityQuestion'), None) + self.assertIsNotNone(f) + self.assertIsInstance(f['value'], dict) + self.assertEqual(f['value']['question'], 'MyQuestion') + self.assertEqual(f['value']['answer'], 'MyAnswer') +