/* BEGIN_HEADER */
#include "mbedtls/hkdf.h"
#include "mbedtls/md_internal.h"
/* END_HEADER */

/* BEGIN_DEPENDENCIES
 * depends_on:MBEDTLS_HKDF_C
 * END_DEPENDENCIES
 */

/* BEGIN_CASE */
void test_hkdf(int md_alg, data_t *ikm, data_t *salt, data_t *info,
               data_t *expected_okm)
{
    int ret;
    unsigned char okm[128] = { '\0' };

    const mbedtls_md_info_t *md = mbedtls_md_info_from_type(md_alg);
    TEST_ASSERT(md != NULL);

    TEST_ASSERT(expected_okm->len <= sizeof(okm));

    ret = mbedtls_hkdf(md, salt->x, salt->len, ikm->x, ikm->len,
                       info->x, info->len, okm, expected_okm->len);
    TEST_ASSERT(ret == 0);

    TEST_MEMORY_COMPARE(okm, expected_okm->len,
                        expected_okm->x, expected_okm->len);
}
/* END_CASE */

/* BEGIN_CASE */
void test_hkdf_extract(int md_alg, char *hex_ikm_string,
                       char *hex_salt_string, char *hex_prk_string)
{
    int ret;
    unsigned char *ikm = NULL;
    unsigned char *salt = NULL;
    unsigned char *prk = NULL;
    unsigned char *output_prk = NULL;
    size_t ikm_len, salt_len, prk_len, output_prk_len;

    const mbedtls_md_info_t *md = mbedtls_md_info_from_type(md_alg);
    TEST_ASSERT(md != NULL);

    output_prk_len = mbedtls_md_get_size(md);
    output_prk = mbedtls_calloc(1, output_prk_len);

    ikm = mbedtls_test_unhexify_alloc(hex_ikm_string, &ikm_len);
    salt = mbedtls_test_unhexify_alloc(hex_salt_string, &salt_len);
    prk = mbedtls_test_unhexify_alloc(hex_prk_string, &prk_len);

    ret = mbedtls_hkdf_extract(md, salt, salt_len, ikm, ikm_len, output_prk);
    TEST_ASSERT(ret == 0);

    TEST_MEMORY_COMPARE(output_prk, output_prk_len, prk, prk_len);

exit:
    mbedtls_free(ikm);
    mbedtls_free(salt);
    mbedtls_free(prk);
    mbedtls_free(output_prk);
}
/* END_CASE */

/* BEGIN_CASE */
void test_hkdf_expand(int md_alg, char *hex_info_string,
                      char *hex_prk_string, char *hex_okm_string)
{
    enum { OKM_LEN  = 1024 };
    int ret;
    unsigned char *info = NULL;
    unsigned char *prk = NULL;
    unsigned char *okm = NULL;
    unsigned char *output_okm = NULL;
    size_t info_len, prk_len, okm_len;

    const mbedtls_md_info_t *md = mbedtls_md_info_from_type(md_alg);
    TEST_ASSERT(md != NULL);

    output_okm = mbedtls_calloc(OKM_LEN, 1);

    prk = mbedtls_test_unhexify_alloc(hex_prk_string, &prk_len);
    info = mbedtls_test_unhexify_alloc(hex_info_string, &info_len);
    okm = mbedtls_test_unhexify_alloc(hex_okm_string, &okm_len);
    TEST_ASSERT(prk_len == mbedtls_md_get_size(md));
    TEST_ASSERT(okm_len < OKM_LEN);

    ret = mbedtls_hkdf_expand(md, prk, prk_len, info, info_len,
                              output_okm, OKM_LEN);
    TEST_ASSERT(ret == 0);
    TEST_MEMORY_COMPARE(output_okm, okm_len, okm, okm_len);

exit:
    mbedtls_free(info);
    mbedtls_free(prk);
    mbedtls_free(okm);
    mbedtls_free(output_okm);
}
/* END_CASE */

/* BEGIN_CASE */
void test_hkdf_extract_ret(int hash_len, int ret)
{
    int output_ret;
    unsigned char *salt = NULL;
    unsigned char *ikm = NULL;
    unsigned char *prk = NULL;
    size_t salt_len, ikm_len;
    struct mbedtls_md_info_t fake_md_info;

    memset(&fake_md_info, 0, sizeof(fake_md_info));
    fake_md_info.type = MBEDTLS_MD_NONE;
    fake_md_info.size = hash_len;

    prk = mbedtls_calloc(MBEDTLS_MD_MAX_SIZE, 1);
    salt_len = 0;
    ikm_len = 0;

    output_ret = mbedtls_hkdf_extract(&fake_md_info, salt, salt_len,
                                      ikm, ikm_len, prk);
    TEST_ASSERT(output_ret == ret);

exit:
    mbedtls_free(prk);
}
/* END_CASE */

/* BEGIN_CASE */
void test_hkdf_expand_ret(int hash_len, int prk_len, int okm_len, int ret)
{
    int output_ret;
    unsigned char *info = NULL;
    unsigned char *prk = NULL;
    unsigned char *okm = NULL;
    size_t info_len;
    struct mbedtls_md_info_t fake_md_info;

    memset(&fake_md_info, 0, sizeof(fake_md_info));
    fake_md_info.type = MBEDTLS_MD_NONE;
    fake_md_info.size = hash_len;

    info_len = 0;

    if (prk_len > 0) {
        prk = mbedtls_calloc(prk_len, 1);
    }

    if (okm_len > 0) {
        okm = mbedtls_calloc(okm_len, 1);
    }

    output_ret = mbedtls_hkdf_expand(&fake_md_info, prk, prk_len,
                                     info, info_len, okm, okm_len);
    TEST_ASSERT(output_ret == ret);

exit:
    mbedtls_free(prk);
    mbedtls_free(okm);
}
/* END_CASE */
