package orchestrator

import (
	"context"
	"testing"
	"time"

	"github.com/flothus/tmux-xterm-research/server-go/internal/harness/store"
)

func TestEnforceMaxDepth(t *testing.T) {
	o := newOrch(t)
	ctx := context.Background()
	_ = o.CreateRun(ctx, "r1", "")
	// Build a chain t1→t2→t3→t4 (depth 3).
	_ = o.CreateTask(ctx, Task{ID: "t1", RunID: "r1", Title: "t1"})
	_ = o.CreateTask(ctx, Task{ID: "t2", RunID: "r1", ParentTaskID: "t1", Title: "t2"})
	_ = o.CreateTask(ctx, Task{ID: "t3", RunID: "r1", ParentTaskID: "t2", Title: "t3"})
	_ = o.CreateTask(ctx, Task{ID: "t4", RunID: "r1", ParentTaskID: "t3", Title: "t4"})

	p := Policy{MaxDepth: 3}
	if err := o.EnforceMaxDepth(ctx, "t1", p); err != nil {
		t.Errorf("depth at t1: unexpected %v", err)
	}
	if err := o.EnforceMaxDepth(ctx, "t4", p); err == nil {
		t.Errorf("depth at t4: expected error, got nil")
	}
}

func TestEnforceMaxFanout(t *testing.T) {
	o := newOrch(t)
	ctx := context.Background()
	_ = o.CreateRun(ctx, "r1", "")
	_ = o.CreateTask(ctx, Task{ID: "p", RunID: "r1", Title: "parent"})
	for i := 0; i < 3; i++ {
		id := "c" + string(rune('0'+i))
		_ = o.CreateTask(ctx, Task{ID: id, RunID: "r1", ParentTaskID: "p", Title: "child"})
	}
	p := Policy{MaxFanout: 3}
	if err := o.EnforceMaxFanout(ctx, "p", p); err == nil {
		t.Errorf("expected fanout error")
	}
}

func TestDetectDelegationCycle(t *testing.T) {
	o := newOrch(t)
	ctx := context.Background()
	_ = o.CreateRun(ctx, "r1", "")
	_, _ = o.St.DB().Exec(`INSERT INTO agents(id, run_id, status, spawned_at) VALUES('A', 'r1', 'running', ?)`, store.FmtTime(time.Now().UTC()))
	_, _ = o.St.DB().Exec(`INSERT INTO agents(id, run_id, status, spawned_at) VALUES('B', 'r1', 'running', ?)`, store.FmtTime(time.Now().UTC()))
	_ = o.CreateTask(ctx, Task{ID: "t1", RunID: "r1", Title: "x"})
	_ = o.AssignOwner(ctx, "t1", "A")
	_ = o.CreateTask(ctx, Task{ID: "t2", RunID: "r1", ParentTaskID: "t1", Title: "y"})
	_ = o.AssignOwner(ctx, "t2", "B")
	// Now if we want to delegate from t2 down to A again, that's a cycle.
	cyc, err := o.DetectDelegationCycle(ctx, "t2", "A")
	if err != nil {
		t.Fatal(err)
	}
	if !cyc {
		t.Errorf("expected cycle A↺")
	}
	// New agent C is fine.
	cyc, err = o.DetectDelegationCycle(ctx, "t2", "C")
	if err != nil {
		t.Fatal(err)
	}
	if cyc {
		t.Errorf("expected no cycle for fresh C")
	}
}

func TestDetectThrash(t *testing.T) {
	o := newOrch(t)
	ctx := context.Background()
	_ = o.CreateRun(ctx, "r1", "")
	_ = o.CreateTask(ctx, Task{ID: "t1", RunID: "r1", Title: "x"})
	// Insert 5 clarify/answer messages between A and B on t1.
	now := store.Now()
	for i := 0; i < 5; i++ {
		_, err := o.St.DB().Exec(
			`INSERT INTO messages(id, run_id, from_agent, to_agent, type, payload_json, task_id, status, ttl_ms, next_visible_at, created_at)
			 VALUES(?, 'r1', 'A', 'B', 'clarify', '{}', 't1', 'acked', 1000, ?, ?)`,
			"m"+string(rune('0'+i)), store.FmtTime(now), store.FmtTime(now),
		)
		if err != nil {
			t.Fatal(err)
		}
	}
	policy := Policy{ThrashMaxExchanges: 4}
	n, err := o.DetectThrash(ctx, "r1", policy)
	if err != nil {
		t.Fatal(err)
	}
	if n != 1 {
		t.Errorf("thrash escalations = %d, want 1", n)
	}
	var state string
	_ = o.St.DB().QueryRow(`SELECT state FROM tasks WHERE id='t1'`).Scan(&state)
	if state != string(StateEscalated) {
		t.Errorf("task state = %s, want escalated", state)
	}
}
