From 023c1a7d93ea6018daeb69712af6db89ebc0c1e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?nils=20m=C3=A5s=C3=A9n?= Date: Fri, 5 Jan 2024 19:12:11 +0100 Subject: [PATCH 1/4] fix(lifecycle): cleanup lifecycle - removes unwieldy SkipUpdate return value in favor of errors.Is - generalizes the code for all four phases - allows timeout to be defined for all phases - enables explicit unit in timeout label values (in addition to implicit minutes) --- internal/actions/mocks/client.go | 11 ++-- internal/util/time.go | 15 +++++ pkg/container/client.go | 32 ++++------ pkg/container/client_test.go | 2 +- pkg/container/container.go | 38 ------------ pkg/container/errors.go | 3 + pkg/container/metadata.go | 101 +++++++++++++++++++------------ pkg/lifecycle/lifecycle.go | 101 +++++++++---------------------- pkg/types/container.go | 9 +-- pkg/types/lifecycle.go | 27 +++++++++ 10 files changed, 160 insertions(+), 179 deletions(-) create mode 100644 internal/util/time.go create mode 100644 pkg/types/lifecycle.go diff --git a/internal/actions/mocks/client.go b/internal/actions/mocks/client.go index 737404a..b482475 100644 --- a/internal/actions/mocks/client.go +++ b/internal/actions/mocks/client.go @@ -5,6 +5,7 @@ import ( "fmt" "time" + c "github.com/containrrr/watchtower/pkg/container" t "github.com/containrrr/watchtower/pkg/types" ) @@ -72,16 +73,16 @@ func (client MockClient) GetContainer(_ t.ContainerID) (t.Container, error) { } // ExecuteCommand is a mock method -func (client MockClient) ExecuteCommand(_ t.ContainerID, command string, _ int) (SkipUpdate bool, err error) { +func (client MockClient) ExecuteCommand(_ t.ContainerID, command string, _ time.Duration) error { switch command { case "/PreUpdateReturn0.sh": - return false, nil + return nil case "/PreUpdateReturn1.sh": - return false, fmt.Errorf("command exited with code 1") + return fmt.Errorf("command exited with code 1") case "/PreUpdateReturn75.sh": - return true, nil + return c.ErrorLifecycleSkip default: - return false, nil + return nil } } diff --git a/internal/util/time.go b/internal/util/time.go new file mode 100644 index 0000000..3ae7c01 --- /dev/null +++ b/internal/util/time.go @@ -0,0 +1,15 @@ +package util + +import ( + "strconv" + "time" +) + +// ParseDuration parses the input string as a duration, treating a plain number as implicitly using the specified unit +func ParseDuration(input string, unitlessUnit time.Duration) (time.Duration, error) { + if unitless, err := strconv.Atoi(input); err == nil { + return unitlessUnit * time.Duration(unitless), nil + } + + return time.ParseDuration(input) +} diff --git a/pkg/container/client.go b/pkg/container/client.go index c6c37de..3af08db 100644 --- a/pkg/container/client.go +++ b/pkg/container/client.go @@ -31,7 +31,7 @@ type Client interface { StartContainer(t.Container) (t.ContainerID, error) RenameContainer(t.Container, string) error IsContainerStale(t.Container, t.UpdateParams) (stale bool, latestImage t.ImageID, err error) - ExecuteCommand(containerID t.ContainerID, command string, timeout int) (SkipUpdate bool, err error) + ExecuteCommand(containerID t.ContainerID, command string, timeout time.Duration) error RemoveImageByID(t.ImageID) error WarnOnHeadPullFailed(container t.Container) bool } @@ -439,7 +439,7 @@ func (client dockerClient) RemoveImageByID(id t.ImageID) error { return err } -func (client dockerClient) ExecuteCommand(containerID t.ContainerID, command string, timeout int) (SkipUpdate bool, err error) { +func (client dockerClient) ExecuteCommand(containerID t.ContainerID, command string, timeout time.Duration) error { bg := context.Background() clog := log.WithField("containerID", containerID) @@ -452,7 +452,7 @@ func (client dockerClient) ExecuteCommand(containerID t.ContainerID, command str exec, err := client.api.ContainerExecCreate(bg, string(containerID), execConfig) if err != nil { - return false, err + return err } response, attachErr := client.api.ContainerExecAttach(bg, exec.ID, types.ExecStartCheck{ @@ -467,7 +467,7 @@ func (client dockerClient) ExecuteCommand(containerID t.ContainerID, command str execStartCheck := types.ExecStartCheck{Detach: false, Tty: true} err = client.api.ContainerExecStart(bg, exec.ID, execStartCheck) if err != nil { - return false, err + return err } var output string @@ -484,24 +484,16 @@ func (client dockerClient) ExecuteCommand(containerID t.ContainerID, command str // Inspect the exec to get the exit code and print a message if the // exit code is not success. - skipUpdate, err := client.waitForExecOrTimeout(bg, exec.ID, output, timeout) - if err != nil { - return true, err - } - - return skipUpdate, nil + return client.waitForExecOrTimeout(bg, exec.ID, output, timeout) } -func (client dockerClient) waitForExecOrTimeout(bg context.Context, ID string, execOutput string, timeout int) (SkipUpdate bool, err error) { +func (client dockerClient) waitForExecOrTimeout(ctx context.Context, ID string, execOutput string, timeout time.Duration) error { const ExTempFail = 75 - var ctx context.Context - var cancel context.CancelFunc if timeout > 0 { - ctx, cancel = context.WithTimeout(bg, time.Duration(timeout)*time.Minute) + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, timeout) defer cancel() - } else { - ctx = bg } for { @@ -516,7 +508,7 @@ func (client dockerClient) waitForExecOrTimeout(bg context.Context, ID string, e }).Debug("Awaiting timeout or completion") if err != nil { - return false, err + return err } if execInspect.Running { time.Sleep(1 * time.Second) @@ -527,15 +519,15 @@ func (client dockerClient) waitForExecOrTimeout(bg context.Context, ID string, e } if execInspect.ExitCode == ExTempFail { - return true, nil + return ErrorLifecycleSkip } if execInspect.ExitCode > 0 { - return false, fmt.Errorf("command exited with code %v %s", execInspect.ExitCode, execOutput) + return fmt.Errorf("command exited with code %v", execInspect.ExitCode) } break } - return false, nil + return nil } func (client dockerClient) waitForStopOrTimeout(c t.Container, waitTime time.Duration) error { diff --git a/pkg/container/client_test.go b/pkg/container/client_test.go index 4e75409..2fea103 100644 --- a/pkg/container/client_test.go +++ b/pkg/container/client_test.go @@ -308,7 +308,7 @@ var _ = Describe("the client", func() { ), ) - _, err := client.ExecuteCommand(containerID, cmd, 1) + err := client.ExecuteCommand(containerID, cmd, 1) Expect(err).NotTo(HaveOccurred()) // Note: Since Execute requires opening up a raw TCP stream to the daemon for the output, this will fail // when using the mock API server. Regardless of the outcome, the log should include the container ID diff --git a/pkg/container/container.go b/pkg/container/container.go index 10ed677..000a33d 100644 --- a/pkg/container/container.go +++ b/pkg/container/container.go @@ -219,44 +219,6 @@ func (c Container) IsWatchtower() bool { return ContainsWatchtowerLabel(c.containerInfo.Config.Labels) } -// PreUpdateTimeout checks whether a container has a specific timeout set -// for how long the pre-update command is allowed to run. This value is expressed -// either as an integer, in minutes, or as 0 which will allow the command/script -// to run indefinitely. Users should be cautious with the 0 option, as that -// could result in watchtower waiting forever. -func (c Container) PreUpdateTimeout() int { - var minutes int - var err error - - val := c.getLabelValueOrEmpty(preUpdateTimeoutLabel) - - minutes, err = strconv.Atoi(val) - if err != nil || val == "" { - return 1 - } - - return minutes -} - -// PostUpdateTimeout checks whether a container has a specific timeout set -// for how long the post-update command is allowed to run. This value is expressed -// either as an integer, in minutes, or as 0 which will allow the command/script -// to run indefinitely. Users should be cautious with the 0 option, as that -// could result in watchtower waiting forever. -func (c Container) PostUpdateTimeout() int { - var minutes int - var err error - - val := c.getLabelValueOrEmpty(postUpdateTimeoutLabel) - - minutes, err = strconv.Atoi(val) - if err != nil || val == "" { - return 1 - } - - return minutes -} - // StopSignal returns the custom stop signal (if any) that is encoded in the // container's metadata. If the container has not specified a custom stop // signal, the empty string "" is returned. diff --git a/pkg/container/errors.go b/pkg/container/errors.go index 05dc722..2caa2a0 100644 --- a/pkg/container/errors.go +++ b/pkg/container/errors.go @@ -6,3 +6,6 @@ var errorNoImageInfo = errors.New("no available image info") var errorNoContainerInfo = errors.New("no available container info") var errorInvalidConfig = errors.New("container configuration missing or invalid") var errorLabelNotFound = errors.New("label was not found in container") + +// ErrorLifecycleSkip is returned by a lifecycle hook when the exit code of the command indicated that it ought to be skipped +var ErrorLifecycleSkip = errors.New("skipping container as the pre-update command returned exit code 75 (EX_TEMPFAIL)") diff --git a/pkg/container/metadata.go b/pkg/container/metadata.go index 8ac5f34..f15b89a 100644 --- a/pkg/container/metadata.go +++ b/pkg/container/metadata.go @@ -1,43 +1,28 @@ package container -import "strconv" +import ( + "errors" + "fmt" + "strconv" + "time" -const ( - watchtowerLabel = "com.centurylinklabs.watchtower" - signalLabel = "com.centurylinklabs.watchtower.stop-signal" - enableLabel = "com.centurylinklabs.watchtower.enable" - monitorOnlyLabel = "com.centurylinklabs.watchtower.monitor-only" - noPullLabel = "com.centurylinklabs.watchtower.no-pull" - dependsOnLabel = "com.centurylinklabs.watchtower.depends-on" - zodiacLabel = "com.centurylinklabs.zodiac.original-image" - scope = "com.centurylinklabs.watchtower.scope" - preCheckLabel = "com.centurylinklabs.watchtower.lifecycle.pre-check" - postCheckLabel = "com.centurylinklabs.watchtower.lifecycle.post-check" - preUpdateLabel = "com.centurylinklabs.watchtower.lifecycle.pre-update" - postUpdateLabel = "com.centurylinklabs.watchtower.lifecycle.post-update" - preUpdateTimeoutLabel = "com.centurylinklabs.watchtower.lifecycle.pre-update-timeout" - postUpdateTimeoutLabel = "com.centurylinklabs.watchtower.lifecycle.post-update-timeout" + "github.com/containrrr/watchtower/internal/util" + wt "github.com/containrrr/watchtower/pkg/types" + + "github.com/sirupsen/logrus" ) -// GetLifecyclePreCheckCommand returns the pre-check command set in the container metadata or an empty string -func (c Container) GetLifecyclePreCheckCommand() string { - return c.getLabelValueOrEmpty(preCheckLabel) -} - -// GetLifecyclePostCheckCommand returns the post-check command set in the container metadata or an empty string -func (c Container) GetLifecyclePostCheckCommand() string { - return c.getLabelValueOrEmpty(postCheckLabel) -} - -// GetLifecyclePreUpdateCommand returns the pre-update command set in the container metadata or an empty string -func (c Container) GetLifecyclePreUpdateCommand() string { - return c.getLabelValueOrEmpty(preUpdateLabel) -} - -// GetLifecyclePostUpdateCommand returns the post-update command set in the container metadata or an empty string -func (c Container) GetLifecyclePostUpdateCommand() string { - return c.getLabelValueOrEmpty(postUpdateLabel) -} +const ( + namespace = "com.centurylinklabs.watchtower" + watchtowerLabel = namespace + signalLabel = namespace + ".stop-signal" + enableLabel = namespace + ".enable" + monitorOnlyLabel = namespace + ".monitor-only" + noPullLabel = namespace + ".no-pull" + dependsOnLabel = namespace + ".depends-on" + zodiacLabel = "com.centurylinklabs.zodiac.original-image" + scope = namespace + ".scope" +) // ContainsWatchtowerLabel takes a map of labels and values and tells // the consumer whether it contains a valid watchtower instance label @@ -46,22 +31,62 @@ func ContainsWatchtowerLabel(labels map[string]string) bool { return ok && val == "true" } -func (c Container) getLabelValueOrEmpty(label string) string { +// GetLifecycleCommand returns the lifecycle command set in the container metadata or an empty string +func (c *Container) GetLifecycleCommand(phase wt.LifecyclePhase) string { + label := fmt.Sprintf("%v.lifecycle.%v", namespace, phase) + value, found := c.getLabelValue(label) + + if !found { + return "" + } + + return value +} + +// GetLifecycleTimeout checks whether a container has a specific timeout set +// for how long the lifecycle command is allowed to run. This value is expressed +// either as a duration, an integer (minutes implied), or as 0 which will allow the command/script +// to run indefinitely. Users should be cautious with the 0 option, as that +// could result in watchtower waiting forever. +func (c *Container) GetLifecycleTimeout(phase wt.LifecyclePhase) time.Duration { + label := fmt.Sprintf("%v.lifecycle.%v-timeout", namespace, phase) + timeout, err := c.getDurationLabelValue(label, time.Minute) + + if err != nil { + timeout = time.Minute + if !errors.Is(err, errorLabelNotFound) { + logrus.WithError(err).Errorf("could not parse timeout label value for %v lifecycle", phase) + } + } + + return timeout +} + +func (c *Container) getLabelValueOrEmpty(label string) string { if val, ok := c.containerInfo.Config.Labels[label]; ok { return val } return "" } -func (c Container) getLabelValue(label string) (string, bool) { +func (c *Container) getLabelValue(label string) (string, bool) { val, ok := c.containerInfo.Config.Labels[label] return val, ok } -func (c Container) getBoolLabelValue(label string) (bool, error) { +func (c *Container) getBoolLabelValue(label string) (bool, error) { if strVal, ok := c.containerInfo.Config.Labels[label]; ok { value, err := strconv.ParseBool(strVal) return value, err } return false, errorLabelNotFound } + +func (c *Container) getDurationLabelValue(label string, unitlessUnit time.Duration) (time.Duration, error) { + value, found := c.getLabelValue(label) + if !found || len(value) < 1 { + return 0, errorLabelNotFound + } + + return util.ParseDuration(value, unitlessUnit) +} diff --git a/pkg/lifecycle/lifecycle.go b/pkg/lifecycle/lifecycle.go index c0f962e..db93fc0 100644 --- a/pkg/lifecycle/lifecycle.go +++ b/pkg/lifecycle/lifecycle.go @@ -6,101 +6,60 @@ import ( log "github.com/sirupsen/logrus" ) -// ExecutePreChecks tries to run the pre-check lifecycle hook for all containers included by the current filter. -func ExecutePreChecks(client container.Client, params types.UpdateParams) { - containers, err := client.ListContainers(params.Filter) - if err != nil { - return - } - for _, currentContainer := range containers { - ExecutePreCheckCommand(client, currentContainer) - } -} - -// ExecutePostChecks tries to run the post-check lifecycle hook for all containers included by the current filter. -func ExecutePostChecks(client container.Client, params types.UpdateParams) { - containers, err := client.ListContainers(params.Filter) - if err != nil { - return - } - for _, currentContainer := range containers { - ExecutePostCheckCommand(client, currentContainer) - } -} +type ExecCommandFunc func(client container.Client, container types.Container) // ExecutePreCheckCommand tries to run the pre-check lifecycle hook for a single container. func ExecutePreCheckCommand(client container.Client, container types.Container) { - clog := log.WithField("container", container.Name()) - command := container.GetLifecyclePreCheckCommand() - if len(command) == 0 { - clog.Debug("No pre-check command supplied. Skipping") - return - } - - clog.Debug("Executing pre-check command.") - _, err := client.ExecuteCommand(container.ID(), command, 1) + err := ExecuteLifeCyclePhaseCommand(types.PreCheck, client, container) if err != nil { - clog.Error(err) + log.WithField("container", container.Name()).Error(err) } } // ExecutePostCheckCommand tries to run the post-check lifecycle hook for a single container. func ExecutePostCheckCommand(client container.Client, container types.Container) { - clog := log.WithField("container", container.Name()) - command := container.GetLifecyclePostCheckCommand() - if len(command) == 0 { - clog.Debug("No post-check command supplied. Skipping") - return - } - - clog.Debug("Executing post-check command.") - _, err := client.ExecuteCommand(container.ID(), command, 1) + err := ExecuteLifeCyclePhaseCommand(types.PostCheck, client, container) if err != nil { - clog.Error(err) + log.WithField("container", container.Name()).Error(err) } } // ExecutePreUpdateCommand tries to run the pre-update lifecycle hook for a single container. -func ExecutePreUpdateCommand(client container.Client, container types.Container) (SkipUpdate bool, err error) { - timeout := container.PreUpdateTimeout() - command := container.GetLifecyclePreUpdateCommand() - clog := log.WithField("container", container.Name()) - - if len(command) == 0 { - clog.Debug("No pre-update command supplied. Skipping") - return false, nil - } - - if !container.IsRunning() || container.IsRestarting() { - clog.Debug("Container is not running. Skipping pre-update command.") - return false, nil - } - - clog.Debug("Executing pre-update command.") - return client.ExecuteCommand(container.ID(), command, timeout) +func ExecutePreUpdateCommand(client container.Client, container types.Container) error { + return ExecuteLifeCyclePhaseCommand(types.PreUpdate, client, container) } // ExecutePostUpdateCommand tries to run the post-update lifecycle hook for a single container. func ExecutePostUpdateCommand(client container.Client, newContainerID types.ContainerID) { newContainer, err := client.GetContainer(newContainerID) - timeout := newContainer.PostUpdateTimeout() - if err != nil { log.WithField("containerID", newContainerID.ShortID()).Error(err) return } - clog := log.WithField("container", newContainer.Name()) - - command := newContainer.GetLifecyclePostUpdateCommand() - if len(command) == 0 { - clog.Debug("No post-update command supplied. Skipping") - return - } - - clog.Debug("Executing post-update command.") - _, err = client.ExecuteCommand(newContainerID, command, timeout) + err = ExecuteLifeCyclePhaseCommand(types.PostUpdate, client, newContainer) if err != nil { - clog.Error(err) + log.WithField("container", newContainer.Name()).Error(err) } } + +// ExecuteLifeCyclePhaseCommand tries to run the corresponding lifecycle hook for a single container. +func ExecuteLifeCyclePhaseCommand(phase types.LifecyclePhase, client container.Client, container types.Container) error { + + timeout := container.GetLifecycleTimeout(phase) + command := container.GetLifecycleCommand(phase) + clog := log.WithField("container", container.Name()) + + if len(command) == 0 { + clog.Debugf("No %v command supplied. Skipping", phase) + return nil + } + + if !container.IsRunning() || container.IsRestarting() { + clog.Debugf("Container is not running. Skipping %v command.", phase) + return nil + } + + clog.Debugf("Executing %v command.", phase) + return client.ExecuteCommand(container.ID(), command, timeout) +} diff --git a/pkg/types/container.go b/pkg/types/container.go index 8a22f44..cc8be9f 100644 --- a/pkg/types/container.go +++ b/pkg/types/container.go @@ -2,6 +2,7 @@ package types import ( "strings" + "time" "github.com/docker/docker/api/types" dc "github.com/docker/docker/api/types/container" @@ -60,18 +61,14 @@ type Container interface { StopSignal() string HasImageInfo() bool ImageInfo() *types.ImageInspect - GetLifecyclePreCheckCommand() string - GetLifecyclePostCheckCommand() string - GetLifecyclePreUpdateCommand() string - GetLifecyclePostUpdateCommand() string + GetLifecycleCommand(LifecyclePhase) string + GetLifecycleTimeout(LifecyclePhase) time.Duration VerifyConfiguration() error SetStale(bool) IsStale() bool IsNoPull(UpdateParams) bool SetLinkedToRestarting(bool) IsLinkedToRestarting() bool - PreUpdateTimeout() int - PostUpdateTimeout() int IsRestarting() bool GetCreateConfig() *dc.Config GetCreateHostConfig() *dc.HostConfig diff --git a/pkg/types/lifecycle.go b/pkg/types/lifecycle.go new file mode 100644 index 0000000..2c7b46c --- /dev/null +++ b/pkg/types/lifecycle.go @@ -0,0 +1,27 @@ +package types + +import "fmt" + +type LifecyclePhase int + +const ( + PreCheck LifecyclePhase = iota + PreUpdate + PostUpdate + PostCheck +) + +func (p LifecyclePhase) String() string { + switch p { + case PreCheck: + return "pre-check" + case PreUpdate: + return "pre-update" + case PostUpdate: + return "post-update" + case PostCheck: + return "post-check" + default: + return fmt.Sprintf("invalid(%d)", p) + } +} From cb8e86d7051bb98b3495ed17c2a0068970490820 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?nils=20m=C3=A5s=C3=A9n?= Date: Fri, 5 Jan 2024 19:24:11 +0100 Subject: [PATCH 2/4] fix(container): rename Stale to MarkedForUpdate renames the container.Stale field to what it's actually used for, as staleness is not the only factor used to decide whether a container should be updated anymore also hides the private field along with linkedToRestarting --- internal/actions/update_test.go | 2 +- pkg/container/container.go | 26 +++++++++++++------------- pkg/types/container.go | 4 ++-- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/internal/actions/update_test.go b/internal/actions/update_test.go index 9209dcd..c09d4ba 100644 --- a/internal/actions/update_test.go +++ b/internal/actions/update_test.go @@ -371,7 +371,7 @@ var _ = Describe("the update action", func() { ExposedPorts: map[nat.Port]struct{}{}, }) - provider.SetStale(true) + provider.SetMarkedForUpdate(true) consumer := CreateMockContainerWithConfig( "test-container-consumer", diff --git a/pkg/container/container.go b/pkg/container/container.go index 000a33d..91067da 100644 --- a/pkg/container/container.go +++ b/pkg/container/container.go @@ -27,31 +27,31 @@ func NewContainer(containerInfo *types.ContainerJSON, imageInfo *types.ImageInsp // Container represents a running Docker container. type Container struct { - LinkedToRestarting bool - Stale bool + linkedToRestarting bool + markedForUpdate bool containerInfo *types.ContainerJSON imageInfo *types.ImageInspect } -// IsLinkedToRestarting returns the current value of the LinkedToRestarting field for the container +// IsLinkedToRestarting returns the current value of the linkedToRestarting field for the container func (c *Container) IsLinkedToRestarting() bool { - return c.LinkedToRestarting + return c.linkedToRestarting } -// IsStale returns the current value of the Stale field for the container -func (c *Container) IsStale() bool { - return c.Stale +// IsMarkedForUpdate returns the current value of the markedForUpdate field for the container +func (c *Container) IsMarkedForUpdate() bool { + return c.markedForUpdate } -// SetLinkedToRestarting sets the LinkedToRestarting field for the container +// SetLinkedToRestarting sets the linkedToRestarting field for the container func (c *Container) SetLinkedToRestarting(value bool) { - c.LinkedToRestarting = value + c.linkedToRestarting = value } -// SetStale implements sets the Stale field for the container -func (c *Container) SetStale(value bool) { - c.Stale = value +// SetMarkedForUpdate sets the markedForUpdate field for the container +func (c *Container) SetMarkedForUpdate(value bool) { + c.markedForUpdate = value } // ContainerInfo fetches JSON info for the container @@ -208,7 +208,7 @@ func (c Container) Links() []string { // ToRestart return whether the container should be restarted, either because // is stale or linked to another stale container. func (c Container) ToRestart() bool { - return c.Stale || c.LinkedToRestarting + return c.markedForUpdate || c.linkedToRestarting } // IsWatchtower returns a boolean flag indicating whether or not the current diff --git a/pkg/types/container.go b/pkg/types/container.go index cc8be9f..3c00fb0 100644 --- a/pkg/types/container.go +++ b/pkg/types/container.go @@ -64,8 +64,8 @@ type Container interface { GetLifecycleCommand(LifecyclePhase) string GetLifecycleTimeout(LifecyclePhase) time.Duration VerifyConfiguration() error - SetStale(bool) - IsStale() bool + SetMarkedForUpdate(bool) + IsMarkedForUpdate() bool IsNoPull(UpdateParams) bool SetLinkedToRestarting(bool) IsLinkedToRestarting() bool From a6949dede9dd4b7c6f4b469c963256f4e7012265 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?nils=20m=C3=A5s=C3=A9n?= Date: Fri, 5 Jan 2024 22:57:37 +0100 Subject: [PATCH 3/4] refactor(update): clean up actions/update - move common arguments to a shared struct - remove unused fields - fix outdated names - improve logging/error handling --- internal/actions/mocks/client.go | 2 +- internal/actions/update.go | 184 ++++++++++++++++++------------- pkg/container/client.go | 5 +- pkg/lifecycle/lifecycle.go | 23 ---- 4 files changed, 112 insertions(+), 102 deletions(-) diff --git a/internal/actions/mocks/client.go b/internal/actions/mocks/client.go index b482475..8af1ac6 100644 --- a/internal/actions/mocks/client.go +++ b/internal/actions/mocks/client.go @@ -87,7 +87,7 @@ func (client MockClient) ExecuteCommand(_ t.ContainerID, command string, _ time. } // IsContainerStale is true if not explicitly stated in TestData for the mock client -func (client MockClient) IsContainerStale(cont t.Container, params t.UpdateParams) (bool, t.ImageID, error) { +func (client MockClient) IsContainerStale(cont t.Container, _ t.UpdateParams) (bool, t.ImageID, error) { stale, found := client.TestData.Staleness[cont.Name()] if !found { stale = true diff --git a/internal/actions/update.go b/internal/actions/update.go index 8853c6e..440f4a0 100644 --- a/internal/actions/update.go +++ b/internal/actions/update.go @@ -2,6 +2,7 @@ package actions import ( "errors" + "fmt" "github.com/containrrr/watchtower/internal/util" "github.com/containrrr/watchtower/pkg/container" @@ -9,37 +10,52 @@ import ( "github.com/containrrr/watchtower/pkg/session" "github.com/containrrr/watchtower/pkg/sorter" "github.com/containrrr/watchtower/pkg/types" + log "github.com/sirupsen/logrus" ) +type updateSession struct { + client container.Client + params types.UpdateParams + progress *session.Progress +} + // Update looks at the running Docker containers to see if any of the images // used to start those containers have been updated. If a change is detected in // any of the images, the associated containers are stopped and restarted with // the new image. func Update(client container.Client, params types.UpdateParams) (types.Report, error) { - log.Debug("Checking containers for updated images") - progress := &session.Progress{} - staleCount := 0 + log.Debug("Starting new update session") + us := updateSession{client: client, params: params, progress: &session.Progress{}} - if params.LifecycleHooks { - lifecycle.ExecutePreChecks(client, params) - } + us.TryExecuteLifecycleCommands(types.PreCheck) - containers, err := client.ListContainers(params.Filter) - if err != nil { + if err := us.run(); err != nil { return nil, err } - staleCheckFailed := 0 + us.TryExecuteLifecycleCommands(types.PostCheck) + + return us.progress.Report(), nil +} + +func (us *updateSession) run() (err error) { + + containers, err := us.client.ListContainers(us.params.Filter) + if err != nil { + return err + } for i, targetContainer := range containers { - stale, newestImage, err := client.IsContainerStale(targetContainer, params) - shouldUpdate := stale && !params.NoRestart && !targetContainer.IsMonitorOnly(params) + stale, newestImage, err := us.client.IsContainerStale(targetContainer, us.params) + shouldUpdate := stale && !us.params.NoRestart && !targetContainer.IsMonitorOnly(us.params) + if err == nil && shouldUpdate { // Check to make sure we have all the necessary information for recreating the container err = targetContainer.VerifyConfiguration() - // If the image information is incomplete and trace logging is enabled, log it for further diagnosis if err != nil && log.IsLevelEnabled(log.TraceLevel) { + // If the image information is incomplete and trace logging is enabled, log it for further diagnosis + log.WithError(err).Trace("Cannot obtain enough information to recreate container") imageInfo := targetContainer.ImageInfo() log.Tracef("Image info: %#v", imageInfo) log.Tracef("Container info: %#v", targetContainer.ContainerInfo()) @@ -51,62 +67,52 @@ func Update(client container.Client, params types.UpdateParams) (types.Report, e if err != nil { log.Infof("Unable to update container %q: %v. Proceeding to next.", targetContainer.Name(), err) - stale = false - staleCheckFailed++ - progress.AddSkipped(targetContainer, err) + us.progress.AddSkipped(targetContainer, err) + containers[i].SetMarkedForUpdate(false) } else { - progress.AddScanned(targetContainer, newestImage) - } - containers[i].SetStale(stale) - - if stale { - staleCount++ + us.progress.AddScanned(targetContainer, newestImage) + containers[i].SetMarkedForUpdate(shouldUpdate) } } containers, err = sorter.SortByDependencies(containers) if err != nil { - return nil, err + return fmt.Errorf("failed to sort containers for updating: %v", err) } UpdateImplicitRestart(containers) var containersToUpdate []types.Container for _, c := range containers { - if !c.IsMonitorOnly(params) { + if c.ToRestart() { containersToUpdate = append(containersToUpdate, c) - progress.MarkForUpdate(c.ID()) + us.progress.MarkForUpdate(c.ID()) } } - if params.RollingRestart { - progress.UpdateFailed(performRollingRestart(containersToUpdate, client, params)) + if us.params.RollingRestart { + us.performRollingRestart(containersToUpdate) } else { - failedStop, stoppedImages := stopContainersInReversedOrder(containersToUpdate, client, params) - progress.UpdateFailed(failedStop) - failedStart := restartContainersInSortedOrder(containersToUpdate, client, params, stoppedImages) - progress.UpdateFailed(failedStart) + stoppedImages := us.stopContainersInReversedOrder(containersToUpdate) + us.restartContainersInSortedOrder(containersToUpdate, stoppedImages) } - if params.LifecycleHooks { - lifecycle.ExecutePostChecks(client, params) - } - return progress.Report(), nil + return nil } -func performRollingRestart(containers []types.Container, client container.Client, params types.UpdateParams) map[types.ContainerID]error { +func (us *updateSession) performRollingRestart(containers []types.Container) { cleanupImageIDs := make(map[types.ImageID]bool, len(containers)) failed := make(map[types.ContainerID]error, len(containers)) for i := len(containers) - 1; i >= 0; i-- { if containers[i].ToRestart() { - err := stopStaleContainer(containers[i], client, params) + err := us.stopContainer(containers[i]) if err != nil { failed[containers[i].ID()] = err } else { - if err := restartStaleContainer(containers[i], client, params); err != nil { + if err := us.restartContainer(containers[i]); err != nil { failed[containers[i].ID()] = err - } else if containers[i].IsStale() { + } else if containers[i].IsMarkedForUpdate() { // Only add (previously) stale containers' images to cleanup cleanupImageIDs[containers[i].ImageID()] = true } @@ -114,17 +120,17 @@ func performRollingRestart(containers []types.Container, client container.Client } } - if params.Cleanup { - cleanupImages(client, cleanupImageIDs) + if us.params.Cleanup { + us.cleanupImages(cleanupImageIDs) } - return failed + us.progress.UpdateFailed(failed) } -func stopContainersInReversedOrder(containers []types.Container, client container.Client, params types.UpdateParams) (failed map[types.ContainerID]error, stopped map[types.ImageID]bool) { - failed = make(map[types.ContainerID]error, len(containers)) +func (us *updateSession) stopContainersInReversedOrder(containers []types.Container) (stopped map[types.ImageID]bool) { + failed := make(map[types.ContainerID]error, len(containers)) stopped = make(map[types.ImageID]bool, len(containers)) for i := len(containers) - 1; i >= 0; i-- { - if err := stopStaleContainer(containers[i], client, params); err != nil { + if err := us.stopContainer(containers[i]); err != nil { failed[containers[i].ID()] = err } else { // NOTE: If a container is restarted due to a dependency this might be empty @@ -132,47 +138,51 @@ func stopContainersInReversedOrder(containers []types.Container, client containe } } - return + us.progress.UpdateFailed(failed) + + return stopped } -func stopStaleContainer(container types.Container, client container.Client, params types.UpdateParams) error { - if container.IsWatchtower() { - log.Debugf("This is the watchtower container %s", container.Name()) +func (us *updateSession) stopContainer(c types.Container) error { + if c.IsWatchtower() { + log.Debugf("This is the watchtower container %s", c.Name()) return nil } - if !container.ToRestart() { + if !c.ToRestart() { return nil } // Perform an additional check here to prevent us from stopping a linked container we cannot restart - if container.IsLinkedToRestarting() { - if err := container.VerifyConfiguration(); err != nil { + if c.IsLinkedToRestarting() { + if err := c.VerifyConfiguration(); err != nil { return err } } - if params.LifecycleHooks { - skipUpdate, err := lifecycle.ExecutePreUpdateCommand(client, container) + if us.params.LifecycleHooks { + err := lifecycle.ExecuteLifeCyclePhaseCommand(types.PreUpdate, us.client, c) if err != nil { + + if errors.Is(err, container.ErrorLifecycleSkip) { + log.Debug(err) + return err + } + log.Error(err) log.Info("Skipping container as the pre-update command failed") return err } - if skipUpdate { - log.Debug("Skipping container as the pre-update command returned exit code 75 (EX_TEMPFAIL)") - return errors.New("skipping container as the pre-update command returned exit code 75 (EX_TEMPFAIL)") - } } - if err := client.StopContainer(container, params.Timeout); err != nil { + if err := us.client.StopContainer(c, us.params.Timeout); err != nil { log.Error(err) return err } return nil } -func restartContainersInSortedOrder(containers []types.Container, client container.Client, params types.UpdateParams, stoppedImages map[types.ImageID]bool) map[types.ContainerID]error { +func (us *updateSession) restartContainersInSortedOrder(containers []types.Container, stoppedImages map[types.ImageID]bool) { cleanupImageIDs := make(map[types.ImageID]bool, len(containers)) failed := make(map[types.ContainerID]error, len(containers)) @@ -181,58 +191,58 @@ func restartContainersInSortedOrder(containers []types.Container, client contain continue } if stoppedImages[c.SafeImageID()] { - if err := restartStaleContainer(c, client, params); err != nil { + if err := us.restartContainer(c); err != nil { failed[c.ID()] = err - } else if c.IsStale() { + } else if c.IsMarkedForUpdate() { // Only add (previously) stale containers' images to cleanup cleanupImageIDs[c.ImageID()] = true } } } - if params.Cleanup { - cleanupImages(client, cleanupImageIDs) + if us.params.Cleanup { + us.cleanupImages(cleanupImageIDs) } - return failed + us.progress.UpdateFailed(failed) } -func cleanupImages(client container.Client, imageIDs map[types.ImageID]bool) { +func (us *updateSession) cleanupImages(imageIDs map[types.ImageID]bool) { for imageID := range imageIDs { if imageID == "" { continue } - if err := client.RemoveImageByID(imageID); err != nil { + if err := us.client.RemoveImageByID(imageID); err != nil { log.Error(err) } } } -func restartStaleContainer(container types.Container, client container.Client, params types.UpdateParams) error { - // Since we can't shutdown a watchtower container immediately, we need to - // start the new one while the old one is still running. This prevents us - // from re-using the same container name so we first rename the current - // instance so that the new one can adopt the old name. +func (us *updateSession) restartContainer(container types.Container) error { if container.IsWatchtower() { - if err := client.RenameContainer(container, util.RandName()); err != nil { + // Since we can't shut down a watchtower container immediately, we need to + // start the new one while the old one is still running. This prevents us + // from re-using the same container name, so we first rename the current + // instance so that the new one can adopt the old name. + if err := us.client.RenameContainer(container, util.RandName()); err != nil { log.Error(err) return nil } } - if !params.NoRestart { - if newContainerID, err := client.StartContainer(container); err != nil { + if !us.params.NoRestart { + if newContainerID, err := us.client.StartContainer(container); err != nil { log.Error(err) return err - } else if container.ToRestart() && params.LifecycleHooks { - lifecycle.ExecutePostUpdateCommand(client, newContainerID) + } else if container.ToRestart() && us.params.LifecycleHooks { + lifecycle.ExecutePostUpdateCommand(us.client, newContainerID) } } return nil } // UpdateImplicitRestart iterates through the passed containers, setting the -// `LinkedToRestarting` flag if any of it's linked containers are marked for restart +// `linkedToRestarting` flag if any of its linked containers are marked for restart func UpdateImplicitRestart(containers []types.Container) { for ci, c := range containers { @@ -265,3 +275,23 @@ func linkedContainerMarkedForRestart(links []string, containers []types.Containe } return "" } + +// TryExecuteLifecycleCommands tries to run the corresponding lifecycle hook for all containers included by the current filter. +func (us *updateSession) TryExecuteLifecycleCommands(phase types.LifecyclePhase) { + if !us.params.LifecycleHooks { + return + } + + containers, err := us.client.ListContainers(us.params.Filter) + if err != nil { + log.WithError(err).Warn("Skipping lifecycle commands. Failed to list containers.") + return + } + + for _, c := range containers { + err := lifecycle.ExecuteLifeCyclePhaseCommand(phase, us.client, c) + if err != nil { + log.WithField("container", c.Name()).Error(err) + } + } +} diff --git a/pkg/container/client.go b/pkg/container/client.go index 3af08db..e10c268 100644 --- a/pkg/container/client.go +++ b/pkg/container/client.go @@ -396,7 +396,10 @@ func (client dockerClient) PullImage(ctx context.Context, container t.Container) return err } - defer response.Close() + defer func() { + _ = response.Close() + }() + // the pull request will be aborted prematurely unless the response is read if _, err = io.ReadAll(response); err != nil { log.Error(err) diff --git a/pkg/lifecycle/lifecycle.go b/pkg/lifecycle/lifecycle.go index db93fc0..a0c22ac 100644 --- a/pkg/lifecycle/lifecycle.go +++ b/pkg/lifecycle/lifecycle.go @@ -6,29 +6,6 @@ import ( log "github.com/sirupsen/logrus" ) -type ExecCommandFunc func(client container.Client, container types.Container) - -// ExecutePreCheckCommand tries to run the pre-check lifecycle hook for a single container. -func ExecutePreCheckCommand(client container.Client, container types.Container) { - err := ExecuteLifeCyclePhaseCommand(types.PreCheck, client, container) - if err != nil { - log.WithField("container", container.Name()).Error(err) - } -} - -// ExecutePostCheckCommand tries to run the post-check lifecycle hook for a single container. -func ExecutePostCheckCommand(client container.Client, container types.Container) { - err := ExecuteLifeCyclePhaseCommand(types.PostCheck, client, container) - if err != nil { - log.WithField("container", container.Name()).Error(err) - } -} - -// ExecutePreUpdateCommand tries to run the pre-update lifecycle hook for a single container. -func ExecutePreUpdateCommand(client container.Client, container types.Container) error { - return ExecuteLifeCyclePhaseCommand(types.PreUpdate, client, container) -} - // ExecutePostUpdateCommand tries to run the post-update lifecycle hook for a single container. func ExecutePostUpdateCommand(client container.Client, newContainerID types.ContainerID) { newContainer, err := client.GetContainer(newContainerID) From a42eb28f3914901e00e72a80c1d25c864cc72567 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?nils=20m=C3=A5s=C3=A9n?= Date: Sat, 6 Jan 2024 14:39:55 +0100 Subject: [PATCH 4/4] fix broken tests --- pkg/container/client_test.go | 2 +- pkg/container/container_test.go | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/pkg/container/client_test.go b/pkg/container/client_test.go index 2fea103..687c1a0 100644 --- a/pkg/container/client_test.go +++ b/pkg/container/client_test.go @@ -308,7 +308,7 @@ var _ = Describe("the client", func() { ), ) - err := client.ExecuteCommand(containerID, cmd, 1) + err := client.ExecuteCommand(containerID, cmd, time.Minute) Expect(err).NotTo(HaveOccurred()) // Note: Since Execute requires opening up a raw TCP stream to the daemon for the output, this will fail // when using the mock API server. Regardless of the outcome, the log should include the container ID diff --git a/pkg/container/container_test.go b/pkg/container/container_test.go index a129afe..ce6e960 100644 --- a/pkg/container/container_test.go +++ b/pkg/container/container_test.go @@ -6,6 +6,7 @@ import ( "github.com/docker/go-connections/nat" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" + "time" ) var _ = Describe("the container", func() { @@ -380,10 +381,10 @@ var _ = Describe("the container", func() { "com.centurylinklabs.watchtower.lifecycle.pre-update-timeout": "3", "com.centurylinklabs.watchtower.lifecycle.post-update-timeout": "5", })) - preTimeout := c.PreUpdateTimeout() - Expect(preTimeout).To(Equal(3)) - postTimeout := c.PostUpdateTimeout() - Expect(postTimeout).To(Equal(5)) + preTimeout := c.GetLifecycleTimeout(types.PreUpdate) + Expect(preTimeout).To(Equal(3 * time.Minute)) + postTimeout := c.GetLifecycleTimeout(types.PostUpdate) + Expect(postTimeout).To(Equal(5 * time.Minute)) }) })