/* BEGIN_HEADER */

#include <psa/crypto.h>

#include <test/psa_crypto_helpers.h>
#include <test/psa_exercise_key.h>

#include <psa_crypto_its.h>

#define TEST_FLAG_EXERCISE      0x00000001
#define TEST_FLAG_READ_ONLY     0x00000002

/** Write a key with the given attributes and key material to storage.
 * Test that it has the expected representation.
 *
 * On error, including if the key representation in storage differs,
 * mark the test case as failed and return 0. On success, return 1.
 */
static int test_written_key( const psa_key_attributes_t *attributes,
                             const data_t *material,
                             psa_storage_uid_t uid,
                             const data_t *expected_representation )
{
    mbedtls_svc_key_id_t created_key_id = MBEDTLS_SVC_KEY_ID_INIT;
    uint8_t *actual_representation = NULL;
    size_t length;
    struct psa_storage_info_t storage_info;
    int ok = 0;

    /* Create a key with the given parameters. */
    PSA_ASSERT( psa_import_key( attributes, material->x, material->len,
                                &created_key_id ) );
    TEST_ASSERT( mbedtls_svc_key_id_equal( psa_get_key_id( attributes ),
                                           created_key_id ) );

    /* Check that the key is represented as expected. */
    PSA_ASSERT( psa_its_get_info( uid, &storage_info ) );
    TEST_EQUAL( storage_info.size, expected_representation->len );
    ASSERT_ALLOC( actual_representation, storage_info.size );
    PSA_ASSERT( psa_its_get( uid, 0, storage_info.size,
                             actual_representation, &length ) );
    ASSERT_COMPARE( expected_representation->x, expected_representation->len,
                    actual_representation, length );

    ok = 1;

exit:
    mbedtls_free( actual_representation );
    return( ok );
}

/** Check if a key is exportable. */
static int can_export( const psa_key_attributes_t *attributes )
{
    if( psa_get_key_usage_flags( attributes ) & PSA_KEY_USAGE_EXPORT )
        return( 1 );
    else if( PSA_KEY_TYPE_IS_PUBLIC_KEY( psa_get_key_type( attributes ) ) )
        return( 1 );
    else
        return( 0 );
}

#if defined(MBEDTLS_TEST_LIBTESTDRIVER1)
static int is_accelerated_rsa( psa_algorithm_t alg )
{
#if defined(MBEDTLS_PSA_ACCEL_ALG_RSA_PKCS1V15_SIGN)
    if ( PSA_ALG_IS_RSA_PKCS1V15_SIGN( alg ) )
        return( 1 );
#endif
#if defined(MBEDTLS_PSA_ACCEL_ALG_RSA_PSS)
    if( PSA_ALG_IS_RSA_PSS( alg ) )
        return( 1 );
#endif
#if defined(MBEDTLS_PSA_ACCEL_ALG_RSA_OAEP)
    if( PSA_ALG_IS_RSA_OAEP( alg ) )
        return( 1 );
#endif
    (void) alg;
    return( 0 );
}

/* Whether the algorithm is implemented as a builtin, i.e. not accelerated,
 * and calls mbedtls_md() functions that require the hash algorithm to
 * also be built-in. */
static int is_builtin_calling_md( psa_algorithm_t alg )
{
#if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN)
    if( PSA_ALG_IS_RSA_PKCS1V15_SIGN( alg ) )
        return( 1 );
#endif
#if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS)
    if( PSA_ALG_IS_RSA_PSS( alg ) )
        return( 1 );
#endif
#if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
    if( PSA_ALG_IS_RSA_OAEP( alg ) )
        return( 1 );
#endif
#if defined(MBEDTLS_PSA_BUILTIN_ALG_DETERMINISTIC_ECDSA)
    if( PSA_ALG_IS_DETERMINISTIC_ECDSA( alg ) )
        return( 1 );
#endif
    (void) alg;
    return( 0 );
}

static int has_builtin_hash( psa_algorithm_t alg )
{
#if !defined(MBEDTLS_MD2_C)
    if( alg == PSA_ALG_MD2 )
        return( 0 );
#endif
#if !defined(MBEDTLS_MD4_C)
    if( alg == PSA_ALG_MD4 )
        return( 0 );
#endif
#if !defined(MBEDTLS_MD5_C)
    if( alg == PSA_ALG_MD5 )
        return( 0 );
#endif
#if !defined(MBEDTLS_RIPEMD160_C)
    if( alg == PSA_ALG_RIPEMD160 )
        return( 0 );
#endif
#if !defined(MBEDTLS_SHA1_C)
    if( alg == PSA_ALG_SHA_1 )
        return( 0 );
#endif
#if !defined(MBEDTLS_SHA224_C)
    if( alg == PSA_ALG_SHA_224 )
        return( 0 );
#endif
#if !defined(MBEDTLS_SHA256_C)
    if( alg == PSA_ALG_SHA_256 )
        return( 0 );
#endif
#if !defined(MBEDTLS_SHA384_C)
    if( alg == PSA_ALG_SHA_384 )
        return( 0 );
#endif
#if !defined(MBEDTLS_SHA512_C)
    if( alg == PSA_ALG_SHA_512 )
        return( 0 );
#endif
    (void) alg;
    return( 1 );
}
#endif

/* Mbed TLS doesn't support certain combinations of key type and algorithm
 * in certain configurations. */
static int can_exercise( const psa_key_attributes_t *attributes )
{
    psa_key_type_t key_type = psa_get_key_type( attributes );
    psa_algorithm_t alg = psa_get_key_algorithm( attributes );
    psa_algorithm_t hash_alg =
        PSA_ALG_IS_HASH_AND_SIGN( alg ) ? PSA_ALG_SIGN_GET_HASH( alg ) :
        PSA_ALG_IS_RSA_OAEP( alg ) ? PSA_ALG_RSA_OAEP_GET_HASH( alg ) :
        PSA_ALG_NONE;
    psa_key_usage_t usage = psa_get_key_usage_flags( attributes );

#if defined(MBEDTLS_TEST_LIBTESTDRIVER1)
    /* We test some configurations using drivers where the driver doesn't
     * support certain hash algorithms, but declares that it supports
     * compound algorithms that use those hashes. Until this is fixed,
     * in those configurations, don't try to actually perform operations.
     *
     * Hash-and-sign algorithms where the asymmetric part doesn't use
     * a hash operation are ok. So randomized ECDSA signature is fine,
     * ECDSA verification is fine, but deterministic ECDSA signature is
     * affected. All RSA signatures are affected except raw PKCS#1v1.5.
     * OAEP is also affected.
     */
    if( PSA_ALG_IS_DETERMINISTIC_ECDSA( alg ) &&
        ! ( usage & ( PSA_KEY_USAGE_SIGN_HASH | PSA_KEY_USAGE_SIGN_MESSAGE ) ) )
    {
        /* Verification only. Verification doesn't use the hash algorithm. */
        return( 1 );
    }

#if defined(MBEDTLS_PSA_ACCEL_ALG_DETERMINISTIC_ECDSA)
    if( PSA_ALG_IS_DETERMINISTIC_ECDSA( alg ) &&
        ( hash_alg == PSA_ALG_MD5 ||
          hash_alg == PSA_ALG_RIPEMD160 ||
          hash_alg == PSA_ALG_SHA_1 ) )
    {
        return( 0 );
    }
#endif
    if( is_accelerated_rsa( alg ) &&
        ( hash_alg == PSA_ALG_RIPEMD160 || hash_alg == PSA_ALG_SHA_384 ) )
    {
            return( 0 );
    }
#if defined(MBEDTLS_PSA_ACCEL_ALG_RSA_OAEP)
    if( PSA_ALG_IS_RSA_OAEP( alg ) &&
        ( hash_alg == PSA_ALG_RIPEMD160 || hash_alg == PSA_ALG_SHA_384 ) )
    {
        return( 0 );
    }
#endif

    /* The built-in implementation of asymmetric algorithms that use a
     * hash internally only dispatch to the internal md module, not to
     * PSA. Until this is supported, don't try to actually perform
     * operations when the operation is built-in and the hash isn't.  */
    if( is_builtin_calling_md( alg ) && ! has_builtin_hash( hash_alg ) )
    {
        return( 0 );
    }
#endif /* MBEDTLS_TEST_LIBTESTDRIVER1 */

    (void) key_type;
    (void) alg;
    (void) hash_alg;
    (void) usage;
    return( 1 );
}

/** Write a key with the given representation to storage, then check
 * that it has the given attributes and (if exportable) key material.
 *
 * On error, including if the key representation in storage differs,
 * mark the test case as failed and return 0. On success, return 1.
 */
static int test_read_key( const psa_key_attributes_t *expected_attributes,
                          const data_t *expected_material,
                          psa_storage_uid_t uid,
                          const data_t *representation,
                          int flags )
{
    psa_key_attributes_t actual_attributes = PSA_KEY_ATTRIBUTES_INIT;
    mbedtls_svc_key_id_t key_id = psa_get_key_id( expected_attributes );
    struct psa_storage_info_t storage_info;
    int ok = 0;
    uint8_t *exported_material = NULL;
    size_t length;

    /* Prime the storage with a key file. */
    PSA_ASSERT( psa_its_set( uid, representation->len, representation->x, 0 ) );

    /* Check that the injected key exists and looks as expected. */
    PSA_ASSERT( psa_get_key_attributes( key_id, &actual_attributes ) );
    TEST_ASSERT( mbedtls_svc_key_id_equal( key_id,
                                           psa_get_key_id( &actual_attributes ) ) );
    TEST_EQUAL( psa_get_key_lifetime( expected_attributes ),
                psa_get_key_lifetime( &actual_attributes ) );
    TEST_EQUAL( psa_get_key_type( expected_attributes ),
                psa_get_key_type( &actual_attributes ) );
    TEST_EQUAL( psa_get_key_bits( expected_attributes ),
                psa_get_key_bits( &actual_attributes ) );
    TEST_EQUAL( psa_get_key_usage_flags( expected_attributes ),
                psa_get_key_usage_flags( &actual_attributes ) );
    TEST_EQUAL( psa_get_key_algorithm( expected_attributes ),
                psa_get_key_algorithm( &actual_attributes ) );
    TEST_EQUAL( psa_get_key_enrollment_algorithm( expected_attributes ),
                psa_get_key_enrollment_algorithm( &actual_attributes ) );
    if( can_export( expected_attributes ) )
    {
        ASSERT_ALLOC( exported_material, expected_material->len );
        PSA_ASSERT( psa_export_key( key_id,
                                    exported_material, expected_material->len,
                                    &length ) );
        ASSERT_COMPARE( expected_material->x, expected_material->len,
                        exported_material, length );
    }

    if( ( flags & TEST_FLAG_EXERCISE ) && can_exercise( &actual_attributes ) )
    {
        TEST_ASSERT( mbedtls_test_psa_exercise_key(
                         key_id,
                         psa_get_key_usage_flags( expected_attributes ),
                         psa_get_key_algorithm( expected_attributes ) ) );
    }


    if( flags & TEST_FLAG_READ_ONLY )
    {
        /* Read-only keys cannot be removed through the API.
         * The key will be removed through ITS in the cleanup code below. */
        TEST_EQUAL( PSA_ERROR_NOT_PERMITTED, psa_destroy_key( key_id ) );
    }
    else
    {
        /* Destroy the key. Confirm through direct access to the storage. */
        PSA_ASSERT( psa_destroy_key( key_id ) );
        TEST_EQUAL( PSA_ERROR_DOES_NOT_EXIST,
                    psa_its_get_info( uid, &storage_info ) );
    }

    ok = 1;

exit:
    psa_reset_key_attributes( &actual_attributes );
    psa_its_remove( uid );
    mbedtls_free( exported_material );
    return( ok );
}

/* END_HEADER */

/* BEGIN_DEPENDENCIES
 * depends_on:MBEDTLS_PSA_CRYPTO_C:MBEDTLS_PSA_CRYPTO_STORAGE_C
 * END_DEPENDENCIES
 */

/* BEGIN_CASE */
void key_storage_save( int lifetime_arg, int type_arg, int bits_arg,
                       int usage_arg, int alg_arg, int alg2_arg,
                       data_t *material,
                       data_t *representation )
{
    /* Forward compatibility: save a key in the current format and
     * check that it has the expected format so that future versions
     * will still be able to read it. */

    psa_key_lifetime_t lifetime = lifetime_arg;
    psa_key_type_t type = type_arg;
    size_t bits = bits_arg;
    psa_key_usage_t usage = usage_arg;
    psa_algorithm_t alg = alg_arg;
    psa_algorithm_t alg2 = alg2_arg;
    mbedtls_svc_key_id_t key_id = mbedtls_svc_key_id_make( 0, 1 );
    psa_storage_uid_t uid = 1;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;

    PSA_INIT( );
    TEST_USES_KEY_ID( key_id );

    psa_set_key_lifetime( &attributes, lifetime );
    psa_set_key_id( &attributes, key_id );
    psa_set_key_type( &attributes, type );
    psa_set_key_bits( &attributes, bits );
    psa_set_key_usage_flags( &attributes, usage );
    psa_set_key_algorithm( &attributes, alg );
    psa_set_key_enrollment_algorithm( &attributes, alg2 );

    /* This is the current storage format. Test that we know exactly how
     * the key is stored. The stability of the test data in future
     * versions of Mbed TLS will guarantee that future versions
     * can read back what this version wrote. */
    TEST_ASSERT( test_written_key( &attributes, material,
                                   uid, representation ) );

exit:
    psa_reset_key_attributes( &attributes );
    psa_destroy_key( key_id );
    PSA_DONE( );
}
/* END_CASE */

/* BEGIN_CASE */
void key_storage_read( int lifetime_arg, int type_arg, int bits_arg,
                       int usage_arg, int alg_arg, int alg2_arg,
                       data_t *material,
                       data_t *representation, int flags )
{
    /* Backward compatibility: read a key in the format of a past version
     * and check that this version can use it. */

    psa_key_lifetime_t lifetime = lifetime_arg;
    psa_key_type_t type = type_arg;
    size_t bits = bits_arg;
    psa_key_usage_t usage = usage_arg;
    psa_algorithm_t alg = alg_arg;
    psa_algorithm_t alg2 = alg2_arg;
    mbedtls_svc_key_id_t key_id = mbedtls_svc_key_id_make( 0, 1 );
    psa_storage_uid_t uid = 1;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;

    PSA_INIT( );
    TEST_USES_KEY_ID( key_id );

    psa_set_key_lifetime( &attributes, lifetime );
    psa_set_key_id( &attributes, key_id );
    psa_set_key_type( &attributes, type );
    psa_set_key_bits( &attributes, bits );
    psa_set_key_usage_flags( &attributes, usage );
    psa_set_key_algorithm( &attributes, alg );
    psa_set_key_enrollment_algorithm( &attributes, alg2 );

    /* Test that we can use a key with the given representation. This
     * guarantees backward compatibility with keys that were stored by
     * past versions of Mbed TLS. */
    TEST_ASSERT( test_read_key( &attributes, material,
                                uid, representation, flags ) );

exit:
    psa_reset_key_attributes( &attributes );
    PSA_DONE( );
}
/* END_CASE */
