diff --git a/gemma/api_server.cc b/gemma/api_server.cc index 8f71043a..c713635f 100644 --- a/gemma/api_server.cc +++ b/gemma/api_server.cc @@ -61,16 +61,18 @@ struct ServerState { std::chrono::steady_clock::time_point last_access; }; - std::unordered_map sessions; + // Lock ordering: always acquire sessions_mutex before inference_mutex. + std::unordered_map> sessions; std::mutex sessions_mutex; std::mutex inference_mutex; - // Cleanup old sessions after 30 minutes of inactivity + // Cleanup old sessions after 30 minutes of inactivity. + // Sessions currently in use are kept alive by shared_ptr held by handlers. void CleanupOldSessions() { std::lock_guard lock(sessions_mutex); auto now = std::chrono::steady_clock::now(); for (auto it = sessions.begin(); it != sessions.end();) { - if (now - it->second.last_access > std::chrono::minutes(30)) { + if (now - it->second->last_access > std::chrono::minutes(30)) { it = sessions.erase(it); } else { ++it; @@ -78,15 +80,17 @@ struct ServerState { } } - // Get or create session with KV cache - Session& GetOrCreateSession(const std::string& session_id) { + // Get or create session with KV cache. Returns shared_ptr so the session + // remains alive even if erased from the map by cleanup. + std::shared_ptr GetOrCreateSession(const std::string& session_id) { std::lock_guard lock(sessions_mutex); auto& session = sessions[session_id]; - if (!session.kv_cache) { - session.kv_cache = std::make_unique( + if (!session) { + session = std::make_shared(); + session->kv_cache = std::make_unique( gemma->Config(), InferenceArgs(), env->ctx.allocator); } - session.last_access = std::chrono::steady_clock::now(); + session->last_access = std::chrono::steady_clock::now(); return session; } }; @@ -185,9 +189,9 @@ void HandleGenerateContentNonStreaming(ServerState& state, try { json request = json::parse(req.body); - // Get or create session + // Get or create session (acquires sessions_mutex, then releases it). std::string session_id = request.value("sessionId", GenerateSessionId()); - auto& session = state.GetOrCreateSession(session_id); + auto session = state.GetOrCreateSession(session_id); // Extract prompt from API format std::string prompt; @@ -201,7 +205,7 @@ void HandleGenerateContentNonStreaming(ServerState& state, return; } - // Lock for inference + // Lock for inference (after sessions_mutex is released). std::lock_guard lock(state.inference_mutex); // Set up runtime config @@ -212,7 +216,7 @@ void HandleGenerateContentNonStreaming(ServerState& state, // Tokenize prompt std::vector tokens = WrapAndTokenize( state.gemma->Tokenizer(), state.gemma->ChatTemplate(), - state.gemma->Config().wrapping, session.abs_pos, prompt); + state.gemma->Config().wrapping, session->abs_pos, prompt); // Run inference with KV cache TimingInfo timing_info = {.verbosity = 0}; @@ -223,12 +227,12 @@ void HandleGenerateContentNonStreaming(ServerState& state, runtime_config.stream_token = [&output, &state, &session, &tokens]( int token, float) { // Skip prompt tokens - if (session.abs_pos < tokens.size()) { - session.abs_pos++; + if (session->abs_pos < tokens.size()) { + session->abs_pos++; return true; } - session.abs_pos++; + session->abs_pos++; // Check for EOS if (state.gemma->Config().IsEOS(token)) { @@ -243,15 +247,15 @@ void HandleGenerateContentNonStreaming(ServerState& state, return true; }; - state.gemma->Generate(runtime_config, tokens, session.abs_pos, prefix_end, - *session.kv_cache, *state.env, timing_info); + state.gemma->Generate(runtime_config, tokens, session->abs_pos, prefix_end, + *session->kv_cache, *state.env, timing_info); // Create response json response = CreateAPIResponse(output.str(), false); response["usageMetadata"] = { {"promptTokenCount", tokens.size()}, - {"candidatesTokenCount", session.abs_pos - tokens.size()}, - {"totalTokenCount", session.abs_pos} + {"candidatesTokenCount", session->abs_pos - tokens.size()}, + {"totalTokenCount", session->abs_pos} }; res.set_content(response.dump(), "application/json"); @@ -279,9 +283,9 @@ void HandleGenerateContentStreaming(ServerState& state, try { json request = json::parse(req.body); - // Get or create session + // Get or create session (acquires sessions_mutex, then releases it). std::string session_id = request.value("sessionId", GenerateSessionId()); - auto& session = state.GetOrCreateSession(session_id); + auto session = state.GetOrCreateSession(session_id); // Extract prompt from API format std::string prompt; @@ -301,14 +305,16 @@ void HandleGenerateContentStreaming(ServerState& state, res.set_header("Connection", "keep-alive"); res.set_header("X-Session-Id", session_id); - // Set up chunked content provider for SSE + // Set up chunked content provider for SSE. The lambda captures `session` + // by value (shared_ptr copy), keeping the session alive even if cleanup + // erases it from the map. res.set_chunked_content_provider( - "text/event-stream", [&state, request, prompt, session_id]( + "text/event-stream", [&state, request, prompt, session]( size_t offset, httplib::DataSink& sink) { try { - // Lock for inference + // Lock for inference (sessions_mutex is NOT held here — + // consistent ordering: sessions_mutex before inference_mutex). std::lock_guard lock(state.inference_mutex); - auto& session = state.GetOrCreateSession(session_id); // Set up runtime config RuntimeConfig runtime_config = ParseGenerationConfig(request); @@ -316,18 +322,18 @@ void HandleGenerateContentStreaming(ServerState& state, // Tokenize prompt std::vector tokens = WrapAndTokenize( state.gemma->Tokenizer(), state.gemma->ChatTemplate(), - state.gemma->Config().wrapping, session.abs_pos, prompt); + state.gemma->Config().wrapping, session->abs_pos, prompt); // Stream token callback std::string accumulated_text; auto stream_token = [&](int token, float) { // Skip prompt tokens - if (session.abs_pos < tokens.size()) { - session.abs_pos++; + if (session->abs_pos < tokens.size()) { + session->abs_pos++; return true; } - session.abs_pos++; + session->abs_pos++; // Check for EOS if (state.gemma->Config().IsEOS(token)) { @@ -355,16 +361,16 @@ void HandleGenerateContentStreaming(ServerState& state, TimingInfo timing_info = {.verbosity = 0}; size_t prefix_end = 0; - state.gemma->Generate(runtime_config, tokens, session.abs_pos, - prefix_end, *session.kv_cache, *state.env, + state.gemma->Generate(runtime_config, tokens, session->abs_pos, + prefix_end, *session->kv_cache, *state.env, timing_info); // Send final event using unified formatter json final_event = CreateAPIResponse("", false); final_event["usageMetadata"] = { {"promptTokenCount", tokens.size()}, - {"candidatesTokenCount", session.abs_pos - tokens.size()}, - {"totalTokenCount", session.abs_pos}}; + {"candidatesTokenCount", session->abs_pos - tokens.size()}, + {"totalTokenCount", session->abs_pos}}; std::string final_sse = "data: " + final_event.dump() + "\n\n"; sink.write(final_sse.data(), final_sse.size()); diff --git a/gemma/bindings/context.cc b/gemma/bindings/context.cc index d86d460d..bdddb555 100644 --- a/gemma/bindings/context.cc +++ b/gemma/bindings/context.cc @@ -252,11 +252,15 @@ int GemmaContext::GenerateInternal(const char* prompt_string, active_conversation->abs_pos--; } - // Copy result buffer to output C-string (ensure null termination) - strncpy(output, result_buffer.c_str(), max_output_chars - 1); - output[max_output_chars - 1] = '\0'; - - return static_cast(strlen(output)); + // Copy result buffer to output C-string (ensure null termination). + if (max_output_chars <= 0) { + return 0; + } + const size_t n = static_cast(max_output_chars); + const size_t copy_len = HWY_MIN(result_buffer.size(), n - 1); + memcpy(output, result_buffer.c_str(), copy_len); + output[copy_len] = '\0'; + return static_cast(copy_len); } // Public Generate method (wrapper for text-only) diff --git a/io/blob_store.cc b/io/blob_store.cc index 0b14e63f..b19dc978 100644 --- a/io/blob_store.cc +++ b/io/blob_store.cc @@ -182,8 +182,13 @@ class BlobStore { bool ParseHeaderAndDirectoryV2(const File& file) { is_file_v2_ = true; // Read header from the end of the file. - size_t offset = file.FileSize() - sizeof(header_); - if (!file.Read(offset, sizeof(header_), &header_)) { + const uint64_t file_bytes = file.FileSize(); + if (file_bytes < sizeof(header_)) { + HWY_WARN("File is too small to contain a BlobStore header."); + return false; + } + size_t pos = file_bytes - sizeof(header_); + if (!file.Read(pos, sizeof(header_), &header_)) { HWY_WARN("Failed to read BlobStore header."); return false; } @@ -199,14 +204,20 @@ class BlobStore { return false; } directory_.resize(header_.num_blobs * 2); - const auto directory_bytes = 2 * kU128Bytes * header_.num_blobs; - offset -= directory_bytes; - // Read directory immediately before the header. - if (!file.Read(offset, directory_bytes, directory_.data())) { + + // Read directory, which ends at the start of the header. + const size_t directory_bytes = 2 * kU128Bytes * header_.num_blobs; + if (directory_bytes > pos) { + HWY_WARN("Directory is larger than the file size."); + return false; + } + pos -= directory_bytes; + if (!file.Read(pos, directory_bytes, directory_.data())) { HWY_WARN("Failed to read BlobStore directory."); return false; } - HWY_ASSERT(IsValid(file.FileSize())); + + HWY_ASSERT(IsValid(file_bytes)); return true; } diff --git a/paligemma/image.cc b/paligemma/image.cc index b2c54195..451467ba 100644 --- a/paligemma/image.cc +++ b/paligemma/image.cc @@ -163,14 +163,16 @@ bool Image::ReadPPM(const hwy::Span& buf) { } void Image::Set(int width, int height, const float* data) { + HWY_ASSERT(width > 0 && height > 0); + HWY_ASSERT(width <= SIZE_MAX / 3 && width * 3 <= SIZE_MAX / height); width_ = width; height_ = height; - int num_elements = width * height * 3; + const size_t num_elements = static_cast(width) * height * 3; data_.resize(num_elements); data_.assign(data, data + num_elements); float min_value = std::numeric_limits::infinity(); float max_value = -std::numeric_limits::infinity(); - for (int i = 0; i < num_elements; ++i) { + for (size_t i = 0; i < num_elements; ++i) { if (data_[i] < min_value) min_value = data_[i]; if (data_[i] > max_value) max_value = data_[i]; } @@ -178,14 +180,17 @@ void Image::Set(int width, int height, const float* data) { float in_range = max_value - min_value; if (in_range == 0.0f) in_range = 1.0f; float scale = 2.0f / in_range; - for (int i = 0; i < num_elements; ++i) { + for (size_t i = 0; i < num_elements; ++i) { data_[i] = (data_[i] - min_value) * scale - 1.0f; } } // This is surprisingly inexpensive for small images (2 ms). void Image::Resize(int new_width, int new_height) { - std::vector new_data(new_width * new_height * 3); + HWY_ASSERT(new_width > 0 && new_height > 0); + HWY_ASSERT(new_width <= SIZE_MAX / 3 && + new_width * 3 <= SIZE_MAX / new_height); + std::vector new_data(static_cast(new_width) * new_height * 3); // TODO: go to bilinear interpolation, or antialias. // E.g. consider WeightsSymmetric3Lowpass and SlowSymmetric3 from // jpegxl/lib/jxl/convolve_slow.cc @@ -195,8 +200,8 @@ void Image::Resize(int new_width, int new_height) { int old_i = NearestNeighbor(i, new_height, height_); int old_j = NearestNeighbor(j, new_width, width_); for (int k = 0; k < 3; ++k) { - new_data[(i * new_width + j) * 3 + k] = - data_[(old_i * width_ + old_j) * 3 + k]; + new_data[(static_cast(i) * new_width + j) * 3 + k] = + data_[(static_cast(old_i) * width_ + old_j) * 3 + k]; } } } @@ -228,7 +233,7 @@ void Image::GetPatch(size_t patch_num, const hwy::Divisor& div_patch_dim, const size_t patch_dim = div_patch_dim.GetDivisor(); const size_t bytes_per_row = (patch_dim * kBytesPerPixel); const size_t in_bytes_to_next_row = (width_ * kBytesPerPixel); - HWY_ASSERT(size() == width_ * height_ * kNumChannels); + HWY_ASSERT(size() == static_cast(width_) * height_ * kNumChannels); HWY_ASSERT(div_patch_dim.Remainder(width_) == 0); HWY_ASSERT(div_patch_dim.Remainder(height_) == 0); const size_t patches_x = div_patch_dim.Divide(width_); diff --git a/python/gemma_py.cc b/python/gemma_py.cc index ec045b01..ecf8b355 100644 --- a/python/gemma_py.cc +++ b/python/gemma_py.cc @@ -156,7 +156,7 @@ class GemmaModel { return result; } - // For a PaliGemma model, sets the image to run on. Subseqent calls to + // For a PaliGemma model, sets the image to run on. Subsequent calls to // Generate* will use this image. Throws an error for other models. void SetImage(const py::array_t& image) { @@ -174,7 +174,7 @@ class GemmaModel { int width = buffer.shape[1]; float* ptr = static_cast(buffer.ptr); gcpp::Image c_image; - c_image.Set(height, width, ptr); + c_image.Set(width, height, ptr); const size_t image_size = config.vit_config.image_size; c_image.Resize(image_size, image_size); image_tokens_.reset(new gcpp::ImageTokens(