Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion src/io/sqlite_store/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
use std::collections::HashMap;
use std::fs;
use std::future::Future;
use std::path::PathBuf;
use std::path::{Component, Path, PathBuf};
use std::sync::atomic::{AtomicI64, AtomicU64, Ordering};
use std::sync::{Arc, Mutex};

Expand Down Expand Up @@ -216,6 +216,7 @@ impl SqliteStoreInner {
) -> io::Result<Self> {
let db_file_name = db_file_name.unwrap_or(DEFAULT_SQLITE_DB_FILE_NAME.to_string());
let kv_table_name = kv_table_name.unwrap_or(DEFAULT_KV_TABLE_NAME.to_string());
Self::check_db_file_name(&db_file_name)?;

fs::create_dir_all(data_dir.clone()).map_err(|e| {
let msg = format!(
Expand Down Expand Up @@ -313,6 +314,24 @@ impl SqliteStoreInner {
Ok(Self { connection, data_dir, kv_table_name, write_version_locks, next_sort_order })
}

fn check_db_file_name(db_file_name: &str) -> io::Result<()> {
let mut components = Path::new(db_file_name).components().peekable();
if components.peek().is_none() {
let msg = format!("Invalid database file name: {db_file_name}");
return Err(io::Error::new(io::ErrorKind::InvalidInput, msg));
}

// Accept only normal path components. Anything else can be a potential path
// traversal or other weird path prefix that would allow absolute paths,
// current-directory components, or parent-directory traversal.
if !components.all(|c| matches!(c, Component::Normal(_))) {
let msg = format!("Invalid database file name: {db_file_name}");
return Err(io::Error::new(io::ErrorKind::InvalidInput, msg));
}

Ok(())
}

fn get_inner_lock_ref(&self, locking_key: String) -> Arc<Mutex<u64>> {
let mut outer_lock = self.write_version_locks.lock().expect("lock");
Arc::clone(&outer_lock.entry(locking_key).or_default())
Expand Down Expand Up @@ -679,6 +698,39 @@ mod tests {
do_test_store(&store_0, &store_1)
}

#[test]
fn rejects_db_file_names_that_escape_data_dir() {
let mut temp_path = random_storage_path();
temp_path.push("rejects_db_file_names_that_escape_data_dir");

for db_file_name in ["", ".", "..", "../escaped.sqlite", "nested/../escaped.sqlite"] {
let res = SqliteStore::new(
temp_path.clone(),
Some(db_file_name.to_string()),
Some("test_table".to_string()),
);
let err = match res {
Ok(_) => panic!("accepted invalid database file name: {}", db_file_name),
Err(e) => e,
};
assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
}

#[cfg(unix)]
{
let res = SqliteStore::new(
temp_path,
Some("/tmp/escaped.sqlite".to_string()),
Some("test_table".to_string()),
);
let err = match res {
Ok(_) => panic!("accepted invalid absolute database file name"),
Err(e) => e,
};
assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
}
}

#[tokio::test]
async fn test_sqlite_store_paginated_listing() {
let mut temp_path = random_storage_path();
Expand Down