simd_adler32/imp/
avx2.rs

1use super::Adler32Imp;
2
3/// Resolves update implementation if CPU supports avx2 instructions.
4pub fn get_imp() -> Option<Adler32Imp> {
5  get_imp_inner()
6}
7
8#[inline]
9#[cfg(all(feature = "std", any(target_arch = "x86", target_arch = "x86_64")))]
10fn get_imp_inner() -> Option<Adler32Imp> {
11  if std::is_x86_feature_detected!("avx2") {
12    Some(imp::update)
13  } else {
14    None
15  }
16}
17
18#[inline]
19#[cfg(all(
20  target_feature = "avx2",
21  not(all(feature = "std", any(target_arch = "x86", target_arch = "x86_64")))
22))]
23fn get_imp_inner() -> Option<Adler32Imp> {
24  Some(imp::update)
25}
26
27#[inline]
28#[cfg(all(
29  not(target_feature = "avx2"),
30  not(all(feature = "std", any(target_arch = "x86", target_arch = "x86_64")))
31))]
32fn get_imp_inner() -> Option<Adler32Imp> {
33  None
34}
35
36#[cfg(all(
37  any(target_arch = "x86", target_arch = "x86_64"),
38  any(feature = "std", target_feature = "avx2")
39))]
40mod imp {
41  const MOD: u32 = 65521;
42  const NMAX: usize = 5552;
43  const BLOCK_SIZE: usize = 32;
44  const CHUNK_SIZE: usize = NMAX / BLOCK_SIZE * BLOCK_SIZE;
45
46  #[cfg(target_arch = "x86")]
47  use core::arch::x86::*;
48  #[cfg(target_arch = "x86_64")]
49  use core::arch::x86_64::*;
50
51  pub fn update(a: u16, b: u16, data: &[u8]) -> (u16, u16) {
52    unsafe { update_imp(a, b, data) }
53  }
54
55  #[inline]
56  #[target_feature(enable = "avx2")]
57  unsafe fn update_imp(a: u16, b: u16, data: &[u8]) -> (u16, u16) {
58    let mut a = a as u32;
59    let mut b = b as u32;
60
61    let chunks = data.chunks_exact(CHUNK_SIZE);
62    let remainder = chunks.remainder();
63    for chunk in chunks {
64      update_chunk_block(&mut a, &mut b, chunk);
65    }
66
67    update_block(&mut a, &mut b, remainder);
68
69    (a as u16, b as u16)
70  }
71
72  #[inline]
73  unsafe fn update_chunk_block(a: &mut u32, b: &mut u32, chunk: &[u8]) {
74    debug_assert_eq!(
75      chunk.len(),
76      CHUNK_SIZE,
77      "Unexpected chunk size (expected {}, got {})",
78      CHUNK_SIZE,
79      chunk.len()
80    );
81
82    reduce_add_blocks(a, b, chunk);
83
84    *a %= MOD;
85    *b %= MOD;
86  }
87
88  #[inline]
89  unsafe fn update_block(a: &mut u32, b: &mut u32, chunk: &[u8]) {
90    debug_assert!(
91      chunk.len() <= CHUNK_SIZE,
92      "Unexpected chunk size (expected <= {}, got {})",
93      CHUNK_SIZE,
94      chunk.len()
95    );
96
97    for byte in reduce_add_blocks(a, b, chunk) {
98      *a += *byte as u32;
99      *b += *a;
100    }
101
102    *a %= MOD;
103    *b %= MOD;
104  }
105
106  #[inline(always)]
107  unsafe fn reduce_add_blocks<'a>(a: &mut u32, b: &mut u32, chunk: &'a [u8]) -> &'a [u8] {
108    if chunk.len() < BLOCK_SIZE {
109      return chunk;
110    }
111
112    let blocks = chunk.chunks_exact(BLOCK_SIZE);
113    let blocks_remainder = blocks.remainder();
114
115    let one_v = _mm256_set1_epi16(1);
116    let zero_v = _mm256_setzero_si256();
117    let weights = get_weights();
118
119    let mut p_v = _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, (*a * blocks.len() as u32) as _);
120    let mut a_v = _mm256_setzero_si256();
121    let mut b_v = _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, *b as _);
122
123    for block in blocks {
124      let block_ptr = block.as_ptr() as *const _;
125      let block = _mm256_loadu_si256(block_ptr);
126
127      p_v = _mm256_add_epi32(p_v, a_v);
128
129      a_v = _mm256_add_epi32(a_v, _mm256_sad_epu8(block, zero_v));
130      let mad = _mm256_maddubs_epi16(block, weights);
131      b_v = _mm256_add_epi32(b_v, _mm256_madd_epi16(mad, one_v));
132    }
133
134    b_v = _mm256_add_epi32(b_v, _mm256_slli_epi32(p_v, 5));
135
136    *a += reduce_add(a_v);
137    *b = reduce_add(b_v);
138
139    blocks_remainder
140  }
141
142  #[inline(always)]
143  unsafe fn reduce_add(v: __m256i) -> u32 {
144    let sum = _mm_add_epi32(_mm256_castsi256_si128(v), _mm256_extracti128_si256(v, 1));
145    let hi = _mm_unpackhi_epi64(sum, sum);
146
147    let sum = _mm_add_epi32(hi, sum);
148    let hi = _mm_shuffle_epi32(sum, crate::imp::_MM_SHUFFLE(2, 3, 0, 1));
149
150    let sum = _mm_add_epi32(sum, hi);
151
152    _mm_cvtsi128_si32(sum) as _
153  }
154
155  #[inline(always)]
156  unsafe fn get_weights() -> __m256i {
157    _mm256_set_epi8(
158      1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
159      24, 25, 26, 27, 28, 29, 30, 31, 32,
160    )
161  }
162}
163
164#[cfg(test)]
165mod tests {
166  use rand::Rng;
167
168  #[test]
169  fn zeroes() {
170    assert_sum_eq(&[]);
171    assert_sum_eq(&[0]);
172    assert_sum_eq(&[0, 0]);
173    assert_sum_eq(&[0; 100]);
174    assert_sum_eq(&[0; 1024]);
175    assert_sum_eq(&[0; 1024 * 1024]);
176  }
177
178  #[test]
179  fn ones() {
180    assert_sum_eq(&[]);
181    assert_sum_eq(&[1]);
182    assert_sum_eq(&[1, 1]);
183    assert_sum_eq(&[1; 100]);
184    assert_sum_eq(&[1; 1024]);
185    assert_sum_eq(&[1; 1024 * 1024]);
186  }
187
188  #[test]
189  fn random() {
190    let mut random = [0; 1024 * 1024];
191    rand::thread_rng().fill(&mut random[..]);
192
193    assert_sum_eq(&random[..1]);
194    assert_sum_eq(&random[..100]);
195    assert_sum_eq(&random[..1024]);
196    assert_sum_eq(&random[..1024 * 1024]);
197  }
198
199  /// Example calculation from https://en.wikipedia.org/wiki/Adler-32.
200  #[test]
201  fn wiki() {
202    assert_sum_eq(b"Wikipedia");
203  }
204
205  fn assert_sum_eq(data: &[u8]) {
206    if let Some(update) = super::get_imp() {
207      let (a, b) = update(1, 0, data);
208      let left = u32::from(b) << 16 | u32::from(a);
209      let right = adler::adler32_slice(data);
210
211      assert_eq!(left, right, "len({})", data.len());
212    }
213  }
214}