/*
 * sta.c - STA management
 *
 * Copyright (C) 2020-2024 IOPSYS Software Solutions AB.
 *
 * See LICENSE file for source code license information.
 *
 */

#include "sta.h"

#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <easy/easy.h>
#include <wifidefs.h>
#include <wifiutils.h>

#include "cntlr.h"
#include "config.h"
#include "steer_module.h"
#include "timer.h"
#include "utils/debug.h"
#include "wifi_dataelements.h"
#include "wifi_opclass.h"


extern void cntlr_bcn_metrics_timer_cb(atimer_t *t);
extern void cntlr_btm_req_timer_cb(atimer_t *t);
extern void cntlr_sta_ageout_timer_cb(atimer_t *t);

struct sta *cntlr_find_sta(struct hlist_head *table, uint8_t *macaddr)
{
	int idx = sta_hash(macaddr);
	struct sta *s = NULL;

	if (hlist_empty(&table[idx]))
		return NULL;

	hlist_for_each_entry(s, &table[idx], hlist) {
		if (!memcmp(s->macaddr, macaddr, 6))
			return s;
	}

	return NULL;
}

static struct sta *sta_alloc(uint8_t *macaddr)
{
	struct wifi_sta_element *wse = NULL;
	struct steer_sta *ss;
	struct sta *s;

	s = calloc(1, sizeof(*s) + sizeof(*wse) + sizeof(*ss));
	if (!s)
		return NULL;

	memcpy(s->macaddr, macaddr, 6);
	timer_init(&s->bcn_metrics_timer, cntlr_bcn_metrics_timer_cb);
	timer_init(&s->btm_req_timer, cntlr_btm_req_timer_cb);
	timer_init(&s->ageout_timer, cntlr_sta_ageout_timer_cb);
	time(&s->lookup_time);

	s->de_sta = (struct wifi_sta_element *)(s + 1);
	wse = s->de_sta;
	memcpy(wse->macaddr, macaddr, 6);
	INIT_LIST_HEAD(&wse->meas_reportlist);
	INIT_LIST_HEAD(&wse->umetriclist);
	wse->num_meas_reports = 0;
	wse->mapsta.steer_summary.blacklist_attempt_cnt = STEER_STATS_NO_DATA;
	wse->mapsta.steer_summary.blacklist_success_cnt = STEER_STATS_NO_DATA;
	wse->mapsta.steer_summary.blacklist_failure_cnt = STEER_STATS_NO_DATA;
	wse->mapsta.first = 0;
	wse->mapsta.next = 0;
	wse->mapsta.num_steer_hist = 0;

	s->steer_data = (struct steer_sta *)(wse + 1);
	ss = s->steer_data;
	memcpy(ss->macaddr, macaddr, 6);
	ss->meas_reportlist = &s->de_sta->meas_reportlist;
	ss->unassoc_metriclist = &s->de_sta->umetriclist;

	s->sta_caps.valid = false;
	s->sta_caps.rm_caps = 0;
	s->sta_caps.btm_support = false;
	s->sta_caps.mbo_ie_cdc_attr = false;
	s->sta_caps.multiband = false;
	s->sta_caps.agile_multiband = false;

	return s;
}

static void sta_free(struct sta *s)
{
	free(s);
}

struct sta *cntlr_add_sta(void *cntlr, struct hlist_head *table, uint8_t *macaddr)
{
	struct sta *s;
	int idx;

#if 0	//TODO
	if (n->sta_count >= LIMIT_STA_COUNT) {
		time_t least_used_sta_time = (time_t)(~(time_t)0);
		struct sta *least_used_sta = NULL;

		list_for_each_entry(s, &n->stalist, list) {
			if ((uint64_t)s->lookup_time < (uint64_t)least_used_sta_time) {
				least_used_sta = s;
				least_used_sta_time = s->lookup_time;
			}
		}

		if (least_used_sta) {
			cntlr_dbg(LOG_STA, "%s: remove least used STA " MACFMT
					" to add new STA " MACFMT "\n", __func__,
					MAC2STR(least_used_sta->de_sta->macaddr),
					MAC2STR(macaddr));

			node_remove_sta(c, n, least_used_sta);
		} else {
			//Why ?
			dbg("%s: failed to find least used sta\n", __func__);
		}
	}
#endif
	if (WARN_ON(hwaddr_is_zero(macaddr))) {
		/* should never happen */
		return NULL;
	}

	s = sta_alloc(macaddr);
	if (!s)
		return NULL;

	idx = sta_hash(macaddr);
	hlist_add_head(&s->hlist, &table[idx]);
	s->cntlr = cntlr;
	s->state = STA_UNKNOWN;

	info("%s: New client " MACFMT " connected\n",
	     __func__, MAC2STR(macaddr));

	return s;
}

void cntlr_del_sta_hash(struct hlist_head *table, uint8_t *macaddr)
{
	int idx;
	struct hlist_node *tmp = NULL;
	struct sta *s = NULL;
	bool found = false;

	idx = sta_hash(macaddr);
	if (!hlist_empty(&table[idx])) {
		hlist_for_each_entry_safe(s, tmp, &table[idx], hlist) {
			if (!memcmp(s->macaddr, macaddr, 6)) {
				hlist_del(&s->hlist, &table[idx]);
				found = true;
			}
		}
	}

	if (!found) {
		cntlr_warn(LOG_STA, "%s: STA " MACFMT " not found in table\n",
				  __func__, MAC2STR(macaddr));
	}
}

void cntlr_free_sta(struct sta *del)
{
	del->cntlr = NULL;
	sta_free(del);
}

void cntlr_del_sta(struct hlist_head *table, struct sta *del)
{
	cntlr_del_sta_hash(table, del->macaddr);
	cntlr_free_sta(del);
}

int sta_link_metrics_process(struct sta *s)
{
	//TODO:

	return 0;
}

void sta_free_bcn_metrics(struct sta *s)
{
	cntlr_trace(LOG_STA, "%s: --->\n", __func__);
 
	struct wifi_sta_meas_report *b = NULL, *tmp;

	if (!s || !s->de_sta) {
		cntlr_warn(LOG_STA, "%s: Unexpected empty STA reference!\n", __func__);
		return;
	}

	list_for_each_entry_safe(b, tmp, &s->de_sta->meas_reportlist, list) {
		list_del(&b->list);
		free(b);
		s->de_sta->num_meas_reports--;
	}
}

void sta_free_usta_metrics(struct sta *s)
{
	struct unassoc_sta_metrics *u = NULL, *tmp;

	list_for_each_entry_safe(u, tmp, &s->de_sta->umetriclist, list) {
		list_del(&u->list);
		free(u);
	}
}

void sta_free_assoc_frame(struct sta *s)
{
	cntlr_trace(LOG_STA, "%s: --->\n", __func__);

	if (!s || !s->de_sta) {
		cntlr_warn(LOG_STA, "%s: Unexpected empty STA reference!\n", __func__);
		return;
	}

	free(s->de_sta->reassoc_frame);
	s->de_sta->reassoc_framelen = 0;
}

/* Check if MBO attribute 0x03 (Cellular Data Capabilities) is present.
 * Also serves as MBO IE (OUI: 50:6F:9A) presence check.
 */
static bool reassoc_mbo_cell_attr_present(uint8_t *ies, int ies_len)
{
	uint8_t mbo_oui[3] = {0x50, 0x6F, 0x9A}; /* Wi-Fi Alliance OUI */
	uint8_t *vsie;
	uint8_t ie_len;
	uint8_t *attr;
	uint8_t *attr_end;

	/* Find MBO IE (OUI: 50:6F:9A, type: 0x16) */
	vsie = wifi_find_vsie(ies, ies_len, mbo_oui, 0x16, 0xff);
	if (!vsie)
		return false;  /* No MBO IE found */

	/* Parse MBO attributes inside the IE */
	ie_len = vsie[1];
	attr = vsie + 6;  /* Skip IE ID, len, OUI(3), type(1) */
	attr_end = vsie + 2 + ie_len;

	while (attr + 2 <= attr_end) {
		uint8_t attr_id = attr[0];
		uint8_t attr_len = attr[1];

		if (attr + 2 + attr_len > attr_end)
			break;  /* Malformed */

		if (attr_id == 0x03 && attr_len == 1) {
			/* Cellular Data Capabilities attr found */
			return true;
		}

		attr += 2 + attr_len;
	}

	/* MBO IE found but no cellular attribute */
	return false;
}

/* Check multi-band capability from Supported Operating Classes IE */
static bool reassoc_multiband_capable(uint8_t *ies, int ies_len)
{
	uint8_t *ie;
	uint8_t ie_len;
	uint8_t *opclass;
	uint8_t *opclass_end;
	bool has_2_4ghz = false;
	bool has_5ghz = false;
	bool has_6ghz = false;
	int band_count = 0;

	/* IE 0x3B = Supported Operating Classes */
	ie = wifi_find_ie(ies, ies_len, 0x3B);
	if (!ie || ie[1] < 2)
		return false;

	ie_len = ie[1];
	opclass = ie + 3;  /* Skip IE ID, len, current operating class */
	opclass_end = ie + 2 + ie_len;

	/* Parse operating class list */
	while (opclass < opclass_end) {
		enum wifi_band band = wifi_opclass_get_band(opclass[0]);

		if (band == BAND_2)
			has_2_4ghz = true;
		else if (band == BAND_5)
			has_5ghz = true;
		else if (band == BAND_6)
			has_6ghz = true;

		opclass++;
	}

	/* Device is multi-band if it supports 2+ bands */
	band_count = (has_2_4ghz ? 1 : 0) + (has_5ghz ? 1 : 0) + (has_6ghz ? 1 : 0);

	return band_count >= 2;
}

void sta_update_capabilities(struct sta *s)
{
	struct wifi_caps caps;
	uint8_t cbitmap[32];
	struct assoc_frame {
		uint16_t cap_info;
		uint16_t listen_int;
		uint8_t tagged[0];
	} __attribute__((packed)) *af;
	bool has_beacon_modes;
	uint8_t *af_ies;
	size_t fix_len;
	int ies_len;
	int ret;

	/* Initialize capabilities as invalid */
	s->sta_caps.valid = false;
	s->sta_caps.rm_caps = 0;
	s->sta_caps.btm_support = false;
	s->sta_caps.mbo_ie_cdc_attr = false;
	s->sta_caps.multiband = false;
	s->sta_caps.agile_multiband = false;

	if (!s->de_sta->reassoc_framelen || !s->de_sta->reassoc_frame)
		return;

	af = (struct assoc_frame *)s->de_sta->reassoc_frame;
	fix_len = sizeof(af->cap_info) + sizeof(af->listen_int);
	ies_len = s->de_sta->reassoc_framelen - fix_len;

	if (ies_len <= 0)
		return;

	af_ies = af->tagged;

	/* Parse WiFi capabilities */
	memset(&caps, 0, sizeof(caps));
	memset(cbitmap, 0, 32);
	ret = wifi_caps_from_ies(af_ies, ies_len, &caps, cbitmap);

	if (ret >= 0) {
		/* Extract RM capabilities */
		if (caps.valid & WIFI_CAP_RM_VALID) {
			if (wifi_cap_isset(cbitmap, WIFI_CAP_RM_BCN_ACTIVE))
				s->sta_caps.rm_caps |= RM_CAP_BCN_ACTIVE;
			if (wifi_cap_isset(cbitmap, WIFI_CAP_RM_BCN_PASSIVE))
				s->sta_caps.rm_caps |= RM_CAP_BCN_PASSIVE;
			if (wifi_cap_isset(cbitmap, WIFI_CAP_RM_BCN_TABLE))
				s->sta_caps.rm_caps |= RM_CAP_BCN_TABLE;
			if (wifi_cap_isset(cbitmap, WIFI_CAP_RM_NBR_REPORT))
				s->sta_caps.rm_caps |= RM_CAP_NBR_REPORT;
		}

		/* 1. Extract BTM support */
		s->sta_caps.btm_support = wifi_cap_isset(cbitmap, WIFI_CAP_11V_BSS_TRANS);
	}

	/* 2. Check for MBO IE */
	s->sta_caps.mbo_ie_cdc_attr = reassoc_mbo_cell_attr_present(af_ies, ies_len);

	/* 3. Check if STA advertises support of operation in multiple bands */
	s->sta_caps.multiband = reassoc_multiband_capable(af_ies, ies_len);

	/* 4. Check if at least one beacon mode is supported (beacon metrics supported) */
	has_beacon_modes = (s->sta_caps.rm_caps &
			(RM_CAP_BCN_ACTIVE | RM_CAP_BCN_PASSIVE | RM_CAP_BCN_TABLE));

	/* Agile multiband STA requires ALL 4 above conditions */
	s->sta_caps.agile_multiband = s->sta_caps.btm_support &&
				      has_beacon_modes &&
				      s->sta_caps.mbo_ie_cdc_attr &&
				      s->sta_caps.multiband;

	s->sta_caps.valid = true;
}

bool sta_is_agile_multiband(struct sta *s)
{
	if (!s->sta_caps.valid)
		sta_update_capabilities(s);

	return s->sta_caps.agile_multiband;
}

int sta_inc_ref(struct sta *s)
{
	s->nref++;

	if (timer_pending(&s->ageout_timer))
		timer_del(&s->ageout_timer);

	return s->nref;
}

int sta_dec_ref(struct sta *s)
{
	struct controller *c = s->cntlr;
	int timeout_ms;

	if (WARN_ON(s->nref == 0))
		return 0;

	s->nref--;
	if (s->nref == 0) {
		time(&s->disassoc_time);

		if (c->cfg.stale_sta_timeout > ULOOP_MAX_TIMEOUT) {
			/* For timeouts > 24 days, use maximum possible uloop timeout */
			timeout_ms = ULOOP_MAX_TIMEOUT * 1000;
		} else {
			/* For timeouts <= 24 days, use configured timeout directly */
			timeout_ms = c->cfg.stale_sta_timeout * 1000;
		}
		timer_set(&s->ageout_timer, timeout_ms);
	}

	return s->nref;
}
