diff --git a/Cargo.lock b/Cargo.lock index cf6d2a6..217ee09 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1527,6 +1527,16 @@ dependencies = [ "syn", ] +[[package]] +name = "pythonize" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffd1c3ef39c725d63db5f9bc455461bafd80540cb7824c61afb823501921a850" +dependencies = [ + "pyo3", + "serde", +] + [[package]] name = "quote" version = "1.0.45" @@ -1691,6 +1701,7 @@ dependencies = [ "futures-util", "parking_lot 0.12.5", "pyo3", + "pythonize", "serde", "serde_json", "tokio", diff --git a/Makefile b/Makefile index 71a583d..c8b0c27 100644 --- a/Makefile +++ b/Makefile @@ -74,6 +74,11 @@ docker-down: lint: $(RUFF) $(RUFF) check . +.PHONY: format +format: $(RUFF) + $(RUFF) check --fix . + $(RUFF) format . + $(RUFF): @echo "Installing ruff..." @python3 -m venv .venv || true diff --git a/docs/proposals/SNAPSHOTS.md b/docs/proposals/SNAPSHOTS.md index 00e1459..e103df6 100644 --- a/docs/proposals/SNAPSHOTS.md +++ b/docs/proposals/SNAPSHOTS.md @@ -192,7 +192,7 @@ impl SnapshotManager { /// Find snapshot by tag (O(1) lookup via tag ref) /// Returns single snapshot since tags are immutable - pub async fn find_by_tag(&self, tag: &str) -> Result>; + pub async fn find_snapshot_by_tag(&self, tag: &str) -> Result>; /// Get snapshot by ID pub async fn get_snapshot(&self, id: &str) -> Result; @@ -206,45 +206,84 @@ impl SnapshotManager { ## Protocol Integration -**Note:** See [Protocol Specification](PROTOCOL.md) for complete message format details. +Snapshot operations are exposed via WebSocket protocol messages. All operations include a `request_id` for matching requests with responses. -**New message types:** +**Message types:** ```rust -pub enum Request { +pub enum Message { + // Create snapshot CreateSnapshot { - daemon_id: String, - workspace_path: String, - message: String, - tags: Vec, + request_id: String, + workspace: String, // Path to workspace directory + message: Option, // Optional description + tags: Option>, // Optional tags (must be unique) + }, + SnapshotCreated { + request_id: String, + snapshot_id: String, // UUID of created snapshot + file_count: usize, // Number of files captured + total_size: u64, // Total size in bytes }, + // Restore snapshot RestoreSnapshot { - daemon_id: String, - snapshot_id: String, - destination: String, + request_id: String, + snapshot_id: String, // Snapshot ID + destination: String, // Path to restore to + }, + SnapshotRestored { + request_id: String, + file_count: usize, // Number of files restored }, - ListSnapshots { daemon_id: String }, - DeleteSnapshot { daemon_id: String, snapshot_id: String }, - GarbageCollect { daemon_id: String }, -} + // List snapshots (with optional tag filter) + ListSnapshots { + request_id: String, + tags: Option>, // OR filter: snapshots with any of these tags + }, + SnapshotList { + request_id: String, + snapshots: Vec, // Sorted by creation time (newest first) + }, -pub enum Response { - SnapshotCreated { + // Find snapshot by tag (O(1) lookup) + FindSnapshotByTag { + request_id: String, + tag: String, // Tag name (immutable) + }, + // Get snapshot details + GetSnapshot { + request_id: String, + snapshot_id: String, + }, + SnapshotDetails { + request_id: String, + snapshot: Option, // Snapshot metadata, None if doesn't exist + }, + + // Delete snapshot (also removes tag refs) + DeleteSnapshot { + request_id: String, snapshot_id: String, - file_count: usize, - total_size: u64, - duration_ms: u64, + }, + SnapshotDeleted { + request_id: String, }, - SnapshotRestored { file_count: usize, duration_ms: u64 }, - Snapshots { snapshots: Vec }, - SnapshotDeleted { freed_bytes: u64 }, - GarbageCollected { objects_deleted: usize, bytes_freed: u64 }, + // Error response + SnapshotError { + request_id: String, + error: String, // Error message (e.g., "Tag 'v1.0.0' already exists") + }, } ``` +**Error cases:** +- `CreateSnapshot` with existing tag → `SnapshotError` +- `RestoreSnapshot`/`GetSnapshot`/`DeleteSnapshot` with non-existent ID → `SnapshotError` +- File I/O errors → `SnapshotError` + --- @@ -275,7 +314,7 @@ async fn main() -> Result<()> { } // Find snapshot by tag (O(1) lookup) - if let Some(snapshot) = manager.find_by_tag("pre-task").await? { + if let Some(snapshot) = manager.find_snapshot_by_tag("pre-task").await? { println!("Found: {}", snapshot.id); } diff --git a/examples/exec_commands.py b/examples/exec_commands.py index fa72c94..e844084 100755 --- a/examples/exec_commands.py +++ b/examples/exec_commands.py @@ -17,7 +17,11 @@ daemons = server.list_daemons() stats = server.get_stats() - print(f"\rConnected: {stats.total_daemons} | Platforms: {stats.by_platform}", end="", flush=True) + print( + f"\rConnected: {stats.total_daemons} | Platforms: {stats.by_platform}", + end="", + flush=True, + ) if daemons and len(daemons) > 0: for daemon in daemons: @@ -27,28 +31,38 @@ result = server.exec( daemon_id, "python3 -c 'import sys; print(f\"Python {sys.version_info.major}.{sys.version_info.minor}\")'", - timeout=5 + timeout=5, ) if result.success: - print(f"\n✓ Python test passed on {daemon_id}: {result.stdout.strip()}") + print( + f"\n✓ Python test passed on {daemon_id}: {result.stdout.strip()}" + ) else: - print(f"\n✗ Python test failed on {daemon_id}: exit_code={result.exit_code}") + print( + f"\n✗ Python test failed on {daemon_id}: exit_code={result.exit_code}" + ) # Test 2: Wrong Python script (intentional error) result = server.exec( - daemon_id, - "python3 -c 'undefined_variable'", - timeout=5 + daemon_id, "python3 -c 'undefined_variable'", timeout=5 ) if not result.success: - print(f"✓ Error handling test passed on {daemon_id}, stderr: {result.stderr.strip()}") + print( + f"✓ Error handling test passed on {daemon_id}, stderr: {result.stderr.strip()}" + ) else: - print(f"✗ Error handling test failed on {daemon_id}: expected error but got success") + print( + f"✗ Error handling test failed on {daemon_id}: expected error but got success" + ) # Test 3: Echo command - result = server.exec(daemon_id, "echo 'Hello from daemon!'", timeout=5) + result = server.exec( + daemon_id, "echo 'Hello from daemon!'", timeout=5 + ) if result.success: - print(f"✓ Echo test passed on {daemon_id}: {result.stdout.strip()}") + print( + f"✓ Echo test passed on {daemon_id}: {result.stdout.strip()}" + ) except Exception as e: print(f"\n✗ Command failed on {daemon_id}: {e}") diff --git a/examples/install_htop.py b/examples/install_htop.py index 81f8102..06b2485 100755 --- a/examples/install_htop.py +++ b/examples/install_htop.py @@ -39,9 +39,7 @@ def install_htop(server, daemon_id): else: # Linux - detect distribution distro_result = server.exec( - daemon_id, - "cat /etc/os-release 2>/dev/null || echo 'unknown'", - timeout=5 + daemon_id, "cat /etc/os-release 2>/dev/null || echo 'unknown'", timeout=5 ) if not distro_result.success: @@ -56,7 +54,9 @@ def install_htop(server, daemon_id): elif "ubuntu" in distro or "debian" in distro: cmd = "apt-get update && apt-get install -y htop" elif "rocky" in distro or "rhel" in distro or "centos" in distro: - cmd = "microdnf install -y htop || dnf install -y htop || yum install -y htop" + cmd = ( + "microdnf install -y htop || dnf install -y htop || yum install -y htop" + ) elif "fedora" in distro: cmd = "dnf install -y htop" else: @@ -84,7 +84,9 @@ def main(): # Wait for at least one daemon print("Waiting for daemons to connect...") - print("(Start a daemon with: ./target/release/sandd --server-url ws://127.0.0.1:8765/ws)") + print( + "(Start a daemon with: ./target/release/sandd --server-url ws://127.0.0.1:8765/ws)" + ) daemons = server.list_daemons() while not daemons: time.sleep(1) @@ -103,7 +105,7 @@ def main(): result = server.exec(daemon_id, "htop --version", timeout=5) if result.success: # htop version is usually first line - version_line = result.stdout.split('\n')[0] + version_line = result.stdout.split("\n")[0] print(f" {version_line}") else: print("✗ htop is not installed") @@ -114,7 +116,7 @@ def main(): # Verify installation result = server.exec(daemon_id, "htop --version", timeout=5) if result.success: - version_line = result.stdout.split('\n')[0] + version_line = result.stdout.split("\n")[0] print(f" {version_line}") else: print("Failed to install htop") @@ -139,10 +141,10 @@ def main(): if result.success: print(f"✓ {description}") # Show first few lines of output - output_lines = result.stdout.strip().split('\n')[:3] + output_lines = result.stdout.strip().split("\n")[:3] for line in output_lines: print(f" {line}") - if len(result.stdout.strip().split('\n')) > 3: + if len(result.stdout.strip().split("\n")) > 3: print(" ...") print() else: diff --git a/examples/programmatic_session.py b/examples/programmatic_session.py index 825366e..94497f9 100755 --- a/examples/programmatic_session.py +++ b/examples/programmatic_session.py @@ -78,14 +78,14 @@ def main(): # Example 4: Create new session for long-running task print("\n=== Example 4: Long-Running Task ===") session2 = server.new_session(daemon_id) - session2.write(b"for i in 1 2 3; do echo \"Step $i\"; sleep 1; done\n") + session2.write(b'for i in 1 2 3; do echo "Step $i"; sleep 1; done\n') # Stream output as it arrives start = time.time() while time.time() - start < 5: output = session2.read(timeout=0.5) if output: - print(output.decode(), end='', flush=True) + print(output.decode(), end="", flush=True) else: break diff --git a/examples/snapshot_example.py b/examples/snapshot_example.py new file mode 100755 index 0000000..a113984 --- /dev/null +++ b/examples/snapshot_example.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python3 +""" +Snapshot Management Example + +Demonstrates how to: +1. Create snapshots with tags +2. List and filter snapshots +3. Find snapshots by tag +4. Restore snapshots +5. Delete snapshots + +Usage: + python examples/snapshot_example.py +""" + +import sys +import tempfile +from pathlib import Path + +from sandd import Server + + +def main(): + # Create a temporary workspace + with tempfile.TemporaryDirectory() as workspace: + workspace_path = Path(workspace) + print(f"📁 Using workspace: {workspace_path}\n") + + # Create some test files + (workspace_path / "file1.txt").write_text("Hello World") + (workspace_path / "file2.txt").write_text("Python Snapshot Example") + (workspace_path / "subdir").mkdir() + (workspace_path / "subdir" / "file3.txt").write_text("Nested file") + + # Start server + server = Server(host="0.0.0.0", port=8765) + print("✅ Server started on 0.0.0.0:8765") + print("⏳ Waiting for daemon to connect...\n") + + # Wait for at least one daemon to connect + import time + + while server.daemon_count() == 0: + time.sleep(0.5) + + daemons = server.list_daemons() + daemon_id = daemons[0].id + print(f"✅ Connected to daemon: {daemon_id}\n") + + # ========== 1. Create Snapshots ========== + print("=" * 60) + print("1️⃣ Creating Snapshots") + print("=" * 60) + + snapshot_id1 = server.create_snapshot( + daemon_id=daemon_id, + workspace=str(workspace_path), + message="Initial snapshot with 3 files", + tags=["v1.0", "initial"], + ) + print(f"✅ Created snapshot 1: {snapshot_id1}") + print(" Tags: v1.0, initial\n") + + # Modify workspace + (workspace_path / "file2.txt").write_text("Modified content") + (workspace_path / "file4.txt").write_text("New file") + + snapshot_id2 = server.create_snapshot( + daemon_id=daemon_id, + workspace=str(workspace_path), + message="After modifications", + tags=["v1.1"], + ) + print(f"✅ Created snapshot 2: {snapshot_id2}") + print(" Tags: v1.1\n") + + # ========== 2. List All Snapshots ========== + print("=" * 60) + print("2️⃣ Listing All Snapshots") + print("=" * 60) + + snapshots = server.list_snapshots(daemon_id=daemon_id) + for i, snap in enumerate(snapshots, 1): + print(f"\nSnapshot {i}:") + print(f" ID: {snap.id}") + print(f" Message: {snap.message}") + print(f" Tags: {', '.join(snap.tags)}") + print(f" Files: {snap.file_count}") + print(f" Size: {snap.total_size} bytes") + print(f" Created: {snap.created_at}") + + # ========== 3. Filter by Tags ========== + print("\n" + "=" * 60) + print("3️⃣ Filtering Snapshots by Tag") + print("=" * 60) + + v1_snapshots = server.list_snapshots(daemon_id=daemon_id, tags=["v1.0"]) + print(f"\n📌 Snapshots with tag 'v1.0': {len(v1_snapshots)}") + for snap in v1_snapshots: + print(f" - {snap.message} ({snap.id})") + + # ========== 4. Find by Tag ========== + print("\n" + "=" * 60) + print("4️⃣ Finding Snapshot by Tag") + print("=" * 60) + + initial_snapshot = server.find_snapshot_by_tag( + daemon_id=daemon_id, tag="initial" + ) + if initial_snapshot: + print("\n✅ Found snapshot with tag 'initial':") + print(f" ID: {initial_snapshot.id}") + print(f" Message: {initial_snapshot.message}") + else: + print("❌ No snapshot found with tag 'initial'") + + # ========== 5. Get Snapshot Details ========== + print("\n" + "=" * 60) + print("5️⃣ Getting Snapshot Details") + print("=" * 60) + + snapshot = server.get_snapshot(daemon_id=daemon_id, snapshot_id=snapshot_id1) + if snapshot: + print("\n✅ Snapshot details:") + print(f" ID: {snapshot.id}") + print(f" Message: {snapshot.message}") + print(f" Tags: {snapshot.tags}") + print(f" File count: {snapshot.file_count}") + print(f" Total size: {snapshot.total_size} bytes") + else: + print("❌ Snapshot not found") + + # ========== 6. Restore Snapshot ========== + print("\n" + "=" * 60) + print("6️⃣ Restoring Snapshot") + print("=" * 60) + + restore_path = workspace_path / "restored" + restore_path.mkdir() + + file_count = server.restore_snapshot( + daemon_id=daemon_id, + snapshot_id=snapshot_id1, + destination=str(restore_path), + ) + print(f"\n✅ Restored {file_count} files to: {restore_path}") + + # Verify restored files + restored_files = list(restore_path.rglob("*")) + print(" Files in restored directory:") + for f in restored_files: + if f.is_file(): + print(f" - {f.relative_to(restore_path)}") + + # ========== 7. Delete Snapshot ========== + print("\n" + "=" * 60) + print("7️⃣ Deleting Snapshot") + print("=" * 60) + + server.delete_snapshot(daemon_id=daemon_id, snapshot_id=snapshot_id2) + print(f"\n✅ Deleted snapshot: {snapshot_id2}") + + # Verify deletion + remaining = server.list_snapshots(daemon_id=daemon_id) + print(f" Remaining snapshots: {len(remaining)}") + + print("\n" + "=" * 60) + print("✅ Snapshot example completed successfully!") + print("=" * 60) + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("\n\n👋 Interrupted by user") + sys.exit(0) + except Exception as e: + print(f"\n❌ Error: {e}", file=sys.stderr) + sys.exit(1) diff --git a/examples/snapshot_simple.rs b/examples/snapshot_simple.rs index 22490e5..70361f6 100644 --- a/examples/snapshot_simple.rs +++ b/examples/snapshot_simple.rs @@ -60,12 +60,12 @@ async fn main() -> Result<()> { // 6. Find by tag (returns single snapshot since tags are immutable) println!("\n6. Finding snapshot with 'init' tag:"); - if let Some(snap) = manager.find_by_tag("init").await? { + if let Some(snap) = manager.find_snapshot_by_tag("init").await? { println!(" {} - {}", snap.id, snap.message); } println!("\n7. Finding snapshot with 'feature' tag:"); - if let Some(snap) = manager.find_by_tag("feature").await? { + if let Some(snap) = manager.find_snapshot_by_tag("feature").await? { println!(" {} - {}", snap.id, snap.message); } diff --git a/python/sandd/__init__.py b/python/sandd/__init__.py index c238e56..3897ac3 100644 --- a/python/sandd/__init__.py +++ b/python/sandd/__init__.py @@ -47,8 +47,7 @@ from ._core import Session, TunnelConfig except ImportError as e: raise ImportError( - "Failed to import Rust extension. " - "Please build the package with: make install" + "Failed to import Rust extension. Please build the package with: make install" ) from e __all__ = [ diff --git a/python/sandd/async_server.py b/python/sandd/async_server.py index bb01fc9..3d0f66b 100644 --- a/python/sandd/async_server.py +++ b/python/sandd/async_server.py @@ -157,12 +157,7 @@ def get_stats(self) -> ServerStats: """ raise NotImplementedError("AsyncServer.get_stats() not yet implemented") - async def upload_file( - self, - daemon_id: str, - remote_path: str, - data: bytes - ) -> None: + async def upload_file(self, daemon_id: str, remote_path: str, data: bytes) -> None: """Upload file to daemon (async) Args: @@ -172,11 +167,7 @@ async def upload_file( """ raise NotImplementedError("AsyncServer.upload_file() not yet implemented") - async def download_file( - self, - daemon_id: str, - remote_path: str - ) -> bytes: + async def download_file(self, daemon_id: str, remote_path: str) -> bytes: """Download file from daemon (async) Args: diff --git a/python/sandd/models.py b/python/sandd/models.py index 5bcd8af..837fb51 100644 --- a/python/sandd/models.py +++ b/python/sandd/models.py @@ -1,13 +1,13 @@ """Data models for SandD""" -from typing import Dict +from typing import Dict, List +from datetime import datetime try: from ._core import PyCommandResult, PyStats except ImportError as e: raise ImportError( - "Failed to import Rust extension. " - "Please build the package with: make install" + "Failed to import Rust extension. Please build the package with: make install" ) from e @@ -114,8 +114,40 @@ def oldest_connection_secs(self) -> int: """Age of oldest connection in seconds""" return self._stats.oldest_connection_secs + def __repr__(self) -> str: + return f"ServerStats(total={self.total_daemons}, platforms={self.by_platform})" + + +class SnapshotInfo: + """Snapshot metadata + + Attributes: + id: Snapshot ID (UUID) + created_at: Creation timestamp + message: Snapshot description + tags: List of tags (immutable) + file_count: Number of files in snapshot + total_size: Total size in bytes + """ + + def __init__( + self, + id: str, + created_at: int, # Unix timestamp + message: str, + tags: List[str], + file_count: int, + total_size: int, + ): + self.id = id + self.created_at = datetime.fromtimestamp(created_at) + self.message = message + self.tags = tags + self.file_count = file_count + self.total_size = total_size + def __repr__(self) -> str: return ( - f"ServerStats(total={self.total_daemons}, " - f"platforms={self.by_platform})" + f"SnapshotInfo(id={self.id!r}, message={self.message!r}, " + f"tags={self.tags}, files={self.file_count})" ) diff --git a/python/sandd/server.py b/python/sandd/server.py index 23c0aa0..9392bb0 100644 --- a/python/sandd/server.py +++ b/python/sandd/server.py @@ -5,14 +5,13 @@ import sys import select -from .models import CommandResult, ServerStats, DaemonInfo +from .models import CommandResult, ServerStats, DaemonInfo, SnapshotInfo try: from ._core import Server as _RustServer, Session, TunnelConfig except ImportError as e: raise ImportError( - "Failed to import Rust extension. " - "Please build the package with: make install" + "Failed to import Rust extension. Please build the package with: make install" ) from e @@ -54,12 +53,10 @@ def __init__( port: int = 8765, connect: str = "direct", tunnel_config: Optional[TunnelConfig] = None, - verbose: bool = True + verbose: bool = True, ): if connect not in ["direct", "tunnel"]: - raise ValueError( - f"connect must be 'direct' or 'tunnel', got '{connect}'" - ) + raise ValueError(f"connect must be 'direct' or 'tunnel', got '{connect}'") if connect == "tunnel" and tunnel_config is None: raise ValueError( @@ -118,9 +115,7 @@ def exec( Each daemon processes commands sequentially to ensure predictable execution order and avoid resource conflicts. """ - result = self._server.exec( - daemon_id, command, timeout, env, cwd - ) + result = self._server.exec(daemon_id, command, timeout, env, cwd) return CommandResult(result) def new_session( @@ -314,6 +309,7 @@ def _run_interactive(self, session: Session) -> None: if sys.platform != "win32": import tty import termios + old_settings = termios.tcgetattr(sys.stdin) try: tty.setraw(sys.stdin.fileno()) @@ -341,15 +337,16 @@ def _interactive_loop(self, session: Session) -> None: rlist, _, _ = select.select([sys.stdin], [], [], 0.01) if rlist: data = sys.stdin.read(1) - if not data or data == '\x04': # Ctrl+D + if not data or data == "\x04": # Ctrl+D break session.write(data.encode()) else: # Windows - simple blocking read import msvcrt + if msvcrt.kbhit(): data = msvcrt.getch() - if data == b'\x04': # Ctrl+D + if data == b"\x04": # Ctrl+D break session.write(data) @@ -421,15 +418,23 @@ def run_command(daemon_id): return self.exec(daemon_id, command, timeout, env, cwd) except Exception as e: # Create error result - return CommandResult(type('obj', (object,), { - 'stdout': '', - 'stderr': str(e), - 'exit_code': -1, - 'duration_ms': 0, - })()) + return CommandResult( + type( + "obj", + (object,), + { + "stdout": "", + "stderr": str(e), + "exit_code": -1, + "duration_ms": 0, + }, + )() + ) results = {} - with concurrent.futures.ThreadPoolExecutor(max_workers=len(daemon_ids)) as executor: + with concurrent.futures.ThreadPoolExecutor( + max_workers=len(daemon_ids) + ) as executor: futures = {executor.submit(run_command, did): did for did in daemon_ids} for future in concurrent.futures.as_completed(futures): daemon_id = futures[future] @@ -467,13 +472,207 @@ def wait_for_daemon( time.sleep(poll_interval) return False + def create_snapshot( + self, + daemon_id: str, + workspace: str, + message: Optional[str] = None, + tags: Optional[List[str]] = None, + ) -> str: + """Create a snapshot of workspace on daemon + + Args: + daemon_id: Target daemon ID + workspace: Path to workspace directory on daemon + message: Optional snapshot description + tags: Optional list of tags (must be unique, immutable) + + Returns: + Snapshot ID (UUID) + + Raises: + ValueError: If daemon not found + RuntimeError: If tag already exists or other error + + Example: + >>> snapshot_id = server.create_snapshot( + ... "daemon-1", + ... "/workspace", + ... message="Before deployment", + ... tags=["v1.0.0", "stable"] + ... ) + >>> print(f"Created: {snapshot_id}") + """ + return self._server.create_snapshot(daemon_id, workspace, message, tags) + + def restore_snapshot( + self, + daemon_id: str, + snapshot_id: str, + destination: str, + ) -> int: + """Restore snapshot on daemon + + Args: + daemon_id: Target daemon ID + snapshot_id: Snapshot ID to restore + destination: Path to restore to on daemon + + Returns: + Number of files restored + + Raises: + ValueError: If daemon or snapshot not found + RuntimeError: If restore fails + + Example: + >>> file_count = server.restore_snapshot( + ... "daemon-1", + ... "snap-abc-123", + ... "/tmp/restored" + ... ) + >>> print(f"Restored {file_count} files") + """ + return self._server.restore_snapshot(daemon_id, snapshot_id, destination) + + def list_snapshots( + self, + daemon_id: str, + tags: Optional[List[str]] = None, + ) -> List[SnapshotInfo]: + """List snapshots on daemon (optionally filtered by tags) + + Args: + daemon_id: Target daemon ID + tags: Optional list of tags to filter by (OR logic) + + Returns: + List of snapshot info (sorted by creation time, newest first) + + Raises: + ValueError: If daemon not found + RuntimeError: If tag doesn't exist + + Example: + >>> # List all snapshots + >>> snapshots = server.list_snapshots("daemon-1") + >>> for snap in snapshots: + ... print(f"{snap.id}: {snap.message} (tags: {snap.tags})") + >>> + >>> # Filter by tags + >>> snapshots = server.list_snapshots("daemon-1", tags=["stable"]) + """ + result = self._server.list_snapshots(daemon_id, tags) + return [ + SnapshotInfo( + id=snap["id"], + created_at=snap["created_at"], + message=snap["message"], + tags=snap["tags"], + file_count=snap["file_count"], + total_size=snap["total_size"], + ) + for snap in result + ] + + def find_snapshot_by_tag( + self, + daemon_id: str, + tag: str, + ) -> Optional[SnapshotInfo]: + """Find snapshot by tag on daemon (O(1) lookup) + + Args: + daemon_id: Target daemon ID + tag: Tag name to search for + + Returns: + Snapshot info if found, None otherwise + + Raises: + ValueError: If daemon not found + RuntimeError: If tag name is invalid + + Example: + >>> snapshot = server.find_snapshot_by_tag("daemon-1", "v1.0.0") + >>> if snapshot: + ... print(f"Found: {snapshot.id}") + """ + result = self._server.find_snapshot_by_tag(daemon_id, tag) + if result is None: + return None + + return SnapshotInfo( + id=result["id"], + created_at=result["created_at"], + message=result["message"], + tags=result["tags"], + file_count=result["file_count"], + total_size=result["total_size"], + ) + + def get_snapshot( + self, + daemon_id: str, + snapshot_id: str, + ) -> Optional[SnapshotInfo]: + """Get snapshot info from daemon + + Args: + daemon_id: Target daemon ID + snapshot_id: Snapshot ID + + Returns: + Snapshot info, or None if not found + + Raises: + ValueError: If daemon not found + RuntimeError: If communication error + + Example: + >>> snapshot = server.get_snapshot("daemon-1", "snap-abc-123") + >>> if snapshot: + ... print(f"Message: {snapshot.message}") + ... print(f"Tags: {snapshot.tags}") + ... else: + ... print("Snapshot not found") + """ + result = self._server.get_snapshot(daemon_id, snapshot_id) + if result is None: + return None + return SnapshotInfo( + id=result["id"], + created_at=result["created_at"], + message=result["message"], + tags=result["tags"], + file_count=result["file_count"], + total_size=result["total_size"], + ) + + def delete_snapshot( + self, + daemon_id: str, + snapshot_id: str, + ) -> None: + """Delete snapshot from daemon (also removes tag refs) + + Args: + daemon_id: Target daemon ID + snapshot_id: Snapshot ID to delete + + Raises: + ValueError: If daemon or snapshot not found + + Example: + >>> server.delete_snapshot("daemon-1", "snap-abc-123") + >>> print("Snapshot deleted") + """ + self._server.delete_snapshot(daemon_id, snapshot_id) + @property def address(self) -> str: """Server address (host:port)""" return f"{self._host}:{self._port}" def __repr__(self) -> str: - return ( - f"Server(address={self.address}, " - f"daemons={self.daemon_count()})" - ) + return f"Server(address={self.address}, daemons={self.daemon_count()})" diff --git a/python/tests/test_e2e.py b/python/tests/test_e2e.py index 98c5da6..de9e927 100644 --- a/python/tests/test_e2e.py +++ b/python/tests/test_e2e.py @@ -5,6 +5,7 @@ These tests are marked as 'e2e' and skipped by default in 'make test'. Use 'make test-e2e' to run them explicitly. """ + import pytest import time import subprocess @@ -22,21 +23,20 @@ def docker_daemons(): subprocess.run( ["docker", "compose", "-f", compose_file, "build"], check=True, - capture_output=True + capture_output=True, ) subprocess.run( ["docker", "compose", "-f", compose_file, "up", "-d"], check=True, - capture_output=True + capture_output=True, ) yield # Cleanup subprocess.run( - ["docker", "compose", "-f", compose_file, "down"], - capture_output=True + ["docker", "compose", "-f", compose_file, "down"], capture_output=True ) @@ -47,9 +47,12 @@ def server(docker_daemons): # Wait for all daemons to connect (2 debian + 2 alpine + 2 rocky) daemon_ids = [ - "daemon-debian-1", "daemon-debian-2", - "daemon-alpine-1", "daemon-alpine-2", - "daemon-rocky-1", "daemon-rocky-2" + "daemon-debian-1", + "daemon-debian-2", + "daemon-alpine-1", + "daemon-alpine-2", + "daemon-rocky-1", + "daemon-rocky-2", ] for daemon_id in daemon_ids: connected = srv.wait_for_daemon(daemon_id, timeout=15.0) @@ -67,9 +70,12 @@ def test_all_daemons_connected(self, server): daemons = server.list_daemons() daemon_ids = [d.id for d in daemons] expected = [ - "daemon-debian-1", "daemon-debian-2", - "daemon-alpine-1", "daemon-alpine-2", - "daemon-rocky-1", "daemon-rocky-2" + "daemon-debian-1", + "daemon-debian-2", + "daemon-alpine-1", + "daemon-alpine-2", + "daemon-rocky-1", + "daemon-rocky-2", ] for daemon_id in expected: assert daemon_id in daemon_ids @@ -78,16 +84,15 @@ def test_all_daemons_connected(self, server): def test_execute_on_each_daemon(self, server): """Execute commands on each daemon across all distributions""" daemon_ids = [ - "daemon-debian-1", "daemon-debian-2", - "daemon-alpine-1", "daemon-alpine-2", - "daemon-rocky-1", "daemon-rocky-2" + "daemon-debian-1", + "daemon-debian-2", + "daemon-alpine-1", + "daemon-alpine-2", + "daemon-rocky-1", + "daemon-rocky-2", ] for daemon_id in daemon_ids: - result = server.exec( - daemon_id, - "echo 'Hello from container'", - timeout=5 - ) + result = server.exec(daemon_id, "echo 'Hello from container'", timeout=5) assert result.success assert "Hello from container" in result.stdout @@ -95,15 +100,11 @@ def test_concurrent_execution(self, server): """Execute commands concurrently on multiple daemons""" import concurrent.futures - daemon_ids = [ - "daemon-debian-1", "daemon-alpine-1", "daemon-rocky-1" - ] + daemon_ids = ["daemon-debian-1", "daemon-alpine-1", "daemon-rocky-1"] def run_cmd(daemon_id): return server.exec( - daemon_id, - f"echo 'Response from {daemon_id}'", - timeout=5 + daemon_id, f"echo 'Response from {daemon_id}'", timeout=5 ) with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor: @@ -120,7 +121,9 @@ def test_concurrent_execution_same_daemon(self, server): daemon_id = "daemon-debian-1" def run_sleep(n): - result = server.exec(daemon_id, f"sleep {n} && echo 'slept {n}s'", timeout=10) + result = server.exec( + daemon_id, f"sleep {n} && echo 'slept {n}s'", timeout=10 + ) return result def run_fast(): @@ -153,8 +156,7 @@ class TestE2EBroadcast: def test_broadcast_simple_command(self, server): """Broadcast a simple command to multiple daemons""" results = server.broadcast( - labels={"env": "test"}, - command="echo 'hello from broadcast'" + labels={"env": "test"}, command="echo 'hello from broadcast'" ) # Should have 4 test daemons @@ -168,8 +170,7 @@ def test_broadcast_simple_command(self, server): def test_broadcast_with_multiple_labels(self, server): """Broadcast with multiple label filters (AND logic)""" results = server.broadcast( - labels={"env": "test", "distro": "debian"}, - command="hostname" + labels={"env": "test", "distro": "debian"}, command="hostname" ) # Should match only debian test daemons @@ -182,20 +183,14 @@ def test_broadcast_with_multiple_labels(self, server): def test_broadcast_no_matching_daemons(self, server): """Broadcast with labels that match no daemons""" - results = server.broadcast( - labels={"env": "nonexistent"}, - command="hostname" - ) + results = server.broadcast(labels={"env": "nonexistent"}, command="hostname") # Should return empty dict assert len(results) == 0 def test_broadcast_with_failure(self, server): """Broadcast command that fails on some daemons""" - results = server.broadcast( - labels={"env": "prod"}, - command="exit 1" - ) + results = server.broadcast(labels={"env": "prod"}, command="exit 1") # Should have results for prod daemons assert len(results) == 2 @@ -210,10 +205,7 @@ def test_broadcast_concurrent_execution(self, server): # Broadcast a 2-second sleep to test daemons start = time.time() - results = server.broadcast( - labels={"env": "test"}, - command="sleep 2" - ) + results = server.broadcast(labels={"env": "test"}, command="sleep 2") duration = time.time() - start # Should complete in ~2-3 seconds (concurrent), not 8+ seconds (serial) @@ -273,7 +265,7 @@ def test_daemon_restart(self, server): subprocess.run( ["docker", "restart", "sandd-daemon-debian-1"], check=True, - capture_output=True + capture_output=True, ) # Wait for reconnection @@ -303,9 +295,7 @@ class TestE2EDistributionSpecific: def test_package_manager_debian(self, server): """Test apt package manager on Debian daemons""" result = server.exec( - "daemon-debian-1", - "apt-get update && apt-get install -y curl", - timeout=60 + "daemon-debian-1", "apt-get update && apt-get install -y curl", timeout=60 ) assert result.success @@ -316,9 +306,7 @@ def test_package_manager_debian(self, server): def test_package_manager_alpine(self, server): """Test apk package manager on Alpine daemons""" result = server.exec( - "daemon-alpine-1", - "apk update && apk add curl", - timeout=60 + "daemon-alpine-1", "apk update && apk add curl", timeout=60 ) assert result.success @@ -328,11 +316,7 @@ def test_package_manager_alpine(self, server): def test_package_manager_rocky(self, server): """Test dnf package manager on Rocky daemons""" - result = server.exec( - "daemon-rocky-1", - "microdnf install -y curl", - timeout=60 - ) + result = server.exec("daemon-rocky-1", "microdnf install -y curl", timeout=60) assert result.success result = server.exec("daemon-rocky-1", "curl --version", timeout=5) @@ -341,11 +325,7 @@ def test_package_manager_rocky(self, server): def test_all_distros_run_same_command(self, server): """Verify all distributions can run common commands""" - daemon_ids = [ - "daemon-debian-1", - "daemon-alpine-1", - "daemon-rocky-1" - ] + daemon_ids = ["daemon-debian-1", "daemon-alpine-1", "daemon-rocky-1"] for daemon_id in daemon_ids: result = server.exec(daemon_id, "uname -s", timeout=5) assert result.success @@ -370,19 +350,15 @@ def test_session_basic_commands(self, server): # Read output output = session.read(timeout=2.0) assert output is not None - output_str = output.decode('utf-8', errors='ignore') - assert 'Hello from session' in output_str + output_str = output.decode("utf-8", errors="ignore") + assert "Hello from session" in output_str finally: session.close() def test_session_across_distributions(self, server): """Test session works on all distributions""" - daemon_ids = [ - "daemon-debian-1", - "daemon-alpine-1", - "daemon-rocky-1" - ] + daemon_ids = ["daemon-debian-1", "daemon-alpine-1", "daemon-rocky-1"] for daemon_id in daemon_ids: session = server.new_session(daemon_id) @@ -414,7 +390,7 @@ def test_session_multiline_commands(self, server): time.sleep(0.5) # Read all output chunks - all_output = b'' + all_output = b"" for _ in range(5): output = session.read(timeout=0.5) if output: @@ -423,9 +399,9 @@ def test_session_multiline_commands(self, server): break assert all_output - output_str = all_output.decode('utf-8', errors='ignore') + output_str = all_output.decode("utf-8", errors="ignore") # Should see the numbers - assert '1' in output_str and '2' in output_str and '3' in output_str + assert "1" in output_str and "2" in output_str and "3" in output_str finally: session.close() @@ -448,8 +424,8 @@ def test_session_environment_variables(self, server): output = session.read(timeout=2.0) assert output is not None - output_str = output.decode('utf-8', errors='ignore') - assert 'test123' in output_str + output_str = output.decode("utf-8", errors="ignore") + assert "test123" in output_str finally: session.close() @@ -472,12 +448,329 @@ def test_session_cd_persistence(self, server): output = session.read(timeout=2.0) assert output is not None - output_str = output.decode('utf-8', errors='ignore') - assert '/tmp' in output_str + output_str = output.decode("utf-8", errors="ignore") + assert "/tmp" in output_str finally: session.close() +class TestE2ESnapshots: + """E2E snapshot operations""" + + def test_create_and_list_snapshot(self, server): + """Create snapshot and list it""" + daemon_id = "daemon-debian-1" + + # Create a test workspace + server.exec(daemon_id, "mkdir -p /tmp/test-workspace", timeout=5) + server.exec( + daemon_id, "echo 'test content' > /tmp/test-workspace/file.txt", timeout=5 + ) + + # Create snapshot + snapshot_id = server.create_snapshot( + daemon_id, "/tmp/test-workspace", message="Test snapshot", tags=["test"] + ) + assert snapshot_id is not None + assert len(snapshot_id) > 0 + + # List snapshots + snapshots = server.list_snapshots(daemon_id) + assert len(snapshots) > 0 + assert any(s.id == snapshot_id for s in snapshots) + + # Find by tag + found = server.list_snapshots(daemon_id, tags=["test"]) + assert len(found) > 0 + assert found[0].id == snapshot_id + assert found[0].message == "Test snapshot" + assert "test" in found[0].tags + + def test_create_and_restore_snapshot(self, server): + """Create snapshot and restore it""" + daemon_id = "daemon-alpine-1" + + # Create test workspace + server.exec(daemon_id, "mkdir -p /tmp/source", timeout=5) + server.exec(daemon_id, "echo 'original' > /tmp/source/data.txt", timeout=5) + + # Create snapshot + snapshot_id = server.create_snapshot( + daemon_id, "/tmp/source", message="Original state" + ) + + # Verify snapshot created + snapshots = server.list_snapshots(daemon_id) + assert any(s.id == snapshot_id for s in snapshots) + + # Restore to different location + file_count = server.restore_snapshot(daemon_id, snapshot_id, "/tmp/restored") + assert file_count > 0 + + # Verify restored content + result = server.exec(daemon_id, "cat /tmp/restored/data.txt", timeout=5) + assert result.success + assert "original" in result.stdout + + def test_snapshot_with_multiple_tags(self, server): + """Create snapshot with multiple tags""" + daemon_id = "daemon-rocky-1" + + # Create workspace + server.exec(daemon_id, "mkdir -p /tmp/multi-tag", timeout=5) + server.exec(daemon_id, "echo 'tagged' > /tmp/multi-tag/file.txt", timeout=5) + + # Create snapshot with multiple tags + snapshot_id = server.create_snapshot( + daemon_id, + "/tmp/multi-tag", + message="Multi-tagged", + tags=["v1.0.0", "stable", "production"], + ) + + # List by different tags + by_v1 = server.list_snapshots(daemon_id, tags=["v1.0.0"]) + by_stable = server.list_snapshots(daemon_id, tags=["stable"]) + by_production = server.list_snapshots(daemon_id, tags=["production"]) + + assert len(by_v1) > 0 and by_v1[0].id == snapshot_id + assert len(by_stable) > 0 and by_stable[0].id == snapshot_id + assert len(by_production) > 0 and by_production[0].id == snapshot_id + + def test_snapshot_immutable_tags(self, server): + """Verify tags are immutable (duplicate tag should fail)""" + daemon_id = "daemon-debian-2" + + # Create workspace + server.exec(daemon_id, "mkdir -p /tmp/immutable-tag", timeout=5) + server.exec(daemon_id, "echo 'first' > /tmp/immutable-tag/data.txt", timeout=5) + + # Create first snapshot with tag + snapshot_id1 = server.create_snapshot( + daemon_id, "/tmp/immutable-tag", tags=["unique-tag"] + ) + assert snapshot_id1 is not None + + # Try to create second snapshot with same tag (should fail) + server.exec(daemon_id, "echo 'second' > /tmp/immutable-tag/data.txt", timeout=5) + + with pytest.raises(Exception) as exc_info: + server.create_snapshot(daemon_id, "/tmp/immutable-tag", tags=["unique-tag"]) + assert "already exists" in str(exc_info.value).lower() + + def test_delete_snapshot(self, server): + """Delete snapshot and verify it's removed""" + daemon_id = "daemon-alpine-2" + + # Create workspace + server.exec(daemon_id, "mkdir -p /tmp/delete-test", timeout=5) + server.exec( + daemon_id, "echo 'to delete' > /tmp/delete-test/file.txt", timeout=5 + ) + + # Create snapshot with tag + snapshot_id = server.create_snapshot( + daemon_id, "/tmp/delete-test", message="Will be deleted", tags=["delete-me"] + ) + + # Verify snapshot exists + snapshots_before = server.list_snapshots(daemon_id) + assert any(s.id == snapshot_id for s in snapshots_before) + + # Delete snapshot + server.delete_snapshot(daemon_id, snapshot_id) + + # Verify snapshot removed + snapshots_after = server.list_snapshots(daemon_id) + assert not any(s.id == snapshot_id for s in snapshots_after) + + # Verify tag can be reused after deletion + snapshot_id2 = server.create_snapshot( + daemon_id, + "/tmp/delete-test", + tags=["delete-me"], # Should work now + ) + assert snapshot_id2 is not None + assert snapshot_id2 != snapshot_id + + def test_find_snapshot_by_tag(self, server): + """Find snapshot by tag (O(1) lookup)""" + daemon_id = "daemon-rocky-2" + + # Create workspace + server.exec(daemon_id, "mkdir -p /tmp/find-test", timeout=5) + server.exec(daemon_id, "echo 'findme' > /tmp/find-test/data.txt", timeout=5) + + # Create snapshot with unique tag + snapshot_id = server.create_snapshot( + daemon_id, + "/tmp/find-test", + message="Find me by tag", + tags=["unique-find-tag"], + ) + + # Find by tag + found = server.find_snapshot_by_tag(daemon_id, "unique-find-tag") + assert found is not None + assert found.id == snapshot_id + assert found.message == "Find me by tag" + assert "unique-find-tag" in found.tags + + # Try to find non-existent tag + not_found = server.find_snapshot_by_tag(daemon_id, "non-existent-tag") + assert not_found is None + + def test_get_snapshot(self, server): + """Get snapshot details by ID""" + daemon_id = "daemon-debian-1" + + # Create workspace + server.exec(daemon_id, "mkdir -p /tmp/get-test", timeout=5) + server.exec(daemon_id, "echo 'data1' > /tmp/get-test/file1.txt", timeout=5) + server.exec(daemon_id, "echo 'data2' > /tmp/get-test/file2.txt", timeout=5) + + # Create snapshot + snapshot_id = server.create_snapshot( + daemon_id, + "/tmp/get-test", + message="Get test snapshot", + tags=["get-tag-1", "get-tag-2"], + ) + + # Get snapshot details + snapshot = server.get_snapshot(daemon_id, snapshot_id) + assert snapshot.id == snapshot_id + assert snapshot.message == "Get test snapshot" + assert snapshot.tags == ["get-tag-1", "get-tag-2"] + assert snapshot.file_count == 2 + assert snapshot.total_size > 0 + + # Try to get non-existent snapshot + with pytest.raises(Exception): + server.get_snapshot(daemon_id, "non-existent-id") + + def test_snapshot_nested_directories(self, server): + """Verify nested directory structure is preserved""" + daemon_id = "daemon-debian-2" + + # Create nested directory structure + server.exec(daemon_id, "mkdir -p /tmp/nested/a/b/c", timeout=5) + server.exec(daemon_id, "echo 'file1' > /tmp/nested/file1.txt", timeout=5) + server.exec(daemon_id, "echo 'file2' > /tmp/nested/a/file2.txt", timeout=5) + server.exec(daemon_id, "echo 'file3' > /tmp/nested/a/b/file3.txt", timeout=5) + server.exec(daemon_id, "echo 'file4' > /tmp/nested/a/b/c/file4.txt", timeout=5) + + # Create snapshot + snapshot_id = server.create_snapshot( + daemon_id, "/tmp/nested", message="Nested structure" + ) + + # Restore + server.restore_snapshot(daemon_id, snapshot_id, "/tmp/restored-nested") + + # Verify all files and structure + result1 = server.exec( + daemon_id, "cat /tmp/restored-nested/file1.txt", timeout=5 + ) + assert result1.success and "file1" in result1.stdout + + result2 = server.exec( + daemon_id, "cat /tmp/restored-nested/a/file2.txt", timeout=5 + ) + assert result2.success and "file2" in result2.stdout + + result3 = server.exec( + daemon_id, "cat /tmp/restored-nested/a/b/file3.txt", timeout=5 + ) + assert result3.success and "file3" in result3.stdout + + result4 = server.exec( + daemon_id, "cat /tmp/restored-nested/a/b/c/file4.txt", timeout=5 + ) + assert result4.success and "file4" in result4.stdout + + def test_snapshot_binary_files(self, server): + """Verify binary files are correctly captured and restored""" + daemon_id = "daemon-alpine-1" + + # Create workspace with binary file + server.exec(daemon_id, "mkdir -p /tmp/binary-test", timeout=5) + # Create a small binary file + server.exec( + daemon_id, + "dd if=/dev/urandom of=/tmp/binary-test/random.bin bs=1024 count=10", + timeout=5, + ) + + # Get checksum before snapshot + result_before = server.exec( + daemon_id, "md5sum /tmp/binary-test/random.bin", timeout=5 + ) + assert result_before.success + checksum_before = result_before.stdout.split()[0] + + # Create snapshot + snapshot_id = server.create_snapshot( + daemon_id, "/tmp/binary-test", message="Binary file test" + ) + + # Restore + server.restore_snapshot(daemon_id, snapshot_id, "/tmp/restored-binary") + + # Verify checksum matches + result_after = server.exec( + daemon_id, "md5sum /tmp/restored-binary/random.bin", timeout=5 + ) + assert result_after.success + checksum_after = result_after.stdout.split()[0] + + assert checksum_before == checksum_after, ( + "Binary file corrupted during snapshot/restore" + ) + + def test_snapshot_deduplication(self, server): + """Verify deduplication works (same content = same storage)""" + daemon_id = "daemon-rocky-1" + + # Create workspace with duplicate content + server.exec(daemon_id, "mkdir -p /tmp/dedup-test", timeout=5) + server.exec( + daemon_id, "echo 'same content' > /tmp/dedup-test/file1.txt", timeout=5 + ) + server.exec( + daemon_id, "echo 'same content' > /tmp/dedup-test/file2.txt", timeout=5 + ) + server.exec( + daemon_id, "echo 'same content' > /tmp/dedup-test/file3.txt", timeout=5 + ) + + # Create snapshot + snapshot_id = server.create_snapshot( + daemon_id, "/tmp/dedup-test", message="Dedup test" + ) + + # Get snapshot info + snapshot = server.get_snapshot(daemon_id, snapshot_id) + + # Total size should be much less than 3x file size (due to deduplication) + # Each file has "same content\n" (13 bytes), but stored only once + assert snapshot.file_count == 3 + # Size should be close to 13 bytes (one copy), not 39 bytes (three copies) + # Allow some overhead for tree structures + assert snapshot.total_size < 100, ( + f"Expected deduplication, got {snapshot.total_size} bytes" + ) + + # Verify all files restored correctly + server.restore_snapshot(daemon_id, snapshot_id, "/tmp/restored-dedup") + for i in range(1, 4): + result = server.exec( + daemon_id, f"cat /tmp/restored-dedup/file{i}.txt", timeout=5 + ) + assert result.success + assert "same content" in result.stdout + + if __name__ == "__main__": pytest.main([__file__, "-v", "-s"]) diff --git a/python/tests/test_integration.py b/python/tests/test_integration.py index 429e1f8..502cc08 100644 --- a/python/tests/test_integration.py +++ b/python/tests/test_integration.py @@ -5,6 +5,7 @@ Run with: pytest python/tests/test_integration.py -v -s """ + import pytest import subprocess import time @@ -34,6 +35,7 @@ def server(): """Create a server instance on a unique port""" # Use a different port for each test to avoid conflicts import random + port = random.randint(9000, 9999) server = Server(host="127.0.0.1", port=port) yield server @@ -98,7 +100,13 @@ def test_multiple_daemons_connect(self, server, sandd_binary): server_url = f"ws://127.0.0.1:{server.address.split(':')[1]}/ws" proc = subprocess.Popen( - [sandd_binary, "--server-url", server_url, "--daemon-id", daemon_id], + [ + sandd_binary, + "--server-url", + server_url, + "--daemon-id", + daemon_id, + ], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, ) @@ -133,10 +141,14 @@ def test_daemon_with_labels(self, server, sandd_binary): proc_prod = subprocess.Popen( [ sandd_binary, - "--server-url", server_url, - "--daemon-id", daemon_id_prod, - "--label", "env=prod", - "--label", "region=us-west", + "--server-url", + server_url, + "--daemon-id", + daemon_id_prod, + "--label", + "env=prod", + "--label", + "region=us-west", ], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, @@ -147,9 +159,12 @@ def test_daemon_with_labels(self, server, sandd_binary): proc_dev = subprocess.Popen( [ sandd_binary, - "--server-url", server_url, - "--daemon-id", daemon_id_dev, - "--label", "env=dev", + "--server-url", + server_url, + "--daemon-id", + daemon_id_dev, + "--label", + "env=dev", ], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, @@ -229,7 +244,7 @@ def test_exec_with_env(self, server, daemon_process): daemon_id, _ = daemon_process env = {"TEST_VAR": "test_value_123"} - if os.name == 'nt': # Windows + if os.name == "nt": # Windows cmd = "echo %TEST_VAR%" else: # Unix cmd = "echo $TEST_VAR" @@ -258,7 +273,7 @@ def test_execute_long_output(self, server, daemon_process): result = server.exec(daemon_id, cmd, timeout=10) assert result.success - assert result.stdout.count('\n') >= 1000 + assert result.stdout.count("\n") >= 1000 def test_command_timeout(self, server, daemon_process): """Test command timeout handling""" @@ -309,10 +324,14 @@ def test_get_daemon_with_labels(self, server, sandd_binary): proc = subprocess.Popen( [ sandd_binary, - "--server-url", server_url, - "--daemon-id", daemon_id, - "--label", "env=staging", - "--label", "team=backend", + "--server-url", + server_url, + "--daemon-id", + daemon_id, + "--label", + "env=staging", + "--label", + "team=backend", ], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, @@ -418,6 +437,7 @@ def run_long_command(): assert daemon_after is not None assert daemon_after.is_busy is False + class TestServerStats: """Test server statistics with real connections""" @@ -443,8 +463,10 @@ def test_stats_platform_reporting(self, server, daemon_process): assert len(platforms) > 0 # Common platform names (Rust's std::env::consts::OS values) - assert any(p in ["linux", "macos", "windows", "Linux", "Darwin", "Windows"] - for p in platforms) + assert any( + p in ["linux", "macos", "windows", "Linux", "Darwin", "Windows"] + for p in platforms + ) # class TestFileTransfer: @@ -518,6 +540,7 @@ def test_wait_for_new_daemon(self, server, sandd_binary): # Start waiting in one "thread" (we'll simulate with timing) import threading + result_holder = {"connected": False} def wait_thread(): @@ -548,8 +571,7 @@ def wait_thread(): @pytest.mark.skipif( - not DAEMON_BINARY.exists(), - reason="Requires compiled daemon binary" + not DAEMON_BINARY.exists(), reason="Requires compiled daemon binary" ) class TestSession: """Test interactive sessions""" @@ -569,8 +591,8 @@ def test_session(self, server, daemon_process): # Should contain our echo if output: - output_str = output.decode('utf-8', errors='ignore') - assert 'test123' in output_str or 'echo' in output_str + output_str = output.decode("utf-8", errors="ignore") + assert "test123" in output_str or "echo" in output_str # Close session session.close() diff --git a/python/tests/test_unit.py b/python/tests/test_unit.py index bc19434..fbc3c75 100644 --- a/python/tests/test_unit.py +++ b/python/tests/test_unit.py @@ -3,6 +3,7 @@ These tests verify the Python API without requiring real daemon connections. For integration tests with real daemons, see test_integration.py """ + import pytest from sandd import Server, ServerStats, TunnelConfig diff --git a/sandd/src/main.rs b/sandd/src/main.rs index 370aecb..377d18d 100644 --- a/sandd/src/main.rs +++ b/sandd/src/main.rs @@ -188,6 +188,16 @@ async fn connect_and_serve( let executor = Arc::new(CommandExecutor::new()); let session_manager = Arc::new(tokio::sync::Mutex::new(session::SessionManager::new())); + // Initialize sandd root (default: ~/.sandd) + let sandd_root = std::env::var("SANDD_ROOT").unwrap_or_else(|_| { + let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string()); + format!("{}/.sandd", home) + }); + let snapshot_manager = Arc::new( + snapshot::SnapshotManager::new(std::path::PathBuf::from(&sandd_root)) + .context("Failed to initialize snapshot manager")?, + ); + // Spawn heartbeat task let ws_tx_clone = Arc::new(tokio::sync::Mutex::new(ws_tx)); let ws_tx_heartbeat = ws_tx_clone.clone(); @@ -234,6 +244,7 @@ async fn connect_and_serve( ws_tx_clone.clone(), executor.clone(), session_manager.clone(), + snapshot_manager.clone(), ) .await { @@ -252,6 +263,7 @@ async fn handle_message( ws_tx: Arc>, executor: Arc, session_manager: Arc>, + snapshot_manager: Arc, ) -> Result<()> where T: SinkExt + Unpin + Send + 'static, @@ -410,6 +422,162 @@ where } } + Message::CreateSnapshot { + request_id, + workspace, + message, + tags, + } => { + debug!("Creating snapshot of {}", workspace); + let result = snapshot_manager + .create_snapshot(std::path::Path::new(&workspace), message, tags) + .await; + + let response = match result { + Ok(snapshot_id) => { + let snapshot = snapshot_manager.get_snapshot(&snapshot_id).await?; + Message::SnapshotCreated { + request_id, + snapshot_id, + file_count: snapshot.file_count, + total_size: snapshot.total_size, + } + } + Err(e) => Message::SnapshotError { + request_id, + error: e.to_string(), + }, + }; + + let json = serde_json::to_string(&response)?; + let mut tx = ws_tx.lock().await; + tx.send(WsMessage::Text(json)) + .await + .map_err(|e| anyhow::anyhow!("{}", e))?; + } + + Message::RestoreSnapshot { + request_id, + snapshot_id, + destination, + } => { + debug!("Restoring snapshot {} to {}", snapshot_id, destination); + let result = snapshot_manager + .restore_snapshot(&snapshot_id, std::path::Path::new(&destination)) + .await; + + let response = match result { + Ok(()) => { + let snapshot = snapshot_manager.get_snapshot(&snapshot_id).await?; + Message::SnapshotRestored { + request_id, + file_count: snapshot.file_count, + } + } + Err(e) => Message::SnapshotError { + request_id, + error: e.to_string(), + }, + }; + + let json = serde_json::to_string(&response)?; + let mut tx = ws_tx.lock().await; + tx.send(WsMessage::Text(json)) + .await + .map_err(|e| anyhow::anyhow!("{}", e))?; + } + + Message::ListSnapshots { request_id, tags } => { + debug!("Listing snapshots"); + let result = snapshot_manager.list_snapshots(tags).await; + + let response = match result { + Ok(snapshots) => Message::SnapshotList { + request_id, + snapshots, + }, + Err(e) => Message::SnapshotError { + request_id, + error: e.to_string(), + }, + }; + + let json = serde_json::to_string(&response)?; + let mut tx = ws_tx.lock().await; + tx.send(WsMessage::Text(json)) + .await + .map_err(|e| anyhow::anyhow!("{}", e))?; + } + + Message::FindSnapshotByTag { request_id, tag } => { + debug!("Finding snapshot by tag: {}", tag); + let result = snapshot_manager.find_snapshot_by_tag(&tag).await; + + let response = match result { + Ok(snapshot) => Message::SnapshotDetails { + request_id, + snapshot, + }, + Err(e) => Message::SnapshotError { + request_id, + error: e.to_string(), + }, + }; + + let json = serde_json::to_string(&response)?; + let mut tx = ws_tx.lock().await; + tx.send(WsMessage::Text(json)) + .await + .map_err(|e| anyhow::anyhow!("{}", e))?; + } + + Message::GetSnapshot { + request_id, + snapshot_id, + } => { + debug!("Getting snapshot: {}", snapshot_id); + let result = snapshot_manager.get_snapshot(&snapshot_id).await; + + let response = match result { + Ok(snapshot_info) => Message::SnapshotDetails { + request_id, + snapshot: Some(snapshot_info), + }, + Err(e) => Message::SnapshotError { + request_id, + error: e.to_string(), + }, + }; + + let json = serde_json::to_string(&response)?; + let mut tx = ws_tx.lock().await; + tx.send(WsMessage::Text(json)) + .await + .map_err(|e| anyhow::anyhow!("{}", e))?; + } + + Message::DeleteSnapshot { + request_id, + snapshot_id, + } => { + debug!("Deleting snapshot: {}", snapshot_id); + let result = snapshot_manager.delete_snapshot(&snapshot_id).await; + + let response = match result { + Ok(()) => Message::SnapshotDeleted { request_id }, + Err(e) => Message::SnapshotError { + request_id, + error: e.to_string(), + }, + }; + + let json = serde_json::to_string(&response)?; + let mut tx = ws_tx.lock().await; + tx.send(WsMessage::Text(json)) + .await + .map_err(|e| anyhow::anyhow!("{}", e))?; + } + _ => { debug!("Received unhandled message type"); } diff --git a/sandd/src/protocol.rs b/sandd/src/protocol.rs index cf385e2..b49c41b 100644 --- a/sandd/src/protocol.rs +++ b/sandd/src/protocol.rs @@ -104,6 +104,59 @@ pub enum Message { request_id: String, error: String, }, + // Snapshot operations + CreateSnapshot { + request_id: String, + workspace: String, + message: Option, + tags: Option>, + }, + SnapshotCreated { + request_id: String, + snapshot_id: String, + file_count: usize, + total_size: u64, + }, + RestoreSnapshot { + request_id: String, + snapshot_id: String, + destination: String, + }, + SnapshotRestored { + request_id: String, + file_count: usize, + }, + ListSnapshots { + request_id: String, + tags: Option>, + }, + SnapshotList { + request_id: String, + snapshots: Vec, + }, + FindSnapshotByTag { + request_id: String, + tag: String, + }, + GetSnapshot { + request_id: String, + snapshot_id: String, + }, + SnapshotDetails { + request_id: String, + snapshot: Option, + }, + DeleteSnapshot { + request_id: String, + snapshot_id: String, + }, + SnapshotDeleted { + request_id: String, + }, + SnapshotError { + request_id: String, + error: String, + }, Error { message: String, #[serde(default)] diff --git a/sandd/src/snapshot/manager.rs b/sandd/src/snapshot/manager.rs index 667e12c..01ab088 100644 --- a/sandd/src/snapshot/manager.rs +++ b/sandd/src/snapshot/manager.rs @@ -54,6 +54,12 @@ impl SnapshotManager { Self::validate_tag_name(tag)?; } + // Create snapshot metadata (store tags in snapshot file) + let created_at = SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map_err(|e| anyhow::anyhow!("System time error: {}", e))? + .as_secs(); + // Check tag existence upfront (still has TOCTOU, but add_tag uses atomic write) for tag in &tags { let tag_file = self.tags_dir.join(tag); @@ -67,14 +73,13 @@ impl SnapshotManager { // Build tree recursively let (tree_hash, file_count, total_size) = self.build_tree(workspace).await?; - // Create snapshot metadata (store tags in snapshot file) let snapshot = Snapshot { id: snapshot_id.clone(), - created_at: SystemTime::now(), + created_at, tree: tree_hash, message: message.unwrap_or_else(|| format!("Snapshot {}", snapshot_id)), tags: tags.clone(), // Store in snapshot for fast access - workspace_path: workspace.to_path_buf(), + workspace: workspace.to_path_buf(), file_count, total_size, }; @@ -297,9 +302,6 @@ impl SnapshotManager { None => anyhow::bail!("Tag '{}' does not exist", tag), } } - // Deduplicate: multiple tags may point to same snapshot - ids.sort(); - ids.dedup(); ids } else { // No filter: load all snapshots @@ -324,15 +326,23 @@ impl SnapshotManager { snapshots.push(snapshot.into()); } - // Sort by creation time (newest first) - snapshots.sort_by(|a: &SnapshotInfo, b: &SnapshotInfo| b.created_at.cmp(&a.created_at)); + // Sort by creation time (newest first), then deduplicate by ID + // Note: when filtering by tags, multiple tags may point to same snapshot + snapshots.sort_by(|a: &SnapshotInfo, b: &SnapshotInfo| { + b.created_at + .cmp(&a.created_at) + .then_with(|| a.id.cmp(&b.id)) + }); + + // Deduplicate by ID (keep first occurrence = newest due to sort) + snapshots.dedup_by(|a, b| a.id == b.id); Ok(snapshots) } /// Find snapshot by tag (O(1) lookup via tag ref) /// Returns single snapshot since tags are immutable - pub async fn find_by_tag(&self, tag: &str) -> Result> { + pub async fn find_snapshot_by_tag(&self, tag: &str) -> Result> { // Validate tag name (security) Self::validate_tag_name(tag)?; @@ -345,13 +355,13 @@ impl SnapshotManager { } /// Get snapshot by ID - pub async fn get_snapshot(&self, id: &str) -> Result { + pub async fn get_snapshot(&self, id: &str) -> Result { let snapshot_file = self.snapshots_dir.join(format!("{}.json", id)); let json = fs::read_to_string(snapshot_file) .await .with_context(|| format!("Snapshot {} not found", id))?; let snapshot: Snapshot = serde_json::from_str(&json)?; - Ok(snapshot) + Ok(snapshot.into()) } /// Delete snapshot and its tag refs @@ -572,6 +582,9 @@ mod tests { .await .unwrap(); + // sleep for a while to ensure different timestamps + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + let _id2 = manager .create_snapshot( &workspace, @@ -596,7 +609,7 @@ mod tests { assert_eq!(tag1_snapshots[0].message, "First"); // Find by tag (returns single snapshot since tags are immutable) - let tag2_snapshot = manager.find_by_tag("tag2").await.unwrap(); + let tag2_snapshot = manager.find_snapshot_by_tag("tag2").await.unwrap(); assert!(tag2_snapshot.is_some()); assert_eq!(tag2_snapshot.unwrap().message, "Second"); } @@ -910,7 +923,6 @@ mod tests { assert_eq!(snapshot.message, "Test message"); assert_eq!(snapshot.tags, vec!["tag1", "tag2"]); assert_eq!(snapshot.file_count, 1); - assert_eq!(snapshot.workspace_path, workspace); // Try getting non-existent snapshot let result = manager.get_snapshot("non-existent-id").await; @@ -1063,7 +1075,11 @@ mod tests { .create_snapshot( &workspace, Some("Test".to_string()), - Some(vec!["v1.0.0".to_string(), "stable".to_string(), "latest".to_string()]), + Some(vec![ + "v1.0.0".to_string(), + "stable".to_string(), + "latest".to_string(), + ]), ) .await .unwrap(); diff --git a/sandd/src/snapshot/types.rs b/sandd/src/snapshot/types.rs index dd8d2db..27697f9 100644 --- a/sandd/src/snapshot/types.rs +++ b/sandd/src/snapshot/types.rs @@ -1,17 +1,16 @@ use serde::{Deserialize, Serialize}; use std::path::PathBuf; -use std::time::SystemTime; pub type SnapshotId = String; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Snapshot { pub id: SnapshotId, - pub created_at: SystemTime, + pub created_at: u64, // Unix timestamp in seconds pub tree: String, pub message: String, pub tags: Vec, - pub workspace_path: PathBuf, + pub workspace: PathBuf, pub file_count: usize, pub total_size: u64, } @@ -19,7 +18,7 @@ pub struct Snapshot { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct SnapshotInfo { pub id: SnapshotId, - pub created_at: SystemTime, + pub created_at: u64, // Unix timestamp in seconds pub message: String, pub tags: Vec, pub file_count: usize, diff --git a/server/Cargo.toml b/server/Cargo.toml index 190cc9c..bad910b 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -36,6 +36,7 @@ parking_lot = "0.12" # Python bindings pyo3 = { version = "0.20", features = ["extension-module", "anyhow"] } +pythonize = "0.20" # Base64 for protocol base64 = "0.22" diff --git a/server/src/lib.rs b/server/src/lib.rs index 22c898a..0abf94b 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -170,7 +170,7 @@ impl Server { let request_id = Uuid::new_v4().to_string(); let (tx, rx) = oneshot::channel(); - conn.register_command(request_id.clone(), tx); + conn.register_request(request_id.clone(), tx); // Send command to daemon let msg = Message::ExecuteCommand { @@ -190,12 +190,22 @@ impl Server { self.runtime.block_on(async { // Wait for result with timeout match tokio::time::timeout(Duration::from_secs(timeout), rx).await { - Ok(Ok(result)) => Ok(PyCommandResult { - stdout: result.stdout, - stderr: result.stderr, - exit_code: result.exit_code, - duration_ms: result.duration_ms, + Ok(Ok(Message::CommandOutput { + stdout, + stderr, + exit_code, + duration_ms, + .. + })) => Ok(PyCommandResult { + stdout, + stderr, + exit_code, + duration_ms, }), + Ok(Ok(Message::CommandError { error, .. })) => { + Err(PyRuntimeError::new_err(format!("Command error: {}", error))) + } + Ok(Ok(_)) => Err(PyRuntimeError::new_err("Unexpected response type")), Ok(Err(_)) => Err(PyRuntimeError::new_err("Command channel closed")), Err(_) => Err(PyTimeoutError::new_err("Command execution timed out")), } @@ -349,6 +359,276 @@ impl Server { is_busy: conn.is_busy(), })) } + + /// Create snapshot on daemon + #[pyo3(signature = (daemon_id, workspace, message=None, tags=None))] + fn create_snapshot( + &self, + py: Python, + daemon_id: String, + workspace: String, + message: Option, + tags: Option>, + ) -> PyResult { + let conn = self + .registry + .get(&daemon_id) + .ok_or_else(|| PyValueError::new_err(format!("Daemon {} not found", daemon_id)))?; + + let request_id = Uuid::new_v4().to_string(); + let (tx, rx) = oneshot::channel(); + + conn.register_request(request_id.clone(), tx); + + let msg = Message::CreateSnapshot { + request_id: request_id.clone(), + workspace, + message, + tags, + }; + + conn.send_message(msg) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to send snapshot request: {}", e)))?; + + py.allow_threads(|| { + self.runtime.block_on(async { + match tokio::time::timeout(Duration::from_secs(300), rx).await { + Ok(Ok(Message::SnapshotCreated { snapshot_id, .. })) => Ok(snapshot_id), + Ok(Ok(Message::SnapshotError { error, .. })) => { + Err(PyRuntimeError::new_err(format!("Snapshot error: {}", error))) + } + Ok(Ok(_)) => Err(PyRuntimeError::new_err("Unexpected response type")), + Ok(Err(_)) => Err(PyRuntimeError::new_err("Snapshot channel closed")), + Err(_) => Err(PyTimeoutError::new_err("Snapshot creation timed out")), + } + }) + }) + } + + /// Restore snapshot on daemon + fn restore_snapshot( + &self, + py: Python, + daemon_id: String, + snapshot_id: String, + destination: String, + ) -> PyResult { + let conn = self + .registry + .get(&daemon_id) + .ok_or_else(|| PyValueError::new_err(format!("Daemon {} not found", daemon_id)))?; + + let request_id = Uuid::new_v4().to_string(); + let (tx, rx) = oneshot::channel(); + + conn.register_request(request_id.clone(), tx); + + let msg = Message::RestoreSnapshot { + request_id: request_id.clone(), + snapshot_id, + destination, + }; + + conn.send_message(msg) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to send restore request: {}", e)))?; + + py.allow_threads(|| { + self.runtime.block_on(async { + match tokio::time::timeout(Duration::from_secs(300), rx).await { + Ok(Ok(Message::SnapshotRestored { file_count, .. })) => Ok(file_count), + Ok(Ok(Message::SnapshotError { error, .. })) => { + Err(PyRuntimeError::new_err(format!("Restore error: {}", error))) + } + Ok(Ok(_)) => Err(PyRuntimeError::new_err("Unexpected response type")), + Ok(Err(_)) => Err(PyRuntimeError::new_err("Restore channel closed")), + Err(_) => Err(PyTimeoutError::new_err("Restore timed out")), + } + }) + }) + } + + /// List snapshots on daemon + #[pyo3(signature = (daemon_id, tags=None))] + fn list_snapshots( + &self, + py: Python, + daemon_id: String, + tags: Option>, + ) -> PyResult> { + let conn = self + .registry + .get(&daemon_id) + .ok_or_else(|| PyValueError::new_err(format!("Daemon {} not found", daemon_id)))?; + + let request_id = Uuid::new_v4().to_string(); + let (tx, rx) = oneshot::channel(); + + conn.register_request(request_id.clone(), tx); + + let msg = Message::ListSnapshots { + request_id: request_id.clone(), + tags, + }; + + conn.send_message(msg) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to send list request: {}", e)))?; + + py.allow_threads(|| { + self.runtime.block_on(async { + match tokio::time::timeout(Duration::from_secs(60), rx).await { + Ok(Ok(Message::SnapshotList { snapshots, .. })) => { + Python::with_gil(|py| { + snapshots.into_iter() + .map(|s| pythonize::pythonize(py, &s).map_err(|e| PyRuntimeError::new_err(e.to_string()))) + .collect() + }) + } + Ok(Ok(Message::SnapshotError { error, .. })) => { + Err(PyRuntimeError::new_err(format!("List error: {}", error))) + } + Ok(Ok(_)) => Err(PyRuntimeError::new_err("Unexpected response type")), + Ok(Err(_)) => Err(PyRuntimeError::new_err("List channel closed")), + Err(_) => Err(PyTimeoutError::new_err("List timed out")), + } + }) + }) + } + + /// Find snapshot by tag + fn find_snapshot_by_tag( + &self, + py: Python, + daemon_id: String, + tag: String, + ) -> PyResult> { + let conn = self + .registry + .get(&daemon_id) + .ok_or_else(|| PyValueError::new_err(format!("Daemon {} not found", daemon_id)))?; + + let request_id = Uuid::new_v4().to_string(); + let (tx, rx) = oneshot::channel(); + + conn.register_request(request_id.clone(), tx); + + let msg = Message::FindSnapshotByTag { + request_id: request_id.clone(), + tag, + }; + + conn.send_message(msg) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to send find request: {}", e)))?; + + py.allow_threads(|| { + self.runtime.block_on(async { + match tokio::time::timeout(Duration::from_secs(60), rx).await { + Ok(Ok(Message::SnapshotDetails { snapshot: None, .. })) => Ok(None), + Ok(Ok(Message::SnapshotDetails { snapshot: Some(snapshot), .. })) => { + Python::with_gil(|py| { + pythonize::pythonize(py, &snapshot) + .map(Some) + .map_err(|e| PyRuntimeError::new_err(e.to_string())) + }) + } + Ok(Ok(Message::SnapshotError { error, .. })) => { + Err(PyRuntimeError::new_err(format!("Find error: {}", error))) + } + Ok(Ok(_)) => Err(PyRuntimeError::new_err("Unexpected response type")), + Ok(Err(_)) => Err(PyRuntimeError::new_err("Find channel closed")), + Err(_) => Err(PyTimeoutError::new_err("Find timed out")), + } + }) + }) + } + + /// Get snapshot details (returns None if not found) + fn get_snapshot( + &self, + py: Python, + daemon_id: String, + snapshot_id: String, + ) -> PyResult> { + let conn = self + .registry + .get(&daemon_id) + .ok_or_else(|| PyValueError::new_err(format!("Daemon {} not found", daemon_id)))?; + + let request_id = Uuid::new_v4().to_string(); + let (tx, rx) = oneshot::channel(); + + conn.register_request(request_id.clone(), tx); + + let msg = Message::GetSnapshot { + request_id: request_id.clone(), + snapshot_id, + }; + + conn.send_message(msg) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to send get request: {}", e)))?; + + py.allow_threads(|| { + self.runtime.block_on(async { + match tokio::time::timeout(Duration::from_secs(60), rx).await { + Ok(Ok(Message::SnapshotDetails { snapshot: Some(snapshot), .. })) => { + Python::with_gil(|py| { + pythonize::pythonize(py, &snapshot) + .map(Some) + .map_err(|e| PyRuntimeError::new_err(e.to_string())) + }) + } + Ok(Ok(Message::SnapshotDetails { snapshot: None, .. })) => { + Ok(None) + } + Ok(Ok(Message::SnapshotError { error, .. })) => { + Err(PyRuntimeError::new_err(format!("Get error: {}", error))) + } + Ok(Ok(_)) => Err(PyRuntimeError::new_err("Unexpected response type")), + Ok(Err(_)) => Err(PyRuntimeError::new_err("Get channel closed")), + Err(_) => Err(PyTimeoutError::new_err("Get timed out")), + } + }) + }) + } + + /// Delete snapshot + fn delete_snapshot( + &self, + py: Python, + daemon_id: String, + snapshot_id: String, + ) -> PyResult<()> { + let conn = self + .registry + .get(&daemon_id) + .ok_or_else(|| PyValueError::new_err(format!("Daemon {} not found", daemon_id)))?; + + let request_id = Uuid::new_v4().to_string(); + let (tx, rx) = oneshot::channel(); + + conn.register_request(request_id.clone(), tx); + + let msg = Message::DeleteSnapshot { + request_id: request_id.clone(), + snapshot_id, + }; + + conn.send_message(msg) + .map_err(|e| PyRuntimeError::new_err(format!("Failed to send delete request: {}", e)))?; + + py.allow_threads(|| { + self.runtime.block_on(async { + match tokio::time::timeout(Duration::from_secs(60), rx).await { + Ok(Ok(Message::SnapshotDeleted { .. })) => Ok(()), + Ok(Ok(Message::SnapshotError { error, .. })) => { + Err(PyRuntimeError::new_err(format!("Delete error: {}", error))) + } + Ok(Ok(_)) => Err(PyRuntimeError::new_err("Unexpected response type")), + Ok(Err(_)) => Err(PyRuntimeError::new_err("Delete channel closed")), + Err(_) => Err(PyTimeoutError::new_err("Delete timed out")), + } + }) + }) + } } /// Session handle diff --git a/server/src/protocol.rs b/server/src/protocol.rs index 2b47147..3cc38ca 100644 --- a/server/src/protocol.rs +++ b/server/src/protocol.rs @@ -110,6 +110,60 @@ pub enum Message { error: String, }, + // Snapshot operations + CreateSnapshot { + request_id: String, + workspace: String, + message: Option, + tags: Option>, + }, + SnapshotCreated { + request_id: String, + snapshot_id: String, + file_count: usize, + total_size: u64, + }, + RestoreSnapshot { + request_id: String, + snapshot_id: String, + destination: String, + }, + SnapshotRestored { + request_id: String, + file_count: usize, + }, + ListSnapshots { + request_id: String, + tags: Option>, + }, + SnapshotList { + request_id: String, + snapshots: Vec, + }, + FindSnapshotByTag { + request_id: String, + tag: String, + }, + GetSnapshot { + request_id: String, + snapshot_id: String, + }, + SnapshotDetails { + request_id: String, + snapshot: Option, + }, + DeleteSnapshot { + request_id: String, + snapshot_id: String, + }, + SnapshotDeleted { + request_id: String, + }, + SnapshotError { + request_id: String, + error: String, + }, + // Error handling Error { message: String, diff --git a/server/src/registry.rs b/server/src/registry.rs index 232e1c9..fd8cda2 100644 --- a/server/src/registry.rs +++ b/server/src/registry.rs @@ -17,19 +17,20 @@ pub struct DaemonConnection { // ═══════════════════════════════════════════════════════════════════ // Outgoing: Python → Daemon // ═══════════════════════════════════════════════════════════════════ - /// Channel to send commands to daemon (Python → handle_websocket → Daemon) + /// Channel to send requests to daemon (Python → handle_websocket → Daemon) + /// Handles all message types: ExecuteCommand, CreateSnapshot, etc. /// This is the bridge from Python API to the WebSocket handler. /// Multiple Python threads can send concurrently (lock-free). - command_tx: mpsc::UnboundedSender, + request_tx: mpsc::UnboundedSender, // ═══════════════════════════════════════════════════════════════════ // Incoming: Daemon → Python (Request/Response Pattern) // ═══════════════════════════════════════════════════════════════════ - /// Maps request_id → response channel for exec() calls - /// When Python sends a command, it registers a oneshot channel here and waits. - /// When daemon responds with CommandOutput, we look up and send result back. - /// Pattern: Request/Response (each command gets exactly one response) - pending_commands: Arc>>, + /// Maps request_id → response channel for ALL request/response operations + /// Handles: ExecuteCommand, CreateSnapshot, ListSnapshots, FindSnapshotByTag, + /// GetSnapshot, DeleteSnapshot, RestoreSnapshot, and future operations + /// Pattern: Request/Response (each request gets exactly one response Message) + pending_requests: Arc>>, // ═══════════════════════════════════════════════════════════════════ // Incoming: Daemon → Python (Streaming Pattern) @@ -48,14 +49,6 @@ pub struct DaemonConnection { file_transfers: Arc>, } -#[derive(Debug, Clone)] -pub struct CommandResult { - pub stdout: String, - pub stderr: String, - pub exit_code: i32, - pub duration_ms: u64, -} - #[derive(Debug)] pub struct FileTransfer { pub path: String, @@ -68,7 +61,7 @@ impl DaemonConnection { pub fn new( id: String, metadata: DaemonMetadata, - command_tx: mpsc::UnboundedSender, + request_tx: mpsc::UnboundedSender, ) -> Self { let now = SystemTime::now() .duration_since(UNIX_EPOCH) @@ -80,8 +73,8 @@ impl DaemonConnection { metadata, last_heartbeat: AtomicU64::new(now), connected_at: now, - command_tx, - pending_commands: Arc::new(DashMap::new()), + request_tx, + pending_requests: Arc::new(DashMap::new()), sessions: Arc::new(DashMap::new()), file_transfers: Arc::new(DashMap::new()), } @@ -105,24 +98,24 @@ impl DaemonConnection { } pub fn send_message(&self, msg: Message) -> Result<()> { - self.command_tx + self.request_tx .send(msg) .map_err(|_| anyhow!("Daemon channel closed"))?; Ok(()) } - pub fn register_command(&self, command_id: String, tx: oneshot::Sender) { - self.pending_commands.insert(command_id, tx); + pub fn register_request(&self, request_id: String, tx: oneshot::Sender) { + self.pending_requests.insert(request_id, tx); } - pub fn complete_command(&self, command_id: &str, result: CommandResult) { - if let Some((_, tx)) = self.pending_commands.remove(command_id) { - let _ = tx.send(result); + pub fn complete_request(&self, request_id: &str, response: Message) { + if let Some((_, tx)) = self.pending_requests.remove(request_id) { + let _ = tx.send(response); } } pub fn is_busy(&self) -> bool { - !self.pending_commands.is_empty() + !self.pending_requests.is_empty() } pub fn register_session(&self, session_id: String, tx: mpsc::UnboundedSender>) { diff --git a/server/src/server.rs b/server/src/server.rs index 2babef2..324145b 100644 --- a/server/src/server.rs +++ b/server/src/server.rs @@ -1,5 +1,5 @@ use crate::protocol::Message; -use crate::registry::{CommandResult, DaemonConnection, DaemonRegistry}; +use crate::registry::{DaemonConnection, DaemonRegistry}; use anyhow::{Context, Result}; use axum::{ extract::{ @@ -97,8 +97,8 @@ async fn websocket_handler( async fn handle_websocket(ws: WebSocket, registry: Arc) { let (mut ws_tx, mut ws_rx) = ws.split(); - // Create channel for outgoing commands - let (cmd_tx, mut cmd_rx) = tokio::sync::mpsc::unbounded_channel(); + // Create channel for outgoing requests (Python → Daemon) + let (request_tx, mut request_rx) = tokio::sync::mpsc::unbounded_channel(); let mut daemon_id: Option = None; @@ -131,21 +131,21 @@ async fn handle_websocket(ws: WebSocket, registry: Arc) { } }; - handle_daemon_message(message, &mut daemon_id, ®istry, &mut ws_tx, &cmd_tx).await; + handle_daemon_message(message, &mut daemon_id, ®istry, &mut ws_tx, &request_tx).await; } - // Receive commands from Python (via channel) - Some(cmd) = cmd_rx.recv() => { - let json = match serde_json::to_string(&cmd) { + // Receive requests from Python (via channel) + Some(request) = request_rx.recv() => { + let json = match serde_json::to_string(&request) { Ok(j) => j, Err(e) => { - error!("Failed to serialize command: {}", e); + error!("Failed to serialize request: {}", e); continue; } }; if let Err(e) = ws_tx.send(axum::extract::ws::Message::Text(json)).await { - error!("Failed to send command to daemon: {}", e); + error!("Failed to send request to daemon: {}", e); break; } } @@ -165,7 +165,7 @@ async fn handle_daemon_message( daemon_id: &mut Option, registry: &Arc, ws_tx: &mut futures_util::stream::SplitSink, - cmd_tx: &mpsc::UnboundedSender, + request_tx: &mpsc::UnboundedSender, ) { use futures_util::SinkExt; @@ -183,7 +183,7 @@ async fn handle_daemon_message( ); // Create and register connection with channel - let new_conn = DaemonConnection::new(id.clone(), metadata, cmd_tx.clone()); + let new_conn = DaemonConnection::new(id.clone(), metadata, request_tx.clone()); registry.register(new_conn); // Send ack @@ -206,25 +206,31 @@ async fn handle_daemon_message( } } - Message::CommandOutput { - request_id, - stdout, - stderr, - exit_code, - duration_ms, - } => { + // All response messages for request/response pattern + response @ (Message::CommandOutput { .. } + | Message::CommandError { .. } + | Message::SnapshotCreated { .. } + | Message::SnapshotRestored { .. } + | Message::SnapshotList { .. } + | Message::SnapshotDetails { .. } + | Message::SnapshotDeleted { .. } + | Message::SnapshotError { .. }) => { if let Some(ref id) = daemon_id { if let Some(conn) = registry.get(id) { - debug!("Command {} completed on daemon {}", request_id, id); - conn.complete_command( - &request_id, - CommandResult { - stdout, - stderr, - exit_code, - duration_ms, - }, - ); + // Helper to extract request_id without moving + let request_id = match &response { + Message::CommandOutput { request_id, .. } + | Message::CommandError { request_id, .. } + | Message::SnapshotCreated { request_id, .. } + | Message::SnapshotRestored { request_id, .. } + | Message::SnapshotList { request_id, .. } + | Message::SnapshotDetails { request_id, .. } + | Message::SnapshotDeleted { request_id, .. } + | Message::SnapshotError { request_id, .. } => request_id.clone(), + _ => unreachable!(), + }; + debug!("Request {} completed on daemon {}", request_id, id); + conn.complete_request(&request_id, response); } } }