// Copyright 2018 The Cockroach Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
// implied. See the License for the specific language governing
// permissions and limitations under the License. See the AUTHORS file
// for names of contributors.

package main

import (
	"context"
	"fmt"
	"runtime"
	"strconv"
	"time"

	_ "github.com/lib/pq"

	"github.com/cockroachdb/cockroach/pkg/util/binfetcher"
)

func registerVersion(r *registry) {
	runVersion := func(ctx context.Context, t *test, c *cluster, version string) {
		nodes := c.nodes - 1
		goos := ifLocal(runtime.GOOS, "linux")

		b, err := binfetcher.Download(ctx, binfetcher.Options{
			Binary:  "cockroach",
			Version: version,
			GOOS:    goos,
			GOARCH:  "amd64",
		})
		if err != nil {
			t.Fatal(err)
		}

		c.Put(ctx, workload, "./workload", c.Node(nodes+1))

		c.Put(ctx, b, "./cockroach", c.Range(1, nodes))
		// Force disable encryption.
		// TODO(mberhault): allow it once version >= 2.1.
		c.Start(ctx, c.Range(1, nodes), startArgsDontEncrypt)

		stageDuration := 10 * time.Minute
		buffer := 10 * time.Minute
		if local {
			c.l.Printf("local mode: speeding up test\n")
			stageDuration = 10 * time.Second
			buffer = time.Minute
		}

		loadDuration := " --duration=" + (time.Duration(3*nodes+2)*stageDuration + buffer).String()

		workloads := []string{
			"./workload run tpcc --tolerate-errors --wait=false --drop --init --warehouses=1 " + loadDuration + " {pgurl:1-%d}",
			"./workload run kv --tolerate-errors --init" + loadDuration + " {pgurl:1-%d}",
		}

		m := newMonitor(ctx, c, c.Range(1, nodes))
		for i, cmd := range workloads {
			cmd := cmd // loop-local copy
			i := i     // ditto
			m.Go(func(ctx context.Context) error {
				cmd = fmt.Sprintf(cmd, nodes)
				childL, err := c.l.ChildLogger("workload " + strconv.Itoa(i))
				if err != nil {
					return err
				}
				return c.RunL(ctx, childL, c.Node(nodes+1), cmd)
			})
		}

		m.Go(func(ctx context.Context) error {
			l, err := c.l.ChildLogger("upgrader")
			if err != nil {
				return err
			}
			// NB: the number of calls to `sleep` needs to be reflected in `loadDuration`.
			sleepAndCheck := func() error {
				t.WorkerStatus("sleeping")
				select {
				case <-ctx.Done():
					return ctx.Err()
				case <-time.After(stageDuration):
				}
				// Make sure everyone is still running.
				for i := 1; i <= nodes; i++ {
					t.WorkerStatus("checking ", i)
					db := c.Conn(ctx, 1)
					defer db.Close()
					rows, err := db.Query(`SHOW DATABASES`)
					if err != nil {
						return err
					}
					if err := rows.Close(); err != nil {
						return err
					}
				}
				return nil
			}

			db := c.Conn(ctx, 1)
			defer db.Close()

			// First let the load generators run in the cluster at `version`.
			if err := sleepAndCheck(); err != nil {
				return err
			}

			stop := func(node int) error {
				m.ExpectDeath()
				l.Printf("stopping node %d\n", node)
				port := fmt.Sprintf("{pgport:%d}", node)
				// Note that the following command line needs to run against both v2.0
				// and the current branch. Do not change it in a manner that is
				// incompatible with 2.0.
				if err := c.RunE(ctx, c.Node(node), "./cockroach quit --insecure --port="+port); err != nil {
					return err
				}
				// NB: we still call Stop to make sure the process is dead when we try
				// to restart it (or we'll catch an error from the RocksDB dir being
				// locked). This won't happen unless run with --local due to timing.
				// However, it serves as a reminder that `./cockroach quit` doesn't yet
				// work well enough -- ideally all listeners and engines are closed by
				// the time it returns to the client.
				//
				// TODO(tschottdorf): should return an error. I doubt that we want to
				// call these *testing.T-style methods on goroutines.
				c.Stop(ctx, c.Node(node))
				return nil
			}

			var oldVersion string
			if err := db.QueryRowContext(ctx, `SHOW CLUSTER SETTING version`).Scan(&oldVersion); err != nil {
				return err
			}
			l.Printf("cluster version is %s\n", oldVersion)

			// Now perform a rolling restart into the new binary.
			for i := 1; i < nodes; i++ {
				t.WorkerStatus("upgrading ", i)
				l.Printf("upgrading %d\n", i)
				if err := stop(i); err != nil {
					return err
				}
				c.Put(ctx, cockroach, "./cockroach", c.Node(i))
				c.Start(ctx, c.Node(i), startArgsDontEncrypt)
				if err := sleepAndCheck(); err != nil {
					return err
				}
			}

			l.Printf("stopping last node\n")
			// Stop the last node.
			if err := stop(nodes); err != nil {
				return err
			}

			// Set cluster.preserve_downgrade_option to be the old cluster version to
			// prevent upgrade.
			l.Printf("preventing automatic upgrade\n")
			if _, err := db.ExecContext(ctx,
				fmt.Sprintf("SET CLUSTER SETTING cluster.preserve_downgrade_option = '%s';", oldVersion),
			); err != nil {
				return err
			}

			// Do upgrade for the last node.
			l.Printf("upgrading last node\n")
			c.Put(ctx, cockroach, "./cockroach", c.Node(nodes))
			c.Start(ctx, c.Node(nodes), startArgsDontEncrypt)
			if err := sleepAndCheck(); err != nil {
				return err
			}

			// Changed our mind, let's roll that back.
			for i := 1; i <= nodes; i++ {
				l.Printf("downgrading node %d\n", i)
				t.WorkerStatus("downgrading", i)
				if err := stop(i); err != nil {
					return err
				}
				c.Put(ctx, b, "./cockroach", c.Node(i))
				c.Start(ctx, c.Node(i), startArgsDontEncrypt)
				if err := sleepAndCheck(); err != nil {
					return err
				}
			}

			// OK, let's go forward again.
			for i := 1; i <= nodes; i++ {
				l.Printf("upgrading node %d (again)\n", i)
				t.WorkerStatus("upgrading", i, "(again)")
				if err := stop(i); err != nil {
					return err
				}
				c.Put(ctx, cockroach, "./cockroach", c.Node(i))
				c.Start(ctx, c.Node(i), startArgsDontEncrypt)
				if err := sleepAndCheck(); err != nil {
					return err
				}
			}

			// Reset cluster.preserve_downgrade_option to allow auto upgrade.
			l.Printf("reenabling auto-upgrade\n")
			if _, err := db.ExecContext(ctx,
				"RESET CLUSTER SETTING cluster.preserve_downgrade_option;",
			); err != nil {
				return err
			}

			return sleepAndCheck()
		})
		m.Wait()
	}

	const version = "v2.0.5"
	for _, n := range []int{3, 5} {
		r.Add(testSpec{
			Name:       fmt.Sprintf("version/mixedWith=%s/nodes=%d", version, n),
			MinVersion: "v2.1.0",
			Nodes:      nodes(n + 1),
			Stable:     true, // DO NOT COPY to new tests
			Run: func(ctx context.Context, t *test, c *cluster) {
				runVersion(ctx, t, c, version)
			},
		})
	}
}
