/*
 * sta.c - STA management
 *
 * Copyright (C) 2020-2024 IOPSYS Software Solutions AB.
 *
 * See LICENSE file for source code license information.
 *
 */

#include "sta.h"

#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <easy/easy.h>
#include <wifidefs.h>
#include <wifiutils.h>

#include "cntlr.h"
#include "steer_module.h"
#include "timer.h"
#include "utils/debug.h"
#include "wifi_dataelements.h"


extern void cntlr_bcn_metrics_timer_cb(atimer_t *t);
extern void cntlr_btm_req_timer_cb(atimer_t *t);
extern void cntlr_sta_ageout_timer_cb(atimer_t *t);

struct sta *cntlr_find_sta(struct hlist_head *table, uint8_t *macaddr)
{
	int idx = sta_hash(macaddr);
	struct sta *s = NULL;

	if (hlist_empty(&table[idx]))
		return NULL;

	hlist_for_each_entry(s, &table[idx], hlist) {
		if (!memcmp(s->macaddr, macaddr, 6))
			return s;
	}

	return NULL;
}

static struct sta *sta_alloc(uint8_t *macaddr)
{
	struct wifi_sta_element *wse = NULL;
	struct steer_sta *ss;
	struct sta *s;

	s = calloc(1, sizeof(*s) + sizeof(*wse) + sizeof(*ss));
	if (!s)
		return NULL;

	memcpy(s->macaddr, macaddr, 6);
	timer_init(&s->bcn_metrics_timer, cntlr_bcn_metrics_timer_cb);
	timer_init(&s->btm_req_timer, cntlr_btm_req_timer_cb);
	timer_init(&s->ageout_timer, cntlr_sta_ageout_timer_cb);
	time(&s->lookup_time);

	s->de_sta = (struct wifi_sta_element *)(s + 1);
	wse = s->de_sta;
	memcpy(wse->macaddr, macaddr, 6);
	INIT_LIST_HEAD(&wse->meas_reportlist);
	INIT_LIST_HEAD(&wse->umetriclist);
	wse->num_meas_reports = 0;
	wse->mapsta.steer_summary.blacklist_attempt_cnt = STEER_STATS_NO_DATA;
	wse->mapsta.steer_summary.blacklist_success_cnt = STEER_STATS_NO_DATA;
	wse->mapsta.steer_summary.blacklist_failure_cnt = STEER_STATS_NO_DATA;
	wse->mapsta.first = 0;
	wse->mapsta.next = 0;
	wse->mapsta.num_steer_hist = 0;

	s->steer_data = (struct steer_sta *)(wse + 1);
	ss = s->steer_data;
	memcpy(ss->macaddr, macaddr, 6);
	ss->meas_reportlist = &s->de_sta->meas_reportlist;
	ss->unassoc_metriclist = &s->de_sta->umetriclist;

	s->sta_caps.valid = false;
	s->sta_caps.rm_caps = 0;
	s->sta_caps.btm_support = false;
	s->sta_caps.mbo_support = false;
	s->sta_caps.agile_multiband = false;

	return s;
}

static void sta_free(struct sta *s)
{
	free(s);
}

struct sta *cntlr_add_sta(void *cntlr, struct hlist_head *table, uint8_t *macaddr)
{
	struct sta *s;
	int idx;

#if 0	//TODO
	if (n->sta_count >= LIMIT_STA_COUNT) {
		time_t least_used_sta_time = (time_t)(~(time_t)0);
		struct sta *least_used_sta = NULL;

		list_for_each_entry(s, &n->stalist, list) {
			if ((uint64_t)s->lookup_time < (uint64_t)least_used_sta_time) {
				least_used_sta = s;
				least_used_sta_time = s->lookup_time;
			}
		}

		if (least_used_sta) {
			cntlr_dbg(LOG_STA, "%s: remove least used STA " MACFMT
					" to add new STA " MACFMT "\n", __func__,
					MAC2STR(least_used_sta->de_sta->macaddr),
					MAC2STR(macaddr));

			node_remove_sta(c, n, least_used_sta);
		} else {
			//Why ?
			dbg("%s: failed to find least used sta\n", __func__);
		}
	}
#endif
	if (WARN_ON(hwaddr_is_zero(macaddr))) {
		/* should never happen */
		return NULL;
	}

	s = sta_alloc(macaddr);
	if (!s)
		return NULL;

	idx = sta_hash(macaddr);
	hlist_add_head(&s->hlist, &table[idx]);
	s->cntlr = cntlr;
	s->state = STA_UNKNOWN;

	info("%s: New client " MACFMT " connected\n",
	     __func__, MAC2STR(macaddr));

	return s;
}

void cntlr_del_sta_hash(struct hlist_head *table, uint8_t *macaddr)
{
	int idx;
	struct hlist_node *tmp = NULL;
	struct sta *s = NULL;
	bool found = false;

	idx = sta_hash(macaddr);
	if (!hlist_empty(&table[idx])) {
		hlist_for_each_entry_safe(s, tmp, &table[idx], hlist) {
			if (!memcmp(s->macaddr, macaddr, 6)) {
				hlist_del(&s->hlist, &table[idx]);
				found = true;
			}
		}
	}

	if (!found) {
		cntlr_warn(LOG_STA, "%s: STA " MACFMT " not found in table\n",
				  __func__, MAC2STR(macaddr));
	}
}

void cntlr_free_sta(struct sta *del)
{
	del->cntlr = NULL;
	sta_free(del);
}

void cntlr_del_sta(struct hlist_head *table, struct sta *del)
{
	cntlr_del_sta_hash(table, del->macaddr);
	cntlr_free_sta(del);
}

int sta_link_metrics_process(struct sta *s)
{
	//TODO:

	return 0;
}

void sta_free_bcn_metrics(struct sta *s)
{
	cntlr_trace(LOG_STA, "%s: --->\n", __func__);
 
	struct wifi_sta_meas_report *b = NULL, *tmp;

	if (!s || !s->de_sta) {
		cntlr_warn(LOG_STA, "%s: Unexpected empty STA reference!\n", __func__);
		return;
	}

	list_for_each_entry_safe(b, tmp, &s->de_sta->meas_reportlist, list) {
		list_del(&b->list);
		free(b);
		s->de_sta->num_meas_reports--;
	}
}

void sta_free_usta_metrics(struct sta *s)
{
	struct unassoc_sta_metrics *u = NULL, *tmp;

	list_for_each_entry_safe(u, tmp, &s->de_sta->umetriclist, list) {
		list_del(&u->list);
		free(u);
	}
}

void sta_free_assoc_frame(struct sta *s)
{
	cntlr_trace(LOG_STA, "%s: --->\n", __func__);

	if (!s || !s->de_sta) {
		cntlr_warn(LOG_STA, "%s: Unexpected empty STA reference!\n", __func__);
		return;
	}

	free(s->de_sta->reassoc_frame);
	s->de_sta->reassoc_framelen = 0;
}

/* Check if association frame contains MBO IE (OUI: 50:6F:9A) */
static bool reassoc_support_mbo(uint8_t *ies, int ies_len)
{
	uint8_t *pos = ies;
	uint8_t *end = ies + ies_len;
	uint8_t mbo_oui[3] = {0x50, 0x6F, 0x9A}; /* Wi-Fi Alliance OUI */
	uint8_t mbo_oui_type = 0x16; /* MBO OUI Type */

	while (pos + 1 < end) {
		uint8_t ie_id = *pos++;
		uint8_t ie_len = *pos++;

		if (pos + ie_len > end)
			break;

		/* Check for vendor-specific IE (0xDD) */
		if (ie_id == 0xDD && ie_len >= 4) {
			/* Check if OUI matches Wi-Fi Alliance (50:6F:9A) */
			if (memcmp(pos, mbo_oui, 3) == 0) {
				/* Check if OUI Type is MBO (0x16) */
				if (pos[3] == mbo_oui_type) {
					return true;
				}
			}
		}
		pos += ie_len;
	}

	return false;
}

void sta_update_capabilities(struct sta *s)
{
	struct wifi_caps caps;
	uint8_t cbitmap[32];
	struct assoc_frame {
		uint16_t cap_info;
		uint16_t listen_int;
		uint8_t tagged[0];
	} __attribute__((packed)) *af;
	uint8_t *af_ies;
	size_t fix_len;
	int ies_len;
	int ret;

	/* Initialize capabilities as invalid */
	s->sta_caps.valid = false;
	s->sta_caps.rm_caps = 0;
	s->sta_caps.btm_support = false;
	s->sta_caps.mbo_support = false;
	s->sta_caps.agile_multiband = false;

	if (!s->de_sta->reassoc_framelen || !s->de_sta->reassoc_frame)
		return;

	af = (struct assoc_frame *)s->de_sta->reassoc_frame;
	fix_len = sizeof(af->cap_info) + sizeof(af->listen_int);
	ies_len = s->de_sta->reassoc_framelen - fix_len;

	if (ies_len <= 0)
		return;

	af_ies = af->tagged;

	/* Parse WiFi capabilities */
	memset(&caps, 0, sizeof(caps));
	memset(cbitmap, 0, 32);
	ret = wifi_caps_from_ies(af_ies, ies_len, &caps, cbitmap);

	if (ret >= 0) {
		/* Extract RM capabilities */
		if (caps.valid & WIFI_CAP_RM_VALID) {
			if (wifi_cap_isset(cbitmap, WIFI_CAP_RM_BCN_ACTIVE))
				s->sta_caps.rm_caps |= RM_CAP_BCN_ACTIVE;
			if (wifi_cap_isset(cbitmap, WIFI_CAP_RM_BCN_PASSIVE))
				s->sta_caps.rm_caps |= RM_CAP_BCN_PASSIVE;
			if (wifi_cap_isset(cbitmap, WIFI_CAP_RM_BCN_TABLE))
				s->sta_caps.rm_caps |= RM_CAP_BCN_TABLE;
			if (wifi_cap_isset(cbitmap, WIFI_CAP_RM_NBR_REPORT))
				s->sta_caps.rm_caps |= RM_CAP_NBR_REPORT;
		}

		/* Extract BTM support */
		s->sta_caps.btm_support = wifi_cap_isset(cbitmap, WIFI_CAP_11V_BSS_TRANS);
	}

	/* Extract MBO support */
	s->sta_caps.mbo_support = reassoc_support_mbo(af_ies, ies_len);

	/* Calculate agile multiband capability */
	if (s->sta_caps.mbo_support) {
		bool has_beacon_modes = !!(s->sta_caps.rm_caps &
				(RM_CAP_BCN_ACTIVE | RM_CAP_BCN_PASSIVE | RM_CAP_BCN_TABLE));

		s->sta_caps.agile_multiband = s->sta_caps.btm_support && has_beacon_modes;
	}

	/* Mark capabilities as valid */
	s->sta_caps.valid = true;
}

bool sta_is_agile_multiband(struct sta *s)
{
	if (!s->sta_caps.valid)
		sta_update_capabilities(s);

	return s->sta_caps.agile_multiband;
}

int sta_inc_ref(struct sta *s)
{
	s->nref++;

	if (timer_pending(&s->ageout_timer))
		timer_del(&s->ageout_timer);

	return s->nref;
}

int sta_dec_ref(struct sta *s)
{
	struct controller *c = s->cntlr;

	if (WARN_ON(s->nref == 0))
		return 0;

	s->nref--;
	if (s->nref == 0)
		timer_set(&s->ageout_timer, c->cfg.stale_sta_timeout * 1000);

	return s->nref;
}
