#include <stdio.h>
#include <stddef.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>

#include <easy/easy.h>
#include <wifidefs.h>
#include <wifiutils.h>

//#include <dpp_api.h>

#include "ztdpp.h"

#ifndef bufprintf
#define bufprintf(...)
#endif

static int derive_k(uint8_t *z, size_t zlen, uint8_t *k, unsigned int hashlen)
{
	const char *info = "bootstrap transfer key";
	uint8_t salt[64], prk[64];
	const uint8_t *addr[1];
	size_t addrlen[1];
	int res;

	if (hashlen != 32)	/* for prime256v1 */
		return -1;

	/* k = HKDF(<>, "bootstrap transfer key", z) */

	/* HKDF-Extract(<>, z) */
	memset(salt, 0, hashlen);
	addr[0] = (const uint8_t *)z;
	addrlen[0] = zlen;

	res = PLATFORM_HMAC_SHA256(salt, hashlen, 1, addr, addrlen, prk);
	if (res < 0)
		return -1;

	bufprintf("HKDF-Extract(<>, IKM=z)", prk, hashlen);

	/* HKDF-Expand(PRK, info, L) */
	res = hmac_sha256_kdf(prk, hashlen, NULL, (const uint8_t *)info, strlen(info), k, hashlen);
	memset(prk, 0, hashlen);
	if (res < 0)
		return -1;

	bufprintf("k = HKDF-Expand(PRK, info, L)", k, hashlen);
	return 0;
}


int dpp_process_pa_vsie_frame(void *ap_key, uint8_t *frame, size_t framelen, char **uri)
{
	size_t wrapped_len, unwrapped_len;
	unsigned int hashlen = 32;
	size_t epk1_len = 0;
	uint8_t *unwrapped;
	uint8_t *pos;
	uint8_t k[64];
	size_t zlen;
	void *epk1;
	uint8_t *z;
	int ret;

	/* validate frame leading bytes */
	if (!frame || frame[0] != 0x4 || frame[1] != 0x9 ||
	    memcmp(&frame[2], "\xb4\x56\xfa", 3) || frame[5] != 0xa) {
		printf("Ignore invalid dpp-pa-vsie-frame\n");
		return -1;
	}
	pos = frame + 6;

	epk1_len = buf_get_le16(pos);
	pos += 2;
	epk1 = ecc_key_gen_from_spki(pos, epk1_len);
	if (!epk1) {
		printf("error getting epk1 from frame (epk1_len = %zu)\n", epk1_len);
		return -1;
	}

	pos += epk1_len;
	wrapped_len = buf_get_le16(pos);
	pos += 2;
	printf("Ciphertext len = %zu\n", wrapped_len);
	if (wrapped_len < AES_BLOCK_SIZE) {
		printf("invalid wrapped len\n");
		return -1;
	}

	/* ECDH */
	if (ecc_ecdh(ap_key, epk1, &z, &zlen) < 0) {
		printf("ecdh error!\n");
		return -1;
	}

	bufprintf("DPP: ECDH shared secret (z')", z, zlen);

	/* KHDF-SHA256 */
	if (derive_k(z, zlen, k, hashlen) < 0) {
		printf("error derive_k\n");
		return -1;
	}

	/* Decrypt C */
	unwrapped_len = wrapped_len - AES_BLOCK_SIZE;
	unwrapped = calloc(1, unwrapped_len + 1); 	// +1 for '\0' of uri
	if (!unwrapped) {
		printf("error calloc unwrapped\n");
		return -1;
	}

	ret = AES_SIV_DECRYPT(k, 32, pos, wrapped_len, 0, NULL, NULL, unwrapped);
	if (ret) {
		printf("AES-SIV decrypt failed\n");
		free(unwrapped);
		return -1;
	}
	bufprintf("AES-SIV plaintext", unwrapped, unwrapped_len);

	if (uri)
		*uri = (char *)unwrapped;

	printf("URI (len = %zu): %s\n", unwrapped_len, unwrapped);
	printf("SUCCESS !!!\n");
	return 0;
}

uint8_t *dpp_gen_pa_vsie_frame(uint8_t *payload, size_t len, size_t *olen)
{
	uint8_t *msg, *head;
	uint8_t oui[3] = {0xb4, 0x56, 0xfa};	/* IOPSYS OUI */
	uint8_t ouitype = 0xa;			/* Encoded dpp bootstrap uri */

	*olen = 0;
	msg = calloc(1, len + 6);
	if (!msg)
		return NULL;

	*olen = len + 6;
	head = msg;
	bufptr_put_u8(msg, 0x4);	/* public action */
	bufptr_put_u8(msg, 0x9);	/* public action vendor specific */
	bufptr_put(msg, oui, 3);
	bufptr_put_u8(msg, ouitype);
	memcpy(msg, payload, len);
	//bufprintf("PA-dpp-bs-uri", head, 8);

	return head;
}

uint8_t *dpp_gen_encoded_uri_payload(void *peer_pubkey, void *own_bi_key, const char *uri, size_t *olen)
{
	unsigned int hashlen = 32;	/* for P-256 curve */
	uint8_t *z, *msg = NULL;
	uint8_t *payload = NULL;
	size_t zlen, msglen;
	size_t epk_len = 0;
	size_t plen = 0;
	uint8_t k[64];
	uint8_t *epk;
	int ret;

	if (!uri || !peer_pubkey || !own_bi_key || !olen)
		return NULL;

	*olen = 0;
	printf("URI (len = %zu): %s\n", strlen(uri), uri);

	/* epk */
	epk = ecc_key_get_spki(own_bi_key, &epk_len);
	if (!epk) {
		printf("failed to get epk\n");
		return NULL;
	}

	/* ECDH */
	if (ecc_ecdh(own_bi_key, peer_pubkey, &z, &zlen) < 0) {
		printf("ecdh error!\n");
		return NULL;
	}

	bufprintf("ECDH shared secret (z)", z, zlen);

	/* KHDF-SHA256 */
	if (derive_k(z, zlen, k, hashlen) < 0) {
		printf("error derive_k\n");
		goto out;
	}

	/* AES-SIV encrypt */
	msglen = strlen(uri) + AES_BLOCK_SIZE;
	msg = calloc(1, msglen);
	if (!msg)
		goto out;

	ret = AES_SIV_ENCRYPT(k, 32, (const uint8_t *)uri, strlen(uri), 0, NULL, NULL, msg);
	if (ret) {
		printf("AES-SIV encrypt failed\n");
		goto out;
	}
	//bufprintf("AES-SIV ciphertext", msg, msglen);

	/* construct payload = epk | msg */
	payload = calloc(1, epk_len + msglen);
	if (!payload)
		goto out;

	buf_put_le16(payload, epk_len);
	memcpy(payload + 2, epk, epk_len);
	plen += 2 + epk_len;

	buf_put_le16(&payload[plen], msglen);
	memcpy(&payload[plen + 2], msg, msglen);
	plen += 2 + msglen;
	*olen = plen;

out:
	memset(k, 0, sizeof(k));
	free(z);
	free(msg);

	return payload;
}

uint8_t *dpp_gen_pa_vsie_frame2(uint8_t *payload, size_t len, size_t *olen)
{
	uint8_t *msg, *head;
	uint8_t oui[3] = {0xb4, 0x56, 0xfa};	/* IOPSYS OUI */
	uint8_t ouitype = 0xb;			/* Encoded dpp bootstrap uri */

	*olen = 0;
	msg = calloc(1, len + 6);
	if (!msg)
		return NULL;

	*olen = len + 6;
	head = msg;
	bufptr_put_u8(msg, 0x4);	/* public action */
	bufptr_put_u8(msg, 0x9);	/* public action vendor specific */
	bufptr_put(msg, oui, 3);
	bufptr_put_u8(msg, ouitype);
	memcpy(msg, payload, len);
	//bufprintf("PA-dpp-bs-uri", head, 8);

	return head;
}

uint8_t *dpp_gen_encoded_uri_payload2(char *passphrase, const char *uri, size_t *olen)
{
	unsigned int hashlen = 32;
	uint8_t *payload = NULL;
	uint8_t *msg = NULL;
	size_t plen = 0;
	size_t msglen;
	uint8_t k[64];
	int ret;

	if (!uri || !passphrase || !strlen(passphrase) || !olen)
		return NULL;

	*olen = 0;
	printf("URI (len = %zu): %s\n", strlen(uri), uri);

	/* KHDF-SHA256 */
	if (derive_k((uint8_t *)passphrase, strlen(passphrase), k, hashlen) < 0) {
		printf("error derive_k\n");
		goto out;
	}

	/* AES-SIV encrypt */
	msglen = strlen(uri) + AES_BLOCK_SIZE;
	msg = calloc(1, msglen);
	if (!msg)
		goto out;

	ret = AES_SIV_ENCRYPT(k, 32, (const uint8_t *)uri, strlen(uri), 0, NULL, NULL, msg);
	if (ret) {
		printf("AES-SIV encrypt failed\n");
		goto out;
	}
	//bufprintf("AES-SIV ciphertext", msg, msglen);

	/* construct payload */
	payload = calloc(1, msglen);
	if (!payload)
		goto out;

	buf_put_le16(&payload[plen], msglen);
	memcpy(&payload[plen + 2], msg, msglen);
	plen += 2 + msglen;
	*olen = plen;

out:
	memset(k, 0, sizeof(k));
	free(msg);

	return payload;
}

int dpp_process_pa_vsie_frame2(char *passphrase, uint8_t *frame, size_t framelen, char **uri)
{
	size_t wrapped_len, unwrapped_len;
	unsigned int hashlen = 32;
	uint8_t *unwrapped;
	uint8_t *pos;
	uint8_t k[64];
	int ret;


	if (!passphrase || !strlen(passphrase))
		return -1;

	/* validate frame leading bytes */
	if (!frame || frame[0] != 0x4 || frame[1] != 0x9 ||
	    memcmp(&frame[2], "\xb4\x56\xfa", 3) || frame[5] != 0xb) {
		printf("Ignore invalid dpp-pa-vsie-frame2\n");
		return -1;
	}

	pos = frame + 6;
	wrapped_len = buf_get_le16(pos);
	pos += 2;
	printf("Ciphertext len = %zu\n", wrapped_len);
	if (wrapped_len < AES_BLOCK_SIZE) {
		printf("invalid wrapped len\n");
		return -1;
	}

	/* KHDF-SHA256 */
	if (derive_k((uint8_t *)passphrase, strlen(passphrase), k, hashlen) < 0) {
		printf("error derive_k\n");
		return -1;
	}

	/* Decrypt C */
	unwrapped_len = wrapped_len - AES_BLOCK_SIZE;
	unwrapped = calloc(1, unwrapped_len + 1); 	// +1 for '\0' of uri
	if (!unwrapped) {
		printf("error calloc unwrapped\n");
		return -1;
	}

	ret = AES_SIV_DECRYPT(k, 32, pos, wrapped_len, 0, NULL, NULL, unwrapped);
	if (ret) {
		printf("AES-SIV decrypt failed\n");
		free(unwrapped);
		return -1;
	}
	bufprintf("AES-SIV plaintext", unwrapped, unwrapped_len);

	if (uri)
		*uri = (char *)unwrapped;

	printf("URI (len = %zu): %s\n", unwrapped_len, unwrapped);
	printf("SUCCESS !!!\n");
	return 0;
}
