Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 39 additions & 33 deletions gemma/api_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,32 +61,36 @@ struct ServerState {
std::chrono::steady_clock::time_point last_access;
};

std::unordered_map<std::string, Session> sessions;
// Lock ordering: always acquire sessions_mutex before inference_mutex.
std::unordered_map<std::string, std::shared_ptr<Session>> sessions;
std::mutex sessions_mutex;
std::mutex inference_mutex;

// Cleanup old sessions after 30 minutes of inactivity
// Cleanup old sessions after 30 minutes of inactivity.
// Sessions currently in use are kept alive by shared_ptr held by handlers.
void CleanupOldSessions() {
std::lock_guard<std::mutex> lock(sessions_mutex);
auto now = std::chrono::steady_clock::now();
for (auto it = sessions.begin(); it != sessions.end();) {
if (now - it->second.last_access > std::chrono::minutes(30)) {
if (now - it->second->last_access > std::chrono::minutes(30)) {
it = sessions.erase(it);
} else {
++it;
}
}
}

// Get or create session with KV cache
Session& GetOrCreateSession(const std::string& session_id) {
// Get or create session with KV cache. Returns shared_ptr so the session
// remains alive even if erased from the map by cleanup.
std::shared_ptr<Session> GetOrCreateSession(const std::string& session_id) {
std::lock_guard<std::mutex> lock(sessions_mutex);
auto& session = sessions[session_id];
if (!session.kv_cache) {
session.kv_cache = std::make_unique<KVCache>(
if (!session) {
session = std::make_shared<Session>();
session->kv_cache = std::make_unique<KVCache>(
gemma->Config(), InferenceArgs(), env->ctx.allocator);
}
session.last_access = std::chrono::steady_clock::now();
session->last_access = std::chrono::steady_clock::now();
return session;
}
};
Expand Down Expand Up @@ -185,9 +189,9 @@ void HandleGenerateContentNonStreaming(ServerState& state,
try {
json request = json::parse(req.body);

// Get or create session
// Get or create session (acquires sessions_mutex, then releases it).
std::string session_id = request.value("sessionId", GenerateSessionId());
auto& session = state.GetOrCreateSession(session_id);
auto session = state.GetOrCreateSession(session_id);

// Extract prompt from API format
std::string prompt;
Expand All @@ -201,7 +205,7 @@ void HandleGenerateContentNonStreaming(ServerState& state,
return;
}

// Lock for inference
// Lock for inference (after sessions_mutex is released).
std::lock_guard<std::mutex> lock(state.inference_mutex);

// Set up runtime config
Expand All @@ -212,7 +216,7 @@ void HandleGenerateContentNonStreaming(ServerState& state,
// Tokenize prompt
std::vector<int> tokens = WrapAndTokenize(
state.gemma->Tokenizer(), state.gemma->ChatTemplate(),
state.gemma->Config().wrapping, session.abs_pos, prompt);
state.gemma->Config().wrapping, session->abs_pos, prompt);

// Run inference with KV cache
TimingInfo timing_info = {.verbosity = 0};
Expand All @@ -223,12 +227,12 @@ void HandleGenerateContentNonStreaming(ServerState& state,
runtime_config.stream_token = [&output, &state, &session, &tokens](
int token, float) {
// Skip prompt tokens
if (session.abs_pos < tokens.size()) {
session.abs_pos++;
if (session->abs_pos < tokens.size()) {
session->abs_pos++;
return true;
}

session.abs_pos++;
session->abs_pos++;

// Check for EOS
if (state.gemma->Config().IsEOS(token)) {
Expand All @@ -243,15 +247,15 @@ void HandleGenerateContentNonStreaming(ServerState& state,
return true;
};

state.gemma->Generate(runtime_config, tokens, session.abs_pos, prefix_end,
*session.kv_cache, *state.env, timing_info);
state.gemma->Generate(runtime_config, tokens, session->abs_pos, prefix_end,
*session->kv_cache, *state.env, timing_info);

// Create response
json response = CreateAPIResponse(output.str(), false);
response["usageMetadata"] = {
{"promptTokenCount", tokens.size()},
{"candidatesTokenCount", session.abs_pos - tokens.size()},
{"totalTokenCount", session.abs_pos}
{"candidatesTokenCount", session->abs_pos - tokens.size()},
{"totalTokenCount", session->abs_pos}
};

res.set_content(response.dump(), "application/json");
Expand Down Expand Up @@ -279,9 +283,9 @@ void HandleGenerateContentStreaming(ServerState& state,
try {
json request = json::parse(req.body);

// Get or create session
// Get or create session (acquires sessions_mutex, then releases it).
std::string session_id = request.value("sessionId", GenerateSessionId());
auto& session = state.GetOrCreateSession(session_id);
auto session = state.GetOrCreateSession(session_id);

// Extract prompt from API format
std::string prompt;
Expand All @@ -301,33 +305,35 @@ void HandleGenerateContentStreaming(ServerState& state,
res.set_header("Connection", "keep-alive");
res.set_header("X-Session-Id", session_id);

// Set up chunked content provider for SSE
// Set up chunked content provider for SSE. The lambda captures `session`
// by value (shared_ptr copy), keeping the session alive even if cleanup
// erases it from the map.
res.set_chunked_content_provider(
"text/event-stream", [&state, request, prompt, session_id](
"text/event-stream", [&state, request, prompt, session](
size_t offset, httplib::DataSink& sink) {
try {
// Lock for inference
// Lock for inference (sessions_mutex is NOT held here —
// consistent ordering: sessions_mutex before inference_mutex).
std::lock_guard<std::mutex> lock(state.inference_mutex);
auto& session = state.GetOrCreateSession(session_id);

// Set up runtime config
RuntimeConfig runtime_config = ParseGenerationConfig(request);

// Tokenize prompt
std::vector<int> tokens = WrapAndTokenize(
state.gemma->Tokenizer(), state.gemma->ChatTemplate(),
state.gemma->Config().wrapping, session.abs_pos, prompt);
state.gemma->Config().wrapping, session->abs_pos, prompt);

// Stream token callback
std::string accumulated_text;
auto stream_token = [&](int token, float) {
// Skip prompt tokens
if (session.abs_pos < tokens.size()) {
session.abs_pos++;
if (session->abs_pos < tokens.size()) {
session->abs_pos++;
return true;
}

session.abs_pos++;
session->abs_pos++;

// Check for EOS
if (state.gemma->Config().IsEOS(token)) {
Expand Down Expand Up @@ -355,16 +361,16 @@ void HandleGenerateContentStreaming(ServerState& state,
TimingInfo timing_info = {.verbosity = 0};
size_t prefix_end = 0;

state.gemma->Generate(runtime_config, tokens, session.abs_pos,
prefix_end, *session.kv_cache, *state.env,
state.gemma->Generate(runtime_config, tokens, session->abs_pos,
prefix_end, *session->kv_cache, *state.env,
timing_info);

// Send final event using unified formatter
json final_event = CreateAPIResponse("", false);
final_event["usageMetadata"] = {
{"promptTokenCount", tokens.size()},
{"candidatesTokenCount", session.abs_pos - tokens.size()},
{"totalTokenCount", session.abs_pos}};
{"candidatesTokenCount", session->abs_pos - tokens.size()},
{"totalTokenCount", session->abs_pos}};

std::string final_sse = "data: " + final_event.dump() + "\n\n";
sink.write(final_sse.data(), final_sse.size());
Expand Down
14 changes: 9 additions & 5 deletions gemma/bindings/context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,15 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
active_conversation->abs_pos--;
}

// Copy result buffer to output C-string (ensure null termination)
strncpy(output, result_buffer.c_str(), max_output_chars - 1);
output[max_output_chars - 1] = '\0';

return static_cast<int>(strlen(output));
// Copy result buffer to output C-string (ensure null termination).
if (max_output_chars <= 0) {
return 0;
}
const size_t n = static_cast<size_t>(max_output_chars);
const size_t copy_len = HWY_MIN(result_buffer.size(), n - 1);
memcpy(output, result_buffer.c_str(), copy_len);
output[copy_len] = '\0';
return static_cast<int>(copy_len);
}

// Public Generate method (wrapper for text-only)
Expand Down
25 changes: 18 additions & 7 deletions io/blob_store.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,13 @@ class BlobStore {
bool ParseHeaderAndDirectoryV2(const File& file) {
is_file_v2_ = true;
// Read header from the end of the file.
size_t offset = file.FileSize() - sizeof(header_);
if (!file.Read(offset, sizeof(header_), &header_)) {
const uint64_t file_bytes = file.FileSize();
if (file_bytes < sizeof(header_)) {
HWY_WARN("File is too small to contain a BlobStore header.");
return false;
}
size_t pos = file_bytes - sizeof(header_);
if (!file.Read(pos, sizeof(header_), &header_)) {
HWY_WARN("Failed to read BlobStore header.");
return false;
}
Expand All @@ -199,14 +204,20 @@ class BlobStore {
return false;
}
directory_.resize(header_.num_blobs * 2);
const auto directory_bytes = 2 * kU128Bytes * header_.num_blobs;
offset -= directory_bytes;
// Read directory immediately before the header.
if (!file.Read(offset, directory_bytes, directory_.data())) {

// Read directory, which ends at the start of the header.
const size_t directory_bytes = 2 * kU128Bytes * header_.num_blobs;
if (directory_bytes > pos) {
HWY_WARN("Directory is larger than the file size.");
return false;
}
pos -= directory_bytes;
if (!file.Read(pos, directory_bytes, directory_.data())) {
HWY_WARN("Failed to read BlobStore directory.");
return false;
}
HWY_ASSERT(IsValid(file.FileSize()));

HWY_ASSERT(IsValid(file_bytes));
return true;
}

Expand Down
19 changes: 12 additions & 7 deletions paligemma/image.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,29 +163,34 @@ bool Image::ReadPPM(const hwy::Span<const char>& buf) {
}

void Image::Set(int width, int height, const float* data) {
HWY_ASSERT(width > 0 && height > 0);
HWY_ASSERT(width <= SIZE_MAX / 3 && width * 3 <= SIZE_MAX / height);
width_ = width;
height_ = height;
int num_elements = width * height * 3;
const size_t num_elements = static_cast<size_t>(width) * height * 3;
data_.resize(num_elements);
data_.assign(data, data + num_elements);
float min_value = std::numeric_limits<float>::infinity();
float max_value = -std::numeric_limits<float>::infinity();
for (int i = 0; i < num_elements; ++i) {
for (size_t i = 0; i < num_elements; ++i) {
if (data_[i] < min_value) min_value = data_[i];
if (data_[i] > max_value) max_value = data_[i];
}
// -> out_min + (value - in_min) * (out_max - out_min) / (in_max - in_min)
float in_range = max_value - min_value;
if (in_range == 0.0f) in_range = 1.0f;
float scale = 2.0f / in_range;
for (int i = 0; i < num_elements; ++i) {
for (size_t i = 0; i < num_elements; ++i) {
data_[i] = (data_[i] - min_value) * scale - 1.0f;
}
}

// This is surprisingly inexpensive for small images (2 ms).
void Image::Resize(int new_width, int new_height) {
std::vector<float> new_data(new_width * new_height * 3);
HWY_ASSERT(new_width > 0 && new_height > 0);
HWY_ASSERT(new_width <= SIZE_MAX / 3 &&
new_width * 3 <= SIZE_MAX / new_height);
std::vector<float> new_data(static_cast<size_t>(new_width) * new_height * 3);
// TODO: go to bilinear interpolation, or antialias.
// E.g. consider WeightsSymmetric3Lowpass and SlowSymmetric3 from
// jpegxl/lib/jxl/convolve_slow.cc
Expand All @@ -195,8 +200,8 @@ void Image::Resize(int new_width, int new_height) {
int old_i = NearestNeighbor(i, new_height, height_);
int old_j = NearestNeighbor(j, new_width, width_);
for (int k = 0; k < 3; ++k) {
new_data[(i * new_width + j) * 3 + k] =
data_[(old_i * width_ + old_j) * 3 + k];
new_data[(static_cast<size_t>(i) * new_width + j) * 3 + k] =
data_[(static_cast<size_t>(old_i) * width_ + old_j) * 3 + k];
}
}
}
Expand Down Expand Up @@ -228,7 +233,7 @@ void Image::GetPatch(size_t patch_num, const hwy::Divisor& div_patch_dim,
const size_t patch_dim = div_patch_dim.GetDivisor();
const size_t bytes_per_row = (patch_dim * kBytesPerPixel);
const size_t in_bytes_to_next_row = (width_ * kBytesPerPixel);
HWY_ASSERT(size() == width_ * height_ * kNumChannels);
HWY_ASSERT(size() == static_cast<size_t>(width_) * height_ * kNumChannels);
HWY_ASSERT(div_patch_dim.Remainder(width_) == 0);
HWY_ASSERT(div_patch_dim.Remainder(height_) == 0);
const size_t patches_x = div_patch_dim.Divide(width_);
Expand Down
4 changes: 2 additions & 2 deletions python/gemma_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ class GemmaModel {
return result;
}

// For a PaliGemma model, sets the image to run on. Subseqent calls to
// For a PaliGemma model, sets the image to run on. Subsequent calls to
// Generate* will use this image. Throws an error for other models.
void SetImage(const py::array_t<float, py::array::c_style |
py::array::forcecast>& image) {
Expand All @@ -174,7 +174,7 @@ class GemmaModel {
int width = buffer.shape[1];
float* ptr = static_cast<float*>(buffer.ptr);
gcpp::Image c_image;
c_image.Set(height, width, ptr);
c_image.Set(width, height, ptr);
const size_t image_size = config.vit_config.image_size;
c_image.Resize(image_size, image_size);
image_tokens_.reset(new gcpp::ImageTokens(
Expand Down
Loading