diff --git a/.github/scripts/detect-changes.mjs b/.github/scripts/detect-changes.mjs index e3b05504f99..2f27c3c6ffc 100644 --- a/.github/scripts/detect-changes.mjs +++ b/.github/scripts/detect-changes.mjs @@ -2,7 +2,7 @@ import { execSync } from 'child_process'; import { appendFileSync } from 'fs'; -const ALL_PACKAGES = ['wasm-bip32', 'wasm-mps', 'wasm-utxo', 'wasm-solana', 'wasm-dot', 'wasm-ton']; +const ALL_PACKAGES = ['wasm-bip32', 'wasm-mps', 'wasm-utxo', 'wasm-solana', 'wasm-dot', 'wasm-ton', 'wasm-privacy-coin']; function setOutput(packages) { const value = JSON.stringify(packages); diff --git a/.github/workflows/build-and-test.yaml b/.github/workflows/build-and-test.yaml index 593053b467e..9c3b16b059c 100644 --- a/.github/workflows/build-and-test.yaml +++ b/.github/workflows/build-and-test.yaml @@ -64,6 +64,7 @@ jobs: packages/wasm-solana packages/wasm-dot packages/wasm-ton + packages/wasm-privacy-coin cache-on-failure: true - name: Setup Node @@ -101,6 +102,16 @@ jobs: - name: Build packages run: npm --workspaces run build + - name: Setup JDK 17 (for wasm-privacy-coin JAR) + uses: actions/setup-java@v4 + with: + distribution: corretto + java-version: '17' + + - name: Build wasm-privacy-coin JAR + working-directory: packages/wasm-privacy-coin + run: make jar + - name: Check Source Code Formatting run: npm run check-fmt @@ -121,6 +132,8 @@ jobs: packages/wasm-dot/js/wasm/ packages/wasm-ton/dist/ packages/wasm-ton/js/wasm/ + packages/wasm-privacy-coin/dist/ + packages/wasm-privacy-coin/js/wasm/ retention-days: 1 - name: Upload webui artifact @@ -156,6 +169,9 @@ jobs: - package: wasm-ton needs-wasm-pack: false has-wasm-pack-tests: false + - package: wasm-privacy-coin + needs-wasm-pack: false + has-wasm-pack-tests: false steps: - uses: actions/checkout@v4 with: @@ -298,6 +314,16 @@ jobs: packages/wasm-ton/dist/ retention-days: 1 + - name: Upload wasm-privacy-coin build artifacts + if: inputs.upload-artifacts + uses: actions/upload-artifact@v4 + with: + name: wasm-privacy-coin-build + path: | + packages/wasm-privacy-coin/pkg/ + packages/wasm-privacy-coin/dist/ + retention-days: 1 + # This job provides a stable "test / Test" status check for branch protection. # It runs after all other jobs complete successfully. gate: diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index 519416172ca..304ef30c4e0 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -80,3 +80,65 @@ jobs: - name: Release (multi-semantic-release) run: npx multi-semantic-release --ignore-private-packages + + publish-maven: + name: Publish Maven Artifact + needs: test + runs-on: ubuntu-latest + environment: publish + permissions: + id-token: write + contents: read + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Setup JDK 17 + uses: actions/setup-java@v4 + with: + distribution: corretto + java-version: '17' + + - name: Download wasm-privacy-coin build artifacts + uses: actions/download-artifact@v4 + with: + name: wasm-privacy-coin-build + path: packages/wasm-privacy-coin/dist/ + + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.AWS_ROLE_ARN }} + aws-region: us-west-2 + + - name: Get CodeArtifact token + run: | + TOKEN=$(aws codeartifact get-authorization-token \ + --domain private \ + --domain-owner 199765120567 \ + --query authorizationToken \ + --output text) + echo "::add-mask::$TOKEN" + echo "AWS_CODEARTIFACT_TOKEN=$TOKEN" >> "$GITHUB_ENV" + + - name: Read version from pom.xml + id: version + run: | + VERSION=$(grep '' packages/wasm-privacy-coin/pom.xml | head -1 | sed 's/.*\(.*\)<\/version>.*/\1/') + if [ -z "$VERSION" ]; then echo "Failed to parse version from pom.xml" && exit 1; fi + echo "version=$VERSION" >> "$GITHUB_OUTPUT" + + - name: Deploy JAR to CodeArtifact + working-directory: packages/wasm-privacy-coin + run: | + mvn deploy:deploy-file \ + -s .settings.xml \ + -Pcodeartifact-deploy \ + -DgroupId=com.bitgo \ + -DartifactId=wasm-privacy-coin \ + -Dversion=${{ steps.version.outputs.version }} \ + -Dpackaging=jar \ + -Dfile=dist/wasm-privacy-coin.jar \ + -DrepositoryId=codeartifact-central \ + -Durl=https://private-199765120567.d.codeartifact.us-west-2.amazonaws.com/maven/bitgo-maven-libs-release/ diff --git a/packages/wasm-privacy-coin/.settings.xml b/packages/wasm-privacy-coin/.settings.xml new file mode 100644 index 00000000000..c313e82687a --- /dev/null +++ b/packages/wasm-privacy-coin/.settings.xml @@ -0,0 +1,13 @@ + + + + + codeartifact-central + aws + ${env.AWS_CODEARTIFACT_TOKEN} + + + diff --git a/packages/wasm-privacy-coin/Cargo.toml b/packages/wasm-privacy-coin/Cargo.toml new file mode 100644 index 00000000000..26dc00b79cf --- /dev/null +++ b/packages/wasm-privacy-coin/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "wasm-privacy-coin" +version = "0.1.0" +edition = "2021" + +[lib] +crate-type = ["cdylib", "lib"] + +[dependencies] +# Core tree crates — pinned to exact versions for deterministic builds. +# These must stay in sync with the zcash ecosystem (NU6-compatible). +shardtree = "=0.6.2" +incrementalmerkletree = "=0.8.2" +orchard = { version = "=0.14.0", default-features = false } + +# IPC protocol +serde = { version = "1", features = ["derive"] } +serde_json = "1" +hex = "0.4" + +[profile.release] +opt-level = 3 +lto = true +strip = true diff --git a/packages/wasm-privacy-coin/Makefile b/packages/wasm-privacy-coin/Makefile new file mode 100644 index 00000000000..552f820e032 --- /dev/null +++ b/packages/wasm-privacy-coin/Makefile @@ -0,0 +1,19 @@ +WASM_TARGET = wasm32-unknown-unknown + +.PHONY: build +build: + cargo build --release --target $(WASM_TARGET) + mkdir -p dist + cp target/$(WASM_TARGET)/release/wasm_privacy_coin.wasm dist/wasm-privacy-coin.wasm + +.PHONY: jar +jar: build + mkdir -p jar-staging/wasm + cp dist/wasm-privacy-coin.wasm jar-staging/wasm/privacy_coin.wasm + cd jar-staging && jar cf ../dist/wasm-privacy-coin.jar wasm/ + rm -rf jar-staging + +.PHONY: clean +clean: + cargo clean + rm -rf dist jar-staging diff --git a/packages/wasm-privacy-coin/README.md b/packages/wasm-privacy-coin/README.md new file mode 100644 index 00000000000..77af18d5486 --- /dev/null +++ b/packages/wasm-privacy-coin/README.md @@ -0,0 +1,41 @@ +# wasm-privacy-coin + +Orchard commitment tree (Zcash NU6) compiled to WebAssembly for use by the +indexer-utxo Java service via the Chicory WASM runtime. + +## Building + +```bash +rustup target add wasm32-unknown-unknown +make build +``` + +The compiled WASM binary will be at `dist/wasm-privacy-coin.wasm`. + +## Architecture + +This module exposes a C-style FFI interface (no wasm-bindgen, no WASI). +The host allocates memory via `alloc()`, writes JSON input, calls a function, +then reads JSON output via `last_result_ptr()`/`last_result_len()`. + +## Exported Functions + +| Function | Signature | Description | +|---|---|---| +| `alloc` | `(len: u32) -> *mut u8` | Allocate buffer in WASM memory | +| `dealloc` | `(ptr: *mut u8, len: u32)` | Free allocated buffer | +| `ping` | `() -> i32` | Health check | +| `init_from_frontier` | `(ptr, len) -> i32` | Initialize from z_gettreestate frontier | +| `load_state` | `(ptr, len) -> i32` | Load persisted JSON state | +| `save` | `() -> i32` | Serialize tree to JSON | +| `get_info` | `() -> i32` | Return tip height, leaf count, checkpoint count | +| `append_commitments` | `(ptr, len) -> i32` | Append Orchard commitments, verify root | +| `truncate_to_checkpoint` | `(ptr, len) -> i32` | Reorg handling | +| `last_result_ptr` | `() -> *const u8` | Pointer to last result buffer | +| `last_result_len` | `() -> u32` | Length of last result buffer | + +## Pinned Dependencies + +The `shardtree`, `incrementalmerkletree`, and `orchard` crate versions are +pinned to exact versions to ensure deterministic builds and compatibility with +the NU6 Zcash protocol upgrade. diff --git a/packages/wasm-privacy-coin/package.json b/packages/wasm-privacy-coin/package.json new file mode 100644 index 00000000000..d9dfcf66480 --- /dev/null +++ b/packages/wasm-privacy-coin/package.json @@ -0,0 +1,14 @@ +{ + "name": "@bitgo/wasm-privacy-coin", + "version": "0.1.0", + "private": true, + "scripts": { + "build": "make build", + "lint": "cargo fmt --check && cargo clippy --all-targets --all-features -- -D warnings", + "check-fmt": "cargo fmt -- --check", + "test": "cargo test --workspace" + }, + "files": [ + "dist/wasm-privacy-coin.wasm" + ] +} diff --git a/packages/wasm-privacy-coin/pom.xml b/packages/wasm-privacy-coin/pom.xml new file mode 100644 index 00000000000..81a216e2ab5 --- /dev/null +++ b/packages/wasm-privacy-coin/pom.xml @@ -0,0 +1,26 @@ + + + 4.0.0 + + com.bitgo + wasm-privacy-coin + 0.1.0 + jar + + WASM module for Orchard merkle tree operations + + + + codeartifact-deploy + + + codeartifact-central + BitGo CodeArtifact Release Repository + https://private-199765120567.d.codeartifact.us-west-2.amazonaws.com/maven/bitgo-maven-libs-release/ + + + + + diff --git a/packages/wasm-privacy-coin/src/lib.rs b/packages/wasm-privacy-coin/src/lib.rs new file mode 100644 index 00000000000..fbb1dc156a1 --- /dev/null +++ b/packages/wasm-privacy-coin/src/lib.rs @@ -0,0 +1,216 @@ +mod protocol; +mod tree; + +use protocol::{AppendCommitmentsParams, InitFromFrontierParams, LoadParams, TruncateParams}; +use std::cell::RefCell; + +// --------------------------------------------------------------------------- +// WASM memory management +// --------------------------------------------------------------------------- + +/// Allocate a buffer in WASM linear memory for the host to write input data into. +#[no_mangle] +pub extern "C" fn alloc(len: u32) -> *mut u8 { + let mut buf = Vec::with_capacity(len as usize); + let ptr = buf.as_mut_ptr(); + std::mem::forget(buf); + ptr +} + +/// Free a buffer previously allocated by `alloc` or returned as output. +#[no_mangle] +pub extern "C" fn dealloc(ptr: *mut u8, len: u32) { + unsafe { + let _ = Vec::from_raw_parts(ptr, len as usize, len as usize); + } +} + +// --------------------------------------------------------------------------- +// Result passing: the guest writes JSON output to a thread-local buffer, +// and the host reads pointer+length via last_result_ptr / last_result_len. +// --------------------------------------------------------------------------- + +thread_local! { + static LAST_RESULT: RefCell> = RefCell::new(Vec::new()); +} + +/// Returns a pointer to the last result buffer. Valid until the next call. +#[no_mangle] +pub extern "C" fn last_result_ptr() -> *const u8 { + LAST_RESULT.with(|r| r.borrow().as_ptr()) +} + +/// Returns the length of the last result buffer. +#[no_mangle] +pub extern "C" fn last_result_len() -> u32 { + LAST_RESULT.with(|r| r.borrow().len() as u32) +} + +fn set_result(json: &str) { + LAST_RESULT.with(|r| { + *r.borrow_mut() = json.as_bytes().to_vec(); + }); +} + +fn ok_result() -> i32 { + set_result(r#"{"success":true}"#); + 0 +} + +fn success_result(json: &str) -> i32 { + set_result(json); + 0 +} + +fn error_result(code: &str, message: &str) -> i32 { + let escaped_msg = message.replace('\\', "\\\\").replace('"', "\\\""); + set_result(&format!( + r#"{{"success":false,"error":{{"code":"{}","message":"{}"}}}}"#, + code, escaped_msg + )); + 1 +} + +/// Read JSON bytes from WASM memory at the given pointer+length. +unsafe fn read_input(ptr: *const u8, len: u32) -> Result { + let slice = std::slice::from_raw_parts(ptr, len as usize); + String::from_utf8(slice.to_vec()).map_err(|_| { + error_result("INVALID_UTF8", "input is not valid UTF-8") + }) +} + +// --------------------------------------------------------------------------- +// Exported WASM functions +// --------------------------------------------------------------------------- + +#[no_mangle] +pub extern "C" fn ping() -> i32 { + ok_result() +} + +#[no_mangle] +pub extern "C" fn init_from_frontier(ptr: *const u8, len: u32) -> i32 { + let input = match unsafe { read_input(ptr, len) } { + Ok(s) => s, + Err(code) => return code, + }; + + let params: InitFromFrontierParams = match serde_json::from_str(&input) { + Ok(p) => p, + Err(e) => return error_result("INVALID_PARAMS", &format!("failed to parse params: {}", e)), + }; + + match tree::init_from_frontier(¶ms.frontier_hex, params.block_height) { + Ok(root) => success_result(&format!( + r#"{{"success":true,"result":{{"root":"{}"}}}}"#, + root + )), + Err(e) => error_result("FRONTIER_ERROR", &e), + } +} + +#[no_mangle] +pub extern "C" fn load_state(ptr: *const u8, len: u32) -> i32 { + let input = match unsafe { read_input(ptr, len) } { + Ok(s) => s, + Err(code) => return code, + }; + + let params: LoadParams = match serde_json::from_str(&input) { + Ok(p) => p, + Err(e) => return error_result("INVALID_PARAMS", &format!("failed to parse params: {}", e)), + }; + + match tree::load(¶ms.state) { + Ok(()) => ok_result(), + Err(e) => error_result("LOAD_ERROR", &e), + } +} + +#[no_mangle] +pub extern "C" fn save() -> i32 { + match tree::save() { + Ok(state) => success_result(&format!( + r#"{{"success":true,"result":{{"state":{}}}}}"#, + serde_json::to_string(&state).unwrap_or_else(|_| "null".to_string()) + )), + Err(e) => error_result("SAVE_ERROR", &e), + } +} + +#[no_mangle] +pub extern "C" fn get_info() -> i32 { + match tree::get_info() { + Ok((tip_height, leaf_count, checkpoint_count)) => { + let tip = match tip_height { + Some(h) => h.to_string(), + None => "null".to_string(), + }; + success_result(&format!( + r#"{{"success":true,"result":{{"info":{{"tip_height":{},"leaf_count":{},"checkpoint_count":{}}}}}}}"#, + tip, leaf_count, checkpoint_count + )) + } + Err(e) => error_result("INFO_ERROR", &e), + } +} + +#[no_mangle] +pub extern "C" fn append_commitments(ptr: *const u8, len: u32) -> i32 { + let input = match unsafe { read_input(ptr, len) } { + Ok(s) => s, + Err(code) => return code, + }; + + let params: AppendCommitmentsParams = match serde_json::from_str(&input) { + Ok(p) => p, + Err(e) => return error_result("INVALID_PARAMS", &format!("failed to parse params: {}", e)), + }; + + match tree::append_commitments( + params.block_height, + ¶ms.commitments, + params.expected_root.as_deref(), + ) { + Ok(root) => success_result(&format!( + r#"{{"success":true,"result":{{"root":"{}"}}}}"#, + root + )), + Err(e) => { + let code = if e.contains("ROOT_MISMATCH") { + "ROOT_MISMATCH" + } else { + "TREE_ERROR" + }; + error_result(code, &e) + } + } +} + +#[no_mangle] +pub extern "C" fn truncate_to_checkpoint(ptr: *const u8, len: u32) -> i32 { + let input = match unsafe { read_input(ptr, len) } { + Ok(s) => s, + Err(code) => return code, + }; + + let params: TruncateParams = match serde_json::from_str(&input) { + Ok(p) => p, + Err(e) => return error_result("INVALID_PARAMS", &format!("failed to parse params: {}", e)), + }; + + match tree::truncate_to_checkpoint(params.block_height) { + Ok(root) => success_result(&format!( + r#"{{"success":true,"result":{{"root":"{}"}}}}"#, + root + )), + Err(e) => { + let code = if e.contains("CHECKPOINT_NOT_FOUND") { + "CHECKPOINT_NOT_FOUND" + } else { + "TREE_ERROR" + }; + error_result(code, &e) + } + } +} diff --git a/packages/wasm-privacy-coin/src/protocol.rs b/packages/wasm-privacy-coin/src/protocol.rs new file mode 100644 index 00000000000..75e4cab9647 --- /dev/null +++ b/packages/wasm-privacy-coin/src/protocol.rs @@ -0,0 +1,66 @@ +use serde::Deserialize; + +#[derive(Debug, Deserialize, Default)] +pub struct AppendCommitmentsParams { + /// Block height for checkpoint + pub block_height: u32, + /// Hex-encoded commitment values (cmx) + pub commitments: Vec, + /// Optional hex-encoded expected root for verification + pub expected_root: Option, +} + +#[derive(Debug, Deserialize, Default)] +pub struct TruncateParams { + /// Block height to truncate to + pub block_height: u32, +} + +#[derive(Debug, Deserialize, Default)] +pub struct LoadParams { + /// JSON-serialized tree state + pub state: String, +} + +#[derive(Debug, Deserialize, Default)] +pub struct InitFromFrontierParams { + /// Hex-encoded serialized frontier from z_gettreestate + pub frontier_hex: String, + /// Block height at which this frontier was captured + pub block_height: u32, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_deserialize_append_params() { + let json = r#"{"block_height":100,"commitments":["abcd"],"expected_root":null}"#; + let params: AppendCommitmentsParams = serde_json::from_str(json).unwrap(); + assert_eq!(params.block_height, 100); + assert_eq!(params.commitments, vec!["abcd"]); + } + + #[test] + fn test_deserialize_load_params() { + let json = r#"{"state":"{\"shards\":[]}"}"#; + let params: LoadParams = serde_json::from_str(json).unwrap(); + assert!(params.state.contains("shards")); + } + + #[test] + fn test_deserialize_truncate_params() { + let json = r#"{"block_height":50}"#; + let params: TruncateParams = serde_json::from_str(json).unwrap(); + assert_eq!(params.block_height, 50); + } + + #[test] + fn test_deserialize_init_from_frontier_params() { + let json = r#"{"frontier_hex":"aabb","block_height":999}"#; + let params: InitFromFrontierParams = serde_json::from_str(json).unwrap(); + assert_eq!(params.frontier_hex, "aabb"); + assert_eq!(params.block_height, 999); + } +} diff --git a/packages/wasm-privacy-coin/src/tree.rs b/packages/wasm-privacy-coin/src/tree.rs new file mode 100644 index 00000000000..9df5518e623 --- /dev/null +++ b/packages/wasm-privacy-coin/src/tree.rs @@ -0,0 +1,560 @@ +use incrementalmerkletree::{Address, Hashable, Level, Marking, Position, Retention}; +use orchard::tree::MerkleHashOrchard; +use serde::{Deserialize, Serialize}; +use shardtree::{ + store::memory::MemoryShardStore, + store::{Checkpoint, ShardStore, TreeState}, + LocatedPrunableTree, Node, RetentionFlags, ShardTree, Tree, +}; +use std::collections::BTreeSet; +use std::sync::{Arc, Mutex}; + +/// Orchard commitment tree depth +pub const DEPTH: u8 = 32; +/// Number of checkpoints to retain (allows reorgs up to this depth) +pub const MAX_CHECKPOINTS: usize = 100; +/// Shard height for the tree (log2 of shard size) +pub const SHARD_HEIGHT: u8 = 16; + +pub type OrchardShardTree = + ShardTree, DEPTH, SHARD_HEIGHT>; + +type PrunableT = Tree>, (MerkleHashOrchard, RetentionFlags)>; + +// --------------------------------------------------------------------------- +// Structural serialization types (ported from coins-sandbox/orchard-wasm-shard) +// --------------------------------------------------------------------------- + +/// Serde-tagged enum mirroring `PrunableTree` +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum TreeNode { + Nil, + Leaf { + h: String, + f: u8, + }, + Parent { + a: String, + l: Box, + r: Box, + }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SerializedShard { + pub root_addr_level: u8, + pub root_addr_index: u64, + pub tree: TreeNode, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SerializedCheckpoint { + pub id: u32, + pub position: Option, + pub marks_removed: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PersistedShardTreeState { + pub shards: Vec, + pub cap: TreeNode, + pub checkpoints: Vec, + pub tip_height: Option, + pub leaf_count: u64, +} + +// --------------------------------------------------------------------------- +// Tree <-> TreeNode conversion +// --------------------------------------------------------------------------- + +fn serialize_tree(tree: &PrunableT) -> TreeNode { + match &**tree { + Node::Nil => TreeNode::Nil, + Node::Leaf { value: (h, flags) } => TreeNode::Leaf { + h: hex::encode(h.to_bytes()), + f: flags.bits(), + }, + Node::Parent { ann, left, right } => TreeNode::Parent { + a: ann + .as_ref() + .map(|h| hex::encode(h.to_bytes())) + .unwrap_or_default(), + l: Box::new(serialize_tree(left)), + r: Box::new(serialize_tree(right)), + }, + } +} + +fn deserialize_tree(node: &TreeNode) -> Result { + match node { + TreeNode::Nil => Ok(Tree::empty()), + TreeNode::Leaf { h, f } => { + let hash = parse_hash_hex(h)?; + let flags = + RetentionFlags::from_bits(*f).ok_or_else(|| format!("invalid retention flags: {}", f))?; + Ok(Tree::leaf((hash, flags))) + } + TreeNode::Parent { a, l, r } => { + let ann = if a.is_empty() { + None + } else { + Some(Arc::new(parse_hash_hex(a)?)) + }; + Ok(Tree::parent( + ann, + deserialize_tree(l)?, + deserialize_tree(r)?, + )) + } + } +} + +// --------------------------------------------------------------------------- +// ShardTree <-> PersistedShardTreeState extraction / restoration +// --------------------------------------------------------------------------- + +pub fn extract_state( + tree: &OrchardShardTree, + tip_height: Option, + leaf_count: u64, +) -> Result { + let store = tree.store(); + + let shard_roots = store + .get_shard_roots() + .map_err(|e| format!("get_shard_roots error: {:?}", e))?; + + let mut shards: Vec = Vec::new(); + for addr in shard_roots { + let shard = store + .get_shard(addr) + .map_err(|e| format!("get_shard error: {:?}", e))?; + if let Some(shard) = shard { + shards.push(SerializedShard { + root_addr_level: shard.root_addr().level().into(), + root_addr_index: shard.root_addr().index(), + tree: serialize_tree(shard.root()), + }); + } + } + + let cap = serialize_tree( + &store + .get_cap() + .map_err(|e| format!("get_cap error: {:?}", e))?, + ); + + let count = store + .checkpoint_count() + .map_err(|e| format!("checkpoint_count error: {:?}", e))?; + + let mut checkpoints: Vec = Vec::new(); + if count > 0 { + store + .for_each_checkpoint(count, |id, cp| { + checkpoints.push(SerializedCheckpoint { + id: *id, + position: match cp.tree_state() { + TreeState::Empty => None, + TreeState::AtPosition(pos) => Some(u64::from(pos)), + }, + marks_removed: cp.marks_removed().iter().map(|p| u64::from(*p)).collect(), + }); + Ok(()) + }) + .map_err(|e| format!("for_each_checkpoint error: {:?}", e))?; + } + + Ok(PersistedShardTreeState { + shards, + cap, + checkpoints, + tip_height, + leaf_count, + }) +} + +pub fn restore_state(state: &PersistedShardTreeState) -> Result { + let mut store = MemoryShardStore::empty(); + + // Restore shards + for s in &state.shards { + let addr = Address::from_parts(Level::from(s.root_addr_level), s.root_addr_index); + let tree = deserialize_tree(&s.tree)?; + let located = LocatedPrunableTree::from_parts(addr, tree) + .map_err(|a| format!("invalid shard address: {:?}", a))?; + store + .put_shard(located) + .map_err(|e| format!("put_shard error: {:?}", e))?; + } + + // Restore cap + let cap_tree = deserialize_tree(&state.cap)?; + store + .put_cap(cap_tree) + .map_err(|e| format!("put_cap error: {:?}", e))?; + + // Restore checkpoints + for cp in &state.checkpoints { + let tree_state = match cp.position { + None => TreeState::Empty, + Some(pos) => TreeState::AtPosition(Position::from(pos)), + }; + let marks_removed: BTreeSet = cp + .marks_removed + .iter() + .map(|p| Position::from(*p)) + .collect(); + let checkpoint = Checkpoint::from_parts(tree_state, marks_removed); + store + .add_checkpoint(cp.id, checkpoint) + .map_err(|e| format!("add_checkpoint error: {:?}", e))?; + } + + Ok(ShardTree::new(store, MAX_CHECKPOINTS)) +} + +// --------------------------------------------------------------------------- +// Global in-memory tree state +// --------------------------------------------------------------------------- + +pub struct GlobalTree { + pub tree: OrchardShardTree, + pub tip_height: Option, + pub leaf_count: u64, +} + +pub static TREE: Mutex> = Mutex::new(None); + +fn with_tree(f: F) -> Result +where + F: FnOnce(&mut GlobalTree) -> Result, +{ + let mut guard = TREE.lock().map_err(|e| format!("lock poisoned: {}", e))?; + let gt = guard + .as_mut() + .ok_or_else(|| "TREE_NOT_INITIALIZED: call init or load first".to_string())?; + f(gt) +} + +// --------------------------------------------------------------------------- +// Public API +// --------------------------------------------------------------------------- + +/// Parse a hex-encoded commitment into a MerkleHashOrchard +pub fn parse_hash_hex(hex_str: &str) -> Result { + let bytes = hex::decode(hex_str).map_err(|e| format!("hex decode error: {}", e))?; + if bytes.len() != 32 { + return Err(format!("expected 32 bytes, got {}", bytes.len())); + } + let mut arr = [0u8; 32]; + arr.copy_from_slice(&bytes); + Option::from(MerkleHashOrchard::from_bytes(&arr)) + .ok_or_else(|| "invalid MerkleHashOrchard value".to_string()) +} + +/// Get the root hash at the most recent checkpoint. +/// +/// Uses `root_at_checkpoint_depth` instead of `frontier()` because after a +/// save→load round-trip the store may only contain the active shard + cap. +/// `frontier()` requires full shard data along the frontier path, whereas +/// `root_at_checkpoint_depth` computes the root from shard roots combined +/// with cap annotations, which works with partial shard data. +pub fn get_checkpoint_root(tree: &OrchardShardTree) -> Result { + let root = tree + .root_at_checkpoint_depth(Some(0)) + .map_err(|e| format!("root_at_checkpoint error: {:?}", e))? + .ok_or("no checkpoint available to compute root")?; + Ok(hex::encode(root.to_bytes())) +} + +/// Read a Bitcoin-style CompactSize from a byte slice, returning (value, bytes_consumed). +fn read_compact_size(data: &[u8]) -> Result<(u64, usize), String> { + if data.is_empty() { + return Err("unexpected EOF reading compact size".to_string()); + } + match data[0] { + 0..=252 => Ok((data[0] as u64, 1)), + 253 => { + if data.len() < 3 { + return Err("unexpected EOF reading compact size u16".to_string()); + } + Ok((u16::from_le_bytes([data[1], data[2]]) as u64, 3)) + } + 254 => { + if data.len() < 5 { + return Err("unexpected EOF reading compact size u32".to_string()); + } + Ok(( + u32::from_le_bytes([data[1], data[2], data[3], data[4]]) as u64, + 5, + )) + } + 255 => { + if data.len() < 9 { + return Err("unexpected EOF reading compact size u64".to_string()); + } + Ok(( + u64::from_le_bytes([ + data[1], data[2], data[3], data[4], data[5], data[6], data[7], data[8], + ]), + 9, + )) + } + } +} + +/// Read a 32-byte MerkleHashOrchard from a byte slice. +fn read_hash(data: &[u8]) -> Result<(MerkleHashOrchard, usize), String> { + if data.len() < 32 { + return Err("unexpected EOF reading hash".to_string()); + } + let mut arr = [0u8; 32]; + arr.copy_from_slice(&data[..32]); + let hash = Option::from(MerkleHashOrchard::from_bytes(&arr)) + .ok_or_else(|| "invalid MerkleHashOrchard value in frontier".to_string())?; + Ok((hash, 32)) +} + +/// Read an Optional in zcashd v0 encoding: 0x00 = None, 0x01 = Some(32 bytes). +fn read_optional_hash(data: &[u8]) -> Result<(Option, usize), String> { + if data.is_empty() { + return Err("unexpected EOF reading optional flag".to_string()); + } + match data[0] { + 0x00 => Ok((None, 1)), + 0x01 => { + let (hash, n) = read_hash(&data[1..])?; + Ok((Some(hash), 1 + n)) + } + b => Err(format!("invalid optional flag byte: 0x{:02x}", b)), + } +} + +/// Initialize tree from a serialized frontier (from z_gettreestate). +/// +/// The `orchardTree` field from zcashd's `z_gettreestate` uses the **v0 CommitmentTree** +/// encoding (the legacy incremental merkle tree format): +/// +/// ```text +/// Optional(left) | Optional(right) | Vector(parents: Vec) +/// ``` +/// +/// Where: +/// - Optional: 0x00 = absent, 0x01 = present followed by 32-byte hash +/// - Vector: CompactSize(length) followed by N elements +/// +/// The position is derived from the tree structure: if `right` is present, the lowest +/// bit of position is 1. Each present parent at level `i` means bit `i+1` (or `i+1` +/// for the right-shifted position) is set, indicating a completed left subtree at that level. +/// +/// This function converts the CommitmentTree into a `NonEmptyFrontier` (position, leaf, ommers) +/// and inserts it into a new ShardTree. +pub fn init_from_frontier(frontier_hex: &str, block_height: u32) -> Result { + use incrementalmerkletree::frontier::NonEmptyFrontier; + + let bytes = hex::decode(frontier_hex).map_err(|e| format!("hex decode error: {}", e))?; + let mut offset = 0; + + // Parse CommitmentTree v0 format: Optional(left) | Optional(right) | Vector(parents) + let (left, n) = read_optional_hash(&bytes[offset..])?; + offset += n; + + let (right, n) = read_optional_hash(&bytes[offset..])?; + offset += n; + + let (parent_count, n) = read_compact_size(&bytes[offset..])?; + offset += n; + + if parent_count > DEPTH as u64 { + return Err(format!( + "parent_count {} exceeds tree depth {}", + parent_count, DEPTH + )); + } + + let mut parents: Vec> = Vec::with_capacity(parent_count as usize); + for _ in 0..parent_count { + let (parent, n) = read_optional_hash(&bytes[offset..])?; + offset += n; + parents.push(parent); + } + + let left = left.ok_or("commitment tree has no left leaf — tree is empty")?; + + // Convert CommitmentTree to NonEmptyFrontier. + // + // The CommitmentTree stores a path through the Merkle tree: + // - `left` is the left child at the leaf level + // - `right` (if present) is the right child at the leaf level + // - `parents[i]` is the sibling at level i+1 (present = completed left subtree) + // + // The frontier leaf is the most recently appended value: + // - If `right` is present, the leaf is `right` and `left` becomes an ommer + // - Otherwise, the leaf is `left` + // + // Position is reconstructed from the structure: + // - Bit 0 is set if `right` is present + // - Bit i+1 is set if `parents[i]` is present + let (leaf, mut ommers, mut position_val) = if let Some(right_hash) = right { + // right is present: leaf = right, left becomes first ommer + (right_hash, vec![left], 1u64) + } else { + (left, vec![], 0u64) + }; + + for (i, parent) in parents.iter().enumerate() { + if let Some(hash) = parent { + position_val |= 1u64 << (i + 1); + ommers.push(*hash); + } + } + + let position = Position::from(position_val); + let frontier = NonEmptyFrontier::from_parts(position, leaf, ommers) + .map_err(|e| format!("frontier construction error: {:?}", e))?; + + let leaf_count = u64::from(frontier.position()) + 1; + + let mut tree = ShardTree::new(MemoryShardStore::empty(), MAX_CHECKPOINTS); + tree + .insert_frontier_nodes( + frontier, + Retention::Checkpoint { + id: block_height, + marking: Marking::None, + }, + ) + .map_err(|e| format!("insert_frontier_nodes error: {}", e))?; + + let root_hex = get_checkpoint_root(&tree)?; + + let mut guard = TREE.lock().map_err(|e| format!("lock poisoned: {}", e))?; + if guard.is_some() { + return Err("ALREADY_INITIALIZED: tree already contains state; call load or truncate instead".to_string()); + } + *guard = Some(GlobalTree { + tree, + tip_height: Some(block_height), + leaf_count, + }); + + Ok(root_hex) +} + +/// Load tree from persisted JSON state +pub fn load(state_json: &str) -> Result<(), String> { + let state: PersistedShardTreeState = + serde_json::from_str(state_json).map_err(|e| format!("JSON parse error: {}", e))?; + let tree = restore_state(&state)?; + let mut guard = TREE.lock().map_err(|e| format!("lock poisoned: {}", e))?; + if guard.is_some() { + eprintln!("warning: overwriting existing tree state via load"); + } + *guard = Some(GlobalTree { + tree, + tip_height: state.tip_height, + leaf_count: state.leaf_count, + }); + Ok(()) +} + +/// Serialize current tree state to JSON +pub fn save() -> Result { + with_tree(|gt| { + let state = extract_state(>.tree, gt.tip_height, gt.leaf_count)?; + serde_json::to_string(&state).map_err(|e| format!("JSON serialize error: {}", e)) + }) +} + +/// Append commitments to the in-memory tree, checkpoint, verify root +pub fn append_commitments( + block_height: u32, + commitments: &[String], + expected_root: Option<&str>, +) -> Result { + with_tree(|gt| { + if commitments.is_empty() { + // No commitments — still record an empty checkpoint so truncate_to_checkpoint + // can target any processed block height. + gt.tree + .checkpoint(block_height) + .map_err(|e| format!("checkpoint error: {}", e))?; + gt.tip_height = Some(block_height); + if gt.leaf_count == 0 { + let empty_root = MerkleHashOrchard::empty_root(Level::from(DEPTH)); + return Ok(hex::encode(empty_root.to_bytes())); + } + return get_checkpoint_root(>.tree); + } + + for (i, cmx_hex) in commitments.iter().enumerate() { + let hash = parse_hash_hex(cmx_hex)?; + let is_last = i == commitments.len() - 1; + let retention = if is_last { + Retention::Checkpoint { + id: block_height, + marking: Marking::None, + } + } else { + Retention::Ephemeral + }; + gt.tree + .append(hash, retention) + .map_err(|e| format!("append error: {}", e))?; + gt.leaf_count += 1; + } + + gt.tip_height = Some(block_height); + + let root_hex = get_checkpoint_root(>.tree)?; + + if let Some(expected) = expected_root { + if !expected.is_empty() && root_hex != expected { + return Err(format!( + "ROOT_MISMATCH: computed {} but expected {}", + root_hex, expected + )); + } + } + + Ok(root_hex) + }) +} + +/// Truncate the tree back to a specific checkpoint (block height) +pub fn truncate_to_checkpoint(block_height: u32) -> Result { + with_tree(|gt| { + let ok = gt + .tree + .truncate_to_checkpoint(&block_height) + .map_err(|e| format!("truncate error: {:?}", e))?; + if !ok { + return Err(format!( + "CHECKPOINT_NOT_FOUND: no checkpoint for block height {}", + block_height + )); + } + gt.tip_height = Some(block_height); + gt.leaf_count = gt + .tree + .max_leaf_position(Some(0)) + .map_err(|e| format!("max_leaf_position error: {:?}", e))? + .map(|p| u64::from(p) + 1) + .unwrap_or(0); + get_checkpoint_root(>.tree) + }) +} + +/// Return tree metadata +pub fn get_info() -> Result<(Option, u64, u32), String> { + with_tree(|gt| { + let checkpoint_count = gt + .tree + .store() + .checkpoint_count() + .map_err(|e| format!("checkpoint_count error: {:?}", e))?; + Ok((gt.tip_height, gt.leaf_count, checkpoint_count as u32)) + }) +}