diff --git a/src/io/sqlite_store/mod.rs b/src/io/sqlite_store/mod.rs index 076aeef9b..830945a09 100644 --- a/src/io/sqlite_store/mod.rs +++ b/src/io/sqlite_store/mod.rs @@ -9,7 +9,7 @@ use std::collections::HashMap; use std::fs; use std::future::Future; -use std::path::PathBuf; +use std::path::{Component, Path, PathBuf}; use std::sync::atomic::{AtomicI64, AtomicU64, Ordering}; use std::sync::{Arc, Mutex}; @@ -216,6 +216,7 @@ impl SqliteStoreInner { ) -> io::Result { let db_file_name = db_file_name.unwrap_or(DEFAULT_SQLITE_DB_FILE_NAME.to_string()); let kv_table_name = kv_table_name.unwrap_or(DEFAULT_KV_TABLE_NAME.to_string()); + Self::check_db_file_name(&db_file_name)?; fs::create_dir_all(data_dir.clone()).map_err(|e| { let msg = format!( @@ -313,6 +314,24 @@ impl SqliteStoreInner { Ok(Self { connection, data_dir, kv_table_name, write_version_locks, next_sort_order }) } + fn check_db_file_name(db_file_name: &str) -> io::Result<()> { + let mut components = Path::new(db_file_name).components().peekable(); + if components.peek().is_none() { + let msg = format!("Invalid database file name: {db_file_name}"); + return Err(io::Error::new(io::ErrorKind::InvalidInput, msg)); + } + + // Accept only normal path components. Anything else can be a potential path + // traversal or other weird path prefix that would allow absolute paths, + // current-directory components, or parent-directory traversal. + if !components.all(|c| matches!(c, Component::Normal(_))) { + let msg = format!("Invalid database file name: {db_file_name}"); + return Err(io::Error::new(io::ErrorKind::InvalidInput, msg)); + } + + Ok(()) + } + fn get_inner_lock_ref(&self, locking_key: String) -> Arc> { let mut outer_lock = self.write_version_locks.lock().expect("lock"); Arc::clone(&outer_lock.entry(locking_key).or_default()) @@ -679,6 +698,39 @@ mod tests { do_test_store(&store_0, &store_1) } + #[test] + fn rejects_db_file_names_that_escape_data_dir() { + let mut temp_path = random_storage_path(); + temp_path.push("rejects_db_file_names_that_escape_data_dir"); + + for db_file_name in ["", ".", "..", "../escaped.sqlite", "nested/../escaped.sqlite"] { + let res = SqliteStore::new( + temp_path.clone(), + Some(db_file_name.to_string()), + Some("test_table".to_string()), + ); + let err = match res { + Ok(_) => panic!("accepted invalid database file name: {}", db_file_name), + Err(e) => e, + }; + assert_eq!(err.kind(), io::ErrorKind::InvalidInput); + } + + #[cfg(unix)] + { + let res = SqliteStore::new( + temp_path, + Some("/tmp/escaped.sqlite".to_string()), + Some("test_table".to_string()), + ); + let err = match res { + Ok(_) => panic!("accepted invalid absolute database file name"), + Err(e) => e, + }; + assert_eq!(err.kind(), io::ErrorKind::InvalidInput); + } + } + #[tokio::test] async fn test_sqlite_store_paginated_listing() { let mut temp_path = random_storage_path();