/*
 * Copyright (C) 2024 IOPSYS Software Solutions AB. All rights reserved.
 * Copyright (C) 2025 Genexis Sweden AB.
 */

#include "steer.h"

#include <easymesh.h>
#include <libubox/list.h>
#include <map_module.h>
#include <stdbool.h>
#include <stdio.h>
#include <string.h>
#include <time.h>

#include "utils/debug.h"
#include "cntlr.h"
#include "acs.h"
#include "cntlr_cmdu.h"
#include "cntlr_ubus.h"
#include "config.h"
#include "sta.h"
#include "steer_module.h"
#include "timer.h"
#include "wifi_dataelements.h"


struct wifi_steer_history *sta_lookup_steer_attempt(struct sta *s,
						    uint8_t *src_bssid,
						    uint8_t *dst_bssid)
{
	struct wifi_multiap_sta *mapsta = &s->de_sta->mapsta;
	int i;

	cntlr_dbg(LOG_STEER, "%s: sta = " MACFMT", src-ap = "MACFMT"\n", __func__,
		  MAC2STR(s->macaddr), MAC2STR(src_bssid));

	/* find imcomplete steering attempt */
	for (i = 0; i < mapsta->num_steer_hist; i++) {
		int idx = (mapsta->first + i) % MAX_STEER_HISTORY;

		if (!memcmp(mapsta->steer_history[idx].src_bssid, src_bssid, 6) &&
		    !mapsta->steer_history[idx].complete) {
			if (!dst_bssid ||
			    !memcmp(mapsta->steer_history[idx].dst_bssid, dst_bssid, 6)) {
				return &mapsta->steer_history[idx];
			}
		}
	}

	cntlr_dbg(LOG_STEER, "%s: Steer attempt for sta = " MACFMT" not found\n",
		  __func__, MAC2STR(s->macaddr));

	return NULL;
}

void cntlr_update_sta_steer_counters(struct controller *c,
				     uint8_t *sta_mac,
				     uint8_t *src_bssid,
				     uint8_t *dst_bssid,
				     uint32_t mode,
				     enum steer_trigger trigger,
				     uint8_t dst_rcpi)
{
	cntlr_trace(LOG_STEER, "%s:--->\n", __func__);

	struct sta *s = cntlr_find_sta(c->sta_table, sta_mac);
	struct wifi_multiap_sta *mapsta;
	struct wifi_steer_history *a;

	if (!s) {
		cntlr_dbg(LOG_STA, "%s: Unknown STA "MACFMT"\n", __func__,
			  MAC2STR(sta_mac));
		return;
	}

	mapsta = &s->de_sta->mapsta;

	/* update history entry */
	a = &mapsta->steer_history[mapsta->next];
	memset(a, 0, sizeof(struct wifi_steer_history));
	timestamp_update(&a->time);
	time(&a->steer_time);
	if (src_bssid)
		memcpy(a->src_bssid, src_bssid, 6);

	if (dst_bssid)
		memcpy(a->dst_bssid, dst_bssid, 6);

	a->trigger = trigger;
	if (trigger == STEER_TRIGGER_LINK_QUALITY) {
		a->src_rcpi = s->de_sta->rcpi;
		a->dst_rcpi = dst_rcpi;
	}

	mapsta->next = (mapsta->next + 1) % MAX_STEER_HISTORY;
	if (mapsta->num_steer_hist < MAX_STEER_HISTORY) {
		mapsta->num_steer_hist++;
	} else {
		mapsta->first = (mapsta->first + 1) % MAX_STEER_HISTORY;
	}

	//FIXME: mode mappings

	switch (mode) {
	case STEER_MODE_ASSOC_CTL:
		a->method = STEER_METHOD_ASSOC_CTL;
		break;
	case STEER_MODE_BTM_REQ:
		a->method = STEER_METHOD_BTM_REQ;
		s->de_sta->mapsta.steer_summary.btm_attempt_cnt++;
		c->dlem.network.steer_summary.btm_attempt_cnt++;
		break;
	case STEER_MODE_OPPORTUNITY:
		a->method = STEER_METHOD_BTM_REQ;
		/*TODO: add counter for opportunity (incl blacklis count) */
		break;
	default:
		a->method = STEER_METHOD_BTM_REQ;	/* default method */
		break;
	}

	/* Record tsp for most recent steer attempt */
	timestamp_update(&s->de_sta->mapsta.steer_summary.last_attempt_tsp);
}

void cntlr_notify_client_steer_req_evt(struct controller *c,
			uint8_t *bssid, uint32_t sta_nr, uint8_t stas[][6],
			uint32_t bssid_nr, uint8_t target_bssid[][6])
{
	char ev_data[1024] = {0};

	snprintf(ev_data, sizeof(ev_data),
			"{\"bssid\":\""MACFMT"\"",
			MAC2STR(bssid));

	if (sta_nr) {
		char mac[64] = {0};

		strncat(ev_data, ",\"sta_mac\":", sizeof(ev_data) - strlen(ev_data));
		snprintf(mac, sizeof(mac), "\""MACFMT"\"", MAC2STR(stas[0]));
		strncat(ev_data, mac, sizeof(ev_data) - strlen(ev_data));

		// TODO: use blob_buf directly to provide further STA MACs
	}

	if (bssid_nr) {
		char mac[64] = {0};

		strncat(ev_data, ",\"target_bssid\":", sizeof(ev_data) - strlen(ev_data));
		snprintf(mac, sizeof(mac), "\""MACFMT"\"", MAC2STR(target_bssid[0]));
		strncat(ev_data, mac, sizeof(ev_data) - strlen(ev_data));

		// TODO: use blob_buf directly to provide further target MACs
	}

	strncat(ev_data, "}", sizeof(ev_data) - strlen(ev_data));

	cntlr_notify_event(c, CNTLR_EVENT_CLIENT_STEER_REQUEST, ev_data);
}

void cntlr_notify_client_steer_result(struct controller *c,
				      uint8_t *sta_mac, int result)
{
	char ev_data[1024] = {0};

	snprintf(ev_data, sizeof(ev_data),
		 "{\"sta_mac\":\""MACFMT"\""
		 ",\"status\":%d}",
		 MAC2STR(sta_mac), result);

	cntlr_notify_event(c, CNTLR_EVENT_CLIENT_STEER_RESULT, ev_data);
}

void cntlr_btm_req_timer_cb(atimer_t *t)
{
	cntlr_trace(LOG_STEER, "%s:--->\n", __func__);

	struct sta *s = container_of(t, struct sta, btm_req_timer);
	struct controller *c = s->cntlr;

	s->de_sta->mapsta.steer_summary.btm_failure_cnt++;
	c->dlem.network.steer_summary.btm_failure_cnt++;

	cntlr_notify_client_steer_result(c, s->macaddr,
					 STEER_RESULT_FAIL_TIMEOUT);
}

void cntlr_notify_backhaul_steer_req_evt(struct controller *c, uint8_t *agent,
					 uint8_t *bsta, uint8_t *tbssid,
					 uint8_t opclass, uint8_t channel)
{
	char ev_data[512] = {0};

	snprintf(ev_data, sizeof(ev_data), "{\"agent\":\"" MACFMT "\"," "\"bsta\":\""
			MACFMT "\"," "\"target_bssid\":\"" MACFMT "\"," "\"opclass\":%u,"
			"\"channel\":%u}", MAC2STR(agent), MAC2STR(bsta), MAC2STR(tbssid),
			opclass, channel);

	cntlr_notify_event(c, CNTLR_EVENT_BSTA_STEER_REQUEST, ev_data);
}

int cntlr_steer_sta(struct controller *c, struct sta *s, uint8_t *target_bssid,
		    uint32_t mode, uint32_t reason)
{
	cntlr_trace(LOG_STEER, "%s:--->\n", __func__);

	int ret = 0;
	uint16_t mid;
	enum steer_trigger trigger_reason;

	if (!target_bssid || hwaddr_is_zero(target_bssid)) {
		cntlr_dbg(LOG_STEER, "%s: steer verdict = OK, but target AP = NULL\n", __func__);
		return 0;
	}

	if (!memcmp(target_bssid, s->bssid, 6)) {
		s->de_sta->mapsta.steer_summary.no_candidate_cnt++;
		c->dlem.network.steer_summary.no_candidate_cnt++;
		cntlr_dbg(LOG_STEER,
			  "%s: " MACFMT " connected to best AP! No steer needed.\n",
			  __func__, MAC2STR(s->macaddr));
		return 0;
	}

	cntlr_info(LOG_STEER,
		  "%s: Try to steer " MACFMT " from " MACFMT " to " MACFMT "\n",
		  __func__, MAC2STR(s->macaddr), MAC2STR(s->bssid),
		  MAC2STR(target_bssid));

	switch (mode) {
	case STEER_MODE_ASSOC_CTL:
		ret = cntlr_send_client_assoc_ctrl_request(c,
							   s->agent_almacaddr,
							   s->bssid,
							   ASSOC_CTRL_TIMED_BLOCK,
							   10, /* validity period */
							   1,
							   s->macaddr,
							   &mid);
		if (ret) {
			cntlr_warn(LOG_STEER, "%s: Failed to send cmdu for assoc control!\n", __func__);
			//s->de_sta->mapsta.failed_steer_attempts++;
			return ret;
		}
		/* Keep mid & check assoc control succesful in ACK msg */
		s->latest_assoc_cntrl_mid = mid;
		cntlr_dbg(LOG_STEER, "%s: STA assoc control mid = %u\n",
			  __func__, mid);
		break;
	case STEER_MODE_BTM_REQ:
		if (is_sta_in_controller_btm_exclude(c, s->macaddr)) {
			cntlr_dbg(LOG_STEER,
				  "%s: STA " MACFMT " in global BTM exclude list - skip steering\n",
				  __func__, MAC2STR(s->macaddr));
			return 0;
		}
		/* no break */
	case STEER_MODE_OPPORTUNITY:
		ret = cntlr_send_client_steer_request(c,
						      s->agent_almacaddr,
						      s->bssid, 0,
						      1, (uint8_t (*)[6])s->macaddr,
						      1, (uint8_t (*)[6])target_bssid,
						      mode,
						      reason);
		if (ret) {
			cntlr_warn(LOG_STEER, "%s: Failed to send cmdu for steering sta!\n", __func__);
			return ret;
		}

		/* Expect Client Steering BTM Report message and
		 * Tunneled BTM-Response message for the STA.
		 */
		timer_set(&s->btm_req_timer, BTM_RESP_EXP_TIMEOUT * 1000);
		break;
	case STEER_MODE_UNDEFINED:
	default:
		cntlr_dbg(LOG_STEER, "%s: steer mode is undefined\n", __func__);
		return 0;
	}

	switch (reason) {
	case STEER_REASON_LOW_THPUT:
		trigger_reason = STEER_TRIGGER_THPUT;
		break;
	case STEER_REASON_LOW_RCPI:
		trigger_reason = STEER_TRIGGER_LINK_QUALITY;
		break;
	default:
		trigger_reason = STEER_TRIGGER_UNKNOWN;
		break;
	}

	cntlr_update_sta_steer_counters(c, s->macaddr, s->bssid,
					target_bssid, mode,
					trigger_reason,
					/* to->rcpi */ 255);	//FIXME

	return 0;
}

void cntlr_update_sta_steer_data(struct controller *c, struct sta *s)
{
	struct steer_sta *ss = s->steer_data;
	struct netif_iface *bsta_iface = NULL;
	uint8_t *bsta_agent_id = NULL;
	struct node *n = NULL;
	struct netif_radio *r;


	cntlr_dbg(LOG_STEER | LOG_BSTEER, "%s:---> STA = " MACFMT " (is bsta = %d)\n",
		  __func__, MAC2STR(s->macaddr), s->is_bsta);

	if (s->state == STA_DISCONNECTED) {
		ss->bss.connected = 0;
		return;
	}

	if (s->is_bsta) {
		bsta_iface = cntlr_find_bsta(c, s->macaddr);

		if (!bsta_iface) {
			cntlr_dbg(LOG_BSTEER, "%s: skip unknown bSTA " MACFMT" interface\n",
				  __func__, MAC2STR(s->macaddr));
			return;
		}

		bsta_agent_id = bsta_iface->agent->almacaddr;
	}

	ss->bss.connected = 1;
	memcpy(ss->bss.bssid, s->bssid, 6);
	memcpy(ss->bss.agent, s->agent_almacaddr, 6);
	memset(ss->bss.ssid, 0, sizeof(ss->bss.ssid));
	memcpy(ss->bss.ssid, s->ssid, s->ssidlen);

	r = cntlr_find_radio_with_bssid(c, s->bssid);
	if (r) {
		struct radio_policy *rp;

		rp = cntlr_get_radio_policy(&c->cfg, r->radio_el->macaddr);
		if (rp) {
			ss->bss.rcpi_threshold = rp->rcpi_threshold;
			ss->bss.report_rcpi_threshold = rp->report_rcpi_threshold;
			ss->bss.report_rcpi_hysteresis_margin = rp->report_rcpi_hysteresis_margin;
			ss->bss.util_threshold = rp->util_threshold;
			ss->bss.report_util_threshold = rp->report_util_threshold;
		}

		ss->bss.opclass = ctrl_radio_cur_opclass_id(r->radio_el);
		ss->bss.channel = ctrl_radio_cur_opclass_ctrl_chan(r->radio_el);
	}

	memset(&ss->target, 0, sizeof(struct steer_sta_target_bss));

	/* update nbrlist */
	memset(ss->nbrlist, 0, ss->num_nbr * sizeof(struct steer_sta_target_bss));
	ss->num_nbr = 0;

	list_for_each_entry(n, &c->nodelist, list) {
		/* for bSTA, do not include self-node in nbrlist */
		if (s->is_bsta && !memcmp(n->almacaddr, bsta_agent_id, 6))
			continue;

		list_for_each_entry(r, &n->radiolist, list) {
			struct netif_iface *p = NULL;
			uint8_t ctrl_channel = 0;	/* 20MHz */
			uint8_t opclass = 0;

			/* if (s->fh->band != r->radio_el->band) {
				continue;
			}
			*/

			opclass = ctrl_radio_cur_opclass_id(r->radio_el);
			ctrl_channel = ctrl_radio_cur_opclass_ctrl_chan(r->radio_el);
			if (opclass == 0 || ctrl_channel == 0)
				continue;

			list_for_each_entry(p, &r->iflist, list) {
				struct steer_sta_target_bss *t = &ss->nbrlist[ss->num_nbr];

				if (!p->bss)
					continue;

				if (s->ssidlen > 0 && memcmp(p->bss->ssid, s->ssid, s->ssidlen))
					continue;

				//if (!p->bss->sta_assoc_allowed)
				//	continue;

				if (s->is_bsta) {
					if (p->bss->is_bbss) {
						cntlr_dbg(LOG_STEER | LOG_BSTEER,
							  "%s: bss " MACFMT" fBSS = %d, bBSS = %d\n", __func__,
							  MAC2STR(p->bss->bssid), p->bss->is_fbss, p->bss->is_bbss);

						memcpy(t->bssid, p->bss->bssid, 6);
						memcpy(t->agent, n->almacaddr, 6);
						memcpy(t->ruid, r->radio_el->macaddr, 6);
						t->opclass = opclass;
						t->channel = ctrl_channel;
						ss->num_nbr++;

						if (ss->num_nbr >= MAX_NUM_NBRS)
							return;
					}
				} else {
					if (p->bss->is_fbss) {
						cntlr_dbg(LOG_STEER | LOG_BSTEER,
							  "%s: bss " MACFMT" fBSS = %d, bBSS = %d\n", __func__,
							  MAC2STR(p->bss->bssid), p->bss->is_fbss, p->bss->is_bbss);
						memcpy(t->bssid, p->bss->bssid, 6);
						memcpy(t->agent, n->almacaddr, 6);
						memcpy(t->ruid, r->radio_el->macaddr, 6);
						t->opclass = opclass;
						t->channel = ctrl_channel;
						ss->num_nbr++;

						if (ss->num_nbr >= MAX_NUM_NBRS)
							return;
					}
				}
			}
		}
	}
}

int cntlr_steer_bsta(struct controller *c, uint8_t *bsta,
		     uint8_t *tbssid, uint8_t opclass, uint8_t channel)
{
	struct netif_iface *bsta_iface = cntlr_find_bsta(c, bsta);
	uint8_t *bsta_agent_id = NULL;

	if (!bsta_iface) {
		cntlr_dbg(LOG_BSTEER, "%s: skip unknown bSTA " MACFMT" interface\n",
			  __func__, MAC2STR(bsta));
		return -1;
	}

	bsta_agent_id = bsta_iface->agent->almacaddr;
	return cntlr_send_backhaul_steer_request(c, bsta_agent_id, bsta, tbssid,
						 opclass, channel);
}

void cntlr_inform_bsteer_modules(struct controller *c, struct sta *s, uint16_t rxcmdu_type)
{
	//TODO: bSTA steer plugins
}

void cntlr_inform_steer_modules(struct controller *c, struct sta *s, uint16_t rxcmdu_type)
{
	struct steer_control *sc = NULL;

	cntlr_trace(LOG_STEER, "%s: for " MACFMT", and cmdu = %s\n", __func__,
		    MAC2STR(s->macaddr), map_cmdu_type2str(rxcmdu_type));

	list_for_each_entry(sc, &c->sclist, list) {
		struct steer_sta *ss = s->steer_data;
		int ret;

		ret = cntlr_maybe_steer_sta(sc, ss, rxcmdu_type);
		if (ret)
			continue;

		switch (ss->verdict) {
		case STEER_VERDICT_OK:
			if (c->cfg.steer.plugin_policy == STEER_PLUGIN_POLICY_OR) {
				char fmt[256] = {0};

				snprintf(fmt, sizeof(fmt),  "Steer STA: " MACFMT
					 ", src-BSSID:" MACFMT ", target-BSSID:"
					 MACFMT ", reason: %s",
					 MAC2STR(s->macaddr),
					 MAC2STR(s->bssid),
					 MAC2STR(ss->target.bssid),
					 ss->reason == STEER_REASON_LOW_RCPI ? "link quality" :
					 ss->reason == STEER_REASON_LOW_THPUT ? "phyrate" :
					 "unknown");

				if (ss->reason == STEER_REASON_LOW_RCPI) {
					/* log RCPI values for RCPI steer decision */
					snprintf(fmt + strlen(fmt), sizeof(fmt) - strlen(fmt),
						 " src-RCPI: %u, target-RCPI (%s): %u",
						 (unsigned int) s->de_sta->rcpi,
						 (ss->target.ul_rcpi ? "UL" : "DL"),
						 (unsigned int) (ss->target.ul_rcpi ? ss->target.ul_rcpi : ss->target.dl_rcpi));
				}
				snprintf(fmt + strlen(fmt), sizeof(fmt) - strlen(fmt), "\n");
				cntlr_info(LOG_STEER, "%s", fmt);

				if (s->is_bsta) {
					cntlr_steer_bsta(c,
							 s->macaddr,
							 ss->target.bssid,
							 ss->target.opclass,
							 ss->target.channel);
				} else {
					cntlr_steer_sta(c,
							s,
							ss->target.bssid,
							STEER_MODE_BTM_REQ,
							ss->reason);
				}

				return;
			}

			break;
		case STEER_VERDICT_NOK:
			if (c->cfg.steer.plugin_policy == STEER_PLUGIN_POLICY_AND)
				return;

			break;
		default:
			break;
		}
	}
}
