/* * This implementation of Salsa20 is ported from the reference implementation * by D. J. Bernstein, which can be found at: * http://cr.yp.to/snuffle/salsa20/ref/salsa20.c * * This work is hereby released into the Public Domain. To view a copy of the public domain dedication, * visit http://creativecommons.org/licenses/publicdomain/ or send a letter to * Creative Commons, 171 Second Street, Suite 300, San Francisco, California, 94105, USA. */ using System; using System.Diagnostics; using System.Security.Cryptography; using System.Text; namespace Jd.ACES.Common { /// /// Implements the Salsa20 stream encryption cipher, as defined at http://cr.yp.to/snuffle.html. /// /// See Salsa20 Implementation in C#. public sealed class Salsa20 : SymmetricAlgorithm { /// /// Initializes a new instance of the class. /// /// The implementation of the class derived from the symmetric algorithm is not valid. public Salsa20() { // set legal values LegalBlockSizesValue = new[] { new KeySizes(512, 512, 0) }; LegalKeySizesValue = new[] { new KeySizes(128, 256, 128) }; // set default values BlockSizeValue = 512; KeySizeValue = 256; m_rounds = 20; } /// /// Creates a symmetric decryptor object with the specified property /// and initialization vector (). /// /// The secret key to use for the symmetric algorithm. /// The initialization vector to use for the symmetric algorithm. /// A symmetric decryptor object. public override ICryptoTransform CreateDecryptor(byte[] rgbKey, byte[] rgbIV) { // decryption and encryption are symmetrical return CreateEncryptor(rgbKey, rgbIV); } /// /// Creates a symmetric encryptor object with the specified property /// and initialization vector (). /// /// The secret key to use for the symmetric algorithm. /// The initialization vector to use for the symmetric algorithm. /// A symmetric encryptor object. public override ICryptoTransform CreateEncryptor(byte[] rgbKey, byte[] rgbIV) { if (rgbKey == null) throw new ArgumentNullException("rgbKey"); if (!ValidKeySize(rgbKey.Length * 8)) throw new CryptographicException("Invalid key size; it must be 128 or 256 bits."); CheckValidIV(rgbIV, "rgbIV"); return new Salsa20CryptoTransform(rgbKey, rgbIV, m_rounds); } /// /// Generates a random initialization vector () to use for the algorithm. /// public override void GenerateIV() { // generate a random 8-byte IV IVValue = GetRandomBytes(8); } /// /// Generates a random key () to use for the algorithm. /// public override void GenerateKey() { // generate a random key KeyValue = GetRandomBytes(KeySize / 8); } /// /// Gets or sets the initialization vector () for the symmetric algorithm. /// /// The initialization vector. /// An attempt was made to set the initialization vector to null. /// An attempt was made to set the initialization vector to an invalid size. public override byte[] IV { get { return base.IV; } set { CheckValidIV(value, "value"); IVValue = (byte[])value.Clone(); } } /// /// Gets or sets the number of rounds used by the Salsa20 algorithm. /// /// The number of rounds. public int Rounds { get { return m_rounds; } set { if (value != 8 && value != 12 && value != 20) throw new ArgumentOutOfRangeException("value", "The number of rounds must be 8, 12, or 20."); m_rounds = value; } } // Verifies that iv is a legal value for a Salsa20 IV. private static void CheckValidIV(byte[] iv, string paramName) { if (iv == null) throw new ArgumentNullException(paramName); if (iv.Length % 8 != 0) throw new CryptographicException("Invalid IV size; it must be multiple of 8 bytes."); } // Returns a new byte array containing the specified number of random bytes. private static byte[] GetRandomBytes(int byteCount) { byte[] bytes = new byte[byteCount]; RandomNumberGenerator rng = RNGCryptoServiceProvider.Create(); rng.GetBytes(bytes); return bytes; } int m_rounds; /// /// Salsa20Impl is an implementation of that uses the Salsa20 algorithm. /// private sealed class Salsa20CryptoTransform : ICryptoTransform { public Salsa20CryptoTransform(byte[] key, byte[] iv, int rounds) { Debug.Assert(key.Length == 16 || key.Length == 32, "abyKey.Length == 16 || abyKey.Length == 32", "Invalid key size."); Debug.Assert(iv.Length % 8 == 0, "abyIV.Length % 8 == 0", "Invalid IV size."); Debug.Assert(rounds == 8 || rounds == 12 || rounds == 20, "rounds == 8 || rounds == 12 || rounds == 20", "Invalid number of rounds."); m_rounds = rounds; m_key = new byte[key.Length]; m_iv = iv; Initialize(m_key, key, iv); } byte[] m_key; byte[] m_iv; public bool CanReuseTransform { get { return false; } } public bool CanTransformMultipleBlocks { get { return true; } } public int InputBlockSize { get { return 64; } } public int OutputBlockSize { get { return 64; } } public int TransformBlock(byte[] inputBuffer, int inputOffset, int inputCount, byte[] outputBuffer, int outputOffset) { // check arguments if (inputBuffer == null) throw new ArgumentNullException("inputBuffer"); if (inputOffset < 0 || inputOffset >= inputBuffer.Length) throw new ArgumentOutOfRangeException("inputOffset"); if (inputCount < 0 || inputOffset + inputCount > inputBuffer.Length) throw new ArgumentOutOfRangeException("inputCount"); if (outputBuffer == null) throw new ArgumentNullException("outputBuffer"); if (outputOffset < 0 || outputOffset + inputCount > outputBuffer.Length) throw new ArgumentOutOfRangeException("outputOffset"); byte[] output = new byte[64]; byte[] inv = new byte[16]; int bytesTransformed = 0; for (int i = 0; i < 8; i++) inv[i] = m_iv[16 + i]; for (int i = 8; i < 16; i++) inv[i] = 0; while (inputCount >= 0) { Initialize(m_key, inv); Hash(output, m_state); int blockSize = Math.Min(64, inputCount); for (int i = 0; i < blockSize; i++) outputBuffer[outputOffset + i] = (byte)(inputBuffer[inputOffset + i] ^ output[i]); bytesTransformed += blockSize; uint u = 1; for (int i = 8; i < 16; i++) { u += inv[i]; inv[i] = (byte)u; u >>= 8; } inputCount -= 64; outputOffset += 64; inputOffset += 64; } return bytesTransformed; } public byte[] TransformFinalBlock(byte[] inputBuffer, int inputOffset, int inputCount) { if (inputCount < 0) throw new ArgumentOutOfRangeException("inputCount"); byte[] output = new byte[inputCount]; TransformBlock(inputBuffer, inputOffset, inputCount, output, 0); return output; } public void Dispose() { if (m_state != null) Array.Clear(m_state, 0, m_state.Length); m_state = null; if (m_key != null) Array.Clear(m_key, 0, m_key.Length); m_key = null; if (m_iv != null) Array.Clear(m_iv, 0, m_iv.Length); m_iv = null; } private static uint Rotate(uint v, int c) { return (v << c) | (v >> (32 - c)); } private static uint Add(uint v, uint w) { return unchecked(v + w); } private static uint AddOne(uint v) { return unchecked(v + 1); } private void Hash(byte[] output, uint[] input) { uint[] state = (uint[])input.Clone(); for (int round = m_rounds; round > 0; round -= 2) { state[4] ^= Rotate(Add(state[0], state[12]), 7); state[8] ^= Rotate(Add(state[4], state[0]), 9); state[12] ^= Rotate(Add(state[8], state[4]), 13); state[0] ^= Rotate(Add(state[12], state[8]), 18); state[9] ^= Rotate(Add(state[5], state[1]), 7); state[13] ^= Rotate(Add(state[9], state[5]), 9); state[1] ^= Rotate(Add(state[13], state[9]), 13); state[5] ^= Rotate(Add(state[1], state[13]), 18); state[14] ^= Rotate(Add(state[10], state[6]), 7); state[2] ^= Rotate(Add(state[14], state[10]), 9); state[6] ^= Rotate(Add(state[2], state[14]), 13); state[10] ^= Rotate(Add(state[6], state[2]), 18); state[3] ^= Rotate(Add(state[15], state[11]), 7); state[7] ^= Rotate(Add(state[3], state[15]), 9); state[11] ^= Rotate(Add(state[7], state[3]), 13); state[15] ^= Rotate(Add(state[11], state[7]), 18); state[1] ^= Rotate(Add(state[0], state[3]), 7); state[2] ^= Rotate(Add(state[1], state[0]), 9); state[3] ^= Rotate(Add(state[2], state[1]), 13); state[0] ^= Rotate(Add(state[3], state[2]), 18); state[6] ^= Rotate(Add(state[5], state[4]), 7); state[7] ^= Rotate(Add(state[6], state[5]), 9); state[4] ^= Rotate(Add(state[7], state[6]), 13); state[5] ^= Rotate(Add(state[4], state[7]), 18); state[11] ^= Rotate(Add(state[10], state[9]), 7); state[8] ^= Rotate(Add(state[11], state[10]), 9); state[9] ^= Rotate(Add(state[8], state[11]), 13); state[10] ^= Rotate(Add(state[9], state[8]), 18); state[12] ^= Rotate(Add(state[15], state[14]), 7); state[13] ^= Rotate(Add(state[12], state[15]), 9); state[14] ^= Rotate(Add(state[13], state[12]), 13); state[15] ^= Rotate(Add(state[14], state[13]), 18); } for (int index = 0; index < 16; index++) ToBytes(Add(state[index], input[index]), output, 4 * index); } private void Initialize(byte[] key, byte[] iv) { m_state = new uint[16]; m_state[1] = ToUInt32(key, 0); m_state[2] = ToUInt32(key, 4); m_state[3] = ToUInt32(key, 8); m_state[4] = ToUInt32(key, 12); byte[] constants = key.Length == 32 ? c_sigma : c_tau; int keyIndex = key.Length - 16; m_state[11] = ToUInt32(key, keyIndex + 0); m_state[12] = ToUInt32(key, keyIndex + 4); m_state[13] = ToUInt32(key, keyIndex + 8); m_state[14] = ToUInt32(key, keyIndex + 12); m_state[0] = ToUInt32(constants, 0); m_state[5] = ToUInt32(constants, 4); m_state[10] = ToUInt32(constants, 8); m_state[15] = ToUInt32(constants, 12); m_state[6] = ToUInt32(iv, 0); m_state[7] = ToUInt32(iv, 4); m_state[8] = ToUInt32(iv, 8); m_state[9] = ToUInt32(iv, 12); } private void Initialize(byte[] output, byte[] key, byte[] iv) { Initialize(key, iv); byte[] constants = key.Length == 32 ? c_sigma : c_tau; uint[] j = (uint[])m_state.Clone(); for (int round = m_rounds; round > 0; round -= 2) { m_state[4] ^= Rotate(Add(m_state[0], m_state[12]), 7); m_state[8] ^= Rotate(Add(m_state[4], m_state[0]), 9); m_state[12] ^= Rotate(Add(m_state[8], m_state[4]), 13); m_state[0] ^= Rotate(Add(m_state[12], m_state[8]), 18); m_state[9] ^= Rotate(Add(m_state[5], m_state[1]), 7); m_state[13] ^= Rotate(Add(m_state[9], m_state[5]), 9); m_state[1] ^= Rotate(Add(m_state[13], m_state[9]), 13); m_state[5] ^= Rotate(Add(m_state[1], m_state[13]), 18); m_state[14] ^= Rotate(Add(m_state[10], m_state[6]), 7); m_state[2] ^= Rotate(Add(m_state[14], m_state[10]), 9); m_state[6] ^= Rotate(Add(m_state[2], m_state[14]), 13); m_state[10] ^= Rotate(Add(m_state[6], m_state[2]), 18); m_state[3] ^= Rotate(Add(m_state[15], m_state[11]), 7); m_state[7] ^= Rotate(Add(m_state[3], m_state[15]), 9); m_state[11] ^= Rotate(Add(m_state[7], m_state[3]), 13); m_state[15] ^= Rotate(Add(m_state[11], m_state[7]), 18); m_state[1] ^= Rotate(Add(m_state[0], m_state[3]), 7); m_state[2] ^= Rotate(Add(m_state[1], m_state[0]), 9); m_state[3] ^= Rotate(Add(m_state[2], m_state[1]), 13); m_state[0] ^= Rotate(Add(m_state[3], m_state[2]), 18); m_state[6] ^= Rotate(Add(m_state[5], m_state[4]), 7); m_state[7] ^= Rotate(Add(m_state[6], m_state[5]), 9); m_state[4] ^= Rotate(Add(m_state[7], m_state[6]), 13); m_state[5] ^= Rotate(Add(m_state[4], m_state[7]), 18); m_state[11] ^= Rotate(Add(m_state[10], m_state[9]), 7); m_state[8] ^= Rotate(Add(m_state[11], m_state[10]), 9); m_state[9] ^= Rotate(Add(m_state[8], m_state[11]), 13); m_state[10] ^= Rotate(Add(m_state[9], m_state[8]), 18); m_state[12] ^= Rotate(Add(m_state[15], m_state[14]), 7); m_state[13] ^= Rotate(Add(m_state[12], m_state[15]), 9); m_state[14] ^= Rotate(Add(m_state[13], m_state[12]), 13); m_state[15] ^= Rotate(Add(m_state[14], m_state[13]), 18); } for (int index = 0; index < 16; index++) m_state[index] = Add(m_state[index], j[index]); m_state[0] -= ToUInt32(constants, 0); m_state[5] -= ToUInt32(constants, 4); m_state[10] -= ToUInt32(constants, 8); m_state[15] -= ToUInt32(constants, 12); m_state[6] -= ToUInt32(iv, 0); m_state[7] -= ToUInt32(iv, 4); m_state[8] -= ToUInt32(iv, 8); m_state[9] -= ToUInt32(iv, 12); ToBytes(m_state[0], output, 0); ToBytes(m_state[5], output, 4); ToBytes(m_state[10], output, 8); ToBytes(m_state[15], output, 12); ToBytes(m_state[6], output, 16); ToBytes(m_state[7], output, 20); ToBytes(m_state[8], output, 24); ToBytes(m_state[9], output, 28); } private static uint ToUInt32(byte[] input, int inputOffset) { return unchecked((uint)(((input[inputOffset] | (input[inputOffset + 1] << 8)) | (input[inputOffset + 2] << 16)) | (input[inputOffset + 3] << 24))); } private static void ToBytes(uint input, byte[] output, int outputOffset) { unchecked { output[outputOffset] = (byte)input; output[outputOffset + 1] = (byte)(input >> 8); output[outputOffset + 2] = (byte)(input >> 16); output[outputOffset + 3] = (byte)(input >> 24); } } static readonly byte[] c_sigma = Encoding.ASCII.GetBytes("expand 32-byte k"); static readonly byte[] c_tau = Encoding.ASCII.GetBytes("expand 16-byte k"); uint[] m_state; readonly int m_rounds; } } }