diff --git a/README.md b/README.md index 6ccf96c..d9b3fa7 100644 --- a/README.md +++ b/README.md @@ -18,9 +18,9 @@ Rust-powered WebSocket server with Python API for remote command execution and i - **Command Execution** - Run shell commands on remote machines with timeout control - **Interactive Sessions** - Full PTY sessions with bash for manual work -- **File Transfer** - Upload/download files between controller and workers +- **File Transfer** - Upload/download files between server and daemons - **High Performance** - Rust async runtime handles high-concurrency workloads -- **Auto Reconnection** - Workers reconnect automatically on network failures +- **Auto Reconnection** - Daemons reconnect automatically on network failures - **Cross-Platform** - Linux, macOS, Windows support ## Architecture @@ -57,34 +57,32 @@ Rust-powered WebSocket server with Python API for remote command execution and i └───────┘ └───────┘ └───────┘ ``` -**Key Design**: Daemons connect **TO** the agent (not the other way around), so no ports need to be exposed on the execution plane. - ## Quick Start ```bash # Build make install # Python package -make daemon-release # Worker binary +make daemon-release # Daemon binary ``` -**Start controller:** +**Start server:** ```python from sandd import Server server = Server("0.0.0.0", 8765) -server.wait_for_daemon("worker-1", timeout=30) +server.wait_for_daemon("daemon-1", timeout=30) -result = server.exec("worker-1", "hostname") +result = server.exec("daemon-1", "hostname") print(result.stdout) ``` -**Start worker:** +**Start daemon:** ```bash ./target/release/sandd \ - --server-url ws://controller:8765/ws \ - --daemon-id worker-1 + --server-url ws://:8765/ws \ + --daemon-id daemon-1 ``` ## Documentation @@ -101,7 +99,7 @@ print(result.stdout) - Use `wss://` (TLS) instead of plain `ws://` - Add authentication (tokens, mTLS) -- Run workers in containers +- Run daemons in containers - Validate commands before execution - Audit log all commands diff --git a/examples/programmatic_session.py b/examples/programmatic_session.py new file mode 100755 index 0000000..825366e --- /dev/null +++ b/examples/programmatic_session.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +""" +Example: Programmatic Session Control + +Demonstrates how to use non-interactive sessions for: +- Multi-step command sequences +- Inspecting session output programmatically +- Handling session state and errors +- Building automation scripts +""" + +import sys +import time +from sandd import Server + + +def main(): + # Start server + server = Server("0.0.0.0", 8765) + print(f"Server listening on {server.address}") + + # Wait for daemon + daemon_id = "daemon-1" + print(f"\nWaiting for daemon '{daemon_id}'...") + if not server.wait_for_daemon(daemon_id, timeout=30): + print(f"Daemon '{daemon_id}' did not connect") + sys.exit(1) + + print(f"Daemon '{daemon_id}' connected!\n") + + # Create a non-interactive session + print("=== Creating Session ===") + session = server.new_session(daemon_id, rows=24, cols=80) + print("Session created\n") + + # Example 1: Execute command and capture output + print("=== Example 1: Basic Command ===") + session.write(b"echo 'Hello from session'\n") + time.sleep(0.2) + output = session.read(timeout=1.0) + if output: + print(f"Output: {output.decode()}") + + # Example 2: Multi-step workflow + print("\n=== Example 2: Multi-Step Workflow ===") + steps = [ + ("mkdir -p /tmp/test", "Creating directory"), + ("cd /tmp/test", "Changing directory"), + ("pwd", "Verifying location"), + ("echo 'test' > file.txt", "Creating file"), + ("cat file.txt", "Reading file"), + ] + + for cmd, description in steps: + print(f"{description}: {cmd}") + session.write(f"{cmd}\n".encode()) + time.sleep(0.1) + output = session.read(timeout=1.0) + if output: + result = output.decode().strip() + if result: + print(f" → {result}") + + # Example 3: Error handling + print("\n=== Example 3: Error Handling ===") + session.write(b"exit 42\n") # Exit with non-zero code + time.sleep(0.2) + + # Try to write after exit - should fail gracefully + try: + session.write(b"echo 'after exit'\n") + output = session.read(timeout=1.0) + if output: + print(f"Output: {output.decode()}") + except Exception as e: + print(f"Session closed (expected): {e}") + + # Example 4: Create new session for long-running task + print("\n=== Example 4: Long-Running Task ===") + session2 = server.new_session(daemon_id) + session2.write(b"for i in 1 2 3; do echo \"Step $i\"; sleep 1; done\n") + + # Stream output as it arrives + start = time.time() + while time.time() - start < 5: + output = session2.read(timeout=0.5) + if output: + print(output.decode(), end='', flush=True) + else: + break + + session2.close() + print("\n\nSession closed") + + +if __name__ == "__main__": + main() diff --git a/python/sandd/__init__.py b/python/sandd/__init__.py index 26a0594..aeca31c 100644 --- a/python/sandd/__init__.py +++ b/python/sandd/__init__.py @@ -283,14 +283,13 @@ def download_file( def list_daemons( self, - label_key: Optional[str] = None, - label_value: Optional[str] = None, + labels: Optional[Dict[str, str]] = None, ) -> List[str]: - """List all connected daemon IDs, optionally filtered by label + """List all connected daemon IDs, optionally filtered by labels Args: - label_key: Label key to filter by (only applied when label_value is also provided) - label_value: Label value to filter by (only applied when label_key is also provided) + labels: Dictionary of label key-value pairs to filter by (AND logic) + All specified labels must match for a daemon to be included Returns: List of daemon IDs @@ -300,12 +299,15 @@ def list_daemons( >>> daemons = server.list_daemons() >>> print(f"Connected: {len(daemons)} daemons") >>> - >>> # List daemons with env=prod label - >>> prod_daemons = server.list_daemons(label_key="env", label_value="prod") - >>> for daemon_id in prod_daemons: + >>> # List daemons with single label + >>> prod_daemons = server.list_daemons(labels={"env": "prod"}) + >>> + >>> # List daemons with multiple labels (AND logic) + >>> west_prod = server.list_daemons(labels={"env": "prod", "region": "us-west"}) + >>> for daemon_id in west_prod: ... print(f" - {daemon_id}") """ - return self._server.list_daemons(label_key, label_value) + return self._server.list_daemons(labels) def daemon_count(self) -> int: """Get number of connected daemons diff --git a/python/tests/test_e2e.py b/python/tests/test_e2e.py index 78d1c7f..61c5ac1 100644 --- a/python/tests/test_e2e.py +++ b/python/tests/test_e2e.py @@ -107,34 +107,70 @@ def run_cmd(daemon_id): assert all(r.success for r in results) assert all("Response from" in r.stdout for r in results) + def test_concurrent_execution_same_daemon(self, server): + """Execute multiple commands concurrently on the same daemon""" + import concurrent.futures + import time + + daemon_id = "daemon-debian-1" + + def run_sleep(n): + start = time.time() + result = server.exec(daemon_id, f"sleep {n} && echo 'slept {n}s'", timeout=10) + duration = time.time() - start + return result, duration + + def run_fast(): + start = time.time() + result = server.exec(daemon_id, "echo 'fast command'", timeout=5) + duration = time.time() - start + return result, duration + + # Start slow command (3s) and fast command concurrently + with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor: + slow_future = executor.submit(run_sleep, 3) + fast_future = executor.submit(run_fast) + + # Fast command should complete quickly, not wait for slow one + fast_result, fast_duration = fast_future.result() + assert fast_result.success + assert "fast command" in fast_result.stdout + assert fast_duration < 1.0 # Should finish in <1s, not wait for 3s sleep + + # Slow command completes independently + slow_result, slow_duration = slow_future.result() + assert slow_result.success + assert "slept 3s" in slow_result.stdout + assert 2.5 < slow_duration < 4.0 + class TestE2ELabels: """Test label-based filtering in E2E""" def test_filter_by_env_label(self, server): """Filter daemons by env label""" - test_daemons = server.list_daemons(label_key="env", label_value="test") + test_daemons = server.list_daemons(labels={"env": "test"}) assert "daemon-debian-1" in test_daemons assert "daemon-debian-2" in test_daemons assert "daemon-alpine-1" in test_daemons assert "daemon-rocky-2" in test_daemons - prod_daemons = server.list_daemons(label_key="env", label_value="prod") + prod_daemons = server.list_daemons(labels={"env": "prod"}) assert "daemon-alpine-2" in prod_daemons assert "daemon-rocky-1" in prod_daemons def test_filter_by_distro_label(self, server): """Filter daemons by distribution""" - debian_daemons = server.list_daemons(label_key="distro", label_value="debian") + debian_daemons = server.list_daemons(labels={"distro": "debian"}) assert "daemon-debian-1" in debian_daemons assert "daemon-debian-2" in debian_daemons assert len(debian_daemons) >= 2 - alpine_daemons = server.list_daemons(label_key="distro", label_value="alpine") + alpine_daemons = server.list_daemons(labels={"distro": "alpine"}) assert "daemon-alpine-1" in alpine_daemons assert "daemon-alpine-2" in alpine_daemons - rocky_daemons = server.list_daemons(label_key="distro", label_value="rocky") + rocky_daemons = server.list_daemons(labels={"distro": "rocky"}) assert "daemon-rocky-1" in rocky_daemons assert "daemon-rocky-2" in rocky_daemons @@ -289,18 +325,22 @@ def test_session_multiline_commands(self, server): try: # Send multi-line command - session.write(b"for i in 1 2 3; do\n") - time.sleep(0.2) - session.write(b"echo $i\n") - time.sleep(0.2) - session.write(b"done\n") + session.write(b"for i in 1 2 3; do echo $i; done\n") time.sleep(0.5) - output = session.read(timeout=2.0) - assert output is not None - output_str = output.decode('utf-8', errors='ignore') + # Read all output chunks + all_output = b'' + for _ in range(5): + output = session.read(timeout=0.5) + if output: + all_output += output + else: + break + + assert all_output + output_str = all_output.decode('utf-8', errors='ignore') # Should see the numbers - assert '1' in output_str and '2' in output_str + assert '1' in output_str and '2' in output_str and '3' in output_str finally: session.close() diff --git a/python/tests/test_integration.py b/python/tests/test_integration.py index ad2c4ac..3a661fc 100644 --- a/python/tests/test_integration.py +++ b/python/tests/test_integration.py @@ -165,22 +165,27 @@ def test_daemon_with_labels(self, server, sandd_binary): assert len(all_daemons) >= 2 # Test: filter by env=prod - prod_daemons = server.list_daemons(label_key="env", label_value="prod") + prod_daemons = server.list_daemons(labels={"env": "prod"}) assert daemon_id_prod in prod_daemons assert daemon_id_dev not in prod_daemons # Test: filter by env=dev - dev_daemons = server.list_daemons(label_key="env", label_value="dev") + dev_daemons = server.list_daemons(labels={"env": "dev"}) assert daemon_id_dev in dev_daemons assert daemon_id_prod not in dev_daemons # Test: filter by region=us-west - region_daemons = server.list_daemons(label_key="region", label_value="us-west") + region_daemons = server.list_daemons(labels={"region": "us-west"}) assert daemon_id_prod in region_daemons assert daemon_id_dev not in region_daemons + # Test: filter by multiple labels (AND logic) + west_prod = server.list_daemons(labels={"env": "prod", "region": "us-west"}) + assert daemon_id_prod in west_prod + assert daemon_id_dev not in west_prod + # Test: filter by non-existent label - none_daemons = server.list_daemons(label_key="env", label_value="staging") + none_daemons = server.list_daemons(labels={"env": "staging"}) assert daemon_id_prod not in none_daemons assert daemon_id_dev not in none_daemons diff --git a/python/tests/test_unit.py b/python/tests/test_unit.py index 528984c..e7becc7 100644 --- a/python/tests/test_unit.py +++ b/python/tests/test_unit.py @@ -54,13 +54,13 @@ def test_empty_when_no_daemons(self): def test_with_label_filters(self): server = Server() - result = server.list_daemons(label_key="env", label_value="prod") + result = server.list_daemons(labels={"env": "prod"}) assert isinstance(result, list) - def test_with_partial_filters(self): + def test_with_multiple_labels(self): server = Server() - assert isinstance(server.list_daemons(label_key="env"), list) - assert isinstance(server.list_daemons(label_value="prod"), list) + result = server.list_daemons(labels={"env": "prod", "region": "us-west"}) + assert isinstance(result, list) class TestDaemonCount: diff --git a/server/src/lib.rs b/server/src/lib.rs index 1d93c0b..139aacc 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -218,16 +218,13 @@ impl Server { }) } - /// List all connected daemons, optionally filtered by label - #[pyo3(signature = (label_key=None, label_value=None))] + /// List all connected daemons, optionally filtered by labels + #[pyo3(signature = (labels=None))] fn list_daemons( &self, - label_key: Option, - label_value: Option, + labels: Option>, ) -> PyResult> { - let key_ref = label_key.as_deref(); - let value_ref = label_value.as_deref(); - Ok(self.registry.list_all(key_ref, value_ref)) + Ok(self.registry.list_all(labels.as_ref())) } /// Get daemon count diff --git a/server/src/registry.rs b/server/src/registry.rs index f754190..7ad71e9 100644 --- a/server/src/registry.rs +++ b/server/src/registry.rs @@ -196,20 +196,22 @@ impl DaemonRegistry { } } - pub fn list_all(&self, label_key: Option<&str>, label_value: Option<&str>) -> Vec { + pub fn list_all(&self, labels: Option<&std::collections::HashMap>) -> Vec { self.connections .iter() .filter(|entry| { - match (label_key, label_value) { - (Some(key), Some(value)) => { - // Filter by label - entry - .value() - .metadata - .labels - .get(key) - .map(|v| v == value) - .unwrap_or(false) + match labels { + Some(filter_labels) if !filter_labels.is_empty() => { + // Check if daemon has ALL specified labels (AND logic) + filter_labels.iter().all(|(key, value)| { + entry + .value() + .metadata + .labels + .get(key) + .map(|v| v == value) + .unwrap_or(false) + }) } _ => true, // No filter, include all } @@ -426,7 +428,7 @@ mod tests { registry.register(conn); } - let daemons = registry.list_all(None, None); + let daemons = registry.list_all(None); assert_eq!(daemons.len(), 3); } @@ -434,18 +436,20 @@ mod tests { fn test_list_all_with_label_filter() { let registry = DaemonRegistry::new(); - // Daemon with env=prod label + // Daemon with env=prod, region=us-west let (tx1, _rx1) = mpsc::unbounded_channel(); let mut labels1 = HashMap::new(); labels1.insert("env".to_string(), "prod".to_string()); + labels1.insert("region".to_string(), "us-west".to_string()); let metadata1 = create_test_metadata_with_labels("host1", "linux", labels1); let conn1 = DaemonConnection::new("daemon-1".to_string(), metadata1, tx1); registry.register(conn1); - // Daemon with env=dev label + // Daemon with env=dev, region=us-east let (tx2, _rx2) = mpsc::unbounded_channel(); let mut labels2 = HashMap::new(); labels2.insert("env".to_string(), "dev".to_string()); + labels2.insert("region".to_string(), "us-east".to_string()); let metadata2 = create_test_metadata_with_labels("host2", "linux", labels2); let conn2 = DaemonConnection::new("daemon-2".to_string(), metadata2, tx2); registry.register(conn2); @@ -456,22 +460,48 @@ mod tests { let conn3 = DaemonConnection::new("daemon-3".to_string(), metadata3, tx3); registry.register(conn3); - // Filter by env=prod - let prod_daemons = registry.list_all(Some("env"), Some("prod")); + // Filter by single label: env=prod + let mut filter = HashMap::new(); + filter.insert("env".to_string(), "prod".to_string()); + let prod_daemons = registry.list_all(Some(&filter)); assert_eq!(prod_daemons.len(), 1); assert_eq!(prod_daemons[0], "daemon-1"); - // Filter by env=dev - let dev_daemons = registry.list_all(Some("env"), Some("dev")); + // Filter by single label: env=dev + let mut filter = HashMap::new(); + filter.insert("env".to_string(), "dev".to_string()); + let dev_daemons = registry.list_all(Some(&filter)); assert_eq!(dev_daemons.len(), 1); assert_eq!(dev_daemons[0], "daemon-2"); + // Filter by multiple labels: env=prod AND region=us-west (match) + let mut filter = HashMap::new(); + filter.insert("env".to_string(), "prod".to_string()); + filter.insert("region".to_string(), "us-west".to_string()); + let multi_match = registry.list_all(Some(&filter)); + assert_eq!(multi_match.len(), 1); + assert_eq!(multi_match[0], "daemon-1"); + + // Filter by multiple labels: env=prod AND region=us-east (no match) + let mut filter = HashMap::new(); + filter.insert("env".to_string(), "prod".to_string()); + filter.insert("region".to_string(), "us-east".to_string()); + let no_match = registry.list_all(Some(&filter)); + assert_eq!(no_match.len(), 0); + // Filter by nonexistent label - let none_daemons = registry.list_all(Some("env"), Some("staging")); + let mut filter = HashMap::new(); + filter.insert("env".to_string(), "staging".to_string()); + let none_daemons = registry.list_all(Some(&filter)); assert_eq!(none_daemons.len(), 0); + // Empty filter returns all + let empty_filter = HashMap::new(); + let all_daemons = registry.list_all(Some(&empty_filter)); + assert_eq!(all_daemons.len(), 3); + // No filter returns all - let all_daemons = registry.list_all(None, None); + let all_daemons = registry.list_all(None); assert_eq!(all_daemons.len(), 3); }