simd_adler32/imp/
avx512.rs

1use super::Adler32Imp;
2
3/// Resolves update implementation if CPU supports avx512f and avx512bw instructions.
4pub fn get_imp() -> Option<Adler32Imp> {
5  get_imp_inner()
6}
7
8#[inline]
9#[cfg(all(
10  feature = "std",
11  feature = "nightly",
12  any(target_arch = "x86", target_arch = "x86_64")
13))]
14fn get_imp_inner() -> Option<Adler32Imp> {
15  let has_avx512f = std::is_x86_feature_detected!("avx512f");
16  let has_avx512bw = std::is_x86_feature_detected!("avx512bw");
17
18  if has_avx512f && has_avx512bw {
19    Some(imp::update)
20  } else {
21    None
22  }
23}
24
25#[inline]
26#[cfg(all(
27  feature = "nightly",
28  all(target_feature = "avx512f", target_feature = "avx512bw"),
29  not(all(feature = "std", any(target_arch = "x86", target_arch = "x86_64")))
30))]
31fn get_imp_inner() -> Option<Adler32Imp> {
32  Some(imp::update)
33}
34
35#[inline]
36#[cfg(all(
37  not(all(feature = "nightly", target_feature = "avx512f", target_feature = "avx512bw")),
38  not(all(
39    feature = "std",
40    feature = "nightly",
41    any(target_arch = "x86", target_arch = "x86_64")
42  ))
43))]
44fn get_imp_inner() -> Option<Adler32Imp> {
45  None
46}
47
48#[cfg(all(
49  feature = "nightly",
50  any(target_arch = "x86", target_arch = "x86_64"),
51  any(
52    feature = "std",
53    all(target_feature = "avx512f", target_feature = "avx512bw")
54  )
55))]
56mod imp {
57  const MOD: u32 = 65521;
58  const NMAX: usize = 5552;
59  const BLOCK_SIZE: usize = 64;
60  const CHUNK_SIZE: usize = NMAX / BLOCK_SIZE * BLOCK_SIZE;
61
62  #[cfg(target_arch = "x86")]
63  use core::arch::x86::*;
64  #[cfg(target_arch = "x86_64")]
65  use core::arch::x86_64::*;
66
67  pub fn update(a: u16, b: u16, data: &[u8]) -> (u16, u16) {
68    unsafe { update_imp(a, b, data) }
69  }
70
71  #[inline]
72  #[target_feature(enable = "avx512f")]
73  #[target_feature(enable = "avx512bw")]
74  unsafe fn update_imp(a: u16, b: u16, data: &[u8]) -> (u16, u16) {
75    let mut a = a as u32;
76    let mut b = b as u32;
77
78    let chunks = data.chunks_exact(CHUNK_SIZE);
79    let remainder = chunks.remainder();
80    for chunk in chunks {
81      update_chunk_block(&mut a, &mut b, chunk);
82    }
83
84    update_block(&mut a, &mut b, remainder);
85
86    (a as u16, b as u16)
87  }
88
89  #[inline]
90  unsafe fn update_chunk_block(a: &mut u32, b: &mut u32, chunk: &[u8]) {
91    debug_assert_eq!(
92      chunk.len(),
93      CHUNK_SIZE,
94      "Unexpected chunk size (expected {}, got {})",
95      CHUNK_SIZE,
96      chunk.len()
97    );
98
99    reduce_add_blocks(a, b, chunk);
100
101    *a %= MOD;
102    *b %= MOD;
103  }
104
105  #[inline]
106  unsafe fn update_block(a: &mut u32, b: &mut u32, chunk: &[u8]) {
107    debug_assert!(
108      chunk.len() <= CHUNK_SIZE,
109      "Unexpected chunk size (expected <= {}, got {})",
110      CHUNK_SIZE,
111      chunk.len()
112    );
113
114    for byte in reduce_add_blocks(a, b, chunk) {
115      *a += *byte as u32;
116      *b += *a;
117    }
118
119    *a %= MOD;
120    *b %= MOD;
121  }
122
123  #[inline(always)]
124  unsafe fn reduce_add_blocks<'a>(a: &mut u32, b: &mut u32, chunk: &'a [u8]) -> &'a [u8] {
125    if chunk.len() < BLOCK_SIZE {
126      return chunk;
127    }
128
129    let blocks = chunk.chunks_exact(BLOCK_SIZE);
130    let blocks_remainder = blocks.remainder();
131
132    let one_v = _mm512_set1_epi16(1);
133    let zero_v = _mm512_setzero_si512();
134    let weights = get_weights();
135
136    let p_v = (*a * blocks.len() as u32) as _;
137    let mut p_v = _mm512_set_epi32(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, p_v);
138    let mut a_v = _mm512_setzero_si512();
139    let mut b_v = _mm512_set_epi32(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, *b as _);
140
141    for block in blocks {
142      let block_ptr = block.as_ptr() as *const _;
143      let block = _mm512_loadu_si512(block_ptr);
144
145      p_v = _mm512_add_epi32(p_v, a_v);
146
147      a_v = _mm512_add_epi32(a_v, _mm512_sad_epu8(block, zero_v));
148      let mad = _mm512_maddubs_epi16(block, weights);
149      b_v = _mm512_add_epi32(b_v, _mm512_madd_epi16(mad, one_v));
150    }
151
152    b_v = _mm512_add_epi32(b_v, _mm512_slli_epi32(p_v, 6));
153
154    *a += reduce_add(a_v);
155    *b = reduce_add(b_v);
156
157    blocks_remainder
158  }
159
160  #[inline(always)]
161  unsafe fn reduce_add(v: __m512i) -> u32 {
162    let v: [__m256i; 2] = core::mem::transmute(v);
163
164    reduce_add_256(v[0]) + reduce_add_256(v[1])
165  }
166
167  #[inline(always)]
168  unsafe fn reduce_add_256(v: __m256i) -> u32 {
169    let v: [__m128i; 2] = core::mem::transmute(v);
170    let sum = _mm_add_epi32(v[0], v[1]);
171    let hi = _mm_unpackhi_epi64(sum, sum);
172
173    let sum = _mm_add_epi32(hi, sum);
174    let hi = _mm_shuffle_epi32(sum, crate::imp::_MM_SHUFFLE(2, 3, 0, 1));
175
176    let sum = _mm_add_epi32(sum, hi);
177    let sum = _mm_cvtsi128_si32(sum) as _;
178
179    sum
180  }
181
182  #[inline(always)]
183  unsafe fn get_weights() -> __m512i {
184    _mm512_set_epi8(
185      1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
186      24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
187      45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
188    )
189  }
190}
191
192#[cfg(test)]
193mod tests {
194  use rand::Rng;
195
196  #[test]
197  fn zeroes() {
198    assert_sum_eq(&[]);
199    assert_sum_eq(&[0]);
200    assert_sum_eq(&[0, 0]);
201    assert_sum_eq(&[0; 100]);
202    assert_sum_eq(&[0; 1024]);
203    assert_sum_eq(&[0; 1024 * 1024]);
204  }
205
206  #[test]
207  fn ones() {
208    assert_sum_eq(&[]);
209    assert_sum_eq(&[1]);
210    assert_sum_eq(&[1, 1]);
211    assert_sum_eq(&[1; 100]);
212    assert_sum_eq(&[1; 1024]);
213    assert_sum_eq(&[1; 1024 * 1024]);
214  }
215
216  #[test]
217  fn random() {
218    let mut random = [0; 1024 * 1024];
219    rand::thread_rng().fill(&mut random[..]);
220
221    assert_sum_eq(&random[..1]);
222    assert_sum_eq(&random[..100]);
223    assert_sum_eq(&random[..1024]);
224    assert_sum_eq(&random[..1024 * 1024]);
225  }
226
227  /// Example calculation from https://en.wikipedia.org/wiki/Adler-32.
228  #[test]
229  fn wiki() {
230    assert_sum_eq(b"Wikipedia");
231  }
232
233  fn assert_sum_eq(data: &[u8]) {
234    if let Some(update) = super::get_imp() {
235      let (a, b) = update(1, 0, data);
236      let left = u32::from(b) << 16 | u32::from(a);
237      let right = adler::adler32_slice(data);
238
239      assert_eq!(left, right, "len({})", data.len());
240    }
241  }
242}