/*
 * wifi_caps.c - agent wifi capabilities
 *
 * Copyright (C) 2025 IOPSYS Software Solutions AB. All rights reserved.
 *
 */

#include "wifi_caps.h"

#include <easy/utils.h>
#include <easymesh.h>
#include <libubox/blob.h>
#include <libubox/blobmsg.h>
#include <libubox/blobmsg_json.h>
#include <libubus.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <ubusmsg.h>

#include "agent.h"
#include "config.h"
#include "utils/debug.h"
#include "utils/utils.h"
#include "wifi.h"


static uint8_t caps_ht_max_nss(const uint8_t *mcs_set)
{
	uint8_t nss = 0;
	int i;

	for (i = 0; i < 4; i++) {
		if (mcs_set[i] != 0)
			nss++;
	}

	return nss;
}

static void generate_agent_caps_ht(struct wifi_caps_ht *in, struct tlv_ap_ht_cap *out)
{
	uint8_t rx_nss = 0, tx_nss = 0;

	if (in->byte[0] & BIT(1))
		out->cap |= HT_HT40_MASK;
	if (in->byte[0] & BIT(5))
		out->cap |= HT_SGI20_MASK;
	if (in->byte[0] & BIT(6))
		out->cap |= HT_SGI40_MASK;

	rx_nss = caps_ht_max_nss(&in->supp_mcs[0]);

	if (!(in->supp_mcs[12] & BIT(4)))
		/* tx/rx equal */
		tx_nss = rx_nss;

	if (tx_nss)
		tx_nss--;
	if (rx_nss)
		rx_nss--;

	out->cap |= (tx_nss & 0x3) << 6;
	out->cap |= (rx_nss & 0x3) << 4;
}

static uint8_t caps_vht_he_max_nss(uint16_t mcs_set)
{
	uint8_t max_nss = 0;
	int i;

	for (i = 0; i < 8; i++) {
		uint16_t data;

		data = mcs_set >> i*2;
		data &= 0x3;

		if (data != 0x3)
			max_nss++;
	}

	return max_nss;
}

static void generate_agent_caps_vht(struct wifi_caps_vht *in, struct tlv_ap_vht_cap *out)
{
	uint16_t *rx_mcs, *tx_mcs;
	uint8_t rx_nss, tx_nss;
	uint8_t width, nss;

	rx_mcs = (uint16_t *) &in->supp_mcs[0];
	tx_mcs = (uint16_t *) &in->supp_mcs[4];
	out->rx_mcs_supported = BUF_GET_BE16(*rx_mcs);
	out->tx_mcs_supported = BUF_GET_BE16(*tx_mcs);

	rx_nss = caps_vht_he_max_nss(*rx_mcs);
	tx_nss = caps_vht_he_max_nss(*tx_mcs);

	out->cap[0] |= ((tx_nss - 1) & 0x07) << 5;
	out->cap[0] |= ((rx_nss - 1) & 0x07) << 2;

	if (in->byte[0] & BIT(5))
		out->cap[0] |= VHT_SGI80_MASK;
	if (in->byte[0] & BIT(6))
		out->cap[0] |= VHT_SGI160_8080_MASK;
	if (in->byte[2] & BIT(3))
		out->cap[1] |= VHT_SU_BFR;
	if (in->byte[2] & BIT(4))
		out->cap[1] |= VHT_MU_BFR;

	width = in->byte[0] & 0xc;
	nss = in->byte[3] & 0xc;

	if (width >= 1)
		out->cap[1] |= VHT_160_MASK;
	if (width>= 2 || (width == 1 && nss >= 3))
		out->cap[1] |= VHT_8080_MASK;
}

#if (EASYMESH_VERSION > 2)
static void generate_agent_caps_he(struct wifi_caps_he *in, struct agent_wifi_caps *out)
{
	uint16_t *rx_mcs, *tx_mcs;
	uint8_t rx_nss, tx_nss;
	int mcs_len = 4;
	int i;

	if (in->byte_phy[0] & BIT(3)) {
		mcs_len += 4;
		out->he_cap.cap[0] |= HE_160_MASK;
	}

	if (in->byte_phy[0] & BIT(4)) {
		mcs_len += 4;
		out->he_cap.cap[0] |= HE_8080_MASK;
	}

	for (i = 0; i < mcs_len; i+=2) {
		if (i + 1 > sizeof(out->he_cap.mcs))
			break;
		/* change order */
		out->he_cap.mcs[i] = in->byte_mcs[i + 1];
		out->he_cap.mcs[i + 1] = in->byte_mcs[i];
	}
	out->he_cap.mcs_len = i;

	/* set more bits */
	if (in->byte_phy[3] & BIT(7))
		out->he_cap.cap[1] |= HE_SU_BFR;
	if (in->byte_phy[4] & BIT(1))
		out->he_cap.cap[1] |= HE_MU_BFR;
	if (in->byte_phy[2] & BIT(6))
		out->he_cap.cap[1] |= HE_UL_MUMIMO;
	if (in->byte_phy[2] & BIT(7))
		out->he_cap.cap[1] |= HE_UL_MUMIMO_OFDMA;
	if (in->byte_phy[6] & BIT(6))
		out->he_cap.cap[1] |= HE_DL_MUMIMO_OFDMA;
	if (in->byte_phy[2] & BIT(7))
		out->he_cap.cap[1] |= HE_UL_OFDMA;
	if (in->byte_phy[2] & BIT(7))
		out->he_cap.cap[1] |= HE_DL_OFDMA;

	rx_mcs = (uint16_t *) &in->byte_mcs[0];
	tx_mcs = (uint16_t *) &in->byte_mcs[2];

	rx_nss = caps_vht_he_max_nss(*rx_mcs);
	tx_nss = caps_vht_he_max_nss(*tx_mcs);

	out->he_cap.cap[0] |= (rx_nss - 1) << 5;
	out->he_cap.cap[0] |= (tx_nss - 1) << 2;
}

static void generate_agent_caps_wifi6(struct wifi_caps_he *in, struct agent_wifi_caps *out)
{
	int i;

	/* First copy from caps_he */
	if (out->he_cap.cap[0] & HE_160_MASK)
		out->wifi6_cap.caps |= HE160_SUPPORTED;
	if (out->he_cap.cap[0] & HE_8080_MASK)
		out->wifi6_cap.caps |= HE8080_SUPPORTED;

	out->wifi6_cap.mcs_len = out->he_cap.mcs_len;
	out->wifi6_cap.caps |= (out->wifi6_cap.mcs_len & MCS_NSS_LEN_MASK);

	for (i = 0; i < out->he_cap.mcs_len; i++) {
		if (i > sizeof(out->wifi6_cap.mcs))
			break;
		out->wifi6_cap.mcs[i] = out->he_cap.mcs[i];
	}

	/* Parse wifi_caps phy/mac and set more */
	if (in->byte_phy[3] & BIT(7))
		out->wifi6_cap.beamform_caps |= SU_BEAMFORMER_SUPPORTED;
	if (in->byte_phy[4] & BIT(9))
		out->wifi6_cap.beamform_caps |= SU_BEAMFORMEE_SUPPORTED;
	if (in->byte_phy[4] & BIT(1))
		out->wifi6_cap.beamform_caps |= MU_B_FORMER_STATUS_SUPPORTED;
	if (in->byte_phy[4] & (BIT(2) | BIT(3) | BIT(4)))
		out->wifi6_cap.beamform_caps |= B_FORMEE_STS_LE_80_SUPPORTED;
	if (in->byte_phy[4] & (BIT(5) | BIT(6) | BIT(7)))
		out->wifi6_cap.beamform_caps |= B_FORMEE_STS_GT_80_SUPPORTED;
	if (in->byte_phy[2] & BIT(6))
		out->wifi6_cap.beamform_caps |= UL_MU_MIMO_SUPPORTED;
	if (in->byte_mac[3] & BIT(2))
		out->wifi6_cap.beamform_caps |= DL_OFDMA_SUPPORTED;
	if (in->byte_mac[3] & BIT(2))
		out->wifi6_cap.beamform_caps |= UL_OFDMA_SUPPORTED;

	if (in->byte_mac[0] & BIT(1))
		out->wifi6_cap.other_caps |= TWT_REQUESTER_SUPPORTED;
	if (in->byte_mac[0] & BIT(2))
		out->wifi6_cap.other_caps |= TWT_RESPONDER_SUPPORTED;
}
#endif

static int generate_agent_wifi_caps(struct agent_wifi_caps *caps)
{
#if (EASYMESH_VERSION > 2)
	if (caps->wifi_caps.valid & WIFI_CAP_HE_VALID) {
		generate_agent_caps_he(&caps->wifi_caps.he, caps);
		generate_agent_caps_wifi6(&caps->wifi_caps.he, caps);
	}
#endif

	if (caps->wifi_caps.valid & WIFI_CAP_VHT_VALID)
		generate_agent_caps_vht(&caps->wifi_caps.vht, &caps->vht_cap);

	if (caps->wifi_caps.valid & WIFI_CAP_HT_VALID)
		generate_agent_caps_ht(&caps->wifi_caps.ht, &caps->ht_cap);

	return 0;
}

static void parse_radio_wifi_caps(struct blob_attr *msg, struct wifi_caps *caps)
{
	static const struct blobmsg_policy caps_attr[] = {
		[0] = { .name = "ht", .type = BLOBMSG_TYPE_TABLE },
		[1] = { .name = "vht", .type = BLOBMSG_TYPE_TABLE },
		[2] = { .name = "he", .type = BLOBMSG_TYPE_TABLE },
		[3] = { .name = "eht", .type = BLOBMSG_TYPE_TABLE },
		[4] = { .name = "ml", .type = BLOBMSG_TYPE_TABLE },
	};
	struct blob_attr *tb[ARRAY_SIZE(caps_attr)];

	blobmsg_parse(caps_attr, ARRAY_SIZE(caps_attr), tb, blobmsg_data(msg), blob_len(msg));

	if (tb[0]) {
		static const struct blobmsg_policy ht_attr[] = {
			[0] = { .name = "caps", .type = BLOBMSG_TYPE_STRING },
			[1] = { .name = "mcs", .type = BLOBMSG_TYPE_STRING },
		};
		struct blob_attr *tb_ht[ARRAY_SIZE(ht_attr)];

		blobmsg_parse(ht_attr, ARRAY_SIZE(ht_attr), tb_ht, blobmsg_data(tb[0]), blob_len(tb[0]));

		caps->valid |= WIFI_CAP_HT_VALID;
		blobattrtob(tb_ht[0], caps->ht.byte, sizeof(caps->ht.byte));
		blobattrtob(tb_ht[1], caps->ht.supp_mcs, sizeof(caps->ht.supp_mcs));
	}

	if (tb[1]) {
		static const struct blobmsg_policy vht_attr[] = {
			[0] = { .name = "caps", .type = BLOBMSG_TYPE_STRING },
			[1] = { .name = "mcs", .type = BLOBMSG_TYPE_STRING },
		};
		struct blob_attr *tb_vht[ARRAY_SIZE(vht_attr)];

		blobmsg_parse(vht_attr, ARRAY_SIZE(vht_attr), tb_vht, blobmsg_data(tb[1]), blob_len(tb[1]));

		caps->valid |= WIFI_CAP_VHT_VALID;
		blobattrtob(tb_vht[0], caps->vht.byte, sizeof(caps->vht.byte));
		blobattrtob(tb_vht[1], caps->vht.supp_mcs, sizeof(caps->vht.supp_mcs));
	}

#if (EASYMESH_VERSION > 2)
	if (tb[2]) {
		static const struct blobmsg_policy he_attr[] = {
			[0] = { .name = "phy_caps", .type = BLOBMSG_TYPE_STRING },
			[1] = { .name = "mac_caps", .type = BLOBMSG_TYPE_STRING },
			[2] = { .name = "mcs", .type = BLOBMSG_TYPE_STRING },
			[3] = { .name = "ppe", .type = BLOBMSG_TYPE_STRING },
		};
		struct blob_attr *tb_he[ARRAY_SIZE(he_attr)];

		blobmsg_parse(he_attr, ARRAY_SIZE(he_attr), tb_he, blobmsg_data(tb[2]), blob_len(tb[2]));

		caps->valid |= WIFI_CAP_HE_VALID;
		blobattrtob(tb_he[0], caps->he.byte_phy, sizeof(caps->he.byte_phy));
		blobattrtob(tb_he[1], caps->he.byte_mac, sizeof(caps->he.byte_mac));
		blobattrtob(tb_he[2], caps->he.byte_mcs, sizeof(caps->he.byte_mcs));
		blobattrtob(tb_he[3], caps->he.byte_ppe, sizeof(caps->he.byte_ppe));
	}
#endif

#if (EASYMESH_VERSION >= 6)
	if (tb[3]) {
		static const struct blobmsg_policy eht_attr[] = {
			[0] = { .name = "phy_caps", .type = BLOBMSG_TYPE_STRING },
			[1] = { .name = "mac_caps", .type = BLOBMSG_TYPE_STRING },
			[2] = { .name = "mcs", .type = BLOBMSG_TYPE_STRING },
			[3] = { .name = "ppe", .type = BLOBMSG_TYPE_STRING },
		};
		struct blob_attr *tb_eht[ARRAY_SIZE(eht_attr)];

		blobmsg_parse(eht_attr, ARRAY_SIZE(eht_attr), tb_eht, blobmsg_data(tb[3]), blob_len(tb[3]));

		caps->valid |= WIFI_CAP_EHT_VALID;
		blobattrtob(tb_eht[0], caps->eht.byte_phy, sizeof(caps->eht.byte_phy));
		blobattrtob(tb_eht[1], caps->eht.byte_mac, sizeof(caps->eht.byte_mac));
		blobattrtob(tb_eht[2], caps->eht.supp_mcs, sizeof(caps->eht.supp_mcs));
		blobattrtob(tb_eht[3], caps->eht.byte_ppe_th, sizeof(caps->eht.byte_ppe_th));
	}

	if (tb[4]) {
		static const struct blobmsg_policy ml_attr[] = {
			[0] = { .name = "eml", .type = BLOBMSG_TYPE_STRING },
			[1] = { .name = "mld", .type = BLOBMSG_TYPE_STRING },
			[2] = { .name = "emld", .type = BLOBMSG_TYPE_STRING },
		};
		struct blob_attr *tb_ml[ARRAY_SIZE(ml_attr)];

		blobmsg_parse(ml_attr, ARRAY_SIZE(ml_attr), tb_ml, blobmsg_data(tb[4]), blob_len(tb[4]));

		caps->valid |= WIFI_CAP_ML_VALID;
		if (tb_ml[0]) {
			caps->ml.valid |= WIFI_CAP_ML_EML_VALID;
			blobattrtob(tb_ml[0], caps->ml.eml, sizeof(caps->ml.eml));
		}

		if (tb_ml[1]) {
			caps->ml.valid |= WIFI_CAP_ML_MLD_VALID;
			blobattrtob(tb_ml[1], caps->ml.mld, sizeof(caps->ml.mld));
		}

		if (tb_ml[2]) {
			caps->ml.valid |= WIFI_CAP_ML_EMLD_VALID;
			blobattrtob(tb_ml[2], caps->ml.emld, sizeof(caps->ml.emld));
		}
	}
#endif
}

#if (EASYMESH_VERSION >= 6)
static void parse_radio_mlo_ap_caps(struct ubus_request *req, int type,
			      struct blob_attr *msg)
{
	struct wifi_radio_element *re = (struct wifi_radio_element *)req->priv;
	static const struct blobmsg_policy stats_attr[] = {
		[0] = { .name = "emlsr_supported", .type = BLOBMSG_TYPE_INT8 },
		[1] = { .name = "emlmr_supported", .type = BLOBMSG_TYPE_INT8 },
		[2] = { .name = "nstr", .type = BLOBMSG_TYPE_INT8 },
		[3] = { .name = "str", .type = BLOBMSG_TYPE_INT8 },
		[4] = { .name = "max_links", .type = BLOBMSG_TYPE_INT32 },
		[5] = { .name = "ttlm", .type = BLOBMSG_TYPE_INT32 },

	};
	struct blob_attr *tb[ARRAY_SIZE(stats_attr)];
	struct agent_config_radio *rcfg;

	blobmsg_parse(stats_attr, ARRAY_SIZE(stats_attr), tb, blobmsg_data(msg), blob_len(msg));

	rcfg = get_agent_config_radio(&re->agent->cfg, re->name);
	if (!rcfg)
		return;

	if (rcfg->mlo_capable) {
		if (tb[0])
			re->wifi7_caps.ap_emlsr_support = !!blobmsg_get_u8(tb[0]);

		if (tb[1])
			re->wifi7_caps.ap_emlmr_support = !!blobmsg_get_u8(tb[1]);

		if (tb[2])
			re->wifi7_caps.ap_nstr_support = !!blobmsg_get_u8(tb[2]);

		if (tb[3])
			re->wifi7_caps.ap_str_support = !!blobmsg_get_u8(tb[3]);
	}

	if (tb[4])
		re->wifi7_caps.max_ap_links = blobmsg_get_u32(tb[4]);

	if (tb[5])
		re->wifi7_caps.ap_ttlm = blobmsg_get_u32(tb[5]);
}

static void parse_radio_mlo_bsta_caps(struct ubus_request *req, int type,
			      struct blob_attr *msg)
{
	struct wifi_radio_element *re = (struct wifi_radio_element *)req->priv;
	static const struct blobmsg_policy stats_attr[] = {
		[0] = { .name = "emlsr_supported", .type = BLOBMSG_TYPE_INT8 },
		[1] = { .name = "emlmr_supported", .type = BLOBMSG_TYPE_INT8 },
		[2] = { .name = "nstr", .type = BLOBMSG_TYPE_INT8 },
		[3] = { .name = "str", .type = BLOBMSG_TYPE_INT8 },
		[4] = { .name = "max_links", .type = BLOBMSG_TYPE_INT32 },
		[5] = { .name = "ttlm", .type = BLOBMSG_TYPE_INT32 },

	};
	struct blob_attr *tb[ARRAY_SIZE(stats_attr)];
	struct agent_config_radio *rcfg;

	blobmsg_parse(stats_attr, ARRAY_SIZE(stats_attr), tb, blobmsg_data(msg), blob_len(msg));

	rcfg = get_agent_config_radio(&re->agent->cfg, re->name);
	if (!rcfg)
		return;

	if (rcfg->mlo_capable) {
		if (tb[0])
			re->wifi7_caps.bsta_emlsr_support = !!blobmsg_get_u8(tb[0]);

		if (tb[1])
			re->wifi7_caps.bsta_emlmr_support = !!blobmsg_get_u8(tb[1]);

		if (tb[2])
			re->wifi7_caps.bsta_nstr_support = !!blobmsg_get_u8(tb[2]);

		if (tb[3])
			re->wifi7_caps.bsta_str_support = !!blobmsg_get_u8(tb[3]);
	}
	if (tb[4])
		re->wifi7_caps.max_bsta_links = blobmsg_get_u32(tb[4]);

	if (tb[5])
		re->wifi7_caps.bsta_ttlm = blobmsg_get_u32(tb[5]);
}
#endif

void parse_radio_ap_caps(struct ubus_request *req, int type,
			 struct blob_attr *msg)
{
	struct wifi_radio_element *re = (struct wifi_radio_element *)req->priv;

	parse_radio_wifi_caps(msg, &re->ap_caps.wifi_caps);
	generate_agent_wifi_caps(&re->ap_caps);

#if (EASYMESH_VERSION >= 6)
	parse_radio_mlo_ap_caps(req, type, msg);
#endif
}

void parse_radio_sta_caps(struct ubus_request *req, int type,
			  struct blob_attr *msg)
{
	struct wifi_radio_element *re = (struct wifi_radio_element *)req->priv;

	parse_radio_wifi_caps(msg, &re->sta_caps.wifi_caps);
	generate_agent_wifi_caps(&re->sta_caps);

#if (EASYMESH_VERSION >= 6)
	parse_radio_mlo_bsta_caps(req, type, msg);
#endif
}
