From 3004b6ee3cf32452c76640fec8d160d1dd4cb120 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Fri, 12 Jun 2026 02:55:03 -0700 Subject: [PATCH] Serialization: EnumValid uses switch, not sentinel; add framing Calling visitor() on a sub-struct adds size field, whereas the direct call to VisitFields did not. This should be a no-op for existing models. PiperOrigin-RevId: 931040513 --- gemma/configs.h | 55 +++++++++++++++++++++++++++++++++---------------- 1 file changed, 37 insertions(+), 18 deletions(-) diff --git a/gemma/configs.h b/gemma/configs.h index 89cc9906..1f470180 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -114,58 +114,76 @@ AttentionImpl GetAttentionImpl(const std::string& impl); enum class PostNormType { None, Scale, - kSentinel // must be last }; static inline bool EnumValid(PostNormType type) { - return static_cast(type) < - static_cast(PostNormType::kSentinel); + switch (type) { + case PostNormType::None: + case PostNormType::Scale: + return true; + default: + return false; + } } // Post qk projection operation type. enum class PostQKType { Rope, HalfRope, - kSentinel // must be last }; static inline bool EnumValid(PostQKType type) { - return static_cast(type) < - static_cast(PostNormType::kSentinel); + switch (type) { + case PostQKType::Rope: + case PostQKType::HalfRope: + return true; + default: + return false; + } } // FFW activation function. enum class ActivationType { Gelu, - kSentinel // must be last }; static inline bool EnumValid(ActivationType type) { - return static_cast(type) < - static_cast(ActivationType::kSentinel); + switch (type) { + case ActivationType::Gelu: + return true; + default: + return false; + } } // Attention query scale. enum class QueryScaleType { SqrtKeySize, SqrtModelDimDivNumHeads, - kSentinel // must be last }; static inline bool EnumValid(QueryScaleType type) { - return static_cast(type) < - static_cast(QueryScaleType::kSentinel); + switch (type) { + case QueryScaleType::SqrtKeySize: + case QueryScaleType::SqrtModelDimDivNumHeads: + return true; + default: + return false; + } } // Residual connection type. enum class ResidualType { Add, - kSentinel // must be last }; static inline bool EnumValid(ResidualType type) { - return static_cast(type) < - static_cast(ResidualType::kSentinel); + switch (type) { + case ResidualType::Add: + return true; + default: + return false; + } } template @@ -292,7 +310,8 @@ struct LayerConfig : public IFields { visitor(activation); visitor(post_qk); visitor(use_qk_norm); - internal.VisitFields(visitor); + // Visiting includes size prefix, whereas calling VisitFields would inline. + visitor(internal); // Append new fields here, then update `python/configs.cc`. } @@ -411,10 +430,10 @@ struct ModelConfig : public IFields { visitor(scale_base_names); - internal.VisitFields(visitor); - visitor(use_global_timescale); + // Visiting includes size prefix, whereas calling VisitFields would inline. + visitor(internal); // Append new fields here, then update `python/configs.cc`. }