diff --git a/drivers/crypto/vmx/aes_ctr.c b/drivers/crypto/vmx/aes_ctr.c index 38ed10d761d0..7cf6d31c1123 100644 --- a/drivers/crypto/vmx/aes_ctr.c +++ b/drivers/crypto/vmx/aes_ctr.c @@ -80,11 +80,13 @@ static int p8_aes_ctr_setkey(struct crypto_tfm *tfm, const u8 *key, int ret; struct p8_aes_ctr_ctx *ctx = crypto_tfm_ctx(tfm); + preempt_disable(); pagefault_disable(); enable_kernel_vsx(); ret = aes_p8_set_encrypt_key(key, keylen * 8, &ctx->enc_key); disable_kernel_vsx(); pagefault_enable(); + preempt_enable(); ret += crypto_blkcipher_setkey(ctx->fallback, key, keylen); return ret; @@ -99,11 +101,13 @@ static void p8_aes_ctr_final(struct p8_aes_ctr_ctx *ctx, u8 *dst = walk->dst.virt.addr; unsigned int nbytes = walk->nbytes; + preempt_disable(); pagefault_disable(); enable_kernel_vsx(); aes_p8_encrypt(ctrblk, keystream, &ctx->enc_key); disable_kernel_vsx(); pagefault_enable(); + preempt_enable(); crypto_xor(keystream, src, nbytes); memcpy(dst, keystream, nbytes); @@ -132,6 +136,7 @@ static int p8_aes_ctr_crypt(struct blkcipher_desc *desc, blkcipher_walk_init(&walk, dst, src, nbytes); ret = blkcipher_walk_virt_block(desc, &walk, AES_BLOCK_SIZE); while ((nbytes = walk.nbytes) >= AES_BLOCK_SIZE) { + preempt_disable(); pagefault_disable(); enable_kernel_vsx(); aes_p8_ctr32_encrypt_blocks(walk.src.virt.addr, @@ -143,6 +148,7 @@ static int p8_aes_ctr_crypt(struct blkcipher_desc *desc, walk.iv); disable_kernel_vsx(); pagefault_enable(); + preempt_enable(); /* We need to update IV mostly for last bytes/round */ inc = (nbytes & AES_BLOCK_MASK) / AES_BLOCK_SIZE;