diff --git a/.github/workflows/integration_tests.yaml b/.github/workflows/integration_tests.yaml index dae95035..5af3b184 100644 --- a/.github/workflows/integration_tests.yaml +++ b/.github/workflows/integration_tests.yaml @@ -73,7 +73,7 @@ jobs: expected_description_prefix: Arch Linux expected_os_release_major: "" expected_distro_release_major: "" - require_release: "true" + require_release: "false" - name: Oracle Linux 9 image: oraclelinux:9 expected_os_name: OracleLinux @@ -100,7 +100,9 @@ jobs: CGO_ENABLED: 0 GOOS: linux GOARCH: amd64 - run: go build -o dist/facts-linux-amd64 ./cmd/facts + run: | + mkdir -p dist + go build -o dist/facts-linux-amd64 ./cmd/facts - name: Validate distro facts shell: bash diff --git a/.gitignore b/.gitignore index 0aea083c..25688d33 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,9 @@ *.gem *.rbc /facts +/facts.exe /.config +/coverage.out /coverage/ /InstalledFiles /pkg/ diff --git a/Makefile b/Makefile index ab1e9a34..6ab58b4f 100644 --- a/Makefile +++ b/Makefile @@ -6,8 +6,8 @@ COUNT ?= 5 LIMACTL ?= limactl LIMA_INSTANCE_PREFIX ?= facts -LIMA_GO_VERSION ?= 1.25.0 -LIMA_GO_SERIES ?= 1.25 +LIMA_GO_VERSION ?= 1.26.4 +LIMA_GO_SERIES ?= 1.26 LIMA_GOARCH ?= arm64 LIMA_CPUS ?= 6 LIMA_MEMORY ?= 10 diff --git a/engine.go b/engine.go index c078023f..b48d2cb5 100644 --- a/engine.go +++ b/engine.go @@ -120,6 +120,9 @@ func WithSystemDefaults() Option { // early, err satisfies errors.Is(err, ctx.Err()). Not-applicable facts are // absent from the Snapshot and contribute nothing to err. func (e *Engine) Discover(ctx context.Context, queries ...string) (*Snapshot, error) { + if e == nil || e.inner == nil { + return nil, errors.New("facts: uninitialized Engine") + } inner, err := e.inner.Discover(ctx, queries...) return &Snapshot{inner: inner}, err } diff --git a/engine_test.go b/engine_test.go index f3e8394e..406ec8a9 100644 --- a/engine_test.go +++ b/engine_test.go @@ -62,6 +62,18 @@ func TestNew_defaultEngineIsHermetic(t *testing.T) { } } +func TestEngineDiscover_uninitializedReceiverReturnsError(t *testing.T) { + var nilEngine *Engine + if snap, err := nilEngine.Discover(context.Background()); err == nil || snap != nil { + t.Fatalf("nil Engine Discover() = %#v, %v, want nil snapshot and error", snap, err) + } + + var zero Engine + if snap, err := zero.Discover(context.Background()); err == nil || snap != nil { + t.Fatalf("zero Engine Discover() = %#v, %v, want nil snapshot and error", snap, err) + } +} + func TestWithExternalDirs_loadsExactlyOptedDirs(t *testing.T) { t.Setenv("FACTER_env_probe", "leaked") dir := t.TempDir() @@ -595,6 +607,24 @@ func TestAs_shapeMismatchFailsLoudly(t *testing.T) { } } +func TestAs_rejectsMapAnyKeyStringCollisions(t *testing.T) { + eng, err := New(WithFact("ambiguous", func(context.Context) (any, error) { + return map[any]any{"1": "string-key", 1: "int-key"}, nil + })) + if err != nil { + t.Fatal(err) + } + snap, err := eng.Discover(context.Background()) + if err != nil { + t.Fatal(err) + } + + _, err = As[map[string]string](snap, "ambiguous") + if err == nil || !strings.Contains(err.Error(), `duplicate map key after string normalization: "1"`) { + t.Fatalf("As ambiguous err = %v, want duplicate normalized key error", err) + } +} + func TestAs_missingFactReturnsErrFactNotFound(t *testing.T) { snap := hermeticSnapshot() diff --git a/internal/app/app.go b/internal/app/app.go index a3e453da..ac164049 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -498,6 +498,10 @@ func writeVersionQuery(stdout io.Writer, jsonOutput, yamlOutput, hoconOutput boo if err != nil { return err } + if strings.HasSuffix(out, "\n") { + _, err = fmt.Fprint(stdout, out) + return err + } _, err = fmt.Fprintln(stdout, out) return err } diff --git a/internal/app/app_test.go b/internal/app/app_test.go index f49d2b71..6d72ed72 100644 --- a/internal/app/app_test.go +++ b/internal/app/app_test.go @@ -708,6 +708,17 @@ func TestRun_queryYAML(t *testing.T) { } } +func TestRun_queryFacterversionYAMLHasSingleTrailingNewline(t *testing.T) { + var stdout, stderr bytes.Buffer + + if err := Run(&stdout, &stderr, []string{"--yaml", "facterversion"}); err != nil { + t.Fatal(err) + } + if got, want := stdout.String(), "facterversion: "+engine.Version+"\n"; got != want { + t.Fatalf("stdout = %q, want %q", got, want) + } +} + func TestRun_queryHOCON(t *testing.T) { var stdout, stderr bytes.Buffer diff --git a/internal/app/loghandler.go b/internal/app/loghandler.go index 77445710..6488ef0c 100644 --- a/internal/app/loghandler.go +++ b/internal/app/loghandler.go @@ -4,6 +4,7 @@ import ( "context" "io" "log/slog" + "sync" ) // stderrLogHandler renders engine diagnostics as Ruby-compatible stderr lines @@ -17,6 +18,7 @@ type stderrLogHandler struct { color bool debug bool verbose bool + mu sync.Mutex } func (h *stderrLogHandler) Enabled(_ context.Context, level slog.Level) bool { @@ -33,6 +35,9 @@ func (h *stderrLogHandler) Enabled(_ context.Context, level slog.Level) bool { } func (h *stderrLogHandler) Handle(_ context.Context, record slog.Record) error { + h.mu.Lock() + defer h.mu.Unlock() + switch { case record.Level >= slog.LevelError: case record.Level >= slog.LevelWarn: diff --git a/internal/app/loghandler_test.go b/internal/app/loghandler_test.go index 7464a2cc..b8d05ade 100644 --- a/internal/app/loghandler_test.go +++ b/internal/app/loghandler_test.go @@ -3,6 +3,7 @@ package app import ( "bytes" "log/slog" + "sync" "testing" ) @@ -37,3 +38,18 @@ func TestStderrLogHandlerDropsErrorClassKeepsWarnDebug(t *testing.T) { } }) } + +func TestStderrLogHandlerConcurrentHandle(t *testing.T) { + var stderr bytes.Buffer + logger := slog.New(&stderrLogHandler{stderr: &stderr}) + + var wg sync.WaitGroup + for range 8 { + wg.Go(func() { + for range 50 { + logger.Warn("heads up") + } + }) + } + wg.Wait() +} diff --git a/internal/cli/arguments.go b/internal/cli/arguments.go index 5906ddb7..1f5f4252 100644 --- a/internal/cli/arguments.go +++ b/internal/cli/arguments.go @@ -12,8 +12,18 @@ func PrepareArguments(args []string) []string { priority := make([]string, 0, len(prepared)) normal := make([]string, 0, len(prepared)) + afterDelimiter := false for i := 0; i < len(prepared); i++ { arg := prepared[i] + if afterDelimiter { + normal = append(normal, arg) + continue + } + if arg == "--" { + normal = append(normal, arg) + afterDelimiter = true + continue + } if IsTaskFlag(arg) || IsTask(arg) { priority = append(priority, arg) continue @@ -29,7 +39,11 @@ func PrepareArguments(args []string) []string { func expandShortOptions(args []string) []string { expanded := make([]string, 0, len(args)) - for _, arg := range args { + for i, arg := range args { + if arg == "--" { + expanded = append(expanded, args[i:]...) + break + } if len(arg) <= 2 || arg[0] != '-' || arg[1] == '-' || strings.ContainsRune(arg, '=') { expanded = append(expanded, arg) continue @@ -48,6 +62,9 @@ func expandShortOptions(args []string) []string { func containsKnownTaskOrMappedFlag(args []string) bool { for i := 0; i < len(args); i++ { arg := args[i] + if arg == "--" { + return false + } if IsTask(arg) || IsTaskFlag(arg) { return true } diff --git a/internal/cli/arguments_test.go b/internal/cli/arguments_test.go index 0317c72b..edf433ab 100644 --- a/internal/cli/arguments_test.go +++ b/internal/cli/arguments_test.go @@ -30,6 +30,14 @@ func TestPrepareArguments_reordersShortVersionFlag(t *testing.T) { } } +func TestPrepareArguments_preservesFlagsAfterDelimiterAsQueries(t *testing.T) { + got := PrepareArguments([]string{"--", "-v"}) + want := []string{"query", "--", "-v"} + if !slices.Equal(got, want) { + t.Fatalf("PrepareArguments() = %v, want %v", got, want) + } +} + func TestPrepareArguments_doesNotPromoteTaskFlagWithInlineValue(t *testing.T) { got := PrepareArguments([]string{"--help=topic"}) want := []string{"query", "--help=topic"} @@ -98,7 +106,7 @@ func TestValidateOptions_allowsRepeatedExternalDir(t *testing.T) { } func TestValidateOptions_rejectsMissingRequiredOptionValue(t *testing.T) { - err := ValidateOptions([]string{"query", "--external-dir", "--no-external-facts", "site"}) + err := ValidateOptions([]string{"query", "--external-dir"}) if err == nil { t.Fatal("ValidateOptions() err = nil, want missing option value error") } @@ -107,6 +115,27 @@ func TestValidateOptions_rejectsMissingRequiredOptionValue(t *testing.T) { } } +func TestValidateOptions_allowsDashPrefixedOptionValues(t *testing.T) { + err := ValidateOptions([]string{"query", "--external-dir", "-facts", "site"}) + if err != nil { + t.Fatalf("ValidateOptions() err = %v, want nil", err) + } +} + +func TestValidateOptions_stopsAtDelimiter(t *testing.T) { + err := ValidateOptions([]string{"query", "--", "-v"}) + if err != nil { + t.Fatalf("ValidateOptions() err = %v, want nil", err) + } +} + +func TestValidateOptions_stopsAtFirstQuery(t *testing.T) { + err := ValidateOptions([]string{"query", "os.name", "--missing-fact-name"}) + if err != nil { + t.Fatalf("ValidateOptions() err = %v, want nil", err) + } +} + func TestValidateOptions_rejectsUnknownConcatenatedShortFlag(t *testing.T) { args := PrepareArguments([]string{"-jdtz"}) err := ValidateOptions(args) diff --git a/internal/cli/validation.go b/internal/cli/validation.go index fb84fadd..afe7ff64 100644 --- a/internal/cli/validation.go +++ b/internal/cli/validation.go @@ -47,6 +47,15 @@ func validateOptions(args []string) error { logLevel := "" for i := 0; i < len(args); i++ { arg := args[i] + if arg == "--" { + break + } + if !strings.HasPrefix(arg, "-") { + if IsTask(arg) { + continue + } + break + } if strings.HasPrefix(arg, "-") { seenRaw[rawOption(arg)] = true option, ok := LookupOption(arg) @@ -64,7 +73,7 @@ func validateOptions(args []string) error { } } if OptionTakesSeparateValue(arg) { - if i+1 >= len(args) || strings.HasPrefix(args[i+1], "-") { + if i+1 >= len(args) { return fmt.Errorf("%s requires a value", CanonicalOption(arg)) } i++ diff --git a/internal/engine/cache.go b/internal/engine/cache.go index 33372853..98a590d8 100644 --- a/internal/engine/cache.go +++ b/internal/engine/cache.go @@ -19,7 +19,7 @@ const cacheFormatVersion = 1 var ( cacheRemove = os.Remove - cacheWriteFile = os.WriteFile + cacheWriteFile = writeCacheFile ) // DefaultCachePath returns the platform default directory for cached fact groups. @@ -191,7 +191,7 @@ func (fc *FactCache) CacheFacts(facts []ResolvedFact) error { } grouped := make(map[string]map[string]any) for _, fact := range facts { - group, _, ok := fc.cacheGroupForFact(fact.Name) + group, _, ok := fc.cacheGroupForResolvedFact(fact) if !ok { continue } @@ -241,6 +241,38 @@ func warnCacheWriteFailure(err error, log *slog.Logger) bool { return true } +func writeCacheFile(path string, data []byte, perm os.FileMode) error { + dir := filepath.Dir(path) + tmp, err := os.CreateTemp(dir, filepath.Base(path)+".tmp-*") + if err != nil { + return err + } + tmpPath := tmp.Name() + cleanup := true + defer func() { + if cleanup { + _ = os.Remove(tmpPath) + } + }() + + if _, err := tmp.Write(data); err != nil { + _ = tmp.Close() + return err + } + if err := tmp.Chmod(perm); err != nil { + _ = tmp.Close() + return err + } + if err := tmp.Close(); err != nil { + return err + } + if err := os.Rename(tmpPath, path); err != nil { + return err + } + cleanup = false + return nil +} + func (fc *FactCache) cacheGroupForFact(name string) (string, time.Duration, bool) { bestGroup := "" bestTTL := time.Duration(0) @@ -404,7 +436,7 @@ func parseTTLDuration(value string) (time.Duration, bool) { return 0, false } amount, err := strconv.Atoi(fields[0]) - if err != nil { + if err != nil || amount < 0 { return 0, false } unit := "" @@ -418,6 +450,9 @@ func parseTTLDuration(value string) (time.Duration, bool) { if !ok { return 0, false } + if multiplier > 0 && time.Duration(amount) > time.Duration(1<<63-1)/multiplier { + return 0, false + } duration := time.Duration(amount) * multiplier if multiplier < time.Second { duration = duration.Truncate(time.Second) diff --git a/internal/engine/cache_test.go b/internal/engine/cache_test.go index 0fea75ea..84c0a5ef 100644 --- a/internal/engine/cache_test.go +++ b/internal/engine/cache_test.go @@ -5,6 +5,8 @@ import ( "os" "path/filepath" "reflect" + "runtime" + "strconv" "strings" "testing" "time" @@ -215,6 +217,55 @@ func TestFactCache_cacheFactsWritesConfiguredGroups(t *testing.T) { } } +func TestFactCache_cacheFactsWritesExternalFileBasenameGroup(t *testing.T) { + dir := t.TempDir() + cache := NewFactCache(dir, []FactTTL{{Fact: "ext_file.txt", TTL: "1 hour"}}, nil, discardLog()) + + if err := cache.CacheFacts([]ResolvedFact{ + {Name: "my_external_fact", Value: "ext_fact", Type: "file", File: "/tmp/ext_file.txt"}, + }); err != nil { + t.Fatal(err) + } + + data := readJSONFile(t, filepath.Join(dir, "ext_file.txt")) + if data["cache_format_version"] != float64(1) { + t.Fatalf("cache_format_version = %#v, want 1", data["cache_format_version"]) + } + if data["my_external_fact"] != "ext_fact" { + t.Fatalf("my_external_fact = %#v, want ext_fact", data["my_external_fact"]) + } +} + +func TestWriteCacheFileWritesFinalFileAndRemovesTemp(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "facts.cache") + + if err := writeCacheFile(path, []byte("cached"), 0o600); err != nil { + t.Fatal(err) + } + data, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + if string(data) != "cached" { + t.Fatalf("cache file content = %q, want cached", data) + } + info, err := os.Stat(path) + if err != nil { + t.Fatal(err) + } + if got := info.Mode().Perm(); runtime.GOOS != "windows" && got != 0o600 { + t.Fatalf("cache file mode = %v, want 0600", got) + } + entries, err := os.ReadDir(dir) + if err != nil { + t.Fatal(err) + } + if len(entries) != 1 || entries[0].Name() != "facts.cache" { + t.Fatalf("cache dir entries = %#v, want only final cache file", entries) + } +} + func TestFactCache_ignoresUnsafeCacheGroupNames(t *testing.T) { dir := t.TempDir() outside := filepath.Join(dir, "..", "outside-cache") @@ -391,6 +442,17 @@ func TestParseTTLDuration_matchesRubyUnits(t *testing.T) { } } +func TestParseTTLDuration_rejectsNegativeAndOverflowingValues(t *testing.T) { + tooManyHours := strconv.FormatInt(int64((time.Duration(1<<63-1)/time.Hour)+1), 10) + " h" + for _, input := range []string{"-1 seconds", tooManyHours} { + t.Run(input, func(t *testing.T) { + if got, ok := parseTTLDuration(input); ok { + t.Fatalf("parseTTLDuration(%q) = %v, true; want false", input, got) + } + }) + } +} + func writeJSONFile(t *testing.T, path string, data map[string]any) { t.Helper() encoded, err := json.Marshal(data) diff --git a/internal/engine/config.go b/internal/engine/config.go index d241f267..f0c69d64 100644 --- a/internal/engine/config.go +++ b/internal/engine/config.go @@ -84,7 +84,7 @@ func CurrentDefaultExternalFactDirs() []string { runtime.GOOS == "windows", runtime.GOOS != "windows" && os.Geteuid() == 0, os.Getenv("HOME"), - firstNonEmpty(os.Getenv("ProgramData"), os.Getenv("APPDATA")), + os.Getenv("ProgramData"), ) } @@ -304,15 +304,10 @@ func quotedConfigValues(content string) []string { } func configSection(content, name string) string { - start := strings.Index(strings.ToLower(content), strings.ToLower(name)) - if start < 0 { - return "" - } - open := strings.IndexByte(content[start:], '{') + open := configSectionOpenBrace(content, name) if open < 0 { return "" } - open += start depth := 0 for i := open; i < len(content); i++ { switch content[i] { @@ -328,6 +323,95 @@ func configSection(content, name string) string { return content[open+1:] } +func configSectionOpenBrace(content, name string) int { + depth := 0 + for i := 0; i < len(content); { + i = skipConfigSpaceCommentsAndStrings(content, i) + if i >= len(content) { + return -1 + } + switch content[i] { + case '{': + depth++ + i++ + continue + case '}': + if depth > 0 { + depth-- + } + i++ + continue + } + if depth != 0 { + i++ + continue + } + if !configNameAt(content, i, name) { + i++ + continue + } + j := skipConfigSpaceCommentsAndStrings(content, i+len(name)) + if j >= len(content) || content[j] != ':' && content[j] != '=' { + i++ + continue + } + j = skipConfigSpaceCommentsAndStrings(content, j+1) + if j < len(content) && content[j] == '{' { + return j + } + i++ + } + return -1 +} + +func configNameAt(content string, pos int, name string) bool { + if pos > 0 { + prev := content[pos-1] + if prev == '_' || prev == '-' || prev == '.' || prev >= '0' && prev <= '9' || prev >= 'A' && prev <= 'Z' || prev >= 'a' && prev <= 'z' { + return false + } + } + if len(content)-pos < len(name) || !strings.EqualFold(content[pos:pos+len(name)], name) { + return false + } + if next := pos + len(name); next < len(content) { + b := content[next] + if b == '_' || b == '-' || b == '.' || b >= '0' && b <= '9' || b >= 'A' && b <= 'Z' || b >= 'a' && b <= 'z' { + return false + } + } + return true +} + +func skipConfigSpaceCommentsAndStrings(content string, i int) int { + for i < len(content) { + switch content[i] { + case ' ', '\t', '\r', '\n', ',': + i++ + case '#': + for i < len(content) && content[i] != '\n' { + i++ + } + case '"': + i++ + for i < len(content) { + if content[i] == '\\' { + i += 2 + continue + } + if content[i] == '"' { + i++ + break + } + i++ + } + default: + return i + } + } + return i +} + func lowerConfigValues(values []string) []string { for i, value := range values { values[i] = strings.ToLower(value) diff --git a/internal/engine/config_test.go b/internal/engine/config_test.go index 167f3fb9..ebbdd053 100644 --- a/internal/engine/config_test.go +++ b/internal/engine/config_test.go @@ -56,6 +56,40 @@ fact-groups : { } } +func TestParseConfig_ignoresSectionNamesInsideStringsAndComments(t *testing.T) { + path := filepath.Join(t.TempDir(), "facter.conf") + content := `global : { + note : "facts : { ttls : [ { \"wrong\" : \"1 day\" } ] }", + facts : { ttls : [ { "nested" : "1 day" } ] }, +} +# cli : { debug : true } +facts : { + ttls : [ + { "right" : "2 days" }, + ], +} +cli : { + verbose : true, +}` + if err := os.WriteFile(path, []byte(content), 0o600); err != nil { + t.Fatal(err) + } + + got, err := ParseConfig(path, discardLog()) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(got.TTLs, []FactTTL{{Fact: "right", TTL: "2 days"}}) { + t.Fatalf("TTLs = %#v, want real facts section only", got.TTLs) + } + if got.Debug { + t.Fatal("Debug = true, want commented cli section ignored") + } + if !got.Verbose { + t.Fatal("Verbose = false, want real cli section parsed") + } +} + func TestParseConfig_collectsRepeatedDirectoryEntries(t *testing.T) { path := filepath.Join(t.TempDir(), "facter.conf") content := `global : { @@ -595,6 +629,33 @@ func TestParseConfig_returnsConfiguredFactTTLs(t *testing.T) { } } +func TestParseConfig_TTLsUseExactFactsSection(t *testing.T) { + path := filepath.Join(t.TempDir(), "facter.conf") + content := `facts-extra : { + ttls : [ + { "bad" : "1 hour" } + ], +} +facts : { + ttls : [ + { "timezone" : "30 days" } + ], +}` + if err := os.WriteFile(path, []byte(content), 0o600); err != nil { + t.Fatal(err) + } + + config, err := ParseConfig(path, discardLog()) + if err != nil { + t.Fatal(err) + } + got := config.TTLs + want := []FactTTL{{Fact: "timezone", TTL: "30 days"}} + if !reflect.DeepEqual(got, want) { + t.Fatalf("Config.TTLs = %#v, want %#v", got, want) + } +} + func TestParseConfig_acceptsBareFactNamesAndValues(t *testing.T) { path := filepath.Join(t.TempDir(), "facter.conf") content := `facts : { @@ -757,6 +818,39 @@ func TestFilterBlockedFacts_blocksExactNameAndRoot(t *testing.T) { } } +func TestFilterBlockedFacts_prunesBlockedDescendantsFromStructuredParents(t *testing.T) { + facts := []ResolvedFact{ + { + Name: "os", + Value: map[string]any{ + "name": "Ubuntu", + "release": map[string]any{ + "full": "24.04", + "major": "24", + }, + }, + Type: "core", + }, + } + + got := FilterBlockedFacts(facts, map[string]bool{"os.release.major": true}) + want := []ResolvedFact{ + { + Name: "os", + Value: map[string]any{ + "name": "Ubuntu", + "release": map[string]any{ + "full": "24.04", + }, + }, + Type: "core", + }, + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("FilterBlockedFacts(os.release.major) = %#v, want %#v", got, want) + } +} + func TestParseConfig_returnsConfiguredFactGroups(t *testing.T) { path := filepath.Join(t.TempDir(), "facter.conf") content := `fact-groups : { diff --git a/internal/engine/detector.go b/internal/engine/detector.go index b960d325..ee6e234b 100644 --- a/internal/engine/detector.go +++ b/internal/engine/detector.go @@ -24,6 +24,16 @@ func DetectOSIdentifier(hostOS, linuxDistroID string) (string, error) { return linuxDistroID, nil } return "linux", nil + case strings.Contains(hostOS, "freebsd"): + return "freebsd", nil + case strings.Contains(hostOS, "openbsd"): + return "openbsd", nil + case strings.Contains(hostOS, "netbsd"): + return "netbsd", nil + case strings.Contains(hostOS, "dragonfly"): + return "dragonfly", nil + case strings.Contains(hostOS, "illumos") || strings.Contains(hostOS, "sunos") || strings.Contains(hostOS, "solaris"): + return "illumos", nil default: return "", fmt.Errorf("%w: %q", ErrUnknownOS, hostOS) } @@ -98,5 +108,8 @@ func capitalizeOSName(name string) string { if name == "" { return "" } + if name[1:] != strings.ToLower(name[1:]) { + return strings.ToUpper(name[:1]) + name[1:] + } return strings.ToUpper(name[:1]) + strings.ToLower(name[1:]) } diff --git a/internal/engine/detector_test.go b/internal/engine/detector_test.go index 9f0f859f..2f420494 100644 --- a/internal/engine/detector_test.go +++ b/internal/engine/detector_test.go @@ -100,6 +100,15 @@ func TestDetectOSHierarchyUsesFirstKnownFamilyLikeRubyDetector(t *testing.T) { } } +func TestDetectOSHierarchyPreservesMixedCaseFamilyNames(t *testing.T) { + hierarchy := []any{"RedHat"} + got := DetectOSHierarchy(hierarchy, "my_linux_distro", "RedHat", discardLog()) + want := []string{"RedHat"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("DetectOSHierarchy() = %#v, want %#v", got, want) + } +} + func TestDetectOSIdentifier_matchesRubyHostOSMapping(t *testing.T) { tests := []struct { name string @@ -113,6 +122,12 @@ func TestDetectOSIdentifier_matchesRubyHostOSMapping(t *testing.T) { {name: "windows mswin", hostOS: "mswin", want: "windows"}, {name: "linux distro", hostOS: "linux", distro: "redhat", want: "redhat"}, {name: "linux fallback", hostOS: "linux", want: "linux"}, + {name: "freebsd", hostOS: "freebsd13", want: "freebsd"}, + {name: "openbsd", hostOS: "openbsd7.5", want: "openbsd"}, + {name: "netbsd", hostOS: "netbsd10", want: "netbsd"}, + {name: "dragonfly", hostOS: "dragonfly6.4", want: "dragonfly"}, + {name: "illumos", hostOS: "illumos", want: "illumos"}, + {name: "sunos illumos family", hostOS: "sunos5.11", want: "illumos"}, {name: "unknown", hostOS: "my_custom_os", wantErr: ErrUnknownOS}, } diff --git a/internal/engine/disks.go b/internal/engine/disks.go index 6640f36c..29c57c8a 100644 --- a/internal/engine/disks.go +++ b/internal/engine/disks.go @@ -70,8 +70,7 @@ func disksFact(root string, host hostOS) map[string]any { disk["type"] = "hdd" } } - if sectors, err := strconv.Atoi(readSysfsString(root, name, "size", host.readFile)); err == nil && sectors > 0 { - sizeBytes := sectors * 512 + if sizeBytes, ok := linuxSectorSizeBytes(readSysfsString(root, name, "size", host.readFile)); ok { disk["size_bytes"] = sizeBytes disk["size"] = bytesToHumanReadable(sizeBytes) } @@ -258,15 +257,26 @@ func discoverPartitions(root string, host hostOS) map[string]any { } func addLinuxPartitionSize(partition map[string]any, root, name string, readFile fileReader) { - sectors, err := strconv.Atoi(readSysfsString(root, name, "size", readFile)) - if err != nil || sectors < 0 { - sectors = 0 + sizeBytes, ok := linuxSectorSizeBytes(readSysfsString(root, name, "size", readFile)) + if !ok { + sizeBytes = 0 } - sizeBytes := sectors * 512 partition["size_bytes"] = sizeBytes partition["size"] = bytesToHumanReadable(sizeBytes) } +func linuxSectorSizeBytes(value string) (any, bool) { + sectors, err := strconv.ParseInt(strings.TrimSpace(value), 10, 64) + if err != nil || sectors <= 0 || sectors > int64(1<<63-1)/512 { + return nil, false + } + sizeBytes := sectors * 512 + if sizeBytes <= int64(^uint(0)>>1) { + return int(sizeBytes), true + } + return sizeBytes, true +} + func linuxLSBLKVersion(output string) (int, int, bool) { match := linuxLSBLKVersionPattern.FindStringSubmatch(output) if match == nil { @@ -1147,17 +1157,23 @@ func rootFromLinuxCmdline(input string) string { } func linuxDeviceForPartitionID(partitionID, blkidOutput string) string { - _, id, ok := strings.Cut(partitionID, "=") + idKey, id, ok := strings.Cut(partitionID, "=") if !ok || id == "" { return "" } for line := range strings.SplitSeq(blkidOutput, "\n") { - if !strings.Contains(line, id) { + device, _, ok := strings.Cut(line, ":") + if !ok { continue } - device, _, ok := strings.Cut(line, ":") - if ok { - return strings.TrimSpace(device) + for field := range strings.FieldsSeq(line) { + key, rawValue, ok := strings.Cut(field, "=") + if !ok || !strings.EqualFold(key, idKey) { + continue + } + if strings.Trim(rawValue, `"`) == id { + return strings.TrimSpace(device) + } } } return "" @@ -1480,6 +1496,8 @@ func mountpointsFactWithSkip(entries []mountEntry, stat func(string) (mountStat, } if len(entry.Options) > 0 { mountpoint["options"] = append([]string(nil), entry.Options...) + } else if entry.Device != "" || entry.Filesystem != "" { + mountpoint["options"] = []string{} } if len(mountpoint) == 0 { continue diff --git a/internal/engine/disks_test.go b/internal/engine/disks_test.go index 3eaedb55..b67f0d71 100644 --- a/internal/engine/disks_test.go +++ b/internal/engine/disks_test.go @@ -1002,6 +1002,30 @@ func TestDisksCoreFactsUsesSessionHostForLinuxDiskPartitionAndMountpointFacts(t } } +func TestDisksFactOmitsOverflowingLinuxSectorSize(t *testing.T) { + host := &fakeHostOS{ + dirs: map[string][]os.DirEntry{ + "/sys/block": fakeDirEntries("sdz"), + }, + files: map[string][]byte{ + "/sys/block/sdz/device/model": []byte("OverflowDisk\n"), + "/sys/block/sdz/size": []byte("18014398509481984\n"), + }, + stats: map[string]os.FileInfo{ + "/sys/block/sdz/device": fakeFileInfo{name: "device", mode: os.ModeDir, isDir: true}, + }, + } + + got := disksFact("/sys/block", host) + disk := got["sdz"].(map[string]any) + if _, ok := disk["size_bytes"]; ok { + t.Fatalf("size_bytes = %#v, want omitted for overflowing sector count", disk["size_bytes"]) + } + if _, ok := disk["size"]; ok { + t.Fatalf("size = %#v, want omitted for overflowing sector count", disk["size"]) + } +} + func TestCurrentPartitionsUsesSessionHostGlobForIllumos(t *testing.T) { const vtoc = `* Dimensions: * 512 bytes/sector @@ -1188,6 +1212,20 @@ func TestMountpointsFactIncludesDeviceFilesystemAndOptions(t *testing.T) { } } +func TestMountpointsFactIncludesEmptyOptionsForParsedMountEntries(t *testing.T) { + got := mountpointsFact([]mountEntry{{Device: "devfs", Path: "/dev", Filesystem: "devfs"}}, func(string) (mountStat, bool) { + return mountStat{}, false + }) + mountpoint := got["/dev"].(map[string]any) + options, ok := mountpoint["options"].([]string) + if !ok { + t.Fatalf("options = %#v, want empty []string", mountpoint["options"]) + } + if len(options) != 0 { + t.Fatalf("options = %#v, want empty", options) + } +} + func TestMountpointsFactOmitsEmptyEntries(t *testing.T) { t.Parallel() @@ -1262,6 +1300,15 @@ func TestResolveLinuxRootMountDeviceMatchesRubyResolver(t *testing.T) { blkid: `/dev/xvda1: UUID="f3d" PARTUUID="a2f52878-01"`, want: "/dev/xvda1", }, + { + name: "partuuid matches exact blkid field", + cmdline: "console=tty0 root=PARTUUID=a2f52878-01 rw", + blkid: strings.Join([]string{ + `/dev/xvda1: UUID="not-a2f52878-01" PARTUUID="other"`, + `/dev/xvdb1: UUID="uuid-root" PARTUUID="a2f52878-01"`, + }, "\n"), + want: "/dev/xvdb1", + }, { name: "partuuid remains when blkid cannot map", cmdline: "console=tty0 root=PARTUUID=a2f52878-01 rw", @@ -1360,6 +1407,7 @@ tmpfs on /tmp/example path (tmpfs, local, nosuid) "/dev": map[string]any{ "device": "devfs", "filesystem": "devfs", + "options": []string{}, }, "/tmp/example path": map[string]any{ "device": "tmpfs", diff --git a/internal/engine/dmi.go b/internal/engine/dmi.go index 7c147b2f..11a83271 100644 --- a/internal/engine/dmi.go +++ b/internal/engine/dmi.go @@ -371,7 +371,7 @@ func mapFromValues(source map[string]string, names map[string]string) map[string func dmiChassisTypeName(value string) string { types := []string{ - "Other", "", "Desktop", "Low Profile Desktop", "Pizza Box", "Mini Tower", "Tower", + "Other", "Unknown", "Desktop", "Low Profile Desktop", "Pizza Box", "Mini Tower", "Tower", "Portable", "Laptop", "Notebook", "Hand Held", "Docking Station", "All in One", "Sub Notebook", "Space-Saving", "Lunch Box", "Main System Chassis", "Expansion Chassis", "SubChassis", "Bus Expansion Chassis", "Peripheral Chassis", "Storage Chassis", "Rack Mount Chassis", diff --git a/internal/engine/dmi_test.go b/internal/engine/dmi_test.go index 231c2589..fb6f3760 100644 --- a/internal/engine/dmi_test.go +++ b/internal/engine/dmi_test.go @@ -193,6 +193,23 @@ func TestDMIFact_replacesInvalidUTF8InLinuxSysfsValues(t *testing.T) { } } +func TestDMIFact_mapsUnknownLinuxNumericChassisType(t *testing.T) { + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "chassis_type"), []byte("2\n"), 0o600); err != nil { + t.Fatal(err) + } + + got := dmiFact(dir) + want := map[string]any{ + "chassis": map[string]any{ + "type": "Unknown", + }, + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("dmiFact() = %#v, want %#v", got, want) + } +} + func TestFreeBSDDMIFacts_returnsStructuredFacts(t *testing.T) { values := map[string]string{ "smbios.bios.reldate": "12/12/2018", diff --git a/internal/engine/ec2.go b/internal/engine/ec2.go index d3b51596..d55e2dbd 100644 --- a/internal/engine/ec2.go +++ b/internal/engine/ec2.go @@ -91,7 +91,12 @@ func linuxAWSCloudProvider(name string, ec2Metadata map[string]any, euid int, ex if strings.EqualFold(name, "aws") || euid != 0 || !executable("/opt/puppetlabs/puppet/bin/virt-what") { return true } - return strings.TrimSpace(run("/opt/puppetlabs/puppet/bin/virt-what")) == "aws" + for field := range strings.FieldsSeq(run("/opt/puppetlabs/puppet/bin/virt-what")) { + if strings.EqualFold(field, "aws") { + return true + } + } + return false } func fileExecutable(path string) bool { @@ -117,7 +122,7 @@ func (ec *ec2Client) metadata(ctx context.Context) map[string]any { } func (ec *ec2Client) userdata(ctx context.Context) string { - body, ok := ec.get(ctx, "user-data/") + body, ok := ec.getRaw(ctx, "user-data/") if !ok { return "" } @@ -164,6 +169,14 @@ func metadataChildren(body string) []string { } func (ec *ec2Client) get(ctx context.Context, path string) (string, bool) { + body, ok := ec.getRaw(ctx, path) + if !ok { + return "", false + } + return strings.TrimSpace(body), true +} + +func (ec *ec2Client) getRaw(ctx context.Context, path string) (string, bool) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, ec.baseURL+"/"+path, nil) if err != nil { return "", false @@ -183,7 +196,7 @@ func (ec *ec2Client) get(ctx context.Context, path string) (string, bool) { if err != nil { return "", false } - return strings.TrimSpace(string(data)), true + return string(data), true } func (ec *ec2Client) v2Token(ctx context.Context) string { diff --git a/internal/engine/ec2_test.go b/internal/engine/ec2_test.go index 031177fc..64c82547 100644 --- a/internal/engine/ec2_test.go +++ b/internal/engine/ec2_test.go @@ -172,7 +172,7 @@ func TestEC2Facts_returnsMetadataAndUserdataForAWSHypervisors(t *testing.T) { case "GET /latest/meta-data/instance_type": _, _ = w.Write([]byte("c1.medium")) case "GET /latest/user-data/": - _, _ = w.Write([]byte("userdata")) + _, _ = w.Write([]byte(" #!/bin/sh\necho userdata\n")) default: http.NotFound(w, r) } @@ -189,7 +189,7 @@ func TestEC2Facts_returnsMetadataAndUserdataForAWSHypervisors(t *testing.T) { if got, want := metadata["instance_type"], "c1.medium"; got != want { t.Fatalf("ec2_metadata.instance_type = %#v, want %#v", got, want) } - if got, want := got["ec2_userdata"], "userdata"; got != want { + if got, want := got["ec2_userdata"], " #!/bin/sh\necho userdata\n"; got != want { t.Fatalf("ec2_userdata = %#v, want %#v", got, want) } if fact := factByName(facts, "cloud.provider"); fact == nil || fact.Value != "aws" { @@ -234,6 +234,11 @@ func TestLinuxAWSCloudProviderRequiresVirtWhatAWSForRootKVM(t *testing.T) { t.Fatal("linuxAWSCloudProvider(kvm root virt-what=aws) = false, want true") } + run = func(string, ...string) string { return "kvm\naws\n" } + if !linuxAWSCloudProvider("kvm", metadata, 0, executable, run) { + t.Fatal("linuxAWSCloudProvider(kvm root virt-what includes aws) = false, want true") + } + if !linuxAWSCloudProvider("kvm", metadata, 512, executable, func(string, ...string) string { return "kvm" }) { t.Fatal("linuxAWSCloudProvider(kvm non-root) = false, want true") } diff --git a/internal/engine/external.go b/internal/engine/external.go index d9107fad..aa6fd4f2 100644 --- a/internal/engine/external.go +++ b/internal/engine/external.go @@ -331,10 +331,19 @@ func (l externalFactLoader) loadExternalFactFile(path string, mode os.FileMode) ext := strings.ToLower(filepath.Ext(path)) switch ext { case ".txt": + if !mode.IsRegular() { + return nil, nil + } return l.loadExternalTxtFacts(path) case ".json": + if !mode.IsRegular() { + return nil, nil + } return l.loadExternalJSONFacts(path) case ".yaml", ".yml": + if !mode.IsRegular() { + return nil, nil + } return l.loadExternalYAMLFacts(path) case ".rb": l.s.warn(fmt.Sprintf("Ruby fact files are not supported by the Go port; skipping %s. Rewrite it as an executable external fact (see docs/CUSTOM_FACT_MIGRATION.md).", path)) @@ -391,6 +400,7 @@ func (l externalFactLoader) loadExternalTxtFacts(path string) ([]ResolvedFact, e } func parseKeyValueFacts(scanner *bufio.Scanner) ([]ResolvedFact, error) { + scanner.Buffer(make([]byte, 0, 64*1024), externalFactMaxBytes) facts := []ResolvedFact{} for scanner.Scan() { name, value, ok := strings.Cut(scanner.Text(), "=") @@ -544,6 +554,10 @@ func (l externalFactLoader) loadExternalJSONFacts(path string) ([]ResolvedFact, if err := decoder.Decode(&value); err != nil { return nil, nil } + var trailing any + if err := decoder.Decode(&trailing); err != io.EOF { + return nil, nil + } values, ok := value.(map[string]any) if !ok { @@ -781,6 +795,9 @@ func normalizeStructuredValue(value any) (any, error) { switch v := value.(type) { case json.Number: if i, err := v.Int64(); err == nil { + if i > 1<<31-1 || i < -1<<31 { + return i, nil + } return int(i), nil } if f, err := v.Float64(); err == nil { diff --git a/internal/engine/external_test.go b/internal/engine/external_test.go index f96cd69b..44a9c572 100644 --- a/internal/engine/external_test.go +++ b/internal/engine/external_test.go @@ -422,6 +422,37 @@ func TestLoadExternalFacts_jsonFacts(t *testing.T) { } } +func TestLoadExternalFacts_ignoresJSONWithTrailingTokens(t *testing.T) { + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "site.json"), []byte(`{"site":"lab"} garbage`), 0o600); err != nil { + t.Fatal(err) + } + + got, err := LoadExternalFacts(testSession, []string{dir}) + if err != nil { + t.Fatalf("LoadExternalFacts(testSession) err = %v, want nil for malformed structured file", err) + } + if len(got) != 0 { + t.Fatalf("LoadExternalFacts(testSession) = %#v, want no facts", got) + } +} + +func TestLoadExternalFacts_preservesLargeJSONIntegerAsInt64(t *testing.T) { + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "site.json"), []byte(`{"big":2147483648}`), 0o600); err != nil { + t.Fatal(err) + } + + got, err := LoadExternalFacts(testSession, []string{dir}) + if err != nil { + t.Fatal(err) + } + want := []ResolvedFact{{Name: "big", Value: int64(2147483648), Type: "external"}} + if !reflect.DeepEqual(got, want) { + t.Fatalf("LoadExternalFacts(testSession) = %#v, want %#v", got, want) + } +} + func TestLoadExternalFacts_yamlFacts(t *testing.T) { dir := t.TempDir() content := []byte("site: lab\nfeatures:\n - yaml\n - external\nnested:\n enabled: true\ncount: 3\n") @@ -456,6 +487,44 @@ func TestLoadExternalFacts_yamlFacts(t *testing.T) { } } +func TestExternalFactLoader_ignoresNonRegularStructuredFiles(t *testing.T) { + opened := false + host := &fakeExternalFactLoaderHost{ + openFunc: func(string) (io.ReadCloser, error) { + opened = true + return io.NopCloser(strings.NewReader(`{"site":"lab"}`)), nil + }, + } + + got, err := externalFactLoader{s: testSession, host: host}.loadExternalFactFile("site.json", os.ModeNamedPipe) + if err != nil { + t.Fatalf("loadExternalFactFile() err = %v, want nil", err) + } + if len(got) != 0 { + t.Fatalf("loadExternalFactFile() = %#v, want no facts", got) + } + if opened { + t.Fatal("loadExternalFactFile opened non-regular structured file") + } +} + +func TestLoadExternalFacts_acceptsLongKeyValueLineWithinLimit(t *testing.T) { + dir := t.TempDir() + value := strings.Repeat("x", 70*1024) + if err := os.WriteFile(filepath.Join(dir, "site.txt"), []byte("site="+value+"\n"), 0o600); err != nil { + t.Fatal(err) + } + + got, err := LoadExternalFacts(testSession, []string{dir}) + if err != nil { + t.Fatalf("LoadExternalFacts(testSession) err = %v, want nil", err) + } + want := []ResolvedFact{{Name: "site", Value: value, Type: "external"}} + if !reflect.DeepEqual(got, want) { + t.Fatalf("LoadExternalFacts(testSession) = %#v, want long site fact", got) + } +} + func TestLoadExternalFacts_yamlTimestampValuesStayStrings(t *testing.T) { dir := t.TempDir() content := []byte("testsfact:\n time: 2020-04-28 01:44:08.148119000 +01:01\n") diff --git a/internal/engine/fact.go b/internal/engine/fact.go index 3833588f..66ff13a9 100644 --- a/internal/engine/fact.go +++ b/internal/engine/fact.go @@ -63,14 +63,22 @@ func collectFacts(facts []ResolvedFact, includeTypedDotted bool) (map[string]any // ValueForQuery returns the value selected by fact.UserQuery from fact.Value. func ValueForQuery(fact ResolvedFact) any { + value, _ := valueForQuery(fact) + return value +} + +func valueForQuery(fact ResolvedFact) (any, bool) { query := fact.UserQuery if query == "" || query == fact.Name { - return fact.Value + if fact.Value == nil { + return nil, fact.Type == "custom" || fact.Type == "external" + } + return fact.Value, true } if !strings.HasPrefix(query, fact.Name+".") { - return dig(fact.Value, strings.Split(query, ".")) + return digValue(fact.Value, strings.Split(query, ".")) } - return dig(fact.Value, strings.Split(strings.TrimPrefix(query, fact.Name+"."), ".")) + return digValue(fact.Value, strings.Split(strings.TrimPrefix(query, fact.Name+"."), ".")) } func insert(root map[string]any, parts []string, value any) bool { @@ -78,6 +86,9 @@ func insert(root map[string]any, parts []string, value any) bool { return false } if len(parts) == 1 { + if _, ok := root[parts[0]].(map[string]any); ok { + return false + } root[parts[0]] = value return true } @@ -105,16 +116,21 @@ func factTypeLabel(factType string) string { } func dig(value any, parts []string) any { + value, _ = digValue(value, parts) + return value +} + +func digValue(value any, parts []string) (any, bool) { if len(parts) == 0 { - return value + return value, true } switch v := value.(type) { case map[string]any: next, ok := v[parts[0]] if !ok { - return nil + return nil, false } - return dig(next, parts[1:]) + return digValue(next, parts[1:]) case map[any]any: next, ok := v[parts[0]] if !ok { @@ -127,35 +143,35 @@ func dig(value any, parts []string) any { } } if !ok { - return nil + return nil, false } - return dig(next, parts[1:]) + return digValue(next, parts[1:]) case []any: index, err := strconv.Atoi(parts[0]) if err != nil || index < 0 || index >= len(v) { - return nil + return nil, false } - return dig(v[index], parts[1:]) + return digValue(v[index], parts[1:]) case []string: index, err := strconv.Atoi(parts[0]) if err != nil || index < 0 || index >= len(v) { - return nil + return nil, false } if len(parts) > 1 { - return nil + return nil, false } - return v[index] + return v[index], true case []int: index, err := strconv.Atoi(parts[0]) if err != nil || index < 0 || index >= len(v) { - return nil + return nil, false } if len(parts) > 1 { - return nil + return nil, false } - return v[index] + return v[index], true default: - return nil + return nil, false } } diff --git a/internal/engine/factsutil.go b/internal/engine/factsutil.go index f1f6f970..13feed74 100644 --- a/internal/engine/factsutil.go +++ b/internal/engine/factsutil.go @@ -39,7 +39,7 @@ func discoverFamily(id string) string { // carries Red Hat branding that must not leak into os.distro. func usesRedHatReleaseDistro(id string) bool { switch strings.ToLower(id) { - case "ol", "amzn": + case "ol", "oel", "oraclelinux", "amzn", "amazon": return false } return discoverFamily(id) == "RedHat" diff --git a/internal/engine/factsutil_test.go b/internal/engine/factsutil_test.go index 928bdf70..83593c6e 100644 --- a/internal/engine/factsutil_test.go +++ b/internal/engine/factsutil_test.go @@ -46,6 +46,19 @@ func TestDiscoverFamily_matchesRubyFactsUtils(t *testing.T) { } } +func TestUsesRedHatReleaseDistroExcludesOracleAndAmazonAliases(t *testing.T) { + for _, id := range []string{"ol", "oel", "oraclelinux", "amzn", "amazon"} { + t.Run(id, func(t *testing.T) { + if usesRedHatReleaseDistro(id) { + t.Fatalf("usesRedHatReleaseDistro(%q) = true, want false", id) + } + }) + } + if !usesRedHatReleaseDistro("rhel") { + t.Fatal("usesRedHatReleaseDistro(rhel) = false, want true") + } +} + func TestReleaseHashFromString_matchesRubyFactsUtils(t *testing.T) { tests := []struct { name string diff --git a/internal/engine/formatter.go b/internal/engine/formatter.go index ef0d5ec4..22c9edee 100644 --- a/internal/engine/formatter.go +++ b/internal/engine/formatter.go @@ -183,8 +183,7 @@ func yamlSequenceLines(values []any, depth int) []string { lines := make([]string, 0, len(values)) for _, value := range values { if childMap, ok := value.(map[string]any); ok { - lines = append(lines, indent+"-") - lines = append(lines, yamlLines(childMap, depth+1)...) + lines = append(lines, indent+"- "+yamlInlineMap(childMap)) continue } lines = append(lines, indent+"- "+yamlScalar(value)) @@ -272,14 +271,14 @@ func hoconScalar(value any) string { case nil: return "" case string: - if strings.Contains(v, "_") { + if !isPlainHOCONString(v) { return strconv.Quote(v) } return v case int: return strconv.Itoa(v) case float64: - return strconv.FormatFloat(v, 'f', 1, 64) + return strconv.FormatFloat(v, 'f', -1, 64) case bool: return strconv.FormatBool(v) case map[string]any: @@ -341,9 +340,17 @@ func yamlScalar(value any) string { case int: return strconv.Itoa(v) case float64: - return strconv.FormatFloat(v, 'f', 1, 64) + return strconv.FormatFloat(v, 'f', -1, 64) case bool: return strconv.FormatBool(v) + case map[string]any: + return yamlInlineMap(v) + case []any: + parts := make([]string, 0, len(v)) + for _, item := range v { + parts = append(parts, yamlScalar(item)) + } + return "[" + strings.Join(parts, ", ") + "]" case []string: parts := make([]string, 0, len(v)) for _, item := range v { @@ -361,6 +368,27 @@ func yamlScalar(value any) string { } } +func isPlainHOCONString(value string) bool { + if value == "" || strings.Contains(value, "_") { + return false + } + for _, r := range value { + if r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' || r >= '0' && r <= '9' || r == '-' || r == '.' { + continue + } + return false + } + return true +} + +func yamlInlineMap(value map[string]any) string { + parts := make([]string, 0, len(value)) + for _, key := range sortedKeys(value) { + parts = append(parts, yamlKey(key)+": "+yamlScalar(value[key])) + } + return strings.Join(parts, ", ") +} + func isPlainYAMLString(value string) bool { if value == "" { return false diff --git a/internal/engine/formatter_test.go b/internal/engine/formatter_test.go index cbfaa2d8..99605648 100644 --- a/internal/engine/formatter_test.go +++ b/internal/engine/formatter_test.go @@ -278,11 +278,23 @@ func TestFormatYAML_quotesWindowsPath(t *testing.T) { func TestFormatYAML_formatsFloatWithoutQuotes(t *testing.T) { facts := []ResolvedFact{ - {Name: "memory", Value: 1024.0}, + {Name: "load_average", Value: 1.35}, } got := FormatYAML(facts) - want := "memory: 1024.0\n" + want := "load_average: 1.35\n" + if got != want { + t.Fatalf("FormatYAML() = %q, want %q", got, want) + } +} + +func TestFormatYAML_formatsNestedArrayValuesAsYAML(t *testing.T) { + facts := []ResolvedFact{ + {Name: "nested", Value: []any{[]any{"a", "b"}, map[string]any{"name": "c"}}, UserQuery: "nested"}, + } + + got := FormatYAML(facts) + want := "nested:\n- [a, b]\n- name: c\n" if got != want { t.Fatalf("FormatYAML() = %q, want %q", got, want) } @@ -389,6 +401,30 @@ func TestFormatHOCON_formatsArrayValues(t *testing.T) { } } +func TestFormatHOCON_quotesUnsafeStringValues(t *testing.T) { + facts := []ResolvedFact{ + {Name: "external.payload", Value: "a=b # not syntax"}, + } + + got := FormatHOCON(facts) + want := "external={\n payload=\"a=b # not syntax\"\n}\n" + if got != want { + t.Fatalf("FormatHOCON() = %q, want %q", got, want) + } +} + +func TestFormatHOCON_preservesFloatPrecision(t *testing.T) { + facts := []ResolvedFact{ + {Name: "load_average", Value: 1.35}, + } + + got := FormatHOCON(facts) + want := "load_average=1.35\n" + if got != want { + t.Fatalf("FormatHOCON() = %q, want %q", got, want) + } +} + func TestFormatHOCON_singleNilQueryReturnsEmptyScalar(t *testing.T) { facts := []ResolvedFact{ {Name: "my_external_fact", UserQuery: "my_external_fact", Value: nil}, diff --git a/internal/engine/gce.go b/internal/engine/gce.go index 801d6e76..d28d3912 100644 --- a/internal/engine/gce.go +++ b/internal/engine/gce.go @@ -170,7 +170,7 @@ func (gc *gceClient) get(ctx context.Context, path string) (string, bool) { return "", false } defer resp.Body.Close() - if resp.StatusCode != http.StatusOK || resp.Header.Get("Metadata-Flavor") == "" && resp.Header.Get("Metadata-flavor") == "" { + if resp.StatusCode != http.StatusOK || resp.Header.Get("Metadata-Flavor") != "Google" { return "", false } data, err := io.ReadAll(io.LimitReader(resp.Body, gceMaxBodyBytes)) diff --git a/internal/engine/gce_test.go b/internal/engine/gce_test.go index 3c429407..07fd3632 100644 --- a/internal/engine/gce_test.go +++ b/internal/engine/gce_test.go @@ -95,6 +95,18 @@ func TestGCEFactsSkipInvalidMetadata(t *testing.T) { } } +func TestGCEFactsRequireGoogleMetadataFlavor(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Metadata-Flavor", "NotGoogle") + _, _ = w.Write([]byte(`{"some":"metadata"}`)) + })) + t.Cleanup(server.Close) + + if got := gceFacts(context.Background(), newGCEClient(server.URL, server.Client())); len(got) != 0 { + t.Fatalf("gceFacts(context.Background(), ) = %#v, want no facts for spoofed metadata flavor", got) + } +} + func TestLinuxGCEFactsSkipsMetadataWhenBIOSVendorIsNotGoogleLikeRuby(t *testing.T) { var requested atomic.Bool server := httptest.NewServer(http.HandlerFunc(func(http.ResponseWriter, *http.Request) { diff --git a/internal/engine/groups.go b/internal/engine/groups.go index f322a510..2d7c69ee 100644 --- a/internal/engine/groups.go +++ b/internal/engine/groups.go @@ -211,7 +211,47 @@ func FilterBlockedFacts(facts []ResolvedFact, blocked map[string]bool) []Resolve if blocked[fact.Name] || blocked[root] { continue } + fact.Value = pruneBlockedDescendants(fact.Name, fact.Value, blocked) filtered = append(filtered, fact) } return filtered } + +func pruneBlockedDescendants(name string, value any, blocked map[string]bool) any { + var pruned any + for blockedName := range blocked { + if !strings.HasPrefix(blockedName, name+".") { + continue + } + if pruned == nil { + pruned = deepCopyValue(value) + } + pruned = pruneDottedValue(pruned, strings.Split(strings.TrimPrefix(blockedName, name+"."), ".")) + } + if pruned == nil { + return value + } + return pruned +} + +func pruneDottedValue(value any, parts []string) any { + if len(parts) == 0 { + return value + } + switch v := value.(type) { + case map[string]any: + if len(parts) == 1 { + delete(v, parts[0]) + return v + } + child, ok := v[parts[0]] + if !ok { + return v + } + v[parts[0]] = pruneDottedValue(child, parts[1:]) + if childMap, ok := v[parts[0]].(map[string]any); ok && len(childMap) == 0 { + delete(v, parts[0]) + } + } + return value +} diff --git a/internal/engine/identity.go b/internal/engine/identity.go index 09cc273a..2cdf8e89 100644 --- a/internal/engine/identity.go +++ b/internal/engine/identity.go @@ -23,7 +23,11 @@ func identityFact(s *Session) map[string]any { } privileged := os.Geteuid() == 0 - info := identityInfo{Privileged: &privileged} + info := identityInfo{ + UID: strconv.Itoa(os.Getuid()), + GID: strconv.Itoa(os.Getgid()), + Privileged: &privileged, + } current, err := osuser.Current() if err != nil { return identityFactFromInfo(runtime.GOOS, info) diff --git a/internal/engine/networking_test.go b/internal/engine/networking_test.go index abda78c4..6f77cd84 100644 --- a/internal/engine/networking_test.go +++ b/internal/engine/networking_test.go @@ -397,7 +397,7 @@ func TestLinuxProcGetenvForPIDMatchesRubyProcHelper(t *testing.T) { if path != "/proc/1/environ" { t.Fatalf("path = %q, want /proc/1/environ", path) } - return []string{"container=podman", "bubbles=", "HOME=/root"}, true + return []string{"container=podman\x00bubbles=\x00HOME=/root\x00"}, true } tests := []struct { diff --git a/internal/engine/proc.go b/internal/engine/proc.go index 8d1a0aa9..f36dda71 100644 --- a/internal/engine/proc.go +++ b/internal/engine/proc.go @@ -11,9 +11,11 @@ func linuxProcGetenvForPID(pid int, field string, readLines procEnvironReader) ( prefix := field + "=" lines, _ := readLines(fmt.Sprintf("/proc/%d/environ", pid), nil) for _, line := range lines { - value, ok := strings.CutPrefix(line, prefix) - if ok { - return value, true + for part := range strings.SplitSeq(line, "\x00") { + value, ok := strings.CutPrefix(part, prefix) + if ok { + return value, true + } } } return "", false diff --git a/internal/engine/processors.go b/internal/engine/processors.go index fcc68c4f..d758104f 100644 --- a/internal/engine/processors.go +++ b/internal/engine/processors.go @@ -25,7 +25,7 @@ func currentProcessorISA(s *Session, goos, fallback string, run commandRunner) s if isa := s.cachedPlatformProcessorInfo().ISA; isa != "" { return isa } - return "" + return fallback } if goos == "plan9" { return plan9ProcessorISA(s.readFile, fallback) @@ -435,12 +435,16 @@ func parseLinuxProcessorTopology(input string) (int, int) { return 0, 0 } -func currentLinuxProcessorPhysicalCount(cpuinfoPath, sysCPUPath string, readFile fileReader) int { - data, err := readFile(cpuinfoPath) - if err != nil || len(data) == 0 { - return 0 +func currentLinuxProcessorPhysicalCount(cpuinfoPath, sysCPUPath string, host hostOS) int { + if host == nil { + host = osHost{} } - return linuxProcessorPhysicalCount(string(data), sysCPUPath, readFile) + data, err := host.readFile(cpuinfoPath) + cpuinfo := "" + if err == nil { + cpuinfo = string(data) + } + return linuxProcessorPhysicalCountWithReaders(cpuinfo, sysCPUPath, host.readFile, host.readDir) } func linuxProcessorPhysicalCount(cpuinfo, sysCPUPath string, readFiles ...fileReader) int { @@ -448,6 +452,10 @@ func linuxProcessorPhysicalCount(cpuinfo, sysCPUPath string, readFiles ...fileRe if len(readFiles) > 0 && readFiles[0] != nil { readFile = readFiles[0] } + return linuxProcessorPhysicalCountWithReaders(cpuinfo, sysCPUPath, readFile, os.ReadDir) +} + +func linuxProcessorPhysicalCountWithReaders(cpuinfo, sysCPUPath string, readFile fileReader, readDir func(string) ([]os.DirEntry, error)) int { physicalIDs := make(map[string]struct{}) for line := range strings.SplitSeq(cpuinfo, "\n") { key, value, ok := strings.Cut(line, ":") @@ -463,7 +471,7 @@ func linuxProcessorPhysicalCount(cpuinfo, sysCPUPath string, readFiles ...fileRe return len(physicalIDs) } - entries, err := os.ReadDir(sysCPUPath) + entries, err := readDir(sysCPUPath) if err != nil { return 0 } @@ -614,7 +622,7 @@ func processorsCoreFacts(s *Session) []ResolvedFact { platformProcessors = s.cachedPlatformProcessorInfo() } if runtime.GOOS == "linux" { - platformProcessors.PhysicalCount = currentLinuxProcessorPhysicalCount("/proc/cpuinfo", "/sys/devices/system/cpu", s.readFile) + platformProcessors.PhysicalCount = currentLinuxProcessorPhysicalCount("/proc/cpuinfo", "/sys/devices/system/cpu", s.host) } processorCount := runtime.NumCPU() if platformProcessors.LogicalCount > 0 { diff --git a/internal/engine/processors_test.go b/internal/engine/processors_test.go index bf5c9ce0..d9cde155 100644 --- a/internal/engine/processors_test.go +++ b/internal/engine/processors_test.go @@ -326,6 +326,15 @@ func TestCurrentProcessorISAUsesOpenBSDUnameProcessor(t *testing.T) { } } +func TestCurrentProcessorISAWindowsFallsBackWhenWMIHasNoISA(t *testing.T) { + s := NewSession() + s.host = &fakeHostOS{platform: "windows", runOutput: ""} + + if got := currentProcessorISA(s, "windows", "amd64", func(string, ...string) string { return "" }); got != "amd64" { + t.Fatalf("currentProcessorISA(windows) = %q, want amd64 fallback", got) + } +} + func TestCoreFacts_processorSpeedOmittedWhenProbeYieldsNothing(t *testing.T) { collection := Collection(CoreFacts(testSession)) processors, ok := collection["processors"].(map[string]any) @@ -530,6 +539,26 @@ func TestLinuxProcessorPhysicalCountFallsBackToSysfsPackageIDsLikeRuby(t *testin } } +func TestCurrentLinuxProcessorPhysicalCountUsesHostSysfsWhenCPUInfoMissing(t *testing.T) { + host := &fakeHostOS{ + files: map[string][]byte{ + "/sys/devices/system/cpu/cpu0/topology/physical_package_id": []byte("0\n"), + "/sys/devices/system/cpu/cpu1/topology/physical_package_id": []byte("1\n"), + }, + dirs: map[string][]os.DirEntry{ + "/sys/devices/system/cpu": fakeDirEntries("cpu0", "cpu1"), + }, + } + + got := currentLinuxProcessorPhysicalCount("/proc/cpuinfo", "/sys/devices/system/cpu", host) + if got != 2 { + t.Fatalf("currentLinuxProcessorPhysicalCount() = %d, want sysfs fallback count 2", got) + } + if !reflect.DeepEqual(host.readDirCalls, []string{"/sys/devices/system/cpu"}) { + t.Fatalf("readDir calls = %#v, want sysfs path", host.readDirCalls) + } +} + func TestParseLinuxProcessorExtensions_derivesX86Levels(t *testing.T) { input := "flags : fpu cx8 cmov mmx fxsr sse2 syscall lm cx16 lahf_lm popcnt sse4_1 sse4_2 ssse3 abm avx avx2 bmi1 bmi2 f16c fma movbe xsave\n" diff --git a/internal/engine/projection.go b/internal/engine/projection.go index a842a918..a01ea085 100644 --- a/internal/engine/projection.go +++ b/internal/engine/projection.go @@ -69,7 +69,7 @@ func (p *Projection) Select(queries []string) []ResolvedFact { // missing-vs-nil contract. func (p *Projection) LookupValue(query string) (value any, found bool) { fact := p.Select([]string{query})[0] - if v := ValueForQuery(fact); v != nil { + if v, found := valueForQuery(fact); found { return v, true } if (fact.Type == "custom" || fact.Type == "external") && fact.Value == nil && fact.UserQuery == fact.Name { @@ -152,7 +152,7 @@ func findFactIn(facts []ResolvedFact, collection map[string]any, query string) R return fact } } - if value := dig(collection, strings.Split(query, ".")); value != nil { + if value, found := digValue(collection, strings.Split(query, ".")); found { return ResolvedFact{Name: query, UserQuery: query, Value: value} } return ResolvedFact{Name: query, UserQuery: query, Type: "nil"} diff --git a/internal/engine/projection_test.go b/internal/engine/projection_test.go index eb288f7a..0a117916 100644 --- a/internal/engine/projection_test.go +++ b/internal/engine/projection_test.go @@ -224,6 +224,46 @@ func TestSnapshotValueReusesSnapshotTree(t *testing.T) { } } +func TestSnapshotValueDistinguishesNestedNilFromMissing(t *testing.T) { + sn := newSnapshot([]ResolvedFact{ + {Name: "external", Value: map[string]any{"blank": nil}, Type: "external"}, + }, nil) + + value, err := sn.Value("external.blank") + if err != nil { + t.Fatalf("Value() error = %v, want nil", err) + } + if value != nil { + t.Fatalf("Value() = %#v, want nil", value) + } +} + +func TestSnapshotReturnedMutableValuesDoNotAffectSnapshot(t *testing.T) { + sn := newSnapshot([]ResolvedFact{ + {Name: "site", Value: map[string]any{"roles": []string{"web"}}, Type: "external"}, + }, nil) + + tree := sn.Tree() + tree["site"].(map[string]any)["roles"].([]string)[0] = "db" + value, err := sn.Value("site.roles.0") + if err != nil { + t.Fatalf("Value() error = %v", err) + } + if value != "web" { + t.Fatalf("Value() after Tree mutation = %#v, want web", value) + } + + facts := sn.Facts() + facts[0].Value.(map[string]any)["roles"].([]string)[0] = "db" + value, err = sn.Value("site.roles.0") + if err != nil { + t.Fatalf("Value() error = %v", err) + } + if value != "web" { + t.Fatalf("Value() after Facts mutation = %#v, want web", value) + } +} + // sameMap reports whether a and b are the same underlying map instance. func sameMap(a, b map[string]any) bool { return reflect.ValueOf(a).Pointer() == reflect.ValueOf(b).Pointer() diff --git a/internal/engine/query.go b/internal/engine/query.go index 4fd52990..566e2b04 100644 --- a/internal/engine/query.go +++ b/internal/engine/query.go @@ -18,7 +18,8 @@ func SelectWithDottedFacts(facts []ResolvedFact, queries []string, includeTypedD func factMatchesQuery(factName, query string) bool { if strings.Contains(factName, ".*") && !strings.Contains(query, ".") { - matched, err := regexp.MatchString("^"+factName+"$", query) + pattern := strings.ReplaceAll(regexp.QuoteMeta(factName), `\.\*`, `.*`) + matched, err := regexp.MatchString("^"+pattern+"$", query) return err == nil && matched } return query == factName || strings.HasPrefix(query, factName+".") diff --git a/internal/engine/query_test.go b/internal/engine/query_test.go index 2da2f1eb..9de7b5fb 100644 --- a/internal/engine/query_test.go +++ b/internal/engine/query_test.go @@ -91,6 +91,25 @@ func TestSelect_matchesWildcardFactNameLikeRubyQueryParser(t *testing.T) { } } +func TestSelect_wildcardFactNameEscapesOtherRegexpCharacters(t *testing.T) { + facts := []ResolvedFact{ + {Name: "metric[prod].*", Value: "literal", Type: "external"}, + } + + selected := Select(facts, []string{"metric[prod]cpu"}) + if len(selected) != 1 { + t.Fatalf("Select() returned %d facts, want 1", len(selected)) + } + + got := selected[0] + if got.Name != "metric[prod].*" { + t.Fatalf("Name = %q, want metric[prod].*", got.Name) + } + if got.Value != "literal" { + t.Fatalf("Value = %#v, want literal", got.Value) + } +} + func TestSelect_doesNotMatchWildcardNameForDottedStructuredQuery(t *testing.T) { facts := []ResolvedFact{ {Name: "ssh.*key", Value: "wildcard", Type: "external"}, @@ -127,6 +146,19 @@ func TestCollectionWithDottedFactsKeepsExistingScalarOnNestedCollision(t *testin } } +func TestCollectionWithDottedFactsKeepsExistingMapOnScalarCollision(t *testing.T) { + facts := []ResolvedFact{ + {Name: "mygroup.fact1", Value: "g1_f1_value", Type: "custom"}, + {Name: "mygroup", Value: "scalar_value", Type: "custom"}, + } + + got := CollectionWithDottedFacts(facts, true) + want := map[string]any{"mygroup": map[string]any{"fact1": "g1_f1_value"}} + if !reflect.DeepEqual(got, want) { + t.Fatalf("CollectionWithDottedFacts() = %#v, want %#v", got, want) + } +} + // Collisions are reported once at discovery (newSnapshot), not by the formatter // path. CollectionWithDottedFacts itself is diagnostic-silent so the formatter // and query paths that re-run collection never re-emit. diff --git a/internal/engine/selinux.go b/internal/engine/selinux.go index b8ec1995..3eff6a6e 100644 --- a/internal/engine/selinux.go +++ b/internal/engine/selinux.go @@ -19,18 +19,25 @@ func selinuxFactsForPlatform(goos, mountsPath, configPath string, readFile fileR func selinuxFacts(mountsPath, configPath string, readFile fileReader) []ResolvedFact { mountpoint := selinuxMountpoint(mountsPath, readFile) configMode, configPolicy, hasConfig := readSELinuxConfig(configPath, readFile) - enabled := mountpoint != "" && hasConfig + enabled := mountpoint != "" values := map[string]any{"enabled": enabled} if enabled { - values["config_mode"] = configMode - values["config_policy"] = configPolicy + if hasConfig { + if configMode != "" { + values["config_mode"] = configMode + } + if configPolicy != "" { + values["config_policy"] = configPolicy + } + } values["policy_version"] = readOptionalText(filepath.Join(mountpoint, "policyvers"), readFile) - enforced := strings.TrimSpace(readText(filepath.Join(mountpoint, "enforce"), readFile)) == "1" - values["enforced"] = enforced - if enforced { - values["current_mode"] = "enforcing" - } else { - values["current_mode"] = "permissive" + if enforced, ok := readSELinuxEnforce(filepath.Join(mountpoint, "enforce"), readFile); ok { + values["enforced"] = enforced + if enforced { + values["current_mode"] = "enforcing" + } else { + values["current_mode"] = "permissive" + } } } @@ -80,6 +87,21 @@ func readSELinuxConfig(path string, readFile fileReader) (mode, policy string, o return mode, policy, true } +func readSELinuxEnforce(path string, readFile fileReader) (bool, bool) { + data, err := readFile(path) + if err != nil { + return false, false + } + switch strings.TrimSpace(string(data)) { + case "1": + return true, true + case "0": + return false, true + default: + return false, false + } +} + // selinuxCoreFacts assembles the selinux category facts (os.selinux), emitted // only on Linux. func selinuxCoreFacts(s *Session) []ResolvedFact { diff --git a/internal/engine/selinux_test.go b/internal/engine/selinux_test.go index 31f669c8..c0ae53be 100644 --- a/internal/engine/selinux_test.go +++ b/internal/engine/selinux_test.go @@ -48,8 +48,12 @@ func TestSELinuxFactsDisabledWithoutMountpointOrConfig(t *testing.T) { writeFile(t, filepath.Join(dir, "mounts"), "none /sys/fs/selinux selinuxfs rw 0 0\n") core = selinuxFacts(filepath.Join(dir, "mounts"), filepath.Join(dir, "missing-config"), os.ReadFile) - if got := Collection(core)["os"].(map[string]any)["selinux"].(map[string]any)["enabled"]; got != false { - t.Fatalf("os.selinux.enabled = %#v, want false without config", got) + selinux := Collection(core)["os"].(map[string]any)["selinux"].(map[string]any) + if got := selinux["enabled"]; got != true { + t.Fatalf("os.selinux.enabled = %#v, want true with selinuxfs even without config", got) + } + if _, ok := selinux["config_mode"]; ok { + t.Fatalf("os.selinux.config_mode = %#v, want omitted without config", selinux["config_mode"]) } } @@ -95,3 +99,25 @@ func TestSELinuxFactsKeepsMissingPolicyVersionNil(t *testing.T) { t.Fatalf("os.selinux.policy_version = %#v, want nil", got) } } + +func TestSELinuxFactsOmitsEnforcementWhenEnforceIsUnreadableOrInvalid(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + mountpoint := filepath.Join(dir, "selinux") + if err := os.Mkdir(mountpoint, 0o700); err != nil { + t.Fatal(err) + } + writeFile(t, filepath.Join(dir, "mounts"), "none "+mountpoint+" selinuxfs rw 0 0\n") + writeFile(t, filepath.Join(dir, "config"), "SELINUX=enforcing\nSELINUXTYPE=targeted\n") + writeFile(t, filepath.Join(mountpoint, "enforce"), "unexpected\n") + + core := selinuxFacts(filepath.Join(dir, "mounts"), filepath.Join(dir, "config"), os.ReadFile) + selinux := Collection(core)["os"].(map[string]any)["selinux"].(map[string]any) + if _, ok := selinux["enforced"]; ok { + t.Fatalf("os.selinux.enforced = %#v, want omitted for invalid enforce content", selinux["enforced"]) + } + if _, ok := selinux["current_mode"]; ok { + t.Fatalf("os.selinux.current_mode = %#v, want omitted for invalid enforce content", selinux["current_mode"]) + } +} diff --git a/internal/engine/snapshot.go b/internal/engine/snapshot.go index 878c2fd0..f3499673 100644 --- a/internal/engine/snapshot.go +++ b/internal/engine/snapshot.go @@ -28,6 +28,7 @@ func newSnapshot(facts []ResolvedFact, log *slog.Logger) *Snapshot { if log == nil { log = slog.New(slog.DiscardHandler) } + facts = cloneFacts(facts) tree, collisions := collectFacts(facts, false) for _, fact := range collisions { reportCollectionCollision(log, fact) @@ -80,7 +81,15 @@ func (sn *Snapshot) All() iter.Seq2[string, any] { // Facts returns the resolved facts backing the Snapshot, for the CLI's // formatter pipeline. func (sn *Snapshot) Facts() []ResolvedFact { - return slices.Clone(sn.facts) + return cloneFacts(sn.facts) +} + +func cloneFacts(facts []ResolvedFact) []ResolvedFact { + out := slices.Clone(facts) + for i := range out { + out[i].Value = deepCopyValue(out[i].Value) + } + return out } func deepCopyValue(value any) any { @@ -91,6 +100,18 @@ func deepCopyValue(value any) any { out[key] = deepCopyValue(item) } return out + case map[string]string: + out := make(map[string]string, len(v)) + for key, item := range v { + out[key] = item + } + return out + case map[string][]string: + out := make(map[string][]string, len(v)) + for key, item := range v { + out[key] = slices.Clone(item) + } + return out case map[any]any: out := make(map[any]any, len(v)) for key, item := range v { @@ -103,11 +124,76 @@ func deepCopyValue(value any) any { out[i] = deepCopyValue(item) } return out + case []string: + return slices.Clone(v) + case []int: + return slices.Clone(v) + default: + return deepCopyReflect(value) + } +} + +func deepCopyReflect(value any) any { + rv := reflect.ValueOf(value) + if !rv.IsValid() { + return value + } + switch rv.Kind() { + case reflect.Slice: + if rv.IsNil() { + return value + } + out := reflect.MakeSlice(rv.Type(), rv.Len(), rv.Len()) + for i := range rv.Len() { + setReflectValue(out.Index(i), deepCopyValue(rv.Index(i).Interface())) + } + return out.Interface() + case reflect.Array: + out := reflect.New(rv.Type()).Elem() + for i := range rv.Len() { + setReflectValue(out.Index(i), deepCopyValue(rv.Index(i).Interface())) + } + return out.Interface() + case reflect.Map: + if rv.IsNil() { + return value + } + out := reflect.MakeMapWithSize(rv.Type(), rv.Len()) + for _, key := range rv.MapKeys() { + item := deepCopyValue(rv.MapIndex(key).Interface()) + itemValue := reflect.ValueOf(item) + if item == nil { + itemValue = reflect.Zero(rv.Type().Elem()) + } + if itemValue.IsValid() && itemValue.Type().AssignableTo(rv.Type().Elem()) { + out.SetMapIndex(key, itemValue) + } else if itemValue.IsValid() && itemValue.Type().ConvertibleTo(rv.Type().Elem()) { + out.SetMapIndex(key, itemValue.Convert(rv.Type().Elem())) + } else { + out.SetMapIndex(key, rv.MapIndex(key)) + } + } + return out.Interface() default: return value } } +func setReflectValue(dst reflect.Value, value any) { + if value == nil { + dst.SetZero() + return + } + rv := reflect.ValueOf(value) + if rv.Type().AssignableTo(dst.Type()) { + dst.Set(rv) + return + } + if rv.Type().ConvertibleTo(dst.Type()) { + dst.Set(rv.Convert(dst.Type())) + } +} + // NormalizeCustomValue canonicalizes a custom fact value: time.Time values // become RFC 3339 strings and string-keyed maps become map[string]any, so the // canonical tree holds only tree-shaped data. @@ -171,6 +257,39 @@ func customValueContainsNullByte(value any) bool { } } return false + default: + return customValueReflectContainsNullByte(value) + } +} + +func customValueReflectContainsNullByte(value any) bool { + rv := reflect.ValueOf(value) + if !rv.IsValid() { + return false + } + switch rv.Kind() { + case reflect.Interface, reflect.Pointer: + if rv.IsNil() { + return false + } + return customValueContainsNullByte(rv.Elem().Interface()) + case reflect.Slice, reflect.Array: + for i := range rv.Len() { + if customValueContainsNullByte(rv.Index(i).Interface()) { + return true + } + } + return false + case reflect.Map: + for _, key := range rv.MapKeys() { + if key.Kind() == reflect.String && strings.ContainsRune(key.String(), '\x00') { + return true + } + if customValueContainsNullByte(rv.MapIndex(key).Interface()) { + return true + } + } + return false default: return false } diff --git a/internal/engine/snapshot_test.go b/internal/engine/snapshot_test.go index f1cb0d3b..8bc6f72e 100644 --- a/internal/engine/snapshot_test.go +++ b/internal/engine/snapshot_test.go @@ -75,6 +75,8 @@ func TestCustomValueContainsNullByte(t *testing.T) { {"map value with NUL", map[string]any{"k": "v\x00"}, true}, {"map key with NUL", map[string]any{"k\x00": "v"}, true}, {"nested NUL deep in slice-of-map", []any{map[string]any{"k": []any{"x\x00"}}}, true}, + {"typed slice with NUL element", []string{"a", "b\x00"}, true}, + {"typed map with slice NUL element", map[string][]string{"k": {"v\x00"}}, true}, {"non-string scalar is never a NUL", 42, false}, } @@ -132,3 +134,17 @@ func TestDeepCopyValueHandlesAnyKeyedMap(t *testing.T) { t.Errorf("original map[any]any mutated through copy: inner = %q, want %q", got, "orig") } } + +func TestDeepCopyValueHandlesTypedSlicesInMaps(t *testing.T) { + original := map[string][]int{"numbers": {1, 2}} + + copied, ok := deepCopyValue(original).(map[string][]int) + if !ok { + t.Fatalf("deepCopyValue returned %T, want map[string][]int", deepCopyValue(original)) + } + copied["numbers"][0] = 99 + + if got := original["numbers"][0]; got != 1 { + t.Errorf("original typed slice mutated through copy: numbers[0] = %d, want 1", got) + } +} diff --git a/internal/engine/ssh.go b/internal/engine/ssh.go index bdcb7b50..9b40d088 100644 --- a/internal/engine/ssh.go +++ b/internal/engine/ssh.go @@ -30,13 +30,13 @@ func discoverSSHHostKeysForPlatform(goos, programData string, readFile fileReade if programData == "" { return nil } - paths = []string{filepath.Join(programData, "ssh")} + paths = []string{sshJoin(goos, programData, "ssh")} } files := []string{"ssh_host_rsa_key.pub", "ssh_host_dsa_key.pub", "ssh_host_ecdsa_key.pub", "ssh_host_ed25519_key.pub"} keys := make([]sshHostKey, 0, len(files)) for _, dir := range paths { for _, file := range files { - data, err := readFile(filepath.Join(dir, file)) + data, err := readFile(sshJoin(goos, dir, file)) if err != nil { continue } @@ -53,6 +53,13 @@ func discoverSSHHostKeysForPlatform(goos, programData string, readFile fileReade return keys } +func sshJoin(goos, dir, name string) string { + if goos == "windows" { + return strings.TrimRight(dir, `\/`) + `\` + name + } + return filepath.Join(dir, name) +} + func parseSSHHostPublicKey(line string) (sshHostKey, bool) { fields := strings.Fields(line) if len(fields) < 2 { @@ -62,11 +69,10 @@ func parseSSHHostPublicKey(line string) (sshHostKey, bool) { if !ok { return sshHostKey{}, false } - encodedKey := sshBase64Key(fields[1]) - if encodedKey == "" { + if fields[1] == "" { return sshHostKey{}, false } - decoded, err := base64.StdEncoding.DecodeString(encodedKey) + decoded, err := base64.StdEncoding.DecodeString(fields[1]) if err != nil { return sshHostKey{}, false } @@ -81,17 +87,6 @@ func parseSSHHostPublicKey(line string) (sshHostKey, bool) { }, true } -func sshBase64Key(key string) string { - var b strings.Builder - b.Grow(len(key)) - for _, r := range key { - if (r >= 'A' && r <= 'Z') || (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '+' || r == '/' || r == '=' { - b.WriteRune(r) - } - } - return b.String() -} - func sshKeyName(keyType string) (string, int, bool) { switch keyType { case "ssh-rsa": @@ -120,6 +115,9 @@ func sshFactsForPlatform(goos string, keys []sshHostKey) []ResolvedFact { } structured := make(map[string]any, len(keys)) for _, key := range keys { + if _, exists := structured[key.Name]; exists { + continue + } structured[key.Name] = map[string]any{ "fingerprints": map[string]any{ "sha1": key.SHA1, diff --git a/internal/engine/ssh_test.go b/internal/engine/ssh_test.go index 77e05bc9..39be0cea 100644 --- a/internal/engine/ssh_test.go +++ b/internal/engine/ssh_test.go @@ -51,20 +51,9 @@ func TestParseSSHHostPublicKeyBuildsStructuredFacts(t *testing.T) { } } -func TestParseSSHHostPublicKeyIgnoresNonBase64CharactersForFingerprints(t *testing.T) { - entry, ok := parseSSHHostPublicKey("ssh-rsa -_YWJj root@example") - if !ok { - t.Fatal("parseSSHHostPublicKey() ok = false, want true") - } - - if got, want := entry.Key, "-_YWJj"; got != want { - t.Fatalf("entry.Key = %q, want original key %q", got, want) - } - if got, want := entry.SHA1, "SSHFP 1 1 a9993e364706816aba3e25717850c26c9cd0d89d"; got != want { - t.Fatalf("entry.SHA1 = %q, want %q", got, want) - } - if got, want := entry.SHA256, "SSHFP 1 2 ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"; got != want { - t.Fatalf("entry.SHA256 = %q, want %q", got, want) +func TestParseSSHHostPublicKeyRejectsInvalidBase64Key(t *testing.T) { + if entry, ok := parseSSHHostPublicKey("ssh-rsa -_YWJj root@example"); ok { + t.Fatalf("parseSSHHostPublicKey() = %#v, true; want rejected malformed key", entry) } } @@ -135,9 +124,9 @@ func assertDiscoverSSHHostKeysPOSIXSearchesRubyPathsAndOrder(t *testing.T, goos func TestDiscoverSSHHostKeysWindowsReadsProgramDataSSH(t *testing.T) { readFile := func(path string) ([]byte, error) { switch path { - case filepath.Join(`C:\ProgramData`, "ssh", "ssh_host_rsa_key.pub"): + case `C:\ProgramData\ssh\ssh_host_rsa_key.pub`: return []byte("ssh-rsa YWJj root@example"), nil - case filepath.Join(`C:\ProgramData`, "ssh", "ssh_host_ecdsa_key.pub"): + case `C:\ProgramData\ssh\ssh_host_ecdsa_key.pub`: return []byte("ecdsa-sha2-nistp256 ZGVm root@example"), nil default: return nil, os.ErrNotExist @@ -154,6 +143,19 @@ func TestDiscoverSSHHostKeysWindowsReadsProgramDataSSH(t *testing.T) { } } +func TestSSHFactsPreserveFirstDuplicateKeyType(t *testing.T) { + collection := Collection(sshFacts([]sshHostKey{ + {Name: "rsa", Type: "ssh-rsa", Key: "first", SHA1: "first-sha1", SHA256: "first-sha256"}, + {Name: "rsa", Type: "ssh-rsa", Key: "second", SHA1: "second-sha1", SHA256: "second-sha256"}, + })) + + ssh := collection["ssh"].(map[string]any) + rsa := ssh["rsa"].(map[string]any) + if got := rsa["key"]; got != "first" { + t.Fatalf("ssh.rsa.key = %#v, want first discovered key", got) + } +} + func TestSSHFactsWindowsUnprivilegedSkipsDiscovery(t *testing.T) { t.Parallel() diff --git a/internal/engine/timezone.go b/internal/engine/timezone.go index 98af6cd4..7af11c6b 100644 --- a/internal/engine/timezone.go +++ b/internal/engine/timezone.go @@ -7,6 +7,7 @@ import ( "strconv" "strings" "time" + "unicode/utf8" "golang.org/x/text/encoding" "golang.org/x/text/encoding/charmap" @@ -28,6 +29,9 @@ func currentWindowsTimezone(goos, zone, apiCodepage string, registryCodepage fun if goos != "windows" || zone == "" { return "" } + if utf8.ValidString(zone) { + return zone + } codepage := apiCodepage if codepage == "" { codepage = registryCodepage() diff --git a/internal/engine/timezone_test.go b/internal/engine/timezone_test.go index 3c9a93c7..4056a694 100644 --- a/internal/engine/timezone_test.go +++ b/internal/engine/timezone_test.go @@ -5,6 +5,20 @@ import ( "time" ) +func TestWindowsTimezoneKeepsValidUTF8ZoneName(t *testing.T) { + t.Parallel() + + zone := "Hora estándar" + got := currentWindowsTimezone("windows", zone, "850", func() string { + t.Fatal("registry codepage should not be used for valid UTF-8") + return "" + }) + + if got != zone { + t.Fatalf("currentWindowsTimezone() = %q, want %q", got, zone) + } +} + func TestWindowsTimezoneUsesAPICodepage(t *testing.T) { t.Parallel() diff --git a/internal/engine/uptime.go b/internal/engine/uptime.go index 7d04149e..8c13c7af 100644 --- a/internal/engine/uptime.go +++ b/internal/engine/uptime.go @@ -47,11 +47,15 @@ func currentUptimeInfo(s *Session, goos string, readFile fileReader, run command } if goos == "linux" { virtual := detectLinuxVirtualization(currentLinuxVirtualizationInputWithCommands(s, run)) - return currentLinuxUptimeInfo(readFile, run, now, virtual.Name == "docker") + return currentLinuxUptimeInfo(readFile, run, now, linuxContainerUptimeUsesPID1(virtual.Name)) } return currentPosixUptime(readFile, run, now) } +func linuxContainerUptimeUsesPID1(name string) bool { + return name == "docker" || name == "kubernetes" +} + func currentLinuxUptimeInfo(readFile fileReader, run commandRunner, now func() time.Time, docker bool) uptimeInfo { if docker { seconds := parseDockerElapsedTimeSeconds(run("ps", "-o", "etime=", "-p", "1")) @@ -361,13 +365,20 @@ func uptimeCoreFacts(s *Session) []ResolvedFact { if runtime.GOOS == "plan9" { return plan9UptimeCoreFacts(s.cachedUptime()) } - uptime := s.cachedUptime() - loadAverages := s.cachedLoadAverages() - return []ResolvedFact{ + return uptimeFacts(s.cachedUptime(), s.cachedLoadAverages()) +} + +func uptimeFacts(uptime uptimeInfo, loadAverages map[string]any) []ResolvedFact { + facts := []ResolvedFact{ {Name: "load_averages", Value: loadAverages}, - {Name: "system_uptime.days", Value: int(uptime.Duration.Hours()) / 24}, - {Name: "system_uptime.hours", Value: int(uptime.Duration.Hours())}, - {Name: "system_uptime.seconds", Value: int(uptime.Duration.Seconds())}, - {Name: "system_uptime.uptime", Value: uptimeString(uptime)}, } + if uptime.Known { + facts = append(facts, + ResolvedFact{Name: "system_uptime.days", Value: int64(uptime.Duration.Hours()) / 24}, + ResolvedFact{Name: "system_uptime.hours", Value: int64(uptime.Duration.Hours())}, + ResolvedFact{Name: "system_uptime.seconds", Value: int64(uptime.Duration.Seconds())}, + ) + } + facts = append(facts, ResolvedFact{Name: "system_uptime.uptime", Value: uptimeString(uptime)}) + return facts } diff --git a/internal/engine/uptime_test.go b/internal/engine/uptime_test.go index cb5fe449..5bfb1fcf 100644 --- a/internal/engine/uptime_test.go +++ b/internal/engine/uptime_test.go @@ -59,6 +59,52 @@ func TestUptimeStringReturnsUnknownWhenSecondsAreUnknown(t *testing.T) { } } +func TestUptimeFactsOmitNumericFieldsWhenUptimeUnknown(t *testing.T) { + got := Collection(uptimeFacts(uptimeInfo{}, emptyLoadAverages())) + systemUptime, ok := got["system_uptime"].(map[string]any) + if !ok { + t.Fatalf("system_uptime = %#v, want map", got["system_uptime"]) + } + if got := systemUptime["uptime"]; got != "unknown" { + t.Fatalf("system_uptime.uptime = %#v, want unknown", got) + } + for _, key := range []string{"days", "hours", "seconds"} { + if _, ok := systemUptime[key]; ok { + t.Fatalf("system_uptime.%s present for unknown uptime: %#v", key, systemUptime) + } + } +} + +func TestUptimeFactsUseInt64DurationFields(t *testing.T) { + got := Collection(uptimeFacts(uptimeInfo{Duration: time.Duration(1<<33) * time.Second, Known: true}, emptyLoadAverages())) + systemUptime := got["system_uptime"].(map[string]any) + if seconds, ok := systemUptime["seconds"].(int64); !ok || seconds != int64(1<<33) { + t.Fatalf("system_uptime.seconds = %#v, want int64 %d", systemUptime["seconds"], int64(1<<33)) + } +} + +func TestCurrentUptimeInfoUsesPID1ElapsedTimeForKubernetes(t *testing.T) { + s := NewSession() + s.host = &fakeHostOS{ + platform: "linux", + files: map[string][]byte{ + "/proc/1/cgroup": []byte("0::/kubepods.slice/pod123\n"), + }, + } + run := func(name string, args ...string) string { + if name == "ps" && reflect.DeepEqual(args, []string{"-o", "etime=", "-p", "1"}) { + return "01:02" + } + return "" + } + + got := currentUptimeInfo(s, "linux", s.readFile, run, time.Now) + want := uptimeInfo{Duration: 62 * time.Second, Known: true} + if got != want { + t.Fatalf("currentUptimeInfo(kubernetes) = %#v, want %#v", got, want) + } +} + func TestParseUptimeCommandSeconds_matchesRubyFacterUptimeParser(t *testing.T) { t.Parallel() diff --git a/internal/engine/virtual.go b/internal/engine/virtual.go index fa6eaddf..3d7af859 100644 --- a/internal/engine/virtual.go +++ b/internal/engine/virtual.go @@ -440,8 +440,13 @@ func detectLinuxVirtualization(input linuxVirtualizationInput) virtualization { if name := virtWhatVirtualization(input.VirtWhatOutput, input.ProcStatus); name != "" { return virtualization{Name: name, IsVirtual: true} } - if name := dmiProductHypervisor(input.DMIProductName); name != "" { - return virtualization{Name: name, IsVirtual: true} + if virtual := detectDMIHostVirtualization(dmiVirtualizationInput{ + Manufacturer: input.DMISysVendor, + ProductName: input.DMIProductName, + BIOSVendor: input.DMIBIOSVendor, + PCIOutput: input.LspciOutput, + }); virtual.IsVirtual { + return virtual } if name := parseVMwareCommand(input.VMwareCommand); name != "" { return virtualization{Name: name, IsVirtual: true} diff --git a/internal/engine/virtual_test.go b/internal/engine/virtual_test.go index ecabbad5..b2a62675 100644 --- a/internal/engine/virtual_test.go +++ b/internal/engine/virtual_test.go @@ -119,6 +119,40 @@ func TestDetectLinuxVirtualization_detectsOpenVZ(t *testing.T) { } } +func TestDetectLinuxVirtualization_detectsKVMFromDMI(t *testing.T) { + tests := []struct { + name string + input linuxVirtualizationInput + want virtualization + }{ + { + name: "qemu system vendor", + input: linuxVirtualizationInput{ + DMISysVendor: "QEMU", + DMIProductName: "Standard PC (i440FX + PIIX, 1996)", + }, + want: virtualization{Name: "kvm", IsVirtual: true}, + }, + { + name: "seabios vendor", + input: linuxVirtualizationInput{ + DMIBIOSVendor: "SeaBIOS", + DMIProductName: "Standard PC (i440FX + PIIX, 1996)", + }, + want: virtualization{Name: "kvm", IsVirtual: true}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := detectLinuxVirtualization(tt.input) + if got != tt.want { + t.Fatalf("detectLinuxVirtualization() = %#v, want %#v", got, tt.want) + } + }) + } +} + func TestDetectLinuxVirtualization_detectsDMIProductHypervisors(t *testing.T) { got := detectLinuxVirtualization(linuxVirtualizationInput{DMIProductName: "Bochs Machine"}) want := virtualization{Name: "bochs", IsVirtual: true} diff --git a/internal/schema/schema.go b/internal/schema/schema.go index bad41fce..a990bd18 100644 --- a/internal/schema/schema.go +++ b/internal/schema/schema.go @@ -197,18 +197,98 @@ func (s Schema) MissingEntries(paths []string, platform string) []string { if wildcardPrefixAbsent(pattern, paths) { continue } - present := false + if !schemaEntryPresent(pattern, entry, paths) { + missing = append(missing, pattern) + } + } + return missing +} + +func schemaEntryPresent(pattern string, entry Entry, paths []string) bool { + patternSegments := splitPath(pattern) + lastWildcard := lastSegmentIndex(patternSegments, "*") + if lastWildcard == -1 || lastWildcard == len(patternSegments)-1 { for _, path := range paths { if MatchesPath(pattern, entry, path) { + return true + } + } + return false + } + + concretePatterns := concreteWildcardPatterns(patternSegments, paths) + if len(concretePatterns) == 0 { + return false + } + for _, concrete := range concretePatterns { + present := false + for _, path := range paths { + if MatchesPath(concrete, entry, path) { present = true break } } if !present { - missing = append(missing, pattern) + return false } } - return missing + return true +} + +func concreteWildcardPatterns(patternSegments []string, paths []string) []string { + lastWildcard := lastSegmentIndex(patternSegments, "*") + seen := make(map[string]bool) + var out []string + for _, path := range paths { + pathSegments := splitPath(path) + if len(pathSegments) <= lastWildcard { + continue + } + concrete := make([]string, len(patternSegments)) + matches := true + for i, segment := range patternSegments { + if i > lastWildcard { + concrete[i] = segment + continue + } + if segment == "*" { + concrete[i] = pathSegments[i] + continue + } + concrete[i] = segment + if segment != pathSegments[i] { + matches = false + break + } + } + if !matches { + continue + } + pattern := joinEscapedSegments(concrete) + if !seen[pattern] { + seen[pattern] = true + out = append(out, pattern) + } + } + sort.Strings(out) + return out +} + +func joinEscapedSegments(segments []string) string { + escaped := make([]string, len(segments)) + for i, segment := range segments { + escaped[i] = escapeSegment(segment) + } + return strings.Join(escaped, ".") +} + +func lastSegmentIndex(segments []string, target string) int { + for i := len(segments) - 1; i >= 0; i-- { + if segments[i] == target { + return i + } + } + return -1 } // FlattenTree reduces the canonical tree to sorted leaf paths: maps recurse diff --git a/internal/schema/schema_test.go b/internal/schema/schema_test.go index 10782968..b2f653f1 100644 --- a/internal/schema/schema_test.go +++ b/internal/schema/schema_test.go @@ -208,6 +208,37 @@ func TestSchemaMissingEntriesRequiresWildcardChildWhenCollectionExists(t *testin } } +func TestSchemaMissingEntriesAcceptsEscapedWildcardCollectionMember(t *testing.T) { + s := Schema{ + "mountpoints.*.size": { + Type: "string", + Description: "mount size", + Platforms: []string{"linux"}, + }, + } + + got := s.MissingEntries([]string{`mountpoints./etc/resolv\.conf.size`}, "linux") + if len(got) != 0 { + t.Fatalf("MissingEntries() = %#v, want none", got) + } +} + +func TestSchemaMissingEntriesRequiresWildcardChildForEachCollectionMember(t *testing.T) { + s := Schema{ + "mountpoints.*.size": { + Type: "string", + Description: "mount size", + Platforms: []string{"linux"}, + }, + } + + got := s.MissingEntries([]string{"mountpoints.root.size", "mountpoints.data.device"}, "linux") + want := []string{"mountpoints.*.size"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("MissingEntries() = %#v, want %#v", got, want) + } +} + func TestSchemaMissingEntriesSkipsAbsentNestedWildcardCollection(t *testing.T) { s := Schema{ "a.*.b.*.c": { diff --git a/snapshot.go b/snapshot.go index e22f8493..d3c2ef5a 100644 --- a/snapshot.go +++ b/snapshot.go @@ -51,7 +51,11 @@ func As[T any](s *Snapshot, query string) (T, error) { if err != nil { return zero, err } - encoded, err := json.Marshal(jsonValue(value)) + jsonReady, err := jsonValue(value) + if err != nil { + return zero, fmt.Errorf("fact %q: encode canonical value: %w", query, err) + } + encoded, err := json.Marshal(jsonReady) if err != nil { return zero, fmt.Errorf("fact %q: encode canonical value: %w", query, err) } @@ -65,27 +69,43 @@ func As[T any](s *Snapshot, query string) (T, error) { // jsonValue rewrites map[any]any nodes (YAML decoding artifacts) into // map[string]any so the value round-trips through encoding/json. -func jsonValue(value any) any { +func jsonValue(value any) (any, error) { switch v := value.(type) { case map[string]any: out := make(map[string]any, len(v)) for key, item := range v { - out[key] = jsonValue(item) + value, err := jsonValue(item) + if err != nil { + return nil, err + } + out[key] = value } - return out + return out, nil case map[any]any: out := make(map[string]any, len(v)) for key, item := range v { - out[fmt.Sprint(key)] = jsonValue(item) + name := fmt.Sprint(key) + if _, exists := out[name]; exists { + return nil, fmt.Errorf("duplicate map key after string normalization: %q", name) + } + value, err := jsonValue(item) + if err != nil { + return nil, err + } + out[name] = value } - return out + return out, nil case []any: out := make([]any, len(v)) for i, item := range v { - out[i] = jsonValue(item) + value, err := jsonValue(item) + if err != nil { + return nil, err + } + out[i] = value } - return out + return out, nil default: - return value + return value, nil } } diff --git a/tools/freebsd-release-gate.sh b/tools/freebsd-release-gate.sh index 2b755441..d2174c0c 100755 --- a/tools/freebsd-release-gate.sh +++ b/tools/freebsd-release-gate.sh @@ -17,7 +17,7 @@ fi FACT_SET="os.name os.family os.release os.architecture os.hardware kernel.name \ kernel.release.full kernel.version.full kernel.release.major virtual is_virtual networking \ memory memory.system.total processors processors.count dmi system_uptime \ -load_averages mountpoints" +load_averages mountpoints disks partitions path" # shellcheck disable=SC2086 out="$("$FACTS_BIN" --json $FACT_SET)" @@ -38,10 +38,11 @@ printf '%s\n' "$out" | grep -Eq '"os.family"[[:space:]]*:[[:space:]]*"FreeBSD"' for key in os.release os.architecture os.hardware kernel.release.full kernel.version.full \ kernel.release.major virtual is_virtual networking memory memory.system.total \ processors processors.count dmi system_uptime load_averages \ - mountpoints; do - printf '%s\n' "$out" | grep -Eq "\"$key\"[[:space:]]*:" \ + mountpoints disks partitions path; do + key_re=$(printf '%s\n' "$key" | sed 's/\./\\./g') + printf '%s\n' "$out" | grep -Eq "\"$key_re\"[[:space:]]*:" \ || fail "missing fact $key" - printf '%s\n' "$out" | grep -Eq "\"$key\"[[:space:]]*:[[:space:]]*null" \ + printf '%s\n' "$out" | grep -Eq "\"$key_re\"[[:space:]]*:[[:space:]]*null" \ && fail "fact $key is null" done diff --git a/tools/illumos-release-gate.sh b/tools/illumos-release-gate.sh index b23f14ab..465a623a 100755 --- a/tools/illumos-release-gate.sh +++ b/tools/illumos-release-gate.sh @@ -14,6 +14,10 @@ if [ "$(uname -s)" != "SunOS" ]; then echo "illumos-release-gate.sh must run on illumos/SunOS" >&2 exit 1 fi +if ! [ -r /etc/release ] || ! grep -Eiq 'illumos|omnios|openindiana|smartos' /etc/release; then + echo "illumos-release-gate.sh must run on illumos/OmniOS, not Oracle Solaris" >&2 + exit 1 +fi FACT_SET="os.name os.family os.release os.architecture os.hardware kernel.name \ kernel.release.full kernel.version.full kernel.release.major virtual is_virtual \ diff --git a/tools/netbsd-release-gate.sh b/tools/netbsd-release-gate.sh index 61943874..dc3cbe0f 100755 --- a/tools/netbsd-release-gate.sh +++ b/tools/netbsd-release-gate.sh @@ -40,9 +40,10 @@ for key in os.release os.architecture os.hardware kernel.release.full kernel.ver kernel.release.major virtual is_virtual networking memory memory.system.total \ processors processors.count dmi system_uptime load_averages \ mountpoints disks partitions; do - printf '%s\n' "$out" | grep -Eq "\"$key\"[[:space:]]*:" \ + key_re=$(printf '%s\n' "$key" | sed 's/\./\\./g') + printf '%s\n' "$out" | grep -Eq "\"$key_re\"[[:space:]]*:" \ || fail "missing fact $key" - printf '%s\n' "$out" | grep -Eq "\"$key\"[[:space:]]*:[[:space:]]*null" \ + printf '%s\n' "$out" | grep -Eq "\"$key_re\"[[:space:]]*:[[:space:]]*null" \ && fail "fact $key is null" done diff --git a/tools/supportedfacts/main.go b/tools/supportedfacts/main.go index 82b757aa..e6743629 100644 --- a/tools/supportedfacts/main.go +++ b/tools/supportedfacts/main.go @@ -40,7 +40,11 @@ func renderDocs(schemaFile string) (map[string]string, error) { "docs/supported-facts/README.md": renderIndex(schema), } for _, p := range factschema.Platforms() { - docs["docs/supported-facts/"+p.ID+".md"] = renderPlatform(schema, p) + doc, err := renderPlatform(schema, p) + if err != nil { + return nil, err + } + docs["docs/supported-facts/"+p.ID+".md"] = doc } return docs, nil } @@ -57,12 +61,16 @@ func renderIndex(schema factschema.Schema) string { return b.String() } -func renderPlatform(schema factschema.Schema, p factschema.Platform) string { +func renderPlatform(schema factschema.Schema, p factschema.Platform) (string, error) { entries := schema.EntriesForPlatform(p.ID) + example, err := exampleOutput(p.ID) + if err != nil { + return "", err + } var b strings.Builder writeHeader(&b, p.Label+" Supported Facts") fmt.Fprintf(&b, "Generated from [`docs/schema/facts.yaml`](../schema/facts.yaml). `%s` entries may be absent on a host when their preconditions do not hold.\n\n", "conditional") - fmt.Fprintf(&b, "## Example Output\n\n```console\n$ facts --json\n%s\n```\n\n", exampleOutput(p.ID)) + fmt.Fprintf(&b, "## Example Output\n\n```console\n$ facts --json\n%s\n```\n\n", example) fmt.Fprintf(&b, "## Fact Contract\n\n%d schema entries include `%s`.\n\n", len(entries), p.ID) b.WriteString("| Fact | Type | Conditional | Description |\n| --- | --- | --- | --- |\n") for _, item := range entries { @@ -77,7 +85,7 @@ func renderPlatform(schema factschema.Schema, p factschema.Platform) string { escapeMarkdown(item.Entry.Description), ) } - return b.String() + return b.String(), nil } func writeHeader(b *strings.Builder, title string) { @@ -89,12 +97,16 @@ func escapeMarkdown(s string) string { return strings.ReplaceAll(s, "|", `\|`) } -func exampleOutput(platform string) string { +func exampleOutput(platform string) (string, error) { + raw, ok := exampleJSON[platform] + if !ok { + return "", fmt.Errorf("missing example JSON for platform %q", platform) + } var out bytes.Buffer - if err := json.Indent(&out, []byte(exampleJSON[platform]), "", " "); err != nil { - panic(err) + if err := json.Indent(&out, []byte(raw), "", " "); err != nil { + return "", fmt.Errorf("indent example JSON for platform %q: %w", platform, err) } - return out.String() + return out.String(), nil } var exampleJSON = map[string]string{ diff --git a/tools/supportedfacts/main_test.go b/tools/supportedfacts/main_test.go index 1756c124..9253a254 100644 --- a/tools/supportedfacts/main_test.go +++ b/tools/supportedfacts/main_test.go @@ -48,6 +48,24 @@ func TestRenderedDocsUseSchemaPlatformVocabulary(t *testing.T) { } } +func TestExampleOutputReturnsErrorForMissingPlatform(t *testing.T) { + if _, err := exampleOutput("missing-platform"); err == nil { + t.Fatal("exampleOutput(missing-platform) err = nil, want error") + } +} + +func TestExampleOutputReturnsErrorForMalformedJSON(t *testing.T) { + original := exampleJSON["linux"] + exampleJSON["linux"] = "{" + t.Cleanup(func() { + exampleJSON["linux"] = original + }) + + if _, err := exampleOutput("linux"); err == nil { + t.Fatal("exampleOutput(malformed) err = nil, want error") + } +} + func repoRoot(t *testing.T) string { t.Helper() dir, err := os.Getwd() diff --git a/tools/windows-release-gate.ps1 b/tools/windows-release-gate.ps1 index 471b54d0..f2413a88 100644 --- a/tools/windows-release-gate.ps1 +++ b/tools/windows-release-gate.ps1 @@ -6,6 +6,7 @@ param( ) $ErrorActionPreference = "Stop" +$RemoveBuiltFacts = $false $isWindowsHost = $IsWindows if ($null -eq $isWindowsHost) { @@ -19,7 +20,8 @@ if (-not $isWindowsHost) { try { if ($FactsPath -eq "") { - $FactsPath = Join-Path ([System.IO.Path]::GetTempPath()) "facts-release-gate.exe" + $FactsPath = Join-Path ([System.IO.Path]::GetTempPath()) ("facts-release-gate-{0}.exe" -f [System.Guid]::NewGuid()) + $RemoveBuiltFacts = $true & go build -o $FactsPath ./cmd/facts if ($LASTEXITCODE -ne 0) { throw "go build ./cmd/facts failed with exit code $LASTEXITCODE" @@ -141,4 +143,9 @@ catch { Write-Error "windows-release-gate failed: $_" exit 1 } +finally { + if ($RemoveBuiltFacts -and (Test-Path -LiteralPath $FactsPath)) { + Remove-Item -LiteralPath $FactsPath -Force -ErrorAction SilentlyContinue + } +} exit 0