package orchestrator

import (
	"context"
	"path/filepath"
	"testing"
	"time"

	"github.com/flothus/tmux-xterm-research/server-go/internal/harness/event"
	"github.com/flothus/tmux-xterm-research/server-go/internal/harness/store"
)

func newOrch(t *testing.T) *Orchestrator {
	t.Helper()
	tmp := t.TempDir()
	st, err := store.Open(filepath.Join(tmp, "harness.db"))
	if err != nil {
		t.Fatalf("open store: %v", err)
	}
	t.Cleanup(func() { st.Close() })
	bus := event.NewBus(st)
	return New(st, bus)
}

func TestFSMLegalTransitions(t *testing.T) {
	cases := []struct {
		from, to TaskState
		want     bool
	}{
		{StateCreated, StateAssigned, true},
		{StateAssigned, StateInProgress, true},
		{StateInProgress, StateCompleted, true},
		{StateInProgress, StateFailed, true},
		{StateCreated, StateInProgress, false}, // must go through assigned
		{StateCompleted, StateInProgress, false},
		{StateFailed, StateRequeued, true},
	}
	for _, c := range cases {
		if got := IsLegalTransition(c.from, c.to); got != c.want {
			t.Errorf("%s → %s: got %v, want %v", c.from, c.to, got, c.want)
		}
	}
}

func TestTaskCreateAndTransition(t *testing.T) {
	o := newOrch(t)
	ctx := context.Background()
	if err := o.CreateRun(ctx, "r1", "do x"); err != nil {
		t.Fatal(err)
	}
	if err := o.CreateTask(ctx, Task{ID: "t1", RunID: "r1", Title: "subtask"}); err != nil {
		t.Fatal(err)
	}
	// Create the owner agent so the FK in AssignOwner is satisfied.
	_, _ = o.St.DB().Exec(`INSERT INTO agents(id, run_id, status, spawned_at) VALUES('agent-1', 'r1', 'running', ?)`, store.FmtTime(time.Now().UTC()))
	if err := o.AssignOwner(ctx, "t1", "agent-1"); err != nil {
		t.Fatal(err)
	}
	if err := o.Transition(ctx, "t1", StateInProgress); err != nil {
		t.Fatal(err)
	}
	if err := o.Transition(ctx, "t1", StateCompleted); err != nil {
		t.Fatal(err)
	}
	// Illegal transition rejected.
	if err := o.Transition(ctx, "t1", StateInProgress); err == nil {
		t.Errorf("expected illegal-transition error, got nil")
	}
}

func TestReapStalledTasks(t *testing.T) {
	o := newOrch(t)
	ctx := context.Background()
	_ = o.CreateRun(ctx, "r1", "")
	_, _ = o.St.DB().Exec(`INSERT INTO agents(id, run_id, status, spawned_at) VALUES('a1', 'r1', 'running', ?)`, store.FmtTime(time.Now().UTC()))
	past := time.Now().UTC().Add(-1 * time.Minute)
	_ = o.CreateTask(ctx, Task{ID: "t-stalled", RunID: "r1", Title: "x", Deadline: &past})
	_ = o.AssignOwner(ctx, "t-stalled", "a1")
	_ = o.Transition(ctx, "t-stalled", StateInProgress)

	n, err := o.ReapStalledTasks(ctx)
	if err != nil {
		t.Fatal(err)
	}
	if n != 1 {
		t.Fatalf("reaped %d, want 1", n)
	}
	var state string
	_ = o.St.DB().QueryRow(`SELECT state FROM tasks WHERE id='t-stalled'`).Scan(&state)
	if state != string(StateFailed) {
		t.Fatalf("state = %s, want failed", state)
	}
}

func TestCostCeilingKills(t *testing.T) {
	o := newOrch(t)
	ctx := context.Background()
	_ = o.CreateRun(ctx, "r1", "")
	_, _ = o.St.DB().Exec(`INSERT INTO agents(id, run_id, status, spawned_at) VALUES('a', 'r1', 'running', ?)`, store.FmtTime(time.Now().UTC()))
	_ = o.CreateTask(ctx, Task{ID: "t-x", RunID: "r1", Title: "x"})
	_ = o.AssignOwner(ctx, "t-x", "a")
	_ = o.Transition(ctx, "t-x", StateInProgress)
	// Bump cost on the run beyond the ceiling.
	_, _ = o.St.DB().Exec(`UPDATE runs SET cost_usd_total=10 WHERE id='r1'`)

	n, err := o.EnforceCostCeiling(ctx, "", 0, 1.0)
	if err != nil {
		t.Fatal(err)
	}
	if n != 1 {
		t.Fatalf("killed %d, want 1", n)
	}
	var status, reason string
	_ = o.St.DB().QueryRow(`SELECT status, kill_reason FROM runs WHERE id='r1'`).Scan(&status, &reason)
	if status != "killed" || reason != "cost_usd_over_ceiling" {
		t.Fatalf("run status/reason = %s/%s", status, reason)
	}
	var taskState string
	_ = o.St.DB().QueryRow(`SELECT state FROM tasks WHERE id='t-x'`).Scan(&taskState)
	// Cost-ceiling killed the run, which triggers EndRun's cascade: every
	// in-flight task transitions to `abandoned` (was `failed` before the
	// abandoned/failed split — see TestEndRunAbandonsDanglingTasks).
	if taskState != string(StateAbandoned) {
		t.Fatalf("task state = %s, want abandoned (run-end cascade)", taskState)
	}
}

func TestReapDeadAgents(t *testing.T) {
	o := newOrch(t)
	ctx := context.Background()
	_ = o.CreateRun(ctx, "r1", "")
	// Insert agent with stale heartbeat.
	stale := store.FmtTime(time.Now().UTC().Add(-1 * time.Hour))
	_, _ = o.St.DB().Exec(
		`INSERT INTO agents(id, run_id, status, spawned_at, heartbeat_at) VALUES('a1', 'r1', 'running', ?, ?)`,
		store.FmtTime(time.Now().UTC()), stale,
	)
	_ = o.CreateTask(ctx, Task{ID: "t1", RunID: "r1", Title: "x"})
	_ = o.AssignOwner(ctx, "t1", "a1")
	_ = o.Transition(ctx, "t1", StateInProgress)

	n, _, err := o.ReapDeadAgents(ctx, 5*time.Minute)
	if err != nil {
		t.Fatal(err)
	}
	if n != 1 {
		t.Fatalf("reaped %d, want 1", n)
	}
	var status string
	_ = o.St.DB().QueryRow(`SELECT status FROM agents WHERE id='a1'`).Scan(&status)
	if status != "terminated" {
		t.Fatalf("agent status = %s, want terminated", status)
	}
	var taskState string
	_ = o.St.DB().QueryRow(`SELECT state FROM tasks WHERE id='t1'`).Scan(&taskState)
	if taskState != string(StateFailed) {
		t.Fatalf("task state = %s, want failed", taskState)
	}
}
