package orchestrator

import (
	"context"
	"errors"

	"github.com/flothus/tmux-xterm-research/server-go/internal/harness/store"
)

// RunMetrics is the canonical per-run measurement bundle for org comparison
// (plan §17 Phase H).
type RunMetrics struct {
	RunID            string
	OrgID            string
	OrgVersion       int
	RoleSetVersion   string
	TokensTotal      int64
	CostUSDTotal     float64
	TasksCompleted   int
	TasksFailed      int
	TasksEscalated   int
	HandoffCount     int
	ThrashCount      int
	ParseFailures    int
	UserRatingMean   float64 // nullable; 0 if absent
	HasUserRating    bool
}

// LoadRunMetrics aggregates a run's metric row + counted derived stats.
func (o *Orchestrator) LoadRunMetrics(ctx context.Context, runID string) (RunMetrics, error) {
	m := RunMetrics{RunID: runID}
	err := o.St.DB().QueryRowContext(ctx,
		`SELECT IFNULL(org_id,''), IFNULL(org_version,0), IFNULL(role_set_version,''), tokens_total, cost_usd_total
		   FROM runs WHERE id=?`, runID,
	).Scan(&m.OrgID, &m.OrgVersion, &m.RoleSetVersion, &m.TokensTotal, &m.CostUSDTotal)
	if err != nil {
		return m, err
	}
	// Tasks by state.
	rows, err := o.St.DB().QueryContext(ctx, `SELECT state, COUNT(*) FROM tasks WHERE run_id=? GROUP BY state`, runID)
	if err != nil {
		return m, err
	}
	for rows.Next() {
		var state string
		var n int
		if err := rows.Scan(&state, &n); err != nil {
			rows.Close()
			return m, err
		}
		switch state {
		case string(StateCompleted):
			m.TasksCompleted = n
		case string(StateFailed):
			m.TasksFailed = n
		case string(StateEscalated):
			m.TasksEscalated = n
		}
	}
	rows.Close()
	// Handoffs.
	_ = o.St.DB().QueryRowContext(ctx, `SELECT COUNT(*) FROM handoffs WHERE task_id IN (SELECT id FROM tasks WHERE run_id=?)`, runID).Scan(&m.HandoffCount)
	// Thrash events.
	_ = o.St.DB().QueryRowContext(ctx, `SELECT COUNT(*) FROM events WHERE run_id=? AND kind='coordination.thrash'`, runID).Scan(&m.ThrashCount)
	// Parse failures (sum across agents in the run).
	_ = o.St.DB().QueryRowContext(ctx, `SELECT IFNULL(SUM(parse_failures),0) FROM agents WHERE run_id=?`, runID).Scan(&m.ParseFailures)
	// User rating mean (across all evaluations with evaluator_id='user' for run-kind targets).
	var sum float64
	var n int
	_ = o.St.DB().QueryRowContext(ctx,
		`SELECT IFNULL(SUM(json_extract(scores_json,'$.overall.value')),0), COUNT(*)
		   FROM evaluations WHERE evaluator_id='user' AND target_kind='run' AND target_id=?`,
		runID,
	).Scan(&sum, &n)
	if n > 0 {
		m.UserRatingMean = sum / float64(n)
		m.HasUserRating = true
	}
	return m, nil
}

// CompareRuns is a thin diff of two RunMetrics. Returns a map of metric →
// (a, b, delta). Phase H done-criterion: produce comparable measurements.
func CompareRuns(a, b RunMetrics) map[string]struct {
	A, B, Delta float64
} {
	out := map[string]struct {
		A, B, Delta float64
	}{}
	put := func(k string, av, bv float64) {
		out[k] = struct{ A, B, Delta float64 }{av, bv, bv - av}
	}
	put("tokens_total", float64(a.TokensTotal), float64(b.TokensTotal))
	put("cost_usd_total", a.CostUSDTotal, b.CostUSDTotal)
	put("tasks_completed", float64(a.TasksCompleted), float64(b.TasksCompleted))
	put("tasks_failed", float64(a.TasksFailed), float64(b.TasksFailed))
	put("handoff_count", float64(a.HandoffCount), float64(b.HandoffCount))
	put("thrash_count", float64(a.ThrashCount), float64(b.ThrashCount))
	put("parse_failures", float64(a.ParseFailures), float64(b.ParseFailures))
	if a.HasUserRating && b.HasUserRating {
		put("user_rating", a.UserRatingMean, b.UserRatingMean)
	}
	return out
}

// SetRunOrg records the org used for a run. Called before dispatching work
// so RunMetrics can attribute correctly. Upserts the org row so the FK is
// satisfied even when the test setup didn't pre-insert it.
func (o *Orchestrator) SetRunOrg(ctx context.Context, runID, orgID string, version int, roleSetVersion string) error {
	if orgID == "" {
		return errors.New("orchestrator: orgID required")
	}
	return o.St.Tx(ctx, func(q store.Querier) error {
		// Upsert org row (yaml_path = id for tests).
		if _, err := q.Exec(
			`INSERT INTO orgs(id, name, version, yaml_path) VALUES(?, ?, ?, ?)
			 ON CONFLICT(id) DO NOTHING`,
			orgID, orgID, version, orgID,
		); err != nil {
			return err
		}
		_, err := q.Exec(
			`UPDATE runs SET org_id=?, org_version=?, role_set_version=? WHERE id=?`,
			orgID, version, nullable(roleSetVersion), runID,
		)
		return err
	})
}
