// -*- Mode: Go; indent-tabs-mode: t -*-

/*
 * Copyright (C) 2016 Canonical Ltd
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License version 3 as
 * published by the Free Software Foundation.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 *
 */

package state_test

import (
	"errors"
	"fmt"
	"strings"
	"sync"
	"time"

	. "gopkg.in/check.v1"
	"gopkg.in/tomb.v2"

	"github.com/snapcore/snapd/overlord/state"
)

type taskRunnerSuite struct{}

var _ = Suite(&taskRunnerSuite{})

type stateBackend struct {
	mu               sync.Mutex
	ensureBefore     time.Duration
	ensureBeforeSeen chan<- bool
}

func (b *stateBackend) Checkpoint([]byte) error { return nil }

func (b *stateBackend) EnsureBefore(d time.Duration) {
	b.mu.Lock()
	if d < b.ensureBefore {
		b.ensureBefore = d
	}
	b.mu.Unlock()
	if b.ensureBeforeSeen != nil {
		b.ensureBeforeSeen <- true
	}
}

func (b *stateBackend) RequestRestart(t state.RestartType) {}

func ensureChange(c *C, r *state.TaskRunner, sb *stateBackend, chg *state.Change) {
	for i := 0; i < 10; i++ {
		sb.ensureBefore = time.Hour
		r.Ensure()
		r.Wait()
		chg.State().Lock()
		s := chg.Status()
		chg.State().Unlock()
		if s.Ready() {
			return
		}
		if sb.ensureBefore > 0 {
			break
		}
	}
	var statuses []string
	chg.State().Lock()
	for _, t := range chg.Tasks() {
		statuses = append(statuses, t.Summary()+":"+t.Status().String())
	}
	chg.State().Unlock()
	c.Fatalf("Change didn't reach final state without blocking: %s", strings.Join(statuses, " "))
}

// The result field encodes the expected order in which the task
// handlers will be called, assuming the provided setup is in place.
//
// Setup options:
//     <task>:was-<status>    - set task status before calling ensure (must be sensible)
//     <task>:(do|undo)-block - block handler until task tomb dies
//     <task>:(do|undo)-retry - return from handler with with state.Retry
//     <task>:(do|undo)-error - return from handler with an error
//     chg:abort              - call abort on the change
//
// Task wait order: ( t11 | t12 ) => ( t21 ) => ( t31 | t32 )
//
// Task t12 has no undo.
//
// Final task statuses are tested based on the resulting events list.
//
var sequenceTests = []struct{ setup, result string }{{
	setup:  "",
	result: "t11:do t12:do t21:do t31:do t32:do",
}, {
	setup:  "t11:was-done t12:was-doing",
	result: "t12:do t21:do t31:do t32:do",
}, {
	setup:  "t11:was-done t12:was-doing chg:abort",
	result: "t11:undo",
}, {
	setup:  "t12:do-retry",
	result: "t11:do t12:do t12:do-retry t12:do t21:do t31:do t32:do",
}, {
	setup:  "t11:do-block t12:do-error",
	result: "t11:do t11:do-block t12:do t12:do-error t11:do-unblock t11:undo",
}, {
	setup:  "t11:do-error t12:do-block",
	result: "t11:do t11:do-error t12:do t12:do-block t12:do-unblock",
}, {
	setup:  "t11:do-block t11:do-retry t12:do-error",
	result: "t11:do t11:do-block t12:do t12:do-error t11:do-unblock t11:do-retry t11:undo",
}, {
	setup:  "t11:do-error t12:do-block t12:do-retry",
	result: "t11:do t11:do-error t12:do t12:do-block t12:do-unblock t12:do-retry",
}, {
	setup:  "t31:do-error t21:undo-error",
	result: "t11:do t12:do t21:do t31:do t31:do-error t32:do t32:undo t21:undo t21:undo-error t11:undo",
}, {
	setup:  "t21:do-set-ready",
	result: "t11:do t12:do t21:do t31:do t32:do",
},
	{
		setup:  "t31:do-error t21:undo-set-ready",
		result: "t11:do t12:do t21:do t31:do t31:do-error t32:do t32:undo t21:undo t11:undo",
	},
}

func (ts *taskRunnerSuite) TestSequenceTests(c *C) {
	sb := &stateBackend{}
	st := state.New(sb)
	r := state.NewTaskRunner(st)
	defer r.Stop()

	ch := make(chan string, 256)
	fn := func(label string) state.HandlerFunc {
		return func(task *state.Task, tomb *tomb.Tomb) error {
			st.Lock()
			defer st.Unlock()
			ch <- task.Summary() + ":" + label
			var isSet bool
			if task.Get(label+"-block", &isSet) == nil && isSet {
				ch <- task.Summary() + ":" + label + "-block"
				st.Unlock()
				<-tomb.Dying()
				st.Lock()
				ch <- task.Summary() + ":" + label + "-unblock"
			}
			if task.Get(label+"-retry", &isSet) == nil && isSet {
				task.Set(label+"-retry", false)
				ch <- task.Summary() + ":" + label + "-retry"
				return &state.Retry{}
			}
			if task.Get(label+"-error", &isSet) == nil && isSet {
				ch <- task.Summary() + ":" + label + "-error"
				return errors.New("boom")
			}
			if task.Get(label+"-set-ready", &isSet) == nil && isSet {
				switch task.Status() {
				case state.DoingStatus:
					task.SetStatus(state.DoneStatus)
				case state.UndoingStatus:
					task.SetStatus(state.UndoneStatus)
				}
			}
			return nil
		}
	}
	r.AddHandler("do", fn("do"), nil)
	r.AddHandler("do-undo", fn("do"), fn("undo"))

	for _, test := range sequenceTests {
		st.Lock()

		// Delete previous changes.
		st.Prune(1, 1)

		chg := st.NewChange("install", "...")
		tasks := make(map[string]*state.Task)
		for _, name := range strings.Fields("t11 t12 t21 t31 t32") {
			if name == "t12" {
				tasks[name] = st.NewTask("do", name)
			} else {
				tasks[name] = st.NewTask("do-undo", name)
			}
			chg.AddTask(tasks[name])
		}
		tasks["t21"].WaitFor(tasks["t11"])
		tasks["t21"].WaitFor(tasks["t12"])
		tasks["t31"].WaitFor(tasks["t21"])
		tasks["t32"].WaitFor(tasks["t21"])
		st.Unlock()

		c.Logf("-----")
		c.Logf("Testing setup: %s", test.setup)

		statuses := make(map[string]state.Status)
		for s := state.DefaultStatus; s <= state.ErrorStatus; s++ {
			statuses[strings.ToLower(s.String())] = s
		}

		// Reset and prepare initial task state.
		st.Lock()
		for _, t := range chg.Tasks() {
			t.SetStatus(state.DefaultStatus)
			t.Set("do-error", false)
			t.Set("do-block", false)
			t.Set("undo-error", false)
			t.Set("undo-block", false)
		}
		for _, item := range strings.Fields(test.setup) {
			if item == "chg:abort" {
				chg.Abort()
				continue
			}
			kv := strings.Split(item, ":")
			if strings.HasPrefix(kv[1], "was-") {
				tasks[kv[0]].SetStatus(statuses[kv[1][4:]])
			} else {
				tasks[kv[0]].Set(kv[1], true)
			}
		}
		st.Unlock()

		// Run change until final.
		ensureChange(c, r, sb, chg)

		// Compute order of events observed.
		var events []string
		var done bool
		for !done {
			select {
			case ev := <-ch:
				events = append(events, ev)
				// Make t11/t12 and t31/t32 always show up in the
				// same order if they're next to each other.
				for i := len(events) - 2; i >= 0; i-- {
					prev := events[i]
					next := events[i+1]
					switch strings.Split(next, ":")[1] {
					case "do-unblock", "undo-unblock":
					default:
						if prev[1] == next[1] && prev[2] > next[2] {
							events[i], events[i+1] = next, prev
							continue
						}
					}
					break
				}
			default:
				done = true
			}
		}

		c.Logf("Expected result: %s", test.result)
		c.Assert(strings.Join(events, " "), Equals, test.result, Commentf("setup: %s", test.setup))

		// Compute final expected status for tasks.
		finalStatus := make(map[string]state.Status)
		// ... default when no handler is called
		for tname := range tasks {
			finalStatus[tname] = state.HoldStatus
		}
		// ... overwrite based on relevant setup
		for _, item := range strings.Fields(test.setup) {
			if item == "chg:abort" && strings.Contains(test.setup, "t12:was-doing") {
				// t12 has no undo so must hold if asked to abort when was doing.
				finalStatus["t12"] = state.HoldStatus
			}
			kv := strings.Split(item, ":")
			if !strings.HasPrefix(kv[1], "was-") {
				continue
			}
			switch strings.TrimPrefix(kv[1], "was-") {
			case "do", "doing", "done":
				finalStatus[kv[0]] = state.DoneStatus
			case "abort", "undo", "undoing", "undone":
				if kv[0] == "t12" {
					finalStatus[kv[0]] = state.DoneStatus // no undo for t12
				} else {
					finalStatus[kv[0]] = state.UndoneStatus
				}
			case "was-error":
				finalStatus[kv[0]] = state.ErrorStatus
			case "was-hold":
				finalStatus[kv[0]] = state.ErrorStatus
			}
		}
		// ... and overwrite based on events observed.
		for _, ev := range events {
			kv := strings.Split(ev, ":")
			switch kv[1] {
			case "do":
				finalStatus[kv[0]] = state.DoneStatus
			case "undo":
				finalStatus[kv[0]] = state.UndoneStatus
			case "do-error", "undo-error":
				finalStatus[kv[0]] = state.ErrorStatus
			case "do-retry":
				if kv[0] == "t12" && finalStatus["t11"] == state.ErrorStatus {
					// t12 has no undo so must hold if asked to abort on retry.
					finalStatus["t12"] = state.HoldStatus
				}
			}
		}

		st.Lock()
		var gotStatus, wantStatus []string
		for _, task := range chg.Tasks() {
			gotStatus = append(gotStatus, task.Summary()+":"+task.Status().String())
			wantStatus = append(wantStatus, task.Summary()+":"+finalStatus[task.Summary()].String())
		}
		st.Unlock()

		c.Logf("Expected statuses: %s", strings.Join(wantStatus, " "))
		comment := Commentf("calls: %s", test.result)
		c.Assert(strings.Join(gotStatus, " "), Equals, strings.Join(wantStatus, " "), comment)
	}
}

func (ts *taskRunnerSuite) TestExternalAbort(c *C) {
	sb := &stateBackend{}
	st := state.New(sb)
	r := state.NewTaskRunner(st)
	defer r.Stop()

	ch := make(chan bool)
	r.AddHandler("blocking", func(t *state.Task, tb *tomb.Tomb) error {
		ch <- true
		<-tb.Dying()
		return nil
	}, nil)

	st.Lock()
	chg := st.NewChange("install", "...")
	t := st.NewTask("blocking", "...")
	chg.AddTask(t)
	st.Unlock()

	r.Ensure()
	<-ch

	st.Lock()
	chg.Abort()
	st.Unlock()

	// The Abort above must make Ensure kill the task, or this will never end.
	ensureChange(c, r, sb, chg)
}

func (ts *taskRunnerSuite) TestStopHandlerJustFinishing(c *C) {
	sb := &stateBackend{}
	st := state.New(sb)
	r := state.NewTaskRunner(st)
	defer r.Stop()

	ch := make(chan bool)
	r.AddHandler("just-finish", func(t *state.Task, tb *tomb.Tomb) error {
		ch <- true
		<-tb.Dying()
		// just ignore and actually finishes
		return nil
	}, nil)

	st.Lock()
	chg := st.NewChange("install", "...")
	t := st.NewTask("just-finish", "...")
	chg.AddTask(t)
	st.Unlock()

	r.Ensure()
	<-ch
	r.Stop()

	st.Lock()
	defer st.Unlock()
	c.Check(t.Status(), Equals, state.DoneStatus)
}

func (ts *taskRunnerSuite) TestStopAskForRetry(c *C) {
	sb := &stateBackend{}
	st := state.New(sb)
	r := state.NewTaskRunner(st)
	defer r.Stop()

	ch := make(chan bool)
	r.AddHandler("ask-for-retry", func(t *state.Task, tb *tomb.Tomb) error {
		ch <- true
		<-tb.Dying()
		// ask for retry
		return &state.Retry{}
	}, nil)

	st.Lock()
	chg := st.NewChange("install", "...")
	t := st.NewTask("ask-for-retry", "...")
	chg.AddTask(t)
	st.Unlock()

	r.Ensure()
	<-ch
	r.Stop()

	st.Lock()
	defer st.Unlock()
	c.Check(t.Status(), Equals, state.DoingStatus)
}

func (ts *taskRunnerSuite) TestRetryAfterDuration(c *C) {
	ensureBeforeTick := make(chan bool, 1)
	sb := &stateBackend{
		ensureBefore:     time.Hour,
		ensureBeforeSeen: ensureBeforeTick,
	}
	st := state.New(sb)
	r := state.NewTaskRunner(st)
	defer r.Stop()

	ch := make(chan bool)
	ask := 0
	r.AddHandler("ask-for-retry", func(t *state.Task, _ *tomb.Tomb) error {
		ask++
		if ask == 1 {
			return &state.Retry{After: time.Minute}
		}
		ch <- true
		return nil
	}, nil)

	st.Lock()
	chg := st.NewChange("install", "...")
	t := st.NewTask("ask-for-retry", "...")
	chg.AddTask(t)
	st.Unlock()

	tock := time.Now()
	restore := state.MockTime(tock)
	defer restore()
	r.Ensure() // will run and be rescheduled in a minute
	select {
	case <-ensureBeforeTick:
	case <-time.After(2 * time.Second):
		c.Fatal("EnsureBefore wasn't called")
	}

	st.Lock()
	defer st.Unlock()
	c.Check(t.Status(), Equals, state.DoingStatus)

	c.Check(ask, Equals, 1)
	c.Check(sb.ensureBefore, Equals, 1*time.Minute)
	schedule := t.AtTime()
	c.Check(schedule.IsZero(), Equals, false)

	state.MockTime(tock.Add(5 * time.Second))
	sb.ensureBefore = time.Hour
	st.Unlock()
	r.Ensure() // too soon
	st.Lock()

	c.Check(t.Status(), Equals, state.DoingStatus)
	c.Check(ask, Equals, 1)
	c.Check(sb.ensureBefore, Equals, 55*time.Second)
	c.Check(t.AtTime().Equal(schedule), Equals, true)

	state.MockTime(schedule)
	sb.ensureBefore = time.Hour
	st.Unlock()
	r.Ensure() // time to run again
	select {
	case <-ch:
	case <-time.After(2 * time.Second):
		c.Fatal("handler wasn't called")
	}

	// wait for handler to finish
	r.Wait()

	st.Lock()
	c.Check(t.Status(), Equals, state.DoneStatus)
	c.Check(ask, Equals, 2)
	c.Check(sb.ensureBefore, Equals, time.Hour)
	c.Check(t.AtTime().IsZero(), Equals, true)
}

func (ts *taskRunnerSuite) TestTaskSerialization(c *C) {
	ensureBeforeTick := make(chan bool, 1)
	sb := &stateBackend{
		ensureBefore:     time.Hour,
		ensureBeforeSeen: ensureBeforeTick,
	}
	st := state.New(sb)
	r := state.NewTaskRunner(st)
	defer r.Stop()

	ch1 := make(chan bool)
	ch2 := make(chan bool)
	r.AddHandler("do1", func(t *state.Task, _ *tomb.Tomb) error {
		ch1 <- true
		ch1 <- true
		return nil
	}, nil)
	r.AddHandler("do2", func(t *state.Task, _ *tomb.Tomb) error {
		ch2 <- true
		return nil
	}, nil)

	// start first do1, and then do2 when nothing else is running
	startedDo1 := false
	r.SetBlocked(func(t *state.Task, running []*state.Task) bool {
		if t.Kind() == "do2" && (len(running) != 0 || !startedDo1) {
			return true
		}
		if t.Kind() == "do1" {
			startedDo1 = true
		}
		return false
	})

	st.Lock()
	chg := st.NewChange("install", "...")
	t1 := st.NewTask("do1", "...")
	chg.AddTask(t1)
	t2 := st.NewTask("do2", "...")
	chg.AddTask(t2)
	st.Unlock()

	r.Ensure() // will start only one, do1

	select {
	case <-ch1:
	case <-time.After(2 * time.Second):
		c.Fatal("do1 wasn't called")
	}

	c.Check(ensureBeforeTick, HasLen, 0)
	c.Check(ch2, HasLen, 0)

	r.Ensure() // won't yet start anything new

	c.Check(ensureBeforeTick, HasLen, 0)
	c.Check(ch2, HasLen, 0)

	// finish do1
	select {
	case <-ch1:
	case <-time.After(2 * time.Second):
		c.Fatal("do1 wasn't continued")
	}

	// getting an EnsureBefore 0 call
	select {
	case <-ensureBeforeTick:
	case <-time.After(2 * time.Second):
		c.Fatal("EnsureBefore wasn't called")
	}
	c.Check(sb.ensureBefore, Equals, time.Duration(0))

	r.Ensure() // will start do2

	select {
	case <-ch2:
	case <-time.After(2 * time.Second):
		c.Fatal("do2 wasn't called")
	}

	// no more EnsureBefore calls
	c.Check(ensureBeforeTick, HasLen, 0)
}

func (ts *taskRunnerSuite) TestPrematureChangeReady(c *C) {
	sb := &stateBackend{}
	st := state.New(sb)
	r := state.NewTaskRunner(st)
	defer r.Stop()

	ch := make(chan bool)
	r.AddHandler("block-undo", func(t *state.Task, tb *tomb.Tomb) error { return nil },
		func(t *state.Task, tb *tomb.Tomb) error {
			ch <- true
			<-ch
			return nil
		})
	r.AddHandler("fail", func(t *state.Task, tb *tomb.Tomb) error {
		return errors.New("BAM")
	}, nil)

	st.Lock()
	chg := st.NewChange("install", "...")
	t1 := st.NewTask("block-undo", "...")
	t2 := st.NewTask("fail", "...")
	chg.AddTask(t1)
	chg.AddTask(t2)
	st.Unlock()

	r.Ensure() // Error
	r.Wait()
	r.Ensure() // Block on undo
	<-ch

	defer func() {
		ch <- true
		r.Wait()
	}()

	st.Lock()
	defer st.Unlock()

	if chg.IsReady() || chg.Status().Ready() {
		c.Errorf("Change considered ready prematurely")
	}

	c.Assert(chg.Err(), IsNil)
}

func (ts *taskRunnerSuite) TestCleanup(c *C) {
	sb := &stateBackend{}
	st := state.New(sb)
	r := state.NewTaskRunner(st)
	defer r.Stop()

	r.AddHandler("clean-it", func(t *state.Task, tb *tomb.Tomb) error { return nil }, nil)
	r.AddHandler("other", func(t *state.Task, tb *tomb.Tomb) error { return nil }, nil)

	called := 0
	r.AddCleanup("clean-it", func(t *state.Task, tb *tomb.Tomb) error {
		called++
		if called == 1 {
			return fmt.Errorf("retry me")
		}
		return nil
	})

	st.Lock()
	chg := st.NewChange("install", "...")
	t1 := st.NewTask("clean-it", "...")
	t2 := st.NewTask("other", "...")
	chg.AddTask(t1)
	chg.AddTask(t2)
	st.Unlock()

	chgIsClean := func() bool {
		st.Lock()
		defer st.Unlock()
		return chg.IsClean()
	}

	// Mark tasks as done.
	ensureChange(c, r, sb, chg)

	// First time it errors, then it works, then it's ignored.
	c.Assert(chgIsClean(), Equals, false)
	c.Assert(called, Equals, 0)
	r.Ensure()
	r.Wait()
	c.Assert(chgIsClean(), Equals, false)
	c.Assert(called, Equals, 1)
	r.Ensure()
	r.Wait()
	c.Assert(chgIsClean(), Equals, true)
	c.Assert(called, Equals, 2)
	r.Ensure()
	r.Wait()
	c.Assert(chgIsClean(), Equals, true)
	c.Assert(called, Equals, 2)
}
