/*
 * sta_ratings.c - give star-ratings to associated STAs
 *
 * Copyright (C) 2025 Genexis AB.
 *
 * Author: anjan.chanda@iopsys.eu
 * See LICENSE file for license related information.
 */
#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <inttypes.h>

#include <easy/easy.h>
#include <wifidefs.h>
#include <wifiutils.h>

#include "debug.h"
#include "wifi_dataelements.h"

struct sta_ratings_context {
	uint8_t macaddr[6];
	time_t tsp;
	uint8_t gen;
	uint8_t rcpi;
	float rating;
	int valid;
	uint8_t bssid[6];
	uint32_t a1_max;
	uint32_t a2_max;
	uint8_t ap_gen;
	int ap_maxbw;
	int ap_maxmcs;
	int ap_maxnss;
	uint64_t tx_bytes;            /* ap -> sta */
	uint64_t rx_bytes;            /* sta -> ap */
	uint32_t tx_pkts;
	uint32_t rx_pkts;
	uint32_t tx_errors;
	uint32_t rx_errors;
	uint32_t rtx_pkts;
	uint32_t dl_rate;             /* in Kbps: ap -> sta */
	uint32_t ul_rate;             /* in Kbps: sta -> ap */
};

int sta_ratings_init(void **ctx, uint8_t *sta_macaddr)
{
	struct sta_ratings_context *p = calloc(1, sizeof(struct sta_ratings_context));

	*ctx = NULL;
	if (!p)
		return -1;

	*ctx = p;
	memcpy(p->macaddr, sta_macaddr, 6);
	p->rating = -1;
	return 0;
}

void sta_ratings_free(void *ctx)
{
	if (ctx)
		free(ctx);
}

enum {
	WIFI_GEN_UNKNOWN = 0,
	WIFI_3 = 3,
	WIFI_4 = 4,
	WIFI_5 = 5,
	WIFI_6 = 6,
	WIFI_7 = 7,
};

#define WIFI_GEN_MAX	WIFI_7
#define WIFI_GEN_MIN	WIFI_3

#define RCPI_MAX	220
#define RCPI_MIN	40
#define RCPI_INVALID	0

#define RXERR_MAX	10
#define TXRTX_MAX	10

#define STA_IDLE_THPUT	10000

uint32_t bw_value(enum wifi_bw bw)
{
	switch (bw) {
	case BW20: return 20;
	case BW40: return 40;
	case BW80: return 80;
	case BW160: return 160;
	case BW320: return 320;
	case BW8080: return 160;
	case BW_AUTO:
	case BW_UNKNOWN:
	default:
		return 0;
	}
}

uint8_t get_sta_generation(struct wifi_sta_element *sta)
{
	uint8_t gen = WIFI_GEN_UNKNOWN;

	if (!!(sta->caps.valid & HT_CAP_VALID))
		gen = WIFI_4;

	if (!!(sta->caps.valid & VHT_CAP_VALID))
		gen = WIFI_5;

	if (!!(sta->caps.valid & HE_CAP_VALID))
		gen = WIFI_6;

	if (!!(sta->caps.valid & EHT_CAP_VALID))
		gen = WIFI_7;

	return gen;
}

int get_ap_max_nss(struct wifi_radio_element *radio)
{
	int max_nss = 4;

	//TODO

	return max_nss;
}

int get_ap_max_mcs(struct wifi_radio_element *radio)
{
	int maxmcs = 0;

	if (!!(radio->caps.valid & HT_CAP_VALID))
		maxmcs = 31;

	if (!!(radio->caps.valid & VHT_CAP_VALID))
		maxmcs = 9;

	if (!!(radio->caps.valid & HE_CAP_VALID))
		maxmcs = 11;

	if (!!(radio->caps.valid & EHT_CAP_VALID))
		maxmcs = 13;

	return maxmcs;
}

uint8_t get_ap_generation(struct wifi_radio_element *radio)
{
	uint8_t gen = WIFI_GEN_UNKNOWN;

	if (!!(radio->caps.valid & HT_CAP_VALID))
		gen = WIFI_4;

	if (!!(radio->caps.valid & VHT_CAP_VALID))
		gen = WIFI_5;

	if (!!(radio->caps.valid & HE_CAP_VALID))
		gen = WIFI_6;

	if (!!(radio->caps.valid & EHT_CAP_VALID))
		gen = WIFI_7;

	return gen;
}

static int get_ap_max_bandwidth(struct wifi_radio_element *radio)
{
	int max_bw = BW20;

	if (!radio->cur_opclass.num_opclass)
		return -1;

	for (int i = 0; i < radio->cur_opclass.num_opclass; i++) {
		struct wifi_opclass e = {0};
		int ret = wifi_get_opclass_entry(radio->cur_opclass.opclass[i].id, &e);

		if (!ret) {
			if (e.bw > max_bw)
				max_bw = e.bw;
		}
	}

	return max_bw;
}

int find_sta_in_network_device(struct wifi_network_device *dev, uint8_t *sta_macaddr,
			       struct wifi_radio_element **radio,
			       struct wifi_bss_element **bss,
			       struct wifi_sta_element **sta)
{
	struct wifi_radio_element *r = NULL;
	struct wifi_bss_element *b = NULL;
	struct wifi_sta_element *s = NULL;

	list_for_each_entry(r, &dev->radiolist, list) {
		list_for_each_entry(b, &r->bsslist, list) {
			list_for_each_entry(s, &b->stalist, list) {
				if (!memcmp(s->macaddr, sta_macaddr, 6)) {
					*radio = r;
					*bss = b;
					*sta = s;
					return 0;
				}
			}
		}
	}

	return -1;
}

float sta_ratings_calculate(void *ctx, struct wifi_network_device *dev, uint8_t *sta_macaddr)
{
	struct sta_ratings_context *sctx = (struct sta_ratings_context *)ctx;
	float a1_nl = 0.0, a2_nl = 0.0, b_nl = 0.0, d_nl = 0.0, e_nl = 0.0, r_nl = 0.0;
	uint64_t tx_avg_thput = 0, rx_avg_thput = 0;
	uint64_t tx_bytes = 0, rx_bytes = 0;
	uint8_t b = 0, d = 0, e = 0, r = 0;
	uint8_t rxerr = 0, txrtx = 0;
	uint32_t a1 = 0, a2 = 0;

	struct wifi_radio_element *radio = NULL;
	struct wifi_bss_element *bss = NULL;
	struct wifi_sta_element *sta = NULL;

	uint8_t gen = WIFI_GEN_UNKNOWN;
	uint8_t b_max = WIFI_7 - WIFI_3;
	uint8_t d_max = RXERR_MAX;
	uint8_t e_max = TXRTX_MAX;

	/* non-idle wts */
	float a1_wt = 2.0f / 5.0f;
	float a2_wt = 1.0f / 5.0f;
	float b_wt = 1.0f / 5.0f;
	float d_wt = 0.5f / 5.0f;
	float e_wt = 0.5f / 5.0f;
	float r_wt = 0.0;

	bool sta_idle = true;
	float rating = 0.0;
	double secs = 0.0;
	int ret;


	if (!sctx || memcmp(sctx->macaddr, sta_macaddr, 6)) {
		dbg("%s: Invalid context %p for STA " MACFMT"\n",
			__func__, ctx, MAC2STR(sta_macaddr));
		return -1;
	}

	ret = find_sta_in_network_device(dev, sta_macaddr, &radio, &bss, &sta);
	if (ret) {
		dbg("%s: STA " MACFMT " not found Node " MACFMT "\n", __func__,
			MAC2STR(sta_macaddr), MAC2STR(dev->macaddr));

		return -1;
	}

	/* initial fill or when STA's bss changes */
	if (!sctx->valid || hwaddr_is_zero(sctx->bssid) || memcmp(sctx->bssid, bss->bssid, 6)) {
		int ap_maxbw, ap_maxmcs, ap_gen, ap_maxnss;

		memcpy(sctx->bssid, bss->bssid, 6);

		ap_maxbw = get_ap_max_bandwidth(radio);
		if (ap_maxbw == -1)
			goto err_out;

		ap_gen = get_ap_generation(radio);
		if (ap_gen == WIFI_GEN_UNKNOWN)
			goto err_out;

		ap_maxmcs = get_ap_max_mcs(radio);
		if (ap_maxmcs == 0)
			goto err_out;

		ap_maxnss = get_ap_max_nss(radio);
		if (ap_maxnss == -1)
			goto err_out;

		sctx->ap_gen = ap_gen;
		sctx->ap_maxbw = ap_maxbw;
		sctx->ap_maxmcs = ap_maxmcs;
		sctx->ap_maxnss = ap_maxnss;


		sctx->a1_max = wifi_mcs2rate(sctx->ap_maxmcs,
					     bw_value(sctx->ap_maxbw),
					     sctx->ap_maxnss,
					     sctx->ap_gen > WIFI_5 ? WIFI_1xLTF_GI800 : WIFI_GI400);

		dbg("AP: max-bw = %u, gen = %u, max-mcs = %u, max-nss = %u, max-rate = %u\n",
		    bw_value(sctx->ap_maxbw), sctx->ap_gen,
		    sctx->ap_maxmcs, sctx->ap_maxnss, sctx->a1_max);

		sctx->a2_max = sctx->a1_max;
		sctx->valid = 1;
		sctx->tsp = sta->last_updated;
	}

	gen = get_sta_generation(sta);
	if (gen < WIFI_GEN_MIN || gen > WIFI_GEN_MAX)
		goto err_out;

	sctx->gen = gen;
	if (gen > sctx->ap_gen)
		gen = sctx->ap_gen;

	if (sta->rcpi < RCPI_MIN || sta->rcpi > RCPI_MAX)
		goto err_out;

	if (sta->dl_rate / 1000 > sctx->a1_max || sta->ul_rate / 1000 > sctx->a2_max)
		goto err_out;

	secs = difftime(sta->last_updated, sctx->tsp);
	if (secs > 0) {
		rx_bytes = sta->rx_bytes - sctx->rx_bytes;
		tx_bytes = sta->tx_bytes - sctx->tx_bytes;
		rxerr = sta->rx_errors - sctx->rx_errors;
		txrtx = sta->rtx_pkts - sctx->rtx_pkts;

		rx_avg_thput = rx_bytes / secs;
		tx_avg_thput = tx_bytes / secs;

		if (rx_avg_thput > STA_IDLE_THPUT || tx_avg_thput > STA_IDLE_THPUT)
			sta_idle = false;
	}

	sctx->tsp = sta->last_updated;
	sctx->rx_bytes = sta->rx_bytes;
	sctx->tx_bytes = sta->tx_bytes;
	sctx->rx_errors = sta->rx_errors;
	sctx->rtx_pkts = sta->rtx_pkts;
	sctx->dl_rate = sta->dl_rate/1000;
	sctx->ul_rate = sta->ul_rate/1000;
	sctx->rcpi = sta->rcpi;

	/* return last rating */
	if (secs == 0 && sctx->rating != -1)
		return sta->rating;

	if (rxerr > RXERR_MAX)
		rxerr = RXERR_MAX;

	if (txrtx > TXRTX_MAX)
		txrtx = TXRTX_MAX;

	/* derive a, b, c, d, e, r */
	a1 = sctx->a1_max - sctx->dl_rate;
	a2 = sctx->a2_max - sctx->ul_rate;
	b = sctx->ap_gen - gen;
	d = rxerr;
	e = txrtx;
	r = sta->rcpi;

	dbg("STA: " MACFMT"\n", MAC2STR(sta->macaddr));
	dbg("state = %s\n", sta_idle ? "Idle" : "Active");
	dbg("Txbytes = %"PRIu64", Rxbytes = %"PRIu64" in %.1fs "
	    "(avg-thput: tx = %" PRIu64", rx = %"PRIu64")\n",
	    tx_bytes, rx_bytes, secs, tx_avg_thput, rx_avg_thput);
	dbg("Dl-rate = %u, a1 = %u\n", sctx->dl_rate, a1);
	dbg("Ul-rate = %u, a2 = %u\n", sctx->ul_rate, a2);
	dbg("Gen = %u, b = %hhu\n", sctx->gen, b);
	dbg("Rx-errors = d = %hhu\n", d);
	dbg("Tx-retransits = e = %hhu\n", e);
	dbg("Ul-rcpi = r = %u\n", r);

	if (sta_idle) {
		r_nl = (sqrtf(r) - sqrtf(RCPI_MIN)) / (sqrtf(RCPI_MAX) - sqrtf(RCPI_MIN));
		r_wt = 2.0f / 4.0f;
		b_wt = 1.0f / 4.0f;
		d_wt = 0.5f / 4.0f;
		e_wt = 0.5f / 4.0f;
	} else {
		a1_nl = (float)(a1 * a1 * a1) / (float)(sctx->a1_max * sctx->a1_max * sctx->a1_max);
		a2_nl = (float)(a2 * a2 * a2) / (float)(sctx->a2_max * sctx->a2_max * sctx->a2_max);
	}

	b_nl = (float)(b * b) / (float)(b_max * b_max);
	d_nl = (float)(d * d * d) / (float)(d_max * d_max * d_max);
	e_nl = (float)(e * e) / (float)(e_max * e_max);

	dbg("a1 (normalized) = %f,   wt*a1 = %f\n", a1_nl, a1_wt * a1_nl);
	dbg("a2 (normalized) = %f    wt*a2 = %f\n", a2_nl, a2_wt * a2_nl);
	dbg("b  (normalized) = %f    wt*b  = %f\n", b_nl, b_wt * b_nl);
	dbg("d  (normalized) = %f    wt*d  = %f\n", d_nl, d_wt * d_nl);
	dbg("e  (normalized) = %f    wt*e  = %f\n", e_nl, e_wt * e_nl);
	dbg("r  (normalized) = %f    wt*r  = %f\n", r_nl, r_wt * r_nl);

	if (sta_idle) {
		rating = r_wt * r_nl + 0.5f - (b_wt * b_nl + d_wt * d_nl + e_wt * e_nl);
	} else {
		rating = 1.0f - (a1_wt * a1_nl +
				 a2_wt * a2_nl +
				 b_wt * b_nl +
				 d_wt * d_nl +
				 e_wt * e_nl);
	}

	dbg("rating [0 ~ 1] = %f\n", rating);
	rating = 1.0f + 4.0f * rating;
	dbg("rating [1 ~ 5] = %f\n", rating);
	sctx->rating = roundf(rating * 100.0) / 100.0;

	return sctx->rating;

err_out:
	return -1.0;
}
