/*
* This implementation of Salsa20 is ported from the reference implementation
* by D. J. Bernstein, which can be found at:
* http://cr.yp.to/snuffle/salsa20/ref/salsa20.c
*
* This work is hereby released into the Public Domain. To view a copy of the public domain dedication,
* visit http://creativecommons.org/licenses/publicdomain/ or send a letter to
* Creative Commons, 171 Second Street, Suite 300, San Francisco, California, 94105, USA.
*/
using System;
using System.Diagnostics;
using System.Security.Cryptography;
using System.Text;
namespace Jd.ACES.Common
{
///
/// Implements the Salsa20 stream encryption cipher, as defined at http://cr.yp.to/snuffle.html.
///
/// See Salsa20 Implementation in C#.
public sealed class Salsa20 : SymmetricAlgorithm
{
///
/// Initializes a new instance of the class.
///
/// The implementation of the class derived from the symmetric algorithm is not valid.
public Salsa20()
{
// set legal values
LegalBlockSizesValue = new[] { new KeySizes(512, 512, 0) };
LegalKeySizesValue = new[] { new KeySizes(128, 256, 128) };
// set default values
BlockSizeValue = 512;
KeySizeValue = 256;
m_rounds = 20;
}
///
/// Creates a symmetric decryptor object with the specified property
/// and initialization vector ().
///
/// The secret key to use for the symmetric algorithm.
/// The initialization vector to use for the symmetric algorithm.
/// A symmetric decryptor object.
public override ICryptoTransform CreateDecryptor(byte[] rgbKey, byte[] rgbIV)
{
// decryption and encryption are symmetrical
return CreateEncryptor(rgbKey, rgbIV);
}
///
/// Creates a symmetric encryptor object with the specified property
/// and initialization vector ().
///
/// The secret key to use for the symmetric algorithm.
/// The initialization vector to use for the symmetric algorithm.
/// A symmetric encryptor object.
public override ICryptoTransform CreateEncryptor(byte[] rgbKey, byte[] rgbIV)
{
if (rgbKey == null)
throw new ArgumentNullException("rgbKey");
if (!ValidKeySize(rgbKey.Length * 8))
throw new CryptographicException("Invalid key size; it must be 128 or 256 bits.");
CheckValidIV(rgbIV, "rgbIV");
return new Salsa20CryptoTransform(rgbKey, rgbIV, m_rounds);
}
///
/// Generates a random initialization vector () to use for the algorithm.
///
public override void GenerateIV()
{
// generate a random 8-byte IV
IVValue = GetRandomBytes(8);
}
///
/// Generates a random key () to use for the algorithm.
///
public override void GenerateKey()
{
// generate a random key
KeyValue = GetRandomBytes(KeySize / 8);
}
///
/// Gets or sets the initialization vector () for the symmetric algorithm.
///
/// The initialization vector.
/// An attempt was made to set the initialization vector to null.
/// An attempt was made to set the initialization vector to an invalid size.
public override byte[] IV
{
get
{
return base.IV;
}
set
{
CheckValidIV(value, "value");
IVValue = (byte[])value.Clone();
}
}
///
/// Gets or sets the number of rounds used by the Salsa20 algorithm.
///
/// The number of rounds.
public int Rounds
{
get
{
return m_rounds;
}
set
{
if (value != 8 && value != 12 && value != 20)
throw new ArgumentOutOfRangeException("value", "The number of rounds must be 8, 12, or 20.");
m_rounds = value;
}
}
// Verifies that iv is a legal value for a Salsa20 IV.
private static void CheckValidIV(byte[] iv, string paramName)
{
if (iv == null)
throw new ArgumentNullException(paramName);
if (iv.Length % 8 != 0)
throw new CryptographicException("Invalid IV size; it must be multiple of 8 bytes.");
}
// Returns a new byte array containing the specified number of random bytes.
private static byte[] GetRandomBytes(int byteCount)
{
byte[] bytes = new byte[byteCount];
RandomNumberGenerator rng = RNGCryptoServiceProvider.Create();
rng.GetBytes(bytes);
return bytes;
}
int m_rounds;
///
/// Salsa20Impl is an implementation of that uses the Salsa20 algorithm.
///
private sealed class Salsa20CryptoTransform : ICryptoTransform
{
public Salsa20CryptoTransform(byte[] key, byte[] iv, int rounds)
{
Debug.Assert(key.Length == 16 || key.Length == 32, "abyKey.Length == 16 || abyKey.Length == 32", "Invalid key size.");
Debug.Assert(iv.Length % 8 == 0, "abyIV.Length % 8 == 0", "Invalid IV size.");
Debug.Assert(rounds == 8 || rounds == 12 || rounds == 20, "rounds == 8 || rounds == 12 || rounds == 20", "Invalid number of rounds.");
m_rounds = rounds;
m_key = new byte[key.Length];
m_iv = iv;
Initialize(m_key, key, iv);
}
byte[] m_key;
byte[] m_iv;
public bool CanReuseTransform
{
get { return false; }
}
public bool CanTransformMultipleBlocks
{
get { return true; }
}
public int InputBlockSize
{
get { return 64; }
}
public int OutputBlockSize
{
get { return 64; }
}
public int TransformBlock(byte[] inputBuffer, int inputOffset, int inputCount, byte[] outputBuffer, int outputOffset)
{
// check arguments
if (inputBuffer == null)
throw new ArgumentNullException("inputBuffer");
if (inputOffset < 0 || inputOffset >= inputBuffer.Length)
throw new ArgumentOutOfRangeException("inputOffset");
if (inputCount < 0 || inputOffset + inputCount > inputBuffer.Length)
throw new ArgumentOutOfRangeException("inputCount");
if (outputBuffer == null)
throw new ArgumentNullException("outputBuffer");
if (outputOffset < 0 || outputOffset + inputCount > outputBuffer.Length)
throw new ArgumentOutOfRangeException("outputOffset");
byte[] output = new byte[64];
byte[] inv = new byte[16];
int bytesTransformed = 0;
for (int i = 0; i < 8; i++)
inv[i] = m_iv[16 + i];
for (int i = 8; i < 16; i++)
inv[i] = 0;
while (inputCount >= 0)
{
Initialize(m_key, inv);
Hash(output, m_state);
int blockSize = Math.Min(64, inputCount);
for (int i = 0; i < blockSize; i++)
outputBuffer[outputOffset + i] = (byte)(inputBuffer[inputOffset + i] ^ output[i]);
bytesTransformed += blockSize;
uint u = 1;
for (int i = 8; i < 16; i++)
{
u += inv[i];
inv[i] = (byte)u;
u >>= 8;
}
inputCount -= 64;
outputOffset += 64;
inputOffset += 64;
}
return bytesTransformed;
}
public byte[] TransformFinalBlock(byte[] inputBuffer, int inputOffset, int inputCount)
{
if (inputCount < 0)
throw new ArgumentOutOfRangeException("inputCount");
byte[] output = new byte[inputCount];
TransformBlock(inputBuffer, inputOffset, inputCount, output, 0);
return output;
}
public void Dispose()
{
if (m_state != null)
Array.Clear(m_state, 0, m_state.Length);
m_state = null;
if (m_key != null)
Array.Clear(m_key, 0, m_key.Length);
m_key = null;
if (m_iv != null)
Array.Clear(m_iv, 0, m_iv.Length);
m_iv = null;
}
private static uint Rotate(uint v, int c)
{
return (v << c) | (v >> (32 - c));
}
private static uint Add(uint v, uint w)
{
return unchecked(v + w);
}
private static uint AddOne(uint v)
{
return unchecked(v + 1);
}
private void Hash(byte[] output, uint[] input)
{
uint[] state = (uint[])input.Clone();
for (int round = m_rounds; round > 0; round -= 2)
{
state[4] ^= Rotate(Add(state[0], state[12]), 7);
state[8] ^= Rotate(Add(state[4], state[0]), 9);
state[12] ^= Rotate(Add(state[8], state[4]), 13);
state[0] ^= Rotate(Add(state[12], state[8]), 18);
state[9] ^= Rotate(Add(state[5], state[1]), 7);
state[13] ^= Rotate(Add(state[9], state[5]), 9);
state[1] ^= Rotate(Add(state[13], state[9]), 13);
state[5] ^= Rotate(Add(state[1], state[13]), 18);
state[14] ^= Rotate(Add(state[10], state[6]), 7);
state[2] ^= Rotate(Add(state[14], state[10]), 9);
state[6] ^= Rotate(Add(state[2], state[14]), 13);
state[10] ^= Rotate(Add(state[6], state[2]), 18);
state[3] ^= Rotate(Add(state[15], state[11]), 7);
state[7] ^= Rotate(Add(state[3], state[15]), 9);
state[11] ^= Rotate(Add(state[7], state[3]), 13);
state[15] ^= Rotate(Add(state[11], state[7]), 18);
state[1] ^= Rotate(Add(state[0], state[3]), 7);
state[2] ^= Rotate(Add(state[1], state[0]), 9);
state[3] ^= Rotate(Add(state[2], state[1]), 13);
state[0] ^= Rotate(Add(state[3], state[2]), 18);
state[6] ^= Rotate(Add(state[5], state[4]), 7);
state[7] ^= Rotate(Add(state[6], state[5]), 9);
state[4] ^= Rotate(Add(state[7], state[6]), 13);
state[5] ^= Rotate(Add(state[4], state[7]), 18);
state[11] ^= Rotate(Add(state[10], state[9]), 7);
state[8] ^= Rotate(Add(state[11], state[10]), 9);
state[9] ^= Rotate(Add(state[8], state[11]), 13);
state[10] ^= Rotate(Add(state[9], state[8]), 18);
state[12] ^= Rotate(Add(state[15], state[14]), 7);
state[13] ^= Rotate(Add(state[12], state[15]), 9);
state[14] ^= Rotate(Add(state[13], state[12]), 13);
state[15] ^= Rotate(Add(state[14], state[13]), 18);
}
for (int index = 0; index < 16; index++)
ToBytes(Add(state[index], input[index]), output, 4 * index);
}
private void Initialize(byte[] key, byte[] iv)
{
m_state = new uint[16];
m_state[1] = ToUInt32(key, 0);
m_state[2] = ToUInt32(key, 4);
m_state[3] = ToUInt32(key, 8);
m_state[4] = ToUInt32(key, 12);
byte[] constants = key.Length == 32 ? c_sigma : c_tau;
int keyIndex = key.Length - 16;
m_state[11] = ToUInt32(key, keyIndex + 0);
m_state[12] = ToUInt32(key, keyIndex + 4);
m_state[13] = ToUInt32(key, keyIndex + 8);
m_state[14] = ToUInt32(key, keyIndex + 12);
m_state[0] = ToUInt32(constants, 0);
m_state[5] = ToUInt32(constants, 4);
m_state[10] = ToUInt32(constants, 8);
m_state[15] = ToUInt32(constants, 12);
m_state[6] = ToUInt32(iv, 0);
m_state[7] = ToUInt32(iv, 4);
m_state[8] = ToUInt32(iv, 8);
m_state[9] = ToUInt32(iv, 12);
}
private void Initialize(byte[] output, byte[] key, byte[] iv)
{
Initialize(key, iv);
byte[] constants = key.Length == 32 ? c_sigma : c_tau;
uint[] j = (uint[])m_state.Clone();
for (int round = m_rounds; round > 0; round -= 2)
{
m_state[4] ^= Rotate(Add(m_state[0], m_state[12]), 7);
m_state[8] ^= Rotate(Add(m_state[4], m_state[0]), 9);
m_state[12] ^= Rotate(Add(m_state[8], m_state[4]), 13);
m_state[0] ^= Rotate(Add(m_state[12], m_state[8]), 18);
m_state[9] ^= Rotate(Add(m_state[5], m_state[1]), 7);
m_state[13] ^= Rotate(Add(m_state[9], m_state[5]), 9);
m_state[1] ^= Rotate(Add(m_state[13], m_state[9]), 13);
m_state[5] ^= Rotate(Add(m_state[1], m_state[13]), 18);
m_state[14] ^= Rotate(Add(m_state[10], m_state[6]), 7);
m_state[2] ^= Rotate(Add(m_state[14], m_state[10]), 9);
m_state[6] ^= Rotate(Add(m_state[2], m_state[14]), 13);
m_state[10] ^= Rotate(Add(m_state[6], m_state[2]), 18);
m_state[3] ^= Rotate(Add(m_state[15], m_state[11]), 7);
m_state[7] ^= Rotate(Add(m_state[3], m_state[15]), 9);
m_state[11] ^= Rotate(Add(m_state[7], m_state[3]), 13);
m_state[15] ^= Rotate(Add(m_state[11], m_state[7]), 18);
m_state[1] ^= Rotate(Add(m_state[0], m_state[3]), 7);
m_state[2] ^= Rotate(Add(m_state[1], m_state[0]), 9);
m_state[3] ^= Rotate(Add(m_state[2], m_state[1]), 13);
m_state[0] ^= Rotate(Add(m_state[3], m_state[2]), 18);
m_state[6] ^= Rotate(Add(m_state[5], m_state[4]), 7);
m_state[7] ^= Rotate(Add(m_state[6], m_state[5]), 9);
m_state[4] ^= Rotate(Add(m_state[7], m_state[6]), 13);
m_state[5] ^= Rotate(Add(m_state[4], m_state[7]), 18);
m_state[11] ^= Rotate(Add(m_state[10], m_state[9]), 7);
m_state[8] ^= Rotate(Add(m_state[11], m_state[10]), 9);
m_state[9] ^= Rotate(Add(m_state[8], m_state[11]), 13);
m_state[10] ^= Rotate(Add(m_state[9], m_state[8]), 18);
m_state[12] ^= Rotate(Add(m_state[15], m_state[14]), 7);
m_state[13] ^= Rotate(Add(m_state[12], m_state[15]), 9);
m_state[14] ^= Rotate(Add(m_state[13], m_state[12]), 13);
m_state[15] ^= Rotate(Add(m_state[14], m_state[13]), 18);
}
for (int index = 0; index < 16; index++)
m_state[index] = Add(m_state[index], j[index]);
m_state[0] -= ToUInt32(constants, 0);
m_state[5] -= ToUInt32(constants, 4);
m_state[10] -= ToUInt32(constants, 8);
m_state[15] -= ToUInt32(constants, 12);
m_state[6] -= ToUInt32(iv, 0);
m_state[7] -= ToUInt32(iv, 4);
m_state[8] -= ToUInt32(iv, 8);
m_state[9] -= ToUInt32(iv, 12);
ToBytes(m_state[0], output, 0);
ToBytes(m_state[5], output, 4);
ToBytes(m_state[10], output, 8);
ToBytes(m_state[15], output, 12);
ToBytes(m_state[6], output, 16);
ToBytes(m_state[7], output, 20);
ToBytes(m_state[8], output, 24);
ToBytes(m_state[9], output, 28);
}
private static uint ToUInt32(byte[] input, int inputOffset)
{
return unchecked((uint)(((input[inputOffset] | (input[inputOffset + 1] << 8)) | (input[inputOffset + 2] << 16)) | (input[inputOffset + 3] << 24)));
}
private static void ToBytes(uint input, byte[] output, int outputOffset)
{
unchecked
{
output[outputOffset] = (byte)input;
output[outputOffset + 1] = (byte)(input >> 8);
output[outputOffset + 2] = (byte)(input >> 16);
output[outputOffset + 3] = (byte)(input >> 24);
}
}
static readonly byte[] c_sigma = Encoding.ASCII.GetBytes("expand 32-byte k");
static readonly byte[] c_tau = Encoding.ASCII.GetBytes("expand 16-byte k");
uint[] m_state;
readonly int m_rounds;
}
}
}