/* BEGIN_HEADER */

#include "psa/crypto.h"
#include "test/psa_crypto_helpers.h"

static int test_equal_status( const char *test,
                              int line_no, const char* filename,
                              psa_status_t value1,
                              psa_status_t value2 )
{
    if( ( value1 == PSA_ERROR_INVALID_ARGUMENT &&
          value2 == PSA_ERROR_NOT_SUPPORTED ) ||
        ( value1 == PSA_ERROR_NOT_SUPPORTED &&
          value2 == PSA_ERROR_INVALID_ARGUMENT ) )
    {
        return( 1 );
    }
    return( mbedtls_test_equal( test, line_no, filename, value1, value2 ) );
}

/** Like #TEST_EQUAL, but expects #psa_status_t values and treats
 * #PSA_ERROR_INVALID_ARGUMENT and #PSA_ERROR_NOT_SUPPORTED as
 * interchangeable.
 *
 * This test suite currently allows NOT_SUPPORTED and INVALID_ARGUMENT
 * to be interchangeable in places where the library's behavior does not
 * match the strict expectations of the test case generator. In the long
 * run, it would be better to clarify the expectations and reconcile the
 * library and the test case generator.
 */
#define TEST_STATUS( expr1, expr2 )                                     \
    do {                                                                \
        if( ! test_equal_status( #expr1 " == " #expr2, __LINE__, __FILE__, \
                                 expr1, expr2 ) )                       \
            goto exit;                                                  \
    } while( 0 )

/* END_HEADER */

/* BEGIN_DEPENDENCIES
 * depends_on:MBEDTLS_PSA_CRYPTO_C
 * END_DEPENDENCIES
 */

/* BEGIN_CASE */
void hash_fail( int alg_arg, int expected_status_arg )
{
    psa_status_t expected_status = expected_status_arg;
    psa_algorithm_t alg = alg_arg;
    psa_hash_operation_t operation = PSA_HASH_OPERATION_INIT;
    uint8_t input[1] = {'A'};
    uint8_t output[PSA_HASH_MAX_SIZE] = {0};
    size_t length = SIZE_MAX;

    PSA_INIT( );

    TEST_EQUAL( expected_status,
                psa_hash_setup( &operation, alg ) );
    TEST_EQUAL( expected_status,
                psa_hash_compute( alg, input, sizeof( input ),
                                  output, sizeof( output ), &length ) );
    TEST_EQUAL( expected_status,
                psa_hash_compare( alg, input, sizeof( input ),
                                  output, sizeof( output ) ) );

exit:
    psa_hash_abort( &operation );
    PSA_DONE( );
}
/* END_CASE */

/* BEGIN_CASE */
void mac_fail( int key_type_arg, data_t *key_data,
               int alg_arg, int expected_status_arg )
{
    psa_status_t expected_status = expected_status_arg;
    psa_key_type_t key_type = key_type_arg;
    psa_algorithm_t alg = alg_arg;
    psa_mac_operation_t operation = PSA_MAC_OPERATION_INIT;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
    mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
    uint8_t input[1] = {'A'};
    uint8_t output[PSA_MAC_MAX_SIZE] = {0};
    size_t length = SIZE_MAX;

    PSA_INIT( );

    psa_set_key_type( &attributes, key_type );
    psa_set_key_usage_flags( &attributes,
                             PSA_KEY_USAGE_SIGN_HASH |
                             PSA_KEY_USAGE_VERIFY_HASH );
    psa_set_key_algorithm( &attributes, alg );
    PSA_ASSERT( psa_import_key( &attributes,
                                key_data->x, key_data->len,
                                &key_id ) );

    TEST_STATUS( expected_status,
                 psa_mac_sign_setup( &operation, key_id, alg ) );
    TEST_STATUS( expected_status,
                 psa_mac_verify_setup( &operation, key_id, alg ) );
    TEST_STATUS( expected_status,
                 psa_mac_compute( key_id, alg,
                                  input, sizeof( input ),
                                  output, sizeof( output ), &length ) );
    TEST_STATUS( expected_status,
                 psa_mac_verify( key_id, alg,
                                 input, sizeof( input ),
                                 output, sizeof( output ) ) );

exit:
    psa_mac_abort( &operation );
    psa_destroy_key( key_id );
    psa_reset_key_attributes( &attributes );
    PSA_DONE( );
}
/* END_CASE */

/* BEGIN_CASE */
void cipher_fail( int key_type_arg, data_t *key_data,
                  int alg_arg, int expected_status_arg )
{
    psa_status_t expected_status = expected_status_arg;
    psa_key_type_t key_type = key_type_arg;
    psa_algorithm_t alg = alg_arg;
    psa_cipher_operation_t operation = PSA_CIPHER_OPERATION_INIT;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
    mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
    uint8_t input[1] = {'A'};
    uint8_t output[64] = {0};
    size_t length = SIZE_MAX;

    PSA_INIT( );

    psa_set_key_type( &attributes, key_type );
    psa_set_key_usage_flags( &attributes,
                             PSA_KEY_USAGE_ENCRYPT |
                             PSA_KEY_USAGE_DECRYPT );
    psa_set_key_algorithm( &attributes, alg );
    PSA_ASSERT( psa_import_key( &attributes,
                                key_data->x, key_data->len,
                                &key_id ) );

    TEST_STATUS( expected_status,
                 psa_cipher_encrypt_setup( &operation, key_id, alg ) );
    TEST_STATUS( expected_status,
                 psa_cipher_decrypt_setup( &operation, key_id, alg ) );
    TEST_STATUS( expected_status,
                 psa_cipher_encrypt( key_id, alg,
                                     input, sizeof( input ),
                                     output, sizeof( output ), &length ) );
    TEST_STATUS( expected_status,
                 psa_cipher_decrypt( key_id, alg,
                                     input, sizeof( input ),
                                     output, sizeof( output ), &length ) );

exit:
    psa_cipher_abort( &operation );
    psa_destroy_key( key_id );
    psa_reset_key_attributes( &attributes );
    PSA_DONE( );
}
/* END_CASE */

/* BEGIN_CASE */
void aead_fail( int key_type_arg, data_t *key_data,
                int alg_arg, int expected_status_arg )
{
    psa_status_t expected_status = expected_status_arg;
    psa_key_type_t key_type = key_type_arg;
    psa_algorithm_t alg = alg_arg;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
    mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
    uint8_t input[16] = "ABCDEFGHIJKLMNO";
    uint8_t output[64] = {0};
    size_t length = SIZE_MAX;

    PSA_INIT( );

    psa_set_key_type( &attributes, key_type );
    psa_set_key_usage_flags( &attributes,
                             PSA_KEY_USAGE_ENCRYPT |
                             PSA_KEY_USAGE_DECRYPT );
    psa_set_key_algorithm( &attributes, alg );
    PSA_ASSERT( psa_import_key( &attributes,
                                key_data->x, key_data->len,
                                &key_id ) );

    TEST_STATUS( expected_status,
                 psa_aead_encrypt( key_id, alg,
                                   input, sizeof( input ),
                                   NULL, 0, input, sizeof( input ),
                                   output, sizeof( output ), &length ) );
    TEST_STATUS( expected_status,
                 psa_aead_decrypt( key_id, alg,
                                   input, sizeof( input ),
                                   NULL, 0, input, sizeof( input ),
                                   output, sizeof( output ), &length ) );

exit:
    psa_destroy_key( key_id );
    psa_reset_key_attributes( &attributes );
    PSA_DONE( );
}
/* END_CASE */

/* BEGIN_CASE */
void sign_fail( int key_type_arg, data_t *key_data,
                int alg_arg, int private_only,
                int expected_status_arg )
{
    psa_status_t expected_status = expected_status_arg;
    psa_key_type_t key_type = key_type_arg;
    psa_algorithm_t alg = alg_arg;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
    mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
    uint8_t input[1] = {'A'};
    uint8_t output[PSA_SIGNATURE_MAX_SIZE] = {0};
    size_t length = SIZE_MAX;

    PSA_INIT( );

    psa_set_key_type( &attributes, key_type );
    psa_set_key_usage_flags( &attributes,
                             PSA_KEY_USAGE_SIGN_HASH |
                             PSA_KEY_USAGE_VERIFY_HASH );
    psa_set_key_algorithm( &attributes, alg );
    PSA_ASSERT( psa_import_key( &attributes,
                                key_data->x, key_data->len,
                                &key_id ) );

    TEST_STATUS( expected_status,
                 psa_sign_hash( key_id, alg,
                                input, sizeof( input ),
                                output, sizeof( output ), &length ) );
    if( ! private_only )
    {
        /* Determine a plausible signature size to avoid an INVALID_SIGNATURE
         * error based on this. */
        PSA_ASSERT( psa_get_key_attributes( key_id, &attributes ) );
        size_t key_bits = psa_get_key_bits( &attributes );
        size_t output_length = sizeof( output );
        if( PSA_KEY_TYPE_IS_RSA( key_type ) )
            output_length = PSA_BITS_TO_BYTES( key_bits );
        else if( PSA_KEY_TYPE_IS_ECC( key_type ) )
            output_length = 2 * PSA_BITS_TO_BYTES( key_bits );
        TEST_ASSERT( output_length <= sizeof( output ) );
        TEST_STATUS( expected_status,
                     psa_verify_hash( key_id, alg,
                                      input, sizeof( input ),
                                      output, output_length ) );
    }

exit:
    psa_destroy_key( key_id );
    psa_reset_key_attributes( &attributes );
    PSA_DONE( );
}
/* END_CASE */

/* BEGIN_CASE */
void asymmetric_encryption_fail( int key_type_arg, data_t *key_data,
                                 int alg_arg, int private_only,
                                 int expected_status_arg )
{
    psa_status_t expected_status = expected_status_arg;
    psa_key_type_t key_type = key_type_arg;
    psa_algorithm_t alg = alg_arg;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
    mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
    uint8_t plaintext[PSA_ASYMMETRIC_DECRYPT_OUTPUT_MAX_SIZE] = {0};
    uint8_t ciphertext[PSA_ASYMMETRIC_ENCRYPT_OUTPUT_MAX_SIZE] = {0};
    size_t length = SIZE_MAX;

    PSA_INIT( );

    psa_set_key_type( &attributes, key_type );
    psa_set_key_usage_flags( &attributes,
                             PSA_KEY_USAGE_ENCRYPT |
                             PSA_KEY_USAGE_DECRYPT );
    psa_set_key_algorithm( &attributes, alg );
    PSA_ASSERT( psa_import_key( &attributes,
                                key_data->x, key_data->len,
                                &key_id ) );

    if( ! private_only )
    {
        TEST_STATUS( expected_status,
                     psa_asymmetric_encrypt( key_id, alg,
                                             plaintext, 1,
                                             NULL, 0,
                                             ciphertext, sizeof( ciphertext ),
                                             &length ) );
    }
    TEST_STATUS( expected_status,
                 psa_asymmetric_decrypt( key_id, alg,
                                         ciphertext, sizeof( ciphertext ),
                                         NULL, 0,
                                         plaintext, sizeof( plaintext ),
                                         &length ) );

exit:
    psa_destroy_key( key_id );
    psa_reset_key_attributes( &attributes );
    PSA_DONE( );
}
/* END_CASE */

/* BEGIN_CASE */
void key_derivation_fail( int alg_arg, int expected_status_arg )
{
    psa_status_t expected_status = expected_status_arg;
    psa_algorithm_t alg = alg_arg;
    psa_key_derivation_operation_t operation = PSA_KEY_DERIVATION_OPERATION_INIT;

    PSA_INIT( );

    TEST_EQUAL( expected_status,
                psa_key_derivation_setup( &operation, alg ) );

exit:
    psa_key_derivation_abort( &operation );
    PSA_DONE( );
}
/* END_CASE */

/* BEGIN_CASE */
void key_agreement_fail( int key_type_arg, data_t *key_data,
                         int alg_arg, int private_only,
                         int expected_status_arg )
{
    psa_status_t expected_status = expected_status_arg;
    psa_key_type_t key_type = key_type_arg;
    psa_algorithm_t alg = alg_arg;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
    mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
    uint8_t public_key[PSA_EXPORT_PUBLIC_KEY_MAX_SIZE] = {0};
    size_t public_key_length = SIZE_MAX;
    uint8_t output[PSA_SIGNATURE_MAX_SIZE] = {0};
    size_t length = SIZE_MAX;
    psa_key_derivation_operation_t operation = PSA_KEY_DERIVATION_OPERATION_INIT;

    PSA_INIT( );

    psa_set_key_type( &attributes, key_type );
    psa_set_key_usage_flags( &attributes,
                             PSA_KEY_USAGE_DERIVE );
    psa_set_key_algorithm( &attributes, alg );
    PSA_ASSERT( psa_import_key( &attributes,
                                key_data->x, key_data->len,
                                &key_id ) );
    if( PSA_KEY_TYPE_IS_KEY_PAIR( key_type ) ||
        PSA_KEY_TYPE_IS_PUBLIC_KEY( key_type ) )
    {
        PSA_ASSERT( psa_export_public_key( key_id,
                                           public_key, sizeof( public_key ),
                                           &public_key_length ) );
    }

    TEST_STATUS( expected_status,
                 psa_raw_key_agreement( alg, key_id,
                                        public_key, public_key_length,
                                        output, sizeof( output ), &length ) );

#if defined(PSA_WANT_ALG_HKDF) && defined(PSA_WANT_ALG_SHA_256)
    PSA_ASSERT( psa_key_derivation_setup( &operation,
                                          PSA_ALG_HKDF( PSA_ALG_SHA_256 ) ) );
    TEST_STATUS( expected_status,
                 psa_key_derivation_key_agreement(
                     &operation,
                     PSA_KEY_DERIVATION_INPUT_SECRET,
                     key_id,
                     public_key, public_key_length ) );
#endif

    /* There are no public-key operations. */
    (void) private_only;

exit:
    psa_key_derivation_abort( &operation );
    psa_destroy_key( key_id );
    psa_reset_key_attributes( &attributes );
    PSA_DONE( );
}
/* END_CASE */
