diff --git a/pkgs/sops-install-secrets/main.go b/pkgs/sops-install-secrets/main.go index 1aeade9..3680e21 100644 --- a/pkgs/sops-install-secrets/main.go +++ b/pkgs/sops-install-secrets/main.go @@ -15,6 +15,7 @@ import ( "path/filepath" "strconv" "strings" + "syscall" "github.com/Mic92/sops-nix/pkgs/sshkeys" @@ -102,6 +103,28 @@ type appContext struct { checkMode CheckMode } +func secureSymlinkChown(symlinkToCheck, expectedTarget string, owner, group int) error { + fd, err := unix.Open(symlinkToCheck, unix.O_CLOEXEC|unix.O_PATH|unix.O_NOFOLLOW, 0) + if err != nil { + return fmt.Errorf("Failed to open %s: %w", symlinkToCheck, err) + } + defer unix.Close(fd) + + buf := make([]byte, len(expectedTarget) + 1) // oversize by one to detect trunc + n, err := unix.Readlinkat(fd, "", buf) + if err != nil { + return fmt.Errorf("couldn't readlinkat %s", symlinkToCheck) + } + if n > len(expectedTarget) || string(buf[:n]) != expectedTarget { + return fmt.Errorf("symlink %s does not point to %s", symlinkToCheck, expectedTarget) + } + err = unix.Fchownat(fd, "", owner, group, unix.AT_EMPTY_PATH) + if err != nil { + return fmt.Errorf("cannot change owner of '%s' to %d/%d: %w", symlinkToCheck, owner, group, err) + } + return nil +} + func readManifest(path string) (*manifest, error) { file, err := os.Open(path) if err != nil { @@ -116,6 +139,17 @@ func readManifest(path string) (*manifest, error) { return &m, nil } +func linksAreEqual(linkTarget, targetFile string, info os.FileInfo, secret *secret) bool { + validUG := true; + if stat, ok := info.Sys().(*syscall.Stat_t); ok { + validUG = validUG && int(stat.Uid) == secret.owner + validUG = validUG && int(stat.Gid) == secret.group + } else { + panic("Failed to cast fileInfo Sys() to *syscall.Stat_t. This is possibly an unsupported OS.") + } + return linkTarget == targetFile && validUG +} + func symlinkSecret(targetFile string, secret *secret) error { for { stat, err := os.Lstat(secret.Path) @@ -123,6 +157,9 @@ func symlinkSecret(targetFile string, secret *secret) error { if err := os.Symlink(targetFile, secret.Path); err != nil { return fmt.Errorf("Cannot create symlink '%s': %w", secret.Path, err) } + if err := secureSymlinkChown(secret.Path, targetFile, secret.owner, secret.group); err != nil { + return fmt.Errorf("Cannot chown symlink '%s': %w", secret.Path, err) + } return nil } else if err != nil { return fmt.Errorf("Cannot stat '%s': %w", secret.Path, err) @@ -133,7 +170,7 @@ func symlinkSecret(targetFile string, secret *secret) error { continue } else if err != nil { return fmt.Errorf("Cannot read symlink '%s': %w", secret.Path, err) - } else if linkTarget == targetFile { + } else if linksAreEqual(linkTarget, targetFile, stat, secret) { return nil } } @@ -154,7 +191,7 @@ func symlinkSecrets(targetDir string, secrets []secret) error { return fmt.Errorf("Cannot create parent directory of '%s': %w", secret.Path, err) } if err := symlinkSecret(targetFile, &secret); err != nil { - return err + return fmt.Errorf("Failed to symlink secret '%s': %w", secret.Path, err) } } return nil