Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 37 additions & 18 deletions gemma/configs.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,58 +114,76 @@ AttentionImpl GetAttentionImpl(const std::string& impl);
enum class PostNormType {
None,
Scale,
kSentinel // must be last
};

static inline bool EnumValid(PostNormType type) {
return static_cast<size_t>(type) <
static_cast<size_t>(PostNormType::kSentinel);
switch (type) {
case PostNormType::None:
case PostNormType::Scale:
return true;
default:
return false;
}
}

// Post qk projection operation type.
enum class PostQKType {
Rope,
HalfRope,
kSentinel // must be last
};

static inline bool EnumValid(PostQKType type) {
return static_cast<size_t>(type) <
static_cast<size_t>(PostNormType::kSentinel);
switch (type) {
case PostQKType::Rope:
case PostQKType::HalfRope:
return true;
default:
return false;
}
}

// FFW activation function.
enum class ActivationType {
Gelu,
kSentinel // must be last
};

static inline bool EnumValid(ActivationType type) {
return static_cast<size_t>(type) <
static_cast<size_t>(ActivationType::kSentinel);
switch (type) {
case ActivationType::Gelu:
return true;
default:
return false;
}
}

// Attention query scale.
enum class QueryScaleType {
SqrtKeySize,
SqrtModelDimDivNumHeads,
kSentinel // must be last
};

static inline bool EnumValid(QueryScaleType type) {
return static_cast<size_t>(type) <
static_cast<size_t>(QueryScaleType::kSentinel);
switch (type) {
case QueryScaleType::SqrtKeySize:
case QueryScaleType::SqrtModelDimDivNumHeads:
return true;
default:
return false;
}
}

// Residual connection type.
enum class ResidualType {
Add,
kSentinel // must be last
};

static inline bool EnumValid(ResidualType type) {
return static_cast<size_t>(type) <
static_cast<size_t>(ResidualType::kSentinel);
switch (type) {
case ResidualType::Add:
return true;
default:
return false;
}
}

template <size_t kNum>
Expand Down Expand Up @@ -292,7 +310,8 @@ struct LayerConfig : public IFields {
visitor(activation);
visitor(post_qk);
visitor(use_qk_norm);
internal.VisitFields(visitor);
// Visiting includes size prefix, whereas calling VisitFields would inline.
visitor(internal);
// Append new fields here, then update `python/configs.cc`.
}

Expand Down Expand Up @@ -411,10 +430,10 @@ struct ModelConfig : public IFields {

visitor(scale_base_names);

internal.VisitFields(visitor);

visitor(use_global_timescale);

// Visiting includes size prefix, whereas calling VisitFields would inline.
visitor(internal);
// Append new fields here, then update `python/configs.cc`.
}

Expand Down
Loading