ensure temp file cleanup after tests

This commit is contained in:
nils måsén 2023-08-12 18:01:08 +02:00
parent e46cca23dc
commit abeed88313
2 changed files with 18 additions and 16 deletions

View file

@ -468,12 +468,14 @@ func GetSecretsFromFiles(rootCmd *cobra.Command) {
"notification-url", "notification-url",
} }
for _, secret := range secrets { for _, secret := range secrets {
getSecretFromFile(flags, secret) if err := getSecretFromFile(flags, secret); err != nil {
log.Fatalf("failed to get secret from flag %v: %s", secret, err)
}
} }
} }
// getSecretFromFile will check if the flag contains a reference to a file; if it does, replaces the value of the flag with the contents of the file. // getSecretFromFile will check if the flag contains a reference to a file; if it does, replaces the value of the flag with the contents of the file.
func getSecretFromFile(flags *pflag.FlagSet, secret string) { func getSecretFromFile(flags *pflag.FlagSet, secret string) error {
flag := flags.Lookup(secret) flag := flags.Lookup(secret)
if sliceValue, ok := flag.Value.(pflag.SliceValue); ok { if sliceValue, ok := flag.Value.(pflag.SliceValue); ok {
oldValues := sliceValue.GetSlice() oldValues := sliceValue.GetSlice()
@ -482,7 +484,7 @@ func getSecretFromFile(flags *pflag.FlagSet, secret string) {
if value != "" && isFile(value) { if value != "" && isFile(value) {
file, err := os.Open(value) file, err := os.Open(value)
if err != nil { if err != nil {
log.Fatal(err) return err
} }
scanner := bufio.NewScanner(file) scanner := bufio.NewScanner(file)
for scanner.Scan() { for scanner.Scan() {
@ -492,25 +494,26 @@ func getSecretFromFile(flags *pflag.FlagSet, secret string) {
} }
values = append(values, line) values = append(values, line)
} }
if err := file.Close(); err != nil {
return err
}
} else { } else {
values = append(values, value) values = append(values, value)
} }
} }
sliceValue.Replace(values) return sliceValue.Replace(values)
return
} }
value := flag.Value.String() value := flag.Value.String()
if value != "" && isFile(value) { if value != "" && isFile(value) {
file, err := os.ReadFile(value) content, err := os.ReadFile(value)
if err != nil { if err != nil {
log.Fatal(err) return err
}
err = flags.Set(secret, strings.TrimSpace(string(file)))
if err != nil {
log.Error(err)
} }
return flags.Set(secret, strings.TrimSpace(string(content)))
} }
return nil
} }
func isFile(s string) bool { func isFile(s string) bool {

View file

@ -1,13 +1,12 @@
package flags package flags
import ( import (
"os"
"testing"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"os"
"testing"
) )
func TestEnvConfig_Defaults(t *testing.T) { func TestEnvConfig_Defaults(t *testing.T) {
@ -60,9 +59,9 @@ func TestGetSecretsFromFilesWithFile(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Write the secret to the temporary file. // Write the secret to the temporary file.
secret := []byte(value) _, err = file.Write([]byte(value))
_, err = file.Write(secret)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, file.Close())
t.Setenv("WATCHTOWER_NOTIFICATION_EMAIL_SERVER_PASSWORD", file.Name()) t.Setenv("WATCHTOWER_NOTIFICATION_EMAIL_SERVER_PASSWORD", file.Name())