package orchestrator

import (
	"context"
	"crypto/sha256"
	"encoding/hex"
	"fmt"
	"time"

	"github.com/flothus/tmux-xterm-research/server-go/internal/harness/event"
)

// Policy holds the live run policy. Loaded from the org yaml in later phases;
// for Phase A it's constructed by tests / the main binary directly.
type Policy struct {
	MaxDepth           int
	MaxFanout          int
	MaxTokensPerRun    int64
	MaxCostUSDPerRun   float64
	ThrashMaxExchanges int
	HandoffTimeout     time.Duration
}

// DefaultPolicy returns the plan's default policy.
func DefaultPolicy() Policy {
	return Policy{
		MaxDepth:           4,
		MaxFanout:          3,
		MaxTokensPerRun:    2_000_000,
		MaxCostUSDPerRun:   5.00,
		ThrashMaxExchanges: 4,
		HandoffTimeout:     60 * time.Second,
	}
}

// DetectThrash counts exchanges per (from, to, thread) tuple in a given run
// and emits coordination.thrash for any tuple at/above ThrashMaxExchanges
// where the linked task has not changed state in the window. Returns the
// count of escalations emitted.
//
// "Thread" is derived from in_reply_to chains: messages sharing the same root
// message via in_reply_to belong to one thread. For Phase A we approximate
// thread as the task_id (most thrash happens on a single task) — this catches
// the common case without graph traversal.
func (o *Orchestrator) DetectThrash(ctx context.Context, runID string, policy Policy) (int, error) {
	rows, err := o.St.DB().QueryContext(ctx,
		`SELECT from_agent, to_agent, task_id, COUNT(*) as exchanges
		   FROM messages
		  WHERE run_id=? AND task_id IS NOT NULL AND type IN ('clarify','answer','delegate','query')
		  GROUP BY from_agent, to_agent, task_id
		 HAVING COUNT(*) >= ?`,
		runID, policy.ThrashMaxExchanges,
	)
	if err != nil {
		return 0, err
	}
	defer rows.Close()
	type thrash struct {
		from, to, task string
		count          int
	}
	var found []thrash
	for rows.Next() {
		var t thrash
		if err := rows.Scan(&t.from, &t.to, &t.task, &t.count); err != nil {
			return 0, err
		}
		found = append(found, t)
	}
	if err := rows.Err(); err != nil {
		return 0, err
	}

	escalated := 0
	for _, t := range found {
		// Has the task transitioned since the last N messages? If yes, not
		// thrash (work was progressing).
		var lastState string
		_ = o.St.DB().QueryRowContext(ctx, `SELECT state FROM tasks WHERE id=?`, t.task).Scan(&lastState)
		if lastState == string(StateCompleted) || lastState == string(StateFailed) {
			continue
		}
		_, _ = o.Bus.Emit(ctx, event.Event{
			Kind: event.KindCoordinationThrash, RunID: runID, TaskID: t.task,
			Payload: map[string]any{
				"from": t.from, "to": t.to, "exchanges": t.count,
				"limit": policy.ThrashMaxExchanges,
			},
		})
		// Escalate the task.
		_ = o.Transition(ctx, t.task, StateEscalated)
		escalated++
	}
	return escalated, nil
}

// DetectDelegationCycle checks the parent chain of `taskID` against `proposedChildOwner`.
// Returns true if assigning a subtask whose parent chain already includes
// proposedChildOwner would create a cycle. Used before issuing a delegate
// message.
func (o *Orchestrator) DetectDelegationCycle(ctx context.Context, taskID, proposedChildOwner string) (bool, error) {
	// Walk parents.
	visited := map[string]bool{}
	current := taskID
	for current != "" {
		if visited[current] {
			return true, nil
		}
		visited[current] = true
		var owner, parent string
		err := o.St.DB().QueryRowContext(ctx, `SELECT IFNULL(owner_agent_id,''), IFNULL(parent_task_id,'') FROM tasks WHERE id=?`, current).Scan(&owner, &parent)
		if err != nil {
			return false, err
		}
		if owner == proposedChildOwner {
			return true, nil
		}
		current = parent
	}
	return false, nil
}

// EnforceMaxDepth walks the parent chain of taskID and returns an error if the
// depth meets or exceeds policy.MaxDepth.
func (o *Orchestrator) EnforceMaxDepth(ctx context.Context, taskID string, policy Policy) error {
	depth := 0
	current := taskID
	for current != "" && depth < policy.MaxDepth+10 {
		var parent string
		err := o.St.DB().QueryRowContext(ctx, `SELECT IFNULL(parent_task_id,'') FROM tasks WHERE id=?`, current).Scan(&parent)
		if err != nil {
			return err
		}
		if parent == "" {
			break
		}
		depth++
		current = parent
	}
	if depth >= policy.MaxDepth {
		return fmt.Errorf("orchestrator: max_depth %d exceeded at task %s", policy.MaxDepth, taskID)
	}
	return nil
}

// EnforceMaxFanout returns an error if taskID already has >= MaxFanout open
// children.
func (o *Orchestrator) EnforceMaxFanout(ctx context.Context, taskID string, policy Policy) error {
	var n int
	err := o.St.DB().QueryRowContext(ctx, `SELECT COUNT(*) FROM tasks WHERE parent_task_id=? AND state NOT IN ('completed','failed','abandoned','escalated')`, taskID).Scan(&n)
	if err != nil {
		return err
	}
	if n >= policy.MaxFanout {
		return fmt.Errorf("orchestrator: max_fanout %d exceeded for parent %s (current=%d)", policy.MaxFanout, taskID, n)
	}
	return nil
}

// EmitPolicyViolation is the canonical way other modules raise policy errors.
// Recording these as events makes them visible in the dashboard and trace.
func (o *Orchestrator) EmitPolicyViolation(ctx context.Context, runID, agentID, kind, detail string) {
	_, _ = o.Bus.Emit(ctx, event.Event{
		Kind: event.KindPolicyViolation, RunID: runID, AgentID: agentID,
		Payload: map[string]any{"kind": kind, "detail": detail},
	})
}

// intentHash gives a short stable hash of message intent for loop/thrash
// diagnostics. Not load-bearing for correctness.
func intentHash(s string) string {
	h := sha256.Sum256([]byte(s))
	return hex.EncodeToString(h[:6])
}

var _ = intentHash // reserved for cycle detector in later phases
